Likelihood computation in ADDM.jl
In the previous tutorial we were not able to recover the true parameters used for the simulated data when using stateStep = 0.1. Reducing this to stateStep = 0.01 corrected the recovery. In this tutorial we will walk through how parameter estimation in ADDM.jl works to explain the effect of this change.
Brief overview of parameter estimation methods for sequential sampling models
How do we estimate parameters? We choose a measure to quantify the difference between observed/empirical data and data that would be generated
For sequential sampling models these can be ...[1]
A very common metric in all kinds of applications is likelihood
What is the likelihood in the context of sequential sampling models? It is the probability of observing the endorsed choice at the observed response time. This second part is what makes these models powerful. This is what we mean by the "joint" modeling of choice and response times.
There are a few ways of calculating this value for sequential sampling models:
- The analytical solution of the Wiener First Passage Time distribution
- Trialwise simulations
- Approximate Bayesian Computation
- Solving the Fokker Planck Equation
The likelihood functions in ADDM.jl use the last method.
Briefly, the FPE describes how a probability distribution changes over time. Since it is an expression of change, formally it is written as a partial differential equation. We'll skip the details of the math here but for an in depth dive, please see Shinn et al.
Here, we'll try to keep things intuitive.
[ADD Gabi's supplementary figure here]
Ok so what is the effect of the discretization step sizes (in both time and space)
julia> using ADDM, CSV, DataFrames, DataFramesMetajulia> using Plots, StatsPlots, Random, Plots.PlotMeasuresjulia> Random.seed!(38435)Random.TaskLocalRNG()julia> MyModel = ADDM.define_model(d = 0.007, σ = 0.03, θ = .6, barrier = 1, decay = 0, nonDecisionTime = 100, bias = 0.0)ADDM.aDDM(Dict{Symbol, Any}(:nonDecisionTime => 100, :σ => 0.03, :d => 0.007, :bias => 0.0, :barrier => 1, :decay => 0, :θ => 0.6, :η => 0.0))julia> data_path = "./data/""./data/"julia> data = ADDM.load_data_from_csv(data_path * "stimdata.csv", data_path * "fixations.csv"; stimsOnly = true);Error while reading experimental data file ./data/stimdata.csvjulia> nTrials = 1400;julia> MyStims = (valueLeft = reduce(vcat, [[i.valueLeft for i in data[j]] for j in keys(data)])[1:nTrials], valueRight = reduce(vcat, [[i.valueRight for i in data[j]] for j in keys(data)])[1:nTrials]);ERROR: MethodError: no method matching keys(::Nothing) Closest candidates are: keys(::Union{Tables.AbstractColumns, Tables.AbstractRow}) @ Tables ~/.julia/packages/Tables/NSGZI/src/Tables.jl:189 keys(::IndexStyle, ::AbstractArray, ::AbstractArray...) @ Base abstractarray.jl:397 keys(::Base.SkipMissing) @ Base missing.jl:264 ...julia> vDiffs = sort(unique([x.valueLeft - x.valueRight for x in data["1"]]));ERROR: MethodError: no method matching getindex(::Nothing, ::String)julia> MyFixationData = ADDM.process_fixations(data, fixDistType="fixation", valueDiffs = vDiffs);ERROR: UndefVarError: `vDiffs` not definedjulia> MyArgs = (timeStep = 10.0, cutOff = 20000, fixationData = MyFixationData);ERROR: UndefVarError: `MyFixationData` not definedjulia> SimData = ADDM.simulate_data(MyModel, MyStims, ADDM.aDDM_simulate_trial, MyArgs);ERROR: UndefVarError: `MyStims` not defined
We can look at a few things.
Save intermediate likelihoods for all trials with stepSize = .1 vs .01 for the correct and incorrect parameters
julia> param_grid = [(d = 0.007, sigma = 0.03, theta = 0.6), (d = 0.007, sigma = 0.05, theta = 0.6)];julia> output_large = ADDM.grid_search(SimData, param_grid, ADDM.aDDM_get_trial_likelihood, Dict(:η=>0.0, :barrier=>1, :decay=>0, :nonDecisionTime=>100, :bias=>0.0), likelihood_args = (timeStep = 10.0, stateStep = 0.1), save_intermediate_likelihoods = true , intermediate_likelihood_path="./outputs/", intermediate_likelihood_fn="large_stateStep_likelihoods");ERROR: UndefVarError: `SimData` not definedjulia> output_small = ADDM.grid_search(SimData, param_grid, ADDM.aDDM_get_trial_likelihood, Dict(:η=>0.0, :barrier=>1, :decay=>0, :nonDecisionTime=>100, :bias=>0.0), likelihood_args = (timeStep = 10.0, stateStep = 0.01), save_intermediate_likelihoods = true, intermediate_likelihood_path="./outputs/", intermediate_likelihood_fn="small_stateStep_likelihoods");ERROR: UndefVarError: `SimData` not defined
julia> fns = ["large", "small"];julia> trial_likelihoods_for_sigmas = DataFrame();julia> for fn in fns trial_likelihoods = DataFrame(CSV.File("./outputs/"* fn *"_stateStep_likelihoods.csv", delim=",")) cur_tlfs = unstack(trial_likelihoods, :trial_num, :sigma, :likelihood) cur_tlfs[!, :stateStep] .= fn * " stateStep" trial_likelihoods_for_sigmas = vcat(trial_likelihoods_for_sigmas, cur_tlfs) endERROR: ArgumentError: "./outputs/large_stateStep_likelihoods.csv" is not a valid file or doesn't existjulia> rename!(trial_likelihoods_for_sigmas, [Symbol(0.05), Symbol(0.03)] .=> [:incorrect_sigma, :correct_sigma])ERROR: ArgumentError: Tried renaming :0.05 to :incorrect_sigma, when data frame has no columns.julia> ax_lims = (minimum(vcat(trial_likelihoods_for_sigmas.incorrect_sigma, trial_likelihoods_for_sigmas.correct_sigma)), maximum(vcat(trial_likelihoods_for_sigmas.incorrect_sigma, trial_likelihoods_for_sigmas.correct_sigma)))ERROR: ArgumentError: column name "incorrect_sigma" not found in the data frame since it has no columnsjulia> @df trial_likelihoods_for_sigmas scatter(:correct_sigma, :incorrect_sigma, xlabel = "Likelihoods for true parameters", ylabel = "Likelihoods for incorrect parameters", lim = ax_lims, group = :stateStep, m = (0.5, [:x :+], 4))ERROR: UndefVarError: `ax_lims` not definedjulia> Plots.abline!(1, 0, line=:dash, color=:black, label="")ERROR: No current plot/subplot
Pick a few trials where the likelihoods differ a lot between the correct and incorrect parameters. Use the debug option in the aDDM_get_trial_likelihood to plot the propogation of the probability distribution across timeSteps
julia> # make new column for the difference in likelihoods for correct vs incorrect sigma @transform!(trial_likelihoods_for_sigmas, :diff_likelihood = :incorrect_sigma - :correct_sigma)ERROR: ArgumentError: column name "incorrect_sigma" not found in the data frame since it has no columnsjulia> # order by that difference column @orderby(trial_likelihoods_for_sigmas, -:diff_likelihood)ERROR: ArgumentError: column name "diff_likelihood" not found in the data frame since it has no columnsjulia> # Pick top 4 trials (or maybe just one) diff_trial_nums = [@orderby(trial_likelihoods_for_sigmas, -:diff_likelihood)[1,:trial_num]];ERROR: ArgumentError: column name "diff_likelihood" not found in the data frame since it has no columnsjulia> # extract these from the data diff_trials = SimData[diff_trial_nums];ERROR: UndefVarError: `SimData` not defined
Plot probStates for each trial with small vs large stateStep for correct and incorrect model
julia> # 2 x 2 plot # Rows are stepsize # Cols are models # Point is to show that likelihood value changes depending on stepsize # Colors must match across the four plots # Need a legend common to all # Why are the prStates plots with small stepsize so dark? # Because the values in each bin are very small. # They values in each bin are small because they are spread over 10 times as many bins. correct_model = MyModelADDM.aDDM(Dict{Symbol, Any}(:nonDecisionTime => 100, :σ => 0.03, :d => 0.007, :bias => 0.0, :barrier => 1, :decay => 0, :θ => 0.6, :η => 0.0))julia> incorrect_model = ADDM.define_model(d = 0.007, σ = 0.05, θ = .6, barrier = 1, decay = 0, nonDecisionTime = 100, bias = 0.0)ADDM.aDDM(Dict{Symbol, Any}(:nonDecisionTime => 100, :σ => 0.05, :d => 0.007, :bias => 0.0, :barrier => 1, :decay => 0, :θ => 0.6, :η => 0.0))julia> # Use aDDM_get_trial_likelihood with debug = true to get probStates and probUp and _, prStates_cm_ls, probUpCrossing_cm_ls, probDownCrossing_cm_ls = ADDM.aDDM_get_trial_likelihood(;model = correct_model, trial = diff_trials[1], timeStep = 10.0, stateStep = 0.1, debug = true)ERROR: UndefVarError: `diff_trials` not definedjulia> _, prStates_cm_ss, probUpCrossing_cm_ss, probDownCrossing_cm_ss = ADDM.aDDM_get_trial_likelihood(;model = correct_model, trial = diff_trials[1], timeStep = 10.0, stateStep = 0.01, debug = true)ERROR: UndefVarError: `diff_trials` not definedjulia> _, prStates_im_ls, probUpCrossing_im_ls, probDownCrossing_im_ls = ADDM.aDDM_get_trial_likelihood(;model = incorrect_model, trial = diff_trials[1], timeStep = 10.0, stateStep = 0.1, debug = true)ERROR: UndefVarError: `diff_trials` not definedjulia> _, prStates_im_ss, probUpCrossing_im_ss, probDownCrossing_im_ss = ADDM.aDDM_get_trial_likelihood(;model = incorrect_model, trial = diff_trials[1], timeStep = 10.0, stateStep = 0.01, debug = true)ERROR: UndefVarError: `diff_trials` not definedjulia> likMax = maximum(vcat(probUpCrossing_cm_ls, probDownCrossing_cm_ls, probUpCrossing_cm_ss, probDownCrossing_cm_ss, probUpCrossing_im_ls, probDownCrossing_im_ls, probUpCrossing_im_ss, probDownCrossing_im_ss))ERROR: UndefVarError: `probUpCrossing_cm_ls` not definedjulia> likelihoodLims = (0, likMax);ERROR: UndefVarError: `likMax` not definedjulia> prStateLims = (0, 0.05);julia> p1 = state_space_plot(prStates_cm_ls, probUpCrossing_cm_ls, probDownCrossing_cm_ls, 10, 0.1, likelihoodLims, prStateLims);ERROR: UndefVarError: `state_space_plot` not definedjulia> p2 = state_space_plot(prStates_cm_ss, probUpCrossing_cm_ss, probDownCrossing_cm_ss, 10, 0.01, likelihoodLims, prStateLims);ERROR: UndefVarError: `state_space_plot` not definedjulia> p3 = state_space_plot(prStates_im_ls, probUpCrossing_im_ls, probDownCrossing_im_ls, 10, 0.1, likelihoodLims, prStateLims);ERROR: UndefVarError: `state_space_plot` not definedjulia> p4 = state_space_plot(prStates_im_ss, probUpCrossing_im_ss, probDownCrossing_im_ss, 10, 0.01, likelihoodLims, prStateLims);ERROR: UndefVarError: `state_space_plot` not definedjulia> plot_array = Any[];julia> push!(plot_array, p1);ERROR: UndefVarError: `p1` not definedjulia> push!(plot_array, p2);ERROR: UndefVarError: `p2` not definedjulia> push!(plot_array, p3);ERROR: UndefVarError: `p3` not definedjulia> push!(plot_array, p4);ERROR: UndefVarError: `p4` not definedjulia> plot(plot_array...)Plot{Plots.GRBackend() n=0}
- 1For a more detailed overview see Shinn, M., Lam, N. H., & Murray, J. D. (2020). A flexible framework for simulating and fitting generalized drift-diffusion models. ELife, 9, e56938.