Skip to content

Commit ba600a7

Browse files
author
Frankie Robertson
committed
Add all sorts of mean and dispersion estimation stuff
* Delete a lot of comments * Remove vestigial dependency on RDataGet * Use derivatives stuff from PsychometricsBazaarBase
1 parent 69a8e66 commit ba600a7

12 files changed

Lines changed: 309 additions & 342 deletions

File tree

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337"
2828
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
2929
PsychometricsBazaarBase = "b0d9cada-d963-45e9-a4c6-4746243987f1"
3030
QuickHeaps = "30b38841-0f52-47f8-a5f8-18d5d4064379"
31-
RDataGet = "a115732e-4334-4ecb-8ea3-f683e7f66d4d"
3231
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3332
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
3433
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
@@ -68,7 +67,6 @@ PrettyPrinting = "0.4.2"
6867
PrettyTables = "3"
6968
PsychometricsBazaarBase = "^0.8.6"
7069
QuickHeaps = "0.2.2"
71-
RDataGet = "0.1.0"
7270
Random = "^1.11"
7371
Reexport = "1"
7472
Setfield = "^1"

src/Aggregators/ability_estimator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function pdf(::LikelihoodAbilityEstimator,
2626
AbilityLikelihood(tracked_responses)
2727
end
2828

29-
function power_summary(io::IO, ability_estimator::LikelihoodAbilityEstimator)
29+
function power_summary(io::IO, ::LikelihoodAbilityEstimator)
3030
println(io, "Ability likelihood distribution")
3131
end
3232

src/ComputerAdaptiveTesting.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ include("./NextItemRules/NextItemRules.jl")
3030
include("./TerminationConditions.jl")
3131

3232
# Combining / running
33+
include("./DerivedMeasures.jl")
3334
include("./Rules.jl")
3435
include("./Sim/Sim.jl")
3536
include("./DecisionTree/DecisionTree.jl")
@@ -53,4 +54,4 @@ function require_testext()
5354
return TestExt
5455
end
5556

56-
end
57+
end

src/DerivedMeasures.jl

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
module DerivedMeasures
2+
3+
import PsychometricsBazaarBase: power_summary, GridSummary
4+
using ..Aggregators: TrackedResponses,
5+
Aggregators,
6+
AbilityIntegrator,
7+
AbilityOptimizer,
8+
AbilityEstimator,
9+
ModeAbilityEstimator,
10+
MeanAbilityEstimator,
11+
LikelihoodAbilityEstimator,
12+
DistributionAbilityEstimator,
13+
get_integrator,
14+
expectation
15+
using FittedItemBanks: domdims
16+
using ..NextItemRules: AbilityVariance, compute_criteria, compute_criterion, best_item
17+
using PsychometricsBazaarBase.Integrators: AnyGridIntegrator, get_grid, normdenom
18+
using PsychometricsBazaarBase.IndentWrappers: indent
19+
using PsychometricsBazaarBase: IntegralCoeffs
20+
using PsychometricsBazaarBase: Differentiation
21+
22+
export PointAndSpreadEstimator, MeanAndStdDevEstimator, LaplaceApproxEstimator, SpreadEstimator
23+
24+
abstract type PointAndSpreadEstimator end
25+
26+
# TODO: These all recalculate everything at the moment, but they should reuse the results generated during the CAT
27+
28+
struct MeanAndStdDevEstimator{
29+
DistEstT <: DistributionAbilityEstimator,
30+
IntegratorT <: AbilityIntegrator
31+
} <: PointAndSpreadEstimator
32+
dist_est::DistEstT
33+
integrator::IntegratorT
34+
end
35+
36+
MeanAndStdDevEstimator(ability_estimator::MeanAbilityEstimator) = MeanAndStdDevEstimator(ability_estimator.dist_est, ability_estimator.integrator)
37+
MeanAndStdDevEstimator(ability_variance::AbilityVariance) = MeanAndStdDevEstimator(ability_variance.dist_est, ability_variance.integrator)
38+
39+
function (est::MeanAndStdDevEstimator)(tracked_responses::TrackedResponses)
40+
denom = normdenom(est.integrator,
41+
est.dist_est,
42+
tracked_responses)
43+
mean = expectation(IntegralCoeffs.id,
44+
domdims(tracked_responses.item_bank),
45+
est.integrator,
46+
est.dist_est,
47+
tracked_responses,
48+
denom)
49+
return (
50+
mean,
51+
sqrt(expectation(IntegralCoeffs.SqDev(mean),
52+
domdims(tracked_responses.item_bank),
53+
est.integrator,
54+
est.dist_est,
55+
tracked_responses,
56+
denom))
57+
)
58+
end
59+
60+
function power_summary(io::IO, est::MeanAndStdDevEstimator)
61+
println(io, "Mean and standard deviation estimator")
62+
indent_io = indent(io, 2)
63+
power_summary(indent_io, est.dist_est)
64+
power_summary(indent_io, est.integrator)
65+
end
66+
67+
show(io::IO, ::MIME"text/plain", est::MeanAndStdDevEstimator) = power_summary(io, est)
68+
69+
struct LaplaceApproxEstimator{
70+
DistEstT <: DistributionAbilityEstimator,
71+
OptimizerT <: AbilityOptimizer
72+
} <: PointAndSpreadEstimator
73+
dist_est::DistEstT
74+
optimizer::OptimizerT
75+
end
76+
77+
LaplaceApproxEstimator(ability_estimator::ModeAbilityEstimator) = LaplaceApproxEstimator(ability_estimator.dist_est, ability_estimator.optim)
78+
79+
function (est::LaplaceApproxEstimator)(tracked_responses::TrackedResponses)
80+
# TODO: Numerical stability: Should directly access the log-pdf here
81+
mode = est.optimizer(IntegralCoeffs.one, est.dist_est, tracked_responses)
82+
return (
83+
mode,
84+
-Differentiation.double_derivative((ability -> log(pdf(est, tracked_responses))), mode)
85+
)
86+
end
87+
88+
function power_summary(io::IO, est::LaplaceApproxEstimator)
89+
println(io, "Laplace approximation based mean and standard deviation estimator")
90+
indent_io = indent(io, 2)
91+
power_summary(indent_io, est.dist_est)
92+
power_summary(indent_io, est.optimizer)
93+
end
94+
95+
struct SpreadEstimator{InnerT <: PointAndSpreadEstimator}
96+
inner::InnerT
97+
end
98+
99+
function (est::SpreadEstimator)(tracked_responses::TrackedResponses)
100+
_, stddev = est.inner(tracked_responses)
101+
return stddev
102+
end
103+
104+
struct DistributionSampler{
105+
DistEst <: DistributionAbilityEstimator,
106+
IntegratorT <: AbilityIntegrator,
107+
ContainerT <: Union{Vector{Float64}, Vector{Vector{Float64}}}
108+
}
109+
dist_est::DistEst
110+
integrator::IntegratorT
111+
points::ContainerT
112+
end
113+
114+
_get_estimator_and_integrator(ability_estimator::MeanAbilityEstimator) = (ability_estimator.dist_est, ability_estimator.integrator)
115+
_get_estimator_and_integrator(ability_variance::AbilityVariance) = (ability_variance.dist_est, ability_variance.integrator)
116+
117+
function DistributionSampler(composite::Union{MeanAbilityEstimator, AbilityVariance}, points=nothing)
118+
dist_est, integrator = _get_estimator_and_integrator(composite)
119+
return DistributionSampler(dist_est, integrator, points)
120+
end
121+
122+
function DistributionSampler(dist_est::DistributionAbilityEstimator, integrator::Union{AbilityIntegrator, Nothing}=nothing, points::Nothing=nothing)
123+
@info "DistributionSampler" dist_est integrator points
124+
if isnothing(integrator)
125+
return nothing
126+
end
127+
inner_integrator = get_integrator(integrator)
128+
if !isnothing(points)
129+
return DistributionSampler(dist_est, integrator, points)
130+
elseif inner_integrator isa AnyGridIntegrator
131+
return DistributionSampler(dist_est, integrator, get_grid(inner_integrator))
132+
else
133+
return nothing
134+
end
135+
end
136+
137+
function eachmatcol(xs::AbstractMatrix)
138+
eachcol(xs)
139+
end
140+
141+
function eachmatcol(xs::AbstractVector)
142+
xs
143+
end
144+
145+
function (est::DistributionSampler)(tracked_responses::TrackedResponses)
146+
num = Aggregators.pdf.(
147+
est.dist_est,
148+
tracked_responses,
149+
eachmatcol(est.points)
150+
)
151+
denom = normdenom(est.integrator, est.dist_est, tracked_responses)
152+
return num ./ denom
153+
end
154+
155+
function power_summary(io::IO, est::DistributionSampler)
156+
println(io, "Distribution sampler")
157+
indent_io = indent(io, 2)
158+
power_summary(indent_io, est.dist_est)
159+
power_summary(indent_io, est.integrator)
160+
power_summary(indent_io, GridSummary(est.points))
161+
end
162+
163+
end

src/NextItemRules/NextItemRules.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome,
2222
find1_instance, find1_type
2323
using PsychometricsBazaarBase.Integrators: Integrator, intval
2424
using PsychometricsBazaarBase: Integrators
25+
using PsychometricsBazaarBase: Differentiation
2526
using PsychometricsBazaarBase.IndentWrappers: indent
2627
import PsychometricsBazaarBase.IntegralCoeffs
2728
using FittedItemBanks: AbstractItemBank, DiscreteDomain, DomainType,
@@ -59,6 +60,9 @@ export AbilityCovarianceStateMultiCriterion, StateMultiCriterion, ItemMultiCrite
5960
export InformationMatrixCriteria
6061
export ScalarizedStateCriterion, ScalarizedItemCriterion
6162
export DRuleItemCriterion, TRuleItemCriterion
63+
export ObservedInformationPointwiseItemCriterion
64+
export RawEmpiricalInformationPointwiseItemCriterion
65+
export EmpiricalInformationPointwiseItemCriterion
6266

6367
# Prelude
6468
include("./prelude/abstract.jl")

src/NextItemRules/criteria/pointwise/information.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function compute_criterion(
99
ability,
1010
category
1111
)
12-
actual = -double_derivative((ability -> log_resp(ir, category, ability)), ability) .* resp(ir, category, ability)
12+
actual = -Differentiation.double_derivative((ability -> log_resp(ir, category, ability)), ability) .* resp(ir, category, ability)
1313
-actual
1414
end
1515

@@ -18,7 +18,7 @@ function compute_criterion_vec(
1818
ir::ItemResponse,
1919
ability
2020
)
21-
actual = -double_derivative((ability -> log_resp_vec(ir, ability)), ability) .* resp_vec(ir, ability)
21+
actual = -Differentiation.double_derivative((ability -> log_resp_vec(ir, ability)), ability) .* resp_vec(ir, ability)
2222
-actual
2323
end
2424

@@ -87,7 +87,7 @@ function compute_criterion(
8787
ir,
8888
ability,
8989
category
90-
) .- double_derivative((ability -> resp(ir, category, ability)), ability)
90+
) .- Differentiation.double_derivative((ability -> resp(ir, category, ability)), ability)
9191
-actual
9292
end
9393

