Mixed Logit with correlations#

Using swissmetro data, comparing results to biogeme (Bierlaire, M. (2018). PandasBiogeme: a short introduction. EPFL (Transport and Mobility Laboratory, ENAC)). Based on the xlogit example Mixed Logit.

[1]:
import pandas as pd
import numpy as np
import jax

from jaxlogit.mixed_logit import MixedLogit, ConfigData

# 64bit precision
jax.config.update("jax_enable_x64", True)

Swissmetro Dataset#

The swissmetro dataset contains stated-preferences for three alternative transportation modes that include car, train and a newly introduced mode: the swissmetro. This dataset is commonly used for estimation examples with the Biogeme and PyLogit packages. The dataset is available at http://transp-or.epfl.ch/data/swissmetro.dat and Bierlaire et. al., (2001) provides a detailed discussion of the data as wells as its context and collection process. The explanatory variables in this example include the travel time (TT) and cost CO for each of the three alternative modes.

This example also adds alternative-specific constraints to represent unobserved factors, and shows how known parameter values can be set. It also shows the functionality of considering the correlation between normally distributed variables.

Read data#

The dataset is imported to the Python environment using pandas. Then, two types of samples, ones with a trip purpose different to commute or business and ones with an unknown choice, are filtered out. The original dataset contains 10,729 records, but after filtering, 6,768 records remain for following analysis. Finally, a new column that uniquely identifies each sample is added to the dataframe and the CHOICE column, which originally contains a numerical coding of the choices, is mapped to a description that is consistent with the alternatives in the column names.

[2]:
df_wide = pd.read_table("http://transp-or.epfl.ch/data/swissmetro.dat", sep="\t")

# Keep only observations for commute and business purposes that contain known choices
df_wide = df_wide[(df_wide["PURPOSE"].isin([1, 3]) & (df_wide["CHOICE"] != 0))]

df_wide["custom_id"] = np.arange(len(df_wide))  # Add unique identifier
df_wide["CHOICE"] = df_wide["CHOICE"].map({1: "TRAIN", 2: "SM", 3: "CAR"})
df_wide
[2]:
GROUP SURVEY SP ID PURPOSE FIRST TICKET WHO LUGGAGE AGE ... TRAIN_CO TRAIN_HE SM_TT SM_CO SM_HE SM_SEATS CAR_TT CAR_CO CHOICE custom_id
0 2 0 1 1 1 0 1 1 0 3 ... 48 120 63 52 20 0 117 65 SM 0
1 2 0 1 1 1 0 1 1 0 3 ... 48 30 60 49 10 0 117 84 SM 1
2 2 0 1 1 1 0 1 1 0 3 ... 48 60 67 58 30 0 117 52 SM 2
3 2 0 1 1 1 0 1 1 0 3 ... 40 30 63 52 20 0 72 52 SM 3
4 2 0 1 1 1 0 1 1 0 3 ... 36 60 63 42 20 0 90 84 SM 4
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
8446 3 1 1 939 3 1 7 3 1 5 ... 13 30 50 17 30 0 130 64 TRAIN 6763
8447 3 1 1 939 3 1 7 3 1 5 ... 12 30 53 16 10 0 80 80 TRAIN 6764
8448 3 1 1 939 3 1 7 3 1 5 ... 16 60 50 16 20 0 80 64 TRAIN 6765
8449 3 1 1 939 3 1 7 3 1 5 ... 16 30 53 17 30 0 80 104 TRAIN 6766
8450 3 1 1 939 3 1 7 3 1 5 ... 13 60 53 21 30 0 100 80 TRAIN 6767

6768 rows × 29 columns

Reshape data#

The imported dataframe is in wide format, and it needs to be reshaped to long format for processing by jaxlogit, which offers the wide_to_long utility for this reshaping process from xlogit. The user specifies the column that uniquely identifies each sample, the names of the alternatives, the columns that vary across alternatives, and whether the alternative names are a prefix or suffix of the column names. Additionally, the user can specify a value (empty_val) to be used by default when an alternative is not available for a certain variable. Additional usage examples for the wide_to_long function are available in xlogit’s documentation. Also, details about the function parameters are available at the API reference.

[3]:
from jaxlogit.utils import wide_to_long

