Implementing an SHA transformer by hand

So the first question, you, and any reasonable person, may have when seeing this post title is "why?? why would you ever do this?" Well, I have a pretty good reason—Andis and I came up with a scheme to insert undetectable backdoors into transformers, but it depends on having a "submodule" which can compute a pretty strong hash. SHA-2 is a great candidate for this—not least because its Wikipedia article actually has pseudocode for its implementation—so now all we have to do is make a transformer which, given some input, places that input's hash into the residual stream.

But wait! You ask. Isn't RASP a thing? Yes, it is, and I love it. For those of you who are unfamiliar, in 2021, Weiss, Goldberg, and Yoav introduced an amazing programming language called RASP, in which any program you write is theoretically implementable on a transformer. Thanks to Sasha Rush, I've had a lot of fun messing around with RASP code. Heck, I've even already implemented SHA-256 in RASP.

So why doesn't this blog post end here? Well, the original RASP paper only proposed the language—they didn't actually provide any sort of "compiler" which turns RASP code into an actual transformer, which is what I needed. Fortunately, Lindner, Kramár, Farquhar, Rahtz, McGrath, and Mikulik at Deepmind heroically solve this longstanding problem by introducing Tracr, a way of compiling RASP into JAX-Haiku.

So I load in my RASP code into Tracr, and—haha, I wish! Tracr has its own even more restricted implementation of RASP, which, for instance, doesn't support things like ORing two selectors. Or using anything other than None as the default for a categorical aggregate. Or averaging over non-binary variables. After finally wrangling our code to be Tracr-compatible, Andis and I realized another horrible truth—for sequence lengths greater than about 400, Tracr gets numerically unstable to the point where 1s and 0s become indistinguishable! This makes it hard to implement any hashing algorithm that uses 512-bit chunks (as SHA-2 and most others do).

(To be clear, none of this is a criticism of the Tracr team—they did an incredible job writing a highly technically impressive compiler, and open-sourced it! For small programs, it's incredibly usable and well-made. Though I would be happy to see someone get a grant to write a more complete version. For some cool work with Tracr that I've been excited about recently, check out Iván Arcuschin Moreno's Tracr circuits benchmark.)

However, I was now stuck between a rock (improving the Tracr compiler to fix the numerical instability issue) and a hard place (manually coding each weight).

Welcome to the hard place.

Part 1: Transformers as Virtual Machines

Last July, Ben Newhouse did something incredible—he hand-crafted a transformer which does long-hand addition, and wrote up his approach in an amazing colab notebook. I highly recommend reading through this—it's very well-written, and for the rest of this project, I'll be building on top of his approach. (At the very least, you should look at the "transformers as virtual machines" section).

But, as a very short summary, we can embed each token in a pretty high-dimensional space, where the first n dimensions serve as a one-hot representation of the token, and all the extra dimensions are used to store extra information. Attention layers allow us to move information across tokens (e.g., "shift every token to the right"), while MLPs allow us to calculations within a token (e.g., "multiply each token by 2"). Then, we can group sets of the extra dimensions into meaningful "registers". Let's start with an example:

# Various imports

import torch
import torch.nn.functional as F
import math
import torch.nn
import matplotlib.pyplot as plt
import numpy as np

# A constant which determines how fine-grained our positional embedding is
POS_STEP = 1e-3

# Our sequence length
INPUT_LENGTH = 7


# Random helper functions
def plot_tensor(tensor, embedding, title=None):
    fig, ax = plt.subplots(figsize=(12, 6))
    im = ax.imshow(tensor.detach().cpu(), cmap='viridis', aspect='auto')
    
    # Add colorbar
    cbar = ax.figure.colorbar(im, ax=ax)
    cbar.ax.set_ylabel('Value', rotation=-90, va="bottom")
    
    # Create x-axis labels based on register information
    labels = ['input']
    label_positions = [1.5]
    line_positions = [len(tokens)]
    start = len(tokens)
    for register in embedding.registers:
        if register.name != 'input':
            end = start + register.size
            labels.append(register.name)
            label_positions.append((start + end - 1) / 2)
            line_positions.append(end)
            start = end
    
    # Set x-axis labels and tick positions for labels
    ax.set_xticks(label_positions)
    ax.set_xticklabels(labels)
    
    # Remove x-axis tick marks
    ax.tick_params(axis='x', length=0)
    
    # Move x-axis labels to the top
    ax.xaxis.set_label_position('top')
    
    # Add vertical lines between registers
    for pos in line_positions[1:]:
        ax.axvline(pos - 0.5, color='white', linestyle='--', linewidth=1.5)
    
    if title is not None:
        ax.set_title(title)
    
    plt.tight_layout()
    plt.show()

from io import BytesIO

def draw_tensor(tensor, embedding, title=None):
    fig, ax = plt.subplots(figsize=(12, 6))
    im = ax.imshow(tensor.detach().cpu(), cmap='viridis', aspect='auto')
    
    # Add colorbar
    cbar = ax.figure.colorbar(im, ax=ax)
    cbar.ax.set_ylabel('Value', rotation=-90, va="bottom")
    
    # Create x-axis labels based on register information
    labels = ['input']
    label_positions = [1.5]
    line_positions = [len(tokens)]
    start = len(tokens)
    for register in embedding.registers:
        if register.name != 'input':
            end = start + register.size
            labels.append(register.name)
            label_positions.append((start + end - 1) / 2)
            line_positions.append(end)
            start = end
    
    # Set x-axis labels and tick positions for labels
    ax.set_xticks(label_positions)
    ax.set_xticklabels(labels)
    
    # Remove x-axis tick marks
    ax.tick_params(axis='x', length=0)
    
    # Move x-axis labels to the top
    ax.xaxis.set_label_position('top')
    
    # Add vertical lines between registers
    for pos in line_positions[1:]:
        ax.axvline(pos - 0.5, color='white', linestyle='--', linewidth=1.5)
    
    if title is not None:
        ax.set_title(title)
    
    plt.tight_layout()
    
    buffer = BytesIO()
    plt.savefig(buffer, format='png')
    plt.close()

    buffer.seek(0)
    
    return buffer
class Register(object):
    def __init__(self, name, size):
        self.name = name
        self.size = size
        self.offset = None
class EmbeddedState(object):
    def __init__(self, tokens: list[str], registers: list[Register]):
        self.tokens = tokens
        self.token_map = { t: i for i,t in enumerate(tokens) }
        self.registers = registers
        self.register_map = {}
        self.register_size = 0
        
        if len(registers) == 0 or registers[0].name != 'pos':
            raise Exception("First register must be 'pos'") 
        
        offset = len(tokens)
        for reg in registers:
            reg.offset = offset
            offset += reg.size
            self.register_size += reg.size
            self.register_map[reg.name] = reg
            
        self.dim = len(tokens) + self.register_size

    def tokenize(self, string: str):
        return F.one_hot(torch.tensor([self.token_map[c] for c in string]), num_classes=len(self.tokens)).float()
    
    def itokenize(self, string: str):
        return torch.tensor([self.token_map[c] for c in string]).float().unsqueeze(1)

    def embed(self, sequence, additional_constants):
        # We want to create additional space to store the registers
        extension_tensor = torch.zeros(*sequence.shape[:-1], self.register_size)

        # Encode position in the first extra embedding dimension
        for i in range(sequence.shape[0]):
            extension_tensor[i, 0] = math.sin(i*(2*math.pi)*POS_STEP)
            extension_tensor[i, 1] = math.cos(i*(2*math.pi)*POS_STEP)
        
        # Next columns of the extension tensor are the additional constants
        offset = 2
        for constant in additional_constants:
            extension_tensor[:, offset:offset + constant.shape[-1]] = constant
            offset += constant.shape[-1]

        sequence = torch.cat((sequence, extension_tensor), dim=-1)

        return sequence
    
    def predict(self, sequence):
        return self.tokens[torch.argmax(sequence[-1,:len(self.tokens)])]

Above, we've defined a Register and a way of embedding the input. Here is what this looks like in practice, where we load the input cafebeef into our transformer using a one-hot embedding:

tokens = list('abcdef')
pos = Register('pos', 2)

embedding = EmbeddedState(tokens, [pos])
example = embedding.embed(embedding.tokenize('cafebeef'), [])
plot_tensor(example, embedding)

Now, if you look at the plot above, you should be able to see the following:

So once again, in the plot, each row represents a token; there are eight rows because cafebeef has eight letters, and each row has its first six columns represent which letter the current token is, and its next two columns represent where within the input the letter is (e.g. is it the first letter, second letter, third letter, etc.)

Let's copy Ben Newhouse's AttentionLayer module, as well as a specific attention layer called GetRelativeToken which allows us to shift our input:

class AttentionLayer(torch.nn.Module):
    def __init__(self, instruction):
        super(AttentionLayer, self).__init__()
        
        self.key = torch.nn.Parameter(instruction.key)
        self.value = torch.nn.Parameter(instruction.value)
        self.query = torch.nn.Parameter(instruction.query)
        
        self.mask = instruction.mask
        
        self.softmax = torch.nn.Softmax(2)
        
    def forward(self, seq):
        batch_size, seq_length, dim = seq.shape
        
        query = seq @ self.query
        key = seq @ self.key
        value = seq @ self.value

        causal_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)*0).to(seq.device)
        norm = np.sqrt(dim)
        
        kq = self.softmax(query @ key.transpose(-2, -1) / norm + causal_mask)
        
        s = (kq @ value) * self.mask
        
        return (seq + s)

    def reset(self):
        torch.nn.init.xavier_uniform_(self.key)
        torch.nn.init.xavier_uniform_(self.query)
        torch.nn.init.xavier_uniform_(self.value)
