diff --git a/http-server/src/server.rs b/http-server/src/server.rs index 1bbd019b29..e276b64706 100644 --- a/http-server/src/server.rs +++ b/http-server/src/server.rs @@ -61,6 +61,7 @@ pub struct Builder { /// Custom tokio runtime to run the server on. tokio_runtime: Option, middleware: M, + health_api: Option, } impl Default for Builder { @@ -73,6 +74,7 @@ impl Default for Builder { access_control: AccessControl::default(), tokio_runtime: None, middleware: (), + health_api: None, } } } @@ -119,6 +121,7 @@ impl Builder { access_control: self.access_control, tokio_runtime: self.tokio_runtime, middleware, + health_api: self.health_api, } } @@ -166,6 +169,14 @@ impl Builder { self } + /// Enable health endpoint. + /// Allows you to expose one of the methods under GET / The method will be invoked with no parameters. Error returned from the method will be converted to status 500 response. + /// Expects a tuple with (, ). + pub fn health_api(mut self, path: impl Into, method: impl Into) -> Self { + self.health_api = Some(HealthApi { path: path.into(), method: method.into() }); + self + } + /// Finalizes the configuration of the server with customized TCP settings on the socket and on hyper. /// /// ```rust @@ -213,6 +224,7 @@ impl Builder { resources: self.resources, tokio_runtime: self.tokio_runtime, middleware: self.middleware, + health_api: self.health_api, }) } @@ -256,6 +268,7 @@ impl Builder { resources: self.resources, tokio_runtime: self.tokio_runtime, middleware: self.middleware, + health_api: self.health_api, }) } @@ -290,10 +303,17 @@ impl Builder { resources: self.resources, tokio_runtime: self.tokio_runtime, middleware: self.middleware, + health_api: self.health_api, }) } } +#[derive(Debug, Clone)] +struct HealthApi { + path: String, + method: String, +} + /// Handle used to run or stop the server. #[derive(Debug)] pub struct ServerHandle { @@ -345,6 +365,7 @@ pub struct Server { /// Custom tokio runtime to run the server on. tokio_runtime: Option, middleware: M, + health_api: Option, } impl Server { @@ -364,12 +385,14 @@ impl Server { let middleware = self.middleware; let batch_requests_supported = self.batch_requests_supported; let methods = methods.into().initialize_resources(&resources)?; + let health_api = self.health_api; let make_service = make_service_fn(move |_| { let methods = methods.clone(); let access_control = access_control.clone(); let resources = resources.clone(); let middleware = middleware.clone(); + let health_api = health_api.clone(); async move { Ok::<_, HyperError>(service_fn(move |request| { @@ -377,6 +400,7 @@ impl Server { let access_control = access_control.clone(); let resources = resources.clone(); let middleware = middleware.clone(); + let health_api = health_api.clone(); // Run some validation on the http request, then read the body and try to deserialize it into one of // two cases: a single RPC request or a batch of RPC requests. @@ -430,6 +454,12 @@ impl Server { } Ok(res) } + Method::GET => match health_api.as_ref() { + Some(health) if health.path.as_str() == request.uri().path() => { + process_health_request(health, middleware, methods, max_response_body_size).await + } + _ => Ok(response::method_not_allowed()), + }, // Error scenarios: Method::POST => Ok(response::unsupported_content_type()), _ => Ok(response::method_not_allowed()), @@ -687,3 +717,44 @@ async fn process_validated_request( middleware.on_response(request_start); Ok(response::ok_response(response)) } + +async fn process_health_request( + health_api: &HealthApi, + middleware: impl Middleware, + methods: Methods, + max_response_body_size: u32, +) -> Result, HyperError> { + let (tx, mut rx) = mpsc::unbounded::(); + let sink = MethodSink::new_with_limit(tx, max_response_body_size); + + let request_start = middleware.on_request(); + + let success = match methods.method_with_name(&health_api.method) { + None => false, + Some((name, method_callback)) => match method_callback.inner() { + MethodKind::Sync(callback) => { + let res = (callback)(Id::Number(0), Params::new(None), &sink); + middleware.on_result(name, res, request_start); + res + } + MethodKind::Async(callback) => { + let res = (callback)(Id::Number(0), Params::new(None), sink.clone(), 0, None).await; + middleware.on_result(name, res, request_start); + res + } + + MethodKind::Subscription(_) | MethodKind::Unsubscription(_) => { + middleware.on_result(name, false, request_start); + false + } + }, + }; + + let data = rx.next().await; + middleware.on_response(request_start); + + match data { + Some(resp) if success => Ok(response::ok_response(resp)), + _ => Ok(response::internal_error()), + } +} diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 15cd3a9d5a..86760db67f 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -222,12 +222,19 @@ pub async fn http_server() -> (SocketAddr, HttpServerHandle) { } pub async fn http_server_with_access_control(acl: AccessControl) -> (SocketAddr, HttpServerHandle) { - let server = HttpServerBuilder::default().set_access_control(acl).build("127.0.0.1:0").await.unwrap(); + let server = HttpServerBuilder::default() + .set_access_control(acl) + .health_api("/health", "system_health") + .build("127.0.0.1:0") + .await + .unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| Ok("hello")).unwrap(); module.register_method("notif", |_, _| Ok("")).unwrap(); + module.register_method("system_health", |_, _| Ok("im ok")).unwrap(); + let handle = server.start(module).unwrap(); (addr, handle) } diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index b0c78d59c3..c9dce9704f 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -750,3 +750,22 @@ async fn ws_subscribe_with_bad_params() { .unwrap_err(); assert!(matches!(err, Error::Call(_))); } + +#[tokio::test] +async fn http_health_api_works() { + use hyper::{Body, Client, Request}; + + let (server_addr, _handle) = http_server().await; + + let http_client = Client::new(); + let uri = format!("http://{}/health", server_addr); + + let req = Request::builder().method("GET").uri(&uri).body(Body::empty()).expect("request builder"); + let res = http_client.request(req).await.unwrap(); + + assert!(res.status().is_success()); + + let bytes = hyper::body::to_bytes(res.into_body()).await.unwrap(); + let out = String::from_utf8(bytes.to_vec()).unwrap(); + assert_eq!(out, "{\"jsonrpc\":\"2.0\",\"result\":\"im ok\",\"id\":0}"); +}