Skip to content

Commit

Permalink
Only call set_as_current_context() if context is found
Browse files Browse the repository at this point in the history
  • Loading branch information
kapyteinaikido authored and bitterpanda63 committed Dec 20, 2024
1 parent 1bb3ec7 commit dd46025
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 13 deletions.
4 changes: 1 addition & 3 deletions aikido_zen/sources/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def extract_form_data_from_flask_request_and_save_data(req):
context.set_body(req.form)
else:
context.set_body(req.data.decode("utf-8"))

context.set_as_current_context()

context.set_as_current_context()
except Exception as e:
logger.debug("Exception occured whilst extracting flask body data: %s", e)

Expand Down
71 changes: 61 additions & 10 deletions aikido_zen/sources/flask_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@
"REMOTE_ADDR": "198.51.100.23",
}

sample_environ_malformed_cookie = {
"REQUEST_METHOD": "POST",
"HTTP_HEADER_1": "header 1 value",
"HTTP_HEADER_2": "Header 2 value",
"RANDOM_VALUE": "Random value",
"HTTP_COOKIE": "\u0000" * 10,
"wsgi.url_scheme": "https",
"HTTP_HOST": "example.com",
"PATH_INFO": "/hello",
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
"body": '{"asd": 1}',
"CONTENT_TYPE": "application/x-www-form-urlencoded",
"REMOTE_ADDR": "198.51.100.23",
}

sample_environ_malformed_json = {
"REQUEST_METHOD": "POST",
"HTTP_HEADER_1": "header 1 value",
Expand All @@ -42,8 +57,8 @@
"HTTP_HEADER_2": "Header 2 value",
"HTTP_HOST": "example.com",
"CONTENT_TYPE": "application/json",
"PATH_INFO": "/hello/JohnDoe/30",
"QUERY_STRING": "",
"PATH_INFO": "/hello/JohnDoe/30",
"QUERY_STRING": "",
"body": '{"invalid_json": true',
"REMOTE_ADDR": "198.51.100.23",
"wsgi.url_scheme": "https",
Expand All @@ -60,26 +75,29 @@ def timeout_handler(signum, frame):

signal.signal(signal.SIGALRM, timeout_handler)


def test_flask_all_3_func_with_view_args_and_invalid_json_body():
with patch("aikido_zen.sources.functions.request_handler.request_handler") as mock_request_handler:
with patch(
"aikido_zen.sources.functions.request_handler.request_handler"
) as mock_request_handler:
reset_comms()
current_context.set(None)
mock_request_handler.return_value = None

from flask import Flask

app = Flask(__name__)

@app.route("/hello/<user>/<age>", methods=["POST"])
def hello(user, age):
return f"User: {user}, Age: {age}"

try:
signal.alarm(1)
signal.alarm(1)

app(sample_environ_view_args_and_malformed_json, lambda x, y: x)
app.run()

except TimeoutException:
pass

Expand All @@ -101,7 +119,40 @@ def hello(user, age):

assert get_current_context().route_params["user"] == "JohnDoe"
assert get_current_context().route_params["age"] == "30"



def test_flask_all_3_func_with_malformed_cookie():
"""When the flask body can not be parsed (because it contains invalid json for example), we should still parse the cookies of the endpoint"""
with patch(
"aikido_zen.sources.functions.request_handler.request_handler"
) as mock_request_handler:
reset_comms()
current_context.set(None)
mock_request_handler.return_value = None

from flask import Flask

app = Flask(__name__)
try:
signal.alarm(1)
app(sample_environ_malformed_cookie, lambda x, y: x)
app.run()
except TimeoutException:
pass

print(get_current_context())

assert get_current_context().method == "POST"
assert get_current_context().cookies == {"\u0000" * 10: ""}

calls = mock_request_handler.call_args_list
assert len(calls) == 3
assert calls[0][1]["stage"] == "init"
assert calls[1][1]["stage"] == "pre_response"
assert calls[2][1]["stage"] == "post_response"
assert calls[2][1]["status_code"] == 404


def test_flask_all_3_func_with_invalid_body():
"""When the flask body can not be parsed (because it contains invalid json for example), we should still parse the cookies of the endpoint"""
with patch(
Expand Down Expand Up @@ -138,7 +189,7 @@ def test_flask_all_3_func_with_invalid_body():
assert calls[1][1]["stage"] == "pre_response"
assert calls[2][1]["stage"] == "post_response"
assert calls[2][1]["status_code"] == 404


def test_flask_all_3_func():
with patch(
Expand Down

0 comments on commit dd46025

Please sign in to comment.