diff --git a/lib/irt_ruby.rb b/lib/irt_ruby.rb index 8fb5314..6479b61 100644 --- a/lib/irt_ruby.rb +++ b/lib/irt_ruby.rb @@ -2,6 +2,7 @@ require "irt_ruby/version" require "matrix" +require "irt_ruby/model_options_validator" require "irt_ruby/response_data_validator" require "irt_ruby/rasch_model" require "irt_ruby/two_parameter_model" diff --git a/lib/irt_ruby/model_options_validator.rb b/lib/irt_ruby/model_options_validator.rb new file mode 100644 index 0000000..56973b2 --- /dev/null +++ b/lib/irt_ruby/model_options_validator.rb @@ -0,0 +1,38 @@ +# frozen_string_literal: true + +module IrtRuby + # Validates optimization hyperparameters shared by IRT model implementations. + module ModelOptionsValidator + module_function + + def validate!(max_iter:, tolerance:, param_tolerance:, learning_rate:, decay_factor:) + validate_positive_integer!(:max_iter, max_iter) + validate_positive_finite_numeric!(:tolerance, tolerance) + validate_positive_finite_numeric!(:param_tolerance, param_tolerance) + validate_positive_finite_numeric!(:learning_rate, learning_rate) + validate_decay_factor!(decay_factor) + end + + def validate_positive_integer!(name, value) + return if value.is_a?(Integer) && value.positive? + + raise ArgumentError, "#{name} must be a positive Integer" + end + + def validate_positive_finite_numeric!(name, value) + return if finite_numeric?(value) && value.positive? + + raise ArgumentError, "#{name} must be a positive finite Numeric" + end + + def validate_decay_factor!(value) + return if finite_numeric?(value) && value.positive? && value < 1 + + raise ArgumentError, "decay_factor must be a finite Numeric strictly between 0 and 1" + end + + def finite_numeric?(value) + value.is_a?(Numeric) && !value.is_a?(Complex) && value.finite? + end + end +end diff --git a/lib/irt_ruby/rasch_model.rb b/lib/irt_ruby/rasch_model.rb index 086e70e..a499092 100644 --- a/lib/irt_ruby/rasch_model.rb +++ b/lib/irt_ruby/rasch_model.rb @@ -1,5 +1,6 @@ # frozen_string_literal: true +require "irt_ruby/model_options_validator" require "irt_ruby/response_data_validator" module IrtRuby @@ -21,6 +22,12 @@ def initialize(data, # data: A Matrix or array-of-arrays of responses (0/1 or nil for missing). # missing_strategy: :ignore (skip), :treat_as_incorrect, :treat_as_correct + ModelOptionsValidator.validate!(max_iter: max_iter, + tolerance: tolerance, + param_tolerance: param_tolerance, + learning_rate: learning_rate, + decay_factor: decay_factor) + @data = data @data_array = ResponseDataValidator.validate!(data) num_rows = @data_array.size diff --git a/lib/irt_ruby/three_parameter_model.rb b/lib/irt_ruby/three_parameter_model.rb index f172de2..9b3410b 100644 --- a/lib/irt_ruby/three_parameter_model.rb +++ b/lib/irt_ruby/three_parameter_model.rb @@ -1,5 +1,6 @@ # frozen_string_literal: true +require "irt_ruby/model_options_validator" require "irt_ruby/response_data_validator" module IrtRuby @@ -20,6 +21,12 @@ def initialize(data, learning_rate: 0.01, decay_factor: 0.5, missing_strategy: :ignore) + ModelOptionsValidator.validate!(max_iter: max_iter, + tolerance: tolerance, + param_tolerance: param_tolerance, + learning_rate: learning_rate, + decay_factor: decay_factor) + @data = data @data_array = ResponseDataValidator.validate!(data) num_rows = @data_array.size diff --git a/lib/irt_ruby/two_parameter_model.rb b/lib/irt_ruby/two_parameter_model.rb index e9a5a98..04e1c6f 100644 --- a/lib/irt_ruby/two_parameter_model.rb +++ b/lib/irt_ruby/two_parameter_model.rb @@ -1,5 +1,6 @@ # frozen_string_literal: true +require "irt_ruby/model_options_validator" require "irt_ruby/response_data_validator" module IrtRuby @@ -16,6 +17,12 @@ class TwoParameterModel def initialize(data, max_iter: 1000, tolerance: 1e-6, param_tolerance: 1e-6, learning_rate: 0.01, decay_factor: 0.5, missing_strategy: :ignore) + ModelOptionsValidator.validate!(max_iter: max_iter, + tolerance: tolerance, + param_tolerance: param_tolerance, + learning_rate: learning_rate, + decay_factor: decay_factor) + @data = data @data_array = ResponseDataValidator.validate!(data) num_rows = @data_array.size diff --git a/spec/irt_ruby/rasch_model_spec.rb b/spec/irt_ruby/rasch_model_spec.rb index b4fa443..aa468e4 100644 --- a/spec/irt_ruby/rasch_model_spec.rb +++ b/spec/irt_ruby/rasch_model_spec.rb @@ -4,6 +4,7 @@ RSpec.describe IrtRuby::RaschModel do it_behaves_like "response data validation" + it_behaves_like "model optimization option validation" let(:data_array) do [ diff --git a/spec/irt_ruby/three_parameter_model_spec.rb b/spec/irt_ruby/three_parameter_model_spec.rb index e120ea6..38fcab8 100644 --- a/spec/irt_ruby/three_parameter_model_spec.rb +++ b/spec/irt_ruby/three_parameter_model_spec.rb @@ -4,6 +4,7 @@ RSpec.describe IrtRuby::ThreeParameterModel do it_behaves_like "response data validation" + it_behaves_like "model optimization option validation" let(:data_array) do [ diff --git a/spec/irt_ruby/two_parameter_model_spec.rb b/spec/irt_ruby/two_parameter_model_spec.rb index 4c9e906..11ef270 100644 --- a/spec/irt_ruby/two_parameter_model_spec.rb +++ b/spec/irt_ruby/two_parameter_model_spec.rb @@ -4,6 +4,7 @@ RSpec.describe IrtRuby::TwoParameterModel do it_behaves_like "response data validation" + it_behaves_like "model optimization option validation" let(:data_array) do [ diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb index 0ddcf0c..c1ec94a 100644 --- a/spec/spec_helper.rb +++ b/spec/spec_helper.rb @@ -41,6 +41,24 @@ end end +RSpec.shared_examples "model optimization option validation" do + let(:valid_data) { [[1, 0], [0, 1]] } + + { + max_iter: [0, -1, 1.5, "100", nil], + tolerance: [0, -1e-6, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(1, 0), "1e-6", nil], + param_tolerance: [0, -1e-6, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(1, 0), "1e-6", nil], + learning_rate: [0, -0.01, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(0.01, 0), "0.01", nil], + decay_factor: [0, 1, -0.1, 1.1, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(0.5, 0), "0.5", nil] + }.each do |option, invalid_values| + invalid_values.each do |value| + it "rejects #{option}=#{value.inspect}" do + expect { described_class.new(valid_data, option => value) }.to raise_error(ArgumentError, /\A#{option} /) + end + end + end +end + RSpec.configure do |config| # Enable flags like --only-failures and --next-failure config.example_status_persistence_file_path = ".rspec_status"