Skip to content

Commit 5e7f10a

Browse files
committedMar 17, 2025·
🐛 Custom tags for Middlewares (#185)
1 parent 72174bd commit 5e7f10a

File tree

3 files changed

+83
-7
lines changed

3 files changed

+83
-7
lines changed
 

‎flama/authentication/middleware.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818

1919

2020
class AuthenticationMiddleware:
21-
def __init__(self, app: "types.App", *, ignored: list[str] = []):
21+
def __init__(self, app: "types.App", *, tag: str = "permissions", ignored: list[str] = []):
2222
self.app: Flama = t.cast("Flama", app)
23+
self._tag = tag
2324
self._ignored = [re.compile(x) for x in ignored]
2425

2526
async def __call__(self, scope: "types.Scope", receive: "types.Receive", send: "types.Send") -> None:
@@ -34,7 +35,7 @@ async def __call__(self, scope: "types.Scope", receive: "types.Receive", send: "
3435
def _get_permissions(self, app: "Flama", scope: "types.Scope") -> set[str]:
3536
try:
3637
route, _ = app.router.resolve_route(scope)
37-
permissions = set(route.tags.get("permissions", []))
38+
permissions = set(route.tags.get(self._tag, []))
3839
except (exceptions.MethodNotAllowedException, exceptions.NotFoundException):
3940
permissions = []
4041

‎flama/telemetry/middleware.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import typing as t
55

6-
from flama import Flama, concurrency, types
6+
from flama import Flama, concurrency, exceptions, types
77
from flama.telemetry.data_structures import Error, Response, TelemetryData
88

99
logger = logging.getLogger(__name__)
@@ -126,12 +126,14 @@ def __init__(
126126
log_level: int = logging.NOTSET,
127127
before: t.Optional[HookFunction] = None,
128128
after: t.Optional[HookFunction] = None,
129+
tag: str = "telemetry",
129130
ignored: list[str] = [],
130131
) -> None:
131132
self.app: Flama = t.cast(Flama, app)
132133
self._log_level = log_level
133134
self._before = before
134135
self._after = after
136+
self._tag = tag
135137
self._ignored = [re.compile(x) for x in ignored]
136138

137139
async def before(self, data: TelemetryData):
@@ -142,8 +144,20 @@ async def after(self, data: TelemetryData):
142144
if self._after:
143145
await concurrency.run(self._after, data)
144146

147+
def _get_tag(self, scope: "types.Scope") -> bool:
148+
try:
149+
app: Flama = scope["app"]
150+
route, _ = app.router.resolve_route(scope)
151+
return route.tags.get(self._tag, True)
152+
except (exceptions.MethodNotAllowedException, exceptions.NotFoundException):
153+
return False
154+
145155
async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
146-
if scope["type"] not in ("http", "websocket") or any(pattern.match(scope["path"]) for pattern in self._ignored):
156+
if (
157+
scope["type"] not in ("http", "websocket")
158+
or any(pattern.match(scope["path"]) for pattern in self._ignored)
159+
or not self._get_tag(scope)
160+
):
147161
await self.app(scope, receive, send)
148162
return
149163

‎tests/telemetry/test_middleware.py

+64-3
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,17 @@ def error():
4747
def ignored():
4848
return "ignored"
4949

50+
@app.post("/explicit-off/", name="explicit-off", tags={"telemetry": False})
51+
def explicit_off():
52+
return "explicit_off"
53+
5054
@pytest.mark.parametrize(
5155
[
5256
"path",
5357
"request_params",
5458
"request_body",
5559
"request_cookies",
60+
"status_code",
5661
"response",
5762
"exception",
5863
"before",
@@ -65,6 +70,7 @@ def ignored():
6570
{"y": 1},
6671
b"body",
6772
{"access_token": TOKEN},
73+
http.HTTPStatus.OK,
6874
{"x": 1, "y": 1, "body": "body"},
6975
None,
7076
None,
@@ -77,6 +83,7 @@ def ignored():
7783
{"y": 1},
7884
b"body",
7985
{"access_token": TOKEN},
86+
http.HTTPStatus.OK,
8087
{"x": 1, "y": 1, "body": "body"},
8188
None,
8289
MagicMock(),
@@ -128,6 +135,7 @@ def ignored():
128135
{"y": 1},
129136
b"body",
130137
{"access_token": TOKEN},
138+
http.HTTPStatus.OK,
131139
{"x": 1, "y": 1, "body": "body"},
132140
None,
133141
AsyncMock(),
@@ -179,6 +187,7 @@ def ignored():
179187
{},
180188
None,
181189
{},
190+
http.HTTPStatus.OK,
182191
None,
183192
ValueError("foo"),
184193
None,
@@ -191,6 +200,7 @@ def ignored():
191200
{},
192201
None,
193202
{},
203+
http.HTTPStatus.OK,
194204
None,
195205
ValueError("foo"),
196206
MagicMock(),
@@ -222,6 +232,7 @@ def ignored():
222232
{},
223233
None,
224234
{},
235+
http.HTTPStatus.OK,
225236
None,
226237
ValueError("foo"),
227238
AsyncMock(),
@@ -248,12 +259,62 @@ def ignored():
248259
),
249260
id="error_async_hooks",
250261
),
251-
pytest.param("/ignored/", {}, None, {}, "ignored", None, AsyncMock(), AsyncMock(), None, id="ignored"),
262+
pytest.param(
263+
"/ignored/",
264+
{},
265+
None,
266+
{},
267+
http.HTTPStatus.OK,
268+
"ignored",
269+
None,
270+
AsyncMock(),
271+
AsyncMock(),
272+
None,
273+
id="ignored",
274+
),
275+
pytest.param(
276+
"/explicit-off/",
277+
{},
278+
None,
279+
{},
280+
http.HTTPStatus.OK,
281+
"explicit_off",
282+
None,
283+
AsyncMock(),
284+
AsyncMock(),
285+
None,
286+
id="explicit_off",
287+
),
288+
pytest.param(
289+
"/not-found/",
290+
{},
291+
None,
292+
{},
293+
http.HTTPStatus.NOT_FOUND,
294+
{"status_code": 404, "detail": "Not Found", "error": "HTTPException"},
295+
None,
296+
None,
297+
None,
298+
None,
299+
id="not_found",
300+
),
252301
],
253302
indirect=["exception"],
254303
)
255304
async def test_request(
256-
self, app, client, path, request_params, request_body, request_cookies, response, exception, before, after, data
305+
self,
306+
app,
307+
client,
308+
path,
309+
request_params,
310+
request_body,
311+
request_cookies,
312+
status_code,
313+
response,
314+
exception,
315+
before,
316+
after,
317+
data,
257318
):
258319
app.add_middleware(Middleware(TelemetryMiddleware, before=before, after=after, ignored=[r"/ignored.*"]))
259320

@@ -271,7 +332,7 @@ async def test_request(
271332
with exception, patch("datetime.datetime", MagicMock(now=MagicMock(return_value=now))):
272333
r = await client.post(path, params=request_params, content=request_body)
273334

274-
assert r.status_code == http.HTTPStatus.OK
335+
assert r.status_code == status_code
275336
assert r.json() == response
276337

277338
if before:

0 commit comments

Comments
 (0)
Please sign in to comment.