Abusing rust's type system

Published on · rust type system recursion

Introduction

In this blog, I will tell you about all the type system tricks I know and how I encorporated all those tricks into https://github.com/thatmagicalcat/tcrts.

The Goal

Lol it’s fairly obvious, we’ll abuse rust’s type system what else?

The Plan

The most important thing for abusing the type system is to have a good understanding of recursion. We are going to use recursion to create a set of types that can be used to represent a set of values, datastructure like list, tuple, etc.

The Basics

Let’s start with the numbers, we are going to represent the numbers using types. Let’s start with the number zero.

struct Zero;

No we aren’t going to create a struct for every number, that would be stupid. Instead, we are going to use a recursive type to represent the numbers.

What I mean is, for 1 can be written as 0 + 1, 2 can be written as 0 + 1 + 1, and so on.

// we'll have to use PhantomData marker as
// we are not going to use the generic type T
struct Next<T>(std::marker::PhantomData<T>);

And now we can represent the numbers as follows:

type One = Next<Zero>;
type Two = Next<One>;
type Three = Next<Two>;
// and so on

Currently the Next<T> and Zero are two separate structs, we’ll define a trait called Num to unify them, this crate won’t do anything but it’ll help us with the implementation of traits.

trait Num {}

// and implement the trait for both types
impl Num for Zero {}
impl<T: Num> Num for Next<T> {}

// We're also gonna change the `Next` struct's
// definition to include the `Num` trait bound
struct Next<T: Num>(std::marker::PhantomData<T>);

Ok but let’s say we have a very long type chain, for example Next<Next<Next<..Next<Zero>..>>>, how do we get the actual value from this type? Surely we are not going to count the number of Next types manually. For checking the value of the type, let’s define a trait called ToVal that will convert the type to a value which we can print.

trait ToVal {
    const VALUE: usize;
}

Yup you guessed it, it is going to be a recursive definition. In any recursive definition, we need a base case, and this implementation the base case is Zero.

impl ToVal for Zero {
    const VALUE: usize = 0;
}

And now we can implement the ToVal trait for Next<T> as follows:

impl<N: ToVal> ToVal for Next<N> {
    const VALUE: usize = 1 + N::VALUE;
}

We can test that this works by running the following code:

fn main() {
    type One = Next<Zero>;
    type Two = Next<One>;
    type Three = Next<Two>;

    assert_eq!(One::VALUE, 1);
    assert_eq!(Two::VALUE, 2);
    assert_eq!(Three::VALUE, 3);
}

And it works! Now we’re done with the basics, let’s move on to the next part and define some operations which we can perform on these number types.

Arithmetic operations

For our arithmetic operations, we are going to use Peano Axioms, which are a set of axioms for the natural numbers that can be used to define operations like addition, subtraction and multiplication in a recursive way.

Addition

Before jumping into the type system implementation, let’s define the addition operation using the Peano Axioms. The addition operation is defined as follows:

Let’s define a function S(x) = x + 1 (this is our Next struct from above).

Our base cases are going to be:

a+0=a(1)a + 0 = a \tag{1} a+S(b)=S(a+b)(2)a + S(b) = S(a + b) \tag{2}

Which is fairly obvious, as addition is commutative, that means, a + (b + 1) is same as (a + b) + 1. We did the same thing above.

Let’s try to create a recursive definition for simple examples first:

a+1=a+S(0)=S(a+0)using (1)=S(a)using (2)a+1=S(a)\begin{align*} a + 1 &= a + S(0) \\ &= S(a + 0) &\text{using (1)} \\ &= S(a) &\text{using (2)} \\ a + 1 &= S(a) \tag{3} \end{align*} a+2=a+S(1)=S(a+1)using (2)=S(S(a))using (3)a+2=S(S(a))\begin{align*} a + 2 &= a + S(1) \\ &= S(a + 1) &\text{using (2)} \\ &= S(S(a)) &\text{using (3)} \\ a + 2 &= S(S(a)) \tag{4} \end{align*} a+3=a+S(2)=S(a+2)using (2)=S(S(S(a)))using (4)a+3=S(S(S(a)))\begin{align*} a + 3 &= a + S(2) \\ &= S(a + 2) &\text{using (2)} \\ &= S(S(S(a))) &\text{using (4)} \\ a + 3 &= S(S(S(a))) \tag{5} \end{align*}

