jaxlogit.scikit_wrapper.MixedLogitEstimator#

class jaxlogit.scikit_wrapper.MixedLogitEstimator(alternatives=(), varnames=(), randvars=(), weights=None, avail=None, panels=None, init_coeff=None, maxiter=2000, random_state=None, n_draws=1000, halton=True, halton_opts=None, tol_opts=None, num_hess=False, set_vars=None, optim_method='L-BFGS-scipy', skip_std_errs=False, include_correlations=False, force_positive_chol_diag=True, hessian_by_row=True, finite_diff_hessian=False, batch_size=None, verbose=1)#
__init__(alternatives=(), varnames=(), randvars=(), weights=None, avail=None, panels=None, init_coeff=None, maxiter=2000, random_state=None, n_draws=1000, halton=True, halton_opts=None, tol_opts=None, num_hess=False, set_vars=None, optim_method='L-BFGS-scipy', skip_std_errs=False, include_correlations=False, force_positive_chol_diag=True, hessian_by_row=True, finite_diff_hessian=False, batch_size=None, verbose=1)#

Initialises a jaxlogit estimator with configurations for the fit and predict functions.

Parameters#

alternativeslist-like

Names of valid alternatives that were chosen.

varnameslist-like of shape (n_features,), required

Names of explanatory variables that must match the number and order of columns in X.

randvarsdict, required

Names (keys) and mixing distributions (values) of variables that have random parameters as coefficients. Possible mixing distributions are: - 'n': normal - 'ln': lognormal - 't': triangular - 'tn': truncated normal

weightsarray-like, shape (n_samples,), optional

Sample weights in long format.

availarray-like, shape (n_samples*n_alts,), optional

Availability of alternatives for the choice situations. One when available or zero otherwise.

panelsarray-like, shape (n_samples*n_alts,), optional

Identifiers in long format to create panels in combination with ids.

init_coeffnumpy.ndarray, shape (n_variables,), optional

Initial coefficients for estimation.

maxiterint, default=2000

Maximum number of iterations.

random_stateint, optional

Random seed for numpy random generator.

n_drawsint, default=1000

Number of random draws to approximate the mixing distributions of the random coefficients.

haltonbool, default=True

Whether the estimation uses halton draws.

halton_optsdict, optional

Options for generation of halton draws. The dictionary accepts the following options (keys):

  • shufflebool, default=False

    Whether the Halton draws should be shuffled.

  • dropint, default=100

    Number of initial Halton draws to discard to minimize correlations between Halton sequences.

  • primeslist

    List of primes to be used as base for generation of Halton sequences.

tol_optsdict, optional

Options for tolerance of optimization routine. The dictionary accepts the following options (keys):

  • ftolfloat, default=1e-10

    Tolerance for objective function (log-likelihood).

  • gtolfloat, default=1e-5

    Tolerance for gradient function.

num_hessbool, default=False

Whether numerical hessian should be used for estimation of standard errors.

set_varsdict, optional

Specified variable names (keys) of variables to be set to the given value (values).

optim_method{‘trust-region’, ‘BFGS’, ‘L-BFGS-B’}, default=’L-BFGS-B’

Optimization method to use for model estimation.

skip_std_errsbool, default=False

Whether estimation of standard errors should be skipped.

include_correlationsbool, default=False

Whether correlations between variables should be included as explanatory variables.

force_positive_chol_diagbool, default=True

Whether to force positive diagonal elements in Cholesky decomposition.

hessian_by_rowbool, default=True

Whether to calculate the hessian row by row in a for loop to save memory at the expense of runtime.

finite_diff_hessianbool, default=False

Whether the hessian should be computed using finite difference. If True, this will stay within memory limits.

batch_sizeint, optional

Size of batches used to avoid GPU memory overflow.

Methods

__init__([alternatives, varnames, randvars, ...])

Initialises a jaxlogit estimator with configurations for the fit and predict functions.

fit(X, y)

Fit Mixed Logit model.

get_metadata_routing()

Get metadata routing of this object.

get_params([deep])

Get parameters for this estimator.

predict(X)

Generate probabilities of each alternative for each choice situation in X.

score(X, y[, sample_weight])

Return accuracy on provided data and labels.

set_params(**params)

Set the parameters of this estimator.

set_score_request(*[, sample_weight])

Configure whether metadata should be requested to be passed to the score method.