Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/irt_ruby.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

require "irt_ruby/version"
require "matrix"
require "irt_ruby/model_options_validator"
require "irt_ruby/response_data_validator"
require "irt_ruby/rasch_model"
require "irt_ruby/two_parameter_model"
Expand Down
38 changes: 38 additions & 0 deletions lib/irt_ruby/model_options_validator.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# frozen_string_literal: true

module IrtRuby
# Validates optimization hyperparameters shared by IRT model implementations.
module ModelOptionsValidator
module_function

def validate!(max_iter:, tolerance:, param_tolerance:, learning_rate:, decay_factor:)
validate_positive_integer!(:max_iter, max_iter)
validate_positive_finite_numeric!(:tolerance, tolerance)
validate_positive_finite_numeric!(:param_tolerance, param_tolerance)
validate_positive_finite_numeric!(:learning_rate, learning_rate)
validate_decay_factor!(decay_factor)
end

def validate_positive_integer!(name, value)
return if value.is_a?(Integer) && value.positive?

raise ArgumentError, "#{name} must be a positive Integer"
end

def validate_positive_finite_numeric!(name, value)
return if finite_numeric?(value) && value.positive?

raise ArgumentError, "#{name} must be a positive finite Numeric"
end

def validate_decay_factor!(value)
return if finite_numeric?(value) && value.positive? && value < 1

raise ArgumentError, "decay_factor must be a finite Numeric strictly between 0 and 1"
end

def finite_numeric?(value)
value.is_a?(Numeric) && !value.is_a?(Complex) && value.finite?
end
end
end
7 changes: 7 additions & 0 deletions lib/irt_ruby/rasch_model.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# frozen_string_literal: true

require "irt_ruby/model_options_validator"
require "irt_ruby/response_data_validator"

module IrtRuby
Expand All @@ -21,6 +22,12 @@ def initialize(data,
# data: A Matrix or array-of-arrays of responses (0/1 or nil for missing).
# missing_strategy: :ignore (skip), :treat_as_incorrect, :treat_as_correct

ModelOptionsValidator.validate!(max_iter: max_iter,
tolerance: tolerance,
param_tolerance: param_tolerance,
learning_rate: learning_rate,
decay_factor: decay_factor)

@data = data
@data_array = ResponseDataValidator.validate!(data)
num_rows = @data_array.size
Expand Down
7 changes: 7 additions & 0 deletions lib/irt_ruby/three_parameter_model.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# frozen_string_literal: true

require "irt_ruby/model_options_validator"
require "irt_ruby/response_data_validator"

module IrtRuby
Expand All @@ -20,6 +21,12 @@ def initialize(data,
learning_rate: 0.01,
decay_factor: 0.5,
missing_strategy: :ignore)
ModelOptionsValidator.validate!(max_iter: max_iter,
tolerance: tolerance,
param_tolerance: param_tolerance,
learning_rate: learning_rate,
decay_factor: decay_factor)

@data = data
@data_array = ResponseDataValidator.validate!(data)
num_rows = @data_array.size
Expand Down
7 changes: 7 additions & 0 deletions lib/irt_ruby/two_parameter_model.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# frozen_string_literal: true

require "irt_ruby/model_options_validator"
require "irt_ruby/response_data_validator"

module IrtRuby
Expand All @@ -16,6 +17,12 @@ class TwoParameterModel
def initialize(data, max_iter: 1000, tolerance: 1e-6, param_tolerance: 1e-6,
learning_rate: 0.01, decay_factor: 0.5,
missing_strategy: :ignore)
ModelOptionsValidator.validate!(max_iter: max_iter,
tolerance: tolerance,
param_tolerance: param_tolerance,
learning_rate: learning_rate,
decay_factor: decay_factor)

@data = data
@data_array = ResponseDataValidator.validate!(data)
num_rows = @data_array.size
Expand Down
1 change: 1 addition & 0 deletions spec/irt_ruby/rasch_model_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

RSpec.describe IrtRuby::RaschModel do
it_behaves_like "response data validation"
it_behaves_like "model optimization option validation"

let(:data_array) do
[
Expand Down
1 change: 1 addition & 0 deletions spec/irt_ruby/three_parameter_model_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

RSpec.describe IrtRuby::ThreeParameterModel do
it_behaves_like "response data validation"
it_behaves_like "model optimization option validation"

let(:data_array) do
[
Expand Down
1 change: 1 addition & 0 deletions spec/irt_ruby/two_parameter_model_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

RSpec.describe IrtRuby::TwoParameterModel do
it_behaves_like "response data validation"
it_behaves_like "model optimization option validation"

let(:data_array) do
[
Expand Down
18 changes: 18 additions & 0 deletions spec/spec_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@
end
end

RSpec.shared_examples "model optimization option validation" do
let(:valid_data) { [[1, 0], [0, 1]] }

{
max_iter: [0, -1, 1.5, "100", nil],
tolerance: [0, -1e-6, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(1, 0), "1e-6", nil],
param_tolerance: [0, -1e-6, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(1, 0), "1e-6", nil],
learning_rate: [0, -0.01, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(0.01, 0), "0.01", nil],
decay_factor: [0, 1, -0.1, 1.1, Float::INFINITY, -Float::INFINITY, Float::NAN, Complex(0.5, 0), "0.5", nil]
}.each do |option, invalid_values|
invalid_values.each do |value|
it "rejects #{option}=#{value.inspect}" do
expect { described_class.new(valid_data, option => value) }.to raise_error(ArgumentError, /\A#{option} /)
end
end
end
end

RSpec.configure do |config|
# Enable flags like --only-failures and --next-failure
config.example_status_persistence_file_path = ".rspec_status"
Expand Down
Loading