Now we can see a pattern here, we can generalize this for a + n as follows:

a+n=S(a+S(n1))=S(S(S(...S(a)...)))where S is applied n times\begin{align*} a + n &= S(a + S(n - 1)) \\ &= S(S(S(...S(a)...))) \\ &\text{where }S\text{ is applied }n\text{ times} \\ \end{align*}

Here’s another example to help you understand, let’s try to use our recursive definition to find 10 + 4

10+4=10+S(3)=S(10+3)=S(10+S(2))=S(10+2)=S(S(10+2))=S(S(10+S(1)))=S(S(S(10+1)))=S(S(S(10+S(0))))=S(S(S(S(10+0))))=S(S(S(S(10))))=10+1+1+1+1=14\begin{align*} 10 + 4 &= 10 + S(3) \\ &= S(10 + 3) \\ &= S(10 + S(2)) \\ &= S(10 + 2) \\ &= S(S(10 + 2)) \\ &= S(S(10 + S(1))) \\ &= S(S(S(10 + 1))) \\ &= S(S(S(10 + S(0)))) \\ &= S(S(S(S(10 + 0)))) \\ &= S(S(S(S(10)))) \\ &= 10 + 1 + 1 + 1 + 1 \\ &= 14 \end{align*}

With that understanding, we’re ready to implement the addition operation in Rust’s type system. Let’s start by defining a trait called PeanoAdd that will represent the addition operation.

trait PeanoAdd<Rhs: Num> {
    type Output: Num;
}

Base case:

// N + 0 = N
impl<N: Num> PeanoAdd<Zero> for N {
    type Output = N;
}

Next, we generalize the addition operation for Next<T> as follows:

// N + Next<M> = Next<N + Next<M - 1>>
impl<N, M> PeanoAdd<Next<M>> for N
where
    N: Num + PeanoAdd<M>,
    M: Num,
{ 
    type Output = Next<<N as PeanoAdd<M>>::Output>;
}

Notice that we’re doing <N as PeanoAdd<M>> but out trait is PeanoAdd<Next<M>> what this means is we’re basically doing M - 1, which we can’t do directly in Rust system

We can test our addition operation by running the following code:

type _0 = Zero;
type _1 = Next<_0>;
type _2 = Next<_1>;
type _3 = Next<_2>;
type _4 = Next<_3>;
type _5 = Next<_4>;
type _6 = Next<_5>;
type _7 = Next<_6>;
type _8 = Next<_7>;
type _9 = Next<_8>;

fn main() {
    assert_eq!(<_5 as PeanoAdd<_6>>::Output::VALUE, _9::VALUE);

    // Remember `::VALUE` comes from the `ToValue` trait
}

If you’re wondering how this recursive definition will actually work, I’ve got you covered.

Let’s say we have two numbers

type A = Next<Next<Next< ... Zero ... >>> // Next<...> applied A times
type B = Next<Next<Next< ... Zero ... >>> // Next<...> applied B times

Now our addition will look like this

type Result = <A as PeanoAdd<B>>::Output;

I’ll use a notation to simplify the type chain Next<Next<Next<Zero>>> - 1 will be equivalent to Next<Next<Zero>>

Ok we’re ready to expand our addition operation step by step.

<A as PeanoAdd<B>>::Output = Next<A as PeanoAdd<B - 1>>
<A as Add<Next<B - 1>>::Output = Next<A as Add<Next<B - 2>>>
<A as Add<Next<B - 2>>::Output = Next<A as Add<Next<B - 3>>>
<A as Add<Next<B - 3>>::Output = Next<A as Add<Next<B - 4>>>
   .
   . repeated B times
   .
<A as Add<Next<Zero>>::Output = Next<A as Add<Zero>>
<A as Add<Zero>>::Output = Next<A as Add<Zero>>
// this is our base case, its value is A itself

// now when you unwind the stack, you get
<A as Add<Next<One>>::Output = Next<A>
<A as Add<Next<Two>>::Output = Next<Next<A>>
<A as Add<Next<Three>>::Output = Next<Next<Next<A>>>
   .
   . repeated B times
   .