class GetRelativeToken(AttentionLayer):
    def __init__(self, embedding: EmbeddedState, pos_reg: Register, steps: int, out: Register):
        tpos_reg = embedding.register_map['pos']
        
        position_select = torch.zeros(embedding.dim, embedding.dim)
        position_select[tpos_reg.offset, tpos_reg.offset] = 1e10
        position_select[tpos_reg.offset + 1, tpos_reg.offset + 1] = 1e10

        i = -steps
        sin = math.sin(i*(2*math.pi)*POS_STEP)*1
        cos = math.cos(i*(2*math.pi)*POS_STEP)*1

        rotation = torch.zeros(embedding.dim, embedding.dim)
        rotation[pos_reg.offset, tpos_reg.offset] = cos
        rotation[pos_reg.offset + 1, tpos_reg.offset] = -sin
        rotation[pos_reg.offset, tpos_reg.offset + 1] = sin
        rotation[pos_reg.offset + 1, tpos_reg.offset + 1] = cos
        
        token_copy = torch.zeros(embedding.dim, embedding.dim)
        for i in range(len(embedding.tokens)):
            token_copy[i, i + out.offset] = 1.0
            
        self.query = rotation
        self.key = position_select
        self.value = token_copy
        
        self.mask = torch.zeros(embedding.dim)
        self.mask[out.offset:out.offset + out.size] = 1.0
                
        super(GetRelativeToken, self).__init__(self)

And let's see what happens when we apply a GetRelativeToken layer to shift the input by 1.

tokens = list('abcdef')
pos = Register('pos', 2)
shifted_input = Register('shifted', len(tokens))

embedding = EmbeddedState(tokens, [pos, shifted_input])

example = embedding.embed(embedding.tokenize('cafebeef'), [])
shift = GetRelativeToken(embedding, pos, 1, shifted_input)


plot_tensor(example, embedding, 'Starting tensor')
x = shift.forward(example.unsqueeze(0))[0]
plot_tensor(x, embedding, 'After shift.forward')

Look at that! We've added a register, called shifted. At the beginning, it starts out empty. Then, we call a GetRelativeToken(embedding, pos, 1, shifted_input) layer in order to copy the input into the shifted_input register, shifted left by 1.

For instance, now row 1 consists of [0, 0, 1, 0, 0, 0, <pos>, 1, 0, 0, 0, 0, 0], which means that the input register is (still) storing a one-hot representation of c, while the shifted register is storing a one-hot representation of a! Similarly, in the second row, shifted is storing [0, 0, 0, 0, 0, 1], which is the embedding of f. You should be able to "read off" the shifted register to see what it stores; however, it would be nice to have a "print" function, which allows us to read the de-embedded contents of a register, instead of having to stare at the tensor each time:

def print_(embedding, seq, register, what):
        # Get the register
        register = seq[0, :, register.offset:register.offset + register.size]
        
        # For each row in the register, convert the one-hot encoding to an index
        indices = torch.argmax(register, dim=1)
        # Convert the indices to characters
        chars = [embedding.tokens[i] for i in indices]
        
        # Print the characters
        print(f"{what}: {{{''.join(chars)}}}")
print_(embedding, x.unsqueeze(0), shifted_input, "The shifted register's value")
The shifted register's value: {afebeeff}

Look at that! Indeed, the input cafebeef was shifted left by 1, giving us afebeeff.

So, let's get back to our original goal, which is implementing SHA. We need to have the following operations:

  1. rightshift, where we take a sequence and shift it to the right, with new bits on the left becoming zeros;
  2. rightrotate, where we take a sequence and rotate it to the right (e.g., same as rightshift, but the bits appearing on the left were the previously rightmost bits of the original sequence, instead of zeroes);
  3. and, where we compute the elementwise AND between two sequences;
  4. xor, where we compute the elementwise OR between two sequences; and
  5. +, where we add two sequences together (as binary numbers modulo 2^32).

Now, in some sense, we already almost have the rightshift function complete. However, as you may have noticed, it doesn't make the new bits zeroes (or whatever the earliest token in our list is)—rather, it makes the new bits the same as whatever the last bit was. This is why shifting cafebeef to the left by 1 gives us afebeeff (note the second "f" at the end).