df = wide_to_long(
    df_wide,
    id_col="custom_id",
    alt_name="alt",
    sep="_",
    alt_list=["TRAIN", "SM", "CAR"],
    empty_val=0,
    varying=["TT", "CO", "HE", "AV", "SEATS"],
    alt_is_prefix=True,
)
df
[3]:
custom_id alt TT CO HE AV SEATS GROUP SURVEY SP ... TICKET WHO LUGGAGE AGE MALE INCOME GA ORIGIN DEST CHOICE
0 0 TRAIN 112 48 120 1 0 2 0 1 ... 1 1 0 3 0 2 0 2 1 SM
1 0 SM 63 52 20 1 0 2 0 1 ... 1 1 0 3 0 2 0 2 1 SM
2 0 CAR 117 65 0 1 0 2 0 1 ... 1 1 0 3 0 2 0 2 1 SM
3 1 TRAIN 103 48 30 1 0 2 0 1 ... 1 1 0 3 0 2 0 2 1 SM
4 1 SM 60 49 10 1 0 2 0 1 ... 1 1 0 3 0 2 0 2 1 SM
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
20299 6766 SM 53 17 30 1 0 3 1 1 ... 7 3 1 5 1 2 0 1 2 TRAIN
20300 6766 CAR 80 104 0 1 0 3 1 1 ... 7 3 1 5 1 2 0 1 2 TRAIN
20301 6767 TRAIN 108 13 60 1 0 3 1 1 ... 7 3 1 5 1 2 0 1 2 TRAIN
20302 6767 SM 53 21 30 1 0 3 1 1 ... 7 3 1 5 1 2 0 1 2 TRAIN
20303 6767 CAR 100 80 0 1 0 3 1 1 ... 7 3 1 5 1 2 0 1 2 TRAIN

20304 rows × 23 columns

Create model specification#

Following the reshaping, users can create or update the dataset’s columns in order to accommodate their model specification needs, if necessary. The code below shows how the columns ASC_TRAIN and ASC_CAR were created to incorporate alternative-specific constants in the model. In addition, the example illustrates an effective way of establishing variable interactions (e.g., trip costs for commuters with an annual pass) by updating existing columns conditional on values of other columns. Column operations provide users with an intuitive and highly-flexible mechanism to incorporate model specification aspects, such as variable transformations, interactions, and alternative specific coefficients and constants. By operating the dataframe columns, any utility specification can be accommodated in jaxlogit.

[4]:
df["ASC_TRAIN"] = np.ones(len(df)) * (df["alt"] == "TRAIN")
df["ASC_CAR"] = np.ones(len(df)) * (df["alt"] == "CAR")
df["TT"], df["CO"] = df["TT"] / 100, df["CO"] / 100  # Scale variables
annual_pass = (df["GA"] == 1) & (df["alt"].isin(["TRAIN", "SM"]))
df.loc[annual_pass, "CO"] = 0  # Cost zero for pass holders

Estimate model parameters#

The fit method estimates the model by taking as input the data from the previous step along with additional specification criteria, such as the distribution of the random parameters (randvars), the number of random draws (n_draws), and the availability of alternatives for the choice situations (avail). We set the optimization method as L-BFGS-B as this is a robust routine that usually helps solve convergence issues. Once the estimation routine is completed, the summary method can be used to display the estimation results.

The ConfigData class is used to store optional arguments to the fit method.

[5]:
varnames = ["ASC_CAR", "ASC_TRAIN", "CO", "TT"]
model = MixedLogit()

config = ConfigData(
    n_draws=1500,
    avail=(df["AV"]),
    panels=(df["ID"]),
)

res = model.fit(df[varnames], df["CHOICE"], varnames, df["alt"], df["custom_id"], {"TT": "n"}, config)
model.summary()
    Message: Operation terminated successfully
    Iterations: 17
    Function evaluations: 18
Estimation time= 13.0 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
ASC_CAR                 0.2831170     0.1061792     2.6664065       0.00768 **
ASC_TRAIN              -0.5722739     0.1404305    -4.0751410      4.65e-05 ***
CO                     -1.6601696     0.2918887    -5.6876811      1.34e-08 ***
TT                     -3.2289982     0.2035933   -15.8600453      1.19e-55 ***
sd.TT                   3.6221853     0.2379918    15.2197917      1.85e-51 ***
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -4359.218
AIC= 8728.436
BIC= 8762.536

Example of fixing parameters#

Here we add the alternative specific constraint for the swissmetro and set it to 0.

[6]:
# we left this one out before, let's add it and assert parameters to 0
df["ASC_SM"] = np.ones(len(df)) * (df["alt"] == "SM")
[7]:
varnames = ["ASC_CAR", "ASC_TRAIN", "ASC_SM", "CO", "TT"]
set_vars = {"ASC_SM": 0.0}  # Fixing parameters
model = MixedLogit()

