Mixed Logit#
Based on the xlogit example Mixed Logit.
[1]:
import pandas as pd
import numpy as np
import jax
from jaxlogit.mixed_logit import MixedLogit, ConfigData
[2]:
# 64bit precision
jax.config.update("jax_enable_x64", True)
Electricity Dataset#
The electricity dataset contains 4,308 choices among four electricity suppliers based on the attributes of the offered plans, which include prices(pf), contract lengths(cl), time of day rates (tod), seasonal rates(seas), as well as attributes of the suppliers, which include whether the supplier is local (loc) and well-known (wk). The data was collected through a survey where 12 different choice situations were presented to each participant. The multiple responses per participants were organized
into panels. Given that some participants answered less than 12 of the choice situations, some panels are unbalanced, which jaxlogit is able to handle. Revelt and Train (1999) provide a detailed description of this dataset.
Read data#
The dataset is already in long format so no reshaping is necessary, it can be used directly in jaxlogit.
[3]:
df = pd.read_csv("https://raw.githubusercontent.com/outerl/jaxlogit/refs/heads/main/examples/electricity_long.csv")
df
[3]:
| choice | id | alt | pf | cl | loc | wk | tod | seas | chid | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 1 | 1 | 7 | 5 | 0 | 1 | 0 | 0 | 1 |
| 1 | 0 | 1 | 2 | 9 | 1 | 1 | 0 | 0 | 0 | 1 |
| 2 | 0 | 1 | 3 | 0 | 0 | 0 | 0 | 0 | 1 | 1 |
| 3 | 1 | 1 | 4 | 0 | 5 | 0 | 1 | 1 | 0 | 1 |
| 4 | 0 | 1 | 1 | 7 | 0 | 0 | 1 | 0 | 0 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 17227 | 0 | 361 | 4 | 0 | 1 | 1 | 0 | 0 | 1 | 4307 |
| 17228 | 1 | 361 | 1 | 9 | 0 | 0 | 1 | 0 | 0 | 4308 |
| 17229 | 0 | 361 | 2 | 7 | 0 | 0 | 0 | 0 | 0 | 4308 |
| 17230 | 0 | 361 | 3 | 0 | 1 | 0 | 1 | 0 | 1 | 4308 |
| 17231 | 0 | 361 | 4 | 0 | 5 | 1 | 0 | 1 | 0 | 4308 |
17232 rows × 10 columns
Fit the model#
Note that the parameter panels was included in the fit function in order to take into account panel structure of this dataset during estimation.
[4]:
varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']
model = MixedLogit()
config = ConfigData(
n_draws=600,
panels=df['id'],
)
res = model.fit(
X=df[varnames],
y=df['choice'],
varnames=varnames,
ids=df['chid'],
alts=df['alt'],
randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
config=config
)
model.summary()
Message: Operation terminated successfully
Iterations: 80
Function evaluations: 97
Estimation time= 38.2 seconds
---------------------------------------------------------------------------
Coefficient Estimate Std.Err. z-val P>|z|
---------------------------------------------------------------------------
pf -0.9972244 0.0552512 -18.0489244 2.94e-70 ***
cl -0.2196763 0.0298504 -7.3592361 2.2e-13 ***
loc 2.2901926 0.1375747 16.6469044 2.36e-60 ***
wk 1.6943196 0.1103646 15.3520167 7.91e-52 ***
tod -9.6753913 0.4904010 -19.7295489 5.04e-83 ***
seas -9.6962087 0.4876893 -19.8819399 3.14e-84 ***
sd.pf -1.3984445 0.1250274 -11.1851049 1.19e-28 ***
sd.cl -0.6750223 0.0933470 -7.2313201 5.63e-13 ***
sd.loc 1.6001551 0.1526084 10.4853699 2.04e-25 ***
sd.wk 0.8837811 0.1486714 5.9445258 2.99e-09 ***
sd.tod 2.1673560 0.2205245 9.8281855 1.47e-22 ***
sd.seas 1.2295732 0.2222076 5.5334435 3.33e-08 ***
---------------------------------------------------------------------------
Significance: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Log-Likelihood= -3888.465
AIC= 7800.930
BIC= 7877.349
[5]:
# Note the sd. variables in jaxlogit are softplus transformed by default such that they are always positive. To compare these to xlogits results at https://github.com/arteagac/xlogit/blob/master/examples/mixed_logit_model.ipynb
# use jax.nn.softplus(params) for non-asserted sd. params. Or run w/o softplus:
model = MixedLogit()
config = ConfigData(
force_positive_chol_diag=False,
panels=df['id'],
n_draws=600,
)
res = model.fit(
X=df[varnames],
y=df['choice'],
varnames=varnames,
ids=df['chid'],
alts=df['alt'],
randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
config=config
)
model.summary()
Message: Operation terminated successfully
Iterations: 79
Function evaluations: 88
Estimation time= 32.7 seconds
---------------------------------------------------------------------------
Coefficient Estimate Std.Err. z-val P>|z|
---------------------------------------------------------------------------
pf -0.9972195 0.0552517 -18.0486708 2.95e-70 ***
cl -0.2196723 0.0298517 -7.3587774 2.21e-13 ***
loc 2.2902120 0.1375768 16.6467857 2.37e-60 ***
wk 1.6943254 0.1103640 15.3521591 7.89e-52 ***
tod -9.6753001 0.4904065 -19.7291442 5.07e-83 ***
seas -9.6963453 0.4877121 -19.8812892 3.18e-84 ***
sd.pf 0.2207239 0.0247733 8.9097514 7.41e-19 ***
sd.cl 0.4115687 0.0314965 13.0671437 2.69e-38 ***
sd.loc 1.7840564 0.1269839 14.0494644 7.12e-44 ***
sd.wk 1.2296080 0.1052025 11.6880097 4.31e-31 ***
sd.tod 2.2757668 0.1978899 11.5001677 3.61e-30 ***
sd.seas 1.4864533 0.1719911 8.6426148 7.63e-18 ***
---------------------------------------------------------------------------
Significance: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Log-Likelihood= -3888.465
AIC= 7800.930
BIC= 7877.349
Fishing Dataset#
This example illustrates the estimation of a Mixed Logit model for choices of 1,182 individuals for sport fishing modes using jaxlogit. The goal is to analyse the market shares of four alternatives (i.e., beach, pier, boat, and charter) based on their cost and fish catch rate. Cameron (2005) provides additional details about this dataset. The following code illustrates how to use jaxlogit to estimate the model parameters.
Read data#
The data to be analyzed can be imported to Python using any preferred method.
[6]:
import pandas as pd
df = pd.read_csv("https://raw.github.com/arteagac/xlogit/master/examples/data/fishing_long.csv")
df
[6]:
| id | alt | choice | income | price | catch | |
|---|---|---|---|---|---|---|
| 0 | 1 | beach | 0 | 7083.33170 | 157.930 | 0.0678 |
| 1 | 1 | boat | 0 | 7083.33170 | 157.930 | 0.2601 |
| 2 | 1 | charter | 1 | 7083.33170 | 182.930 | 0.5391 |
| 3 | 1 | pier | 0 | 7083.33170 | 157.930 | 0.0503 |
| 4 | 2 | beach | 0 | 1249.99980 | 15.114 | 0.1049 |
| ... | ... | ... | ... | ... | ... | ... |
| 4723 | 1181 | pier | 0 | 416.66668 | 36.636 | 0.4522 |
| 4724 | 1182 | beach | 0 | 6250.00130 | 339.890 | 0.2537 |
| 4725 | 1182 | boat | 1 | 6250.00130 | 235.436 | 0.6817 |
| 4726 | 1182 | charter | 0 | 6250.00130 | 260.436 | 2.3014 |
| 4727 | 1182 | pier | 0 | 6250.00130 | 339.890 | 0.1498 |
4728 rows × 6 columns
Fit model#
Once the data is in the Python environment, jaxlogit can be used to fit the model, as shown below. The MultinomialLogit class is imported from jaxlogit, and its constructor is used to initialise a new model. The fit method estimates the model using the input data and estimation criteria provided as arguments to the method’s call. The arguments of the fit methods are described in jaxlogit’s documentation.
[7]:
varnames = ['price', 'catch']
model = MixedLogit()
config = ConfigData(
n_draws=2000, # Note using 1000 draws here leads to sd.catch going to zero, need more draws to find minimum at positive stddev
)
model.fit(
X=df[varnames],
y=df['choice'],
varnames=varnames,
alts=df['alt'],
ids=df['id'],
randvars={'price': 'n', 'catch': 'n'},
config=config
)
model.summary()
Message: Operation terminated successfully
Iterations: 40
Function evaluations: 59
Estimation time= 10.3 seconds
---------------------------------------------------------------------------
Coefficient Estimate Std.Err. z-val P>|z|
---------------------------------------------------------------------------
price -0.0272479 0.0021055 -12.9413936 6.38e-36 ***
catch 1.3258114 0.1838717 7.2105248 9.92e-13 ***
sd.price -4.5741734 0.2091743 -21.8677566 2.74e-89 ***
sd.catch 1.3316619 0.3791517 3.5122142 0.000461 ***
---------------------------------------------------------------------------
Significance: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Log-Likelihood= -1300.582
AIC= 2609.163
BIC= 2629.463
[8]:
# sd. vals agree with xlogit results except for sign of sd.catch, which is due to xlogit not restricting the sd devs to positive parameters and the log-likelihood being symmetric wrt to sign of normal std dev for non-correlated parameters.
jax.nn.softplus(model.coeff_[len(model._rvidx):])
[8]:
Array([0.01026199, 1.56597334], dtype=float64)
Car Dataset#
The fourth example uses a stated preference panel dataset for choice of car. Three alternatives are considered, with up to 6 choice situations per individual. This again is an unbalanced panel with responses of some individuals less than 6 situations. The dataset contains 8 explanatry variables: price, operating cost, range, and binary indicators to indicate whether the car is electric, hybrid, and if performance is high or medium respectively. This dataset was taken from Kenneth Train’s MATLAB codes for estimation of Mixed Logit models as shown in this link: https://eml.berkeley.edu/Software/abstracts/train1006mxlmsl.html
Read data#
[9]:
import pandas as pd
import numpy as np
df = pd.read_csv("https://raw.github.com/arteagac/xlogit/master/examples/data/car100_long.csv")
Since price and operating cost need to be estimated with negative coefficients, we reverse the variable signs in the dataframe.
[10]:
df['price'] = -df['price']/10000
df['opcost'] = -df['opcost']
df
[10]:
| person_id | choice_id | alt | choice | price | opcost | range | ev | gas | hybrid | hiperf | medhiperf | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 1 | 1 | 0 | -4.6763 | -47.43 | 0.0 | 0 | 0 | 1 | 0 | 0 |
| 1 | 1 | 1 | 2 | 1 | -5.7209 | -27.43 | 1.3 | 1 | 0 | 0 | 1 | 1 |
| 2 | 1 | 1 | 3 | 0 | -8.7960 | -32.41 | 1.2 | 1 | 0 | 0 | 0 | 1 |
| 3 | 1 | 2 | 1 | 1 | -3.3768 | -4.89 | 1.3 | 1 | 0 | 0 | 1 | 1 |
| 4 | 1 | 2 | 2 | 0 | -9.0336 | -30.19 | 0.0 | 0 | 0 | 1 | 0 | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 4447 | 100 | 1483 | 2 | 0 | -2.8036 | -14.45 | 1.6 | 1 | 0 | 0 | 0 | 0 |
| 4448 | 100 | 1483 | 3 | 0 | -1.9360 | -54.76 | 0.0 | 0 | 1 | 0 | 1 | 1 |
| 4449 | 100 | 1484 | 1 | 1 | -2.4054 | -50.57 | 0.0 | 0 | 1 | 0 | 0 | 0 |
| 4450 | 100 | 1484 | 2 | 0 | -5.2795 | -21.25 | 0.0 | 0 | 0 | 1 | 0 | 1 |
| 4451 | 100 | 1484 | 3 | 0 | -6.0705 | -25.41 | 1.4 | 1 | 0 | 0 | 0 | 0 |
4452 rows × 12 columns
Fit the model#
[11]:
varnames = ['hiperf', 'medhiperf', 'price', 'opcost', 'range', 'ev', 'hybrid']
model = MixedLogit()
config = ConfigData(
n_draws = 1000,
panels=df['person_id'],
)
model.fit(
X=df[varnames],
y=df['choice'],
varnames=varnames,
alts=df['alt'],
ids=df['choice_id'],
randvars = {'price': 'ln', 'opcost': 'n', 'range': 'ln', 'ev':'n', 'hybrid': 'n'},
config=config
)
model.summary()
Message: Operation terminated successfully
Iterations: 86
Function evaluations: 103
Estimation time= 26.0 seconds
---------------------------------------------------------------------------
Coefficient Estimate Std.Err. z-val P>|z|
---------------------------------------------------------------------------
hiperf 0.1058218 0.1088161 0.9724833 0.331
medhiperf 0.5712811 0.1120844 5.0968819 3.9e-07 ***
price -0.7406271 0.1584242 -4.6749617 3.21e-06 ***
opcost 0.0119904 0.0055115 2.1755190 0.0297 *
range -0.6709790 0.4198036 -1.5983165 0.11
ev -1.5937473 0.3743450 -4.2574291 2.2e-05 ***
hybrid 0.7055276 0.1814268 3.8887717 0.000105 ***
sd.price 0.4546790 0.1594400 2.8517246 0.00441 **
sd.opcost -3.2384613 0.1451508 -22.3110177 2.53e-95 ***
sd.range -0.2468878 0.2763423 -0.8934131 0.372
sd.ev 0.5359967 0.3389797 1.5812060 0.114
sd.hybrid 0.1024516 0.3019864 0.3392589 0.734
---------------------------------------------------------------------------
Significance: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Log-Likelihood= -1298.176
AIC= 2620.351
BIC= 2683.981
Softplus the standard deviations to make them positive.#
[12]:
jax.nn.softplus(model.coeff_[len(model._rvidx):])
[12]:
Array([0.94610873, 0.03847448, 0.57730321, 0.99663525, 0.74568443], dtype=float64)
References#
Bierlaire, M. (2018). PandasBiogeme: a short introduction. EPFL (Transport and Mobility Laboratory, ENAC).
Brathwaite, T., & Walker, J. L. (2018). Asymmetric, closed-form, finite-parameter models of multinomial choice. Journal of Choice Modelling, 29, 78–112.
Cameron, A. C., & Trivedi, P. K. (2005). Microeconometrics: methods and applications. Cambridge university press.
Croissant, Y. (2020). Estimation of Random Utility Models in R: The mlogit Package. Journal of Statistical Software, 95(1), 1-41.
Revelt, D., & Train, K. (1999). Customer-specific taste parameters and mixed logit. University of California, Berkeley.