At this point, it probably makes sense to get back into the realm of the set of tokens that we're working with; in our case, that's going to be 0 and 1. Sure, SHA is usually thought of as a map from ascii text to a hexadecimal output, but it's pretty easy to write a Tracr transformer which converts input text to the binary form necessary to run the SHA algorithm, and back from binary to hex. Thus, we'll only focus on the binary->binary part of the calculation. (In particular, we'll leave the length calculation and padding to length 512 to Tracr.)

So, if we have a sequence like [1, 1, 1, 0, 1, 1] that we rightshift by 1, we want to get the output [0, 1, 1, 1, 0, 1], but our current shift function would give us [1, 1, 1, 1, 0, 1]. There's a pretty simple solution to this—create a constant register that stores the sequence [0, 1, 1, 1, 1, 1], and after each shift operation, compute the AND between the result and the constant register. For simplicity, we'll refer to this constant register as the Tchaikovsky register, and then have the shift function, under the hood, calculate Tchaikovsky AND shift(1).

Ok, so it looks like we should start by implementing the bitwise AND operation. First, let's define the MLPLayer class, also copied from Ben Newhouse:

class MLPLayer(torch.nn.Module):
    def __init__(self, instruction, debug=False):
        super(MLPLayer, self).__init__()
        self.debug = debug
        
        self.first_weights = torch.nn.Parameter(instruction.first_weights)
        self.first_bias = torch.nn.Parameter(instruction.first_bias)
        self.second_weights = torch.nn.Parameter(instruction.second_weights)
        self.second_bias = torch.nn.Parameter(instruction.second_bias)
        
        self.gelu = torch.nn.ReLU()
        
        self.mask = instruction.mask
        
    def forward(self, seq):
        if self.debug:
            plot_tensor(seq.squeeze(), 'seq')
        
        a = self.gelu(seq @ self.first_weights + self.first_bias)
        if self.debug:
            plot_tensor(a.squeeze(), 'a')
        b = (a @ self.second_weights)
        if self.debug:
            plot_tensor(b.squeeze(), 'b')
        x = b + self.second_bias
        if self.debug:
            plot_tensor(x.squeeze(), 'x')
        return seq + (x * self.mask)

    def reset(self):
        torch.nn.init.xavier_uniform_(self.first_weights)
        torch.nn.init.zeros_(self.first_bias)
        torch.nn.init.xavier_uniform_(self.second_weights)
        torch.nn.init.zeros_(self.second_bias)

And now, let's implement our elementwise AND function using two MLP layers:

class ANDPart1(MLPLayer):
    def __init__(self, embedding: EmbeddedState, pos_reg: Register, zero_register: Register, first_reg: Register, second_reg: Register, result_reg: Register):
        self.first_weights = torch.zeros(embedding.dim, embedding.dim)
        self.first_bias = torch.zeros(embedding.dim)
        
        # The idea here is as follows:
        # The first register has a bunch of rows of the form [X1, Y1, Z1]
        # where the X=1, Y=0, Z=0 case corresponds to that row containing a 0
        # and the X=0, Y=1, Z=0 case corresponds to that row containing a 1
        # this is identically true for the second register, which we'll call, say, [X2, Y2, Z2]
        # The result register, [Xr, Yr, Zr], starts out as all zeros
        # What we do is every time Y1 or Y2 is 1, we add that value to Yr
        # For example:
        # If the first register was 1 and the second register was 1,
        # then Y1 is 1 and Y2 is 1, so Yr is 2
        # If the first register was 1 and the second register was 0,
        # then Y1 is 1 and Y2 is 0, so Yr is 1
        # If the first register was 0 and the second register was 1,
        # then Y1 is 0 and Y2 is 1, so Yr is 1
        # If the first register was 0 and the second register was 0,
        # then Y1 is 0 and Y2 is 0, so Yr is 0
        # We only want the first of the four cases to "survive",
        # so we set the bias to -1
        # This makes it so that Yr is 1 if and only if Y1 and Y2 were both 1
        
        self.first_weights[first_reg.offset + 1, result_reg.offset + 1] += 1
        self.first_weights[second_reg.offset + 1, result_reg.offset + 1] += 1
        
        self.first_bias[result_reg.offset:result_reg.offset + result_reg.size] = -1.0
        
        self.second_weights = torch.eye(embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        self.mask = torch.zeros(embedding.dim)
        for reg in [result_reg]:
            self.mask[reg.offset:reg.offset + reg.size] = 1.0
        
        super(ANDPart1, self).__init__(self)
class ANDPart2(MLPLayer):
    def __init__(self, embedding: EmbeddedState, pos_reg: Register, zero_register: Register, first_reg: Register, second_reg: Register, result_reg: Register):
        self.first_weights = torch.zeros(embedding.dim, embedding.dim)
        self.first_bias = torch.zeros(embedding.dim)
        
        # Now, setting the Yr's to 1 where Y1 and Y2 were both 1 isn't enough:
        # We need to *also* set all the other values to zero (e.g. by setting Xr to 1)
        # We do this by setting Xr to 1 everywhere
        # and then subtracting the value of Yr from it, which will zero it out in the cases where Yr=1
        # And thus, where the first register and the second register were both [0, 1, 0],
        # the result is [0, 1, 0]
        # and in all other cases,
        # the result is [1, 0, 0]!
        
        self.first_weights[zero_register.offset + 0, result_reg.offset + 0] += 1
        self.first_weights[result_reg.offset + 1, result_reg.offset + 0] += -1
        
        self.second_weights = torch.eye(embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        self.mask = torch.zeros(embedding.dim)
        for reg in [result_reg]:
            self.mask[reg.offset:reg.offset + reg.size] = 1.0
        
        super(ANDPart2, self).__init__(self)

Let's see this in action! First, let's define two registers, a and b, so we can demonstrate the ANDing.

tokens = list('012')
pos = Register('pos', 2)
zeros = Register('zeros', len(tokens))
a = Register('a', len(tokens))
b = Register('b', len(tokens))
result = Register('result', len(tokens))

embedding = EmbeddedState(tokens, [pos, zeros, a, b, result])

example = embedding.embed(embedding.tokenize('0110101'), [embedding.tokenize('0000000')])
shift1 = GetRelativeToken(embedding, pos, 1, a)
shift2 = GetRelativeToken(embedding, pos, 2, b)

x = shift1.forward(example.unsqueeze(0))[0]
x = shift2.forward(x.unsqueeze(0))[0]
plot_tensor(x, embedding, 'Registers')

print_(embedding, x.unsqueeze(0), a, 'a')
print_(embedding, x.unsqueeze(0), b, 'b')
a: {1101011}
b: {1010111}

So, we expect the answer to be 1000011.

and1 = ANDPart1(embedding, pos, zeros, a, b, result)

x = and1.forward(x.unsqueeze(0))[0]
plot_tensor(x, embedding, 'AND, part 1')

and2 = ANDPart2(embedding, pos, zeros, a, b, result)

x = and2.forward(x.unsqueeze(0))[0]
plot_tensor(x, embedding, 'AND, part 2')

print_(embedding, x.unsqueeze(0), result, 'result')
result: {1000011}

and 1000011 it is! Cool, this was fun. However, as you may have realized while reading the above implementation, having the internal registers also use one-hot encoding, while nice and consistent, is pretty clunky—it would make more sense if we just used width-1 registers containing zeros and ones directly! To do this, let's implement a layer which will convert to our internal representation:

class ConvertToInternal(MLPLayer):
    def __init__(self, embedding: EmbeddedState, out: Register):
        self.first_weights = torch.zeros(embedding.dim, embedding.dim)
        self.first_bias = torch.zeros(embedding.dim)
        
        self.first_weights[1, out.offset] += 1
        
        self.second_weights = torch.eye(embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        self.mask = torch.zeros(embedding.dim)
        for reg in [out]:
            self.mask[reg.offset:reg.offset + reg.size] = 1.0
        
        super(ConvertToInternal, self).__init__(self)

Then, we need to slightly change our GetRelativeToken logic to be able to copy to and from "internal" registers:

class GRLT2(AttentionLayer):
    """
    Copy the tokens from the given register to the output register, with an optional rotation by `steps`
    """
    
    def __init__(self, embedding: EmbeddedState, pos_reg: Register, steps: int, copy_from: Register, out: Register):
        tpos_reg = embedding.register_map['pos']
        
        position_select = torch.zeros(embedding.dim, embedding.dim)
        position_select[tpos_reg.offset, tpos_reg.offset] = 1e10
        position_select[tpos_reg.offset + 1, tpos_reg.offset + 1] = 1e10
        
        i = -steps
        sin = math.sin(i*(2*math.pi)*POS_STEP)*1
        cos = math.cos(i*(2*math.pi)*POS_STEP)*1

        rotation = torch.zeros(embedding.dim, embedding.dim)
        rotation[pos_reg.offset, tpos_reg.offset] = cos
        rotation[pos_reg.offset + 1, tpos_reg.offset] = -sin
        rotation[pos_reg.offset, tpos_reg.offset + 1] = sin
        rotation[pos_reg.offset + 1, tpos_reg.offset + 1] = cos
        
        token_copy = torch.zeros(embedding.dim, embedding.dim)
        token_copy[copy_from.offset, out.offset] = 1.0
            
        self.query = rotation
        self.key = position_select
        self.value = token_copy
        
        self.mask = torch.zeros(embedding.dim)
        self.mask[out.offset:out.offset + out.size] = 1.0
                
        super(GRLT2, self).__init__(self)

and finally, we can implement an updated version of AND:

class AND(MLPLayer):
    def __init__(self, embedding: EmbeddedState, first_reg: Register, second_reg: Register, result_reg: Register):
        self.first_weights = torch.zeros(embedding.dim, embedding.dim)
        self.first_bias = torch.zeros(embedding.dim)
        
        self.first_weights[first_reg.offset, result_reg.offset] += 1
        self.first_weights[second_reg.offset, result_reg.offset] += 1
        self.first_bias[result_reg.offset:result_reg.offset + result_reg.size] = -1.0
        
        self.second_weights = torch.eye(embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        self.mask = torch.zeros(embedding.dim)
        for reg in [result_reg]:
            self.mask[reg.offset:reg.offset + reg.size] = 1.0
        
        super(AND, self).__init__(self)

Here it is in action:

tokens = list('012')
pos = Register('pos', 2)
internal = Register('internal', 1)
a = Register('a', 1)
b = Register('b', 1)
result = Register('result', 1)

embedding = EmbeddedState(tokens, [pos, internal, a, b, result])

example = embedding.embed(embedding.tokenize('0110101'), [])
c2i = ConvertToInternal(embedding, internal)
shift1 = GRLT2(embedding, pos, 1, internal, a)
shift2 = GRLT2(embedding, pos, 2, internal, b)
anded = AND(embedding, a, b, result)

x = c2i.forward(example.unsqueeze(0))[0]
x = shift1.forward(x.unsqueeze(0))[0]
x = shift2.forward(x.unsqueeze(0))[0]
x = anded.forward(x.unsqueeze(0))[0]
plot_tensor(x, embedding, 'Registers')

Ahhh, much better. To clarify what you're seeing above, note how the input register stores a one-hot representation of the string 0110101 (the first row contains [1, 0, 0, ....] = 0, the second row contains [0, 1, 0, ....] = 1, etc.) Well now, the internal register just directly stores 0110101. Then, a is 1101011, b is 1010111, and result is 1000011, as expected.

Now we can finally implement Shift!

It would be nice to have Shift be an in-place operation, which we can just call on a register to shift it left/right by some number of steps. In order to do this, we'll introduce the concept of "work registers", which layers like Shift will require—they'll allow us to e.g. calculate intermediate shifted tchaikovsky registers, and then save the result of the AND back into the original register. To this end, we'll also introduce the Clear and Copy layers, which are pretty self-explanatory.

class Clear(MLPLayer):
    def __init__(self, embedding: EmbeddedState, registers: list[Register]):
        self.first_weights = torch.zeros(embedding.dim, embedding.dim)
        self.first_bias = torch.zeros(embedding.dim)
        
        for reg in registers:
            for i in range(reg.size):
                self.first_weights[reg.offset + i, reg.offset + i] = 100.0
        
        self.second_weights = torch.zeros(embedding.dim, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        for reg in registers:
            for i in range(reg.size):
                self.second_weights[reg.offset + i, reg.offset + i] = -0.01
    
        self.mask = torch.zeros(embedding.dim)
        for reg in registers:
            self.mask[reg.offset:reg.offset + reg.size] = 1.0
        
        super(Clear, self).__init__(self)
class Copy(torch.nn.Module):
    def __init__(self, embedding: EmbeddedState, pos_reg: Register, copy_from: Register, copy_to: Register):
        super(Copy, self).__init__()
        
        self.copy = GRLT2(embedding, pos_reg, 0, copy_from, copy_to)
    
    def forward(self, seq):
        return self.copy.forward(seq)
    
class Shift(torch.nn.Module):
    def __init__(self, embedding: EmbeddedState, pos: Register, tchaikovsky: Register, register_to_shift: Register, amount: int, work_registers: list[Register]):
        super(Shift, self).__init__()
        
        self.embedding = embedding
        
        self.shiftpt1 = GRLT2(embedding, pos, -amount, register_to_shift, work_registers[0])
        self.clear = Clear(embedding, [register_to_shift])
        self.shifted_tchaikovsky = GRLT2(embedding, pos, -(amount - 1), tchaikovsky, work_registers[1])
        self.shiftpt2 = AND(embedding, work_registers[1], work_registers[0], register_to_shift)
        self.cleannup = Clear(embedding, work_registers)
        
    def forward(self, seq, save_intermediate_steps=False):
        intermediate_steps = []
        x = self.shiftpt1.forward(seq)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'Shift, part 1/5'))
        x = self.clear.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'Shift, part 2/5'))
        x = self.shifted_tchaikovsky.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'Shift, part 3/5'))
        x = self.shiftpt2.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'Shift, part 4/5'))
        x = self.cleannup.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'Shift, part 5/5'))
        if save_intermediate_steps:
            return x, True, intermediate_steps
        return x
    
