1use super::prelude::*;
3
4verus! {
5
6#[cfg(verus_keep_ghost)]
7use super::arithmetic::power::pow;
8#[cfg(verus_keep_ghost)]
9use super::arithmetic::power2::{
10 pow2,
11 lemma_pow2_unfold,
12 lemma_pow2_adds,
13 lemma_pow2_pos,
14 lemma2_to64,
15 lemma2_to64_rest,
16 lemma_pow2_strictly_increases,
17};
18#[cfg(verus_keep_ghost)]
19use super::arithmetic::div_mod::{
20 lemma_div_by_multiple,
21 lemma_div_denominator,
22 lemma_div_is_ordered,
23 lemma_mod_breakdown,
24 lemma_mod_multiples_vanish,
25 lemma_remainder_lower,
26};
27#[cfg(verus_keep_ghost)]
28use super::arithmetic::mul::{
29 lemma_mul_inequality,
30 lemma_mul_is_commutative,
31 lemma_mul_is_associative,
32};
33#[cfg(verus_keep_ghost)]
34use super::calc_macro::*;
35
36} macro_rules! lemma_shr_is_div {
39 ($name:ident, $uN:ty) => {
40 #[cfg(verus_keep_ghost)]
41 verus! {
42 #[doc = "Proof that for x and n of type "]
43 #[doc = stringify!($uN)]
44 #[doc = ", shifting x right by n is equivalent to division of x by 2^n."]
45 pub broadcast proof fn $name(x: $uN, shift: $uN)
46 requires
47 0 <= shift < <$uN>::BITS,
48 ensures
49 #[trigger] (x >> shift) == x as nat / pow2(shift as nat),
50 decreases shift,
51 {
52 if shift == 0 {
53 assert(x >> 0 == x) by (bit_vector);
54 reveal(pow);
55 assert(pow2(0) == 1) by (compute_only);
56 } else {
57 assert(x >> shift == (x >> ((sub(shift, 1)) as $uN)) / 2) by (bit_vector)
58 requires
59 0 < shift < <$uN>::BITS,
60 ;
61 calc!{ (==)
62 (x >> shift) as nat;
63 {}
64 ((x >> ((sub(shift, 1)) as $uN)) / 2) as nat;
65 { $name(x, (shift - 1) as $uN); }
66 (x as nat / pow2((shift - 1) as nat)) / 2;
67 {
68 lemma_pow2_pos((shift - 1) as nat);
69 lemma2_to64();
70 lemma_div_denominator(x as int, pow2((shift - 1) as nat) as int, 2);
71 }
72 x as nat / (pow2((shift - 1) as nat) * pow2(1));
73 {
74 lemma_pow2_adds((shift - 1) as nat, 1);
75 }
76 x as nat / pow2(shift as nat);
77 }
78 }
79 }
80 }
81 };
82}
83
84lemma_shr_is_div!(lemma_u128_shr_is_div, u128);
85lemma_shr_is_div!(lemma_u64_shr_is_div, u64);
86lemma_shr_is_div!(lemma_u32_shr_is_div, u32);
87lemma_shr_is_div!(lemma_u16_shr_is_div, u16);
88lemma_shr_is_div!(lemma_u8_shr_is_div, u8);
89lemma_shr_is_div!(lemma_usize_shr_is_div, usize);
90
91macro_rules! lemma_pow2_no_overflow {
93 ($name:ident, $uN:ty) => {
94 #[cfg(verus_keep_ghost)]
95 verus! {
96 #[doc = "Proof that 2^n does not overflow "]
97 #[doc = stringify!($uN)]
98 #[doc = " for an exponent n."]
99 pub broadcast proof fn $name(n: nat)
100 requires
101 0 <= n < <$uN>::BITS,
102 ensures
103 0 < #[trigger] pow2(n) < <$uN>::MAX,
104 {
105 lemma_pow2_pos(n);
106 lemma2_to64();
107 lemma2_to64_rest();
108 }
109 }
110 };
111}
112
113lemma_pow2_no_overflow!(lemma_u64_pow2_no_overflow, u64);
114lemma_pow2_no_overflow!(lemma_u32_pow2_no_overflow, u32);
115lemma_pow2_no_overflow!(lemma_u16_pow2_no_overflow, u16);
116lemma_pow2_no_overflow!(lemma_u8_pow2_no_overflow, u8);
117lemma_pow2_no_overflow!(lemma_usize_pow2_no_overflow, usize);
118
119macro_rules! lemma_shl_is_mul {
121 ($name:ident, $no_overflow:ident, $uN:ty) => {
122 #[cfg(verus_keep_ghost)]
123 verus! {
124 #[doc = "Proof that for x and n of type "]
125 #[doc = stringify!($uN)]
126 #[doc = ", shifting x left by n is equivalent to multiplication of x by 2^n (provided no overflow)."]
127 pub broadcast proof fn $name(x: $uN, shift: $uN)
128 requires
129 0 <= shift < <$uN>::BITS,
130 x * pow2(shift as nat) <= <$uN>::MAX,
131 ensures
132 #[trigger] (x << shift) == x * pow2(shift as nat),
133 decreases shift,
134 {
135 $no_overflow(shift as nat);
136 if shift == 0 {
137 assert(x << 0 == x) by (bit_vector);
138 assert(pow2(0) == 1) by (compute_only);
139 } else {
140 assert(x << shift == mul(x << ((sub(shift, 1)) as $uN), 2)) by (bit_vector)
141 requires
142 0 < shift < <$uN>::BITS,
143 ;
144 assert((x << (sub(shift, 1) as $uN)) == x * pow2(sub(shift, 1) as nat)) by {
145 lemma_pow2_strictly_increases((shift - 1) as nat, shift as nat);
146 lemma_mul_inequality(
147 pow2((shift - 1) as nat) as int,
148 pow2(shift as nat) as int,
149 x as int,
150 );
151 lemma_mul_is_commutative(x as int, pow2((shift - 1) as nat) as int);
152 lemma_mul_is_commutative(x as int, pow2(shift as nat) as int);
153 $name(x, (shift - 1) as $uN);
154 }
155 calc!{ (==)
156 ((x << (sub(shift, 1) as $uN)) * 2);
157 {}
158 ((x * pow2(sub(shift, 1) as nat)) * 2);
159 {
160 lemma_mul_is_associative(x as int, pow2(sub(shift, 1) as nat) as int, 2);
161 }
162 x * ((pow2(sub(shift, 1) as nat)) * 2);
163 {
164 lemma_pow2_adds((shift - 1) as nat, 1);
165 lemma2_to64();
166 }
167 x * pow2(shift as nat);
168 }
169 }
170 }
171 }
172 };
173}
174
175lemma_shl_is_mul!(lemma_u64_shl_is_mul, lemma_u64_pow2_no_overflow, u64);
176lemma_shl_is_mul!(lemma_u32_shl_is_mul, lemma_u32_pow2_no_overflow, u32);
177lemma_shl_is_mul!(lemma_u16_shl_is_mul, lemma_u16_pow2_no_overflow, u16);
178lemma_shl_is_mul!(lemma_u8_shl_is_mul, lemma_u8_pow2_no_overflow, u8);
179lemma_shl_is_mul!(lemma_usize_shl_is_mul, lemma_usize_pow2_no_overflow, usize);
180
181macro_rules! lemma_mul_pow2_le_max_iff_max_shr {
182 ($name:ident, $shr_is_div:ident, $uN:ty) => {
183 #[cfg(verus_keep_ghost)]
184 verus! {
185 #[doc = "Proof that for x, n and max of type "]
186 #[doc = stringify!($uN)]
187 #[doc = ", multiplication of x by 2^n is less than or equal to max if and only if x is less than or equal to shifting max right by n."]
188 pub proof fn $name(x: $uN, shift: $uN, max: $uN)
189 requires
190 0 <= shift < <$uN>::BITS,
191 ensures
192 x * pow2(shift as nat) <= max <==> x <= (max >> shift),
193 {
194 assert(max >> shift == max as nat / pow2(shift as nat)) by {
195 $shr_is_div(max, shift as $uN);
196 };
197
198 lemma_pow2_pos(shift as nat);
199
200 if x * pow2(shift as nat) <= max {
201 assert(x <= (max as nat) / pow2(shift as nat)) by {
202 lemma_div_is_ordered(x as int * pow2(shift as nat) as int, max as int, pow2(shift as nat) as int);
203 lemma_div_by_multiple(x as int, pow2(shift as nat) as int);
204 };
205 }
206 if x <= (max >> shift) {
207 assert(x * pow2(shift as nat) <= max as nat) by {
208 lemma_mul_inequality(x as int, max as int / pow2(shift as nat) as int, pow2(shift as nat) as int);
209 lemma_remainder_lower(max as int, pow2(shift as nat) as int);
210 lemma_mul_is_commutative(max as int / pow2(shift as nat) as int, pow2(shift as nat) as int);
211 };
212 }
213 }
214 }
215 };
216}
217
218lemma_mul_pow2_le_max_iff_max_shr!(
219 lemma_u64_mul_pow2_le_max_iff_max_shr,
220 lemma_u64_shr_is_div,
221 u64
222);
223lemma_mul_pow2_le_max_iff_max_shr!(
224 lemma_u32_mul_pow2_le_max_iff_max_shr,
225 lemma_u32_shr_is_div,
226 u32
227);
228lemma_mul_pow2_le_max_iff_max_shr!(
229 lemma_u16_mul_pow2_le_max_iff_max_shr,
230 lemma_u16_shr_is_div,
231 u16
232);
233lemma_mul_pow2_le_max_iff_max_shr!(lemma_u8_mul_pow2_le_max_iff_max_shr, lemma_u8_shr_is_div, u8);
234lemma_mul_pow2_le_max_iff_max_shr!(
235 lemma_usize_mul_pow2_le_max_iff_max_shr,
236 lemma_usize_shr_is_div,
237 usize
238);
239
240verus! {
241
242pub open spec fn low_bits_mask(n: nat) -> nat {
244 (pow2(n) - 1) as nat
245}
246
247pub broadcast proof fn lemma_low_bits_mask_unfold(n: nat)
249 requires
250 n > 0,
251 ensures
252 #[trigger] low_bits_mask(n) == 2 * low_bits_mask((n - 1) as nat) + 1,
253{
254 calc! {
255 (==)
256 low_bits_mask(n); {}
257 (pow2(n) - 1) as nat; {
258 lemma_pow2_unfold(n);
259 }
260 (2 * pow2((n - 1) as nat) - 1) as nat; {}
261 (2 * (pow2((n - 1) as nat) - 1) + 1) as nat; {
262 lemma_pow2_pos((n - 1) as nat);
263 }
264 (2 * low_bits_mask((n - 1) as nat) + 1) as nat;
265 }
266}
267
268pub broadcast proof fn lemma_low_bits_mask_is_odd(n: nat)
270 requires
271 n > 0,
272 ensures
273 #[trigger] (low_bits_mask(n) % 2) == 1,
274{
275 calc! {
276 (==)
277 low_bits_mask(n) % 2; {
278 lemma_low_bits_mask_unfold(n);
279 }
280 (2 * low_bits_mask((n - 1) as nat) + 1) % 2; {
281 lemma_mod_multiples_vanish(low_bits_mask((n - 1) as nat) as int, 1, 2);
282 }
283 1nat % 2;
284 }
285}
286
287pub broadcast proof fn lemma_low_bits_mask_div2(n: nat)
289 requires
290 n > 0,
291 ensures
292 #[trigger] (low_bits_mask(n) / 2) == low_bits_mask((n - 1) as nat),
293{
294 lemma_low_bits_mask_unfold(n);
295}
296
297pub proof fn lemma_low_bits_mask_values()
300 ensures
301 low_bits_mask(0) == 0x0,
302 low_bits_mask(1) == 0x1,
303 low_bits_mask(2) == 0x3,
304 low_bits_mask(3) == 0x7,
305 low_bits_mask(4) == 0xf,
306 low_bits_mask(5) == 0x1f,
307 low_bits_mask(6) == 0x3f,
308 low_bits_mask(7) == 0x7f,
309 low_bits_mask(8) == 0xff,
310 low_bits_mask(9) == 0x1ff,
311 low_bits_mask(10) == 0x3ff,
312 low_bits_mask(11) == 0x7ff,
313 low_bits_mask(12) == 0xfff,
314 low_bits_mask(13) == 0x1fff,
315 low_bits_mask(14) == 0x3fff,
316 low_bits_mask(15) == 0x7fff,
317 low_bits_mask(16) == 0xffff,
318 low_bits_mask(17) == 0x1ffff,
319 low_bits_mask(18) == 0x3ffff,
320 low_bits_mask(19) == 0x7ffff,
321 low_bits_mask(20) == 0xfffff,
322 low_bits_mask(21) == 0x1fffff,
323 low_bits_mask(22) == 0x3fffff,
324 low_bits_mask(23) == 0x7fffff,
325 low_bits_mask(24) == 0xffffff,
326 low_bits_mask(25) == 0x1ffffff,
327 low_bits_mask(26) == 0x3ffffff,
328 low_bits_mask(27) == 0x7ffffff,
329 low_bits_mask(28) == 0xfffffff,
330 low_bits_mask(29) == 0x1fffffff,
331 low_bits_mask(30) == 0x3fffffff,
332 low_bits_mask(31) == 0x7fffffff,
333 low_bits_mask(32) == 0xffffffff,
334 low_bits_mask(64) == 0xffffffffffffffff,
335{
336 #[verusfmt::skip]
337 assert(
338 low_bits_mask(0) == 0x0 &&
339 low_bits_mask(1) == 0x1 &&
340 low_bits_mask(2) == 0x3 &&
341 low_bits_mask(3) == 0x7 &&
342 low_bits_mask(4) == 0xf &&
343 low_bits_mask(5) == 0x1f &&
344 low_bits_mask(6) == 0x3f &&
345 low_bits_mask(7) == 0x7f &&
346 low_bits_mask(8) == 0xff &&
347 low_bits_mask(9) == 0x1ff &&
348 low_bits_mask(10) == 0x3ff &&
349 low_bits_mask(11) == 0x7ff &&
350 low_bits_mask(12) == 0xfff &&
351 low_bits_mask(13) == 0x1fff &&
352 low_bits_mask(14) == 0x3fff &&
353 low_bits_mask(15) == 0x7fff &&
354 low_bits_mask(16) == 0xffff &&
355 low_bits_mask(17) == 0x1ffff &&
356 low_bits_mask(18) == 0x3ffff &&
357 low_bits_mask(19) == 0x7ffff &&
358 low_bits_mask(20) == 0xfffff &&
359 low_bits_mask(21) == 0x1fffff &&
360 low_bits_mask(22) == 0x3fffff &&
361 low_bits_mask(23) == 0x7fffff &&
362 low_bits_mask(24) == 0xffffff &&
363 low_bits_mask(25) == 0x1ffffff &&
364 low_bits_mask(26) == 0x3ffffff &&
365 low_bits_mask(27) == 0x7ffffff &&
366 low_bits_mask(28) == 0xfffffff &&
367 low_bits_mask(29) == 0x1fffffff &&
368 low_bits_mask(30) == 0x3fffffff &&
369 low_bits_mask(31) == 0x7fffffff &&
370 low_bits_mask(32) == 0xffffffff &&
371 low_bits_mask(64) == 0xffffffffffffffff
372 ) by (compute_only);
373}
374
375} macro_rules! lemma_low_bits_mask_is_mod {
378 ($name:ident, $and_split_low_bit:ident, $no_overflow:ident, $uN:ty) => {
379 #[cfg(verus_keep_ghost)]
380 verus! {
381 #[doc = "Proof that for natural n and x of type "]
382 #[doc = stringify!($uN)]
383 #[doc = ", and with the low n-bit mask is equivalent to modulo 2^n."]
384 pub broadcast proof fn $name(x: $uN, n: nat)
385 requires
386 n < <$uN>::BITS,
387 ensures
388 #[trigger] (x & (low_bits_mask(n) as $uN)) == x % (pow2(n) as $uN),
389 decreases n,
390 {
391 $no_overflow(n);
393 lemma_pow2_pos(n);
394
395 if n == 0 {
397 assert(low_bits_mask(0) == 0) by (compute_only);
398 assert(x & 0 == 0) by (bit_vector);
399 assert(pow2(0) == 1) by (compute_only);
400 assert(x % 1 == 0);
401 } else {
402 lemma_pow2_unfold(n);
403 assert((x % 2) == ((x % 2) & 1)) by (bit_vector);
404 calc!{ (==)
405 x % (pow2(n) as $uN);
406 {}
407 x % ((2 * pow2((n-1) as nat)) as $uN);
408 {
409 lemma_pow2_pos((n-1) as nat);
410 lemma_mod_breakdown(x as int, 2, pow2((n-1) as nat) as int);
411 }
412 add(mul(2, (x / 2) % (pow2((n-1) as nat) as $uN)), x % 2);
413 {
414 $name(x/2, (n-1) as nat);
415 }
416 add(mul(2, (x / 2) & (low_bits_mask((n-1) as nat) as $uN)), x % 2);
417 {
418 lemma_low_bits_mask_div2(n);
419 }
420 add(mul(2, (x / 2) & (low_bits_mask(n) as $uN / 2)), x % 2);
421 {
422 lemma_low_bits_mask_is_odd(n);
423 }
424 add(mul(2, (x / 2) & (low_bits_mask(n) as $uN / 2)), (x % 2) & ((low_bits_mask(n) as $uN) % 2));
425 {
426 $and_split_low_bit(x as $uN, low_bits_mask(n) as $uN);
427 }
428 x & (low_bits_mask(n) as $uN);
429 }
430 }
431 }
432
433 proof fn $and_split_low_bit(x: $uN, m: $uN)
435 by (bit_vector)
436 ensures
437 x & m == add(mul(((x / 2) & (m / 2)), 2), (x % 2) & (m % 2)),
438 {
439 }
440 }
441 };
442}
443
444lemma_low_bits_mask_is_mod!(
445 lemma_u64_low_bits_mask_is_mod,
446 lemma_u64_and_split_low_bit,
447 lemma_u64_pow2_no_overflow,
448 u64
449);
450lemma_low_bits_mask_is_mod!(
451 lemma_u32_low_bits_mask_is_mod,
452 lemma_u32_and_split_low_bit,
453 lemma_u32_pow2_no_overflow,
454 u32
455);
456lemma_low_bits_mask_is_mod!(
457 lemma_u16_low_bits_mask_is_mod,
458 lemma_u16_and_split_low_bit,
459 lemma_u16_pow2_no_overflow,
460 u16
461);
462lemma_low_bits_mask_is_mod!(
463 lemma_u8_low_bits_mask_is_mod,
464 lemma_u8_and_split_low_bit,
465 lemma_u8_pow2_no_overflow,
466 u8
467);
468lemma_low_bits_mask_is_mod!(
469 lemma_usize_low_bits_mask_is_mod,
470 lemma_usize_and_split_low_bit,
471 lemma_usize_pow2_no_overflow,
472 usize
473);