Scikit-learn interface and Cross Validation#

Uses the swissmetro data. Based on previous example for this dataset, which is based on the xlogit example Mixed Logit.

Note that this wrapper can use scikit-learn’s tools such as cross-validation as in this example, but it is not a proper estimator by scikit-learn’s requirements, and it does not pass sklearn.utils.estimator_checks.check_estimator. This is because information about the variables and the alternatives needs to be provided in the pandas dataframe and as data to the estimator, where as the check validation tool for scikit-learn only passes in generated numpy arrays of floats for the input data. The number of alternatives and variables could be inferred, but sometimes would be ambiguous.

[1]:
import pandas as pd
import numpy as np
import jax

#  64bit precision
jax.config.update("jax_enable_x64", True)

Import Swissmetro Dataset#

The alternatives are car, train or SM (the Swissmetro). The explanatory variables are cost, travel time and alternative specific constants for the train and car options. See the previous example for the Swissmetro Dataset for a detailed explaination here

Read data#

The dataset is imported and filtered.

[2]:
df_wide = pd.read_table("http://transp-or.epfl.ch/data/swissmetro.dat", sep="\t")

# Keep only observations for commute and business purposes that contain known choices
df_wide = df_wide[(df_wide["PURPOSE"].isin([1, 3]) & (df_wide["CHOICE"] != 0))]
df_wide["CHOICE"] = df_wide["CHOICE"].map({1: "TRAIN", 2: "SM", 3: "CAR"})

df_wide["custom_id"] = np.arange(len(df_wide))  # Add unique identifier
df_wide
[2]:
GROUP SURVEY SP ID PURPOSE FIRST TICKET WHO LUGGAGE AGE ... TRAIN_CO TRAIN_HE SM_TT SM_CO SM_HE SM_SEATS CAR_TT CAR_CO CHOICE custom_id
0 2 0 1 1 1 0 1 1 0 3 ... 48 120 63 52 20 0 117 65 SM 0
1 2 0 1 1 1 0 1 1 0 3 ... 48 30 60 49 10 0 117 84 SM 1
2 2 0 1 1 1 0 1 1 0 3 ... 48 60 67 58 30 0 117 52 SM 2
3 2 0 1 1 1 0 1 1 0 3 ... 40 30 63 52 20 0 72 52 SM 3
4 2 0 1 1 1 0 1 1 0 3 ... 36 60 63 42 20 0 90 84 SM 4
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
8446 3 1 1 939 3 1 7 3 1 5 ... 13 30 50 17 30 0 130 64 TRAIN 6763
8447 3 1 1 939 3 1 7 3 1 5 ... 12 30 53 16 10 0 80 80 TRAIN 6764
8448 3 1 1 939 3 1 7 3 1 5 ... 16 60 50 16 20 0 80 64 TRAIN 6765
8449 3 1 1 939 3 1 7 3 1 5 ... 16 30 53 17 30 0 80 104 TRAIN 6766
8450 3 1 1 939 3 1 7 3 1 5 ... 13 60 53 21 30 0 100 80 TRAIN 6767

6768 rows × 29 columns

Reshape data#

This scikit learn interface uses the data in wide format. Here are data transformations and adding alternative specific constraints using pandas dataframes. Data headings for each alternative and variable pair is in the form alternative_variable, so for the cost of the train option, it would be TRAIN_CO.

[3]:
varnames = ["CO", "TT"]
alternatives = ["TRAIN", "CAR", "SM"]
seperator = "_"
alt_is_prefix = True

for alternative in alternatives:
    # alternative specific constants for train and car
    for alternative_constant in ["TRAIN", "CAR"]:
        if alternative_constant == alternative:
            df_wide[alternative + seperator + 'ASC' + seperator + alternative_constant] = np.ones(len(df_wide))
        else:
            df_wide[alternative + seperator + 'ASC' + seperator + alternative_constant] = np.zeros(len(df_wide))

    # scale time and cost
    for var in varnames:
        df_wide[alternative + seperator + var] = df_wide[alternative + seperator + var]/100


varnames = ["CO", "TT", "ASC_TRAIN", "ASC_CAR"]
all_varnames = [alternative + seperator + varname for alternative in alternatives for varname in varnames]
all_varnames
[3]:
['TRAIN_CO',
 'TRAIN_TT',
 'TRAIN_ASC_TRAIN',
 'TRAIN_ASC_CAR',
 'CAR_CO',
 'CAR_TT',
 'CAR_ASC_TRAIN',
 'CAR_ASC_CAR',
 'SM_CO',
 'SM_TT',
 'SM_ASC_TRAIN',
 'SM_ASC_CAR']

Creating and fitting a model#

Options for the model are given in the creation of the esimtator. Note that variable names must be included here. Panel data is currently not supported.

Then the model can be fit when given the data.

[4]:
from jaxlogit.scikit_wrapper import MixedLogitEstimator

mixed_logit_estimator = MixedLogitEstimator(
    varnames=varnames,
    randvars = {'TT': 'n'},
    n_draws=1500
)
X=df_wide[all_varnames]
y=df_wide["CHOICE"]

mixed_logit_estimator.fit(X, y)
[4]:
MixedLogitEstimator(n_draws=1500, randvars={'TT': 'n'},
                    varnames=['CO', 'TT', 'ASC_TRAIN', 'ASC_CAR'])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Scikit learn utilities#

From this interface utilties for splitting up data in to training and testing data and cross validation can be used.

[5]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
mixed_logit_estimator.fit(X_train, y_train)

mixed_logit_estimator.predict(X_test)
[5]:
array(['SM', 'SM', 'SM', ..., 'SM', 'SM', 'SM'],
      shape=(2708,), dtype='<U3')
[6]:
mixed_logit_estimator.score(X_test, y_test)
[6]:
0.6004431314623339
[7]:
from sklearn.model_selection import cross_val_score
scores = cross_val_score(mixed_logit_estimator, X, y, cv=5)
scores
[7]:
array([0.59379616, 0.49039882, 0.59748892, 0.60458241, 0.60458241])