diff --git a/wrappercommon/src/main/java/com/genexus/cors/CORSHelper.java b/wrappercommon/src/main/java/com/genexus/cors/CORSHelper.java index a24f00b16..02b9a6c15 100644 --- a/wrappercommon/src/main/java/com/genexus/cors/CORSHelper.java +++ b/wrappercommon/src/main/java/com/genexus/cors/CORSHelper.java @@ -3,72 +3,130 @@ import com.genexus.common.interfaces.SpecificImplementation; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; public class CORSHelper { - public static String REQUEST_METHOD_HEADER_NAME = "Access-Control-Request-Method"; - public static String REQUEST_HEADERS_HEADER_NAME = "Access-Control-Request-Headers"; + public static final String REQUEST_METHOD_HEADER_NAME = "Access-Control-Request-Method"; + public static final String REQUEST_HEADERS_HEADER_NAME = "Access-Control-Request-Headers"; + public static final String ORIGIN_HEADER_NAME = "Origin"; - private static String CORS_ALLOWED_ORIGIN = "CORS_ALLOW_ORIGIN"; - private static String CORS_MAX_AGE_SECONDS = "86400"; - private static String PREFLIGHT_REQUEST = "OPTIONS"; + private static final String CORS_ALLOWED_ORIGIN_PROPERTY = "CORS_ALLOW_ORIGIN"; + private static final String CORS_MAX_AGE_SECONDS = "86400"; + private static final String PREFLIGHT_REQUEST = "OPTIONS"; + private static final String WILDCARD = "*"; + + // Test seam: tests can replace this to avoid wiring SpecificImplementation. + static Supplier allowedOriginSupplier = CORSHelper::readAllowedOriginFromConfig; public static boolean corsSupportEnabled() { - return getAllowedOrigin() != null; + return getConfiguredAllowedOrigin() != null; } + /** Build CORS headers from a multi-valued header map (JAX-RS style). */ public static HashMap getCORSHeaders(String httpMethod, Map> headers) { - if (getAllowedOrigin() == null) { - return null; - } + return corsHeaders(httpMethod, + getHeaderValue(ORIGIN_HEADER_NAME, headers), + getHeaderValue(REQUEST_METHOD_HEADER_NAME, headers), + getHeaderValue(REQUEST_HEADERS_HEADER_NAME, headers)); + } - String requestedMethod = getHeaderValue(REQUEST_METHOD_HEADER_NAME, headers); - String requestedHeaders = getHeaderValue(REQUEST_HEADERS_HEADER_NAME, headers); + /** Build CORS headers from individual header values (Servlet style). */ + public static HashMap getCORSHeaders(String httpMethod, String origin, String requestedMethod, String requestedHeaders) { + return corsHeaders(httpMethod, origin, requestedMethod, requestedHeaders); + } - return corsHeaders(httpMethod, requestedMethod, requestedHeaders); + /** True iff this request looks like a CORS preflight (OPTIONS + Origin + Access-Control-Request-Method). */ + public static boolean isPreflight(String httpMethod, String origin, String requestedMethod) { + return httpMethod != null + && PREFLIGHT_REQUEST.equalsIgnoreCase(httpMethod) + && origin != null && !origin.isEmpty() + && requestedMethod != null && !requestedMethod.isEmpty(); } - public static HashMap getCORSHeaders(String httpMethod, String requestedMethod, String requestedHeaders) { - return corsHeaders(httpMethod, requestedMethod, requestedHeaders); + private static String getConfiguredAllowedOrigin() { + String value = allowedOriginSupplier.get(); + return (value == null || value.isEmpty()) ? null : value; } - private static String getAllowedOrigin() { - String corsAllowedOrigin = SpecificImplementation.Application.getClientPreferences().getProperty(CORS_ALLOWED_ORIGIN, ""); - if (corsAllowedOrigin == null || corsAllowedOrigin.isEmpty()) { + private static String readAllowedOriginFromConfig() { + if (SpecificImplementation.Application == null) { return null; } - return corsAllowedOrigin; + return SpecificImplementation.Application.getClientPreferences().getProperty(CORS_ALLOWED_ORIGIN_PROPERTY, ""); } - private static HashMap corsHeaders(String httpMethodName, String requestedMethod, String requestedHeaders) { - String corsAllowedOrigin = getAllowedOrigin(); - if (corsAllowedOrigin == null) { + /** + * Resolve the value to send in Access-Control-Allow-Origin, or null when the + * request origin is not in the configured allowlist (no CORS headers should be emitted). + * + * Configuration accepts: + * "*" -> allow any origin (without credentials, per spec) + * "https://a.example" -> single origin + * "https://a.example,https://b.test" -> allowlist + */ + private static String resolveAllowedOrigin(String configuredOrigin, String requestOrigin) { + if (requestOrigin == null || requestOrigin.isEmpty()) { return null; } + if (WILDCARD.equals(configuredOrigin.trim())) { + return WILDCARD; + } + for (String allowed : configuredOrigin.split(",")) { + String candidate = allowed.trim(); + if (!candidate.isEmpty() && candidate.equals(requestOrigin)) { + return candidate; + } + } + return null; + } - boolean isPreflightRequest = httpMethodName.equalsIgnoreCase(PREFLIGHT_REQUEST); + private static HashMap corsHeaders(String httpMethodName, String origin, String requestedMethod, String requestedHeaders) { + String configuredOrigin = getConfiguredAllowedOrigin(); + if (configuredOrigin == null) return null; - HashMap corsHeaders = new HashMap<>(); - corsHeaders.put("Access-Control-Allow-Origin", corsAllowedOrigin); - corsHeaders.put("Access-Control-Allow-Credentials", "true"); - corsHeaders.put("Access-Control-Max-Age", CORS_MAX_AGE_SECONDS); + String allowOriginValue = resolveAllowedOrigin(configuredOrigin, origin); + if (allowOriginValue == null) return null; - if (isPreflightRequest && requestedHeaders != null && !requestedHeaders.isEmpty()) { - corsHeaders.put("Access-Control-Allow-Headers", requestedHeaders); + boolean isWildcard = WILDCARD.equals(allowOriginValue); + boolean isPreflight = httpMethodName != null && PREFLIGHT_REQUEST.equalsIgnoreCase(httpMethodName); + + HashMap corsHeaders = new LinkedHashMap<>(); + corsHeaders.put("Access-Control-Allow-Origin", allowOriginValue); + if (!isWildcard) { + // Vary lets caches differentiate responses per Origin. + corsHeaders.put("Vary", "Origin"); + // "*" + credentials is forbidden by the CORS spec, so credentials only when echoing a real origin. + corsHeaders.put("Access-Control-Allow-Credentials", "true"); } - if (isPreflightRequest && requestedMethod != null && !requestedMethod.isEmpty()) { - corsHeaders.put("Access-Control-Allow-Methods", requestedMethod); + + if (isPreflight) { + corsHeaders.put("Access-Control-Max-Age", CORS_MAX_AGE_SECONDS); + if (requestedMethod != null && !requestedMethod.isEmpty()) { + corsHeaders.put("Access-Control-Allow-Methods", requestedMethod); + } + if (requestedHeaders != null && !requestedHeaders.isEmpty()) { + corsHeaders.put("Access-Control-Allow-Headers", requestedHeaders); + } } return corsHeaders; } private static String getHeaderValue(String headerName, Map> headers) { + if (headers == null) return null; List value = headers.get(headerName); - if (value != null && value.size() > 0) { - return value.get(0); + if (value == null) { + for (Map.Entry> e : headers.entrySet()) { + if (e.getKey() != null && headerName.equalsIgnoreCase(e.getKey())) { + value = e.getValue(); + break; + } + } } + if (value != null && !value.isEmpty()) return value.get(0); return null; } } diff --git a/wrappercommon/src/test/java/com/genexus/cors/CORSHelperTest.java b/wrappercommon/src/test/java/com/genexus/cors/CORSHelperTest.java new file mode 100644 index 000000000..920ccd5b0 --- /dev/null +++ b/wrappercommon/src/test/java/com/genexus/cors/CORSHelperTest.java @@ -0,0 +1,189 @@ +package com.genexus.cors; + +import org.junit.After; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class CORSHelperTest { + + private final Supplier originalSupplier = CORSHelper.allowedOriginSupplier; + + @After + public void restoreSupplier() { + CORSHelper.allowedOriginSupplier = originalSupplier; + } + + private void configureAllowedOrigin(final String value) { + CORSHelper.allowedOriginSupplier = new Supplier() { + @Override public String get() { return value; } + }; + } + + @Test + public void corsSupportDisabledWhenNotConfigured() { + configureAllowedOrigin(""); + assertFalse(CORSHelper.corsSupportEnabled()); + assertNull(CORSHelper.getCORSHeaders("GET", "https://app.example", null, null)); + } + + @Test + public void corsSupportDisabledWhenSupplierReturnsNull() { + configureAllowedOrigin(null); + assertFalse(CORSHelper.corsSupportEnabled()); + } + + @Test + public void corsSupportEnabledWhenConfigured() { + configureAllowedOrigin("https://app.example"); + assertTrue(CORSHelper.corsSupportEnabled()); + } + + @Test + public void noHeadersWhenRequestHasNoOrigin() { + configureAllowedOrigin("https://app.example"); + assertNull(CORSHelper.getCORSHeaders("GET", null, null, null)); + assertNull(CORSHelper.getCORSHeaders("GET", "", null, null)); + } + + @Test + public void noHeadersWhenOriginNotInAllowlist() { + configureAllowedOrigin("https://app.example"); + assertNull(CORSHelper.getCORSHeaders("GET", "https://evil.example", null, null)); + } + + @Test + public void singleAllowedOriginSimpleRequest() { + configureAllowedOrigin("https://app.example"); + HashMap headers = CORSHelper.getCORSHeaders("GET", "https://app.example", null, null); + + assertNotNull(headers); + assertEquals("https://app.example", headers.get("Access-Control-Allow-Origin")); + assertEquals("Origin", headers.get("Vary")); + assertEquals("true", headers.get("Access-Control-Allow-Credentials")); + assertFalse("Max-Age belongs only on preflight responses", headers.containsKey("Access-Control-Max-Age")); + assertFalse(headers.containsKey("Access-Control-Allow-Methods")); + assertFalse(headers.containsKey("Access-Control-Allow-Headers")); + } + + @Test + public void preflightIncludesMaxAgeAndRequestedMethodAndHeaders() { + configureAllowedOrigin("https://app.example"); + HashMap headers = CORSHelper.getCORSHeaders( + "OPTIONS", "https://app.example", "PUT", "Content-Type, X-Custom"); + + assertNotNull(headers); + assertEquals("https://app.example", headers.get("Access-Control-Allow-Origin")); + assertEquals("Origin", headers.get("Vary")); + assertEquals("true", headers.get("Access-Control-Allow-Credentials")); + assertEquals("86400", headers.get("Access-Control-Max-Age")); + assertEquals("PUT", headers.get("Access-Control-Allow-Methods")); + assertEquals("Content-Type, X-Custom", headers.get("Access-Control-Allow-Headers")); + } + + @Test + public void wildcardOriginNeverCombinesWithCredentials() { + configureAllowedOrigin("*"); + HashMap headers = CORSHelper.getCORSHeaders("GET", "https://anything.example", null, null); + + assertNotNull(headers); + assertEquals("*", headers.get("Access-Control-Allow-Origin")); + assertFalse("'*' must not be sent with credentials per the CORS spec", + headers.containsKey("Access-Control-Allow-Credentials")); + assertFalse("Vary: Origin is unnecessary when emitting '*'", + headers.containsKey("Vary")); + } + + @Test + public void wildcardOriginPreflightIncludesMaxAge() { + configureAllowedOrigin("*"); + HashMap headers = CORSHelper.getCORSHeaders( + "OPTIONS", "https://anything.example", "POST", "Content-Type"); + + assertNotNull(headers); + assertEquals("*", headers.get("Access-Control-Allow-Origin")); + assertEquals("86400", headers.get("Access-Control-Max-Age")); + assertEquals("POST", headers.get("Access-Control-Allow-Methods")); + } + + @Test + public void allowlistMatchesOneOfMultiple() { + configureAllowedOrigin("https://a.example, https://b.example ,https://c.example"); + + HashMap b = CORSHelper.getCORSHeaders("GET", "https://b.example", null, null); + assertNotNull(b); + assertEquals("https://b.example", b.get("Access-Control-Allow-Origin")); + assertEquals("Origin", b.get("Vary")); + assertEquals("true", b.get("Access-Control-Allow-Credentials")); + + assertNull(CORSHelper.getCORSHeaders("GET", "https://d.example", null, null)); + } + + @Test + public void mapOverloadReadsOriginAndIsCaseInsensitive() { + configureAllowedOrigin("https://app.example"); + Map> requestHeaders = new LinkedHashMap<>(); + requestHeaders.put("origin", Collections.singletonList("https://app.example")); + requestHeaders.put("access-control-request-method", Collections.singletonList("DELETE")); + requestHeaders.put("access-control-request-headers", Arrays.asList("X-A, X-B")); + + HashMap headers = CORSHelper.getCORSHeaders("OPTIONS", requestHeaders); + assertNotNull(headers); + assertEquals("https://app.example", headers.get("Access-Control-Allow-Origin")); + assertEquals("DELETE", headers.get("Access-Control-Allow-Methods")); + assertEquals("X-A, X-B", headers.get("Access-Control-Allow-Headers")); + } + + @Test + public void mapOverloadReturnsNullWithoutOrigin() { + configureAllowedOrigin("https://app.example"); + Map> requestHeaders = new LinkedHashMap<>(); + requestHeaders.put("Access-Control-Request-Method", Collections.singletonList("POST")); + + assertNull(CORSHelper.getCORSHeaders("OPTIONS", requestHeaders)); + } + + @Test + public void isPreflightSemantics() { + assertTrue(CORSHelper.isPreflight("OPTIONS", "https://x", "GET")); + assertTrue(CORSHelper.isPreflight("options", "https://x", "GET")); + assertFalse(CORSHelper.isPreflight("GET", "https://x", "GET")); + assertFalse(CORSHelper.isPreflight("OPTIONS", null, "GET")); + assertFalse(CORSHelper.isPreflight("OPTIONS", "", "GET")); + assertFalse(CORSHelper.isPreflight("OPTIONS", "https://x", null)); + assertFalse(CORSHelper.isPreflight("OPTIONS", "https://x", "")); + assertFalse(CORSHelper.isPreflight(null, "https://x", "GET")); + } + + @Test + public void nullHttpMethodDoesNotThrow() { + configureAllowedOrigin("https://app.example"); + HashMap headers = CORSHelper.getCORSHeaders(null, "https://app.example", null, null); + assertNotNull(headers); + assertEquals("https://app.example", headers.get("Access-Control-Allow-Origin")); + assertFalse(headers.containsKey("Access-Control-Max-Age")); + } + + @Test + public void preflightWithoutRequestedMethodOrHeadersOmitsThem() { + configureAllowedOrigin("https://app.example"); + HashMap headers = CORSHelper.getCORSHeaders( + "OPTIONS", "https://app.example", null, null); + assertNotNull(headers); + assertEquals("86400", headers.get("Access-Control-Max-Age")); + assertFalse(headers.containsKey("Access-Control-Allow-Methods")); + assertFalse(headers.containsKey("Access-Control-Allow-Headers")); + } +} diff --git a/wrapperjakarta/src/main/java/com/genexus/servlet/CorsFilter.java b/wrapperjakarta/src/main/java/com/genexus/servlet/CorsFilter.java index 52e5c2e65..497d23914 100644 --- a/wrapperjakarta/src/main/java/com/genexus/servlet/CorsFilter.java +++ b/wrapperjakarta/src/main/java/com/genexus/servlet/CorsFilter.java @@ -22,15 +22,23 @@ public void init(FilterConfig filterConfig) throws ServletException { @Override public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { HttpServletRequest request = (HttpServletRequest) servletRequest; + HttpServletResponse response = (HttpServletResponse) servletResponse; - HashMap corsHeaders = CORSHelper.getCORSHeaders(request.getMethod(), request.getHeader(CORSHelper.REQUEST_METHOD_HEADER_NAME), request.getHeader(CORSHelper.REQUEST_HEADERS_HEADER_NAME)); + String origin = request.getHeader(CORSHelper.ORIGIN_HEADER_NAME); + String requestedMethod = request.getHeader(CORSHelper.REQUEST_METHOD_HEADER_NAME); + String requestedHeaders = request.getHeader(CORSHelper.REQUEST_HEADERS_HEADER_NAME); + + HashMap corsHeaders = CORSHelper.getCORSHeaders(request.getMethod(), origin, requestedMethod, requestedHeaders); if (corsHeaders != null) { - HttpServletResponse response = (HttpServletResponse) servletResponse; for (String headerName : corsHeaders.keySet()) { if (!response.containsHeader(headerName)) { response.setHeader(headerName, corsHeaders.get(headerName)); } } + if (CORSHelper.isPreflight(request.getMethod(), origin, requestedMethod)) { + response.setStatus(HttpServletResponse.SC_NO_CONTENT); + return; + } } filterChain.doFilter(servletRequest, servletResponse); } diff --git a/wrapperjakarta/src/main/java/com/genexus/ws/JAXRSCorsFilter.java b/wrapperjakarta/src/main/java/com/genexus/ws/JAXRSCorsFilter.java index d087d58f4..510c1debc 100644 --- a/wrapperjakarta/src/main/java/com/genexus/ws/JAXRSCorsFilter.java +++ b/wrapperjakarta/src/main/java/com/genexus/ws/JAXRSCorsFilter.java @@ -2,15 +2,40 @@ import com.genexus.cors.CORSHelper; import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerRequestFilter; import jakarta.ws.rs.container.ContainerResponseContext; import jakarta.ws.rs.container.ContainerResponseFilter; +import jakarta.ws.rs.container.PreMatching; +import jakarta.ws.rs.core.Response; import jakarta.ws.rs.ext.Provider; -import java.util.Collections; import java.util.HashMap; @Provider -public class JAXRSCorsFilter implements ContainerResponseFilter { +@PreMatching +public class JAXRSCorsFilter implements ContainerRequestFilter, ContainerResponseFilter { + + @Override + public void filter(ContainerRequestContext requestContext) { + String method = requestContext.getMethod(); + String origin = requestContext.getHeaderString(CORSHelper.ORIGIN_HEADER_NAME); + String requestedMethod = requestContext.getHeaderString(CORSHelper.REQUEST_METHOD_HEADER_NAME); + + if (!CORSHelper.isPreflight(method, origin, requestedMethod)) { + return; + } + String requestedHeaders = requestContext.getHeaderString(CORSHelper.REQUEST_HEADERS_HEADER_NAME); + HashMap corsHeaders = CORSHelper.getCORSHeaders(method, origin, requestedMethod, requestedHeaders); + if (corsHeaders == null) { + return; + } + Response.ResponseBuilder builder = Response.noContent(); + for (String headerName : corsHeaders.keySet()) { + builder.header(headerName, corsHeaders.get(headerName)); + } + requestContext.abortWith(builder.build()); + } + @Override public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext) { @@ -19,7 +44,7 @@ public void filter(ContainerRequestContext requestContext, return; } for (String headerName : corsHeaders.keySet()) { - responseContext.getHeaders().putSingle(headerName,corsHeaders.get(headerName)); + responseContext.getHeaders().putSingle(headerName, corsHeaders.get(headerName)); } } } diff --git a/wrapperjavax/src/main/java/com/genexus/servlet/CorsFilter.java b/wrapperjavax/src/main/java/com/genexus/servlet/CorsFilter.java index 9c914be05..3d5faf645 100644 --- a/wrapperjavax/src/main/java/com/genexus/servlet/CorsFilter.java +++ b/wrapperjavax/src/main/java/com/genexus/servlet/CorsFilter.java @@ -22,14 +22,23 @@ public void init(FilterConfig filterConfig) throws ServletException { @Override public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { HttpServletRequest request = (HttpServletRequest) servletRequest; - HashMap corsHeaders = CORSHelper.getCORSHeaders(request.getMethod(), request.getHeader(CORSHelper.REQUEST_METHOD_HEADER_NAME), request.getHeader(CORSHelper.REQUEST_HEADERS_HEADER_NAME)); + HttpServletResponse response = (HttpServletResponse) servletResponse; + + String origin = request.getHeader(CORSHelper.ORIGIN_HEADER_NAME); + String requestedMethod = request.getHeader(CORSHelper.REQUEST_METHOD_HEADER_NAME); + String requestedHeaders = request.getHeader(CORSHelper.REQUEST_HEADERS_HEADER_NAME); + + HashMap corsHeaders = CORSHelper.getCORSHeaders(request.getMethod(), origin, requestedMethod, requestedHeaders); if (corsHeaders != null) { - HttpServletResponse response = (HttpServletResponse) servletResponse; for (String headerName : corsHeaders.keySet()) { if (!response.containsHeader(headerName)) { response.setHeader(headerName, corsHeaders.get(headerName)); } } + if (CORSHelper.isPreflight(request.getMethod(), origin, requestedMethod)) { + response.setStatus(HttpServletResponse.SC_NO_CONTENT); + return; + } } filterChain.doFilter(servletRequest, servletResponse); } diff --git a/wrapperjavax/src/main/java/com/genexus/ws/JAXRSCorsFilter.java b/wrapperjavax/src/main/java/com/genexus/ws/JAXRSCorsFilter.java index c1a5e9b1b..c0592d0f8 100644 --- a/wrapperjavax/src/main/java/com/genexus/ws/JAXRSCorsFilter.java +++ b/wrapperjavax/src/main/java/com/genexus/ws/JAXRSCorsFilter.java @@ -3,13 +3,38 @@ import com.genexus.cors.CORSHelper; import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.container.ContainerRequestFilter; import javax.ws.rs.container.ContainerResponseContext; import javax.ws.rs.container.ContainerResponseFilter; +import javax.ws.rs.container.PreMatching; +import javax.ws.rs.core.Response; import javax.ws.rs.ext.Provider; import java.util.HashMap; @Provider -public class JAXRSCorsFilter implements ContainerResponseFilter { +@PreMatching +public class JAXRSCorsFilter implements ContainerRequestFilter, ContainerResponseFilter { + + @Override + public void filter(ContainerRequestContext requestContext) { + String method = requestContext.getMethod(); + String origin = requestContext.getHeaderString(CORSHelper.ORIGIN_HEADER_NAME); + String requestedMethod = requestContext.getHeaderString(CORSHelper.REQUEST_METHOD_HEADER_NAME); + + if (!CORSHelper.isPreflight(method, origin, requestedMethod)) { + return; + } + String requestedHeaders = requestContext.getHeaderString(CORSHelper.REQUEST_HEADERS_HEADER_NAME); + HashMap corsHeaders = CORSHelper.getCORSHeaders(method, origin, requestedMethod, requestedHeaders); + if (corsHeaders == null) { + return; + } + Response.ResponseBuilder builder = Response.noContent(); + for (String headerName : corsHeaders.keySet()) { + builder.header(headerName, corsHeaders.get(headerName)); + } + requestContext.abortWith(builder.build()); + } @Override public void filter(ContainerRequestContext requestContext, @@ -19,7 +44,7 @@ public void filter(ContainerRequestContext requestContext, return; } for (String headerName : corsHeaders.keySet()) { - responseContext.getHeaders().putSingle(headerName,corsHeaders.get(headerName)); + responseContext.getHeaders().putSingle(headerName, corsHeaders.get(headerName)); } } }