<A as Add<Next<B - 1>>::Output = Next<Next<Next<... A ...>
                                     ^^^^^^^^^^^^^^^^^^^^^
                                     |
                             applied B times

// and finally we get
<A as Add<B>::Output = Next<Next<Next<... A ...>

You can skip the above explanation if you understand the mathematical definition of addition that we defined above, but I think it is important to understand how the recursion works in Rust’s type system. And you should be able to dry run the type system code in your head because that’s the only way to debug your code.

Subtraction

Subtraction kinda the same as addition but instead of adding 1, we’ll keep subtracting 1 from both of the terms until one of them becomes 0. Here’s what i mean: Let’s take an example of 5 - 3

53=S(S(S(S(S(0)))))S(S(S(0)))=S(S(S(S(0))))S(S(0))=S(S(S(0)))S(0)=S(S(0))0=S(S(0))=2\begin{align*} 5 - 3 &= S(S(S(S(S(0))))) - S(S(S(0))) \\ &= S(S(S(S(0)))) - S(S(0)) \\ &= S(S(S(0))) - S(0) \\ &= S(S(0)) - 0 \\ &= S(S(0)) \\ &= 2 \\ \end{align*}

As you can see we’re basically removing the S function from both terms until the right term becomes 0, and if we reach 0 for left term first, we’ll have a negative number which is a problem as we don’t really have a way to deal with.

Since the I already the explained the concept above, i’ll jump straight into the code:

The trait:

trait PeanoSub<N: Num> {
    type Output: Num;
}

Base case

// N - 0 = N
impl<N: Num> PeanoSub<Zero> for N {
    type Output = N;
}

Recursive definition:

// Next<N> - Next<M> = N - M
impl<N, M> PeanoSub<Next<M>> for Next<N>
where
    N: Num + PeanoSub<M>,
    M: Num,
{
    type Output = <N as PeanoSub<M>>::Output;
}

You can see the similarity, the implementation is defined on Next<M> and Next<N>, so using N and M in the recurive definition is same as converting S(a) -> a.

Multiplication

Multiplication is kinda the hardest thing to figure, it took me a while to come up with a recursive definition that would work with rust’s type system, here it is:

Here’s the definition

a×n=a+(a×(n1))a×0=0\begin{align*} a \times n &= a + (a \times (n - 1)) \tag{6} \\ a \times 0 &= 0\\ \end{align*}

Here’s how it looks like:

a×b=a+(a×(b1))a×(b1)=a+(a×(b2))a×(b2)=a+(a×(b3))a×(b3)=a+(a×(b4))a×3=a+(a×2)a×2=a+(a×1)a×1=a+(a×0)a×0=0\begin{align*} a \times b &= a + (a \times (b - 1)) \\ a \times (b - 1) &= a + (a \times (b - 2)) \\ a \times (b - 2) &= a + (a \times (b - 3)) \\ a \times (b - 3) &= a + (a \times (b - 4)) \\ &\vdots\\ a \times 3 &= a + (a \times 2) \\ a \times 2 &= a + (a \times 1) \\ a \times 1 &= a + (a \times 0) \\ a \times 0 &= 0 \\ \end{align*}

Now we unwind the stack

a×1=aa×2=a+aa×3=a+a+aa×4=a+a+a+aa×b=a+a++a+ab terms\begin{align*} a \times 1 &= a \\ a \times 2 &= a + a \\ a \times 3 &= a + a + a \\ a \times 4 &= a + a + a + a \\ &\vdots \\ a \times b &= \underbrace{a + a + \dots + a + a}_{b\text{ terms}} \\ \end{align*}

Here’s an example, let’s try to calculate 5 * 6

5×6=5+(5×5)=5+5+(5×4)=5+5+5+(5×3)=5+5+5+5+5(×2)=5+5+5+5+5+5(×1)=5+5+5+5+5+5=30\begin{align*} 5 \times 6 &= 5 + (5\times 5) \\ &= 5 + 5 + (5\times 4) \\ &= 5 + 5 + 5 + (5\times 3 ) \\ &= 5 + 5 + 5 + 5 + 5 (\times 2 ) \\ &= 5 + 5 + 5 + 5 + 5 + 5 (\times 1) \\ &= 5 + 5 + 5 + 5 + 5 + 5 \\ &= 30 \end{align*}

