diff --git a/README.md b/README.md index 6558f6c..f84e8a4 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,20 @@ IrtRuby::TwoParameterModel.new( decay_factor: 0.5 ) ``` + +### Reproducible Initialization +Each model initializes parameters randomly. By default, constructors use Ruby's global random number generator, preserving the historical behavior and honoring any external `srand` calls. For reproducible model initialization without resetting or consuming global RNG state, pass `seed:`: + +```ruby +model_a = IrtRuby::ThreeParameterModel.new(data, seed: 1234) +model_b = IrtRuby::ThreeParameterModel.new(data, seed: 1234) + +# Same data, options, and seed produce identical fitted results. +model_a.fit == model_b.fit #=> true +``` + +The `seed:` keyword is available for `RaschModel`, `TwoParameterModel`, and `ThreeParameterModel`. + ### Parameter Clamping For 2PL and 3PL: diff --git a/lib/irt_ruby/rasch_model.rb b/lib/irt_ruby/rasch_model.rb index a499092..c6f44ff 100644 --- a/lib/irt_ruby/rasch_model.rb +++ b/lib/irt_ruby/rasch_model.rb @@ -18,7 +18,8 @@ def initialize(data, param_tolerance: 1e-6, learning_rate: 0.01, decay_factor: 0.5, - missing_strategy: :ignore) + missing_strategy: :ignore, + seed: nil) # data: A Matrix or array-of-arrays of responses (0/1 or nil for missing). # missing_strategy: :ignore (skip), :treat_as_incorrect, :treat_as_correct @@ -36,10 +37,11 @@ def initialize(data, raise ArgumentError, "missing_strategy must be one of #{MISSING_STRATEGIES}" unless MISSING_STRATEGIES.include?(missing_strategy) @missing_strategy = missing_strategy + @random = seed.nil? ? nil : Random.new(seed) # Initialize parameters near zero - @abilities = Array.new(num_rows) { rand(-0.25..0.25) } - @difficulties = Array.new(num_cols) { rand(-0.25..0.25) } + @abilities = Array.new(num_rows) { random_between(-0.25..0.25) } + @difficulties = Array.new(num_cols) { random_between(-0.25..0.25) } @max_iter = max_iter @tolerance = tolerance @@ -52,6 +54,11 @@ def sigmoid(x) 1.0 / (1.0 + Math.exp(-x)) end + def random_between(range) + @random ? @random.rand(range) : rand(range) + end + private :random_between + def resolve_missing(resp) return [resp, false] unless resp.nil? diff --git a/lib/irt_ruby/three_parameter_model.rb b/lib/irt_ruby/three_parameter_model.rb index 9b3410b..28a58c3 100644 --- a/lib/irt_ruby/three_parameter_model.rb +++ b/lib/irt_ruby/three_parameter_model.rb @@ -20,7 +20,8 @@ def initialize(data, param_tolerance: 1e-6, learning_rate: 0.01, decay_factor: 0.5, - missing_strategy: :ignore) + missing_strategy: :ignore, + seed: nil) ModelOptionsValidator.validate!(max_iter: max_iter, tolerance: tolerance, param_tolerance: param_tolerance, @@ -35,12 +36,13 @@ def initialize(data, raise ArgumentError, "missing_strategy must be one of #{MISSING_STRATEGIES}" unless MISSING_STRATEGIES.include?(missing_strategy) @missing_strategy = missing_strategy + @random = seed.nil? ? nil : Random.new(seed) # Initialize parameters - @abilities = Array.new(num_rows) { rand(-0.25..0.25) } - @difficulties = Array.new(num_cols) { rand(-0.25..0.25) } - @discriminations = Array.new(num_cols) { rand(0.5..1.5) } - @guessings = Array.new(num_cols) { rand(0.0..0.3) } + @abilities = Array.new(num_rows) { random_between(-0.25..0.25) } + @difficulties = Array.new(num_cols) { random_between(-0.25..0.25) } + @discriminations = Array.new(num_cols) { random_between(0.5..1.5) } + @guessings = Array.new(num_cols) { random_between(0.0..0.3) } @max_iter = max_iter @tolerance = tolerance @@ -53,6 +55,11 @@ def sigmoid(x) 1.0 / (1.0 + Math.exp(-x)) end + def random_between(range) + @random ? @random.rand(range) : rand(range) + end + private :random_between + # Probability for the 3PL model: c + (1-c)*sigmoid(a*(θ - b)) def probability(theta, a, b, c) c + ((1.0 - c) * sigmoid(a * (theta - b))) diff --git a/lib/irt_ruby/two_parameter_model.rb b/lib/irt_ruby/two_parameter_model.rb index 04e1c6f..1491fb5 100644 --- a/lib/irt_ruby/two_parameter_model.rb +++ b/lib/irt_ruby/two_parameter_model.rb @@ -16,7 +16,7 @@ class TwoParameterModel def initialize(data, max_iter: 1000, tolerance: 1e-6, param_tolerance: 1e-6, learning_rate: 0.01, decay_factor: 0.5, - missing_strategy: :ignore) + missing_strategy: :ignore, seed: nil) ModelOptionsValidator.validate!(max_iter: max_iter, tolerance: tolerance, param_tolerance: param_tolerance, @@ -31,12 +31,13 @@ def initialize(data, max_iter: 1000, tolerance: 1e-6, param_tolerance: 1e-6, raise ArgumentError, "missing_strategy must be one of #{MISSING_STRATEGIES}" unless MISSING_STRATEGIES.include?(missing_strategy) @missing_strategy = missing_strategy + @random = seed.nil? ? nil : Random.new(seed) # Initialize parameters # Typically: ability ~ 0, difficulty ~ 0, discrimination ~ 1 - @abilities = Array.new(num_rows) { rand(-0.25..0.25) } - @difficulties = Array.new(num_cols) { rand(-0.25..0.25) } - @discriminations = Array.new(num_cols) { rand(0.5..1.5) } + @abilities = Array.new(num_rows) { random_between(-0.25..0.25) } + @difficulties = Array.new(num_cols) { random_between(-0.25..0.25) } + @discriminations = Array.new(num_cols) { random_between(0.5..1.5) } @max_iter = max_iter @tolerance = tolerance @@ -49,6 +50,11 @@ def sigmoid(x) 1.0 / (1.0 + Math.exp(-x)) end + def random_between(range) + @random ? @random.rand(range) : rand(range) + end + private :random_between + def resolve_missing(resp) return [resp, false] unless resp.nil? diff --git a/spec/irt_ruby/rasch_model_spec.rb b/spec/irt_ruby/rasch_model_spec.rb index aa468e4..be9ddbd 100644 --- a/spec/irt_ruby/rasch_model_spec.rb +++ b/spec/irt_ruby/rasch_model_spec.rb @@ -5,6 +5,7 @@ RSpec.describe IrtRuby::RaschModel do it_behaves_like "response data validation" it_behaves_like "model optimization option validation" + it_behaves_like "seeded model initialization", %i[abilities difficulties] let(:data_array) do [ diff --git a/spec/irt_ruby/three_parameter_model_spec.rb b/spec/irt_ruby/three_parameter_model_spec.rb index 38fcab8..786a899 100644 --- a/spec/irt_ruby/three_parameter_model_spec.rb +++ b/spec/irt_ruby/three_parameter_model_spec.rb @@ -5,6 +5,7 @@ RSpec.describe IrtRuby::ThreeParameterModel do it_behaves_like "response data validation" it_behaves_like "model optimization option validation" + it_behaves_like "seeded model initialization", %i[abilities difficulties discriminations guessings] let(:data_array) do [ diff --git a/spec/irt_ruby/two_parameter_model_spec.rb b/spec/irt_ruby/two_parameter_model_spec.rb index 11ef270..9fb598e 100644 --- a/spec/irt_ruby/two_parameter_model_spec.rb +++ b/spec/irt_ruby/two_parameter_model_spec.rb @@ -5,6 +5,7 @@ RSpec.describe IrtRuby::TwoParameterModel do it_behaves_like "response data validation" it_behaves_like "model optimization option validation" + it_behaves_like "seeded model initialization", %i[abilities difficulties discriminations] let(:data_array) do [ diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb index c1ec94a..444abab 100644 --- a/spec/spec_helper.rb +++ b/spec/spec_helper.rb @@ -59,6 +59,46 @@ end end +RSpec.shared_examples "seeded model initialization" do |parameter_names| + let(:seeded_fit_options) { { max_iter: 50, learning_rate: 0.05 } } + + def seeded_parameter_snapshot(model, parameter_names) + parameter_names.to_h do |parameter_name| + [parameter_name, model.instance_variable_get("@#{parameter_name}").dup] + end + end + + it "produces identical initial and fitted parameters with the same seed" do + model1 = described_class.new(data_array, **seeded_fit_options, seed: 12_345) + model2 = described_class.new(data_array, **seeded_fit_options, seed: 12_345) + + expect(seeded_parameter_snapshot(model1, parameter_names)).to eq( + seeded_parameter_snapshot(model2, parameter_names) + ) + expect(model1.fit).to eq(model2.fit) + end + + it "produces different initial parameters with different seeds" do + model1 = described_class.new(data_array, **seeded_fit_options, seed: 12_345) + model2 = described_class.new(data_array, **seeded_fit_options, seed: 54_321) + + expect(seeded_parameter_snapshot(model1, parameter_names)).not_to eq( + seeded_parameter_snapshot(model2, parameter_names) + ) + end + + it "does not reset or consume Ruby's global random number generator" do + srand(98_765) + expected_values = Array.new(5) { rand } + + srand(98_765) + described_class.new(data_array, **seeded_fit_options, seed: 12_345) + actual_values = Array.new(5) { rand } + + expect(actual_values).to eq(expected_values) + end +end + RSpec.configure do |config| # Enable flags like --only-failures and --next-failure config.example_status_persistence_file_path = ".rspec_status"