Comparison of time taken and brier scores for jaxlogit, xlogit, and biogeme#
Where the estimation is using draws = 500 (suboptimal but highest without running out of memory in biogeme), and training and test data is separated.
jaxlogit-scipy |
jaxlogit-jax |
xlogit |
biogeme |
|
|---|---|---|---|---|
Making Model |
33.1s |
22.2s |
18.5s |
4:30 |
Estimating |
1.6s |
0.2s |
0.0s |
14.3s |
Brier Score |
0.624247 |
0.624247 |
0.624570 |
0.624163 |
Setup#
[1]:
import pandas as pd
import numpy as np
import jax
import pathlib
import xlogit
import sklearn
from jaxlogit.mixed_logit import MixedLogit, ConfigData
from jaxlogit.utils import wide_to_long
import biogeme.biogeme_logging as blog
import biogeme.biogeme as bio
from biogeme import models
from biogeme.expressions import Beta, Draws, log, MonteCarlo, PanelLikelihoodTrajectory
import biogeme.database as db
from biogeme.expressions import Variable
logger = blog.get_screen_logger()
logger.setLevel(blog.INFO)
# 64bit precision
jax.config.update("jax_enable_x64", True)
/opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/arviz/__init__.py:39: FutureWarning:
ArviZ is undergoing a major refactor to improve flexibility and extensibility while maintaining a user-friendly interface.
Some upcoming changes may be backward incompatible.
For details and migration guidance, visit: https://python.arviz.org/en/latest/user_guide/migration_guide.html
warn(
/opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/tqdm_joblib/__init__.py:4: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from tqdm.autonotebook import tqdm
Use for jaxlogit and xlogit. Adjustusting n_draws can improve accuracy, but Biogeme cannot handle 700 or more draws with this data set.
[2]:
varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']
n_draws = 500
Reshape the data so it can be passed to test_train_split in a wide format. Additionally, xlogit and jaxlogit require long format while biogeme requires a wide format.
[3]:
df_long = pd.read_csv(pathlib.Path.cwd().parent.parent / "examples" / "electricity_long.csv")
choice_df = df_long.loc[df_long['choice'] == 1, ['id', 'chid', 'alt']]
choice_df = choice_df.rename(columns={'alt': 'choice'})
df_wide = df_long.pivot(index=['id', 'chid'], columns='alt', values=varnames)
df_wide.columns = [f'{var}_{alt}' for var, alt in df_wide.columns]
df_wide = df_wide.reset_index()
df = df_wide.merge(
choice_df,
on=['id', 'chid'],
how='inner',
validate='one_to_one'
)
df_wide_train, df_wide_test = sklearn.model_selection.train_test_split(df, train_size=0.2)
df_train = wide_to_long(df_wide_train, "chid", [1,2,3,4], "alt", varying=varnames, panels=True)
df_train = df_train.sort_values(['chid', 'alt'])
df_test = wide_to_long(df_wide_test, "chid", [1,2,3,4], "alt", varying=varnames, panels=True)
df_test = df_test.sort_values(['chid', 'alt'])
df_wide_train = df_wide_train.sort_values('chid')
database_train = db.Database('electricity', df_wide_train)
database_train.panel('id')
database_test = db.Database('electricity', df_wide_test)
jaxlogit and xlogit setup:
[4]:
X_train = df_train[varnames]
y_train = df_train['choice']
ids_train = df_train['chid']
alts_train = df_train['alt']
panels_train = df_train['id']
X_test = df_test[varnames]
y_test = df_test['choice']
ids_test = df_test['chid']
alts_test = df_test['alt']
panels_test = df_test['id']
[5]:
randvars = {'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'}
model_jax = MixedLogit()
model_x = xlogit.MixedLogit()
init_coeff = None
Biogeme setup:
[6]:
X = {
name: {
j: Variable(f"{name}_{j}")
for j in [1,2,3,4]
}
for name in varnames
}
alt_1 = Beta('alt_1', 0, None, None, 0)
alt_2 = Beta('alt_2', 0, None, None, 0)
alt_3 = Beta('alt_3', 0, None, None, 0)
alt_4 = Beta('alt_4', 0, None, None, 1)
pf_mean = Beta('pf_mean', 0, None, None, 0)
pf_sd = Beta('pf_sd', 1, None, None, 0)
cl_mean = Beta('cl_mean', 0, None, None, 0)
cl_sd = Beta('cl_sd', 1, None, None, 0)
loc_mean = Beta('loc_mean', 0, None, None, 0)
loc_sd = Beta('loc_sd', 1, None, None, 0)
wk_mean = Beta('wk_mean', 0, None, None, 0)
wk_sd = Beta('wk_sd', 1, None, None, 0)
tod_mean = Beta('tod_mean', 0, None, None, 0)
tod_sd = Beta('tod_sd', 1, None, None, 0)
seas_mean = Beta('seas_mean', 0, None, None, 0)
seas_sd = Beta('seas_sd', 1, None, None, 0)
pf_rnd = pf_mean + pf_sd * Draws('pf_rnd', 'NORMAL')
cl_rnd = cl_mean + cl_sd * Draws('cl_rnd', 'NORMAL')
loc_rnd = loc_mean + loc_sd * Draws('loc_rnd', 'NORMAL')
wk_rnd = wk_mean + wk_sd * Draws('wk_rnd', 'NORMAL')
tod_rnd = tod_mean + tod_sd * Draws('tod_rnd', 'NORMAL')
seas_rnd = seas_mean + seas_sd * Draws('seas_rnd', 'NORMAL')
choice = Variable('choice')
V = {
j: pf_rnd * X['pf'][j] + cl_rnd * X['cl'][j] + loc_rnd * X['loc'][j] + wk_rnd * X['wk'][j] + tod_rnd * X['tod'][j] + seas_rnd * X['seas'][j]
for j in [1,2,3,4]
}
Make the models#
Jaxlogit:
[7]:
model_jax_scipy = MixedLogit()
config = ConfigData(
panels=panels_train,
n_draws=n_draws,
skip_std_errs=True, # skip standard errors to speed up the example
batch_size=None,
optim_method="L-BFGS-scipy",
)
model_jax_scipy.fit(
X=X_train,
y=y_train,
varnames=varnames,
ids=ids_train,
alts=alts_train,
randvars=randvars,
config=config
)
display(model_jax_scipy.summary())
init_coeff_scipy = model_jax_scipy.coeff_
Message: Operation terminated successfully
Iterations: 103
Function evaluations: 112
Estimation time= 8.8 seconds
---------------------------------------------------------------------------
Coefficient Estimate Std.Err. z-val P>|z|
---------------------------------------------------------------------------
pf -0.9034407 1.0000000 -0.9034407 0.367
cl -0.2651879 1.0000000 -0.2651879 0.791
loc 2.3902863 1.0000000 2.3902863 0.017 *
wk 1.5523322 1.0000000 1.5523322 0.121
tod -9.0580513 1.0000000 -9.0580513 8.74e-19 ***
seas -9.2757488 1.0000000 -9.2757488 1.39e-19 ***
sd.pf -12.3754010 1.0000000 -12.3754010 1.71e-32 ***
sd.cl -0.6778763 1.0000000 -0.6778763 0.498
sd.loc 0.4420251 1.0000000 0.4420251 0.659
sd.wk 0.5413921 1.0000000 0.5413921 0.588
sd.tod 3.0335262 1.0000000 3.0335262 0.00249 **
sd.seas 1.5317662 1.0000000 1.5317662 0.126
---------------------------------------------------------------------------
Significance: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Log-Likelihood= -899.391
AIC= 1822.781
BIC= 1879.878
None
[8]:
model_jax = MixedLogit()
config = ConfigData(
panels=panels_train,
n_draws=n_draws,
skip_std_errs=True, # skip standard errors to speed up the example
batch_size=None,
optim_method="L-BFGS-jax",
)
model_jax.fit(
X=X_train,
y=y_train,
varnames=varnames,
ids=ids_train,
alts=alts_train,
randvars=randvars,
config=config
)
display(model_jax.summary())
init_coeff_jax = model_jax.coeff_
**** The optimization did not converge after 4 iterations. ****
Convergence not reached. The estimates may not be reliable.
Message: max line search iters reached
Iterations: 4
Function evaluations: 13
Estimation time= 4.8 seconds
---------------------------------------------------------------------------
Coefficient Estimate Std.Err. z-val P>|z|
---------------------------------------------------------------------------
pf 0.0819541 1.0000000 0.0819541 0.935
cl -0.3373719 1.0000000 -0.3373719 0.736
loc 1.3096022 1.0000000 1.3096022 0.191
wk 0.4720790 1.0000000 0.4720790 0.637
tod 0.0489491 1.0000000 0.0489491 0.961
seas -0.6395835 1.0000000 -0.6395835 0.523
sd.pf -2.1613754 1.0000000 -2.1613754 0.0309 *
sd.cl -0.1939785 1.0000000 -0.1939785 0.846
sd.loc 0.3585869 1.0000000 0.3585869 0.72
sd.wk 0.3581432 1.0000000 0.3581432 0.72
sd.tod 0.2222653 1.0000000 0.2222653 0.824
sd.seas 0.1537772 1.0000000 0.1537772 0.878
---------------------------------------------------------------------------
Significance: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Log-Likelihood= -1056.143
AIC= 2136.285
BIC= 2193.383
None
xlogit:
[9]:
model_x.fit(
X=X_train,
y=y_train,
varnames=varnames,
ids=ids_train,
alts=alts_train,
randvars=randvars,
panels=panels_train,
n_draws=n_draws,
skip_std_errs=True, # skip standard errors to speed up the example
batch_size=None,
optim_method="L-BFGS-B",
)
display(model_x.summary())
init_coeff_x = model_x.coeff_
Optimization terminated successfully.
Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
Iterations: 92
Function evaluations: 101
Estimation time= 5.9 seconds
---------------------------------------------------------------------------
Coefficient Estimate Std.Err. z-val P>|z|
---------------------------------------------------------------------------
pf -0.9324873 1.0000000 -0.9324873 0.351
cl -0.2655834 1.0000000 -0.2655834 0.791
loc 2.3868240 1.0000000 2.3868240 0.0172 *
wk 1.5585239 1.0000000 1.5585239 0.119
tod -9.2178090 1.0000000 -9.2178090 2.27e-19 ***
seas -9.4389188 1.0000000 -9.4389188 3.42e-20 ***
sd.pf 0.1227657 1.0000000 0.1227657 0.902
sd.cl 0.4070082 1.0000000 0.4070082 0.684
sd.loc 0.9470355 1.0000000 0.9470355 0.344
sd.wk 1.0274017 1.0000000 1.0274017 0.305
sd.tod 2.9194248 1.0000000 2.9194248 0.0036 **
sd.seas 1.4444963 1.0000000 1.4444963 0.149
---------------------------------------------------------------------------
Significance: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Log-Likelihood= -898.761
AIC= 1821.522
BIC= 1878.619
None
Biogeme:
[10]:
prob = models.logit(V, None, choice)
logprob = log(MonteCarlo(PanelLikelihoodTrajectory(prob)))
the_biogeme = bio.BIOGEME(
database_train, logprob, number_of_draws=n_draws, seed=999, generate_yaml=False, generate_html=False
)
the_biogeme.model_name = 'model_b'
results = the_biogeme.estimate()
print(results)
File biogeme.toml has been created
The number of draws (500) is low. The results may not be meaningful.
Results for model model_b
Nbr of parameters: 12
Sample size: 339
Observations: 861
Excluded data: 0
Final log likelihood: -896.35
Akaike Information Criterion: 1816.7
Bayesian Information Criterion: 1862.612
Compare parameters:#
[11]:
print("{:>9} {:>20} {:>15} {:>13} {:>13}".format("Estimate", "Jaxlogit-scipy", "Jaxlogit-jax", "Xlogit", "Biogeme"))
print("-" * 76)
fmt = "{:9} {:18.7f} {:16.7f} {:15.7f} {:13.7f}"
biogeme_values = results.get_beta_values()
coeff_names = {'pf': 'pf_mean', 'sd.pf': 'pf_sd', 'cl': 'cl_mean', 'sd.cl': 'cl_sd', 'loc': 'loc_mean', 'sd.loc': 'loc_sd', 'wk': 'wk_mean', 'sd.wk': 'wk_sd', 'tod': 'tod_mean', 'sd.tod': 'tod_sd', 'seas': 'seas_mean', 'sd.seas': 'seas_sd'}
for i in range(len(model_jax.coeff_)):
name = model_jax.coeff_names[i]
print(fmt.format(name[:13],
model_jax_scipy.coeff_[i],
model_jax.coeff_[i],
model_x.coeff_[i],
biogeme_values[coeff_names[name]]))
print("-" * 76)
Estimate Jaxlogit-scipy Jaxlogit-jax Xlogit Biogeme
----------------------------------------------------------------------------
pf -0.9034407 0.0819541 -0.9324873 -0.9290651
cl -0.2651879 -0.3373719 -0.2655834 -0.2658865
loc 2.3902863 1.3096022 2.3868240 2.3662199
wk 1.5523322 0.4720790 1.5585239 1.5556951
tod -9.0580513 0.0489491 -9.2178090 -9.2291964
seas -9.2757488 -0.6395835 -9.4389188 -9.4137359
sd.pf -12.3754010 -2.1613754 0.1227657 0.1298152
sd.cl -0.6778763 -0.1939785 0.4070082 0.4029175
sd.loc 0.4420251 0.3585869 0.9470355 1.0222264
sd.wk 0.5413921 0.3581432 1.0274017 0.9232866
sd.tod 3.0335262 0.2222653 2.9194248 2.9030425
sd.seas 1.5317662 0.1537772 1.4444963 1.3707404
----------------------------------------------------------------------------
Predict#
jaxlogit:
[12]:
model = model_jax_scipy
config = ConfigData(
panels=panels_test,
n_draws=n_draws,
skip_std_errs=True, # skip standard errors to speed up the example
batch_size=None,
optim_method="L-BFGS-scipy",
)
config.init_coeff = init_coeff_scipy
prob_j_scipy = model.predict(X_test, varnames, alts_test, ids_test, randvars, config)
[13]:
model = model_jax
config = ConfigData(
panels=panels_test,
n_draws=n_draws,
skip_std_errs=True, # skip standard errors to speed up the example
batch_size=None,
optim_method="L-BFGS-jax",
)
config.init_coeff = init_coeff_jax
prob_j_jax = model.predict(X_test, varnames, alts_test, ids_test, randvars, config)
xlogit:
[14]:
_, prob_xx = model_x.predict(X_test, varnames, alts_test, ids_test, isvars=None, panels=panels_test, n_draws=n_draws, return_proba=True)
Biogeme:
[15]:
P = {
j: MonteCarlo(models.logit(V, None, j))
for j in [1, 2, 3, 4]
}
simulate = {
f'Prob_alt{j}': P[j]
for j in [1, 2, 3, 4]
}
biogeme_sim = bio.BIOGEME(database_test, simulate)
biogeme_sim.model_name = 'per_choice_probs'
probs = biogeme_sim.simulate(results.get_beta_values())
Compute the brier score:
[16]:
print("{:>9} {:>9} {:>9} {:>9}".format("Jaxlogit-scipy", "Jaxlogit-jax", "xlogit", "Biogeme"))
print("-" * 48)
fmt = "{:9f} {:9f} {:9f} {:9f}"
print(fmt.format(sklearn.metrics.brier_score_loss(np.reshape(y_test, (prob_j_scipy.shape[0], -1)), prob_j_scipy),
sklearn.metrics.brier_score_loss(np.reshape(y_test, (prob_j_jax.shape[0], -1)), prob_j_jax),
sklearn.metrics.brier_score_loss(np.reshape(y_test, (prob_xx.shape[0], -1)), prob_xx),
sklearn.metrics.brier_score_loss(df_wide_test['choice'], probs)))
print("-" * 48)
Jaxlogit-scipy Jaxlogit-jax xlogit Biogeme
------------------------------------------------
0.630974 0.716454 0.630776 0.630947
------------------------------------------------