Example of running jaxlogit with batched draws#

jaxlogit’s default way of processing random draws for simulation is to generate them once at the beginning and then run calculate the loglikelihood at each step of the optimization routine with these draws. The size of the corresponding array is (number_of_observations x number_of_random_variables x number_of_draws) which can get very large. In case this is too large for local memory, jaxlogit can dynamcially generate draws on each iteration. The advantage of this is that calculations can now be batched, i.e., processed on smaller subsets and then added up. This reduces memory load that the cost of runtime. Note that jax still calculates gradients so this method also has memory limits.

[1]:
import pandas as pd
import numpy as np
import jax

from jaxlogit.mixed_logit import MixedLogit, ConfigData
[2]:
#  64bit precision
jax.config.update("jax_enable_x64", True)

Electricity Dataset batching example#

From xlogit’s examples. Since this example shows how batching reduces memory load, to speed up test times we skip the calculation of std errors and reduce the maximum interations to 10.

[3]:
df = pd.read_csv("https://raw.githubusercontent.com/outerl/jaxlogit/refs/heads/main/examples/electricity_long.csv")
[4]:
n_obs = df['chid'].unique().shape[0]
n_vars = 6
n_draws = 5000
maxiter = 10

size_in_ram = (n_obs * n_vars * n_draws * 8) / (1024 ** 3)  # in GB

print(
    f"Data has {n_obs} observations, we use {n_vars} random variables in the model. We work in 64 bit precision, so each element is 8 bytes."
    + f" For {n_draws} draws, the array of draws is about {size_in_ram:.2f} GB."
)

varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']
Data has 4308 observations, we use 6 random variables in the model. We work in 64 bit precision, so each element is 8 bytes. For 5000 draws, the array of draws is about 0.96 GB.

Four batches#

First we try four batches

[5]:
n_batches = 4
batch_size = np.ceil(n_obs/n_batches)
print(f"For {n_batches} batches and {n_obs} obervations, batch size is {batch_size}")

model = MixedLogit()

config = ConfigData(
    panels=df['id'],
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=batch_size,
    optim_method="L-BFGS-scipy",  # "L-BFGS-B", "BFGS"lver
    maxiter=maxiter,
)

res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    alts=df['alt'],
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    config=config
)
display(model.summary())
For 4 batches and 4308 obervations, batch size is 1077.0
**** The optimization did not converge after 10 iterations. ****
Convergence not reached. The estimates may not be reliable.
    Message: max BFGS iters reached
    Iterations: 10
    Function evaluations: 13
Estimation time= 28.3 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.6058103     1.0000000    -0.6058103         0.545
cl                     -0.1305006     1.0000000    -0.1305006         0.896
loc                     1.3781105     1.0000000     1.3781105         0.168
wk                      1.2574063     1.0000000     1.2574063         0.209
tod                    -5.7968454     1.0000000    -5.7968454      7.24e-09 ***
seas                   -6.2281762     1.0000000    -6.2281762      5.17e-10 ***
sd.pf                  -3.5817888     1.0000000    -3.5817888      0.000345 ***
sd.cl                  -0.7694025     1.0000000    -0.7694025         0.442
sd.loc                  1.9279293     1.0000000     1.9279293        0.0539 .
sd.wk                   0.9267313     1.0000000     0.9267313         0.354
sd.tod                  4.6087321     1.0000000     4.6087321      4.17e-06 ***
sd.seas                 2.2479243     1.0000000     2.2479243        0.0246 *
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -4033.856
AIC= 8091.711
BIC= 8168.130
None

No batches#

[6]:
model = MixedLogit()

config = ConfigData(
    panels=df['id'],
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-scipy",
    maxiter=maxiter,
)

res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    alts=df['alt'],
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    config=config
)
display(model.summary())
**** The optimization did not converge after 10 iterations. ****
Convergence not reached. The estimates may not be reliable.
    Message: max BFGS iters reached
    Iterations: 10
    Function evaluations: 13
Estimation time= 21.1 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.6058103     1.0000000    -0.6058103         0.545
cl                     -0.1305006     1.0000000    -0.1305006         0.896
loc                     1.3781105     1.0000000     1.3781105         0.168
wk                      1.2574063     1.0000000     1.2574063         0.209
tod                    -5.7968454     1.0000000    -5.7968454      7.24e-09 ***
seas                   -6.2281762     1.0000000    -6.2281762      5.17e-10 ***
sd.pf                  -3.5817888     1.0000000    -3.5817888      0.000345 ***
sd.cl                  -0.7694025     1.0000000    -0.7694025         0.442
sd.loc                  1.9279293     1.0000000     1.9279293        0.0539 .
sd.wk                   0.9267313     1.0000000     0.9267313         0.354
sd.tod                  4.6087321     1.0000000     4.6087321      4.17e-06 ***
sd.seas                 2.2479243     1.0000000     2.2479243        0.0246 *
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -4033.856
AIC= 8091.711
BIC= 8168.130
None