Q-Learning in RASP

After my previous RASP experiments, I wondered: just how far can we take RASP? For instance, can we make a transformer that solves a tabular Q-learning environment?

Well, let's find out.

!pip install -qqq git+https://github.com/chalk-diagrams/chalk git+https://github.com/srush/RASPy
  Installing build dependencies ... ents to build wheel ... etadata (pyproject.toml) ... etadata (setup.py) ... etadata (setup.py) ... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.1/67.1 kB 1.1 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.5/62.5 kB 3.3 MB/s eta 0:00:00
s (pyproject.toml) ... 

Same helper functions as last time:

from raspy import *
from raspy.rasp import Seq, Sel, SOp, Key, Query
from raspy.visualize import draw_all, draw, draw_sel
from chalk import *
from colour import Color
from raspy.visualize import word
import random
import matplotlib.pyplot as plt

def draw(c_inp=Color("white"), c_att=Color("white"), c_back=Color("white"), c_ffn=Color("white")):

    d =  box("Input", c_inp).named("inp") / vstrut(1) / (rectangle(3, 4).fill_color(c_back).named("main") +  ( box("Feed Forward", c_ffn).named("ffn") / vstrut(1) / box("Attention", c_att).named("att")).center_xy()) / vstrut(1) / box("Final").named("final")
    return d.connect_outside("inp", "main").connect_outside("ffn", "att").connect_outside("main", "final")

def draw_att():
    d = rectangle(2.5, 2.5)
    d = d.beside(box2("key", green).rotate_by(0.25).named("key"), -unit_x)
    d = d.beside(box2("query", orange).named("query"), -unit_y)
    d = d.beside(box2("value", red).rotate_by(-0.25).named("value"), unit_x)
    d = d.beside(box2("output").named("output"), unit_y)
    d = d + rectangle(0.4,0.4).fill_color(black).named("mid").translate(0, -0.5)
    d = d + rectangle(0.4,0.4).fill_color(black).named("mid2").translate(0, 0.5)
    d = d.connect_perim("key", "mid", unit_x, -unit_x).connect_outside("query", "mid").connect_outside("mid", "mid2").connect_perim("value", "mid2", -unit_x, unit_x).connect_outside("mid2", "output")
    return d
before = key(indices) < query(indices)

def atoi(seq=tokens):
    return seq.map(lambda x: ord(x) - ord('0'))

def itatoi(seq=tokens):
    return seq.map(lambda x: int(x))

def ititoa(seq=tokens):
    return seq.map(lambda x: chr(x))

def cumsum(seq=tokens):
    x = (before | (key(indices) == query(indices))).value(seq)
    return x.name("cumsum")

def index(i, seq=tokens):
    x = (key(indices) == query(i)).value(seq)
    return x.name("index")

def shift(i=1, default="_", seq=tokens):
    x = (key(indices) == query(indices - i)).value(seq, default)
    return x.name("shift")

def shiftl(i=1, default="_", seq=tokens):
    x = (key(indices) == query(indices + i)).value(seq, default)
    return x.name("shiftl")

def shift_to_one(default="_", seq=tokens):
    x = (key(indices) == query(indices + 35)).value(seq, default)
    return x.name("shift_to_one")

def shiftls(i=1, default="_", seq=tokens):
    x = (key(indices) == query(indices + i - 1)).value(seq, default)
    return x.name("shiftls")

def shake_shift(i=1, default="_", seq=tokens):
    x = (key(indices) == query(indices - i + 3)).value(seq, default)
    x = (key(indices) == query(indices + i - 3)).value(x, default)
    return x.name("shakeshift")

def lfsr_shift(seq=tokens):
    x = (key(indices) == query(indices + 10)).value(seq, 0)
    return x.name("lfsr_shift")

def minimum(seq=tokens):
    sel1 = before & (key(seq) == query(seq))
    sel2 = key(seq) < query(seq)
    less = (sel1 | sel2).value(1)
    x = (key(less) == query(0)).value(seq)
    return x.name("min")

def first(q, seq=tokens):
    return minimum(where(seq == q, indices, 99))