class ShiftL(torch.nn.Module):
    def __init__(self, embedding: EmbeddedState, pos: Register, anti_tchaikovsky: Register, register_to_shift: Register, amount: int, work_registers: list[Register]):
        super(ShiftL, self).__init__()
        
        self.embedding = embedding
        
        self.shiftpt1 = GRLT2(embedding, pos, amount, register_to_shift, work_registers[0])
        self.clear = Clear(embedding, [register_to_shift])
        self.shifted_antitchaikovsky = GRLT2(embedding, pos, amount - 1, anti_tchaikovsky, work_registers[1])
        self.shiftpt2 = AND(embedding, work_registers[1], work_registers[0], register_to_shift)
        self.cleannup = Clear(embedding, work_registers)
        
    def forward(self, seq, save_intermediate_steps=False):
        intermediate_steps = []
        x = self.shiftpt1.forward(seq)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'ShiftL, part 1/5'))
        x = self.clear.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'ShiftL, part 2/5'))
        x = self.shifted_antitchaikovsky.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'ShiftL, part 3/5'))
        x = self.shiftpt2.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'ShiftL, part 4/5'))
        x = self.cleannup.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'ShiftL, part 5/5'))
        if save_intermediate_steps:
            return x, True, intermediate_steps
        return x
    

Now let's see it in action:

