Passing functions as values
In Rust, functions may be passed by value using the FnOnce
, FnMut
, and Fn
traits.
Just like for normal functions, Verus supports reasoning about the preconditions
and postconditions of such functions.
Reasoning about preconditions and postconditions
Verus allows you to reason about the preconditions and postconditions of function values
via two builtin spec functions: call_requires
and call_ensures
.
call_requires(f, args)
represents the precondition. It takes two arguments: the function object and arguments as a tuple. If it returns true, then it is possible to callf
with the given args.call_ensures(f, args, output)
represents the postcondition. It takes takes three arguments: the function object, arguments, and return vaue. It represents the valid input-output pairs forf
.
The vstd
library also provides aliases, f.requires(args)
and f.ensures(args, output)
.
These mean the same thing as call_requires
and call_ensures
.
As with any normal call, Verus demands that the precondition be satisfied when you call a function object. This is demonstrated by the following example:
fn double(x: u8) -> (res: u8)
requires
0 <= x < 128,
ensures
res == 2 * x,
{
2 * x
}
fn higher_order_fn(f: impl Fn(u8) -> u8) -> (res: u8) {
f(50)
}
fn test() {
higher_order_fn(double);
}
As we can see, test
calls higher_order_fn
, passing in double
.
The higher_order_fn
then calls the argument with 50
. This should be allowed,
according to the requires
clause of double
; however, higher_order_fn
does not have
the information to know this is correct.
Verus gives an error:
error: Call to non-static function fails to satisfy `callee.requires(args)`
--> vec_map.rs:25:5
|
25 | f(50)
| ^^^^^
To fix this, we can add a precondition to higher_order_fn
that gives information on
the precondition of f
:
fn double(x: u8) -> (res: u8)
requires
0 <= x < 128,
ensures
res == 2 * x,
{
2 * x
}
fn higher_order_fn(f: impl Fn(u8) -> u8) -> (res: u8)
requires
call_requires(f, (50,)),
{
f(50)
}
fn test() {
higher_order_fn(double);
}
The (50,)
looks a little funky. This is a 1-tuple.
The call_requires
and call_ensures
always take tuple arguments for the “args”.
If f
takes 0 arguments, then call_requires
takes a unit tuple;
if f
takes 2 arguments, then it takes a pair; etc.
Here, f
takes 1 argument, so it takes a 1-tuple, which can be constructed by using
the trailing comma, as in (50,)
.
Verus now accepts this code, as the precondition of higher_order_fn
now guarantees that
f
accepts the input of 50
.
We can go further and allow higher_order_fn
to reason about the output value of f
:
fn double(x: u8) -> (res: u8)
requires
0 <= x < 128,
ensures
res == 2 * x,
{
2 * x
}
fn higher_order_fn(f: impl Fn(u8) -> u8) -> (res: u8)
requires
call_requires(f, (50,)),
forall|x, y| call_ensures(f, x, y) ==> y % 2 == 0,
ensures
res % 2 == 0,
{
let ret = f(50);
return ret;
}
fn test() {
higher_order_fn(double);
}
Observe that the precondition of higher_order_fn
places a constraint on the postcondition
of f
.
As a result, higher_order_fn
learns information about the return value of f(50)
.
Specifically, it learns that call_ensures(f, (50,), ret)
holds, which by higher_order_fn
’s
precondition, implies that ret % 2 == 0
.
An important note
The above examples show the idiomatic way to constrain the preconditions and postconditions
of a function argument. Observe that call_requires
is used in a positive position,
i.e., “call_requires
holds for this value”.
Meanwhile call_ensures
is used in a negative position, i.e., on the left hand side
of an implication: “if call_ensures
holds for a given value, this is satisfies this particular constraint”.
It is very common to need a guarantee that f(args)
will return one specific value,
say expected_return_value
.
In this situation, it can be tempting to write,
requires call_ensures(f, args, expected_return_value),
as your constraint. However, this is almost never what you actually want,
and in fact, Verus may not even let you prove it.
The proposition call_ensures(f, args, expected_return_value)
says that expected_return_value
is a possible return value of f(args)
;
however, it says nothing about other possible return values.
In general, f
may be nondeterministic!
Just because expected_return_value
is one possible return
value does not mean it is only one.
When faced with this situation, what you really want is to write:
requires forall |ret| call_ensures(f, args, ret) ==> ret == expected_return_value
This is the proposition that you really want, i.e., “if f(args)
returns a value ret
,
then that value is equal to expected_return_value
”.
Of course, this is flipped around when you write a postcondition, as we’ll see in the next example.
Example: vec_map
Let’s take what we learned and write a simple function, vec_map
, which applies a given
function to each element of a vector and returns a new vector.
The key challenge is to determine the right specfication to use.
The signature we want is:
fn vec_map<T, U>(v: &Vec<T>, f: impl Fn(T) -> U) -> (result: Vec<U>) where
T: Copy,
First, what do we need to require? We need to require that it’s okay to call f
with any element of the vector as input.
requires
forall|i|
0 <= i < v.len() ==> call_requires(
f,
(v[i],),
),
Next, what ought we to ensure? Naturally, we want the returned vector to have the same
length as the input. Furthermore, we want to guarantee that any element in the output
vector is a possible output when the provided function f
is called on the corresponding
element from the input vector.
ensures
result.len() == v.len(),
forall|i|
0 <= i < v.len() ==> call_ensures(
f,
(v[i],),
#[trigger] result[i],
)
,
Now that we have a specification, the implementation and loop invariant should fall into place:
fn vec_map<T, U>(v: &Vec<T>, f: impl Fn(T) -> U) -> (result: Vec<U>) where
T: Copy,
requires
forall|i|
0 <= i < v.len() ==> call_requires(
f,
(v[i],),
),
ensures
result.len() == v.len(),
forall|i|
0 <= i < v.len() ==> call_ensures(
f,
(v[i],),
#[trigger] result[i],
)
,
{
let mut result = Vec::new();
let mut j = 0;
while j < v.len()
invariant
forall|i| 0 <= i < v.len() ==> call_requires(f, (v[i],)),
0 <= j <= v.len(),
j == result.len(),
forall|i| 0 <= i < j ==> call_ensures(f, (v[i],), #[trigger] result[i]),
{
result.push(f(v[j]));
j += 1;
}
result
}
Finally, we can try it out with an example:
fn double(x: u8) -> (res: u8)
requires
0 <= x < 128,
ensures
res == 2 * x,
{
2 * x
}
fn test_vec_map() {
let mut v = Vec::new();
v.push(0);
v.push(10);
v.push(20);
let w = vec_map(&v, double);
assert(w[2] == 40);
}
Conclusion
In this chapter, we learned how to write higher-order functions with higher-order specifications, i.e., specifications that constrain the specifications of functions that are passed around as values.
All of the examples from this chapter passed functions by referring to them directly by name,
e.g., passing the function double
by writing double
.
In Rust, a more common way to work with higher-order functions is to pass closures.
In the next chapter, we’ll learn how to use closures.