@@ -100,7 +100,7 @@ function compute_criterion_vec(
100100
RawEmpiricalInformationPointwiseItemCategoryCriterion(),
101101
ir,
102102
ability
103-
) .- double_derivative((ability -> resp_vec(ir, ability)), ability)
103+
) .- Differentiation.double_derivative((ability -> resp_vec(ir, ability)), ability)
104104
-actual
105105
end
106106

src/NextItemRules/criteria/pointwise/information_support.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,9 @@ log_resp(ir::ItemResponse{<:GuessAndSlipItemBank}, response, θ) = log(resp(ir,
7070
log_resp(ir::ItemResponse{<:GuessAndSlipItemBank}, θ) = log(resp(ir, θ))
7171
log_resp_vec(ir::ItemResponse{<:GuessAndSlipItemBank}, θ) = log.(resp_vec(ir, θ))
7272

73-
function vector_hessian(f, x, n)
74-
out = ForwardDiff.jacobian(x -> ForwardDiff.jacobian(f, x), x)
75-
return reshape(out, n, n, n)
76-
end
77-
78-
function double_derivative(f, x)
79-
ForwardDiff.derivative(x -> ForwardDiff.derivative(f, x), x)
80-
end
81-
8273
function expected_item_information(ir::ItemResponse, θ::Number)
8374
exp_resp = resp_vec(ir, θ)
84-
= double_derivative((θ -> log_resp_vec(ir, θ)), θ)
75+
= Differentiation.double_derivative((θ -> log_resp_vec(ir, θ)), θ)
8576
-sum(exp_resp .* d²)
8677
end
8778

@@ -90,7 +81,7 @@ end
9081
function expected_item_information(ir::ItemResponse, θ::Vector)
9182
exp_resp = resp_vec(ir, θ)
9283
n = domdims(ir.item_bank)
93-
hess = vector_hessian-> log_resp_vec(ir, θ), θ, n)
84+
hess = Differentiation.vector_hessian-> log_resp_vec(ir, θ), θ, n)
9485
return -sum(eachslice(hess, dims=1) .* exp_resp)
9586
end
9687