import imageio.v2 as imageio
from IPython.display import Image

tokens = list('012')
pos = Register('pos', 2)
tchaikovsky = Register('tchaik', 1)
anti_tchaikovsky = Register('antit', 1)
internal = Register('internal', 1)
a = Register('a', 1)
b = Register('b', 1)
work_registers = []
for i in range(4):
    work_registers.append(Register(f'work{i}', 1))

embedding = EmbeddedState(tokens, [pos, tchaikovsky, anti_tchaikovsky, internal, a, b] + work_registers)

example = embedding.embed(embedding.tokenize('1110101'), [embedding.itokenize('0111111'), embedding.itokenize('1111110')])
c2i = ConvertToInternal(embedding, internal)
copy1 = Copy(embedding, pos, internal, a)
copy2 = Copy(embedding, pos, internal, b)
shift1 = Shift(embedding, pos, tchaikovsky, a, 1, work_registers)
shift2 = ShiftL(embedding, pos, anti_tchaikovsky, b, 1, work_registers)

tensors = []
x = c2i.forward(example.unsqueeze(0))[0]
tensors.append(draw_tensor(x, embedding, 'Starting state'))
x = copy1.forward(x.unsqueeze(0))[0]
tensors.append(draw_tensor(x, embedding, 'After first copy'))
x = copy2.forward(x.unsqueeze(0))[0]
tensors.append(draw_tensor(x, embedding, 'After second copy'))
x, _, intermediates = shift1.forward(x.unsqueeze(0), save_intermediate_steps=True)
x = x[0]
tensors.extend(intermediates)
x, _, intermediates = shift2.forward(x.unsqueeze(0), save_intermediate_steps=True)
x = x[0]
tensors.extend(intermediates)
tensors.append(draw_tensor(x, embedding, 'Final state'))

images = []

for idx, tensor in enumerate(tensors):
    image = imageio.imread(tensor)
    images.append(image)

# Create a GIF from the in-memory images
imageio.mimsave("animation.gif", images, duration=1000)
display(Image(filename="animation.gif"))

Wow, look at that! We copied the internal register (1110101) to registers a and b, shifted A to the right by 1 (getting 0111010), and shifted B to the left by 1 (getting 1101010), with both shifts happening in-place and being zero-padded! I think we can mark the right-shift operation complete.

The next unfinished thing on our list is rightrotate. This is pretty easy to do: if we want to rotate a string of length 7 to the right by k steps, we just do that, and then OR it with the string rotated to the left by 7-k steps. But if we're going to implement OR for this anyways, let's just implement a bunch of useful binary functions, and test that they work on arbitrary registers, which we'll denote A and B:

