Expectation-Maximization¶
-
class
fitr.inference.em.EM(loglik_func, params, name='EMModel')¶ Expectation-Maximization with the Laplace Approximation [Huys2011], [HuysEMCode].
Attributes
name (str) Name of the model being fit. We suggest using the free parameters. loglik_func (function) The log-likelihood function to be used for model fitting params (list) List of parameters from the rlparams module nparams (int) Number of free parameters in the model param_rng (list) List of strings denoting the parameter ranges (see rlparams module for further details) prior (scipy.stats distribution) The prior distribution over parameter estimates. Here this is fixed to a multivariate normal. mu (ndarray(shape=nparams)) The prior mean over parameters cov (ndarray(shape=(nparams,nparams))) The covariance matrix for prior over parameter estimates Methods
fit(data, n_iterations=1000, c_limit=1, opt_algorithm=’BFGS’, diag=False, verbose=True) Run the model-fitting algorithm logposterior(x, states, actions, rewards) Computes the log-posterior probability group_level_estimate(param_est, hess_inv) Updates the hyperparameters of the group-level prior __printfitstart(self, n_iterations, c_limit, algorithm, init_grid, grid_reinit, dofull, early_stopping, verbose) (Private) function to print optimization info to console __printupdate(self, opt_iter, subject_i, posterior_ll, verbose) (Private) function to print update on fit iteration to console -
fit(data, n_iterations=1000, c_limit=0.001, opt_algorithm='L-BFGS-B', init_grid=False, grid_reinit=True, n_grid_points=5, n_reinit=1, dofull=True, early_stopping=True, verbose=True)¶ Performs maximum a posteriori estimation of subject-level parameters
Parameters: data : dict
Dictionary of data from all subjects.
n_iterations : int
Maximum number of iterations to allow.
c_limit : float
Threshold at which convergence is determined
opt_algorithm : {‘BFGS’, ‘L-BFGS-B’}
Algorithm to use for optimization
init_grid : bool
Whether to initialize the optimizer using brute force grid search. If False, will sample from normal distribution with mean 0 and standard deviation 1.
grid_reinit : bool
If optimization does not converge, whether to reinitialize with values from grid search
n_grid_points : int
Number of points along each axis to evaluate during grid-search initialization (only meaningful if init_grid is True).
n_reinit : int
Number of times to reinitialize the optimizer if not converged
dofull : bool
Whether update of the full covariance matrix of the prior should be done. If False, the covariance matrix is limited to one in which the off-diagonal elements are set to zero.
early_stopping : bool
Whether to stop the EM procedure if the log-model-evidence begins decreasing (thereby reverting to the last iteration’s results).
verbose : bool
Whether to print progress of model fitting
Returns: ModelFitResult
Representation of the model fitting results
-
group_level_estimate(param_est, hess_inv, dofull, verbose=True)¶ Updates the group-level hyperparameters
Parameters: param_est : ndarray(shape=(nsubjects, nparams))
Current parameter estimates for each subject
hess_inv : ndarray(shape=(nparams, nparams, nsubjects))
Inverse Hessian matrix estimate for each subject from the iteration with highest log-posterior probability
dofull : bool
Whether update of the full covariance matrix of the prior should be done. If False, the covariance matrix is limited to one in which the off-diagonal elements are set to zero.
verbose : bool
Controls degree to which results are printed
-
initialize_opt(fn=None, grid=False, Ns=None)¶ Returns initial values for the optimization
Parameters: fn : function
Function over which grid search takes place
grid : bool
Whether to return initialization values from grid search
Ns : int
Number of points per axis over which to evaluate during grid search
Returns: x0 : ndarray
1 X N vector of initial values for each parameter
-
logposterior(x, states, actions, rewards)¶ Represents the log-posterior probability function
Parameters: x : ndarray(nparams)
Array of parameters for single subject
states : ndarray(shape=[ntrials, nsteps])
Array of states encountered by subject
actions: ndarray(shape=[ntrials, nsteps])
Array of actions taken by subject
rewards : ndarray(shape=[ntrials, nsteps])
Array of rewards received by the subject.
Returns: float
Log-posterior probability
-