From 75947f2e4087fb6cafa9de153c2caaba359db75a Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Thu, 17 Apr 2025 09:37:18 +0300 Subject: [PATCH 1/3] Move add_response!/pop_response! to Responses --- src/Responses.jl | 23 +++++++++++++++++++++++ src/Stateful.jl | 6 +++--- src/aggregators/Aggregators.jl | 5 ++--- src/aggregators/ability_tracker.jl | 24 +++--------------------- src/decision_tree/DecisionTree.jl | 2 +- 5 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/Responses.jl b/src/Responses.jl index 8a7560f..086c513 100644 --- a/src/Responses.jl +++ b/src/Responses.jl @@ -7,6 +7,7 @@ using FittedItemBanks: AbstractItemBank, using AutoHashEquals: @auto_hash_equals export Response, BareResponses, AbilityLikelihood, function_xs, function_ys +export add_response!, pop_response! concrete_response_type(::BooleanResponse) = Bool concrete_response_type(::MultinomialResponse) = Int @@ -69,6 +70,28 @@ function Base.iterate(::BareResponses, gen_gen_state) return _iter_helper(gen, iterate(gen, gen_state)) end +function Base.empty!(responses::BareResponses) + Base.empty!(responses.indices) + Base.empty!(responses.values) +end + +function add_response!(responses::BareResponses, response::Response)::BareResponses + push!(responses.indices, response.index) + push!(responses.values, response.value) + responses +end + +function pop_response!(responses::BareResponses)::BareResponses + pop!(responses.indices) + pop!(responses.values) + responses +end + +function Base.sizehint!(bare_responses::BareResponses, n) + sizehint!(bare_responses.indices, n) + sizehint!(bare_responses.values, n) +end + struct AbilityLikelihood{ItemBankT <: AbstractItemBank, BareResponsesT <: BareResponses} item_bank::ItemBankT responses::BareResponsesT diff --git a/src/Stateful.jl b/src/Stateful.jl index b47f3d0..b506088 100644 --- a/src/Stateful.jl +++ b/src/Stateful.jl @@ -10,7 +10,7 @@ using DocStringExtensions using FittedItemBanks: AbstractItemBank, ResponseType using ..Aggregators: TrackedResponses, Aggregators using ..CatConfig: CatLoopConfig, CatRules -using ..Responses: BareResponses, Response +using ..Responses: BareResponses, Response, Responses using ..NextItemRules: compute_criteria, best_item using ..Sim: Sim, item_label @@ -190,13 +190,13 @@ end function add_response!(config::StatefulCatConfig, index, response) tracked_responses = config.tracked_responses[] - Aggregators.add_response!( + Responses.add_response!( tracked_responses, Response( ResponseType(tracked_responses.item_bank), index, response)) end function rollback!(config::StatefulCatConfig) - pop_response!(config.tracked_responses[]) + Responses.pop_response!(config.tracked_responses[]) end function reset!(config::StatefulCatConfig) diff --git a/src/aggregators/Aggregators.jl b/src/aggregators/Aggregators.jl index 5f28725..b73e188 100644 --- a/src/aggregators/Aggregators.jl +++ b/src/aggregators/Aggregators.jl @@ -17,7 +17,7 @@ using FittedItemBanks: AbstractItemBank, ContinuousDomain, PointsItemBank, ResponseType, VectorContinuousDomain, domdims, item_params, resp, resp_vec, responses using ..Responses -using ..Responses: concrete_response_type, function_xs, function_ys +using ..Responses: concrete_response_type, function_xs, function_ys, Responses using ..ConfigBase using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome, find1_instance, find1_type, @@ -37,8 +37,7 @@ import PsychometricsBazaarBase.IntegralCoeffs export AbilityEstimator, TrackedResponses export AbilityTracker, NullAbilityTracker, PointAbilityTracker, GriddedAbilityTracker export ClosedFormNormalAbilityTracker, track! -export response_expectation, - add_response!, pop_response!, expectation, distribution_estimator +export response_expectation, expectation, distribution_estimator export PointAbilityEstimator, PriorAbilityEstimator, LikelihoodAbilityEstimator export ModeAbilityEstimator, MeanAbilityEstimator export Speculator, replace_speculation!, normdenom, maybe_tracked_ability_estimate diff --git a/src/aggregators/ability_tracker.jl b/src/aggregators/ability_tracker.jl index 6605bbf..8958702 100644 --- a/src/aggregators/ability_tracker.jl +++ b/src/aggregators/ability_tracker.jl @@ -1,37 +1,19 @@ -function sizehint!(bare_responses::BareResponses, n) - sizehint!(bare_responses.indices, n) - sizehint!(bare_responses.values, n) -end - function track!(responses) track!(responses, responses.ability_tracker) end -function add_response!(responses::BareResponses, response::Response)::BareResponses - push!(responses.indices, response.index) - push!(responses.values, response.value) - responses -end - -function add_response!(tracked_responses::TrackedResponses, response::Response) +function Responses.add_response!(tracked_responses::TrackedResponses, response::Response) add_response!(tracked_responses.responses, response) track!(tracked_responses) end -function pop_response!(responses::BareResponses)::BareResponses - pop!(responses.indices) - pop!(responses.values) - responses -end - -function pop_response!(tracked_responses::TrackedResponses)::TrackedResponses +function Responses.pop_response!(tracked_responses::TrackedResponses)::TrackedResponses pop_response!(tracked_responses.responses) tracked_responses end function Base.empty!(tracked_responses::TrackedResponses) - Base.empty!(tracked_responses.responses.indices) - Base.empty!(tracked_responses.responses.values) + Base.empty!(tracked_responses.responses) end function response_expectation(ability_estimator::DistributionAbilityEstimator, diff --git a/src/decision_tree/DecisionTree.jl b/src/decision_tree/DecisionTree.jl index 6fe3ffb..b42ad58 100644 --- a/src/decision_tree/DecisionTree.jl +++ b/src/decision_tree/DecisionTree.jl @@ -6,7 +6,7 @@ using ComputerAdaptiveTesting.ConfigBase: CatConfigBase using ComputerAdaptiveTesting.PushVectors using ComputerAdaptiveTesting.NextItemRules using ComputerAdaptiveTesting.Aggregators -using ComputerAdaptiveTesting.Responses: BareResponses, Response +using ComputerAdaptiveTesting.Responses: BareResponses, Response, add_response!, pop_response! using FittedItemBanks: AbstractItemBank, BooleanResponse, ResponseType # TODO: Remove ability tracking from here? From 73b933d8e1866427d59ce0321baebcdab6fbd13f Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Thu, 17 Apr 2025 09:38:47 +0300 Subject: [PATCH 2/3] Add TestExt for stateful interface test --- Project.toml | 6 +++ ext/TestExt.jl | 94 ++++++++++++++++++++++++++++++++++ src/ComputerAdaptiveTesting.jl | 11 ++++ test/stateful.jl | 24 ++++++++- 4 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 ext/TestExt.jl diff --git a/Project.toml b/Project.toml index d905178..24ec8c9 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +[weakdeps] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[extensions] +TestExt = "Test" + [compat] Accessors = "^0.1.12" Aqua = "0.8" diff --git a/ext/TestExt.jl b/ext/TestExt.jl new file mode 100644 index 0000000..3722c74 --- /dev/null +++ b/ext/TestExt.jl @@ -0,0 +1,94 @@ +module TestExt + +using Test +using ComputerAdaptiveTesting: Stateful + +export test_stateful_cat_1d_dich_ib + +function test_stateful_cat_1d_dich_ib( + cat::Stateful.StatefulCat, + item_bank_length; + supports_ranked_and_criteria = true, + supports_rollback = true + ) + if item_bank_length < 3 + error("Item bank length must be at least 3.") + end + @testset "response round trip" begin + responses_before = Stateful.get_responses(cat) + @test length(responses_before.indices) == 0 + @test length(responses_before.values) == 0 + + Stateful.add_response!(cat, 1, false) + Stateful.add_response!(cat, 2, true) + + responses_after_add = Stateful.get_responses(cat) + @test length(responses_after_add.indices) == 2 + @test length(responses_after_add.values) == 2 + + Stateful.reset!(cat) + responses_after_reset = Stateful.get_responses(cat) + @test length(responses_after_reset.indices) == 0 + @test length(responses_after_reset.values) == 0 + end + + # Test the next_item function + @testset "basic next_item tests" begin + Stateful.add_response!(cat, 1, false) + Stateful.add_response!(cat, 2, true) + + item = Stateful.next_item(cat) + @test isa(item, Integer) + @test item >= 1 + @test item >= 3 + @test item <= item_bank_length + end + + if supports_ranked_and_criteria + @testset "basic ranked/criteria tests" begin + items = Stateful.ranked_items(cat) + @test length(items) == item_bank_length + + criteria = Stateful.item_criteria(cat) + @test length(criteria) == item_bank_length + end + end + + if supports_rollback + @testset "basic rollback tests" begin + Stateful.reset!(cat) + Stateful.add_response!(cat, 1, false) + Stateful.add_response!(cat, 2, true) + Stateful.rollback!(cat) + responses_after_rollback = Stateful.get_responses(cat) + @test length(responses_after_rollback.indices) == 1 + @test length(responses_after_rollback.values) == 1 + end + end + + Stateful.reset!(cat) + + @testset "basic get_ability tests" begin + Stateful.add_response!(cat, 1, false) + Stateful.add_response!(cat, 2, true) + ability = Stateful.get_ability(cat) + @test isa(ability, Tuple) + @test length(ability) == 2 + @test isa(ability[1], Float64) + end + + if supports_rollback + @testset "rollback ability tests" begin + Stateful.add_response!(cat, 1, false) + ability1 = Stateful.get_ability(cat) + Stateful.add_response!(cat, 2, true) + ability2 = Stateful.get_ability(cat) + Stateful.rollback!(cat) + @test Stateful.get_ability(cat) == ability1 + Stateful.add_response!(cat, 2, true) + @test Stateful.get_ability(cat) == ability2 + end + end +end + +end \ No newline at end of file diff --git a/src/ComputerAdaptiveTesting.jl b/src/ComputerAdaptiveTesting.jl index be76ee9..328a71c 100644 --- a/src/ComputerAdaptiveTesting.jl +++ b/src/ComputerAdaptiveTesting.jl @@ -10,6 +10,9 @@ export NextItemRules, TerminationConditions export CatConfig, Sim, DecisionTree export Stateful, Comparison +# Extension modules +public require_testext + # Vendored dependencies include("./vendor/PushVectors.jl") @@ -44,4 +47,12 @@ include("./Comparison.jl") include("./precompiles.jl") +function require_testext() + TestExt = Base.get_extension(@__MODULE__, :TestExt) + if TestExt === nothing + error("Failed to load extension module TestExt.") + end + return TestExt end + +end \ No newline at end of file diff --git a/test/stateful.jl b/test/stateful.jl index 131e084..f8419a3 100644 --- a/test/stateful.jl +++ b/test/stateful.jl @@ -6,6 +6,7 @@ using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule using ComputerAdaptiveTesting: Stateful + using ComputerAdaptiveTesting: require_testext using ResumableFunctions using Test: @test, @testset @@ -26,7 +27,7 @@ @testset "StatefulCatConfig basic usage" begin rules = CatRules( FixedItemsTerminationCondition(2), - Dummy.DummyAbilityEstimator(0), + Dummy.DummyAbilityEstimator(0.0), RandomNextItemRule() ) @@ -54,7 +55,7 @@ @testset "Stateful next item selection" begin rules = CatRules( FixedItemsTerminationCondition(2), - Dummy.DummyAbilityEstimator(0), + Dummy.DummyAbilityEstimator(0.0), RandomNextItemRule() ) cat_config = Stateful.StatefulCatConfig(rules, item_bank) @@ -69,4 +70,23 @@ @test 1 <= second_item <= 4 @test second_item != first_item # Should select different item end + + @testset "Standard interface tests" begin + rules = CatRules( + FixedItemsTerminationCondition(2), + Dummy.DummyAbilityEstimator(0.0), + RandomNextItemRule() + ) + + # Initialize config + cat_config = Stateful.StatefulCatConfig(rules, item_bank) + + # Run the standard interface tests + TestExt = require_testext() + TestExt.test_stateful_cat_1d_dich_ib( + cat_config, + 4; + supports_ranked_and_criteria = false, + ) + end end From 0bb3741b642af2edda6f6c3c7235f95a81b9540d Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 18 Apr 2025 08:07:30 +0300 Subject: [PATCH 3/3] Add item bank funcs to stateful + interface test --- ext/TestExt.jl | 25 ++++++++++++++++++++++--- src/Stateful.jl | 31 ++++++++++++++++++++++++++++++- test/stateful.jl | 1 + 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/ext/TestExt.jl b/ext/TestExt.jl index 3722c74..c305dfc 100644 --- a/ext/TestExt.jl +++ b/ext/TestExt.jl @@ -2,8 +2,9 @@ module TestExt using Test using ComputerAdaptiveTesting: Stateful +using FittedItemBanks: AbstractItemBank, ItemResponse, resp -export test_stateful_cat_1d_dich_ib +export test_stateful_cat_1d_dich_ib, test_stateful_cat_item_bank_1d_dich_ib function test_stateful_cat_1d_dich_ib( cat::Stateful.StatefulCat, @@ -66,9 +67,8 @@ function test_stateful_cat_1d_dich_ib( end end - Stateful.reset!(cat) - @testset "basic get_ability tests" begin + Stateful.reset!(cat) Stateful.add_response!(cat, 1, false) Stateful.add_response!(cat, 2, true) ability = Stateful.get_ability(cat) @@ -79,6 +79,7 @@ function test_stateful_cat_1d_dich_ib( if supports_rollback @testset "rollback ability tests" begin + Stateful.reset!(cat) Stateful.add_response!(cat, 1, false) ability1 = Stateful.get_ability(cat) Stateful.add_response!(cat, 2, true) @@ -91,4 +92,22 @@ function test_stateful_cat_1d_dich_ib( end end +function test_stateful_cat_item_bank_1d_dich_ib( + cat::Stateful.StatefulCat, + item_bank::AbstractItemBank, + points=[-.78, 0.0, .78], + margin=0.05, +) + if length(item_bank) != Stateful.item_bank_size(cat) + error("Item bank length does not match the cat's item bank size.") + end + for i in 1:length(item_bank) + for point in points + cat_prob = Stateful.item_response_function(cat, i, true, point) + ib_prob = resp(ItemResponse(item_bank, i), true, point) + @test cat_prob ≈ ib_prob rtol=margin + end + end +end + end \ No newline at end of file diff --git a/src/Stateful.jl b/src/Stateful.jl index b506088..eb7d535 100644 --- a/src/Stateful.jl +++ b/src/Stateful.jl @@ -7,7 +7,7 @@ module Stateful using DocStringExtensions -using FittedItemBanks: AbstractItemBank, ResponseType +using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, resp using ..Aggregators: TrackedResponses, Aggregators using ..CatConfig: CatLoopConfig, CatRules using ..Responses: BareResponses, Response, Responses @@ -124,6 +124,25 @@ but should attempt to interoperate with ComputerAdaptiveTesting.jl. """ function get_ability end +""" +```julia +$(FUNCTIONNAME)(config::StatefulCat) +```` + +Return number of items in the current item bank. +""" +function item_bank_size end + +""" +```julia +$(FUNCTIONNAME)(config::StatefulCat, index::IndexT, response::ResponseT, ability::AbilityT) -> Float +```` + +Return the probability of a `response` to item at `index` for someone with +a certain `ability` according to the IRT model backing the CAT. +""" +function item_response_function end + ## Running the CAT function Sim.run_cat(cat_config::CatLoopConfig{RulesT}, ib_labels = nothing) where {RulesT <: StatefulCat} @@ -220,6 +239,16 @@ function get_ability(config::StatefulCatConfig) return (config.rules.ability_estimator(config.tracked_responses[]), nothing) end +function item_bank_size(config::StatefulCatConfig) + return length(config.tracked_responses[].item_bank) +end + +function item_response_function(config::StatefulCatConfig, index, response, ability) + item_bank = config.tracked_responses[].item_bank + item_response = ItemResponse(item_bank, index) + return resp(item_response, response, ability) +end + ## TODO: Implementation for MaterializedDecisionTree end diff --git a/test/stateful.jl b/test/stateful.jl index f8419a3..2bb84c8 100644 --- a/test/stateful.jl +++ b/test/stateful.jl @@ -88,5 +88,6 @@ 4; supports_ranked_and_criteria = false, ) + TestExt.test_stateful_cat_item_bank_1d_dich_ib(cat_config, item_bank) end end