Skip to content

Commit 15fe2ed

Browse files
Fix add_middleware enum comparison (#1698)
Fixes #1697 Because of a wrong comparison against the position `Enum`, middleware was not actually being added to the stack via `add_middleware`. This PR fixes this, adds a warning when the middleware position cannot be found, and adds a test.
1 parent 128a8e0 commit 15fe2ed

File tree

3 files changed

+46
-11
lines changed

3 files changed

+46
-11
lines changed

connexion/middleware/main.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import dataclasses
23
import enum
34
import logging
@@ -180,7 +181,9 @@ def __init__(
180181
self.app = app
181182
self.lifespan = lifespan
182183
self.middlewares = (
183-
middlewares if middlewares is not None else self.default_middlewares
184+
middlewares
185+
if middlewares is not None
186+
else copy.copy(self.default_middlewares)
184187
)
185188
self.middleware_stack: t.Optional[t.Iterable[ASGIApp]] = None
186189
self.apis: t.List[API] = []
@@ -223,11 +226,16 @@ def add_middleware(
223226
if isinstance(middleware, partial):
224227
middleware = middleware.func
225228

226-
if middleware == position:
229+
if middleware == position.value:
227230
self.middlewares.insert(
228231
m, t.cast(ASGIApp, partial(middleware_class, **options))
229232
)
230233
break
234+
else:
235+
raise ValueError(
236+
f"Could not insert middleware at position {position.name}. "
237+
f"Please make sure you have a {position.value} in your stack."
238+
)
231239

232240
def _build_middleware_stack(self) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]:
233241
"""Apply all middlewares to the provided app.

tests/api/test_errors.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
import json
2-
3-
import flask
4-
5-
61
def fix_data(data):
72
return data.replace(b'\\"', b'"')
83

tests/test_middleware.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import sys
2-
from unittest import mock
3-
41
import pytest
5-
from connexion.middleware import ConnexionMiddleware
2+
from connexion.middleware import ConnexionMiddleware, MiddlewarePosition
3+
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
64
from starlette.datastructures import MutableHeaders
75

86
from conftest import build_app_from_fixture
@@ -49,3 +47,37 @@ def test_routing_middleware(middleware_app):
4947
assert (
5048
response.headers.get("operation_id") == "fakeapi.hello.post_greeting"
5149
), response.status_code
50+
51+
52+
def test_add_middleware(spec, app_class):
53+
"""Test adding middleware via the `add_middleware` method."""
54+
app = build_app_from_fixture("simple", app_class=app_class, spec_file=spec)
55+
app.add_middleware(TestMiddleware)
56+
57+
app_client = app.test_client()
58+
response = app_client.post("/v1.0/greeting/robbe")
59+
60+
assert (
61+
response.headers.get("operation_id") == "fakeapi.hello.post_greeting"
62+
), response.status_code
63+
64+
65+
def test_position(spec, app_class):
66+
"""Test adding middleware via the `add_middleware` method."""
67+
middlewares = [
68+
middleware
69+
for middleware in ConnexionMiddleware.default_middlewares
70+
if middleware != SwaggerUIMiddleware
71+
]
72+
app = build_app_from_fixture(
73+
"simple", app_class=app_class, spec_file=spec, middlewares=middlewares
74+
)
75+
76+
with pytest.raises(ValueError) as exc_info:
77+
app.add_middleware(TestMiddleware, position=MiddlewarePosition.BEFORE_SWAGGER)
78+
79+
assert (
80+
exc_info.value.args[0]
81+
== f"Could not insert middleware at position BEFORE_SWAGGER. "
82+
f"Please make sure you have a {SwaggerUIMiddleware} in your stack."
83+
)

0 commit comments

Comments
 (0)