class NOT_To(MLPLayer):
    def __init__(self, embedding: EmbeddedState, from_reg: Register, result_reg: Register):
        self.first_weights = torch.zeros(embedding.dim, embedding.dim)
        self.first_bias = torch.zeros(embedding.dim)
        
        self.first_weights[from_reg.offset, result_reg.offset] = -1
        self.first_bias[result_reg.offset:result_reg.offset + result_reg.size] = 1.0
        
        self.second_weights = torch.eye(embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        self.mask = torch.zeros(embedding.dim)
        for reg in [result_reg]:
            self.mask[reg.offset:reg.offset + reg.size] = 1.0
        
        super(NOT_To, self).__init__(self)

class NOT(torch.nn.Module):
    def __init__(self, embedding: EmbeddedState, pos: Register, register: Register, work_registers: list[Register]):
        super(NOT, self).__init__()
        
        self.not_to = NOT_To(embedding, register, work_registers[0])
        self.clear = Clear(embedding, [register])
        self.copy = Copy(embedding, pos, work_registers[0], register)
        self.clear2 = Clear(embedding, work_registers)
    
    def forward(self, seq):
        x = self.not_to.forward(seq)
        x = self.clear.forward(x)
        x = self.copy.forward(x)
        x = self.clear2.forward(x)
        return x
tokens = list('012')
pos = Register('pos', 2)
a = Register('A', 1)
work_registers = []
for i in range(4):
    work_registers.append(Register(f'work{i}', 1))

embedding = EmbeddedState(tokens, [pos, a] + work_registers)

example = embedding.embed(embedding.tokenize('0110001'), [embedding.itokenize('0110001')])
notted = NOT(embedding, pos, a, work_registers)

plot_tensor(example, embedding, 'Starting state')
x = notted.forward(example.unsqueeze(0))[0]
plot_tensor(x, embedding, '¬A')

class NOR(MLPLayer):
    def __init__(self, embedding: EmbeddedState, first_reg: Register, second_reg: Register, result_reg: Register):
        self.first_weights = torch.zeros(embedding.dim, embedding.dim)
        self.first_bias = torch.zeros(embedding.dim)
        
        self.first_weights[first_reg.offset, result_reg.offset] += -1
        self.first_weights[second_reg.offset, result_reg.offset] += -1
        self.first_bias[result_reg.offset:result_reg.offset + result_reg.size] = 1.0
        
        self.second_weights = torch.eye(embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        self.mask = torch.zeros(embedding.dim)
        for reg in [result_reg]:
            self.mask[reg.offset:reg.offset + reg.size] = 1.0
        
        super(NOR, self).__init__(self)
tokens = list('012')
pos = Register('pos', 2)
a = Register('A', 1)
b = Register('B', 1)
result = Register('result', 1)
work_registers = []
for i in range(0):
    work_registers.append(Register(f'work{i}', 1))

embedding = EmbeddedState(tokens, [pos, a, b, result] + work_registers)

example = embedding.embed(embedding.tokenize('0110001'), [embedding.itokenize('0110001'), embedding.itokenize('1010101')])
norred = NOR(embedding, a, b, result)

plot_tensor(example, embedding, 'Starting state')
x = norred.forward(example.unsqueeze(0))[0]
plot_tensor(x, embedding, 'NOR(A, B)')

class OR(torch.nn.Module):
    def __init__(self, embedding: EmbeddedState, pos: Register, first_reg: Register, second_reg: Register, result_reg: Register, work_registers: list[Register]):
        super(OR, self).__init__()
        
        self.part1 = NOR(embedding, first_reg, second_reg, result_reg)
        self.part2 = NOT(embedding, pos, result_reg, work_registers)
        self.cleanup = Clear(embedding, work_registers)
        
    def forward(self, seq):
        x = self.part1.forward(seq)
        x = self.part2.forward(x)
        x = self.cleanup.forward(x)
        return x
tokens = list('012')
pos = Register('pos', 2)
a = Register('A', 1)
b = Register('B', 1)
result = Register('result', 1)
work_registers = []
for i in range(1):
    work_registers.append(Register(f'work{i}', 1))

embedding = EmbeddedState(tokens, [pos, a, b, result] + work_registers)

example = embedding.embed(embedding.tokenize('0110001'), [embedding.itokenize('0110001'), embedding.itokenize('1010101')])
orred = OR(embedding, pos, a, b, result, work_registers)

plot_tensor(example, embedding, 'Starting state')
x = orred.forward(example.unsqueeze(0))[0]
plot_tensor(x, embedding, 'OR(A, B)')

class XOR(torch.nn.Module):
    def __init__(self, embedding: EmbeddedState, pos: Register, first_reg: Register, second_reg: Register, result_reg: Register, work_registers: list[Register]):
        super(XOR, self).__init__()
        
        # A OR B
        self.part1 = OR(embedding, pos, first_reg, second_reg, work_registers[0], work_registers[1:])
        
        # A AND B
        self.part2 = AND(embedding, first_reg, second_reg, work_registers[1])
        
        # NOT (A AND B)
        self.part3 = NOT(embedding, pos, work_registers[1], work_registers[2:])
        
        # (A OR B) AND NOT (A AND B)
        self.part4 = AND(embedding, work_registers[0], work_registers[1], result_reg)
        
        # Clear the work registers
        self.part5 = Clear(embedding, work_registers)
        
    def forward(self, seq):
        x = self.part1.forward(seq)
        x = self.part2.forward(x)
        x = self.part3.forward(x)
        x = self.part4.forward(x)
        x = self.part5.forward(x)
        return x
tokens = list('012')
pos = Register('pos', 2)
a = Register('A', 1)
b = Register('B', 1)
result = Register('result', 1)
work_registers = []
for i in range(3):
    work_registers.append(Register(f'work{i}', 1))

embedding = EmbeddedState(tokens, [pos, a, b, result] + work_registers)

example = embedding.embed(embedding.tokenize('0110001'), [embedding.itokenize('0110001'), embedding.itokenize('1010101')])
xorred = XOR(embedding, pos, a, b, result, work_registers)

plot_tensor(example, embedding, 'Starting state')
x = xorred.forward(example.unsqueeze(0))[0]
plot_tensor(x, embedding, 'XOR(A, B)')

class Rotate(torch.nn.Module):
    def __init__(self, embedding: EmbeddedState, pos: Register, tchaikovsky: Register, anti_tchaikovsky: Register, register_to_rotate: Register, amount: int, work_registers: list[Register]):
        super(Rotate, self).__init__()
        
        self.embedding = embedding
        
        # First, we need to copy the register to two work registers
        # Thus, work registers 0 and 1 are currently in use
        self.copies = [Copy(embedding, pos, register_to_rotate, work_registers[i]) for i in range(2)]
        
        # Next, we shift the first work register to the right
        self.shift_right = Shift(embedding, pos, tchaikovsky, work_registers[0], amount, work_registers[2:])
        
        # Then, we shift the second work register to the left INPUT_LENGTH - 1 times
        self.left_shifts = ShiftL(embedding, pos, anti_tchaikovsky, work_registers[1], (INPUT_LENGTH - amount), work_registers[2:])
        
        # Now, we clear the original register
        self.clear = Clear(embedding, [register_to_rotate])
        
        # And finally, we OR work registers 0 and 1 to get the final result
        self.or_result = OR(embedding, pos, work_registers[0], work_registers[1], register_to_rotate, work_registers[2:])
        
        # Oh, and clear the work registers
        self.clear_work = Clear(embedding, work_registers)
        
    def forward(self, seq, save_intermediate_steps=False):
        intermediate_steps = []
        for copy in self.copies:
            seq = copy.forward(seq)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(seq.squeeze(), self.embedding, 'Rotate, part 1/6'))
        seq = self.shift_right.forward(seq)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(seq.squeeze(), self.embedding, 'Rotate, part 2/6'))
        seq = self.left_shifts.forward(seq)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(seq.squeeze(), self.embedding, 'Rotate, part 3/6'))
        seq = self.clear.forward(seq)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(seq.squeeze(), self.embedding, 'Rotate, part 4/6'))
        seq = self.or_result.forward(seq)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(seq.squeeze(), self.embedding, 'Rotate, part 5/6'))
        seq = self.clear_work.forward(seq)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(seq.squeeze(), self.embedding, 'Rotate, part 6/6'))
        if save_intermediate_steps:
            return seq, True, intermediate_steps
        return seq
    

Here is our rightrotate working successfully:

tokens = list('012')
pos = Register('pos', 2)
tchaikovsky = Register('tchaik', 1)
anti_tchaikovsky = Register('antit', 1)
a = Register('A', 1)
work_registers = []
for i in range(4):
    work_registers.append(Register(f'work{i}', 1))

embedding = EmbeddedState(tokens, [pos, tchaikovsky, anti_tchaikovsky, a] + work_registers)

example = embedding.embed(embedding.tokenize('0110001'), [embedding.itokenize('0111111'), embedding.itokenize('1111110'), embedding.itokenize('0110001')])
rotated = Rotate(embedding, pos, tchaikovsky, anti_tchaikovsky, internal, 2, work_registers)

tensors = []
plot_tensor(example, embedding, 'Starting state')
x, _, intermediates = rotated.forward(example.unsqueeze(0), True)
plot_tensor(x[0], embedding, 'rightrotate(A, 2)')
tensors.extend(intermediates)
tensors.append(draw_tensor(x[0], embedding, 'Final state'))

And here's an animation of the process:

images = []

for idx, tensor in enumerate(tensors):
    image = imageio.imread(tensor)
    images.append(image)

imageio.mimsave("rotate.gif", images, duration=1000)

display(Image(filename="rotate.gif"))

The last thing we have to do is addition (modulo 2^32). Since we can't have "if" statements in transformer logic, we simply add the two numbers together, calculate the carry bits, shift the carry bits, and add the carry bits back in...thirty-two times. Obviously, for most of the operations, the carry_bits register is all zeros, but we don't know when this happens, so we just add in the zeros a constant number of times. Here it is, in its full glory:

