diff --git a/src/server/_common.py b/src/server/_common.py index 8633d07fd..33a3f9c48 100644 --- a/src/server/_common.py +++ b/src/server/_common.py @@ -68,6 +68,7 @@ def log_info_with_request(message, **kwargs): remote_addr=request.remote_addr, real_remote_addr=get_real_ip_addr(request), user_agent=request.user_agent.string, + referrer=request.referrer or request.origin, api_key=resolve_auth_token(), user_id=(current_user and current_user.id), **kwargs @@ -114,19 +115,7 @@ def before_request_execute(): user = current_user api_key = resolve_auth_token() - # TODO: replace this next call with: log_info_with_request("Received API request") - get_structured_logger("server_api").info( - "Received API request", - method=request.method, - url=request.url, - form_args=request.form, - req_length=request.content_length, - remote_addr=request.remote_addr, - real_remote_addr=get_real_ip_addr(request), - user_agent=request.user_agent.string, - api_key=api_key, - user_id=(user and user.id) - ) + log_info_with_request("Received API request") if not _is_public_route() and api_key and not user: # if this is a privleged endpoint, and an api key was given but it does not look up to a user, raise exception: @@ -150,28 +139,10 @@ def after_request_execute(response): # Convert to milliseconds total_time *= 1000 - api_key = resolve_auth_token() - update_key_last_time_used(current_user) - # TODO: replace this next call with: log_info_with_request_and_response("Served API request", response, elapsed_time_ms=total_time) - get_structured_logger("server_api").info( - "Served API request", - method=request.method, - url=request.url, - form_args=request.form, - req_length=request.content_length, - remote_addr=request.remote_addr, - real_remote_addr=get_real_ip_addr(request), - user_agent=request.user_agent.string, - api_key=api_key, - values=request.values.to_dict(flat=False), - blueprint=request.blueprint, - endpoint=request.endpoint, - response_status=response.status, - content_length=response.calculate_content_length(), - elapsed_time_ms=total_time, - ) + log_info_with_request_and_response("Served API request", response, elapsed_time_ms=total_time) + return response diff --git a/tests/server/test_validate.py b/tests/server/test_validate.py index f06e9e997..eff7e9c9e 100644 --- a/tests/server/test_validate.py +++ b/tests/server/test_validate.py @@ -26,6 +26,7 @@ def setUp(self): app.config["TESTING"] = True app.config["WTF_CSRF_ENABLED"] = False app.config["DEBUG"] = False + self.client = app.test_client() def test_require_all(self): with self.subTest("all given"): @@ -60,3 +61,39 @@ def test_require_any(self): with self.subTest("one options given with is empty but ok"): with app.test_request_context("/?abc="): self.assertTrue(require_any(request, "abc", empty=True)) + + def test_origin_headers(self): + with self.subTest("referer only"): + with self.assertLogs("server_api", level='INFO') as logs: + self.client.get("/signal_dashboard_status", headers={ + "Referer": "https://test.com/test" + }) + output = logs.output + self.assertEqual(len(output), 2) # [before_request, after_request] + self.assertIn("Received API request", output[0]) + self.assertIn("\"referrer\": \"https://test.com/test\"", output[0]) + self.assertIn("Served API request", output[1]) + self.assertIn("\"referrer\": \"https://test.com/test\"", output[1]) + with self.subTest("origin only"): + with self.assertLogs("server_api", level='INFO') as logs: + self.client.get("/signal_dashboard_status", headers={ + "Origin": "https://test.com" + }) + output = logs.output + self.assertEqual(len(output), 2) # [before_request, after_request] + self.assertIn("Received API request", output[0]) + self.assertIn("\"referrer\": \"https://test.com\"", output[0]) + self.assertIn("Served API request", output[1]) + self.assertIn("\"referrer\": \"https://test.com\"", output[1]) + with self.subTest("referer overrides origin"): + with self.assertLogs("server_api", level='INFO') as logs: + self.client.get("/signal_dashboard_status", headers={ + "Referer": "https://test.com/test", + "Origin": "https://test.com" + }) + output = logs.output + self.assertEqual(len(output), 2) # [before_request, after_request] + self.assertIn("Received API request", output[0]) + self.assertIn("\"referrer\": \"https://test.com/test\"", output[0]) + self.assertIn("Served API request", output[1]) + self.assertIn("\"referrer\": \"https://test.com/test\"", output[1])