2/3rds Game



Anatomy of a memo model

memo models are decorated functions

from memo import memo

X = [1, 2, 3]
Y = range(10)

@memo
def my_model[x: X, y: Y](a, b=2, c=None):
    return x + y

my_model(1, b=2, c=3)

You can convert the output array into an pandas DataFrame or an xarray for easy indexing:

res = my_model(1, b=2, c=3, 
                        return_aux=True, 
                        return_pandas=True, 
                        return_xarray=True)
data = res.data
df = res.aux.pandas
xa = res.aux.xarray
print(f"===JAX array==="); print(data)
print(f"\n===pandas==="); print(df.head())
print(f"\n===xarray==="); print(xa)

chooses()

import jax.numpy as np

Z = np.arange(100) + 1

@memo
def my_model[z: Z]():
    kartik: chooses(z in Z, wpp=1)
    return E[kartik.z == z]

my_model()

expressions

import jax.numpy as np

Z = np.arange(100) + 1

@memo
def my_model[z: Z]():
    kartik: chooses(z in Z, wpp=1)
    return Pr[exp(kartik.z) > 5]

my_model()


@memo
def my_model[z: Z]():
    kartik: chooses(z in Z, wpp=3 if z % 2 == 0 else 1)
    return E[kartik.z == z]

my_model()


from matplotlib import pyplot as plt
import jax
from jax.scipy.stats.norm import pdf as normpdf

gaussianpdf = jax.jit(normpdf)

@memo
def my_model[z: Z]():
    kartik: chooses(z in Z, wpp=gaussianpdf(z, 20, 5))
    return E[kartik.z == z]

plt.plot(Z, my_model())

thinks[]

@memo
def my_model[z: Z]():
    kartik: thinks[
        maxkw: chooses(z in Z, wpp=3 if z % 2 == 0 else 1)
    ]
    return kartik[ E[maxkw.z == 1]]

my_model()



The 2/3rds game

import jax.numpy as np

N = np.arange(100) + 1

@memo
def reader_thinks():
    reader: thinks[
        population: chooses(n in N, wpp=1)  ### back to a uniform prior
    ]

    return reader[ E[population.n] ]  ### query what the reader thinks about the population

reader_thinks()



@memo
def reader_thinks():
    reader: thinks[
        population: chooses(n in N, wpp=1)
    ]

    return reader[ (2/3) * E[population.n] ]

reader_thinks()



from matplotlib import pyplot as plt

@memo
def reader_chooses[n: N](k):
    reader: thinks[
        population: chooses(
            n in N,
            wpp=reader_chooses[n](k-1) if k > 0 else 1
        )
    ]
    reader: chooses(n in N, wpp=exp(-abs((2/3)*E[population.n] - n)))
    return Pr[reader.n == n]

plt.plot(N, reader_chooses(0))



fig, ax = plt.subplots()
k_vals = [0, 1, 2, 5, 7]
for k_ in k_vals:
    ax.plot(N, reader_chooses(k_), alpha=0.8, label=f"k={k_}")
ax.set_ylim((0, 0.5))
ax.legend()