class Add(torch.nn.Module):
    def __init__(self, embedding: EmbeddedState, pos: Register, anti_tchaikovsky: Register, a: Register, b: Register, result: Register, work_registers: list[Register]):
        super(Add, self).__init__()
        
        self.embedding = embedding
        
        # work_registers[0] is `sum_`
        # work_registers[1] is `carry`
        self.first_sum = XOR(embedding, pos, a, b, work_registers[0], work_registers[2:])
        self.first_carry = AND(embedding, a, b, work_registers[1])
        
        self.next_operations = []
        
        for _ in range(32):
            # Copy `carry` to work_registers[2]
            copy_of_carry = Copy(embedding, pos, work_registers[1], work_registers[2])
            # Shift this copy of `carry` to the left. Now `work_registers[2]` contains `shifted_carry`
            shifted_carry = ShiftL(embedding, pos, anti_tchaikovsky, work_registers[2], 1, work_registers[3:])
            # XOR `sum_` with `shifted_carry`. Now `work_registers[3]` contains `new_sum`
            new_sum = XOR(embedding, pos, work_registers[0], work_registers[2], work_registers[3], work_registers[4:])
            # Clear `carry`
            clear_carry = Clear(embedding, [work_registers[1]])
            # AND `sum_` with `shifted_carry`. Now `work_registers[1]` contains `carry` again
            carry = AND(embedding, work_registers[0], work_registers[2], work_registers[1])
            # Clear `sum`
            clear_sum = Clear(embedding, [work_registers[0]])
            # Copy `new_sum` to `sum_`
            sum = Copy(embedding, pos, work_registers[3], work_registers[0])
            # Clear the work registers
            clear_work = Clear(embedding, work_registers[2:])
            self.next_operations.append((copy_of_carry, shifted_carry, new_sum, clear_carry, carry, clear_sum, sum, clear_work))
            
        self.copy_to_result = Copy(embedding, pos, work_registers[0], result)
        
        self.clear = Clear(embedding, work_registers)
    
    def forward(self, seq, save_intermediate_steps=False):
        intermediate_steps = []
        
        c = 0
        
        x = self.first_sum.forward(seq)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'Add, part 1'))
        x = self.first_carry.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, 'Add, part 2'))
        
        for copy_of_carry, shifted_carry, new_sum, clear_carry, carry, clear_sum, sum, clear_work in self.next_operations:
            x = copy_of_carry.forward(x)
            if save_intermediate_steps:
                intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, f'Add, part {c}'))
                c += 1
            x = shifted_carry.forward(x)
            if save_intermediate_steps:
                intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, f'Add, part {c}'))
                c += 1
            x = new_sum.forward(x)
            if save_intermediate_steps:
                intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, f'Add, part {c}'))
                c += 1
            x = clear_carry.forward(x)
            if save_intermediate_steps:
                intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, f'Add, part {c}'))
                c += 1
            x = carry.forward(x)
            if save_intermediate_steps:
                intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, f'Add, part {c}'))
                c += 1
            x = clear_sum.forward(x)
            if save_intermediate_steps:
                intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, f'Add, part {c}'))
                c += 1
            x = sum.forward(x)
            if save_intermediate_steps:
                intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, f'Add, part {c}'))
                c += 1
            x = clear_work.forward(x)
            if save_intermediate_steps:
                intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, f'Add, part {c}'))
                c += 1
        
        x = self.copy_to_result.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, f'Add, part {c}'))
            c += 1
        x = self.clear.forward(x)
        if save_intermediate_steps:
            intermediate_steps.append(draw_tensor(x.squeeze(), self.embedding, f'Add, part {c}'))
            c += 1
        
        if save_intermediate_steps:
            return x, True, intermediate_steps
        return x

And here it is, adding two 32-bit numbers together:

INPUT_LENGTH = 32
tokens = list('012')
pos = Register('pos', 2)
tchaikovsky = Register('tchaik', 1)
anti_tchaikovsky = Register('antit', 1)
a = Register('a', 1)
internal = Register('internal', 1)
result = Register('result', 1)
work_registers = []
for i in range(7):
    work_registers.append(Register(f'work{i}', 1))

embedding = EmbeddedState(tokens, [pos, tchaikovsky, anti_tchaikovsky, a, internal, result] + work_registers)

binary1 = '10101010010101011101001010010011'
binary2 = '11001100001100110000011111001010'

example = embedding.embed(embedding.tokenize(binary1), [embedding.itokenize('01111111111111111111111111111111'), embedding.itokenize('11111111111111111111111111111110'), embedding.itokenize(binary2)])
c2i = ConvertToInternal(embedding, internal)
add = Add(embedding, pos, anti_tchaikovsky, a, internal, result, work_registers)

c2ix = c2i.forward(example.unsqueeze(0))[0]
x = add.forward(c2ix.unsqueeze(0), False)
x = x[0]

python_binary_sum = bin(int(binary1, 2) + int(binary2, 2))[2:].zfill(32)[-32:]
transformer_binary_sum = ''.join(str(c) for c in [int(q) for q in x[:, result.offset:result.offset + result.size].detach().flatten()])

print(python_binary_sum)
print(transformer_binary_sum)
01110110100010001101101001011101
01110110100010001101101001011101

Wahoo!! It works!!!

For those of you who like pretty pictures (me), here's the ANIMATED addition process struggling to add these two numbers!

tokens = list('012')
pos = Register('pos', 2)
tchaikovsky = Register('tchaik', 1)
anti_tchaikovsky = Register('antit', 1)
a = Register('a', 1)
internal = Register('internal', 1)
result = Register('result', 1)
work_registers = []
for i in range(7):
    work_registers.append(Register(f'work{i}', 1))

embedding = EmbeddedState(tokens, [pos, tchaikovsky, anti_tchaikovsky, a, internal, result] + work_registers)

# This lets us see all 32 carries actually do something
binary1 = '10001000100010101110111110100011'
binary2 = '01110111011101010001000001011101'

example = embedding.embed(embedding.tokenize(binary1), [embedding.itokenize('01111111111111111111111111111111'), embedding.itokenize('11111111111111111111111111111110'), embedding.itokenize(binary2)])
c2i = ConvertToInternal(embedding, internal)
add = Add(embedding, pos, anti_tchaikovsky, a, internal, result, work_registers)

c2ix = c2i.forward(example.unsqueeze(0))[0]

x, _, intermediates = add.forward(c2ix.unsqueeze(0), True)

images = []

for idx, tensor in enumerate(intermediates):
    image = imageio.imread(tensor)
    images.append(image)

imageio.mimsave("add.gif", images, duration=200)
display(Image(filename="add.gif"))

Now that we have rightshift, rightrotate, and, xor, and + all implemented, we're ready to go to...

Part 2: Implementing SHA

Now the problem with all of our code above is it's very hard to use. For instance, different functions take different numbers of parameters—some require the zero register, others don't; some require the work register, others don't—and besides, it's really hard to keep track of all the registers we need to implement something as complex as SHA256. Thus, we implement a little "template language" that allows us to "program" transformers, which supports functions like copy, rotate_, xor, print_, and everything else we'll need. One simple program might look something like this:

% Required declarations
TOKENS = /0\1\2/
NUM_WORK_REGISTERS = 10
INPUT = 1101011101011010

% Actual program

PROGRAM_START
w0 = copy_input()
print_(w0)
w1 = copy(w0)
print_(w1)
w2 = copy(w0)
rotate_with_limit_(w2, 1, 4)
print_(w2)
w3 = copy(w0)
rotate_with_limit_(w3, 2, 4)
print_(w3)
w4 = copy(w0)
rotate_with_limit_(w4, 3, 4)
print_(w4)

s0_a = copy(w1)
rotate_(s0_a, 7)
print_(s0_a)
s0_b = copy(w1)
rotate_(s0_b, 18)
print_(s0_b)
s0_c = copy(w1)
shift_(s0_c, 3)
print_(s0_c)
s0_d = xor(s0_a, s0_b)
print_(s0_d)
s0 = xor(s0_c, s0_d)
print_(s0)

