from memo import memo
import jax
import jax.numpy as np
from enum import IntEnum
print(jax.__version__)
class U(IntEnum): # utterance space
GREEN = 0
PINK = 1
SQUARE = 2
ROUND = 3
class R(IntEnum): # referent space
GREEN_SQUARE = 0
GREEN_CIRCLE = 1
PINK_CIRCLE = 2
@jax.jit
def denotes(u, r):
return np.array([
# green square
# | green circle
# | | pink circle
# | | |
[1, 1, 0], # "green"
[0, 0, 1], # "pink"
[1, 0, 0], # "square"
[0, 1, 1] # "round"
])[u, r]RSA
@memo
def L[u: U, r: R](beta, t):
listener: thinks[
speaker: given(r in R, wpp=1),
speaker: chooses(u in U, wpp=
denotes(u, r) * (1 if t == 0 else exp(beta * L[u, r](beta, t - 1))))
]
listener: observes [speaker.u] is u
listener: chooses(r in R, wpp=Pr[speaker.r == r])
return Pr[listener.r == r]
beta = 1.
print(L(beta, 0))
print(L(beta, 1))## Fitting the model to data...
Y = np.array([65, 115, 0]) / 180 # data from Qing & Franke 2015
@jax.jit
def loss(beta):
return np.mean((L(beta, 1)[0] - Y) ** 2)
from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))
## Fitting by gradient descent!
vg = jax.value_and_grad(loss)
losses = []
beta = 0.
for _ in range(26):
l, dbeta = vg(beta)
losses.append(l)
beta = beta - dbeta * 12.
plt.plot(np.arange(len(losses)), losses)
plt.ylabel('MSE (%)')
plt.xlabel('Step #')
plt.yticks([0, 0.02], [0, 2])
plt.title('Gradient descent')
plt.tight_layout()