Now let’s implement this in rust

The trait:

trait PeanoMul<N: Num> {
    type Output: Num;
}

Base case:

// N * 0 = 0
impl<N: Num> PeanoMul<Zero> for N {
    type Output = Zero;
}

Recursive definition:

impl<N, M> PeanoMul<Next<M>> for N
where
    N: Num + PeanoMul<M> + PeanoAdd<<N as PeanoMul<M>>::Output>,
    M: Num,
    <N as PeanoMul<M>>::Output: Num,
{
    type Output = <
        N as PeanoAdd< // N +
            <N as PeanoMul<M> // N * (M - 1)
        >::Output>
    >::Output;
}

And now we can use it like the following:

fn main() {
    type Product = <_3 as PeanoMul<_3>>::Output;

    assert_eq!(Product::VALUE, _9::VALUE);
}

Boolean types

Ok this is going to be easy, try to implement boolean and boolean operations yourself.

Trait and structs:

trait Boolean {
    type Value: Num, 
}

struct True;
struct False

Implementations

impl Boolean for True {
    type Value = Next<Zero>;
}

impl Boolean for False {
    type Value = Zero;
}

impl<B> ToVal for b
where
    B: Boolean,
    B::Value: ToVal
{
    const VALUE: usize = B::Value::VALUE;
}

The implementaiton of binary operations is kinda self explanatory so I guess I’ll just drop the rust code.

I am implement the operations in a very simple way because otherwise I’ll have to add type contrains, as an exercise, try to implement NAND gate and then define all other binary operations using that.

NOT

trait Not {
    type Output: Boolean;
}

impl Not for True  { type Output = False; }
impl Not for False { type Output = True;  }

AND

trait And<B: Boolean> {
    type Output: Boolean;
}

impl And<False> for False { type Output = False; }
impl And<True>  for False { type Output = False; }
impl And<False> for True  { type Output = False; }
impl And<True>  for True  { type Output = True;  }

OR

trait Or<B: Boolean> {
    type Output: Boolean;
}

impl Or<False> for False { type Output = False; }
impl Or<True>  for False { type Output = True;  }
impl Or<False> for True  { type Output = True;  }
impl Or<True>  for True  { type Output = True;  }

XOR

trait Xor<B: Boolean> {
    type Output: Boolean;
}

impl Xor<False> for False { type Output = False; }
impl Xor<True>  for True  { type Output = False; }
impl Xor<False> for True  { type Output = True;  }
impl Xor<True>  for False { type Output = True;  }

NAND

trait NAnd<B: Boolean> {
    type Output: Boolean;
}

impl NAnd<False> for False { type Output = True;  }
impl NAnd<True>  for False { type Output = True;  }
impl NAnd<False> for True  { type Output = True;  }
impl NAnd<True>  for True  { type Output = False; }

Comparisons

Equality

Let Eq be a function defined recursively by:

Eq(0,0)=trueEq(S(a),0)=falseEq(0,S(b))=falseEq(S(a),S(b))=Eq(a,b)\begin{align*} Eq(0, 0) &= \text{true} \\ Eq(S(a), 0) &= \text{false} \\ Eq(0, S(b)) &= \text{false} \\ Eq(S(a), S(b)) &= Eq(a, b) \\ \end{align*}

Now we just have to translate this logic to our rust code:

trait PeanoEq<N> {
    type Output: Boolean;
}

// Eq(a, 0) = false
impl<N: Num> PeanoEq<Zero> for Next<N> {
    type Output = False;
}

// Eq(0, a) = false
impl<N: Num> PeanoEq<Next<N>> for Zero {
    type Output = False;
}

// Eq(0, 0) = true
impl PeanoEq<Zero> for Zero {
    type Output = True;
}

// Eq(S(a), S(b)) = Eq(a, b)
impl<N, M> PeanoEq<Next<M>> for Next<N>
where
    N: Num + PeanoEq<M>,
    M: Num,
{
    type Output = <N as PeanoEq<M>>::Output;
}

