Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/qa-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ jobs:
app_port: 8080
sleep_before_test: 30
config_update_delay: 100
skip_tests: test_bypassed_ip_for_geo_blocking,test_demo_apps_generic_tests,test_path_traversal,test_outbound_domain_blocking,test_bypassed_ip,test_wave_attack,test_block_traffic_by_countries,test_user_rate_limiting_1_minute
skip_tests: test_demo_apps_generic_tests,test_path_traversal,test_wave_attack,test_user_rate_limiting_1_minute
40 changes: 40 additions & 0 deletions aikido_zen/context/apply_or_bypass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Helper that either sets the request context as current OR marks the request
as bypassed (when the remote IP is in the bypass list).
"""

from aikido_zen.helpers.logging import logger
from aikido_zen.storage import bypassed_context_store
from aikido_zen.thread.thread_cache import get_cache
from . import current_context


def apply_context_or_bypass(context):
"""
For bypassed IPs: clears the current context and sets the bypassed flag.
For other IPs: sets the context as current and clears the bypassed flag.

Mirrors the firewall-java BypassedContextStore approach so almost every
blocking site short-circuits naturally on `if not context: return`.
"""
try:
cache = get_cache()
if (
cache
and context
and context.remote_address
and cache.is_bypassed_ip(context.remote_address)
):
current_context.set(None)
bypassed_context_store.set_bypassed(True)
return

bypassed_context_store.set_bypassed(False)
if context:
context.set_as_current_context()
except Exception as e:
logger.debug("Exception in apply_context_or_bypass: %s", e)
# On error, fall back to the previous behaviour (set context).
bypassed_context_store.set_bypassed(False)
if context:
context.set_as_current_context()
75 changes: 75 additions & 0 deletions aikido_zen/context/apply_or_bypass_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest

from aikido_zen.context import current_context, get_current_context
from aikido_zen.context.apply_or_bypass import apply_context_or_bypass
from aikido_zen.storage import bypassed_context_store
from aikido_zen.test_utils.context_utils import generate_context
from aikido_zen.thread.thread_cache import get_cache


@pytest.fixture(autouse=True)
def _reset_state():
yield
current_context.set(None)
bypassed_context_store.clear()
cache = get_cache()
if cache:
cache.reset()


def _set_bypass_list(ips):
cache = get_cache()
cache.config.bypassed_ips = _IPMatcherStub(ips)


class _IPMatcherStub:
def __init__(self, ips):
self._ips = set(ips)

def has(self, ip):
return ip in self._ips


def test_non_bypassed_ip_sets_context_and_clears_flag():
_set_bypass_list({"9.9.9.9"})
bypassed_context_store.set_bypassed(True) # stale value from previous request

ctx = generate_context(ip="1.2.3.4")
apply_context_or_bypass(ctx)

assert get_current_context() is ctx
assert bypassed_context_store.is_bypassed() is False


def test_bypassed_ip_clears_context_and_sets_flag():
_set_bypass_list({"1.2.3.4"})

ctx = generate_context(ip="1.2.3.4")
apply_context_or_bypass(ctx)

assert get_current_context() is None
assert bypassed_context_store.is_bypassed() is True


def test_no_remote_address_falls_through_to_set_context():
_set_bypass_list({"1.2.3.4"})

ctx = generate_context()
ctx.remote_address = None
apply_context_or_bypass(ctx)

assert get_current_context() is ctx
assert bypassed_context_store.is_bypassed() is False


def test_bypass_then_non_bypass_resets_flag():
_set_bypass_list({"1.2.3.4"})

ctx_bypassed = generate_context(ip="1.2.3.4")
apply_context_or_bypass(ctx_bypassed)
assert bypassed_context_store.is_bypassed() is True

ctx_normal = generate_context(ip="9.9.9.9")
apply_context_or_bypass(ctx_normal)
assert get_current_context() is ctx_normal
assert bypassed_context_store.is_bypassed() is False
6 changes: 6 additions & 0 deletions aikido_zen/sinks/socket/should_block_outbound_domain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from aikido_zen.storage import bypassed_context_store
from aikido_zen.thread.thread_cache import get_cache


Expand All @@ -6,6 +7,11 @@ def should_block_outbound_domain(hostname, port):
if not process_cache:
return False

if bypassed_context_store.is_bypassed():
# Bypassed IPs are trusted — don't report their outbound hostnames
# in heartbeats and don't block their outbound traffic.
return False

# We store the hostname before checking the blocking status
# This is because if we are in lockdown mode and blocking all new hostnames, it should still
# show up in the dashboard. This allows the user to allow traffic to newly detected hostnames.
Expand Down
56 changes: 56 additions & 0 deletions aikido_zen/sinks/socket/should_block_outbound_domain_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from aikido_zen.sinks.socket.should_block_outbound_domain import (
should_block_outbound_domain,
)
from aikido_zen.storage import bypassed_context_store
from aikido_zen.thread.thread_cache import get_cache


@pytest.fixture(autouse=True)
def _reset_state():
cache = get_cache()
if cache:
cache.reset()
bypassed_context_store.clear()
yield
if cache:
cache.reset()
bypassed_context_store.clear()


def test_unknown_domain_not_blocked():
assert should_block_outbound_domain("safe.example.com", 80) is False
assert any(
h["hostname"] == "safe.example.com"
for h in get_cache().hostnames.as_array()
)


def test_blocked_domain_is_blocked_and_recorded():
cache = get_cache()
cache.config.update_outbound_domains([{"hostname": "evil.example.com", "mode": "block"}])

assert should_block_outbound_domain("evil.example.com", 80) is True
# Blocked domains are still recorded so they show up in the dashboard.
assert any(
h["hostname"] == "evil.example.com"
for h in cache.hostnames.as_array()
)


def test_bypassed_request_does_not_block_or_record():
cache = get_cache()
cache.config.update_outbound_domains([{"hostname": "evil.example.com", "mode": "block"}])
bypassed_context_store.set_bypassed(True)

assert should_block_outbound_domain("evil.example.com", 80) is False
# No hostname pollution from bypassed-IP requests.
assert cache.hostnames.as_array() == []


def test_bypassed_request_does_not_record_unknown_domain():
bypassed_context_store.set_bypassed(True)

assert should_block_outbound_domain("anything.example.com", 80) is False
assert get_cache().hostnames.as_array() == []
3 changes: 2 additions & 1 deletion aikido_zen/sources/django/run_init_stage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Exports run_init_stage function"""