def ralign(default="-", sop=tokens):
    c = (key(sop) == query("_")).value(1)
    x = (key(indices + c) == query(indices)).value(sop, default)
    return x.name("ralign")

def split(v, i, sop=tokens, default="0"):
    split_point = (key(sop) == query(v)).value(indices)
    if i:
        x = ralign(default, where(indices < split_point, sop, "_"))
        return x
    else:
        x = where(indices > split_point, sop, default)
        return x

def slide(match, seq=tokens):
    x = cumsum(match)
    y = ((key(x) == query(x + 1)) & (key(match) == query(True))).value(seq)
    seq =  where(match, seq, y)
    return seq.name("slide")

def add(sop=tokens):
    # 0) Parse and add
    x = atoi(split("+", True, sop)) \
        + atoi(split("+", False, sop))

    # 1) Check for carries
    gets_carry = shift(-1, "0", where(x > 9, "1", where(x == 9, "<", "0")))

    # 2) In parallel, slide carries to their column
    gets_carry = atoi(slide(gets_carry != "<", gets_carry))

    # 3) Add in carries.
    return (x + gets_carry) % 10

def get_length(sop=tokens):
  length = (key(1) == query(1)).value(1)
  return length.name("length")

def indexof(v, sop=tokens):
  length = get_length(sop)
  replaced = where((sop == v), indices, length)  # Replace everything but the token with the length, and the token with its index
  return minimum(replaced)  # the minimum value is the first index!

Ok, let's define the actual game! The setup is going to be as follows: it's gonna be a 2x3 world, where the agent starts in the bottom-left corner, the goal is in the bottom-right corner, and the bottom-middle tile is an impassable wall. Thus, it looks something like this:

environment = [[0, 0, 0], [2, 1, 3]]
plt.imshow(environment)
<matplotlib.image.AxesImage at 0x7f00eeb47910>

Where the light green is the agent, the blue is the rock, and the yellow is the goal.

What do we need in order to do Q-learning on this? Well, first of all, we need a Q-table: that is, a table which represents our estimate of how good every action is given a certain state.

So, let's create a string representing the entire game state:

initial_game_state = "000210400004000040000400004000040000"

This string might look a bit confusing at first glance. In reality, we just split it up using the 4 as a delimiter, getting us the following sections:

So that's it! Game state plus Q table. We're ready to go.

We're going to be implementing the simplest possible tabular Q-learning algorithm, using epsilon-greedy...oh wait a second, how on earth are we going to get random numbers inside a transformer???

Well, of course we could just generate a random string of numbers in Python and feed it in after the game state so the transformer has a source of randomness. On the other hand, we could implement some decent pseudorandom number generator—like the Mersenne Twister. A good compromise between coolness and implementation difficulty is the Linear-feedback shift register—basically, you start with sixteen numbers that are 0 or 1, shift them all over to the right, and do a bunch of XORs to re-seed the newly empty zeroth array index. This doesn't actually give you random numbers, but the sequence of bits will have a long enough cycle that, for our purposes, will be random enough. Ok, let's do it!

The first thing we'll need to do is implement XOR:

def xor(comp, sop=tokens):
  result = where((sop == 7), sop, sop)
  sum = comp + result
  xor_result = sum % 2
  return xor_result.name("xor")

And now we're ready to implement the linear-feedback shift register:

def lfsr(sop=tokens):
  """
  Length-11 Linear-feedback shift register, with taps at positions 11 and 9.
  The period of the cycle is 2047, which is good enough for our purposes!
  """

  seed = itatoi(sop)

  right_shifted = shift(1, 0, seed)

  tap_11 = index(10, seed)
  tap_9 = index(8, seed)

  xored = xor(tap_11, tap_9)

  xor_result_only_in_first_index = lfsr_shift(xored)

  new_seed = xor_result_only_in_first_index + right_shifted

  return new_seed.name("lfsr")

Let's test it:

seed = "00010100001"
lfsr()(seed)

rng_bits = ""
for _ in range(100):
  seed = ''.join(str(i) for i in lfsr()(seed).val)
  rng_bits += seed[9]

rng_bits
0001010001001000101011010100001100001001111001011100111001011110111001001010111011000010101110010000

