diff --git a/.github/workflows/qa-tests.yml b/.github/workflows/qa-tests.yml index ce01cadd6..39dbeaa92 100644 --- a/.github/workflows/qa-tests.yml +++ b/.github/workflows/qa-tests.yml @@ -44,9 +44,8 @@ jobs: cp firewall-java/.github/workflows/Dockerfile.qa zen-demo-java/Dockerfile - name: Run Firewall QA Tests - uses: AikidoSec/firewall-tester-action@v1.0.0 + uses: AikidoSec/firewall-tester-action@v1.0.12 with: dockerfile_path: ./zen-demo-java/Dockerfile app_port: 8080 sleep_before_test: 30 - skip_tests: test_ssrf,test_demo_apps_generic_tests diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/InetAddressWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/InetAddressWrapper.java index e95caf193..d2ce47204 100644 --- a/agent/src/main/java/dev/aikido/agent/wrappers/InetAddressWrapper.java +++ b/agent/src/main/java/dev/aikido/agent/wrappers/InetAddressWrapper.java @@ -8,6 +8,7 @@ import java.lang.reflect.Executable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.net.IDN; import java.net.InetAddress; import java.net.MalformedURLException; import java.net.URL; @@ -32,6 +33,29 @@ public static class InetAdvice { // Since we have to wrap a native Java Class stuff gets more complicated // The classpath is not the same anymore, and we can't import our modules directly. // To bypass this issue we load collectors from a .jar file + + // Java's system resolver rejects non-ASCII hostnames, so convert IDN to Punycode + // before the real lookup runs. Without this, getAllByName throws UnknownHostException + // and OnMethodExit never fires — meaning DNSRecordCollector can't block or track + // the hostname. + @Advice.OnMethodEnter(suppress = Throwable.class) + public static void before( + @Advice.Argument(value = 0, readOnly = false) String hostname + ) { + if (hostname == null) { + return; + } + for (int i = 0; i < hostname.length(); i++) { + if (hostname.charAt(i) > 0x7F) { + try { + hostname = IDN.toASCII(hostname, IDN.ALLOW_UNASSIGNED); + } catch (IllegalArgumentException ignored) { + } + return; + } + } + } + @Advice.OnMethodExit public static void after( @Advice.Argument(0) String hostname, diff --git a/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java b/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java index d33c165c9..cfcb0ff95 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java @@ -1,6 +1,7 @@ package dev.aikido.agent_api.collectors; import dev.aikido.agent_api.context.Context; +import dev.aikido.agent_api.storage.BypassedContextStore; import dev.aikido.agent_api.storage.HostnamesStore; import dev.aikido.agent_api.storage.PendingHostnamesStore; import dev.aikido.agent_api.storage.ServiceConfigStore; @@ -34,20 +35,24 @@ public static void report(String hostname, InetAddress[] inetAddresses) { // store stats StatisticsStore.registerCall("java.net.InetAddress.getAllByName", OperationKind.OUTGOING_HTTP_OP); + boolean bypassed = BypassedContextStore.isBypassed(); + // Consume pending ports recorded by URLCollector for this hostname. // Removing them here ensures each (hostname, port) pair is counted exactly once. Set ports = PendingHostnamesStore.getAndRemove(hostname); - if (!ports.isEmpty()) { - for (int port : ports) { - HostnamesStore.incrementHits(hostname, port); + if (!bypassed) { + // Bypassed IPs are trusted — don't report their outbound hostnames in heartbeats. + if (!ports.isEmpty()) { + for (int port : ports) { + HostnamesStore.incrementHits(hostname, port); + } + } else { + HostnamesStore.incrementHits(hostname, 0); } - } else { - // We still need to report a hit to the hostname for outbound domain blocking - HostnamesStore.incrementHits(hostname, 0); } // Block if the hostname is in the blocked domains list - if (ServiceConfigStore.shouldBlockOutgoingRequest(hostname)) { + if (ServiceConfigStore.shouldBlockOutgoingRequest(hostname) && !bypassed) { logger.debug("Blocking DNS lookup for domain: %s", hostname); throw BlockedOutboundException.get(); } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java b/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java index 1d66a2126..805728488 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java @@ -8,6 +8,7 @@ import dev.aikido.agent_api.helpers.logging.LogManager; import dev.aikido.agent_api.helpers.logging.Logger; import dev.aikido.agent_api.storage.AttackQueue; +import dev.aikido.agent_api.storage.BypassedContextStore; import dev.aikido.agent_api.storage.PendingHostnamesStore; import dev.aikido.agent_api.storage.ServiceConfigStore; import dev.aikido.agent_api.storage.ServiceConfiguration; @@ -44,8 +45,10 @@ public static Res report(ContextObject newContext) { // Flush pending hostnames on every context change to prevent the store from // growing unboundedly when a thread is reused across multiple requests. PendingHostnamesStore.clear(); + BypassedContextStore.clear(); if (config.isIpBypassed(newContext.getRemoteAddress())) { + BypassedContextStore.setBypassed(true); return null; // do not set context when the IP address is bypassed (zen = off) } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/helpers/extraction/StringExtractor.java b/agent_api/src/main/java/dev/aikido/agent_api/helpers/extraction/StringExtractor.java index d70f54eae..88a2bce7e 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/helpers/extraction/StringExtractor.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/helpers/extraction/StringExtractor.java @@ -1,6 +1,7 @@ package dev.aikido.agent_api.helpers.extraction; import dev.aikido.agent_api.helpers.patterns.LooksLikeJWT; +import dev.aikido.agent_api.vulnerabilities.DangerousBodyException; import java.lang.reflect.Field; import java.lang.reflect.Modifier; import java.util.*; @@ -9,12 +10,16 @@ import static dev.aikido.agent_api.helpers.patterns.PrimitiveType.isPrimitiveType; public class StringExtractor { + private static final int MAX_DEPTH = 1024; // Ensures that we don't get recursion : Set scanned = new HashSet<>(); public static Map extractStringsFromObject(Object obj) { - return new StringExtractor().extractStringsRecursive(obj, new ArrayList<>()); + return new StringExtractor().extractStringsRecursive(obj, new ArrayList<>(), 0); } - private Map extractStringsRecursive(Object target, ArrayList pathToPayload) { + private Map extractStringsRecursive(Object target, ArrayList pathToPayload, int depth) { + if (depth > MAX_DEPTH) { + throw DangerousBodyException.bodyTooDeep(); + } HashMap result = new HashMap<>(); if (target == null || scanned.contains(target)) { return Map.of(); // Do not rescan objects, because this might lead to recursion. @@ -22,18 +27,18 @@ private Map extractStringsRecursive(Object target, ArrayList || target.getClass().isArray()) { - result.putAll(extractStringsFromArray(target, pathToPayload)); + result.putAll(extractStringsFromArray(target, pathToPayload, depth)); } else if (target instanceof Map targetMap) { - result.putAll(extractStringsFromMap(targetMap, pathToPayload)); + result.putAll(extractStringsFromMap(targetMap, pathToPayload, depth)); } else if (!isPrimitiveType(target)) { // Stop algorithm if it's a primitive type. - result.putAll(extractStringsFromStructure(target, pathToPayload)); + result.putAll(extractStringsFromStructure(target, pathToPayload, depth)); } return result; } - private Map extractStringsFromString(String target, ArrayList pathToPayload) { + private Map extractStringsFromString(String target, ArrayList pathToPayload, int depth) { HashMap result = new HashMap<>(); result.put(target, buildPathToPayload(pathToPayload)); @@ -42,7 +47,7 @@ private Map extractStringsFromString(String target, ArrayList newPathToPayload = new ArrayList<>(pathToPayload); newPathToPayload.add(new PathBuilder.PathPart("jwt")); - Map resultsFromJWT = extractStringsRecursive(jwtResult.payload(), newPathToPayload); + Map resultsFromJWT = extractStringsRecursive(jwtResult.payload(), newPathToPayload, depth + 1); for (Map.Entry entry : resultsFromJWT.entrySet()) { String key = entry.getKey(); String value = entry.getValue(); @@ -57,27 +62,27 @@ private Map extractStringsFromString(String target, ArrayList extractStringsFromArray(Object target, ArrayList pathToPayload) { + private Map extractStringsFromArray(Object target, ArrayList pathToPayload, int depth) { HashMap result = new HashMap<>(); if (target instanceof Collection targetCollection) { int index = 0; for (Object element : (Collection) targetCollection) { ArrayList newPathToPayload = new ArrayList<>(pathToPayload); newPathToPayload.add(new PathBuilder.PathPart("array", index)); - result.putAll(extractStringsRecursive(element, newPathToPayload)); + result.putAll(extractStringsRecursive(element, newPathToPayload, depth + 1)); index++; } } else if (target instanceof Object[] targetArray) { for (int i = 0; i < targetArray.length; i++) { ArrayList newPathToPayload = new ArrayList<>(pathToPayload); newPathToPayload.add(new PathBuilder.PathPart("array", i)); - result.putAll(extractStringsRecursive(targetArray[i], newPathToPayload)); + result.putAll(extractStringsRecursive(targetArray[i], newPathToPayload, depth + 1)); } } return result; } - private Map extractStringsFromMap(Map target, ArrayList pathToPayload) { + private Map extractStringsFromMap(Map target, ArrayList pathToPayload, int depth) { HashMap result = new HashMap<>(); for (Object key : target.keySet()) { if (key instanceof String stringKey) { @@ -89,12 +94,12 @@ private Map extractStringsFromMap(Map target, ArrayList extractStringsFromStructure(Object target, ArrayList pathToPayload) { + private Map extractStringsFromStructure(Object target, ArrayList pathToPayload, int depth) { HashMap result = new HashMap<>(); Field[] fields = target.getClass().getDeclaredFields(); for (Field field : fields) { @@ -106,7 +111,9 @@ private Map extractStringsFromStructure(Object target, ArrayList Object fieldValue = field.get(target); ArrayList newPathToPayload = new ArrayList<>(pathToPayload); newPathToPayload.add(new PathBuilder.PathPart("object", field.getName())); - result.putAll(extractStringsRecursive(fieldValue, newPathToPayload)); + result.putAll(extractStringsRecursive(fieldValue, newPathToPayload, depth + 1)); + } catch (DangerousBodyException e) { + throw e; } catch (IllegalAccessException | RuntimeException e) { // pass-through } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/helpers/patterns/LooksLikeJWT.java b/agent_api/src/main/java/dev/aikido/agent_api/helpers/patterns/LooksLikeJWT.java index e50b1779b..71727eb0e 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/helpers/patterns/LooksLikeJWT.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/helpers/patterns/LooksLikeJWT.java @@ -4,12 +4,16 @@ import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; +import dev.aikido.agent_api.vulnerabilities.DangerousBodyException; + import java.util.Map; import java.util.Objects; public final class LooksLikeJWT { private LooksLikeJWT() {} + public static final int MAX_JWT_PAYLOAD_BYTES = 8 * 1024; + public static Result tryDecodeAsJwt(String jwt) { // Check if the JWT contains the required parts if (jwt == null || !jwt.contains(".")) { @@ -23,14 +27,24 @@ public static Result tryDecodeAsJwt(String jwt) { return new Result(false, null); } + byte[] decoded; try { - // Decode the middle part (payload) of the JWT - String payload = new String(Base64.getUrlDecoder().decode(parts[1])); + decoded = Base64.getUrlDecoder().decode(parts[1]); + } catch (IllegalArgumentException ignored) { + return new Result(false, null); + } + if (decoded.length > MAX_JWT_PAYLOAD_BYTES) { + throw DangerousBodyException.jwtTooLarge(); + } + + String payload = new String(decoded); + try { Gson gson = new Gson(); Map jwtPayload = gson.fromJson(payload, new TypeToken>(){}.getType()); - return new Result(true, jwtPayload); + } catch (StackOverflowError soe) { + throw DangerousBodyException.jwtTooLarge(); } catch (Exception ignored) { return new Result(false, null); } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/BypassedContextStore.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/BypassedContextStore.java new file mode 100644 index 000000000..d3dbd8f87 --- /dev/null +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/BypassedContextStore.java @@ -0,0 +1,23 @@ +package dev.aikido.agent_api.storage; + +/** + * Thread-local flag recording whether the current request's remote IP is in the bypass list. + * Needed because bypassed requests intentionally do not set a context, but for Stored SSRF we still want to skip. + */ +public final class BypassedContextStore { + private BypassedContextStore() {} + + private static final ThreadLocal store = ThreadLocal.withInitial(() -> false); + + public static void setBypassed(boolean bypassed) { + store.set(bypassed); + } + + public static boolean isBypassed() { + return store.get(); + } + + public static void clear() { + store.remove(); + } +} diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfiguration.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfiguration.java index 095d52a62..681d25255 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfiguration.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/ServiceConfiguration.java @@ -97,7 +97,7 @@ public BlockedResult isIpBlocked(String ip) { // Check for allowed ip addresses (i.e. only one country is allowed to visit the site) // Always allow access from private IP addresses (those include local IP addresses) if (!isPrivateIp(ip) && !firewallLists.matchesAllowedIps(ip)) { - return new BlockedResult(true, "not in allowlist"); + return new BlockedResult(true, "not allowed"); } // Check for monitored IP addresses diff --git a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/DangerousBodyException.java b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/DangerousBodyException.java new file mode 100644 index 000000000..bd71c50d6 --- /dev/null +++ b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/DangerousBodyException.java @@ -0,0 +1,15 @@ +package dev.aikido.agent_api.vulnerabilities; + +public class DangerousBodyException extends AikidoException { + public DangerousBodyException(String reason) { + super(generateDefaultMessage("Dangerous Body") + ": " + reason); + } + + public static DangerousBodyException jwtTooLarge() { + return new DangerousBodyException("JWT payload too large"); + } + + public static DangerousBodyException bodyTooDeep() { + return new DangerousBodyException("Body is too deeply nested to scan"); + } +} diff --git a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/Scanner.java b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/Scanner.java index bb49c7bbc..b064fdbfb 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/Scanner.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/Scanner.java @@ -50,6 +50,8 @@ public static void scanForGivenVulnerability(Vulnerabilities.Vulnerability vulne break; } } + } catch (AikidoException ae) { + exception = Optional.of(ae); } catch (Throwable e) { logger.debug(e); } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/SkipVulnerabilityScanDecider.java b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/SkipVulnerabilityScanDecider.java index 49ea7f922..41d5636f1 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/SkipVulnerabilityScanDecider.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/SkipVulnerabilityScanDecider.java @@ -2,6 +2,7 @@ import dev.aikido.agent_api.background.Endpoint; import dev.aikido.agent_api.context.ContextObject; +import dev.aikido.agent_api.storage.BypassedContextStore; import dev.aikido.agent_api.storage.ServiceConfiguration; import java.util.List; @@ -12,6 +13,11 @@ public final class SkipVulnerabilityScanDecider { private SkipVulnerabilityScanDecider() {} public static boolean shouldSkipVulnerabilityScan(ContextObject context, boolean defaultIfNoContext) { + // Check if ip is bypassed, important still for stored ssrf where it runs without a context. + if (BypassedContextStore.isBypassed()) { + return true; + } + if (context == null) { return defaultIfNoContext; } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/BlockedOutboundException.java b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/BlockedOutboundException.java index da55cb32f..f7bd1ed69 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/BlockedOutboundException.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/BlockedOutboundException.java @@ -8,7 +8,7 @@ public BlockedOutboundException(String msg) { } public static BlockedOutboundException get() { - String defaultMsg = generateDefaultMessage("an outbound request"); + String defaultMsg = generateDefaultMessage("an outbound connection"); return new BlockedOutboundException(defaultMsg); } } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/OutboundDomains.java b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/OutboundDomains.java index 84e8bc281..1350a25e5 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/OutboundDomains.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/vulnerabilities/outbound_blocking/OutboundDomains.java @@ -2,8 +2,10 @@ import dev.aikido.agent_api.storage.service_configuration.Domain; +import java.net.IDN; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; public class OutboundDomains { @@ -14,14 +16,14 @@ public void update(List newDomains, boolean blockNewOutgoingRequests) { if (newDomains != null) { this.domains = new HashMap<>(); for (Domain domain : newDomains) { - this.domains.put(domain.hostname(), domain.mode()); + this.domains.put(normalize(domain.hostname()), domain.mode()); } } this.blockNewOutgoingRequests = blockNewOutgoingRequests; } public boolean shouldBlockOutgoingRequest(String hostname) { - String mode = this.domains.get(hostname); + String mode = this.domains.get(normalize(hostname)); if (this.blockNewOutgoingRequests) { // Only allow outgoing requests if the mode is "allow" @@ -32,4 +34,18 @@ public boolean shouldBlockOutgoingRequest(String hostname) { // Only block outgoing requests if the mode is "block" return "block".equals(mode); } + + // Normalize to lowercased Unicode form so Punycode (xn--...) and Unicode + // variants of the same hostname compare equal. + private static String normalize(String hostname) { + if (hostname == null) { + return null; + } + String lower = hostname.toLowerCase(Locale.ROOT); + try { + return IDN.toUnicode(lower, IDN.ALLOW_UNASSIGNED); + } catch (IllegalArgumentException e) { + return lower; + } + } } diff --git a/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java b/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java index c7cdd4b3b..6606161f6 100644 --- a/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java +++ b/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java @@ -6,6 +6,7 @@ import dev.aikido.agent_api.context.Context; import dev.aikido.agent_api.context.ContextObject; import dev.aikido.agent_api.storage.AttackQueue; +import dev.aikido.agent_api.storage.BypassedContextStore; import dev.aikido.agent_api.storage.Hostnames; import dev.aikido.agent_api.storage.HostnamesStore; import dev.aikido.agent_api.storage.PendingHostnamesStore; @@ -37,6 +38,7 @@ void setup() throws UnknownHostException { AttackQueue.clear(); HostnamesStore.clear(); PendingHostnamesStore.clear(); + BypassedContextStore.clear(); } @AfterEach @@ -45,6 +47,7 @@ public void cleanup() { PendingHostnamesStore.clear(); Context.set(null); AttackQueue.clear(); + BypassedContextStore.clear(); // Reset domain config ServiceConfigStore.updateFromAPIResponse(new APIResponse( true, null, 0L, null, null, null, false, List.of(), true, false, List.of() @@ -134,6 +137,85 @@ public void testAllowedDomainNotBlocked() { ); } + @Test + public void testBlockedDomainNotBlockedWhenIpBypassed() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + false, List.of(new Domain("blocked.example.com", "block")), true, true, List.of() + )); + BypassedContextStore.setBypassed(true); + assertDoesNotThrow(() -> + DNSRecordCollector.report("blocked.example.com", new InetAddress[]{inetAddress1}) + ); + } + + @Test + public void testUnicodeDomainBlockedForPunycodeRequest() { + // Blocklist stores Unicode; a Punycode request for the same domain must still be blocked. + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + false, List.of(new Domain("böse.example.com", "block")), true, true, List.of() + )); + assertThrows(BlockedOutboundException.class, () -> + DNSRecordCollector.report("xn--bse-sna.example.com", new InetAddress[]{inetAddress1}) + ); + } + + @Test + public void testUnicodeDomainBlockedForUnicodeRequest() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + false, List.of(new Domain("münchen.example.com", "block")), true, true, List.of() + )); + assertThrows(BlockedOutboundException.class, () -> + DNSRecordCollector.report("münchen.example.com", new InetAddress[]{inetAddress1}) + ); + } + + @Test + public void testAllowedUnicodeDomainNotBlockedForPunycodeRequest() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + true, List.of(new Domain("münchen-allowed.example.com", "allow")), true, true, List.of() + )); + assertDoesNotThrow(() -> + DNSRecordCollector.report("xn--mnchen-allowed-gsb.example.com", new InetAddress[]{inetAddress1}) + ); + } + + @Test + public void testAllowedUnicodeDomainNotBlockedForUnicodeRequest() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + true, List.of(new Domain("münchen-allowed.example.com", "allow")), true, true, List.of() + )); + assertDoesNotThrow(() -> + DNSRecordCollector.report("münchen-allowed.example.com", new InetAddress[]{inetAddress1}) + ); + } + + @Test + public void testBypassedIpDoesNotRecordHostname() { + BypassedContextStore.setBypassed(true); + DNSRecordCollector.report("domain1.example.com", new InetAddress[]{inetAddress1}); + Hostnames.HostnameEntry[] entries = HostnamesStore.getHostnamesAsList(); + assertEquals(0, entries.length); + } + + @Test + public void testBlockedDomainStillRecordedWhenNotBypassed() { + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, + false, List.of(new Domain("blocked.example.com", "block")), true, true, List.of() + )); + assertThrows(BlockedOutboundException.class, () -> + DNSRecordCollector.report("blocked.example.com", new InetAddress[]{inetAddress1}) + ); + Hostnames.HostnameEntry[] entries = HostnamesStore.getHostnamesAsList(); + assertEquals(1, entries.length); + assertEquals("blocked.example.com", entries[0].getHostname()); + } + @Test public void testUnknownDomainBlockedWhenBlockNewOutgoingRequests() { ServiceConfigStore.updateFromAPIResponse(new APIResponse( diff --git a/agent_api/src/test/java/collectors/WebRequestCollectorTest.java b/agent_api/src/test/java/collectors/WebRequestCollectorTest.java index f54f06c71..4a3341d17 100644 --- a/agent_api/src/test/java/collectors/WebRequestCollectorTest.java +++ b/agent_api/src/test/java/collectors/WebRequestCollectorTest.java @@ -8,6 +8,7 @@ import dev.aikido.agent_api.context.Context; import dev.aikido.agent_api.context.ContextObject; import dev.aikido.agent_api.storage.AttackQueue; +import dev.aikido.agent_api.storage.BypassedContextStore; import dev.aikido.agent_api.storage.ServiceConfigStore; import dev.aikido.agent_api.storage.statistics.StatisticsStore; import org.junit.jupiter.api.BeforeEach; @@ -34,6 +35,7 @@ void setUp() { ServiceConfigStore.updateFromAPIResponse(emptyAPIResponse); ServiceConfigStore.updateFromAPIListsResponse(emptyAPIListsResponse); AttackQueue.clear(); + BypassedContextStore.clear(); } @Test @@ -157,7 +159,7 @@ void testReport_ipNotAllowedUsingLists() { contextObject.setIp("4.4.4.4"); response = WebRequestCollector.report(contextObject); assertNotNull(response); - assertEquals("Your IP address is blocked. Reason: not in allowlist (Your IP: 4.4.4.4)", response.msg()); + assertEquals("Your IP address is blocked. Reason: not allowed (Your IP: 4.4.4.4)", response.msg()); assertEquals(403, response.status()); } diff --git a/agent_api/src/test/java/helpers/LooksLikeJWTTest.java b/agent_api/src/test/java/helpers/LooksLikeJWTTest.java index 39ec097fc..5c7bfb60f 100644 --- a/agent_api/src/test/java/helpers/LooksLikeJWTTest.java +++ b/agent_api/src/test/java/helpers/LooksLikeJWTTest.java @@ -1,12 +1,15 @@ package helpers; import dev.aikido.agent_api.helpers.patterns.LooksLikeJWT; +import dev.aikido.agent_api.vulnerabilities.DangerousBodyException; import org.junit.jupiter.api.Test; +import java.util.Base64; import java.util.HashMap; import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; public class LooksLikeJWTTest { @@ -52,4 +55,35 @@ void testReturnsDecodedJwtForValidJwtWithBearerPrefix() { expectedPayload.put("iat", 1.516239022E9); assertEquals(new LooksLikeJWT.Result(true, expectedPayload), LooksLikeJWT.tryDecodeAsJwt(validJwtWithBearer)); } + + private static String buildJwt(String payloadJson) { + String payloadB64 = Base64.getUrlEncoder().withoutPadding() + .encodeToString(payloadJson.getBytes()); + return "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + payloadB64 + ".sig"; + } + + @Test + void testThrowsDangerousBodyExceptionForOversizedPayload() { + StringBuilder sb = new StringBuilder("{\"k\":\""); + for (int i = 0; i < LooksLikeJWT.MAX_JWT_PAYLOAD_BYTES + 10; i++) { + sb.append('a'); + } + sb.append("\"}"); + String jwt = buildJwt(sb.toString()); + assertThrows(DangerousBodyException.class, () -> LooksLikeJWT.tryDecodeAsJwt(jwt)); + } + + @Test + void testThrowsDangerousBodyExceptionForDeeplyNestedPayload() { + int depth = 7000; + StringBuilder open = new StringBuilder(); + StringBuilder close = new StringBuilder(); + for (int i = 0; i < depth; i++) { + open.append("{\"a\":"); + close.append("}"); + } + String payload = open.toString() + "1" + close.toString(); + String jwt = buildJwt(payload); + assertThrows(DangerousBodyException.class, () -> LooksLikeJWT.tryDecodeAsJwt(jwt)); + } } diff --git a/agent_api/src/test/java/helpers/StringExtractorTest.java b/agent_api/src/test/java/helpers/StringExtractorTest.java index 9bc794778..4adc91bed 100644 --- a/agent_api/src/test/java/helpers/StringExtractorTest.java +++ b/agent_api/src/test/java/helpers/StringExtractorTest.java @@ -4,6 +4,7 @@ import dev.aikido.agent_api.api_discovery.DataSchemaGenerator; import dev.aikido.agent_api.api_discovery.DataSchemaItem; import dev.aikido.agent_api.api_discovery.DataSchemaType; +import dev.aikido.agent_api.vulnerabilities.DangerousBodyException; import org.junit.jupiter.api.Test; import static dev.aikido.agent_api.helpers.extraction.StringExtractor.extractStringsFromObject; @@ -407,6 +408,42 @@ public void testItChecksScannedClasses() { assertEquals(".important_record.a", result.get("Hello World")); } + @Test + public void testThrowsDangerousBodyExceptionForTooDeepNesting() { + Map root = new HashMap<>(); + Map current = root; + for (int i = 0; i < 1100; i++) { + Map next = new HashMap<>(); + current.put("a", next); + current = next; + } + current.put("leaf", "x"); + assertThrows(DangerousBodyException.class, () -> extractStringsFromObject(root)); + } + + @Test + public void testDoesNotThrowForModestNesting() { + Map root = new HashMap<>(); + Map current = root; + for (int i = 0; i < 500; i++) { + Map next = new HashMap<>(); + current.put("a", next); + current = next; + } + current.put("leaf", "x"); + Map result = extractStringsFromObject(root); + assertEquals("x", findKeyEndingWithLeafValue(result)); + } + + private static String findKeyEndingWithLeafValue(Map result) { + for (Map.Entry e : result.entrySet()) { + if ("x".equals(e.getKey())) { + return e.getKey(); + } + } + return null; + } + @Test public void testItChecksScannedObjects() { Map input = new HashMap<>(); diff --git a/agent_api/src/test/java/storage/BypassedContextStoreTest.java b/agent_api/src/test/java/storage/BypassedContextStoreTest.java new file mode 100644 index 000000000..1bcb3759b --- /dev/null +++ b/agent_api/src/test/java/storage/BypassedContextStoreTest.java @@ -0,0 +1,59 @@ +package storage; + +import dev.aikido.agent_api.storage.BypassedContextStore; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.jupiter.api.Assertions.*; + +public class BypassedContextStoreTest { + + @BeforeEach + public void setUp() { + BypassedContextStore.clear(); + } + + @AfterEach + public void tearDown() { + BypassedContextStore.clear(); + } + + @Test + public void testDefaultIsFalse() { + assertFalse(BypassedContextStore.isBypassed()); + } + + @Test + public void testSetBypassed() { + BypassedContextStore.setBypassed(true); + assertTrue(BypassedContextStore.isBypassed()); + + BypassedContextStore.setBypassed(false); + assertFalse(BypassedContextStore.isBypassed()); + } + + @Test + public void testClear() { + BypassedContextStore.setBypassed(true); + assertTrue(BypassedContextStore.isBypassed()); + + BypassedContextStore.clear(); + assertFalse(BypassedContextStore.isBypassed()); + } + + @Test + public void testThreadIsolation() throws InterruptedException { + BypassedContextStore.setBypassed(true); + AtomicBoolean observedInOtherThread = new AtomicBoolean(true); + + Thread t = new Thread(() -> observedInOtherThread.set(BypassedContextStore.isBypassed())); + t.start(); + t.join(); + + assertFalse(observedInOtherThread.get()); + assertTrue(BypassedContextStore.isBypassed()); + } +} diff --git a/agent_api/src/test/java/vulnerabilities/SkipVulnerabilityScanDeciderTest.java b/agent_api/src/test/java/vulnerabilities/SkipVulnerabilityScanDeciderTest.java index 2b7e136bc..bf56ef255 100644 --- a/agent_api/src/test/java/vulnerabilities/SkipVulnerabilityScanDeciderTest.java +++ b/agent_api/src/test/java/vulnerabilities/SkipVulnerabilityScanDeciderTest.java @@ -2,7 +2,10 @@ import dev.aikido.agent_api.background.Endpoint; import dev.aikido.agent_api.context.ContextObject; +import dev.aikido.agent_api.storage.BypassedContextStore; import dev.aikido.agent_api.vulnerabilities.SkipVulnerabilityScanDecider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import utils.EmptyAPIResponses; import utils.EmptySampleContextObject; @@ -14,6 +17,16 @@ import static org.junit.jupiter.api.Assertions.*; public class SkipVulnerabilityScanDeciderTest { + @BeforeEach + public void setUp() { + BypassedContextStore.clear(); + } + + @AfterEach + public void tearDown() { + BypassedContextStore.clear(); + } + private List createEndpoints(boolean protectionForcedOff1, boolean protectionForcedOff2) { List endpoints = new ArrayList<>(); endpoints.add(new Endpoint("POST", "/api/login", 3, 1000, Collections.emptyList(), false, protectionForcedOff1, true)); @@ -157,6 +170,33 @@ public void testShouldSkipVulnerabilityScan_NoConditionsMet() { )); } + @Test + public void testShouldSkipVulnerabilityScan_BypassedIp_NullContext() { + BypassedContextStore.setBypassed(true); + // Even with defaultIfNoContext=false (the Stored SSRF path), a bypassed IP must skip. + assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(null, false)); + assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(null, true)); + } + + @Test + public void testShouldSkipVulnerabilityScan_BypassedIp_WithContext() { + EmptyAPIResponses.setEmptyConfigWithEndpointList(createEndpoints(false, false)); + ContextObject ctx = new EmptySampleContextObject("", "/api/login", "POST"); + // Without bypass flag this context would return false (no forced protection off). + assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(ctx)); + + BypassedContextStore.setBypassed(true); + ContextObject freshCtx = new EmptySampleContextObject("", "/api/login", "POST"); + assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(freshCtx)); + } + + @Test + public void testShouldSkipVulnerabilityScan_NotBypassed_NullContext() { + // Sanity check: default behavior unchanged when bypass flag is not set. + assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(null, false)); + assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(null, true)); + } + @Test public void testUsesCache() { ContextObject ctx = new EmptySampleContextObject("", "/api/login", "POST");