Skip to main content

vstd/
bits.rs

1//! Properties of bitwise operators.
2use super::prelude::*;
3
4verus! {
5
6#[cfg(verus_keep_ghost)]
7use super::arithmetic::power::pow;
8#[cfg(verus_keep_ghost)]
9use super::arithmetic::power2::{
10    pow2,
11    lemma_pow2_unfold,
12    lemma_pow2_adds,
13    lemma_pow2_pos,
14    lemma2_to64,
15    lemma2_to64_rest,
16    lemma_pow2_strictly_increases,
17};
18#[cfg(verus_keep_ghost)]
19use super::arithmetic::div_mod::{
20    lemma_div_by_multiple,
21    lemma_div_denominator,
22    lemma_div_is_ordered,
23    lemma_mod_breakdown,
24    lemma_mod_multiples_vanish,
25    lemma_remainder_lower,
26};
27#[cfg(verus_keep_ghost)]
28use super::arithmetic::mul::{
29    lemma_mul_inequality,
30    lemma_mul_is_commutative,
31    lemma_mul_is_associative,
32};
33#[cfg(verus_keep_ghost)]
34use super::calc_macro::*;
35
36} // verus!
37// Proofs that shift right is equivalent to division by power of 2.
38macro_rules! lemma_shr_is_div {
39    ($name:ident, $uN:ty) => {
40        #[cfg(verus_keep_ghost)]
41        verus! {
42        #[doc = "Proof that for x and n of type "]
43        #[doc = stringify!($uN)]
44        #[doc = ", shifting x right by n is equivalent to division of x by 2^n."]
45        pub broadcast proof fn $name(x: $uN, shift: $uN)
46            requires
47                0 <= shift < <$uN>::BITS,
48            ensures
49                #[trigger] (x >> shift) == x as nat / pow2(shift as nat),
50            decreases shift,
51        {
52            // Step by 4 to reduce recursion depth (divisor 16 fits in all unsigned types).
53            reveal(pow);
54            if shift == 0 {
55                assert(x >> 0 == x) by (bit_vector);
56                assert(pow2(0) == 1) by (compute_only);
57            } else if shift == 1 {
58                assert(x >> 1 == x / 2) by (bit_vector);
59                assert(pow2(1) == 2) by (compute_only);
60            } else if shift == 2 {
61                assert(x >> 2 == x / 4) by (bit_vector);
62                assert(pow2(2) == 4) by (compute_only);
63            } else if shift == 3 {
64                assert(x >> 3 == x / 8) by (bit_vector);
65                assert(pow2(3) == 8) by (compute_only);
66            } else {
67                assert(x >> shift == (x >> (sub(shift, 4) as $uN)) / 16) by (bit_vector)
68                    requires
69                        4 <= shift < <$uN>::BITS,
70                ;
71                calc!{ (==)
72                    (x >> shift) as nat;
73                        {}
74                    ((x >> (sub(shift, 4) as $uN)) / 16) as nat;
75                        { $name(x, (shift - 4) as $uN); }
76                    (x as nat / pow2((shift - 4) as nat)) / 16;
77                        {
78                            lemma_pow2_pos((shift - 4) as nat);
79                            lemma2_to64();
80                            assert(pow2(4) == 16) by (compute_only);
81                            lemma_div_denominator(x as int, pow2((shift - 4) as nat) as int, 16);
82                        }
83                    x as nat / (pow2((shift - 4) as nat) * pow2(4));
84                        {
85                            lemma_pow2_adds((shift - 4) as nat, 4);
86                        }
87                    x as nat / pow2(shift as nat);
88                }
89            }
90        }
91        }
92    };
93}
94
95lemma_shr_is_div!(lemma_u128_shr_is_div, u128);
96lemma_shr_is_div!(lemma_u64_shr_is_div, u64);
97lemma_shr_is_div!(lemma_u32_shr_is_div, u32);
98lemma_shr_is_div!(lemma_u16_shr_is_div, u16);
99lemma_shr_is_div!(lemma_u8_shr_is_div, u8);
100lemma_shr_is_div!(lemma_usize_shr_is_div, usize);
101
102// Proofs of when a power of 2 fits in an unsigned type.
103macro_rules! lemma_pow2_no_overflow {
104    ($name:ident, $uN:ty) => {
105        #[cfg(verus_keep_ghost)]
106        verus! {
107        #[doc = "Proof that 2^n does not overflow "]
108        #[doc = stringify!($uN)]
109        #[doc = " for an exponent n."]
110        pub broadcast proof fn $name(n: nat)
111            requires
112                0 <= n < <$uN>::BITS,
113            ensures
114                0 < #[trigger] pow2(n) < <$uN>::MAX,
115        {
116            lemma_pow2_pos(n);
117            lemma2_to64();
118            lemma2_to64_rest();
119        }
120        }
121    };
122}
123
124lemma_pow2_no_overflow!(lemma_u64_pow2_no_overflow, u64);
125lemma_pow2_no_overflow!(lemma_u32_pow2_no_overflow, u32);
126lemma_pow2_no_overflow!(lemma_u16_pow2_no_overflow, u16);
127lemma_pow2_no_overflow!(lemma_u8_pow2_no_overflow, u8);
128lemma_pow2_no_overflow!(lemma_usize_pow2_no_overflow, usize);
129
130// Proofs that shift left is equivalent to multiplication by power of 2.
131macro_rules! lemma_shl_is_mul {
132    ($name:ident, $no_overflow:ident, $uN:ty) => {
133        #[cfg(verus_keep_ghost)]
134        verus! {
135        #[doc = "Proof that for x and n of type "]
136        #[doc = stringify!($uN)]
137        #[doc = ", shifting x left by n is equivalent to multiplication of x by 2^n (provided no overflow)."]
138        pub broadcast proof fn $name(x: $uN, shift: $uN)
139            requires
140                0 <= shift < <$uN>::BITS,
141                x * pow2(shift as nat) <= <$uN>::MAX,
142            ensures
143                #[trigger] (x << shift) == x * pow2(shift as nat),
144            decreases shift,
145        {
146            $no_overflow(shift as nat);
147            if shift == 0 {
148                assert(x << 0 == x) by (bit_vector);
149                assert(pow2(0) == 1) by (compute_only);
150            } else {
151                assert(x << shift == mul(x << ((sub(shift, 1)) as $uN), 2)) by (bit_vector)
152                    requires
153                        0 < shift < <$uN>::BITS,
154                ;
155                assert((x << (sub(shift, 1) as $uN)) == x * pow2(sub(shift, 1) as nat)) by {
156                    lemma_pow2_strictly_increases((shift - 1) as nat, shift as nat);
157                    lemma_mul_inequality(
158                        pow2((shift - 1) as nat) as int,
159                        pow2(shift as nat) as int,
160                        x as int,
161                    );
162                    lemma_mul_is_commutative(x as int, pow2((shift - 1) as nat) as int);
163                    lemma_mul_is_commutative(x as int, pow2(shift as nat) as int);
164                    $name(x, (shift - 1) as $uN);
165                }
166                calc!{ (==)
167                    ((x << (sub(shift, 1) as $uN)) * 2);
168                        {}
169                    ((x * pow2(sub(shift, 1) as nat)) * 2);
170                        {
171                            lemma_mul_is_associative(x as int, pow2(sub(shift, 1) as nat) as int, 2);
172                        }
173                    x * ((pow2(sub(shift, 1) as nat)) * 2);
174                        {
175                            lemma_pow2_adds((shift - 1) as nat, 1);
176                            lemma2_to64();
177                        }
178                    x * pow2(shift as nat);
179                }
180            }
181        }
182        }
183    };
184}
185
186lemma_shl_is_mul!(lemma_u64_shl_is_mul, lemma_u64_pow2_no_overflow, u64);
187lemma_shl_is_mul!(lemma_u32_shl_is_mul, lemma_u32_pow2_no_overflow, u32);
188lemma_shl_is_mul!(lemma_u16_shl_is_mul, lemma_u16_pow2_no_overflow, u16);
189lemma_shl_is_mul!(lemma_u8_shl_is_mul, lemma_u8_pow2_no_overflow, u8);
190lemma_shl_is_mul!(lemma_usize_shl_is_mul, lemma_usize_pow2_no_overflow, usize);
191
192macro_rules! lemma_mul_pow2_le_max_iff_max_shr {
193    ($name:ident, $shr_is_div:ident, $uN:ty) => {
194        #[cfg(verus_keep_ghost)]
195        verus! {
196        #[doc = "Proof that for x, n and max of type "]
197        #[doc = stringify!($uN)]
198        #[doc = ", multiplication of x by 2^n is less than or equal to max if and only if x is less than or equal to shifting max right by n."]
199        pub proof fn $name(x: $uN, shift: $uN, max: $uN)
200        requires
201            0 <= shift < <$uN>::BITS,
202        ensures
203            x * pow2(shift as nat) <= max <==> x <= (max >> shift),
204    {
205        assert(max >> shift == max as nat / pow2(shift as nat)) by {
206            $shr_is_div(max, shift as $uN);
207        };
208
209        lemma_pow2_pos(shift as nat);
210
211        if x * pow2(shift as nat) <= max {
212            assert(x <= (max as nat) / pow2(shift as nat)) by {
213                lemma_div_is_ordered(x as int * pow2(shift as nat) as int, max as int, pow2(shift as nat) as int);
214                lemma_div_by_multiple(x as int, pow2(shift as nat) as int);
215            };
216        }
217        if x <= (max >> shift) {
218            assert(x * pow2(shift as nat) <= max as nat) by {
219                lemma_mul_inequality(x as int, max as int / pow2(shift as nat) as int,  pow2(shift as nat) as int);
220                lemma_remainder_lower(max as int, pow2(shift as nat) as int);
221                lemma_mul_is_commutative(max as int / pow2(shift as nat) as int,  pow2(shift as nat) as int);
222            };
223        }
224    }
225    }
226    };
227}
228
229lemma_mul_pow2_le_max_iff_max_shr!(
230    lemma_u64_mul_pow2_le_max_iff_max_shr,
231    lemma_u64_shr_is_div,
232    u64
233);
234lemma_mul_pow2_le_max_iff_max_shr!(
235    lemma_u32_mul_pow2_le_max_iff_max_shr,
236    lemma_u32_shr_is_div,
237    u32
238);
239lemma_mul_pow2_le_max_iff_max_shr!(
240    lemma_u16_mul_pow2_le_max_iff_max_shr,
241    lemma_u16_shr_is_div,
242    u16
243);
244lemma_mul_pow2_le_max_iff_max_shr!(lemma_u8_mul_pow2_le_max_iff_max_shr, lemma_u8_shr_is_div, u8);
245lemma_mul_pow2_le_max_iff_max_shr!(
246    lemma_usize_mul_pow2_le_max_iff_max_shr,
247    lemma_usize_shr_is_div,
248    usize
249);
250
251verus! {
252
253/// Mask with low n bits set.
254pub open spec fn low_bits_mask(n: nat) -> nat {
255    (pow2(n) - 1) as nat
256}
257
258/// Proof relating the n-bit mask to a function of the (n-1)-bit mask.
259pub broadcast proof fn lemma_low_bits_mask_unfold(n: nat)
260    requires
261        n > 0,
262    ensures
263        #[trigger] low_bits_mask(n) == 2 * low_bits_mask((n - 1) as nat) + 1,
264{
265    calc! {
266        (==)
267        low_bits_mask(n); {}
268        (pow2(n) - 1) as nat; {
269            lemma_pow2_unfold(n);
270        }
271        (2 * pow2((n - 1) as nat) - 1) as nat; {}
272        (2 * (pow2((n - 1) as nat) - 1) + 1) as nat; {
273            lemma_pow2_pos((n - 1) as nat);
274        }
275        (2 * low_bits_mask((n - 1) as nat) + 1) as nat;
276    }
277}
278
279/// Proof that low_bits_mask(n) is odd.
280pub broadcast proof fn lemma_low_bits_mask_is_odd(n: nat)
281    requires
282        n > 0,
283    ensures
284        #[trigger] (low_bits_mask(n) % 2) == 1,
285{
286    calc! {
287        (==)
288        low_bits_mask(n) % 2; {
289            lemma_low_bits_mask_unfold(n);
290        }
291        (2 * low_bits_mask((n - 1) as nat) + 1) % 2; {
292            lemma_mod_multiples_vanish(low_bits_mask((n - 1) as nat) as int, 1, 2);
293        }
294        1nat % 2;
295    }
296}
297
298/// Proof that dividing the low n bit mask by 2 gives the low n-1 bit mask.
299pub broadcast proof fn lemma_low_bits_mask_div2(n: nat)
300    requires
301        n > 0,
302    ensures
303        #[trigger] (low_bits_mask(n) / 2) == low_bits_mask((n - 1) as nat),
304{
305    lemma_low_bits_mask_unfold(n);
306}
307
308/// Proof establishing the concrete values of all masks of bit sizes from 0 to
309/// 32, and 64.
310pub proof fn lemma_low_bits_mask_values()
311    ensures
312        low_bits_mask(0) == 0x0,
313        low_bits_mask(1) == 0x1,
314        low_bits_mask(2) == 0x3,
315        low_bits_mask(3) == 0x7,
316        low_bits_mask(4) == 0xf,
317        low_bits_mask(5) == 0x1f,
318        low_bits_mask(6) == 0x3f,
319        low_bits_mask(7) == 0x7f,
320        low_bits_mask(8) == 0xff,
321        low_bits_mask(9) == 0x1ff,
322        low_bits_mask(10) == 0x3ff,
323        low_bits_mask(11) == 0x7ff,
324        low_bits_mask(12) == 0xfff,
325        low_bits_mask(13) == 0x1fff,
326        low_bits_mask(14) == 0x3fff,
327        low_bits_mask(15) == 0x7fff,
328        low_bits_mask(16) == 0xffff,
329        low_bits_mask(17) == 0x1ffff,
330        low_bits_mask(18) == 0x3ffff,
331        low_bits_mask(19) == 0x7ffff,
332        low_bits_mask(20) == 0xfffff,
333        low_bits_mask(21) == 0x1fffff,
334        low_bits_mask(22) == 0x3fffff,
335        low_bits_mask(23) == 0x7fffff,
336        low_bits_mask(24) == 0xffffff,
337        low_bits_mask(25) == 0x1ffffff,
338        low_bits_mask(26) == 0x3ffffff,
339        low_bits_mask(27) == 0x7ffffff,
340        low_bits_mask(28) == 0xfffffff,
341        low_bits_mask(29) == 0x1fffffff,
342        low_bits_mask(30) == 0x3fffffff,
343        low_bits_mask(31) == 0x7fffffff,
344        low_bits_mask(32) == 0xffffffff,
345        low_bits_mask(64) == 0xffffffffffffffff,
346{
347    #[verusfmt::skip]
348    assert(
349        low_bits_mask(0) == 0x0 &&
350        low_bits_mask(1) == 0x1 &&
351        low_bits_mask(2) == 0x3 &&
352        low_bits_mask(3) == 0x7 &&
353        low_bits_mask(4) == 0xf &&
354        low_bits_mask(5) == 0x1f &&
355        low_bits_mask(6) == 0x3f &&
356        low_bits_mask(7) == 0x7f &&
357        low_bits_mask(8) == 0xff &&
358        low_bits_mask(9) == 0x1ff &&
359        low_bits_mask(10) == 0x3ff &&
360        low_bits_mask(11) == 0x7ff &&
361        low_bits_mask(12) == 0xfff &&
362        low_bits_mask(13) == 0x1fff &&
363        low_bits_mask(14) == 0x3fff &&
364        low_bits_mask(15) == 0x7fff &&
365        low_bits_mask(16) == 0xffff &&
366        low_bits_mask(17) == 0x1ffff &&
367        low_bits_mask(18) == 0x3ffff &&
368        low_bits_mask(19) == 0x7ffff &&
369        low_bits_mask(20) == 0xfffff &&
370        low_bits_mask(21) == 0x1fffff &&
371        low_bits_mask(22) == 0x3fffff &&
372        low_bits_mask(23) == 0x7fffff &&
373        low_bits_mask(24) == 0xffffff &&
374        low_bits_mask(25) == 0x1ffffff &&
375        low_bits_mask(26) == 0x3ffffff &&
376        low_bits_mask(27) == 0x7ffffff &&
377        low_bits_mask(28) == 0xfffffff &&
378        low_bits_mask(29) == 0x1fffffff &&
379        low_bits_mask(30) == 0x3fffffff &&
380        low_bits_mask(31) == 0x7fffffff &&
381        low_bits_mask(32) == 0xffffffff &&
382        low_bits_mask(64) == 0xffffffffffffffff
383    ) by (compute_only);
384}
385
386} // verus!
387// Proofs that and with mask is equivalent to modulo with power of two.
388macro_rules! lemma_low_bits_mask_is_mod {
389    ($name:ident, $and_split_low_bit:ident, $no_overflow:ident, $uN:ty) => {
390        #[cfg(verus_keep_ghost)]
391        verus! {
392        #[doc = "Proof that for natural n and x of type "]
393        #[doc = stringify!($uN)]
394        #[doc = ", and with the low n-bit mask is equivalent to modulo 2^n."]
395        pub broadcast proof fn $name(x: $uN, n: nat)
396            requires
397                n < <$uN>::BITS,
398            ensures
399                #[trigger] (x & (low_bits_mask(n) as $uN)) == x % (pow2(n) as $uN),
400            decreases n,
401        {
402            // Bounds.
403            $no_overflow(n);
404            lemma_pow2_pos(n);
405
406            // Inductive proof.
407            if n == 0 {
408                assert(low_bits_mask(0) == 0) by (compute_only);
409                assert(x & 0 == 0) by (bit_vector);
410                assert(pow2(0) == 1) by (compute_only);
411                assert(x % 1 == 0);
412            } else {
413                lemma_pow2_unfold(n);
414                assert((x % 2) == ((x % 2) & 1)) by (bit_vector);
415                calc!{ (==)
416                    x % (pow2(n) as $uN);
417                        {}
418                    x % ((2 * pow2((n-1) as nat)) as $uN);
419                        {
420                            lemma_pow2_pos((n-1) as nat);
421                            lemma_mod_breakdown(x as int, 2, pow2((n-1) as nat) as int);
422                        }
423                    add(mul(2, (x / 2) % (pow2((n-1) as nat) as $uN)), x % 2);
424                        {
425                            $name(x/2, (n-1) as nat);
426                        }
427                    add(mul(2, (x / 2) & (low_bits_mask((n-1) as nat) as $uN)), x % 2);
428                        {
429                            lemma_low_bits_mask_div2(n);
430                        }
431                    add(mul(2, (x / 2) & (low_bits_mask(n) as $uN / 2)), x % 2);
432                        {
433                            lemma_low_bits_mask_is_odd(n);
434                        }
435                    add(mul(2, (x / 2) & (low_bits_mask(n) as $uN / 2)), (x % 2) & ((low_bits_mask(n) as $uN) % 2));
436                        {
437                            $and_split_low_bit(x as $uN, low_bits_mask(n) as $uN);
438                        }
439                    x & (low_bits_mask(n) as $uN);
440                }
441            }
442        }
443
444        // Helper lemma breaking a bitwise-and operation into the low bit and the rest.
445        proof fn $and_split_low_bit(x: $uN, m: $uN)
446            by (bit_vector)
447            ensures
448                x & m == add(mul(((x / 2) & (m / 2)), 2), (x % 2) & (m % 2)),
449        {
450        }
451        }
452    };
453}
454
455lemma_low_bits_mask_is_mod!(
456    lemma_u64_low_bits_mask_is_mod,
457    lemma_u64_and_split_low_bit,
458    lemma_u64_pow2_no_overflow,
459    u64
460);
461lemma_low_bits_mask_is_mod!(
462    lemma_u32_low_bits_mask_is_mod,
463    lemma_u32_and_split_low_bit,
464    lemma_u32_pow2_no_overflow,
465    u32
466);
467lemma_low_bits_mask_is_mod!(
468    lemma_u16_low_bits_mask_is_mod,
469    lemma_u16_and_split_low_bit,
470    lemma_u16_pow2_no_overflow,
471    u16
472);
473lemma_low_bits_mask_is_mod!(
474    lemma_u8_low_bits_mask_is_mod,
475    lemma_u8_and_split_low_bit,
476    lemma_u8_pow2_no_overflow,
477    u8
478);
479lemma_low_bits_mask_is_mod!(
480    lemma_usize_low_bits_mask_is_mod,
481    lemma_usize_and_split_low_bit,
482    lemma_usize_pow2_no_overflow,
483    usize
484);