from memo import memo
X = [1, 2, 3]
Y = range(10)
@memo
def my_model[x: X, y: Y](a, b=2, c=None):
return x + y
my_model(1, b=2, c=3)2/3rds Game
Anatomy of a memo model
memo models are decorated functions
You can convert the output array into an pandas DataFrame or an xarray for easy indexing:
res = my_model(1, b=2, c=3,
return_aux=True,
return_pandas=True,
return_xarray=True)
data = res.data
df = res.aux.pandas
xa = res.aux.xarray
print(f"===JAX array==="); print(data)
print(f"\n===pandas==="); print(df.head())
print(f"\n===xarray==="); print(xa)chooses()
import jax.numpy as np
Z = np.arange(100) + 1
@memo
def my_model[z: Z]():
kartik: chooses(z in Z, wpp=1)
return E[kartik.z == z]
my_model()expressions
import jax.numpy as np
Z = np.arange(100) + 1
@memo
def my_model[z: Z]():
kartik: chooses(z in Z, wpp=1)
return Pr[exp(kartik.z) > 5]
my_model()@memo
def my_model[z: Z]():
kartik: chooses(z in Z, wpp=3 if z % 2 == 0 else 1)
return E[kartik.z == z]
my_model()from matplotlib import pyplot as plt
import jax
from jax.scipy.stats.norm import pdf as normpdf
gaussianpdf = jax.jit(normpdf)
@memo
def my_model[z: Z]():
kartik: chooses(z in Z, wpp=gaussianpdf(z, 20, 5))
return E[kartik.z == z]
plt.plot(Z, my_model())thinks[]
@memo
def my_model[z: Z]():
kartik: thinks[
maxkw: chooses(z in Z, wpp=3 if z % 2 == 0 else 1)
]
return kartik[ E[maxkw.z == 1]]
my_model()The 2/3rds game
import jax.numpy as np
N = np.arange(100) + 1
@memo
def reader_thinks():
reader: thinks[
population: chooses(n in N, wpp=1) ### back to a uniform prior
]
return reader[ E[population.n] ] ### query what the reader thinks about the population
reader_thinks()@memo
def reader_thinks():
reader: thinks[
population: chooses(n in N, wpp=1)
]
return reader[ (2/3) * E[population.n] ]
reader_thinks()from matplotlib import pyplot as plt
@memo
def reader_chooses[n: N](k):
reader: thinks[
population: chooses(
n in N,
wpp=reader_chooses[n](k-1) if k > 0 else 1
)
]
reader: chooses(n in N, wpp=exp(-abs((2/3)*E[population.n] - n)))
return Pr[reader.n == n]
plt.plot(N, reader_chooses(0))fig, ax = plt.subplots()
k_vals = [0, 1, 2, 5, 7]
for k_ in k_vals:
ax.plot(N, reader_chooses(k_), alpha=0.8, label=f"k={k_}")
ax.set_ylim((0, 0.5))
ax.legend()