And with that, we can also define our PeanoNEq as follows:

trait PeanoNEq<N> {
    type Output: Boolean;
}

impl<N: Num, M: Num> PeanoNEq<N> for M
where
    N: PeanoEq<M>,
    <N as PeanoEq<M>>::Output: Not,
{
    type Output = <<N as PeanoEq<M>>::Output as Not>::Output;
}

Less than, Less than or equal to

We’ll define a function LessThan(a, b) = a < b similar to Eq as:

LessThan(0,0)=falseLessThan(0,S(a))=trueLessThan(S(a),0)=falseLessThan(S(a),S(b))=LessThan(a, b)\begin{align*} \text{LessThan}(0, 0) &= \text{false} \\ \text{LessThan}(0, S(a)) &= \text{true} \\ \text{LessThan}(S(a), 0) &= \text{false} \\ \text{LessThan}(S(a), S(b)) &= \text{LessThan(a, b)} \\ \end{align*}

and now our rust implementation

trait PeanoLt<N> {
    type Output: Boolean;
}

// LessThan(0, 0) = false
impl PeanoLt<Zero> for Zero {
    type Output = False;
}

// LessThan(0, S(a)) = true
impl<N: Num> PeanoLt<Next<N>> for Zero {
    type Output = True;
}

// LessThan(S(a), 0) = false
impl<N: Num> PeanoLt<Zero> for Next<N> {
    type Output = False;
}

// LessThan(S(a), S(b)) = LessThan(a, b)
impl<M, N> PeanoLt<Next<M>> for Next<N>
where
    N: Num + PeanoLt<M>,
    M: Num,
{
    type Output = <N as PeanoLt<M>>::Output;
}

And we can use our previous PeanoEq implementation to create PeanoLEq

trait PeanoLEq<N> {
    type Output: Boolean;
}

impl<N, M> PeanoLEq<M> for N
where
    N: Num + PeanoLt<M> + PeanoEq<M>,
    M: Num,
    <N as PeanoLt<M>>::Output: Or<<N as PeanoEq<M>>::Output>,
{
    type Output = <
        <N as PeanoLt<M>>::Output as
        Or<
        <N as PeanoEq<M>>::Output>
    >::Output;
}

Greater than, Greater than or equal to

Now, one way to implement these next two operations, we can everything again what we did for the previous, but we’re smart.

Notice that x > y is literally !(x <= y)
and similarly x >= y is !(x < y), I think I will use this, but feel free to go with any of the implementation.

Greater than

pub trait PeanoGt<N> {
    type Output: Boolean;
}

impl<M, N> PeanoGt<M> for N
where
    N: Num + PeanoLt<M> + PeanoEq<M>,
    M: Num,
    <N as PeanoLt<M>>::Output: Or<<N as PeanoEq<M>>::Output>,
    <<N as PeanoLt<M>>::Output as Or<<N as PeanoEq<M>>::Output>>::Output: Not,
{
    type Output = <<N as PeanoLEq<M>>::Output as Not>::Output;
}

Greater than or equal to

trait PeanoGEq<N> {
    type Output: Boolean;
}

impl<N, M> PeanoGEq<M> for N
where
    N: Num + PeanoLt<M> + PeanoEq<M>,
    M: Num,
    <N as PeanoLt<M>>::Output: Or<<N as PeanoEq<M>>::Output>,
    <N as PeanoLt<M>>::Output: Not,
{
    type Output = <<N as PeanoLt<M>>::Output as Not>::Output;
}

Lists

I think you’ve already guessed we’re basically doing functional programming, we don’t have loops and everything is immutable.

For implementing a list we’re going to use something called a cons list. It is a specific type of linked list often used to when immutability is desired. It is built recursively, meaning each cons node contains the next node.

A simple list of numbers such as [2, 4, 6, 8] will look like Cons(2, Cons(4, Cons(6, Cons(8, nil)))) in a cons list. Clearly, we have to define two structs, one is the Cons itself and other is Nil which will denote the end of the list.

struct Cons<H, T>(std::marker::PhantomData<(H, T)>);
struct Nil;

We will also create a List create to unify them

trait List {}