@@ -118,4 +109,4 @@ end
118109

119110
function log_resp(ir::ItemResponse{<:ItemBanks.LogItemBank}, resp, θ)
120111
log(resp(ItemBanks.inner_ir(ir), resp, θ))
121-
end
112+
end

src/NextItemRules/criteria/state/ability_variance.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ function compute_criterion(
5555
)::Float64
5656
# XXX: Not quite sure about this --- is it useful, the MIRT rules cover this case
5757
mean = expectation(IntegralCoeffs.id,
58-
ndims(tracked_responses.item_bank),
58+
domdims(tracked_responses.item_bank),
5959
criterion.integrator,
6060
criterion.dist_est,
6161
tracked_responses,
6262
denom)
6363
expectation(IntegralCoeffs.SqDev(mean),
64-
ndims(tracked_responses.item_bank),
64+
domdims(tracked_responses.item_bank),
6565
criterion.integrator,
6666
criterion.dist_est,
6767
tracked_responses,

src/NextItemRules/porcelain/porcelain.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,15 @@ function TRuleItemCriterion(ability_estimator)
99
InformationMatrixCriteria(ability_estimator),
1010
TraceScalarizer())
1111
end
12+
13+
function ObservedInformationPointwiseItemCriterion()
14+
TotalItemInformation(ObservedInformationPointwiseItemCategoryCriterion())
15+
end
16+
17+
function RawEmpiricalInformationPointwiseItemCriterion()
18+
TotalItemInformation(RawEmpiricalInformationPointwiseItemCategoryCriterion())
19+
end
20+
21+
function EmpiricalInformationPointwiseItemCriterion()
22+
TotalItemInformation(EmpiricalInformationPointwiseItemCategoryCriterion())
23+
end

src/Sim/Sim.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using DocStringExtensions
77
using StatsBase
88
using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, domdims
99
using PsychometricsBazaarBase: show_into_buf, power_summary_into_buf
10+
using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome
1011
using PsychometricsBazaarBase.Integrators
1112
using PsychometricsBazaarBase.IndentWrappers: indent
1213
using PsychometricsBazaarBase: GridSummary
@@ -25,6 +26,7 @@ using ..Aggregators: TrackedResponses,
2526
MeanAbilityEstimator,
2627
LikelihoodAbilityEstimator,
2728
RiemannEnumerationIntegrator
29+
using ..DerivedMeasures: MeanAndStdDevEstimator, LaplaceApproxEstimator, DistributionSampler
2830
using ..NextItemRules: AbilityVariance, compute_criteria, compute_criterion, best_item
2931
import Base: show
3032
import PsychometricsBazaarBase: power_summary

0 commit comments

Comments
 (0)