Abusing rust's type system
Published on ·rust
type system
recursion
Introduction
In this blog, I will tell you about all the type system tricks I know and how I encorporated all those tricks into https://github.com/thatmagicalcat/tcrts.
The Goal
Lol it’s fairly obvious, we’ll abuse rust’s type system what else?
The Plan
The most important thing for abusing the type system is to have a good understanding of recursion. We are going to use recursion to create a set of types that can be used to represent a set of values, datastructure like list, tuple, etc.
The Basics
Let’s start with the numbers, we are going to represent the numbers using types. Let’s start with the number zero.
struct Zero;
No we aren’t going to create a struct for every number, that would be stupid. Instead, we are going to use a recursive type to represent the numbers.
What I mean is, for 1 can be written as 0 + 1, 2 can be written as 0 + 1 + 1, and so on.
// we'll have to use PhantomData marker as
// we are not going to use the generic type T
struct Next<T>(std::marker::PhantomData<T>);
And now we can represent the numbers as follows:
type One = Next<Zero>;
type Two = Next<One>;
type Three = Next<Two>;
// and so on
Currently the Next<T>
and Zero
are two separate structs, we’ll define a trait called Num
to unify them, this crate won’t do anything but it’ll help us with the implementation of traits.
trait Num {}
// and implement the trait for both types
impl Num for Zero {}
impl<T: Num> Num for Next<T> {}
// We're also gonna change the `Next` struct's
// definition to include the `Num` trait bound
struct Next<T: Num>(std::marker::PhantomData<T>);
Ok but let’s say we have a very long type chain, for example Next<Next<Next<..Next<Zero>..>>>
, how do we get the actual value from this type?
Surely we are not going to count the number of Next
types manually. For checking the value of the type, let’s define a trait called ToVal
that will convert the type to a value which we can print.
trait ToVal {
const VALUE: usize;
}
Yup you guessed it, it is going to be a recursive definition.
In any recursive definition, we need a base case, and this implementation the base case is Zero
.
impl ToVal for Zero {
const VALUE: usize = 0;
}
And now we can implement the ToVal
trait for Next<T>
as follows:
impl<N: ToVal> ToVal for Next<N> {
const VALUE: usize = 1 + N::VALUE;
}
We can test that this works by running the following code:
fn main() {
type One = Next<Zero>;
type Two = Next<One>;
type Three = Next<Two>;
assert_eq!(One::VALUE, 1);
assert_eq!(Two::VALUE, 2);
assert_eq!(Three::VALUE, 3);
}
And it works! Now we’re done with the basics, let’s move on to the next part and define some operations which we can perform on these number types.
Arithmetic operations
For our arithmetic operations, we are going to use Peano Axioms, which are a set of axioms for the natural numbers that can be used to define operations like addition, subtraction and multiplication in a recursive way.
Addition
Before jumping into the type system implementation, let’s define the addition operation using the Peano Axioms. The addition operation is defined as follows:
Let’s define a function S(x) = x + 1
(this is our Next
struct from above).
Our base cases are going to be:
Which is fairly obvious, as addition is commutative, that means, a + (b + 1) is same as (a + b) + 1. We did the same thing above.
Let’s try to create a recursive definition for simple examples first:
Now we can see a pattern here, we can generalize this for a + n
as follows:
Here’s another example to help you understand, let’s try to use our recursive definition to find 10 + 4
With that understanding, we’re ready to implement the addition operation in Rust’s type system.
Let’s start by defining a trait called PeanoAdd
that will represent the addition operation.
trait PeanoAdd<Rhs: Num> {
type Output: Num;
}
Base case:
// N + 0 = N
impl<N: Num> PeanoAdd<Zero> for N {
type Output = N;
}
Next, we generalize the addition operation for Next<T>
as follows:
// N + Next<M> = Next<N + Next<M - 1>>
impl<N, M> PeanoAdd<Next<M>> for N
where
N: Num + PeanoAdd<M>,
M: Num,
{
type Output = Next<<N as PeanoAdd<M>>::Output>;
}
Notice that we’re doing <N as PeanoAdd<M>>
but out trait is PeanoAdd<Next<M>>
what this means is we’re basically doing M - 1
, which we can’t do directly in Rust system
We can test our addition operation by running the following code:
type _0 = Zero;
type _1 = Next<_0>;
type _2 = Next<_1>;
type _3 = Next<_2>;
type _4 = Next<_3>;
type _5 = Next<_4>;
type _6 = Next<_5>;
type _7 = Next<_6>;
type _8 = Next<_7>;
type _9 = Next<_8>;
fn main() {
assert_eq!(<_5 as PeanoAdd<_6>>::Output::VALUE, _9::VALUE);
// Remember `::VALUE` comes from the `ToValue` trait
}
If you’re wondering how this recursive definition will actually work, I’ve got you covered.
Let’s say we have two numbers
type A = Next<Next<Next< ... Zero ... >>> // Next<...> applied A times type B = Next<Next<Next< ... Zero ... >>> // Next<...> applied B times
Now our addition will look like this
type Result = <A as PeanoAdd<B>>::Output;
I’ll use a notation to simplify the type chain
Next<Next<Next<Zero>>> - 1
will be equivalent toNext<Next<Zero>>
Ok we’re ready to expand our addition operation step by step.
<A as PeanoAdd<B>>::Output = Next<A as PeanoAdd<B - 1>> <A as Add<Next<B - 1>>::Output = Next<A as Add<Next<B - 2>>> <A as Add<Next<B - 2>>::Output = Next<A as Add<Next<B - 3>>> <A as Add<Next<B - 3>>::Output = Next<A as Add<Next<B - 4>>> . . repeated B times . <A as Add<Next<Zero>>::Output = Next<A as Add<Zero>> <A as Add<Zero>>::Output = Next<A as Add<Zero>> // this is our base case, its value is A itself // now when you unwind the stack, you get <A as Add<Next<One>>::Output = Next<A> <A as Add<Next<Two>>::Output = Next<Next<A>> <A as Add<Next<Three>>::Output = Next<Next<Next<A>>> . . repeated B times . <A as Add<Next<B - 1>>::Output = Next<Next<Next<... A ...> ^^^^^^^^^^^^^^^^^^^^^ | applied B times // and finally we get <A as Add<B>::Output = Next<Next<Next<... A ...>
You can skip the above explanation if you understand the mathematical definition of addition that we defined above, but I think it is important to understand how the recursion works in Rust’s type system. And you should be able to dry run the type system code in your head because that’s the only way to debug your code.
Subtraction
Subtraction kinda the same as addition but instead of adding 1,
we’ll keep subtracting 1 from both of the terms until one of them becomes 0.
Here’s what i mean:
Let’s take an example of 5 - 3
As you can see we’re basically removing the S
function from both terms until the right term becomes 0
, and if we reach 0
for left term first, we’ll have a negative number which is a problem as we don’t really have a way to deal with.
Since the I already the explained the concept above, i’ll jump straight into the code:
The trait:
trait PeanoSub<N: Num> {
type Output: Num;
}
Base case
// N - 0 = N
impl<N: Num> PeanoSub<Zero> for N {
type Output = N;
}
Recursive definition:
// Next<N> - Next<M> = N - M
impl<N, M> PeanoSub<Next<M>> for Next<N>
where
N: Num + PeanoSub<M>,
M: Num,
{
type Output = <N as PeanoSub<M>>::Output;
}
You can see the similarity, the implementation is defined on Next<M>
and Next<N>
, so using N
and M
in the recurive definition is same as converting S(a) -> a
.
Multiplication
Multiplication is kinda the hardest thing to figure, it took me a while to come up with a recursive definition that would work with rust’s type system, here it is:
Here’s the definition
Here’s how it looks like:
Now we unwind the stack
Here’s an example, let’s try to calculate 5 * 6
Now let’s implement this in rust
The trait:
trait PeanoMul<N: Num> {
type Output: Num;
}
Base case:
// N * 0 = 0
impl<N: Num> PeanoMul<Zero> for N {
type Output = Zero;
}
Recursive definition:
impl<N, M> PeanoMul<Next<M>> for N
where
N: Num + PeanoMul<M> + PeanoAdd<<N as PeanoMul<M>>::Output>,
M: Num,
<N as PeanoMul<M>>::Output: Num,
{
type Output = <
N as PeanoAdd< // N +
<N as PeanoMul<M> // N * (M - 1)
>::Output>
>::Output;
}
And now we can use it like the following:
fn main() {
type Product = <_3 as PeanoMul<_3>>::Output;
assert_eq!(Product::VALUE, _9::VALUE);
}
Boolean types
Ok this is going to be easy, try to implement boolean and boolean operations yourself.
Trait and structs:
trait Boolean {
type Value: Num,
}
struct True;
struct False
Implementations
impl Boolean for True {
type Value = Next<Zero>;
}
impl Boolean for False {
type Value = Zero;
}
impl<B> ToVal for b
where
B: Boolean,
B::Value: ToVal
{
const VALUE: usize = B::Value::VALUE;
}
The implementaiton of binary operations is kinda self explanatory so I guess I’ll just drop the rust code.
I am implement the operations in a very simple way because otherwise I’ll have to add type contrains, as an exercise, try to implement NAND gate and then define all other binary operations using that.
NOT
trait Not {
type Output: Boolean;
}
impl Not for True { type Output = False; }
impl Not for False { type Output = True; }
AND
trait And<B: Boolean> {
type Output: Boolean;
}
impl And<False> for False { type Output = False; }
impl And<True> for False { type Output = False; }
impl And<False> for True { type Output = False; }
impl And<True> for True { type Output = True; }
OR
trait Or<B: Boolean> {
type Output: Boolean;
}
impl Or<False> for False { type Output = False; }
impl Or<True> for False { type Output = True; }
impl Or<False> for True { type Output = True; }
impl Or<True> for True { type Output = True; }
XOR
trait Xor<B: Boolean> {
type Output: Boolean;
}
impl Xor<False> for False { type Output = False; }
impl Xor<True> for True { type Output = False; }
impl Xor<False> for True { type Output = True; }
impl Xor<True> for False { type Output = True; }
NAND
trait NAnd<B: Boolean> {
type Output: Boolean;
}
impl NAnd<False> for False { type Output = True; }
impl NAnd<True> for False { type Output = True; }
impl NAnd<False> for True { type Output = True; }
impl NAnd<True> for True { type Output = False; }
Comparisons
Equality
Let Eq
be a function defined recursively by:
Now we just have to translate this logic to our rust code:
trait PeanoEq<N> {
type Output: Boolean;
}
// Eq(a, 0) = false
impl<N: Num> PeanoEq<Zero> for Next<N> {
type Output = False;
}
// Eq(0, a) = false
impl<N: Num> PeanoEq<Next<N>> for Zero {
type Output = False;
}
// Eq(0, 0) = true
impl PeanoEq<Zero> for Zero {
type Output = True;
}
// Eq(S(a), S(b)) = Eq(a, b)
impl<N, M> PeanoEq<Next<M>> for Next<N>
where
N: Num + PeanoEq<M>,
M: Num,
{
type Output = <N as PeanoEq<M>>::Output;
}
And with that, we can also define our PeanoNEq
as follows:
trait PeanoNEq<N> {
type Output: Boolean;
}
impl<N: Num, M: Num> PeanoNEq<N> for M
where
N: PeanoEq<M>,
<N as PeanoEq<M>>::Output: Not,
{
type Output = <<N as PeanoEq<M>>::Output as Not>::Output;
}
Less than, Less than or equal to
We’ll define a function LessThan(a, b) = a < b
similar to Eq
as:
and now our rust implementation
trait PeanoLt<N> {
type Output: Boolean;
}
// LessThan(0, 0) = false
impl PeanoLt<Zero> for Zero {
type Output = False;
}
// LessThan(0, S(a)) = true
impl<N: Num> PeanoLt<Next<N>> for Zero {
type Output = True;
}
// LessThan(S(a), 0) = false
impl<N: Num> PeanoLt<Zero> for Next<N> {
type Output = False;
}
// LessThan(S(a), S(b)) = LessThan(a, b)
impl<M, N> PeanoLt<Next<M>> for Next<N>
where
N: Num + PeanoLt<M>,
M: Num,
{
type Output = <N as PeanoLt<M>>::Output;
}
And we can use our previous PeanoEq
implementation to create PeanoLEq
trait PeanoLEq<N> {
type Output: Boolean;
}
impl<N, M> PeanoLEq<M> for N
where
N: Num + PeanoLt<M> + PeanoEq<M>,
M: Num,
<N as PeanoLt<M>>::Output: Or<<N as PeanoEq<M>>::Output>,
{
type Output = <
<N as PeanoLt<M>>::Output as
Or<
<N as PeanoEq<M>>::Output>
>::Output;
}
Greater than, Greater than or equal to
Now, one way to implement these next two operations, we can everything again what we did for the previous, but we’re smart.
Notice that x > y
is literally !(x <= y)
and similarly x >= y
is !(x < y)
, I think I will use this, but feel free to go with any of the implementation.
Greater than
pub trait PeanoGt<N> {
type Output: Boolean;
}
impl<M, N> PeanoGt<M> for N
where
N: Num + PeanoLt<M> + PeanoEq<M>,
M: Num,
<N as PeanoLt<M>>::Output: Or<<N as PeanoEq<M>>::Output>,
<<N as PeanoLt<M>>::Output as Or<<N as PeanoEq<M>>::Output>>::Output: Not,
{
type Output = <<N as PeanoLEq<M>>::Output as Not>::Output;
}
Greater than or equal to
trait PeanoGEq<N> {
type Output: Boolean;
}
impl<N, M> PeanoGEq<M> for N
where
N: Num + PeanoLt<M> + PeanoEq<M>,
M: Num,
<N as PeanoLt<M>>::Output: Or<<N as PeanoEq<M>>::Output>,
<N as PeanoLt<M>>::Output: Not,
{
type Output = <<N as PeanoLt<M>>::Output as Not>::Output;
}
Lists
I think you’ve already guessed we’re basically doing functional programming, we don’t have loops and everything is immutable.
For implementing a list we’re going to use something called a cons list. It is a specific type of linked list often used to when immutability is desired. It is built recursively, meaning each cons node contains the next node.
A simple list of numbers such as [2, 4, 6, 8]
will look like Cons(2, Cons(4, Cons(6, Cons(8, nil))))
in a cons list.
Clearly, we have to define two structs, one is the Cons
itself and other is Nil
which will denote the end of the list.
struct Cons<H, T>(std::marker::PhantomData<(H, T)>);
struct Nil;
We will also create a List
create to unify them
trait List {}
impl<H, T> List for Cons<H, T> {}
impl List for Nil {}
And just like that! now we have a list, we can use it like this:
fn main() {
type List = Cons<_2, Cons<_4, Cons<_6, Cons<_8, Nil>>>>;
}
But this list alone is not particularly useful, we’ve to implement some operations…
List operations
Index
Starting from the most basic list operation, we have indexing. And yes you guessed it, this will also be a recursive definition.
Let us define a function Idx(list, index)
as:
This is actually pretty simple, if the index is zero, we just return the first element of Cons
which is Head
, otherwise, we call index again on the second element of Cons
which is another Cons
with index - 1, we keep repeating that until index becomes 0.
Here’s how the rust implementation looks like:
trait GetIndex<Index> {
type Output;
}
// Idx(Cons(Head, Tail), 0) = Head
impl<H, T> GetIndex<Zero> for Cons<H, T> {
type Output = H;
}
// Idx(Cons(Head, Tail), S(a)) = Idx(Tail, a)
impl<H, T, Index> GetIndex<Next<Index>> for Cons<H, T>
where
Index: Num,
T: GetIndex<Index>,
{
type Output = <T as GetIndex<Index>>::Output;
}
And we can use it like following:
fn main() {
type List = Cons<_2, Cons<_4, Cons<_6, Cons<_8, Nil>>>>;
assert_eq!(<List as GetIndex<_1>>::Output::VALUE, _4::VALUE);
assert_eq!(<List as GetIndex<_2>>::Output::VALUE, _6::VALUE);
}
If you’ve followed along till this point, I think you should try to implement more list operations like push, pop, etc. by yourself as an exercise.
Append
This operation will allow us to insert an element at the end of a list, this is as simple as replacing the last Nil
with Cons<T, Nil>
This is going to be our base case:
Yeah we just keep appending the item to the tail until we reach Nil
which is then replaced by Cons(i, Nil)
rust implementation:
trait Append<Item> {
type Output;
}
// Append(Nil, i) = Cons(i, Nil)
impl<Item> Append<Item> for Nil {
type Output = Cons<Item, Nil>;
}
// Append(Cons(Head, Tail), i) = Cons(Head, Append(Tail, i))
impl<H, T, Item> Append<Item> for Cons<H, T>
where
T: Append<Item>,
{
type Output = Cons<H, <T as Append<Item>>::Output>;
}
Now I think you get the idea how all this stuff works, so i guess I’ll drop the implementations of length and pop.
Length
trait Length {
type Output;
}
// Base case: len(Nil) = 0
impl Length for Nil {
type Output = Zero;
}
// len(Cons(Head, Tail)) = 1 + len(tail)
impl<H, T> Length for Cons<H, T>
where
T: Length,
<T as Length>::Output: Num,
{
type Output = <<T as Length>::Output as PeanoAdd<Next<Zero>>>::Output;
}
Pop
pub trait Pop<Index> {
type Output;
}
impl<H, T> Pop<Zero> for Cons<H, T> {
type Output = T;
}
impl<H, T, Index> Pop<Next<Index>> for Cons<H, T>
where
Index: Num,
T: Pop<Index>,
{
type Output = Cons<H, <T as Pop<Index>>::Output>;
}
TODO: This blog post is still a work in progress. More type system tricks and examples coming soon!
- arithmetic
- addition
- subtraction
- multiplication
- boolean operations
- comparison
- lists
- list operations
- index
- append
- pop
- length
- type level functions
- more list operations (map, filter, etc.)
- macros to make life easier
- conditionals