impl<H, T> List for Cons<H, T> {}
impl List for Nil {}

And just like that! now we have a list, we can use it like this:

fn main() {
    type List = Cons<_2, Cons<_4, Cons<_6, Cons<_8, Nil>>>>;
}

But this list alone is not particularly useful, we’ve to implement some operations…

List operations

Index

Starting from the most basic list operation, we have indexing. And yes you guessed it, this will also be a recursive definition.

Let us define a function Idx(list, index) as:

Idx(Cons(Head,Tail),0)=HeadIdx(Cons(Head,Tail),S(a))=Idx(Tail,a)\begin{align*} \text{Idx}(\text{Cons}(\text{Head}, \text{Tail}), 0) &= \text{Head} \\ \text{Idx}(\text{Cons}(\text{Head}, \text{Tail}), S(a)) &= \text{Idx}(\text{Tail}, a) \\ \end{align*}

This is actually pretty simple, if the index is zero, we just return the first element of Cons which is Head, otherwise, we call index again on the second element of Cons which is another Cons with index - 1, we keep repeating that until index becomes 0.

Here’s how the rust implementation looks like:

trait GetIndex<Index> {
    type Output;
}

// Idx(Cons(Head, Tail), 0) = Head
impl<H, T> GetIndex<Zero> for Cons<H, T> {
    type Output = H;
}

// Idx(Cons(Head, Tail), S(a)) = Idx(Tail, a)
impl<H, T, Index> GetIndex<Next<Index>> for Cons<H, T>
where
    Index: Num,
    T: GetIndex<Index>,
{
    type Output = <T as GetIndex<Index>>::Output;
}

And we can use it like following:

fn main() {
    type List = Cons<_2, Cons<_4, Cons<_6, Cons<_8, Nil>>>>;

    assert_eq!(<List as GetIndex<_1>>::Output::VALUE, _4::VALUE);
    assert_eq!(<List as GetIndex<_2>>::Output::VALUE, _6::VALUE);
}

If you’ve followed along till this point, I think you should try to implement more list operations like push, pop, etc. by yourself as an exercise.

Append

This operation will allow us to insert an element at the end of a list, this is as simple as replacing the last Nil with Cons<T, Nil>

This is going to be our base case:

Append(Nil,i)=Cons(i,Nil)Append(Cons(Head, Tail),i)=Cons(Head,Append(Tail,i))\begin{align*} \text{Append}(\text{Nil}, i) &= \text{Cons}(i, \text{Nil})\\ \text{Append}(\text{Cons(Head, Tail)}, i) &= \text{Cons}(\text{Head}, \text{Append}(\text{Tail}, i)) \\ \end{align*}

Yeah we just keep appending the item to the tail until we reach Nil which is then replaced by Cons(i, Nil)

rust implementation:

trait Append<Item> {
    type Output;
}

// Append(Nil, i) = Cons(i, Nil)
impl<Item> Append<Item> for Nil {
    type Output = Cons<Item, Nil>;
}

// Append(Cons(Head, Tail), i) = Cons(Head, Append(Tail, i))
impl<H, T, Item> Append<Item> for Cons<H, T>
where
    T: Append<Item>,
{
    type Output = Cons<H, <T as Append<Item>>::Output>;
}

Now I think you get the idea how all this stuff works, so i guess I’ll drop the implementations of length and pop.

Length

trait Length {
    type Output;
}

// Base case: len(Nil) = 0
impl Length for Nil {
    type Output = Zero;
}

// len(Cons(Head, Tail)) = 1 + len(tail)
impl<H, T> Length for Cons<H, T>
where
    T: Length,
    <T as Length>::Output: Num,
{
    type Output = <<T as Length>::Output as PeanoAdd<Next<Zero>>>::Output;
}

Pop

pub trait Pop<Index> {
    type Output;
}

impl<H, T> Pop<Zero> for Cons<H, T> {
    type Output = T;
}

impl<H, T, Index> Pop<Next<Index>> for Cons<H, T>
where
    Index: Num,
    T: Pop<Index>,
{
    type Output = Cons<H, <T as Pop<Index>>::Output>;
}

TODO: This blog post is still a work in progress. More type system tricks and examples coming soon!