vstd/
bits.rs

1//! Properties of bitwise operators.
2use super::prelude::*;
3
4verus! {
5
6#[cfg(verus_keep_ghost)]
7use super::arithmetic::power::pow;
8#[cfg(verus_keep_ghost)]
9use super::arithmetic::power2::{
10    pow2,
11    lemma_pow2_unfold,
12    lemma_pow2_adds,
13    lemma_pow2_pos,
14    lemma2_to64,
15    lemma2_to64_rest,
16    lemma_pow2_strictly_increases,
17};
18#[cfg(verus_keep_ghost)]
19use super::arithmetic::div_mod::{
20    lemma_div_by_multiple,
21    lemma_div_denominator,
22    lemma_div_is_ordered,
23    lemma_mod_breakdown,
24    lemma_mod_multiples_vanish,
25    lemma_remainder_lower,
26};
27#[cfg(verus_keep_ghost)]
28use super::arithmetic::mul::{
29    lemma_mul_inequality,
30    lemma_mul_is_commutative,
31    lemma_mul_is_associative,
32};
33#[cfg(verus_keep_ghost)]
34use super::calc_macro::*;
35
36} // verus!
37// Proofs that shift right is equivalent to division by power of 2.
38macro_rules! lemma_shr_is_div {
39    ($name:ident, $uN:ty) => {
40        #[cfg(verus_keep_ghost)]
41        verus! {
42        #[doc = "Proof that for x and n of type "]
43        #[doc = stringify!($uN)]
44        #[doc = ", shifting x right by n is equivalent to division of x by 2^n."]
45        pub broadcast proof fn $name(x: $uN, shift: $uN)
46            requires
47                0 <= shift < <$uN>::BITS,
48            ensures
49                #[trigger] (x >> shift) == x as nat / pow2(shift as nat),
50            decreases shift,
51        {
52            if shift == 0 {
53                assert(x >> 0 == x) by (bit_vector);
54                reveal(pow);
55                assert(pow2(0) == 1) by (compute_only);
56            } else {
57                assert(x >> shift == (x >> ((sub(shift, 1)) as $uN)) / 2) by (bit_vector)
58                    requires
59                        0 < shift < <$uN>::BITS,
60                ;
61                calc!{ (==)
62                    (x >> shift) as nat;
63                        {}
64                    ((x >> ((sub(shift, 1)) as $uN)) / 2) as nat;
65                        { $name(x, (shift - 1) as $uN); }
66                    (x as nat / pow2((shift - 1) as nat)) / 2;
67                        {
68                            lemma_pow2_pos((shift - 1) as nat);
69                            lemma2_to64();
70                            lemma_div_denominator(x as int, pow2((shift - 1) as nat) as int, 2);
71                        }
72                    x as nat / (pow2((shift - 1) as nat) * pow2(1));
73                        {
74                            lemma_pow2_adds((shift - 1) as nat, 1);
75                        }
76                    x as nat / pow2(shift as nat);
77                }
78            }
79        }
80        }
81    };
82}
83
84lemma_shr_is_div!(lemma_u128_shr_is_div, u128);
85lemma_shr_is_div!(lemma_u64_shr_is_div, u64);
86lemma_shr_is_div!(lemma_u32_shr_is_div, u32);
87lemma_shr_is_div!(lemma_u16_shr_is_div, u16);
88lemma_shr_is_div!(lemma_u8_shr_is_div, u8);
89lemma_shr_is_div!(lemma_usize_shr_is_div, usize);
90
91// Proofs of when a power of 2 fits in an unsigned type.
92macro_rules! lemma_pow2_no_overflow {
93    ($name:ident, $uN:ty) => {
94        #[cfg(verus_keep_ghost)]
95        verus! {
96        #[doc = "Proof that 2^n does not overflow "]
97        #[doc = stringify!($uN)]
98        #[doc = " for an exponent n."]
99        pub broadcast proof fn $name(n: nat)
100            requires
101                0 <= n < <$uN>::BITS,
102            ensures
103                0 < #[trigger] pow2(n) < <$uN>::MAX,
104        {
105            lemma_pow2_pos(n);
106            lemma2_to64();
107            lemma2_to64_rest();
108        }
109        }
110    };
111}
112
113lemma_pow2_no_overflow!(lemma_u64_pow2_no_overflow, u64);
114lemma_pow2_no_overflow!(lemma_u32_pow2_no_overflow, u32);
115lemma_pow2_no_overflow!(lemma_u16_pow2_no_overflow, u16);
116lemma_pow2_no_overflow!(lemma_u8_pow2_no_overflow, u8);
117lemma_pow2_no_overflow!(lemma_usize_pow2_no_overflow, usize);
118
119// Proofs that shift left is equivalent to multiplication by power of 2.
120macro_rules! lemma_shl_is_mul {
121    ($name:ident, $no_overflow:ident, $uN:ty) => {
122        #[cfg(verus_keep_ghost)]
123        verus! {
124        #[doc = "Proof that for x and n of type "]
125        #[doc = stringify!($uN)]
126        #[doc = ", shifting x left by n is equivalent to multiplication of x by 2^n (provided no overflow)."]
127        pub broadcast proof fn $name(x: $uN, shift: $uN)
128            requires
129                0 <= shift < <$uN>::BITS,
130                x * pow2(shift as nat) <= <$uN>::MAX,
131            ensures
132                #[trigger] (x << shift) == x * pow2(shift as nat),
133            decreases shift,
134        {
135            $no_overflow(shift as nat);
136            if shift == 0 {
137                assert(x << 0 == x) by (bit_vector);
138                assert(pow2(0) == 1) by (compute_only);
139            } else {
140                assert(x << shift == mul(x << ((sub(shift, 1)) as $uN), 2)) by (bit_vector)
141                    requires
142                        0 < shift < <$uN>::BITS,
143                ;
144                assert((x << (sub(shift, 1) as $uN)) == x * pow2(sub(shift, 1) as nat)) by {
145                    lemma_pow2_strictly_increases((shift - 1) as nat, shift as nat);
146                    lemma_mul_inequality(
147                        pow2((shift - 1) as nat) as int,
148                        pow2(shift as nat) as int,
149                        x as int,
150                    );
151                    lemma_mul_is_commutative(x as int, pow2((shift - 1) as nat) as int);
152                    lemma_mul_is_commutative(x as int, pow2(shift as nat) as int);
153                    $name(x, (shift - 1) as $uN);
154                }
155                calc!{ (==)
156                    ((x << (sub(shift, 1) as $uN)) * 2);
157                        {}
158                    ((x * pow2(sub(shift, 1) as nat)) * 2);
159                        {
160                            lemma_mul_is_associative(x as int, pow2(sub(shift, 1) as nat) as int, 2);
161                        }
162                    x * ((pow2(sub(shift, 1) as nat)) * 2);
163                        {
164                            lemma_pow2_adds((shift - 1) as nat, 1);
165                            lemma2_to64();
166                        }
167                    x * pow2(shift as nat);
168                }
169            }
170        }
171        }
172    };
173}
174
175lemma_shl_is_mul!(lemma_u64_shl_is_mul, lemma_u64_pow2_no_overflow, u64);
176lemma_shl_is_mul!(lemma_u32_shl_is_mul, lemma_u32_pow2_no_overflow, u32);
177lemma_shl_is_mul!(lemma_u16_shl_is_mul, lemma_u16_pow2_no_overflow, u16);
178lemma_shl_is_mul!(lemma_u8_shl_is_mul, lemma_u8_pow2_no_overflow, u8);
179lemma_shl_is_mul!(lemma_usize_shl_is_mul, lemma_usize_pow2_no_overflow, usize);
180
181macro_rules! lemma_mul_pow2_le_max_iff_max_shr {
182    ($name:ident, $shr_is_div:ident, $uN:ty) => {
183        #[cfg(verus_keep_ghost)]
184        verus! {
185        #[doc = "Proof that for x, n and max of type "]
186        #[doc = stringify!($uN)]
187        #[doc = ", multiplication of x by 2^n is less than or equal to max if and only if x is less than or equal to shifting max right by n."]
188        pub proof fn $name(x: $uN, shift: $uN, max: $uN)
189        requires
190            0 <= shift < <$uN>::BITS,
191        ensures
192            x * pow2(shift as nat) <= max <==> x <= (max >> shift),
193    {
194        assert(max >> shift == max as nat / pow2(shift as nat)) by {
195            $shr_is_div(max, shift as $uN);
196        };
197
198        lemma_pow2_pos(shift as nat);
199
200        if x * pow2(shift as nat) <= max {
201            assert(x <= (max as nat) / pow2(shift as nat)) by {
202                lemma_div_is_ordered(x as int * pow2(shift as nat) as int, max as int, pow2(shift as nat) as int);
203                lemma_div_by_multiple(x as int, pow2(shift as nat) as int);
204            };
205        }
206        if x <= (max >> shift) {
207            assert(x * pow2(shift as nat) <= max as nat) by {
208                lemma_mul_inequality(x as int, max as int / pow2(shift as nat) as int,  pow2(shift as nat) as int);
209                lemma_remainder_lower(max as int, pow2(shift as nat) as int);
210                lemma_mul_is_commutative(max as int / pow2(shift as nat) as int,  pow2(shift as nat) as int);
211            };
212        }
213    }
214    }
215    };
216}
217
218lemma_mul_pow2_le_max_iff_max_shr!(
219    lemma_u64_mul_pow2_le_max_iff_max_shr,
220    lemma_u64_shr_is_div,
221    u64
222);
223lemma_mul_pow2_le_max_iff_max_shr!(
224    lemma_u32_mul_pow2_le_max_iff_max_shr,
225    lemma_u32_shr_is_div,
226    u32
227);
228lemma_mul_pow2_le_max_iff_max_shr!(
229    lemma_u16_mul_pow2_le_max_iff_max_shr,
230    lemma_u16_shr_is_div,
231    u16
232);
233lemma_mul_pow2_le_max_iff_max_shr!(lemma_u8_mul_pow2_le_max_iff_max_shr, lemma_u8_shr_is_div, u8);
234lemma_mul_pow2_le_max_iff_max_shr!(
235    lemma_usize_mul_pow2_le_max_iff_max_shr,
236    lemma_usize_shr_is_div,
237    usize
238);
239
240verus! {
241
242/// Mask with low n bits set.
243pub open spec fn low_bits_mask(n: nat) -> nat {
244    (pow2(n) - 1) as nat
245}
246
247/// Proof relating the n-bit mask to a function of the (n-1)-bit mask.
248pub broadcast proof fn lemma_low_bits_mask_unfold(n: nat)
249    requires
250        n > 0,
251    ensures
252        #[trigger] low_bits_mask(n) == 2 * low_bits_mask((n - 1) as nat) + 1,
253{
254    calc! {
255        (==)
256        low_bits_mask(n); {}
257        (pow2(n) - 1) as nat; {
258            lemma_pow2_unfold(n);
259        }
260        (2 * pow2((n - 1) as nat) - 1) as nat; {}
261        (2 * (pow2((n - 1) as nat) - 1) + 1) as nat; {
262            lemma_pow2_pos((n - 1) as nat);
263        }
264        (2 * low_bits_mask((n - 1) as nat) + 1) as nat;
265    }
266}
267
268/// Proof that low_bits_mask(n) is odd.
269pub broadcast proof fn lemma_low_bits_mask_is_odd(n: nat)
270    requires
271        n > 0,
272    ensures
273        #[trigger] (low_bits_mask(n) % 2) == 1,
274{
275    calc! {
276        (==)
277        low_bits_mask(n) % 2; {
278            lemma_low_bits_mask_unfold(n);
279        }
280        (2 * low_bits_mask((n - 1) as nat) + 1) % 2; {
281            lemma_mod_multiples_vanish(low_bits_mask((n - 1) as nat) as int, 1, 2);
282        }
283        1nat % 2;
284    }
285}
286
287/// Proof that dividing the low n bit mask by 2 gives the low n-1 bit mask.
288pub broadcast proof fn lemma_low_bits_mask_div2(n: nat)
289    requires
290        n > 0,
291    ensures
292        #[trigger] (low_bits_mask(n) / 2) == low_bits_mask((n - 1) as nat),
293{
294    lemma_low_bits_mask_unfold(n);
295}
296
297/// Proof establishing the concrete values of all masks of bit sizes from 0 to
298/// 32, and 64.
299pub proof fn lemma_low_bits_mask_values()
300    ensures
301        low_bits_mask(0) == 0x0,
302        low_bits_mask(1) == 0x1,
303        low_bits_mask(2) == 0x3,
304        low_bits_mask(3) == 0x7,
305        low_bits_mask(4) == 0xf,
306        low_bits_mask(5) == 0x1f,
307        low_bits_mask(6) == 0x3f,
308        low_bits_mask(7) == 0x7f,
309        low_bits_mask(8) == 0xff,
310        low_bits_mask(9) == 0x1ff,
311        low_bits_mask(10) == 0x3ff,
312        low_bits_mask(11) == 0x7ff,
313        low_bits_mask(12) == 0xfff,
314        low_bits_mask(13) == 0x1fff,
315        low_bits_mask(14) == 0x3fff,
316        low_bits_mask(15) == 0x7fff,
317        low_bits_mask(16) == 0xffff,
318        low_bits_mask(17) == 0x1ffff,
319        low_bits_mask(18) == 0x3ffff,
320        low_bits_mask(19) == 0x7ffff,
321        low_bits_mask(20) == 0xfffff,
322        low_bits_mask(21) == 0x1fffff,
323        low_bits_mask(22) == 0x3fffff,
324        low_bits_mask(23) == 0x7fffff,
325        low_bits_mask(24) == 0xffffff,
326        low_bits_mask(25) == 0x1ffffff,
327        low_bits_mask(26) == 0x3ffffff,
328        low_bits_mask(27) == 0x7ffffff,
329        low_bits_mask(28) == 0xfffffff,
330        low_bits_mask(29) == 0x1fffffff,
331        low_bits_mask(30) == 0x3fffffff,
332        low_bits_mask(31) == 0x7fffffff,
333        low_bits_mask(32) == 0xffffffff,
334        low_bits_mask(64) == 0xffffffffffffffff,
335{
336    #[verusfmt::skip]
337    assert(
338        low_bits_mask(0) == 0x0 &&
339        low_bits_mask(1) == 0x1 &&
340        low_bits_mask(2) == 0x3 &&
341        low_bits_mask(3) == 0x7 &&
342        low_bits_mask(4) == 0xf &&
343        low_bits_mask(5) == 0x1f &&
344        low_bits_mask(6) == 0x3f &&
345        low_bits_mask(7) == 0x7f &&
346        low_bits_mask(8) == 0xff &&
347        low_bits_mask(9) == 0x1ff &&
348        low_bits_mask(10) == 0x3ff &&
349        low_bits_mask(11) == 0x7ff &&
350        low_bits_mask(12) == 0xfff &&
351        low_bits_mask(13) == 0x1fff &&
352        low_bits_mask(14) == 0x3fff &&
353        low_bits_mask(15) == 0x7fff &&
354        low_bits_mask(16) == 0xffff &&
355        low_bits_mask(17) == 0x1ffff &&
356        low_bits_mask(18) == 0x3ffff &&
357        low_bits_mask(19) == 0x7ffff &&
358        low_bits_mask(20) == 0xfffff &&
359        low_bits_mask(21) == 0x1fffff &&
360        low_bits_mask(22) == 0x3fffff &&
361        low_bits_mask(23) == 0x7fffff &&
362        low_bits_mask(24) == 0xffffff &&
363        low_bits_mask(25) == 0x1ffffff &&
364        low_bits_mask(26) == 0x3ffffff &&
365        low_bits_mask(27) == 0x7ffffff &&
366        low_bits_mask(28) == 0xfffffff &&
367        low_bits_mask(29) == 0x1fffffff &&
368        low_bits_mask(30) == 0x3fffffff &&
369        low_bits_mask(31) == 0x7fffffff &&
370        low_bits_mask(32) == 0xffffffff &&
371        low_bits_mask(64) == 0xffffffffffffffff
372    ) by (compute_only);
373}
374
375} // verus!
376// Proofs that and with mask is equivalent to modulo with power of two.
377macro_rules! lemma_low_bits_mask_is_mod {
378    ($name:ident, $and_split_low_bit:ident, $no_overflow:ident, $uN:ty) => {
379        #[cfg(verus_keep_ghost)]
380        verus! {
381        #[doc = "Proof that for natural n and x of type "]
382        #[doc = stringify!($uN)]
383        #[doc = ", and with the low n-bit mask is equivalent to modulo 2^n."]
384        pub broadcast proof fn $name(x: $uN, n: nat)
385            requires
386                n < <$uN>::BITS,
387            ensures
388                #[trigger] (x & (low_bits_mask(n) as $uN)) == x % (pow2(n) as $uN),
389            decreases n,
390        {
391            // Bounds.
392            $no_overflow(n);
393            lemma_pow2_pos(n);
394
395            // Inductive proof.
396            if n == 0 {
397                assert(low_bits_mask(0) == 0) by (compute_only);
398                assert(x & 0 == 0) by (bit_vector);
399                assert(pow2(0) == 1) by (compute_only);
400                assert(x % 1 == 0);
401            } else {
402                lemma_pow2_unfold(n);
403                assert((x % 2) == ((x % 2) & 1)) by (bit_vector);
404                calc!{ (==)
405                    x % (pow2(n) as $uN);
406                        {}
407                    x % ((2 * pow2((n-1) as nat)) as $uN);
408                        {
409                            lemma_pow2_pos((n-1) as nat);
410                            lemma_mod_breakdown(x as int, 2, pow2((n-1) as nat) as int);
411                        }
412                    add(mul(2, (x / 2) % (pow2((n-1) as nat) as $uN)), x % 2);
413                        {
414                            $name(x/2, (n-1) as nat);
415                        }
416                    add(mul(2, (x / 2) & (low_bits_mask((n-1) as nat) as $uN)), x % 2);
417                        {
418                            lemma_low_bits_mask_div2(n);
419                        }
420                    add(mul(2, (x / 2) & (low_bits_mask(n) as $uN / 2)), x % 2);
421                        {
422                            lemma_low_bits_mask_is_odd(n);
423                        }
424                    add(mul(2, (x / 2) & (low_bits_mask(n) as $uN / 2)), (x % 2) & ((low_bits_mask(n) as $uN) % 2));
425                        {
426                            $and_split_low_bit(x as $uN, low_bits_mask(n) as $uN);
427                        }
428                    x & (low_bits_mask(n) as $uN);
429                }
430            }
431        }
432
433        // Helper lemma breaking a bitwise-and operation into the low bit and the rest.
434        proof fn $and_split_low_bit(x: $uN, m: $uN)
435            by (bit_vector)
436            ensures
437                x & m == add(mul(((x / 2) & (m / 2)), 2), (x % 2) & (m % 2)),
438        {
439        }
440        }
441    };
442}
443
444lemma_low_bits_mask_is_mod!(
445    lemma_u64_low_bits_mask_is_mod,
446    lemma_u64_and_split_low_bit,
447    lemma_u64_pow2_no_overflow,
448    u64
449);
450lemma_low_bits_mask_is_mod!(
451    lemma_u32_low_bits_mask_is_mod,
452    lemma_u32_and_split_low_bit,
453    lemma_u32_pow2_no_overflow,
454    u32
455);
456lemma_low_bits_mask_is_mod!(
457    lemma_u16_low_bits_mask_is_mod,
458    lemma_u16_and_split_low_bit,
459    lemma_u16_pow2_no_overflow,
460    u16
461);
462lemma_low_bits_mask_is_mod!(
463    lemma_u8_low_bits_mask_is_mod,
464    lemma_u8_and_split_low_bit,
465    lemma_u8_pow2_no_overflow,
466    u8
467);
468lemma_low_bits_mask_is_mod!(
469    lemma_usize_low_bits_mask_is_mod,
470    lemma_usize_and_split_low_bit,
471    lemma_usize_pow2_no_overflow,
472    usize
473);