PROGRAM_END

Now, our "compiler" takes this program, automatically calculates the total number of registers we need, and converts it to a Python implementation of the transformer layers. The most important part is the function templates:

func_templates = {
    "copy": "Copy(embedding, pos, <a>, <b>)",
    "copy_input": "ConvertToInternal(embedding, <a>)",
    "keep_": "Keep(<a>, <b>)",
    "rotate_": "Rotate(embedding, pos, tchaikovsky, anti_tchaikovsky, <a>, <b>, work_registers)",
    "rotate_with_limit_": "RotateWithLimit(embedding, pos, tchaikovsky, anti_tchaikovsky, <a>, <b>, <c>, work_registers)",
    "shiftr_": "Shift(embedding, pos, tchaikovsky, <a>, <b>, work_registers)",
    "shiftl_": "ShiftL(embedding, pos, anti_tchaikovsky, <a>, <b>, work_registers)",
    "xor": "XOR(embedding, pos, <a>, <b>, <c>, work_registers)",
    "and": "AND(embedding, <a>, <b>, <c>)",
    "not_": "NOT(embedding, pos, <a>, work_registers)",
    "print_": "Print(embedding, <a>)",
    "add": "Add(embedding, pos, anti_tchaikovsky, <a>, <b>, <c>, work_registers)"
}

So, the compiler is essentially a glorified "find-and-replace", with no support for loops, dynamic variable allocations, or anything else you might expect (give me a break, I did math in undergrad). But it works!!

The code above is indicative of the type of operations we'll be using to implement SHA256. In particular, we'll make big use of the rotate_with_limit_ operation: in general, we will be operating on length-512 inputs, but the variables that we actually care about all have length 32. (To be clear, in a transformer, your input length is a hard limit on how many "tokens" you have—both an upper and lower limit—you're stuck with EVERY variable you have being 512 bits long). Instead of doing this the smart way (a phrase you will hear a lot in this section), we will simply use only the first 32 of the 512 bits of each register to store relevant information, and the rest of the bits will be "junk". In any case, the above program compiles into the following Python file:

tokens = ['0', '1', '2']
pos = Register('pos', 2)
tchaikovsky = Register('tchaikovsky', 1)
anti_tchaikovsky = Register('anti_tchaikovsky', 1)
zeros = Register('zeros', 1)
ones = Register('ones', 1)
w0 = Register('w0', 1)
w1 = Register('w1', 1)
w2 = Register('w2', 1)
w3 = Register('w3', 1)
w4 = Register('w4', 1)
s0_a = Register('s0_a', 1)
s0_b = Register('s0_b', 1)
s0_c = Register('s0_c', 1)
s0_d = Register('s0_d', 1)
s0 = Register('s0', 1)

work_registers = []
for i in range(10):
    work_registers.append(Register(f'work_{i}', len(tokens)))
embedding = EmbeddedState(tokens, [pos, tchaikovsky, anti_tchaikovsky, zeros, ones, w0, w1, w2, w3, w4, s0_a, s0_b, s0_c, s0_d, s0] + work_registers)
first_input = embedding.embed(embedding.tokenize('1101011101011010'), [embedding.itokenize('0111111111111111'), embedding.itokenize('1111111111111110'), embedding.itokenize('0000000000000000'), embedding.itokenize('1111111111111111')])
op_0 = ConvertToInternal(embedding, w0)
op_1 = Print(embedding, w0)
op_2 = Copy(embedding, pos, w0, w1)
op_3 = Print(embedding, w1)
op_4 = Copy(embedding, pos, w0, w2)
op_5 = RotateWithLimit(embedding, pos, tchaikovsky, anti_tchaikovsky, w2, 1, 4, work_registers)
op_6 = Print(embedding, w2)
op_7 = Copy(embedding, pos, w0, w3)
op_8 = RotateWithLimit(embedding, pos, tchaikovsky, anti_tchaikovsky, w3, 2, 4, work_registers)
op_9 = Print(embedding, w3)
op_10 = Copy(embedding, pos, w0, w4)
op_11 = RotateWithLimit(embedding, pos, tchaikovsky, anti_tchaikovsky, w4, 3, 4, work_registers)
op_12 = Print(embedding, w4)
op_13 = Copy(embedding, pos, w1, s0_a)
op_14 = Rotate(embedding, pos, tchaikovsky, anti_tchaikovsky, s0_a, 7, work_registers)
op_15 = Print(embedding, s0_a)
op_16 = Copy(embedding, pos, w1, s0_b)
op_17 = Rotate(embedding, pos, tchaikovsky, anti_tchaikovsky, s0_b, 18, work_registers)
op_18 = Print(embedding, s0_b)
op_19 = Copy(embedding, pos, w1, s0_c)
op_20 = Shift(embedding, pos, tchaikovsky, s0_c, 3, work_registers)
op_21 = Print(embedding, s0_c)
op_22 = XOR(embedding, pos, s0_a, s0_b, s0_d, work_registers)
op_23 = Print(embedding, s0_d)
op_24 = XOR(embedding, pos, s0_c, s0_d, s0, work_registers)
op_25 = Print(embedding, s0)

It's so nice to not have to write all of that code manually! And finally, the output of the program is, as expected:

1101011101011010
1101011101011010
1110101110101101
0111010111010110
1011101011101011
1011010110101110
1111010111010110
0001101011101011
0100000001111000
0101101010010011

The final algorithm

This took me some time, but finally, here it is: the full SHA-256 algorithm implemented in our cursed language! It compiles to an even more cursed 26000-line python file. But when we run it, the final outputs (truncated to only the first 32 characters, that we care about) are:

h0_final = 0b10111001010011010010011110111001
h1_final = 0b10010011010011010011111000001000
h2_final = 0b10100101001011100101001011010111
h3_final = 0b11011010011111011010101111111010
h4_final = 0b11000100100001001110111111100011
h5_final = 0b01111010010100111000000011101110
h6_final = 0b10010000100010001111011110101100
h7_final = 0b11100010111011111100110111101001

Which, when concatenated and converted to hex, give us

hex(int(h0_final))[2:] + hex(int(h1_final))[2:] + hex(int(h2_final))[2:] + hex(int(h3_final))[2:] + hex(int(h4_final))[2:] + hex(int(h5_final))[2:] + hex(int(h6_final))[2:] + hex(int(h7_final))[2:]
'b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9'

Whereas the expected output is....

from hashlib import sha256

input_ = 'hello world'
print(sha256(input_.encode('utf-8')).hexdigest())
b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9

WE DID IT!!!!!! A transformer that calculates SHA-256, correctly, 100% of the time 😎

Final thoughts

This was a really fun project—writing a transformer, by hand, that calculates a function that's pretty impossible to learn using gradient descent is supremely satisfying. It was also a lot more engineering-heavy than I had anticipated—debugging my "compiler", writing the SHA-256 implementation, and counting work register assignments all took way more time than I expected. As such, the resulting model isn't actually "good" in a lot of senses, but rather, very much "the first thing that worked".

Some highlights:

If you've read this far and you want to continue on your masochistic streak, please feel free to contribute to my code here :)