Skip to content

Commit 432a9cc

Browse files
afshinZsailerpre-commit-ci[bot]
authored
Return HTTP 400 when attempting to post an event with an unregistered schema (#1463)
Co-authored-by: Zachary Sailer <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d3a6c60 commit 432a9cc

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

jupyter_server/services/events/handlers.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,25 @@ def on_close(self):
7171
self.event_logger.remove_listener(listener=self.event_listener)
7272

7373

74-
def validate_model(data: dict[str, Any]) -> None:
75-
"""Validates for required fields in the JSON request body"""
74+
def validate_model(
75+
data: dict[str, Any], registry: jupyter_events.schema_registry.SchemaRegistry
76+
) -> None:
77+
"""Validates for required fields in the JSON request body and verifies that
78+
a registered schema/version exists"""
7679
required_keys = {"schema_id", "version", "data"}
7780
for key in required_keys:
7881
if key not in data:
79-
raise web.HTTPError(400, f"Missing `{key}` in the JSON request body.")
82+
message = f"Missing `{key}` in the JSON request body."
83+
raise Exception(message)
84+
schema_id = cast(str, data.get("schema_id"))
85+
# The case where a given schema_id isn't found,
86+
# jupyter_events raises a useful error, so there's no need to
87+
# handle that case here.
88+
schema = registry.get(schema_id)
89+
version = int(cast(int, data.get("version")))
90+
if schema.version != version:
91+
message = f"Unregistered version: {version}{schema.version} for `{schema_id}`"
92+
raise Exception(message)
8093

8194

8295
def get_timestamp(data: dict[str, Any]) -> Optional[datetime]:
@@ -111,18 +124,18 @@ async def post(self):
111124
raise web.HTTPError(400, "No JSON data provided")
112125

113126
try:
114-
validate_model(payload)
127+
validate_model(payload, self.event_logger.schemas)
115128
self.event_logger.emit(
116129
schema_id=cast(str, payload.get("schema_id")),
117130
data=cast("Dict[str, Any]", payload.get("data")),
118131
timestamp_override=get_timestamp(payload),
119132
)
120133
self.set_status(204)
121134
self.finish()
122-
except web.HTTPError:
123-
raise
124135
except Exception as e:
125-
raise web.HTTPError(500, str(e)) from e
136+
# All known exceptions are raised by bad requests, e.g., bad
137+
# version, unregistered schema, invalid emission data payload, etc.
138+
raise web.HTTPError(400, str(e)) from e
126139

127140

128141
default_handlers = [

tests/services/events/test_api.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,17 @@ async def test_post_event(jp_fetch, event_logger_sink, payload):
117117
}
118118
"""
119119

120-
121-
@pytest.mark.parametrize("payload", [payload_3, payload_4, payload_5, payload_6])
122-
async def test_post_event_400(jp_fetch, event_logger, payload):
123-
with pytest.raises(tornado.httpclient.HTTPClientError) as e:
124-
await jp_fetch("api", "events", method="POST", body=payload)
125-
126-
assert expected_http_error(e, 400)
127-
128-
129120
payload_7 = """\
121+
{
122+
"schema_id": "http://event.mock.jupyter.org/UNREGISTERED-SCHEMA",
123+
"version": 1,
124+
"data": {
125+
"event_message": "Hello, world!"
126+
}
127+
}
128+
"""
129+
130+
payload_8 = """\
130131
{
131132
"schema_id": "http://event.mock.jupyter.org/message",
132133
"version": 1,
@@ -136,20 +137,23 @@ async def test_post_event_400(jp_fetch, event_logger, payload):
136137
}
137138
"""
138139

139-
payload_8 = """\
140+
payload_9 = """\
140141
{
141142
"schema_id": "http://event.mock.jupyter.org/message",
142143
"version": 2,
143144
"data": {
144-
"message": "Hello, world!"
145+
"event_message": "Hello, world!"
145146
}
146147
}
147148
"""
148149

149150

150-
@pytest.mark.parametrize("payload", [payload_7, payload_8])
151-
async def test_post_event_500(jp_fetch, event_logger, payload):
151+
@pytest.mark.parametrize(
152+
"payload",
153+
[payload_3, payload_4, payload_5, payload_6, payload_7, payload_8, payload_9],
154+
)
155+
async def test_post_event_400(jp_fetch, event_logger, payload):
152156
with pytest.raises(tornado.httpclient.HTTPClientError) as e:
153157
await jp_fetch("api", "events", method="POST", body=payload)
154158

155-
assert expected_http_error(e, 500)
159+
assert expected_http_error(e, 400)

0 commit comments

Comments
 (0)