Expectation-Maximization¶
-
class
fitr.inference.em.
EM
(loglik_func, params, name='EMModel')¶ Expectation-Maximization with the Laplace Approximation [Huys2011], [HuysEMCode].
Attributes: - name : str
Name of the model being fit. We suggest using the free parameters.
- loglik_func : function
The log-likelihood function to be used for model fitting
- params : list
List of parameters from the rlparams module
- nparams : int
Number of free parameters in the model
- param_rng : list
List of strings denoting the parameter ranges (see rlparams module for further details)
- prior : scipy.stats distribution
The prior distribution over parameter estimates. Here this is fixed to a multivariate normal.
- mu : ndarray(shape=nparams)
The prior mean over parameters
- cov : ndarray(shape=(nparams,nparams))
The covariance matrix for prior over parameter estimates
Methods
fit(data, n_iterations=1000, c_limit=1, opt_algorithm=’BFGS’, diag=False, verbose=True) Run the model-fitting algorithm logposterior(x, states, actions, rewards) Computes the log-posterior probability group_level_estimate(param_est, hess_inv) Updates the hyperparameters of the group-level prior __printfitstart(self, n_iterations, c_limit, algorithm, init_grid, grid_reinit, dofull, early_stopping, verbose) (Private) function to print optimization info to console __printupdate(self, opt_iter, subject_i, posterior_ll, verbose) (Private) function to print update on fit iteration to console -
fit
(data, n_iterations=1000, c_limit=0.001, opt_algorithm='L-BFGS-B', init_grid=False, grid_reinit=True, n_grid_points=5, n_reinit=1, dofull=True, early_stopping=True, verbose=True)¶ Performs maximum a posteriori estimation of subject-level parameters
Parameters: - data : dict
Dictionary of data from all subjects.
- n_iterations : int
Maximum number of iterations to allow.
- c_limit : float
Threshold at which convergence is determined
- opt_algorithm : {‘BFGS’, ‘L-BFGS-B’}
Algorithm to use for optimization
- init_grid : bool
Whether to initialize the optimizer using brute force grid search. If False, will sample from normal distribution with mean 0 and standard deviation 1.
- grid_reinit : bool
If optimization does not converge, whether to reinitialize with values from grid search
- n_grid_points : int
Number of points along each axis to evaluate during grid-search initialization (only meaningful if init_grid is True).
- n_reinit : int
Number of times to reinitialize the optimizer if not converged
- dofull : bool
Whether update of the full covariance matrix of the prior should be done. If False, the covariance matrix is limited to one in which the off-diagonal elements are set to zero.
- early_stopping : bool
Whether to stop the EM procedure if the log-model-evidence begins decreasing (thereby reverting to the last iteration’s results).
- verbose : bool
Whether to print progress of model fitting
Returns: - ModelFitResult
Representation of the model fitting results
-
group_level_estimate
(param_est, hess_inv, dofull, verbose=True)¶ Updates the group-level hyperparameters
Parameters: - param_est : ndarray(shape=(nsubjects, nparams))
Current parameter estimates for each subject
- hess_inv : ndarray(shape=(nparams, nparams, nsubjects))
Inverse Hessian matrix estimate for each subject from the iteration with highest log-posterior probability
- dofull : bool
Whether update of the full covariance matrix of the prior should be done. If False, the covariance matrix is limited to one in which the off-diagonal elements are set to zero.
- verbose : bool
Controls degree to which results are printed
-
initialize_opt
(fn=None, grid=False, Ns=None)¶ Returns initial values for the optimization
Parameters: - fn : function
Function over which grid search takes place
- grid : bool
Whether to return initialization values from grid search
- Ns : int
Number of points per axis over which to evaluate during grid search
Returns: - x0 : ndarray
1 X N vector of initial values for each parameter
-
logposterior
(x, states, actions, rewards)¶ Represents the log-posterior probability function
Parameters: - x : ndarray(nparams)
Array of parameters for single subject
- states : ndarray(shape=[ntrials, nsteps])
Array of states encountered by subject
- actions: ndarray(shape=[ntrials, nsteps])
Array of actions taken by subject
- rewards : ndarray(shape=[ntrials, nsteps])
Array of rewards received by the subject.
Returns: - float
Log-posterior probability