RSA

from memo import memo
import jax
import jax.numpy as np
from enum import IntEnum

print(jax.__version__)

class U(IntEnum):  # utterance space
    GREEN  = 0
    PINK   = 1
    SQUARE = 2
    ROUND  = 3

class R(IntEnum):  # referent space
    GREEN_SQUARE = 0
    GREEN_CIRCLE = 1
    PINK_CIRCLE  = 2

@jax.jit
def denotes(u, r):
    return np.array([
    #    green square
    #    |  green circle
    #    |  |  pink circle
    #    |  |  |
        [1, 1, 0],  # "green"
        [0, 0, 1],  # "pink"
        [1, 0, 0],  # "square"
        [0, 1, 1]   # "round"
    ])[u, r]
@memo
def L[u: U, r: R](beta, t):
    listener: thinks[
        speaker: given(r in R, wpp=1),
        speaker: chooses(u in U, wpp=
            denotes(u, r) * (1 if t == 0 else exp(beta * L[u, r](beta, t - 1))))
    ]
    listener: observes [speaker.u] is u
    listener: chooses(r in R, wpp=Pr[speaker.r == r])
    return Pr[listener.r == r]

beta = 1.
print(L(beta, 0))
print(L(beta, 1))
## Fitting the model to data...
Y = np.array([65, 115, 0]) / 180  # data from Qing & Franke 2015
@jax.jit
def loss(beta):
    return np.mean((L(beta, 1)[0] - Y) ** 2)

from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))

## Fitting by gradient descent!
vg = jax.value_and_grad(loss)
losses = []
beta = 0.
for _ in range(26):
    l, dbeta = vg(beta)
    losses.append(l)
    beta = beta - dbeta * 12.
plt.plot(np.arange(len(losses)), losses)
plt.ylabel('MSE (%)')
plt.xlabel('Step #')
plt.yticks([0, 0.02], [0, 2])
plt.title('Gradient descent')

plt.tight_layout()