From f08f8900a202c6d15b2ca895ac9ace2a0af5afe2 Mon Sep 17 00:00:00 2001 From: Antoine Romero-Romero <ant.romero2@orange.fr> Date: Thu, 10 Nov 2022 23:34:36 +0000 Subject: [PATCH 1/4] fix(routers): add a Router trait along with some minor improvements --- src/routers/const_router.rs | 132 ++++++++++++---------------- src/routers/middleware_router.rs | 86 +++++++++--------- src/routers/mod.rs | 23 +++++ src/routers/router.rs | 146 ++++++++++++------------------- src/routers/types.rs | 2 +- src/server.rs | 67 +++++++------- 6 files changed, 210 insertions(+), 246 deletions(-) diff --git a/src/routers/const_router.rs b/src/routers/const_router.rs index f8f65bb7e..0a43a1f54 100644 --- a/src/routers/const_router.rs +++ b/src/routers/const_router.rs @@ -1,120 +1,71 @@ +use std::collections::HashMap; +use std::str::FromStr; use std::sync::Arc; use std::sync::RwLock; // pyo3 modules use crate::executors::execute_function; +use anyhow::Context; use log::debug; use pyo3::prelude::*; use pyo3::types::PyAny; use actix_web::http::Method; -use matchit::Router; -use anyhow::{bail, Error, Result}; +use anyhow::{Error, Result}; -/// Contains the thread safe hashmaps of different routes +use super::RouteType; +use super::Router; + +type RouteMap = RwLock<matchit::Router<String>>; +/// Contains the thread safe hashmaps of different routes pub struct ConstRouter { - get_routes: Arc<RwLock<Router<String>>>, - post_routes: Arc<RwLock<Router<String>>>, - put_routes: Arc<RwLock<Router<String>>>, - delete_routes: Arc<RwLock<Router<String>>>, - patch_routes: Arc<RwLock<Router<String>>>, - head_routes: Arc<RwLock<Router<String>>>, - options_routes: Arc<RwLock<Router<String>>>, - connect_routes: Arc<RwLock<Router<String>>>, - trace_routes: Arc<RwLock<Router<String>>>, + routes: HashMap<Method, Arc<RouteMap>>, } -impl ConstRouter { - pub fn new() -> Self { - Self { - get_routes: Arc::new(RwLock::new(Router::new())), - post_routes: Arc::new(RwLock::new(Router::new())), - put_routes: Arc::new(RwLock::new(Router::new())), - delete_routes: Arc::new(RwLock::new(Router::new())), - patch_routes: Arc::new(RwLock::new(Router::new())), - head_routes: Arc::new(RwLock::new(Router::new())), - options_routes: Arc::new(RwLock::new(Router::new())), - connect_routes: Arc::new(RwLock::new(Router::new())), - trace_routes: Arc::new(RwLock::new(Router::new())), - } - } - - #[inline] - fn get_relevant_map(&self, route: Method) -> Option<Arc<RwLock<Router<String>>>> { - match route { - Method::GET => Some(self.get_routes.clone()), - Method::POST => Some(self.post_routes.clone()), - Method::PUT => Some(self.put_routes.clone()), - Method::PATCH => Some(self.patch_routes.clone()), - Method::DELETE => Some(self.delete_routes.clone()), - Method::HEAD => Some(self.head_routes.clone()), - Method::OPTIONS => Some(self.options_routes.clone()), - Method::CONNECT => Some(self.connect_routes.clone()), - Method::TRACE => Some(self.trace_routes.clone()), - _ => None, - } - } - - #[inline] - fn get_relevant_map_str(&self, route: &str) -> Option<Arc<RwLock<Router<String>>>> { - if route != "WS" { - let method = match Method::from_bytes(route.as_bytes()) { - Ok(res) => res, - Err(_) => return None, - }; - - self.get_relevant_map(method) - } else { - None - } - } - - /// Checks if the functions is an async function - /// Inserts them in the router according to their nature(CoRoutine/SyncFunction) +impl Router<String, Method> for ConstRouter { /// Doesn't allow query params/body/etc as variables cannot be "memoized"/"const"ified - pub fn add_route( + fn add_route( &self, route_type: &str, // we can just have route type as WS route: &str, function: Py<PyAny>, is_async: bool, number_of_params: u8, - event_loop: &PyAny, + event_loop: Option<&PyAny>, ) -> Result<(), Error> { - let table = match self.get_relevant_map_str(route_type) { - Some(table) => table, - None => bail!("No relevant map"), - }; + let table = self + .get_relevant_map_str(route_type) + .context("No relevant map")? + .clone(); + let route = route.to_string(); + let event_loop = + event_loop.context("Event loop must be provided to add a route to the const router")?; + pyo3_asyncio::tokio::run_until_complete(event_loop, async move { let output = execute_function(function, number_of_params, is_async) .await .unwrap(); debug!("This is the result of the output {:?}", output); table - .clone() .write() .unwrap() .insert(route, output.get("body").unwrap().to_string()) .unwrap(); - Ok(()) - }) - .unwrap(); + })?; Ok(()) } - // Checks if the functions is an async function - // Inserts them in the router according to their nature(CoRoutine/SyncFunction) - pub fn get_route( + fn get_route( &self, - route_method: Method, + route_method: RouteType<Method>, route: &str, // check for the route method here ) -> Option<String> { // need to split this function in multiple smaller functions - let table = self.get_relevant_map(route_method)?; + let table = self.routes.get(&route_method.0)?; let route_map = table.read().ok()?; match route_map.at(route) { @@ -123,3 +74,36 @@ impl ConstRouter { } } } + +impl ConstRouter { + pub fn new() -> Self { + let mut routes = HashMap::new(); + routes.insert(Method::GET, Arc::new(RwLock::new(matchit::Router::new()))); + routes.insert(Method::POST, Arc::new(RwLock::new(matchit::Router::new()))); + routes.insert(Method::PUT, Arc::new(RwLock::new(matchit::Router::new()))); + routes.insert( + Method::DELETE, + Arc::new(RwLock::new(matchit::Router::new())), + ); + routes.insert(Method::PATCH, Arc::new(RwLock::new(matchit::Router::new()))); + routes.insert(Method::HEAD, Arc::new(RwLock::new(matchit::Router::new()))); + routes.insert( + Method::OPTIONS, + Arc::new(RwLock::new(matchit::Router::new())), + ); + routes.insert( + Method::CONNECT, + Arc::new(RwLock::new(matchit::Router::new())), + ); + routes.insert(Method::TRACE, Arc::new(RwLock::new(matchit::Router::new()))); + Self { routes } + } + + #[inline] + fn get_relevant_map_str(&self, route: &str) -> Option<&Arc<RouteMap>> { + match route { + "WS" => None, + _ => self.routes.get(&Method::from_str(route).ok()?), + } + } +} diff --git a/src/routers/middleware_router.rs b/src/routers/middleware_router.rs index 4064baef1..5e9e1a1cc 100644 --- a/src/routers/middleware_router.rs +++ b/src/routers/middleware_router.rs @@ -5,57 +5,54 @@ use crate::types::PyFunction; use pyo3::prelude::*; use pyo3::types::PyAny; -use matchit::Router; - -use anyhow::{bail, Error, Result}; +use anyhow::{Context, Error, Result}; use crate::routers::types::MiddlewareRoute; -/// Contains the thread safe hashmaps of different routes +use super::{RouteType, Router}; +type RouteMap = RwLock<matchit::Router<(PyFunction, u8)>>; + +/// Contains the thread safe hashmaps of different routes pub struct MiddlewareRouter { - before_request: RwLock<Router<(PyFunction, u8)>>, - after_request: RwLock<Router<(PyFunction, u8)>>, + routes: HashMap<MiddlewareRoute, RouteMap>, } impl MiddlewareRouter { pub fn new() -> Self { - Self { - before_request: RwLock::new(Router::new()), - after_request: RwLock::new(Router::new()), - } - } - - #[inline] - fn get_relevant_map( - &self, - route: MiddlewareRoute, - ) -> Option<&RwLock<Router<(PyFunction, u8)>>> { - match route { - MiddlewareRoute::BeforeRequest => Some(&self.before_request), - MiddlewareRoute::AfterRequest => Some(&self.after_request), - } + let mut routes = HashMap::new(); + routes.insert( + MiddlewareRoute::BeforeRequest, + RwLock::new(matchit::Router::new()), + ); + routes.insert( + MiddlewareRoute::AfterRequest, + RwLock::new(matchit::Router::new()), + ); + Self { routes } } +} +impl Router<((PyFunction, u8), HashMap<String, String>), MiddlewareRoute> for MiddlewareRouter { // Checks if the functions is an async function // Inserts them in the router according to their nature(CoRoutine/SyncFunction) - pub fn add_route( + fn add_route( &self, - route_type: MiddlewareRoute, + route_type: &str, route: &str, handler: Py<PyAny>, is_async: bool, number_of_params: u8, + _event_loop: Option<&PyAny>, ) -> Result<(), Error> { - let table = match self.get_relevant_map(route_type) { - Some(table) => table, - None => bail!("No relevant map"), - }; + let table = self + .routes + .get(&MiddlewareRoute::from_str(route_type)) + .context("No relevant map")?; - let function = if is_async { - PyFunction::CoRoutine(handler) - } else { - PyFunction::SyncFunction(handler) + let function = match is_async { + true => PyFunction::CoRoutine(handler), + false => PyFunction::SyncFunction(handler), }; table @@ -66,25 +63,20 @@ impl MiddlewareRouter { Ok(()) } - pub fn get_route( + fn get_route( &self, - route_method: MiddlewareRoute, - route: &str, // check for the route method here + route_method: RouteType<MiddlewareRoute>, + route: &str, ) -> Option<((PyFunction, u8), HashMap<String, String>)> { - // need to split this function in multiple smaller functions - let table = self.get_relevant_map(route_method)?; + let table = self.routes.get(&route_method.0)?; - match table.read().unwrap().at(route) { - Ok(res) => { - let mut route_params = HashMap::new(); - - for (key, value) in res.params.iter() { - route_params.insert(key.to_string(), value.to_string()); - } - - Some((res.value.clone(), route_params)) - } - Err(_) => None, + let table_lock = table.read().ok()?; + let res = table_lock.at(route).ok()?; + let mut route_params = HashMap::new(); + for (key, value) in res.params.iter() { + route_params.insert(key.to_string(), value.to_string()); } + + Some((res.value.to_owned(), route_params)) } } diff --git a/src/routers/mod.rs b/src/routers/mod.rs index 77dd7c416..0aa1fe52d 100644 --- a/src/routers/mod.rs +++ b/src/routers/mod.rs @@ -1,5 +1,28 @@ +use anyhow::Result; +use pyo3::{Py, PyAny}; + pub mod const_router; pub mod middleware_router; pub mod router; pub mod types; pub mod web_socket_router; + +pub struct RouteType<T>(pub T); + +pub trait Router<T, U> { + /// Checks if the functions is an async function + /// Inserts them in the router according to their nature(CoRoutine/SyncFunction) + fn add_route( + &self, + route_type: &str, + route: &str, + handler: Py<PyAny>, + is_async: bool, + number_of_params: u8, + event_loop: Option<&PyAny>, + ) -> Result<()>; + + /// Checks if the functions is an async function + /// Inserts them in the router according to their nature(CoRoutine/SyncFunction) + fn get_route(&self, route_method: RouteType<U>, route: &str) -> Option<T>; +} diff --git a/src/routers/router.rs b/src/routers/router.rs index dcf9cb904..f4295ab5d 100644 --- a/src/routers/router.rs +++ b/src/routers/router.rs @@ -1,5 +1,5 @@ -use std::collections::HashMap; use std::sync::RwLock; +use std::{collections::HashMap, str::FromStr}; // pyo3 modules use crate::types::PyFunction; use pyo3::prelude::*; @@ -8,89 +8,34 @@ use pyo3::types::PyAny; use actix_web::http::Method; use matchit::Router as MatchItRouter; -use anyhow::{bail, Error, Result}; +use anyhow::{Context, Result}; -/// Contains the thread safe hashmaps of different routes - -pub struct Router { - get_routes: RwLock<MatchItRouter<(PyFunction, u8)>>, - post_routes: RwLock<MatchItRouter<(PyFunction, u8)>>, - put_routes: RwLock<MatchItRouter<(PyFunction, u8)>>, - delete_routes: RwLock<MatchItRouter<(PyFunction, u8)>>, - patch_routes: RwLock<MatchItRouter<(PyFunction, u8)>>, - head_routes: RwLock<MatchItRouter<(PyFunction, u8)>>, - options_routes: RwLock<MatchItRouter<(PyFunction, u8)>>, - connect_routes: RwLock<MatchItRouter<(PyFunction, u8)>>, - trace_routes: RwLock<MatchItRouter<(PyFunction, u8)>>, -} - -impl Router { - pub fn new() -> Self { - Self { - get_routes: RwLock::new(MatchItRouter::new()), - post_routes: RwLock::new(MatchItRouter::new()), - put_routes: RwLock::new(MatchItRouter::new()), - delete_routes: RwLock::new(MatchItRouter::new()), - patch_routes: RwLock::new(MatchItRouter::new()), - head_routes: RwLock::new(MatchItRouter::new()), - options_routes: RwLock::new(MatchItRouter::new()), - connect_routes: RwLock::new(MatchItRouter::new()), - trace_routes: RwLock::new(MatchItRouter::new()), - } - } - - #[inline] - fn get_relevant_map(&self, route: Method) -> Option<&RwLock<MatchItRouter<(PyFunction, u8)>>> { - match route { - Method::GET => Some(&self.get_routes), - Method::POST => Some(&self.post_routes), - Method::PUT => Some(&self.put_routes), - Method::PATCH => Some(&self.patch_routes), - Method::DELETE => Some(&self.delete_routes), - Method::HEAD => Some(&self.head_routes), - Method::OPTIONS => Some(&self.options_routes), - Method::CONNECT => Some(&self.connect_routes), - Method::TRACE => Some(&self.trace_routes), - _ => None, - } - } +use super::{RouteType, Router}; - #[inline] - fn get_relevant_map_str( - &self, - route: &str, - ) -> Option<&RwLock<MatchItRouter<(PyFunction, u8)>>> { - if route != "WS" { - let method = match Method::from_bytes(route.as_bytes()) { - Ok(res) => res, - Err(_) => return None, - }; +type RouteMap = RwLock<MatchItRouter<(PyFunction, u8)>>; - self.get_relevant_map(method) - } else { - None - } - } +/// Contains the thread safe hashmaps of different routes +pub struct DynRouter { + routes: HashMap<Method, RouteMap>, +} - // Checks if the functions is an async function - // Inserts them in the router according to their nature(CoRoutine/SyncFunction) - pub fn add_route( +impl Router<((PyFunction, u8), HashMap<String, String>), Method> for DynRouter { + fn add_route( &self, - route_type: &str, // we can just have route type as WS + route_type: &str, // We can just have route type as WS route: &str, handler: Py<PyAny>, is_async: bool, number_of_params: u8, - ) -> Result<(), Error> { - let table = match self.get_relevant_map_str(route_type) { - Some(table) => table, - None => bail!("No relevant map"), - }; - - let function = if is_async { - PyFunction::CoRoutine(handler) - } else { - PyFunction::SyncFunction(handler) + _event_loop: Option<&PyAny>, + ) -> Result<()> { + let table = self + .get_relevant_map_str(route_type) + .context("No relevant map")?; + + let function = match is_async { + true => PyFunction::CoRoutine(handler), + false => PyFunction::SyncFunction(handler), }; // try removing unwrap here @@ -102,27 +47,48 @@ impl Router { Ok(()) } - // Checks if the functions is an async function - // Inserts them in the router according to their nature(CoRoutine/SyncFunction) - pub fn get_route( + fn get_route( &self, - route_method: Method, - route: &str, // check for the route method here + route_method: RouteType<Method>, + route: &str, ) -> Option<((PyFunction, u8), HashMap<String, String>)> { // need to split this function in multiple smaller functions - let table = self.get_relevant_map(route_method)?; + let table = self.routes.get(&route_method.0)?; + + let table_lock = table.read().ok()?; + let res = table_lock.at(route).ok()?; + let mut route_params = HashMap::new(); + for (key, value) in res.params.iter() { + route_params.insert(key.to_string(), value.to_string()); + } - match table.read().unwrap().at(route) { - Ok(res) => { - let mut route_params = HashMap::new(); + Some((res.value.to_owned(), route_params)) + } +} - for (key, value) in res.params.iter() { - route_params.insert(key.to_string(), value.to_string()); - } +impl DynRouter { + pub fn new() -> Self { + let mut routes = HashMap::new(); + routes.insert(Method::GET, RwLock::new(MatchItRouter::new())); + routes.insert(Method::POST, RwLock::new(MatchItRouter::new())); + routes.insert(Method::PUT, RwLock::new(MatchItRouter::new())); + routes.insert(Method::DELETE, RwLock::new(MatchItRouter::new())); + routes.insert(Method::PATCH, RwLock::new(MatchItRouter::new())); + routes.insert(Method::HEAD, RwLock::new(MatchItRouter::new())); + routes.insert(Method::OPTIONS, RwLock::new(MatchItRouter::new())); + routes.insert(Method::CONNECT, RwLock::new(MatchItRouter::new())); + routes.insert(Method::TRACE, RwLock::new(MatchItRouter::new())); + Self { routes } + } - Some((res.value.clone(), route_params)) - } - Err(_) => None, + #[inline] + fn get_relevant_map_str( + &self, + route: &str, + ) -> Option<&RwLock<MatchItRouter<(PyFunction, u8)>>> { + match route { + "WS" => None, + _ => self.routes.get(&Method::from_str(route).ok()?), } } } diff --git a/src/routers/types.rs b/src/routers/types.rs index 6d1da332e..d8017e916 100644 --- a/src/routers/types.rs +++ b/src/routers/types.rs @@ -1,4 +1,4 @@ -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub enum MiddlewareRoute { BeforeRequest, AfterRequest, diff --git a/src/server.rs b/src/server.rs index f924b79ed..23f02817d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,8 +3,9 @@ use crate::io_helpers::apply_headers; use crate::request_handler::{handle_http_middleware_request, handle_http_request}; use crate::routers::const_router::ConstRouter; +use crate::routers::{RouteType, Router}; -use crate::routers::router::Router; +use crate::routers::router::DynRouter; use crate::routers::types::MiddlewareRoute; use crate::routers::{middleware_router::MiddlewareRouter, web_socket_router::WebSocketRouter}; use crate::shared_socket::SocketHeld; @@ -43,7 +44,7 @@ struct Directory { #[pyclass] pub struct Server { - router: Arc<Router>, + router: Arc<DynRouter>, const_router: Arc<ConstRouter>, websocket_router: Arc<WebSocketRouter>, middleware_router: Arc<MiddlewareRouter>, @@ -58,7 +59,7 @@ impl Server { #[new] pub fn new() -> Self { Self { - router: Arc::new(Router::new()), + router: Arc::new(DynRouter::new()), const_router: Arc::new(ConstRouter::new()), websocket_router: Arc::new(WebSocketRouter::new()), middleware_router: Arc::new(MiddlewareRouter::new()), @@ -160,7 +161,7 @@ impl Server { app = app.route( &route.clone(), web::get().to( - move |_router: web::Data<Arc<Router>>, + move |_router: web::Data<Arc<DynRouter>>, _global_headers: web::Data<Arc<Headers>>, stream: web::Payload, req: HttpRequest| { @@ -176,7 +177,7 @@ impl Server { } app.default_service(web::route().to( - move |router, + move |router: web::Data<Arc<DynRouter>>, const_router: web::Data<Arc<ConstRouter>>, middleware_router: web::Data<Arc<MiddlewareRouter>>, global_headers, @@ -277,12 +278,12 @@ impl Server { handler, is_async, number_of_params, - event_loop, + Some(event_loop), ) .unwrap(); } else { self.router - .add_route(route_type, route, handler, is_async, number_of_params) + .add_route(route_type, route, handler, is_async, number_of_params, None) .unwrap(); } } @@ -298,11 +299,8 @@ impl Server { number_of_params: u8, ) { debug!("MiddleWare Route added for {} {} ", route_type, route); - - let route_type = MiddlewareRoute::from_str(route_type); - self.middleware_router - .add_route(route_type, route, handler, is_async, number_of_params) + .add_route(route_type, route, handler, is_async, number_of_params, None) .unwrap(); } @@ -372,7 +370,7 @@ async fn merge_headers( /// This is our service handler. It receives a Request, routes on it /// path, and returns a Future of a Response. async fn index( - router: web::Data<Arc<Router>>, + router: web::Data<Arc<DynRouter>>, const_router: web::Data<Arc<ConstRouter>>, middleware_router: web::Data<Arc<MiddlewareRouter>>, global_headers: web::Data<Arc<Headers>>, @@ -394,24 +392,25 @@ async fn index( let headers = merge_headers(&global_headers, req.headers()).await; // need a better name for this - let tuple_params = - match middleware_router.get_route(MiddlewareRoute::BeforeRequest, req.uri().path()) { - Some(((handler_function, number_of_params), route_params)) => { - let x = handle_http_middleware_request( - handler_function, - number_of_params, - &headers, - &mut payload, - &req, - route_params, - queries.clone(), - ) - .await; - debug!("Middleware contents {:?}", x); - x - } - None => HashMap::new(), - }; + let tuple_params = match middleware_router + .get_route(RouteType(MiddlewareRoute::BeforeRequest), req.uri().path()) + { + Some(((handler_function, number_of_params), route_params)) => { + let x = handle_http_middleware_request( + handler_function, + number_of_params, + &headers, + &mut payload, + &req, + route_params, + queries.clone(), + ) + .await; + debug!("Middleware contents {:?}", x); + x + } + None => HashMap::new(), + }; debug!("These are the tuple params {:?}", tuple_params); @@ -424,18 +423,18 @@ async fn index( debug!("These are the request headers {:?}", headers_dup); let response = if const_router - .get_route(req.method().clone(), req.uri().path()) + .get_route(RouteType(req.method().clone()), req.uri().path()) .is_some() { let mut response = HttpResponse::Ok(); apply_headers(&mut response, headers_dup.clone()); response.body( const_router - .get_route(req.method().clone(), req.uri().path()) + .get_route(RouteType(req.method().clone()), req.uri().path()) .unwrap(), ) } else { - match router.get_route(req.method().clone(), req.uri().path()) { + match router.get_route(RouteType(req.method().clone()), req.uri().path()) { Some(((handler_function, number_of_params), route_params)) => { handle_http_request( handler_function, @@ -457,7 +456,7 @@ async fn index( }; if let Some(((handler_function, number_of_params), route_params)) = - middleware_router.get_route(MiddlewareRoute::AfterRequest, req.uri().path()) + middleware_router.get_route(RouteType(MiddlewareRoute::AfterRequest), req.uri().path()) { let x = handle_http_middleware_request( handler_function, From bc398f36db9114ae60b7dcadb99bf9c5274199df Mon Sep 17 00:00:00 2001 From: Antoine Romero-Romero <ant.romero2@orange.fr> Date: Thu, 10 Nov 2022 23:43:48 +0000 Subject: [PATCH 2/4] fix: clippy warnings --- src/server.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/server.rs b/src/server.rs index 23f02817d..437e203cb 100644 --- a/src/server.rs +++ b/src/server.rs @@ -86,10 +86,7 @@ impl Server { return Ok(()); } - let borrow = socket.try_borrow_mut()?; - let held_socket: &SocketHeld = &*borrow; - - let raw_socket = held_socket.get_socket(); + let raw_socket = socket.try_borrow_mut()?.get_socket(); let router = self.router.clone(); let const_router = self.const_router.clone(); From 456f8722c323e6a6a5ca61bdacc719acd27acc42 Mon Sep 17 00:00:00 2001 From: Antoine Romero-Romero <ant.romero2@orange.fr> Date: Fri, 11 Nov 2022 21:48:13 +0000 Subject: [PATCH 3/4] fix: fixes after review --- src/routers/const_router.rs | 10 ++---- src/routers/middleware_router.rs | 8 ++--- src/routers/mod.rs | 7 ++--- src/routers/router.rs | 7 ++--- src/server.rs | 54 +++++++++++++++----------------- 5 files changed, 36 insertions(+), 50 deletions(-) diff --git a/src/routers/const_router.rs b/src/routers/const_router.rs index 0a43a1f54..cd5f30c7f 100644 --- a/src/routers/const_router.rs +++ b/src/routers/const_router.rs @@ -13,7 +13,6 @@ use actix_web::http::Method; use anyhow::{Error, Result}; -use super::RouteType; use super::Router; type RouteMap = RwLock<matchit::Router<String>>; @@ -59,13 +58,8 @@ impl Router<String, Method> for ConstRouter { Ok(()) } - fn get_route( - &self, - route_method: RouteType<Method>, - route: &str, // check for the route method here - ) -> Option<String> { - // need to split this function in multiple smaller functions - let table = self.routes.get(&route_method.0)?; + fn get_route(&self, route_method: Method, route: &str) -> Option<String> { + let table = self.routes.get(&route_method)?; let route_map = table.read().ok()?; match route_map.at(route) { diff --git a/src/routers/middleware_router.rs b/src/routers/middleware_router.rs index 5e9e1a1cc..ea84f6af7 100644 --- a/src/routers/middleware_router.rs +++ b/src/routers/middleware_router.rs @@ -9,7 +9,7 @@ use anyhow::{Context, Error, Result}; use crate::routers::types::MiddlewareRoute; -use super::{RouteType, Router}; +use super::Router; type RouteMap = RwLock<matchit::Router<(PyFunction, u8)>>; @@ -34,8 +34,6 @@ impl MiddlewareRouter { } impl Router<((PyFunction, u8), HashMap<String, String>), MiddlewareRoute> for MiddlewareRouter { - // Checks if the functions is an async function - // Inserts them in the router according to their nature(CoRoutine/SyncFunction) fn add_route( &self, route_type: &str, @@ -65,10 +63,10 @@ impl Router<((PyFunction, u8), HashMap<String, String>), MiddlewareRoute> for Mi fn get_route( &self, - route_method: RouteType<MiddlewareRoute>, + route_method: MiddlewareRoute, route: &str, ) -> Option<((PyFunction, u8), HashMap<String, String>)> { - let table = self.routes.get(&route_method.0)?; + let table = self.routes.get(&route_method)?; let table_lock = table.read().ok()?; let res = table_lock.at(route).ok()?; diff --git a/src/routers/mod.rs b/src/routers/mod.rs index 0aa1fe52d..9a15bba3c 100644 --- a/src/routers/mod.rs +++ b/src/routers/mod.rs @@ -7,8 +7,6 @@ pub mod router; pub mod types; pub mod web_socket_router; -pub struct RouteType<T>(pub T); - pub trait Router<T, U> { /// Checks if the functions is an async function /// Inserts them in the router according to their nature(CoRoutine/SyncFunction) @@ -22,7 +20,6 @@ pub trait Router<T, U> { event_loop: Option<&PyAny>, ) -> Result<()>; - /// Checks if the functions is an async function - /// Inserts them in the router according to their nature(CoRoutine/SyncFunction) - fn get_route(&self, route_method: RouteType<U>, route: &str) -> Option<T>; + /// Retrieve the correct function from the previously inserted routes + fn get_route(&self, route_method: U, route: &str) -> Option<T>; } diff --git a/src/routers/router.rs b/src/routers/router.rs index f4295ab5d..b4598dbcb 100644 --- a/src/routers/router.rs +++ b/src/routers/router.rs @@ -10,7 +10,7 @@ use matchit::Router as MatchItRouter; use anyhow::{Context, Result}; -use super::{RouteType, Router}; +use super::Router; type RouteMap = RwLock<MatchItRouter<(PyFunction, u8)>>; @@ -49,11 +49,10 @@ impl Router<((PyFunction, u8), HashMap<String, String>), Method> for DynRouter { fn get_route( &self, - route_method: RouteType<Method>, + route_method: Method, route: &str, ) -> Option<((PyFunction, u8), HashMap<String, String>)> { - // need to split this function in multiple smaller functions - let table = self.routes.get(&route_method.0)?; + let table = self.routes.get(&route_method)?; let table_lock = table.read().ok()?; let res = table_lock.at(route).ok()?; diff --git a/src/server.rs b/src/server.rs index 437e203cb..e01471da1 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,7 @@ use crate::io_helpers::apply_headers; use crate::request_handler::{handle_http_middleware_request, handle_http_request}; use crate::routers::const_router::ConstRouter; -use crate::routers::{RouteType, Router}; +use crate::routers::Router; use crate::routers::router::DynRouter; use crate::routers::types::MiddlewareRoute; @@ -388,31 +388,29 @@ async fn index( let headers = merge_headers(&global_headers, req.headers()).await; - // need a better name for this - let tuple_params = match middleware_router - .get_route(RouteType(MiddlewareRoute::BeforeRequest), req.uri().path()) - { - Some(((handler_function, number_of_params), route_params)) => { - let x = handle_http_middleware_request( - handler_function, - number_of_params, - &headers, - &mut payload, - &req, - route_params, - queries.clone(), - ) - .await; - debug!("Middleware contents {:?}", x); - x - } - None => HashMap::new(), - }; + let modified_request = + match middleware_router.get_route(MiddlewareRoute::BeforeRequest, req.uri().path()) { + Some(((handler_function, number_of_params), route_params)) => { + let x = handle_http_middleware_request( + handler_function, + number_of_params, + &headers, + &mut payload, + &req, + route_params, + queries.clone(), + ) + .await; + debug!("Middleware contents {:?}", x); + x + } + None => HashMap::new(), + }; - debug!("These are the tuple params {:?}", tuple_params); + debug!("These are the tuple params {:?}", modified_request); - let headers_dup = if !tuple_params.is_empty() { - tuple_params.get("headers").unwrap().clone() + let headers_dup = if !modified_request.is_empty() { + modified_request.get("headers").unwrap().clone() } else { headers }; @@ -420,18 +418,18 @@ async fn index( debug!("These are the request headers {:?}", headers_dup); let response = if const_router - .get_route(RouteType(req.method().clone()), req.uri().path()) + .get_route(req.method().clone(), req.uri().path()) .is_some() { let mut response = HttpResponse::Ok(); apply_headers(&mut response, headers_dup.clone()); response.body( const_router - .get_route(RouteType(req.method().clone()), req.uri().path()) + .get_route(req.method().clone(), req.uri().path()) .unwrap(), ) } else { - match router.get_route(RouteType(req.method().clone()), req.uri().path()) { + match router.get_route(req.method().clone(), req.uri().path()) { Some(((handler_function, number_of_params), route_params)) => { handle_http_request( handler_function, @@ -453,7 +451,7 @@ async fn index( }; if let Some(((handler_function, number_of_params), route_params)) = - middleware_router.get_route(RouteType(MiddlewareRoute::AfterRequest), req.uri().path()) + middleware_router.get_route(MiddlewareRoute::AfterRequest, req.uri().path()) { let x = handle_http_middleware_request( handler_function, From c0850a789f8582c068e38b4162e45833e27a1065 Mon Sep 17 00:00:00 2001 From: Antoine Romero-Romero <ant.romero2@orange.fr> Date: Sat, 12 Nov 2022 12:44:19 +0000 Subject: [PATCH 4/4] fix: rename DynRouter to HttpRouter --- src/routers/{router.rs => http_router.rs} | 6 +++--- src/routers/mod.rs | 2 +- src/server.rs | 12 ++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) rename src/routers/{router.rs => http_router.rs} (97%) diff --git a/src/routers/router.rs b/src/routers/http_router.rs similarity index 97% rename from src/routers/router.rs rename to src/routers/http_router.rs index b4598dbcb..dec7ed332 100644 --- a/src/routers/router.rs +++ b/src/routers/http_router.rs @@ -15,11 +15,11 @@ use super::Router; type RouteMap = RwLock<MatchItRouter<(PyFunction, u8)>>; /// Contains the thread safe hashmaps of different routes -pub struct DynRouter { +pub struct HttpRouter { routes: HashMap<Method, RouteMap>, } -impl Router<((PyFunction, u8), HashMap<String, String>), Method> for DynRouter { +impl Router<((PyFunction, u8), HashMap<String, String>), Method> for HttpRouter { fn add_route( &self, route_type: &str, // We can just have route type as WS @@ -65,7 +65,7 @@ impl Router<((PyFunction, u8), HashMap<String, String>), Method> for DynRouter { } } -impl DynRouter { +impl HttpRouter { pub fn new() -> Self { let mut routes = HashMap::new(); routes.insert(Method::GET, RwLock::new(MatchItRouter::new())); diff --git a/src/routers/mod.rs b/src/routers/mod.rs index 9a15bba3c..7210167a7 100644 --- a/src/routers/mod.rs +++ b/src/routers/mod.rs @@ -2,8 +2,8 @@ use anyhow::Result; use pyo3::{Py, PyAny}; pub mod const_router; +pub mod http_router; pub mod middleware_router; -pub mod router; pub mod types; pub mod web_socket_router; diff --git a/src/server.rs b/src/server.rs index e01471da1..cd4ada16e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,7 +5,7 @@ use crate::request_handler::{handle_http_middleware_request, handle_http_request use crate::routers::const_router::ConstRouter; use crate::routers::Router; -use crate::routers::router::DynRouter; +use crate::routers::http_router::HttpRouter; use crate::routers::types::MiddlewareRoute; use crate::routers::{middleware_router::MiddlewareRouter, web_socket_router::WebSocketRouter}; use crate::shared_socket::SocketHeld; @@ -44,7 +44,7 @@ struct Directory { #[pyclass] pub struct Server { - router: Arc<DynRouter>, + router: Arc<HttpRouter>, const_router: Arc<ConstRouter>, websocket_router: Arc<WebSocketRouter>, middleware_router: Arc<MiddlewareRouter>, @@ -59,7 +59,7 @@ impl Server { #[new] pub fn new() -> Self { Self { - router: Arc::new(DynRouter::new()), + router: Arc::new(HttpRouter::new()), const_router: Arc::new(ConstRouter::new()), websocket_router: Arc::new(WebSocketRouter::new()), middleware_router: Arc::new(MiddlewareRouter::new()), @@ -158,7 +158,7 @@ impl Server { app = app.route( &route.clone(), web::get().to( - move |_router: web::Data<Arc<DynRouter>>, + move |_router: web::Data<Arc<HttpRouter>>, _global_headers: web::Data<Arc<Headers>>, stream: web::Payload, req: HttpRequest| { @@ -174,7 +174,7 @@ impl Server { } app.default_service(web::route().to( - move |router: web::Data<Arc<DynRouter>>, + move |router: web::Data<Arc<HttpRouter>>, const_router: web::Data<Arc<ConstRouter>>, middleware_router: web::Data<Arc<MiddlewareRouter>>, global_headers, @@ -367,7 +367,7 @@ async fn merge_headers( /// This is our service handler. It receives a Request, routes on it /// path, and returns a Future of a Response. async fn index( - router: web::Data<Arc<DynRouter>>, + router: web::Data<Arc<HttpRouter>>, const_router: web::Data<Arc<ConstRouter>>, middleware_router: web::Data<Arc<MiddlewareRouter>>, global_headers: web::Data<Arc<Headers>>,