from memo import memoimport jaximport jax.numpy as npfrom enum import IntEnum## Scalar implicatureNN =10_000N = np.arange(NN +1) # number of nice peopleclass U(IntEnum): NONE =0 SOME =1 ALL =2@jax.jitdef meaning(n, u): # (none) (some) (all)return np.array([ n ==0, n >0, n == NN ])[u]@memodef scalar[n: N, u: U](): listener: thinks[ speaker: chooses(n in N, wpp=1), speaker: chooses(u in U, wpp=imagine[ listener: knows(u), listener: chooses(n in N, wpp=meaning(n, u)), Pr[listener.n == n] ]) ] listener: observes [speaker.u] is u listener: chooses(n in N, wpp=E[speaker.n == n])return Pr[listener.n == n]scalar() # warm up JITimport timet_s = time.time()scalar()print(time.time() - t_s)