config = ConfigData(
    avail=df["AV"],
    panels=df["ID"],
    set_vars=set_vars,
    n_draws=1500,
)

res = model.fit(
    X=df[varnames],
    y=df["CHOICE"],
    varnames=varnames,
    alts=df["alt"],
    ids=df["custom_id"],
    randvars={"TT": "n"},
    config=config,
)
model.summary()
    Message: Operation terminated successfully
    Iterations: 17
    Function evaluations: 18
Estimation time= 12.6 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
ASC_CAR                 0.2831170     0.1061792     2.6664065       0.00768 **
ASC_TRAIN              -0.5722739     0.1404305    -4.0751410      4.65e-05 ***
ASC_SM                  0.1000000     0.0000000           inf             0 ***
CO                     -1.6601696     0.2918887    -5.6876811      1.34e-08 ***
TT                     -3.2289982     0.2035933   -15.8600453      1.19e-55 ***
sd.TT                   3.6221853     0.2379918    15.2197917      1.85e-51 ***
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -4359.218
AIC= 8728.436
BIC= 8762.536

Error components with correlations#

By default, allowing correlation adds variables for the correlation of all normally and log normally distrubuted variables. For variable x and y, it adds a new variable called chol.x.y. Correlation variables representing the correlation between variables that we do not want to be correlated can be set to 0. Here some variables are excluded or set to known values according to research done in J. Walker’s PhD thesis (MIT 2001).

[8]:
varnames = ["ASC_CAR", "ASC_TRAIN", "ASC_SM", "CO", "TT"]

randvars = {"ASC_CAR": "n", "ASC_TRAIN": "n", "ASC_SM": "n"}
set_vars = {
    "ASC_SM": 0.0,
    "sd.ASC_TRAIN": 1.0,
    "sd.ASC_CAR": 0.0,
    "chol.ASC_CAR.ASC_TRAIN": 0.0,
    "chol.ASC_CAR.ASC_SM": 0.0,
}  # Identification of error components, see J. Walker's PhD thesis (MIT 2001)

config = ConfigData(
    avail=df["AV"],
    panels=df["ID"],
    set_vars=set_vars,
    include_correlations=True,  # Enable correlation between random parameters
)


model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df["CHOICE"],
    varnames=varnames,
    alts=df["alt"],
    ids=df["custom_id"],
    randvars=randvars,
    config=config,
)
model.summary()
E0203 07:15:39.052278    3312 slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:

  %gather.4 = f64[3,6768,1,1000]{3,2,1,0} gather(%constant.155, %constant.129), offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={1}, index_vector_dim=1, slice_sizes={6768,1,1000}, metadata={op_name="jit(function_with_args)/jvp()/gather" stack_frame_id=55}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
E0203 07:15:39.292143    3277 slow_operation_alarm.cc:140] The operation took 1.240264744s
Constant folding an instruction is taking > 1s:

  %gather.4 = f64[3,6768,1,1000]{3,2,1,0} gather(%constant.155, %constant.129), offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={1}, index_vector_dim=1, slice_sizes={6768,1,1000}, metadata={op_name="jit(function_with_args)/jvp()/gather" stack_frame_id=55}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
    Message: Operation terminated successfully
    Iterations: 17
    Function evaluations: 21
Estimation time= 31.6 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
ASC_CAR                -0.2279764     0.4655508    -0.4896917         0.624
ASC_TRAIN              -1.1966263     0.5873328    -2.0373905        0.0416 *
ASC_SM                  0.1000000     0.0000000           inf             0 ***
CO                     -2.0490264     0.3444999    -5.9478289      2.85e-09 ***
TT                     -2.1429727     0.6607183    -3.2433986       0.00119 **
sd.ASC_CAR              0.1000000     0.0000000           inf             0 ***
sd.ASC_TRAIN            0.1000000     0.0000000           inf             0 ***
sd.ASC_SM               2.4459162     0.2814464     8.6905216      4.47e-18 ***
chol.ASC_CAR.ASC_TR     0.1000000     0.0000000           inf             0 ***
chol.ASC_CAR.ASC_SM     0.1000000     0.0000000           inf             0 ***
chol.ASC_TRAIN.ASC_     0.6937107     0.2260149     3.0693135       0.00215 **
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -4071.438
AIC= 8154.876
BIC= 8195.795