Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add response body in middlewares before and after request #297

Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,25 @@ def return_int_status_code():
return {"status_code": 202, "body": "hello", "type": "text"}


@app.before_request("/")
async def hello_before_request(request):
@app.before_request("/post_with_body")
async def hello_before_request(request,response):
global callCount
callCount += 1
print("this is before request")
print(request)
print(response)
return ""


@app.after_request("/")
async def hello_after_request(request):
@app.after_request("/post_with_body")
async def hello_after_request(request, response):
global callCount
callCount += 1
print("this is after request")
print(request)
print(response)
return ""



@app.get("/test/:id")
Expand Down Expand Up @@ -122,6 +127,9 @@ async def post():

@app.post("/post_with_body")
async def postreq_with_body(request):
print("This is the main function")
print(request)

return bytearray(request["body"]).decode("utf-8")


Expand Down
20 changes: 20 additions & 0 deletions integration_tests/test_middlware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import requests



BASE_URL = "http://127.0.0.1:5000"



def test_post_with_middleware(session):

res = requests.post(f"{BASE_URL}/post_with_body", data = {
"hello": "world"
})



assert res.text=="hello=world"
assert (res.status_code == 200)


20 changes: 13 additions & 7 deletions robyn/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(self) -> None:
super().__init__()
self.routes = []

def add_route(self, route_type: str, endpoint: str, handler: Callable) -> Callable:
number_of_params = len(signature(handler).parameters)