from aikido_zen.context import Context
from aikido_zen.context.apply_or_bypass import apply_context_or_bypass
from aikido_zen.helpers.logging import logger
from .extract_body import extract_body_from_django_request
from .extract_cookies import extract_cookies_from_django_request
Expand All @@ -24,7 +25,7 @@ def run_init_stage(request):
else:
return
context.set_cookies(cookies)
context.set_as_current_context()
apply_context_or_bypass(context)

# Init stage needs to be run with context already set :
request_handler(stage="init")
Expand Down
3 changes: 2 additions & 1 deletion aikido_zen/sources/flask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from aikido_zen.context import Context
from aikido_zen.context.apply_or_bypass import apply_context_or_bypass
from aikido_zen.helpers.get_argument import get_argument
from aikido_zen.sinks import on_import, patch_function, before_modify_return, after
from .extract_cookies import extract_cookies_from_flask_request_and_save_data
Expand Down Expand Up @@ -33,7 +34,7 @@ def _call(func, instance, args, kwargs):
start_response = get_argument(args, kwargs, 1, "start_response")

context1 = Context(req=environ, source="flask")
context1.set_as_current_context()
apply_context_or_bypass(context1)
funcs.request_handler(stage="init")

# Checks for blocked IPs, blocked UAs, ...
Expand Down
3 changes: 2 additions & 1 deletion aikido_zen/sources/quart.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from aikido_zen.context import Context, get_current_context
from aikido_zen.context.apply_or_bypass import apply_context_or_bypass
from .functions.request_handler import request_handler
from ..helpers.get_argument import get_argument
from ..sinks import on_import, patch_function, before, before_async
Expand All @@ -11,7 +12,7 @@ def _call(func, instance, args, kwargs):
return

