diff --git a/active_utils.gemspec b/active_utils.gemspec index e3b59c0..94779d2 100644 --- a/active_utils.gemspec +++ b/active_utils.gemspec @@ -18,6 +18,7 @@ Gem::Specification.new do |s| s.add_dependency('activesupport', '>= 4.2') s.add_dependency('i18n') + s.add_dependency('net-http-persistent', '~> 4.0') s.add_development_dependency('rake') s.add_development_dependency('minitest') diff --git a/lib/active_utils/posts_data.rb b/lib/active_utils/posts_data.rb index 48d8200..f2bc235 100644 --- a/lib/active_utils/posts_data.rb +++ b/lib/active_utils/posts_data.rb @@ -27,6 +27,37 @@ def self.included(base) base.proxy_address = Connection::PROXY_ADDRESS base.class_attribute :proxy_port + + base.class_attribute :persistent_connections + base.persistent_connections = false + + base.class_attribute :pool_size + base.pool_size = 100 + + base.class_attribute :pool_idle_timeout + base.pool_idle_timeout = 60 + + base.class_attribute :pool_keep_alive + base.pool_keep_alive = 60 + + base.class_attribute :pool_max_requests + base.pool_max_requests = 100 + + base.define_singleton_method(:connection_pool) do + @connection_pool ||= begin + require 'net/http/persistent' + pool = Net::HTTP::Persistent.new(name: name, pool_size: pool_size) + pool.idle_timeout = pool_idle_timeout + pool.keep_alive = pool_keep_alive + pool.max_requests = pool_max_requests + pool + end + end + + base.define_singleton_method(:clear_connection_pool!) do + @connection_pool&.shutdown + @connection_pool = nil + end end def ssl_get(endpoint, headers={}) @@ -45,26 +76,30 @@ def raw_ssl_request(method, endpoint, data, headers = {}) logger.warn "#{self.class} using ssl_strict=false, which is insecure" if logger unless ssl_strict logger.warn "#{self.class} posting to plaintext endpoint, which is insecure" if logger unless endpoint.to_s =~ /^https:/ - connection = new_connection(endpoint) - connection.open_timeout = open_timeout - connection.read_timeout = read_timeout - connection.retry_safe = retry_safe - connection.verify_peer = ssl_strict - connection.ssl_version = ssl_version - connection.logger = logger - connection.max_retries = max_retries - connection.tag = self.class.name - connection.wiredump_device = wiredump_device - - connection.pem = @options[:pem] if @options - connection.pem_password = @options[:pem_password] if @options - - connection.ignore_http_status = @options[:ignore_http_status] if @options - - connection.proxy_address = proxy_address - connection.proxy_port = proxy_port - - connection.request(method, data, headers) + if persistent_connections + persistent_ssl_request(method, endpoint, data, headers) + else + connection = new_connection(endpoint) + connection.open_timeout = open_timeout + connection.read_timeout = read_timeout + connection.retry_safe = retry_safe + connection.verify_peer = ssl_strict + connection.ssl_version = ssl_version + connection.logger = logger + connection.max_retries = max_retries + connection.tag = self.class.name + connection.wiredump_device = wiredump_device + + connection.pem = @options[:pem] if @options + connection.pem_password = @options[:pem_password] if @options + + connection.ignore_http_status = @options[:ignore_http_status] if @options + + connection.proxy_address = proxy_address + connection.proxy_port = proxy_port + + connection.request(method, data, headers) + end end private @@ -73,6 +108,43 @@ def new_connection(endpoint) Connection.new(endpoint) end + def persistent_ssl_request(method, endpoint, data, headers) + pool = self.class.connection_pool + uri = endpoint.is_a?(URI) ? endpoint : URI.parse(endpoint) + + pool.open_timeout = open_timeout + pool.read_timeout = read_timeout + + req = case method + when :get + raise ArgumentError, "GET requests do not support a request body" if data + Net::HTTP::Get.new(uri.request_uri, headers) + when :post + Net::HTTP::Post.new(uri.request_uri, Connection::RUBY_184_POST_HEADERS.merge(headers)).tap { |r| r.body = data } + when :put + Net::HTTP::Put.new(uri.request_uri, headers).tap { |r| r.body = data } + when :patch + Net::HTTP::Patch.new(uri.request_uri, headers).tap { |r| r.body = data } + when :delete + raise ArgumentError, "DELETE requests do not support a request body" if data + Net::HTTP::Delete.new(uri.request_uri, headers) + else + raise ArgumentError, "Unsupported request method #{method.to_s.upcase}" + end + + pool.request(uri, req) + rescue *NetworkConnectionRetries::DEFAULT_CONNECTION_ERRORS.keys => e + raise ActiveUtils::ConnectionError, NetworkConnectionRetries::DEFAULT_CONNECTION_ERRORS.fetch( + (NetworkConnectionRetries::DEFAULT_CONNECTION_ERRORS.keys & e.class.ancestors).first, + e.message + ) + rescue *NetworkConnectionRetries::DEFAULT_RETRY_ERRORS.keys => e + raise ActiveUtils::ConnectionError, NetworkConnectionRetries::DEFAULT_RETRY_ERRORS.fetch( + (NetworkConnectionRetries::DEFAULT_RETRY_ERRORS.keys & e.class.ancestors).first, + e.message + ) + end + def handle_response(response) case response.code.to_i when 200...300 diff --git a/lib/active_utils/version.rb b/lib/active_utils/version.rb index 13a4596..d78691b 100644 --- a/lib/active_utils/version.rb +++ b/lib/active_utils/version.rb @@ -1,3 +1,3 @@ module ActiveUtils - VERSION = "3.6.0" + VERSION = "3.7.0" end diff --git a/test/unit/posts_data_test.rb b/test/unit/posts_data_test.rb index ca19d64..b4c1b06 100644 --- a/test/unit/posts_data_test.rb +++ b/test/unit/posts_data_test.rb @@ -85,4 +85,202 @@ def test_respecting_environment_proxy_settings @poster.ssl_post('http://example.com', '') end end + + # --- Persistent connections tests --- + + def test_persistent_connections_default_off + assert_equal false, SSLPoster.persistent_connections + end + + def test_persistent_connections_uses_connection_when_off + SSLPoster.persistent_connections = false + Connection.any_instance.expects(:request).returns(stub(code: "200", body: "ok")) + + result = @poster.ssl_post("https://shopify.com", "data") + assert_equal "ok", result + ensure + SSLPoster.persistent_connections = false + end + + class PersistentPoster + include PostsData + + self.persistent_connections = true + self.pool_size = 5 + + attr_accessor :logger + end + + def teardown + # Reset pool state without calling shutdown (which would fail on mocks) + PersistentPoster.instance_variable_set(:@connection_pool, nil) + end + + def test_persistent_connections_enabled_uses_pool + pool = mock('pool') + pool.expects(:open_timeout=).with(2) + pool.expects(:read_timeout=).with(10) + pool.expects(:request).with( + instance_of(URI::HTTPS), + instance_of(Net::HTTP::Post) + ).returns(stub(code: "200", body: "pooled response")) + + PersistentPoster.instance_variable_set(:@connection_pool, pool) + + poster = PersistentPoster.new + result = poster.ssl_post("https://example.com", "data") + assert_equal "pooled response", result + end + + def test_persistent_connection_pool_is_per_class + pool_a = PersistentPoster.connection_pool + pool_b = PersistentPoster.connection_pool + assert_same pool_a, pool_b, "Same class should return the same pool instance" + ensure + PersistentPoster.connection_pool.shutdown + PersistentPoster.instance_variable_set(:@connection_pool, nil) + end + + def test_persistent_connection_per_request_timeout_override + pool = mock('pool') + pool.expects(:open_timeout=).with(5) + pool.expects(:read_timeout=).with(3) + pool.expects(:request).returns(stub(code: "200", body: "ok")) + + PersistentPoster.instance_variable_set(:@connection_pool, pool) + PersistentPoster.open_timeout = 5 + PersistentPoster.read_timeout = 3 + + poster = PersistentPoster.new + poster.ssl_post("https://example.com", "data") + ensure + PersistentPoster.open_timeout = 2 + PersistentPoster.read_timeout = 10 + end + + def test_persistent_connection_raises_connection_error_on_timeout + pool = mock('pool') + pool.expects(:open_timeout=) + pool.expects(:read_timeout=) + pool.expects(:request).raises(Net::OpenTimeout, "execution expired") + + PersistentPoster.instance_variable_set(:@connection_pool, pool) + + poster = PersistentPoster.new + error = assert_raises(ActiveUtils::ConnectionError) do + poster.ssl_post("https://example.com", "data") + end + assert_match(/timed out/, error.message) + end + + def test_persistent_connection_raises_connection_error_on_reset + pool = mock('pool') + pool.expects(:open_timeout=) + pool.expects(:read_timeout=) + pool.expects(:request).raises(Errno::ECONNRESET) + + PersistentPoster.instance_variable_set(:@connection_pool, pool) + + poster = PersistentPoster.new + error = assert_raises(ActiveUtils::ConnectionError) do + poster.ssl_post("https://example.com", "data") + end + assert_match(/reset/, error.message) + end + + def test_persistent_connection_raises_connection_error_on_refused + pool = mock('pool') + pool.expects(:open_timeout=) + pool.expects(:read_timeout=) + pool.expects(:request).raises(Errno::ECONNREFUSED) + + PersistentPoster.instance_variable_set(:@connection_pool, pool) + + poster = PersistentPoster.new + assert_raises(ActiveUtils::ConnectionError) do + poster.ssl_post("https://example.com", "data") + end + end + + def test_persistent_connection_raises_response_error_on_non_2xx + pool = mock('pool') + pool.expects(:open_timeout=) + pool.expects(:read_timeout=) + pool.expects(:request).returns(stub(code: "422", body: "bad", message: "Unprocessable Entity")) + + PersistentPoster.instance_variable_set(:@connection_pool, pool) + + poster = PersistentPoster.new + error = assert_raises(ActiveUtils::ResponseError) do + poster.ssl_post("https://example.com", "data") + end + assert_equal "422", error.response.code + end + + def test_persistent_connection_ssl_get + pool = mock('pool') + pool.expects(:open_timeout=) + pool.expects(:read_timeout=) + pool.expects(:request).with( + instance_of(URI::HTTPS), + instance_of(Net::HTTP::Get) + ).returns(stub(code: "200", body: "get response")) + + PersistentPoster.instance_variable_set(:@connection_pool, pool) + + poster = PersistentPoster.new + result = poster.ssl_get("https://example.com/path") + assert_equal "get response", result + end + + def test_persistent_connection_clear_pool + pool = PersistentPoster.connection_pool + refute_nil pool + PersistentPoster.clear_connection_pool! + assert_nil PersistentPoster.instance_variable_get(:@connection_pool) + end + + def test_persistent_connection_post_includes_content_type_header + pool = mock('pool') + pool.expects(:open_timeout=) + pool.expects(:read_timeout=) + pool.expects(:request).with( + instance_of(URI::HTTPS), + instance_of(Net::HTTP::Post) + ) do |_uri, req| + assert_equal "application/x-www-form-urlencoded", req["Content-Type"] + true + end.returns(stub(code: "200", body: "ok")) + + PersistentPoster.instance_variable_set(:@connection_pool, pool) + + poster = PersistentPoster.new + poster.ssl_post("https://example.com", "data") + end + + def test_persistent_connection_merges_custom_headers + pool = mock('pool') + pool.expects(:open_timeout=) + pool.expects(:read_timeout=) + pool.expects(:request).with( + instance_of(URI::HTTPS), + instance_of(Net::HTTP::Post) + ) do |_uri, req| + assert_equal "application/json", req["Content-Type"] + assert_equal "abc123", req["X-Custom"] + true + end.returns(stub(code: "200", body: "ok")) + + PersistentPoster.instance_variable_set(:@connection_pool, pool) + + poster = PersistentPoster.new + poster.ssl_post("https://example.com", "data", { "Content-Type" => "application/json", "X-Custom" => "abc123" }) + end + + def test_pool_config_defaults + assert_equal 100, SSLPoster.pool_size + assert_equal 60, SSLPoster.pool_idle_timeout + assert_equal 60, SSLPoster.pool_keep_alive + assert_equal 100, SSLPoster.pool_max_requests + end end