Looks pretty random to me!!

Ok, now we're going to implement a few helper functions. They're not particularly interesting, but will be useful for our main game "loop" (hahaha, there's no loops in RASP!)

First things first, we're gonna have a simple code representing the game state:

def get_state_id(sop=tokens):
  """
  Assumes a string of the full game state
  """

  # In the flattened game state, the state id is just the index of the 2 (the agent)
  state_id = indexof(2, sop)

  return state_id

get_state_id()([0, 0, 0, 0, 2, 0, 0, 0, 0, 0])

def c_and(other_bit, sop=tokens):
  """
  Improved `and` calculation
  """

  sum = other_bit + sop

  x = where((sum == 2), sum, 1) - 1

  return x

Two functions that add a certain value to a certain index of a sequence (don't ask why we need two)

def add_to_index(index, value, sop=tokens):
  value_to_add = shift(get_length(sop), value, sop)

  value_to_add_in_correct_place = itatoi(shift(index, 0, shiftl(get_length(sop) - 1, 0, value_to_add)))

  return value_to_add_in_correct_place + sop

def ati2(index, value, sop=tokens):
  value_to_add = where(indices >= 0, value, sop)

  value_to_add_in_correct_place = itatoi(shift(index, 0, shiftl(get_length(sop) - 1, 0, value_to_add)))

  return value_to_add_in_correct_place + sop

ati2(3, 10)([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

def update_game_state(new_state, sop=tokens):
  # We want to find the current location of the agent (the 2), and set it to zero
  # Then, we want to set the value of the updated agent location index to 2

  current_agent_location = indexof(2, sop)
  x = add_to_index(new_state, 2, sop)
  x = add_to_index(current_agent_location, -2, x)

  return x
import sys
sys.setrecursionlimit(10000)

And with all that out of the way, we're ready to implement our Q-learning RL agent!!! First, let's define the "transition matrix" and the "reward function":

transitions = "00130121122530333333"
rewards = "11111111111111110000"

In the above, transitions[state * 4 + action] tells us what next state we end up in given that we're currently in state state and take action action. rewards[state * 4 + action] similarly tells us what reward we get.

And finally, the gorgeous joint environment update/training loop:

def step(next_states, rewards, sop=tokens):
  # Split up the game state into a sequence of ints, rather than characters
  x = atoi(sop)
  # x is now [0, 0, 0, 2, 1, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 5, 4, 1, 4, 1, 4]

  # However, actually using 4s as delimiters is a bit difficult--it would be
  # way nicer if each delimiter was unique, so we could easily access the
  # relevant part of the game state. So, let's add one million to each delimiter.
  # We don't want to add a super small number, lest it get confused with state
  # value estimates (normally, they would be negative, but we store them
  # using positive values here due to what I initially thought was a limitation
  # of RASPy but later turned out to be false, and is now an aesthetic choice.)
  divider_addition = [0, 0, 0, 0, 0, 0, 1000000, 0, 0, 0, 0, 1000001, 0, 0, 0, 0, 1000002, 0, 0, 0, 0, 1000003, 0, 0, 0, 0, 1000004, 0, 0, 0, 0, 1000005, 0, 0, 0, 0, 1000006, 0, 1000007, 0, 1000008, 0, 1000009]

  full_game_state = itatoi(x + divider_addition)
  # The full game state is now the eyewatering [0, 0, 0, 2, 1, 0, 1000004, 0, 0, 0, 0, 1000005, 0, 0, 0, 0, 1000006, 0, 0, 0, 0, 1000007, 0, 0, 0, 0, 1000008, 0, 0, 0, 0, 1000009, 0, 0, 0, 0, 1000010, 5, 1000011, 1, 1000012, 1, 1000013]

  # Ok, now it's time to initialize our RNG.
  seed = shift(-100, 0, full_game_state)
  seed_init = [0,0,0,1,0,1,0,0,0,0,1]
  seed = seed + seed_init

  # Time: 22s

  # This is not a real loop, in the sense that there's no loop logic involved.
  # This just saves me the trouble of copying and pasting the layer definitions
  # many many times in a row, which is how I would've had to have done it otherwise.
  # for i in range(1):

  # First, let's get the current game state.
  # 0 is (0, 0)
  # 1 is (0, 1)
  # 2 is (0, 2)
  # 3 is (1, 0)
  # 5 is (1, 2)
  state_id = get_state_id(full_game_state)

  # Decide whether to take a random action or not
  # To do this, read off three bits off our LFSR and see if they're all 1
  # which gets us an epsilon of 0.125
  seed = lfsr(seed)
  rng1 = where(index(15, seed) == 1, 1, 0)
  seed = lfsr(seed)
  rng2 = where(index(15, seed) == 1, 1, 0)
  seed = lfsr(seed)
  rng3 = where(index(15, seed) == 1, 1, 0)

  rng_less_than_epsilon = c_and(rng1, c_and(rng2, rng3))

  # Figure out which action to take

  # First, generate a random action (we always do this, even if we don't end up taking it)
  seed = lfsr(seed)
  rng4 = where(index(15, seed) == 1, 1, 0)
  seed = lfsr(seed)
  rng5 = where(index(15, seed) == 1, 1, 0)
  random_action = rng4 * 2 + rng5 + 2  # [0/1] * 2 + [0/1] is uniform from {0, 1, 2, 3}

  # Time: 7m

  # Next, get the best action
  # First, extract the relevant section of the q-table
  state_lookup_index = indexof(state_id + 1000004, full_game_state)  # this is why we did all of that delimiter nonsense
  action_0_value = index(state_lookup_index + 1, full_game_state)
  action_1_value = index(state_lookup_index + 2, full_game_state)
  action_2_value = index(state_lookup_index + 3, full_game_state)
  action_3_value = index(state_lookup_index + 4, full_game_state)

  # Then, get the best action to take (in this case, the one with minimum value)
  shifted_a0v = shift(-42, 999999999, action_0_value)  # [val(a0), inf, inf, inf, ...]
  shifted_a1v = shift(1, 999999999, shift(-42, 999999999, action_1_value)) # [inf, val(a1), inf, inf, ...]
  shifted_a2v = shift(2, 999999999, shift(-42, 999999999, action_2_value)) # [inf, inf, val(a2), inf, ...]
  shifted_a3v = shift(3, 999999999, shift(-42, 999999999, action_3_value)) # [inf, inf, inf, val(a3), ...]

  action_sum = shifted_a0v + shifted_a1v + shifted_a2v + shifted_a3v
  # [3*inf + val(a0), 3*inf + val(a1), 3*inf + val(a2), 3*inf + val(a3), 4*inf, 4*inf, 4*inf, ...]

  # So the index of the minimum is the best action
  best_action = indexof(minimum(action_sum), action_sum) + 1

  shifted_best_action = shift(-35, 0, best_action)  # [B, 0, 0, 0, ...]
  shifted_random_action = shift(1, 0, shift(-35, 0, random_action))  # [0, R, 0, 0, ...]

  combined_actions = shifted_best_action + shifted_random_action
  # [B, R, 0, 0, 0, ...]

  action = index(rng_less_than_epsilon, combined_actions)
  # Take the zeroth index if rng is not less than epsilon (e.g. best action B)
  # Take the first index if rng is less than epsilon (e.g. random action R)

  # Lookup index for next state and reward as described earlier
  lookup_index = state_id * 4 + action

  next_state = atoi(index(lookup_index, next_states))
  reward = atoi(index(lookup_index, rewards))

  # Get the best value of the next action
  next_state_lookup_index = indexof(next_state + 1000004, full_game_state)
  best_value_of_next_action = minimum(shift(-32, 999999, shift(32, 999999, shiftl(next_state_lookup_index + 1, 999999, full_game_state))))

  bellman_update = reward - best_value_of_next_action  # assuming learning rate and gamma are both 1, lmao
  # also we store rewards as negative rewards for "simplicity"

  # Update the Q table
  full_game_state = ati2(state_lookup_index + action + 1, bellman_update, full_game_state)

  # Update the environment
  full_game_state = update_game_state(next_state, full_game_state)

  return full_game_state

step(transitions, rewards)(initial_game_state)