new_context = Context(req=scope, source="quart")
new_context.set_as_current_context()
apply_context_or_bypass(new_context)
request_handler(stage="init")


Expand Down
3 changes: 2 additions & 1 deletion aikido_zen/sources/starlette/starlette_applications.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Wraps starlette.applications for initial request_handler"""

from aikido_zen.context import Context
from aikido_zen.context.apply_or_bypass import apply_context_or_bypass
from ..functions.request_handler import request_handler
from ...helpers.get_argument import get_argument
from ...sinks import on_import, patch_function, before
Expand All @@ -13,7 +14,7 @@ def _call(func, instance, args, kwargs):
return

new_context = Context(req=scope, source="starlette")
new_context.set_as_current_context()
apply_context_or_bypass(new_context)
request_handler(stage="init")


Expand Down
27 changes: 27 additions & 0 deletions aikido_zen/storage/bypassed_context_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Records whether the current request's remote IP is in the bypass list.

Bypassed requests intentionally do not set a Context (so all per-request
protection short-circuits), but checks that run without a Context — for
example outbound DNS reporting — still need a way to detect "this work
was triggered by a bypassed request".
"""

import contextvars

_bypassed = contextvars.ContextVar("aikido_bypassed_ip", default=False)


def set_bypassed(value: bool) -> None:
_bypassed.set(bool(value))


def is_bypassed() -> bool:
try:
return _bypassed.get()
except Exception:
return False


def clear() -> None:
_bypassed.set(False)
37 changes: 37 additions & 0 deletions aikido_zen/storage/bypassed_context_store_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from aikido_zen.storage import bypassed_context_store


@pytest.fixture(autouse=True)
def _reset_after_each_test():
yield
bypassed_context_store.clear()


def test_default_is_false():
assert bypassed_context_store.is_bypassed() is False


def test_set_bypassed_true_then_false():
bypassed_context_store.set_bypassed(True)
assert bypassed_context_store.is_bypassed() is True

bypassed_context_store.set_bypassed(False)
assert bypassed_context_store.is_bypassed() is False


def test_clear_resets_to_false():
bypassed_context_store.set_bypassed(True)
assert bypassed_context_store.is_bypassed() is True

bypassed_context_store.clear()
assert bypassed_context_store.is_bypassed() is False


def test_truthy_values_coerced_to_bool():
bypassed_context_store.set_bypassed("yes")
assert bypassed_context_store.is_bypassed() is True

bypassed_context_store.set_bypassed(0)
assert bypassed_context_store.is_bypassed() is False
10 changes: 6 additions & 4 deletions aikido_zen/vulnerabilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from aikido_zen.helpers.is_protection_forced_off_cached import (
is_protection_forced_off_cached,
)
from aikido_zen.storage import bypassed_context_store
from aikido_zen.thread.thread_cache import get_cache
from .sql_injection.context_contains_sql_injection import context_contains_sql_injection
from .nosql_injection.check_context import check_context_for_nosql_injection
Expand All @@ -44,6 +45,11 @@ def run_vulnerability_scan(kind, op, args):
if is_protection_forced_off_cached(context):
return

if bypassed_context_store.is_bypassed():
# Bypassed IPs are trusted across all vulnerability kinds, including
# the context-less SSRF path (e.g. stored SSRF).
return

comms = comm.get_comms()
thread_cache = get_cache()
if not context and kind != "ssrf":
Expand All @@ -55,10 +61,6 @@ def run_vulnerability_scan(kind, op, args):
# Make a special exception for SSRF, which checks itself if thread cache is set.
# This is because some scans/tests for SSRF do not require a thread cache to be set.
return
if thread_cache and context:
if thread_cache.is_bypassed_ip(context.remote_address):
# This IP is on the bypass list, not scanning
return

error_type = AikidoException # Default error
error_args = tuple()
Expand Down
Loading
Loading