1use super::prelude::*;
3
4verus! {
5
6#[cfg(verus_keep_ghost)]
7use super::arithmetic::power::pow;
8#[cfg(verus_keep_ghost)]
9use super::arithmetic::power2::{
10 pow2,
11 lemma_pow2_unfold,
12 lemma_pow2_adds,
13 lemma_pow2_pos,
14 lemma2_to64,
15 lemma2_to64_rest,
16 lemma_pow2_strictly_increases,
17};
18#[cfg(verus_keep_ghost)]
19use super::arithmetic::div_mod::{
20 lemma_div_by_multiple,
21 lemma_div_denominator,
22 lemma_div_is_ordered,
23 lemma_mod_breakdown,
24 lemma_mod_multiples_vanish,
25 lemma_remainder_lower,
26};
27#[cfg(verus_keep_ghost)]
28use super::arithmetic::mul::{
29 lemma_mul_inequality,
30 lemma_mul_is_commutative,
31 lemma_mul_is_associative,
32};
33#[cfg(verus_keep_ghost)]
34use super::calc_macro::*;
35
36} macro_rules! lemma_shr_is_div {
39 ($name:ident, $uN:ty) => {
40 #[cfg(verus_keep_ghost)]
41 verus! {
42 #[doc = "Proof that for x and n of type "]
43 #[doc = stringify!($uN)]
44 #[doc = ", shifting x right by n is equivalent to division of x by 2^n."]
45 pub broadcast proof fn $name(x: $uN, shift: $uN)
46 requires
47 0 <= shift < <$uN>::BITS,
48 ensures
49 #[trigger] (x >> shift) == x as nat / pow2(shift as nat),
50 decreases shift,
51 {
52 reveal(pow2);
53 if shift == 0 {
54 assert(x >> 0 == x) by (bit_vector);
55 reveal(pow);
56 assert(pow2(0) == 1) by (compute_only);
57 } else {
58 assert(x >> shift == (x >> ((sub(shift, 1)) as $uN)) / 2) by (bit_vector)
59 requires
60 0 < shift < <$uN>::BITS,
61 ;
62 calc!{ (==)
63 (x >> shift) as nat;
64 {}
65 ((x >> ((sub(shift, 1)) as $uN)) / 2) as nat;
66 { $name(x, (shift - 1) as $uN); }
67 (x as nat / pow2((shift - 1) as nat)) / 2;
68 {
69 lemma_pow2_pos((shift - 1) as nat);
70 lemma2_to64();
71 lemma_div_denominator(x as int, pow2((shift - 1) as nat) as int, 2);
72 }
73 x as nat / (pow2((shift - 1) as nat) * pow2(1));
74 {
75 lemma_pow2_adds((shift - 1) as nat, 1);
76 }
77 x as nat / pow2(shift as nat);
78 }
79 }
80 }
81 }
82 };
83}
84
85lemma_shr_is_div!(lemma_u64_shr_is_div, u64);
86lemma_shr_is_div!(lemma_u32_shr_is_div, u32);
87lemma_shr_is_div!(lemma_u16_shr_is_div, u16);
88lemma_shr_is_div!(lemma_u8_shr_is_div, u8);
89
90macro_rules! lemma_pow2_no_overflow {
92 ($name:ident, $uN:ty) => {
93 #[cfg(verus_keep_ghost)]
94 verus! {
95 #[doc = "Proof that 2^n does not overflow "]
96 #[doc = stringify!($uN)]
97 #[doc = " for an exponent n."]
98 pub broadcast proof fn $name(n: nat)
99 requires
100 0 <= n < <$uN>::BITS,
101 ensures
102 0 < #[trigger] pow2(n) < <$uN>::MAX,
103 {
104 lemma_pow2_pos(n);
105 lemma2_to64();
106 lemma2_to64_rest();
107 }
108 }
109 };
110}
111
112lemma_pow2_no_overflow!(lemma_u64_pow2_no_overflow, u64);
113lemma_pow2_no_overflow!(lemma_u32_pow2_no_overflow, u32);
114lemma_pow2_no_overflow!(lemma_u16_pow2_no_overflow, u16);
115lemma_pow2_no_overflow!(lemma_u8_pow2_no_overflow, u8);
116
117macro_rules! lemma_shl_is_mul {
119 ($name:ident, $no_overflow:ident, $uN:ty) => {
120 #[cfg(verus_keep_ghost)]
121 verus! {
122 #[doc = "Proof that for x and n of type "]
123 #[doc = stringify!($uN)]
124 #[doc = ", shifting x left by n is equivalent to multiplication of x by 2^n (provided no overflow)."]
125 pub broadcast proof fn $name(x: $uN, shift: $uN)
126 requires
127 0 <= shift < <$uN>::BITS,
128 x * pow2(shift as nat) <= <$uN>::MAX,
129 ensures
130 #[trigger] (x << shift) == x * pow2(shift as nat),
131 decreases shift,
132 {
133 $no_overflow(shift as nat);
134 if shift == 0 {
135 assert(x << 0 == x) by (bit_vector);
136 assert(pow2(0) == 1) by (compute_only);
137 } else {
138 assert(x << shift == mul(x << ((sub(shift, 1)) as $uN), 2)) by (bit_vector)
139 requires
140 0 < shift < <$uN>::BITS,
141 ;
142 assert((x << (sub(shift, 1) as $uN)) == x * pow2(sub(shift, 1) as nat)) by {
143 lemma_pow2_strictly_increases((shift - 1) as nat, shift as nat);
144 lemma_mul_inequality(
145 pow2((shift - 1) as nat) as int,
146 pow2(shift as nat) as int,
147 x as int,
148 );
149 lemma_mul_is_commutative(x as int, pow2((shift - 1) as nat) as int);
150 lemma_mul_is_commutative(x as int, pow2(shift as nat) as int);
151 $name(x, (shift - 1) as $uN);
152 }
153 calc!{ (==)
154 ((x << (sub(shift, 1) as $uN)) * 2);
155 {}
156 ((x * pow2(sub(shift, 1) as nat)) * 2);
157 {
158 lemma_mul_is_associative(x as int, pow2(sub(shift, 1) as nat) as int, 2);
159 }
160 x * ((pow2(sub(shift, 1) as nat)) * 2);
161 {
162 lemma_pow2_adds((shift - 1) as nat, 1);
163 lemma2_to64();
164 }
165 x * pow2(shift as nat);
166 }
167 }
168 }
169 }
170 };
171}
172
173lemma_shl_is_mul!(lemma_u64_shl_is_mul, lemma_u64_pow2_no_overflow, u64);
174lemma_shl_is_mul!(lemma_u32_shl_is_mul, lemma_u32_pow2_no_overflow, u32);
175lemma_shl_is_mul!(lemma_u16_shl_is_mul, lemma_u16_pow2_no_overflow, u16);
176lemma_shl_is_mul!(lemma_u8_shl_is_mul, lemma_u8_pow2_no_overflow, u8);
177
178macro_rules! lemma_mul_pow2_le_max_iff_max_shr {
179 ($name:ident, $shr_is_div:ident, $uN:ty) => {
180 #[cfg(verus_keep_ghost)]
181 verus! {
182 #[doc = "Proof that for x, n and max of type "]
183 #[doc = stringify!($uN)]
184 #[doc = ", multiplication of x by 2^n is less than or equal to max if and only if x is less than or equal to shifting max right by n."]
185 pub proof fn $name(x: $uN, shift: $uN, max: $uN)
186 requires
187 0 <= shift < <$uN>::BITS,
188 ensures
189 x * pow2(shift as nat) <= max <==> x <= (max >> shift),
190 {
191 assert(max >> shift == max as nat / pow2(shift as nat)) by {
192 $shr_is_div(max, shift as $uN);
193 };
194
195 lemma_pow2_pos(shift as nat);
196
197 if x * pow2(shift as nat) <= max {
198 assert(x <= (max as nat) / pow2(shift as nat)) by {
199 lemma_div_is_ordered(x as int * pow2(shift as nat) as int, max as int, pow2(shift as nat) as int);
200 lemma_div_by_multiple(x as int, pow2(shift as nat) as int);
201 };
202 }
203 if x <= (max >> shift) {
204 assert(x * pow2(shift as nat) <= max as nat) by {
205 lemma_mul_inequality(x as int, max as int / pow2(shift as nat) as int, pow2(shift as nat) as int);
206 lemma_remainder_lower(max as int, pow2(shift as nat) as int);
207 lemma_mul_is_commutative(max as int / pow2(shift as nat) as int, pow2(shift as nat) as int);
208 };
209 }
210 }
211 }
212 };
213}
214
215lemma_mul_pow2_le_max_iff_max_shr!(
216 lemma_u64_mul_pow2_le_max_iff_max_shr,
217 lemma_u64_shr_is_div,
218 u64
219);
220lemma_mul_pow2_le_max_iff_max_shr!(
221 lemma_u32_mul_pow2_le_max_iff_max_shr,
222 lemma_u32_shr_is_div,
223 u32
224);
225lemma_mul_pow2_le_max_iff_max_shr!(
226 lemma_u16_mul_pow2_le_max_iff_max_shr,
227 lemma_u16_shr_is_div,
228 u16
229);
230lemma_mul_pow2_le_max_iff_max_shr!(lemma_u8_mul_pow2_le_max_iff_max_shr, lemma_u8_shr_is_div, u8);
231
232verus! {
233
234pub open spec fn low_bits_mask(n: nat) -> nat {
236 (pow2(n) - 1) as nat
237}
238
239pub broadcast proof fn lemma_low_bits_mask_unfold(n: nat)
241 requires
242 n > 0,
243 ensures
244 #[trigger] low_bits_mask(n) == 2 * low_bits_mask((n - 1) as nat) + 1,
245{
246 calc! {
247 (==)
248 low_bits_mask(n); {}
249 (pow2(n) - 1) as nat; {
250 lemma_pow2_unfold(n);
251 }
252 (2 * pow2((n - 1) as nat) - 1) as nat; {}
253 (2 * (pow2((n - 1) as nat) - 1) + 1) as nat; {
254 lemma_pow2_pos((n - 1) as nat);
255 }
256 (2 * low_bits_mask((n - 1) as nat) + 1) as nat;
257 }
258}
259
260pub broadcast proof fn lemma_low_bits_mask_is_odd(n: nat)
262 requires
263 n > 0,
264 ensures
265 #[trigger] (low_bits_mask(n) % 2) == 1,
266{
267 calc! {
268 (==)
269 low_bits_mask(n) % 2; {
270 lemma_low_bits_mask_unfold(n);
271 }
272 (2 * low_bits_mask((n - 1) as nat) + 1) % 2; {
273 lemma_mod_multiples_vanish(low_bits_mask((n - 1) as nat) as int, 1, 2);
274 }
275 1nat % 2;
276 }
277}
278
279pub broadcast proof fn lemma_low_bits_mask_div2(n: nat)
281 requires
282 n > 0,
283 ensures
284 #[trigger] (low_bits_mask(n) / 2) == low_bits_mask((n - 1) as nat),
285{
286 lemma_low_bits_mask_unfold(n);
287}
288
289pub proof fn lemma_low_bits_mask_values()
292 ensures
293 low_bits_mask(0) == 0x0,
294 low_bits_mask(1) == 0x1,
295 low_bits_mask(2) == 0x3,
296 low_bits_mask(3) == 0x7,
297 low_bits_mask(4) == 0xf,
298 low_bits_mask(5) == 0x1f,
299 low_bits_mask(6) == 0x3f,
300 low_bits_mask(7) == 0x7f,
301 low_bits_mask(8) == 0xff,
302 low_bits_mask(9) == 0x1ff,
303 low_bits_mask(10) == 0x3ff,
304 low_bits_mask(11) == 0x7ff,
305 low_bits_mask(12) == 0xfff,
306 low_bits_mask(13) == 0x1fff,
307 low_bits_mask(14) == 0x3fff,
308 low_bits_mask(15) == 0x7fff,
309 low_bits_mask(16) == 0xffff,
310 low_bits_mask(17) == 0x1ffff,
311 low_bits_mask(18) == 0x3ffff,
312 low_bits_mask(19) == 0x7ffff,
313 low_bits_mask(20) == 0xfffff,
314 low_bits_mask(21) == 0x1fffff,
315 low_bits_mask(22) == 0x3fffff,
316 low_bits_mask(23) == 0x7fffff,
317 low_bits_mask(24) == 0xffffff,
318 low_bits_mask(25) == 0x1ffffff,
319 low_bits_mask(26) == 0x3ffffff,
320 low_bits_mask(27) == 0x7ffffff,
321 low_bits_mask(28) == 0xfffffff,
322 low_bits_mask(29) == 0x1fffffff,
323 low_bits_mask(30) == 0x3fffffff,
324 low_bits_mask(31) == 0x7fffffff,
325 low_bits_mask(32) == 0xffffffff,
326 low_bits_mask(64) == 0xffffffffffffffff,
327{
328 reveal(pow2);
329 #[verusfmt::skip]
330 assert(
331 low_bits_mask(0) == 0x0 &&
332 low_bits_mask(1) == 0x1 &&
333 low_bits_mask(2) == 0x3 &&
334 low_bits_mask(3) == 0x7 &&
335 low_bits_mask(4) == 0xf &&
336 low_bits_mask(5) == 0x1f &&
337 low_bits_mask(6) == 0x3f &&
338 low_bits_mask(7) == 0x7f &&
339 low_bits_mask(8) == 0xff &&
340 low_bits_mask(9) == 0x1ff &&
341 low_bits_mask(10) == 0x3ff &&
342 low_bits_mask(11) == 0x7ff &&
343 low_bits_mask(12) == 0xfff &&
344 low_bits_mask(13) == 0x1fff &&
345 low_bits_mask(14) == 0x3fff &&
346 low_bits_mask(15) == 0x7fff &&
347 low_bits_mask(16) == 0xffff &&
348 low_bits_mask(17) == 0x1ffff &&
349 low_bits_mask(18) == 0x3ffff &&
350 low_bits_mask(19) == 0x7ffff &&
351 low_bits_mask(20) == 0xfffff &&
352 low_bits_mask(21) == 0x1fffff &&
353 low_bits_mask(22) == 0x3fffff &&
354 low_bits_mask(23) == 0x7fffff &&
355 low_bits_mask(24) == 0xffffff &&
356 low_bits_mask(25) == 0x1ffffff &&
357 low_bits_mask(26) == 0x3ffffff &&
358 low_bits_mask(27) == 0x7ffffff &&
359 low_bits_mask(28) == 0xfffffff &&
360 low_bits_mask(29) == 0x1fffffff &&
361 low_bits_mask(30) == 0x3fffffff &&
362 low_bits_mask(31) == 0x7fffffff &&
363 low_bits_mask(32) == 0xffffffff &&
364 low_bits_mask(64) == 0xffffffffffffffff
365 ) by (compute_only);
366}
367
368} macro_rules! lemma_low_bits_mask_is_mod {
371 ($name:ident, $and_split_low_bit:ident, $no_overflow:ident, $uN:ty) => {
372 #[cfg(verus_keep_ghost)]
373 verus! {
374 #[doc = "Proof that for natural n and x of type "]
375 #[doc = stringify!($uN)]
376 #[doc = ", and with the low n-bit mask is equivalent to modulo 2^n."]
377 pub broadcast proof fn $name(x: $uN, n: nat)
378 requires
379 n < <$uN>::BITS,
380 ensures
381 #[trigger] (x & (low_bits_mask(n) as $uN)) == x % (pow2(n) as $uN),
382 decreases n,
383 {
384 $no_overflow(n);
386 lemma_pow2_pos(n);
387
388 if n == 0 {
390 assert(low_bits_mask(0) == 0) by (compute_only);
391 assert(x & 0 == 0) by (bit_vector);
392 assert(pow2(0) == 1) by (compute_only);
393 assert(x % 1 == 0);
394 } else {
395 lemma_pow2_unfold(n);
396 assert((x % 2) == ((x % 2) & 1)) by (bit_vector);
397 calc!{ (==)
398 x % (pow2(n) as $uN);
399 {}
400 x % ((2 * pow2((n-1) as nat)) as $uN);
401 {
402 lemma_pow2_pos((n-1) as nat);
403 lemma_mod_breakdown(x as int, 2, pow2((n-1) as nat) as int);
404 }
405 add(mul(2, (x / 2) % (pow2((n-1) as nat) as $uN)), x % 2);
406 {
407 $name(x/2, (n-1) as nat);
408 }
409 add(mul(2, (x / 2) & (low_bits_mask((n-1) as nat) as $uN)), x % 2);
410 {
411 lemma_low_bits_mask_div2(n);
412 }
413 add(mul(2, (x / 2) & (low_bits_mask(n) as $uN / 2)), x % 2);
414 {
415 lemma_low_bits_mask_is_odd(n);
416 }
417 add(mul(2, (x / 2) & (low_bits_mask(n) as $uN / 2)), (x % 2) & ((low_bits_mask(n) as $uN) % 2));
418 {
419 $and_split_low_bit(x as $uN, low_bits_mask(n) as $uN);
420 }
421 x & (low_bits_mask(n) as $uN);
422 }
423 }
424 }
425
426 proof fn $and_split_low_bit(x: $uN, m: $uN)
428 by (bit_vector)
429 ensures
430 x & m == add(mul(((x / 2) & (m / 2)), 2), (x % 2) & (m % 2)),
431 {
432 }
433 }
434 };
435}
436
437lemma_low_bits_mask_is_mod!(
438 lemma_u64_low_bits_mask_is_mod,
439 lemma_u64_and_split_low_bit,
440 lemma_u64_pow2_no_overflow,
441 u64
442);
443lemma_low_bits_mask_is_mod!(
444 lemma_u32_low_bits_mask_is_mod,
445 lemma_u32_and_split_low_bit,
446 lemma_u32_pow2_no_overflow,
447 u32
448);
449lemma_low_bits_mask_is_mod!(
450 lemma_u16_low_bits_mask_is_mod,
451 lemma_u16_and_split_low_bit,
452 lemma_u16_pow2_no_overflow,
453 u16
454);
455lemma_low_bits_mask_is_mod!(
456 lemma_u8_low_bits_mask_is_mod,
457 lemma_u8_and_split_low_bit,
458 lemma_u8_pow2_no_overflow,
459 u8
460);