Comparison of time taken and brier scores for jaxlogit, xlogit, and biogeme#

Where the estimation is using draws = 500 (suboptimal but highest without running out of memory in biogeme), and training and test data is separated.

jaxlogit-scipy

jaxlogit-jax

xlogit

biogeme

Making Model

33.1s

22.2s

18.5s

4:30

Estimating

1.6s

0.2s

0.0s

14.3s

Brier Score

0.624247

0.624247

0.624570

0.624163

Setup#

[1]:

import pandas as pd import numpy as np import jax import pathlib import xlogit import sklearn from jaxlogit.mixed_logit import MixedLogit, ConfigData from jaxlogit.utils import wide_to_long import biogeme.biogeme_logging as blog import biogeme.biogeme as bio from biogeme import models from biogeme.expressions import Beta, Draws, log, MonteCarlo, PanelLikelihoodTrajectory import biogeme.database as db from biogeme.expressions import Variable logger = blog.get_screen_logger() logger.setLevel(blog.INFO) # 64bit precision jax.config.update("jax_enable_x64", True)
/opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/arviz/__init__.py:39: FutureWarning:
ArviZ is undergoing a major refactor to improve flexibility and extensibility while maintaining a user-friendly interface.
Some upcoming changes may be backward incompatible.
For details and migration guidance, visit: https://python.arviz.org/en/latest/user_guide/migration_guide.html
  warn(
/opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/tqdm_joblib/__init__.py:4: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from tqdm.autonotebook import tqdm

Use for jaxlogit and xlogit. Adjustusting n_draws can improve accuracy, but Biogeme cannot handle 700 or more draws with this data set.

[2]:
varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']
n_draws = 500

Reshape the data so it can be passed to test_train_split in a wide format. Additionally, xlogit and jaxlogit require long format while biogeme requires a wide format.

[3]:
df_long = pd.read_csv(pathlib.Path.cwd().parent.parent / "examples" / "electricity_long.csv")
choice_df = df_long.loc[df_long['choice'] == 1, ['id', 'chid', 'alt']]
choice_df = choice_df.rename(columns={'alt': 'choice'})
df_wide = df_long.pivot(index=['id', 'chid'], columns='alt', values=varnames)
df_wide.columns = [f'{var}_{alt}' for var, alt in df_wide.columns]
df_wide = df_wide.reset_index()
df = df_wide.merge(
    choice_df,
    on=['id', 'chid'],
    how='inner',
    validate='one_to_one'
)

df_wide_train, df_wide_test = sklearn.model_selection.train_test_split(df, train_size=0.2)
df_train = wide_to_long(df_wide_train, "chid", [1,2,3,4], "alt", varying=varnames, panels=True)
df_train = df_train.sort_values(['chid', 'alt'])
df_test = wide_to_long(df_wide_test, "chid", [1,2,3,4], "alt", varying=varnames, panels=True)
df_test = df_test.sort_values(['chid', 'alt'])

df_wide_train = df_wide_train.sort_values('chid')
database_train = db.Database('electricity', df_wide_train)
database_train.panel('id')
database_test = db.Database('electricity', df_wide_test)

jaxlogit and xlogit setup:

[4]:
X_train = df_train[varnames]
y_train = df_train['choice']

ids_train = df_train['chid']
alts_train = df_train['alt']
panels_train = df_train['id']

X_test = df_test[varnames]
y_test = df_test['choice']

ids_test = df_test['chid']
alts_test = df_test['alt']
panels_test = df_test['id']
[5]:
randvars = {'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'}

model_jax = MixedLogit()
model_x = xlogit.MixedLogit()

init_coeff = None

Biogeme setup:

[6]:
X = {
    name: {
        j: Variable(f"{name}_{j}")
        for j in [1,2,3,4]
    }
    for name in varnames
}

alt_1 = Beta('alt_1', 0, None, None, 0)
alt_2 = Beta('alt_2', 0, None, None, 0)
alt_3 = Beta('alt_3', 0, None, None, 0)
alt_4 = Beta('alt_4', 0, None, None, 1)

pf_mean = Beta('pf_mean', 0, None, None, 0)
pf_sd = Beta('pf_sd', 1, None, None, 0)
cl_mean = Beta('cl_mean', 0, None, None, 0)
cl_sd = Beta('cl_sd', 1, None, None, 0)
loc_mean = Beta('loc_mean', 0, None, None, 0)
loc_sd = Beta('loc_sd', 1, None, None, 0)
wk_mean = Beta('wk_mean', 0, None, None, 0)
wk_sd = Beta('wk_sd', 1, None, None, 0)
tod_mean = Beta('tod_mean', 0, None, None, 0)
tod_sd = Beta('tod_sd', 1, None, None, 0)
seas_mean = Beta('seas_mean', 0, None, None, 0)
seas_sd = Beta('seas_sd', 1, None, None, 0)

pf_rnd = pf_mean + pf_sd * Draws('pf_rnd', 'NORMAL')
cl_rnd = cl_mean + cl_sd * Draws('cl_rnd', 'NORMAL')
loc_rnd = loc_mean + loc_sd * Draws('loc_rnd', 'NORMAL')
wk_rnd = wk_mean + wk_sd * Draws('wk_rnd', 'NORMAL')
tod_rnd = tod_mean + tod_sd * Draws('tod_rnd', 'NORMAL')
seas_rnd = seas_mean + seas_sd * Draws('seas_rnd', 'NORMAL')

choice = Variable('choice')

V = {
    j: pf_rnd * X['pf'][j] + cl_rnd * X['cl'][j] + loc_rnd * X['loc'][j] + wk_rnd * X['wk'][j] + tod_rnd * X['tod'][j] + seas_rnd * X['seas'][j]
    for j in [1,2,3,4]
}

Make the models#

Jaxlogit:

[7]:
model_jax_scipy = MixedLogit()
config = ConfigData(
    panels=panels_train,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-scipy",
)
model_jax_scipy.fit(
    X=X_train,
    y=y_train,
    varnames=varnames,
    ids=ids_train,
    alts=alts_train,
    randvars=randvars,
    config=config
)
display(model_jax_scipy.summary())
init_coeff_scipy = model_jax_scipy.coeff_
    Message: Operation terminated successfully
    Iterations: 103
    Function evaluations: 112
Estimation time= 8.8 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.9034407     1.0000000    -0.9034407         0.367
cl                     -0.2651879     1.0000000    -0.2651879         0.791
loc                     2.3902863     1.0000000     2.3902863         0.017 *
wk                      1.5523322     1.0000000     1.5523322         0.121
tod                    -9.0580513     1.0000000    -9.0580513      8.74e-19 ***
seas                   -9.2757488     1.0000000    -9.2757488      1.39e-19 ***
sd.pf                 -12.3754010     1.0000000   -12.3754010      1.71e-32 ***
sd.cl                  -0.6778763     1.0000000    -0.6778763         0.498
sd.loc                  0.4420251     1.0000000     0.4420251         0.659
sd.wk                   0.5413921     1.0000000     0.5413921         0.588
sd.tod                  3.0335262     1.0000000     3.0335262       0.00249 **
sd.seas                 1.5317662     1.0000000     1.5317662         0.126
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -899.391
AIC= 1822.781
BIC= 1879.878
None
[8]:
model_jax = MixedLogit()
config = ConfigData(
    panels=panels_train,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-jax",
)
model_jax.fit(
    X=X_train,
    y=y_train,
    varnames=varnames,
    ids=ids_train,
    alts=alts_train,
    randvars=randvars,
    config=config
)
display(model_jax.summary())
init_coeff_jax = model_jax.coeff_
**** The optimization did not converge after 4 iterations. ****
Convergence not reached. The estimates may not be reliable.
    Message: max line search iters reached
    Iterations: 4
    Function evaluations: 13
Estimation time= 4.8 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                      0.0819541     1.0000000     0.0819541         0.935
cl                     -0.3373719     1.0000000    -0.3373719         0.736
loc                     1.3096022     1.0000000     1.3096022         0.191
wk                      0.4720790     1.0000000     0.4720790         0.637
tod                     0.0489491     1.0000000     0.0489491         0.961
seas                   -0.6395835     1.0000000    -0.6395835         0.523
sd.pf                  -2.1613754     1.0000000    -2.1613754        0.0309 *
sd.cl                  -0.1939785     1.0000000    -0.1939785         0.846
sd.loc                  0.3585869     1.0000000     0.3585869          0.72
sd.wk                   0.3581432     1.0000000     0.3581432          0.72
sd.tod                  0.2222653     1.0000000     0.2222653         0.824
sd.seas                 0.1537772     1.0000000     0.1537772         0.878
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -1056.143
AIC= 2136.285
BIC= 2193.383
None

xlogit:

[9]:
model_x.fit(
    X=X_train,
    y=y_train,
    varnames=varnames,
    ids=ids_train,
    alts=alts_train,
    randvars=randvars,
    panels=panels_train,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-B",
)
display(model_x.summary())
init_coeff_x = model_x.coeff_
Optimization terminated successfully.
    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 92
    Function evaluations: 101
Estimation time= 5.9 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.9324873     1.0000000    -0.9324873         0.351
cl                     -0.2655834     1.0000000    -0.2655834         0.791
loc                     2.3868240     1.0000000     2.3868240        0.0172 *
wk                      1.5585239     1.0000000     1.5585239         0.119
tod                    -9.2178090     1.0000000    -9.2178090      2.27e-19 ***
seas                   -9.4389188     1.0000000    -9.4389188      3.42e-20 ***
sd.pf                   0.1227657     1.0000000     0.1227657         0.902
sd.cl                   0.4070082     1.0000000     0.4070082         0.684
sd.loc                  0.9470355     1.0000000     0.9470355         0.344
sd.wk                   1.0274017     1.0000000     1.0274017         0.305
sd.tod                  2.9194248     1.0000000     2.9194248        0.0036 **
sd.seas                 1.4444963     1.0000000     1.4444963         0.149
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -898.761
AIC= 1821.522
BIC= 1878.619
None

Biogeme:

[10]:
prob = models.logit(V, None, choice)
logprob = log(MonteCarlo(PanelLikelihoodTrajectory(prob)))

the_biogeme = bio.BIOGEME(
    database_train, logprob, number_of_draws=n_draws, seed=999, generate_yaml=False, generate_html=False
)
the_biogeme.model_name = 'model_b'
results = the_biogeme.estimate()
print(results)
File biogeme.toml has been created
The number of draws (500) is low. The results may not be meaningful.
Results for model model_b
Nbr of parameters:              12
Sample size:                    339
Observations:                   861
Excluded data:                  0
Final log likelihood:           -896.35
Akaike Information Criterion:   1816.7
Bayesian Information Criterion: 1862.612

Compare parameters:#

[11]:
print("{:>9} {:>20} {:>15} {:>13} {:>13}".format("Estimate", "Jaxlogit-scipy", "Jaxlogit-jax", "Xlogit", "Biogeme"))
print("-" * 76)
fmt = "{:9} {:18.7f} {:16.7f} {:15.7f} {:13.7f}"
biogeme_values = results.get_beta_values()
coeff_names = {'pf': 'pf_mean', 'sd.pf': 'pf_sd', 'cl': 'cl_mean', 'sd.cl': 'cl_sd', 'loc': 'loc_mean', 'sd.loc': 'loc_sd', 'wk': 'wk_mean', 'sd.wk': 'wk_sd', 'tod': 'tod_mean', 'sd.tod': 'tod_sd', 'seas': 'seas_mean', 'sd.seas': 'seas_sd'}
for i in range(len(model_jax.coeff_)):
    name = model_jax.coeff_names[i]
    print(fmt.format(name[:13],
                     model_jax_scipy.coeff_[i],
                     model_jax.coeff_[i],
                     model_x.coeff_[i],
                     biogeme_values[coeff_names[name]]))
print("-" * 76)
 Estimate       Jaxlogit-scipy    Jaxlogit-jax        Xlogit       Biogeme
----------------------------------------------------------------------------
pf                -0.9034407        0.0819541      -0.9324873    -0.9290651
cl                -0.2651879       -0.3373719      -0.2655834    -0.2658865
loc                2.3902863        1.3096022       2.3868240     2.3662199
wk                 1.5523322        0.4720790       1.5585239     1.5556951
tod               -9.0580513        0.0489491      -9.2178090    -9.2291964
seas              -9.2757488       -0.6395835      -9.4389188    -9.4137359
sd.pf            -12.3754010       -2.1613754       0.1227657     0.1298152
sd.cl             -0.6778763       -0.1939785       0.4070082     0.4029175
sd.loc             0.4420251        0.3585869       0.9470355     1.0222264
sd.wk              0.5413921        0.3581432       1.0274017     0.9232866
sd.tod             3.0335262        0.2222653       2.9194248     2.9030425
sd.seas            1.5317662        0.1537772       1.4444963     1.3707404
----------------------------------------------------------------------------

Predict#

jaxlogit:

[12]:
model = model_jax_scipy
config = ConfigData(
    panels=panels_test,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-scipy",
)
config.init_coeff = init_coeff_scipy
prob_j_scipy = model.predict(X_test, varnames, alts_test, ids_test, randvars, config)
[13]:
model = model_jax
config = ConfigData(
    panels=panels_test,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-jax",
)
config.init_coeff = init_coeff_jax
prob_j_jax = model.predict(X_test, varnames, alts_test, ids_test, randvars, config)

xlogit:

[14]:
_, prob_xx = model_x.predict(X_test, varnames, alts_test, ids_test, isvars=None, panels=panels_test, n_draws=n_draws, return_proba=True)

Biogeme:

[15]:
P = {
    j: MonteCarlo(models.logit(V, None, j))
    for j in [1, 2, 3, 4]
}

simulate = {
    f'Prob_alt{j}': P[j]
    for j in [1, 2, 3, 4]
}

biogeme_sim = bio.BIOGEME(database_test, simulate)
biogeme_sim.model_name = 'per_choice_probs'

probs = biogeme_sim.simulate(results.get_beta_values())

Compute the brier score:

[16]:
print("{:>9} {:>9} {:>9} {:>9}".format("Jaxlogit-scipy", "Jaxlogit-jax", "xlogit", "Biogeme"))
print("-" * 48)
fmt = "{:9f} {:9f} {:9f} {:9f}"
print(fmt.format(sklearn.metrics.brier_score_loss(np.reshape(y_test, (prob_j_scipy.shape[0], -1)), prob_j_scipy),
                 sklearn.metrics.brier_score_loss(np.reshape(y_test, (prob_j_jax.shape[0], -1)), prob_j_jax),
                 sklearn.metrics.brier_score_loss(np.reshape(y_test, (prob_xx.shape[0], -1)), prob_xx),
                 sklearn.metrics.brier_score_loss(df_wide_test['choice'], probs)))
print("-" * 48)
Jaxlogit-scipy Jaxlogit-jax    xlogit   Biogeme
------------------------------------------------
 0.630974  0.716454  0.630776  0.630947
------------------------------------------------