def add_route(self, route_type: str, endpoint: str, handler: Callable, number_of_params=0) -> None:
self.routes.append(
(
route_type,
Expand All @@ -100,8 +100,9 @@ def add_route(self, route_type: str, endpoint: str, handler: Callable) -> Callab
# and returns the arguments.
# Arguments are returned as they could be modified by the middlewares.
def add_after_request(self, endpoint: str) -> Callable[..., None]:

def inner(handler):
@wraps(handler)

async def async_inner_handler(*args):
await handler(*args)
return args
Expand All @@ -111,10 +112,13 @@ def inner_handler(*args):
handler(*args)
return args

number_of_params = len(signature(handler).parameters)


if iscoroutinefunction(handler):
self.add_route("AFTER_REQUEST", endpoint, async_inner_handler)
self.add_route("AFTER_REQUEST", endpoint, async_inner_handler, number_of_params)
else:
self.add_route("AFTER_REQUEST", endpoint, inner_handler)
self.add_route("AFTER_REQUEST", endpoint, inner_handler, number_of_params)

return inner

Expand All @@ -130,10 +134,12 @@ def inner_handler(*args):
handler(*args)
return args

number_of_params = len(signature(handler).parameters)

if iscoroutinefunction(handler):
self.add_route("BEFORE_REQUEST", endpoint, async_inner_handler)
self.add_route("BEFORE_REQUEST", endpoint, async_inner_handler, number_of_params)
else:
self.add_route("BEFORE_REQUEST", endpoint, inner_handler)
self.add_route("BEFORE_REQUEST", endpoint, inner_handler, number_of_params)

return inner

Expand Down
84 changes: 35 additions & 49 deletions src/executors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,47 @@ use std::collections::HashMap;
use std::rc::Rc;
use std::sync::Arc;

use actix_web::{http::Method, web, HttpRequest};
use anyhow::{bail, Result};
use actix_web::HttpResponse;
use anyhow::Result;
use log::debug;

use pyo3_asyncio::TaskLocals;
// pyO3 module
use crate::types::PyFunction;
use futures_util::stream::StreamExt;

use pyo3::prelude::*;
use pyo3::types::PyDict;

/// @TODO make configurable
const MAX_SIZE: usize = 10_000;

pub async fn execute_middleware_function<'a>(
function: PyFunction,
payload: &mut web::Payload,
payload: &mut [u8],
headers: &HashMap<String, String>,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: Rc<RefCell<HashMap<String, String>>>,
number_of_params: u8,
res: Option<&HttpResponse>,
) -> Result<HashMap<String, HashMap<String, String>>> {
// TODO:
// add body in middlewares too

let mut data: Vec<u8> = Vec::new();

if req.method() == Method::POST
|| req.method() == Method::PUT
|| req.method() == Method::PATCH
|| req.method() == Method::DELETE
{
let mut body = web::BytesMut::new();
while let Some(chunk) = payload.next().await {
let chunk = chunk?;
// limit max size of in-memory payload
if (body.len() + chunk.len()) > MAX_SIZE {
bail!("Body content Overflow");
}
body.extend_from_slice(&chunk);
}

data = body.to_vec()
let data = payload.to_owned();
let temp_response = &HttpResponse::Ok().finish();

// make response object accessible while creating routes
let response = match res {
Some(res) => res,
// do nothing if none
None => temp_response,
};
let mut response_headers = HashMap::new();
for (key, val) in response.headers() {
response_headers.insert(key.to_string(), val.to_str().unwrap().to_string());
}
let mut response_dict: HashMap<&str, Py<PyAny>> = HashMap::new();
let response_status_code = response.status().as_u16();
let response_body = data.clone();

// request object accessible while creating routes
let mut request = HashMap::new();
Expand All @@ -69,14 +66,18 @@ pub async fn execute_middleware_function<'a>(
request.insert("queries", queries_clone.into_py(py));
// is this a bottleneck again?
request.insert("headers", headers.clone().into_py(py));
// request.insert("body", data.into_py(py));
request.insert("body", data.into_py(py));
response_dict.insert("headers", response_headers.into_py(py));
response_dict.insert("status", response_status_code.into_py(py));
response_dict.insert("body", response_body.into_py(py));

// this makes the request object to be accessible across every route
let coro: PyResult<&PyAny> = match number_of_params {
0 => handler.call0(),
1 => handler.call1((request,)),
1 => handler.call1((response_dict,)),
2 => handler.call1((request, response_dict)),
// this is done to accomodate any future params
2_u8..=u8::MAX => handler.call1((request,)),
3_u8..=u8::MAX => handler.call1((request, response_dict)),
};
pyo3_asyncio::tokio::into_future(coro?)
})?;
Expand Down Expand Up @@ -105,12 +106,16 @@ pub async fn execute_middleware_function<'a>(
// is this a bottleneck again?
request.insert("headers", headers.clone().into_py(py));
request.insert("body", data.into_py(py));
response_dict.insert("headers", response_headers.into_py(py));
response_dict.insert("status", response_status_code.into_py(py));
response_dict.insert("body", response_body.into_py(py));

let output: PyResult<&PyAny> = match number_of_params {
0 => handler.call0(),
1 => handler.call1((request,)),
2 => handler.call1((request, response_dict)),
// this is done to accomodate any future params
2_u8..=u8::MAX => handler.call1((request,)),
3_u8..=u8::MAX => handler.call1((request, response_dict)),
Comment on lines +110 to +119
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks great. Just one last nit:

request and response_dict do not match. Can you rename them to request and response or req and res or request_dict and response_dict?

};

let output: Vec<HashMap<String, HashMap<String, String>>> = output?.extract()?;
Expand Down Expand Up @@ -184,34 +189,15 @@ pub async fn execute_function(
#[inline]
pub async fn execute_http_function(
function: PyFunction,
payload: &mut web::Payload,
payload: &mut [u8],
headers: HashMap<String, String>,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: Rc<RefCell<HashMap<String, String>>>,
number_of_params: u8,
// need to change this to return a response struct
// create a custom struct for this
) -> Result<HashMap<String, String>> {
let mut data: Vec<u8> = Vec::new();

if req.method() == Method::POST
|| req.method() == Method::PUT
|| req.method() == Method::PATCH
|| req.method() == Method::DELETE
{
let mut body = web::BytesMut::new();
while let Some(chunk) = payload.next().await {
let chunk = chunk?;
// limit max size of in-memory payload
if (body.len() + chunk.len()) > MAX_SIZE {
bail!("Body content Overflow");
}
body.extend_from_slice(&chunk);
}

data = body.to_vec()
}
let data: Vec<u8> = payload.to_owned();

// request object accessible while creating routes
let mut request = HashMap::new();
Expand Down
12 changes: 5 additions & 7 deletions src/request_handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::rc::Rc;
use std::str::FromStr;
use std::{cell::RefCell, collections::HashMap};

use actix_web::{web, HttpRequest, HttpResponse, HttpResponseBuilder};
use actix_web::{HttpResponse, HttpResponseBuilder};
// pyO3 module
use crate::types::PyFunction;

Expand All @@ -31,16 +31,14 @@ pub async fn handle_http_request(
function: PyFunction,
number_of_params: u8,
headers: HashMap<String, String>,
payload: &mut web::Payload,
req: &HttpRequest,
payload: &mut [u8],
route_params: HashMap<String, String>,
queries: Rc<RefCell<HashMap<String, String>>>,
) -> HttpResponse {
let contents = match execute_http_function(
function,
payload,
headers.clone(),
req,
route_params,
queries,
number_of_params,
Expand Down Expand Up @@ -93,19 +91,19 @@ pub async fn handle_http_middleware_request(
function: PyFunction,
number_of_params: u8,
headers: &HashMap<String, String>,
payload: &mut web::Payload,
req: &HttpRequest,
payload: &mut [u8],
route_params: HashMap<String, String>,
queries: Rc<RefCell<HashMap<String, String>>>,
res: Option<&HttpResponse>,
) -> HashMap<String, HashMap<String, String>> {
let contents = match execute_middleware_function(
function,
payload,
headers,
req,
route_params,
queries,
number_of_params,
res,
)
.await
{
Expand Down
Loading