diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index 3c19ffe651..a95688b0fd 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -48,7 +48,7 @@ use jsonrpsee_types::{ use parking_lot::Mutex; use rustc_hash::FxHashMap; use serde::{de::DeserializeOwned, Serialize}; -use tokio::sync::Notify; +use tokio::sync::{watch, Notify}; /// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request, /// implemented as a function pointer to a `Fn` function taking four arguments: @@ -98,7 +98,7 @@ impl<'a> std::fmt::Debug for ConnState<'a> { } } -type Subscribers = Arc)>>>; +type Subscribers = Arc)>>>; /// Represent a unique subscription entry based on [`RpcSubscriptionId`] and [`ConnectionId`]. #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -794,9 +794,9 @@ impl PendingSubscription { let InnerPendingSubscription { sink, close_notify, method, uniq_sub, subscribers, id } = inner; if sink.send_response(id, &uniq_sub.sub_id) { - let active_sub = Arc::new(()); - subscribers.lock().insert(uniq_sub.clone(), (sink.clone(), active_sub.clone())); - Some(SubscriptionSink { inner: sink, close_notify, method, uniq_sub, subscribers, active_sub }) + let (tx, rx) = watch::channel(()); + subscribers.lock().insert(uniq_sub.clone(), (sink.clone(), tx)); + Some(SubscriptionSink { inner: sink, close_notify, method, uniq_sub, subscribers, unsubscribe: rx }) } else { None } @@ -826,7 +826,8 @@ pub struct SubscriptionSink { uniq_sub: SubscriptionKey, /// Shared Mutex of subscriptions for this method. subscribers: Subscribers, - active_sub: Arc<()>, + /// Future that returns when the unsubscribe method has been called. + unsubscribe: watch::Receiver<()>, } impl SubscriptionSink { @@ -843,7 +844,7 @@ impl SubscriptionSink { } let msg = self.build_message(result)?; - Ok(self.inner_send(msg)) + Ok(self.inner.send_raw(msg).is_ok()) } /// Reads data from the `stream` and sends back data on the subscription @@ -881,7 +882,7 @@ impl SubscriptionSink { /// SubscriptionClosed::Failed(e) => { /// sink.close(e); /// } - /// }; + /// } /// }); /// }); /// ``` @@ -891,14 +892,23 @@ impl SubscriptionSink { T: Serialize, E: std::fmt::Display, { - let close_notify = match self.close_notify.clone() { + let conn_closed = match self.close_notify.clone() { Some(close_notify) => close_notify, - None => return SubscriptionClosed::RemotePeerAborted, + None => { + return SubscriptionClosed::RemotePeerAborted; + } }; + let mut sub_closed = self.unsubscribe.clone(); + let sub_closed_fut = sub_closed.changed(); + + let conn_closed_fut = conn_closed.notified(); + pin_mut!(conn_closed_fut); + pin_mut!(sub_closed_fut); + let mut stream_item = stream.try_next(); - let closed_fut = close_notify.notified(); - pin_mut!(closed_fut); + let mut closed_fut = futures_util::future::select(conn_closed_fut, sub_closed_fut); + loop { match futures_util::future::select(stream_item, closed_fut).await { // The app sent us a value to send back to the subscribers @@ -922,7 +932,7 @@ impl SubscriptionSink { break SubscriptionClosed::Failed(err); } Either::Left((Ok(None), _)) => break SubscriptionClosed::Success, - Either::Right(((), _)) => { + Either::Right((_, _)) => { break SubscriptionClosed::RemotePeerAborted; } } @@ -956,13 +966,13 @@ impl SubscriptionSink { self.pipe_from_try_stream::<_, _, Error>(stream.map(|item| Ok(item))).await } - /// Returns whether this channel is closed without needing a context. + /// Returns whether the subscription is closed. pub fn is_closed(&self) -> bool { - self.inner.is_closed() || self.close_notify.is_none() + self.inner.is_closed() || self.close_notify.is_none() || !self.is_active_subscription() } fn is_active_subscription(&self) -> bool { - Arc::strong_count(&self.active_sub) > 1 + !self.unsubscribe.has_changed().is_err() } fn build_message(&self, result: &T) -> Result { @@ -981,14 +991,6 @@ impl SubscriptionSink { .map_err(Into::into) } - fn inner_send(&mut self, msg: String) -> bool { - if self.is_active_subscription() { - self.inner.send_raw(msg).is_ok() - } else { - false - } - } - /// Close the subscription, sending a notification with a special `error` field containing the provided error. /// /// This can be used to signal an actual error, or just to signal that the subscription has been closed, diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 978ab1fc34..15cd3a9d5a 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -27,10 +27,14 @@ use std::net::SocketAddr; use std::time::Duration; +use futures::{SinkExt, StreamExt}; +use jsonrpsee::core::error::SubscriptionClosed; use jsonrpsee::http_server::{AccessControl, HttpServerBuilder, HttpServerHandle}; use jsonrpsee::types::error::{ErrorObject, SUBSCRIPTION_CLOSED_WITH_ERROR}; use jsonrpsee::ws_server::{WsServerBuilder, WsServerHandle}; use jsonrpsee::RpcModule; +use tokio::time::interval; +use tokio_stream::wrappers::IntervalStream; pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle) { let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); @@ -40,10 +44,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle module .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, pending, _| { - let mut sink = match pending.accept() { - Some(sink) => sink, - _ => return, - }; + let mut sink = pending.accept().unwrap(); std::thread::spawn(move || loop { if let Ok(false) = sink.send(&"hello from subscription") { break; @@ -55,10 +56,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle module .register_subscription("subscribe_foo", "subscribe_foo", "unsubscribe_foo", |_, pending, _| { - let mut sink = match pending.accept() { - Some(sink) => sink, - _ => return, - }; + let mut sink = pending.accept().unwrap(); std::thread::spawn(move || loop { if let Ok(false) = sink.send(&1337_usize) { break; @@ -75,10 +73,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle _ => return, }; - let mut sink = match pending.accept() { - Some(sink) => sink, - _ => return, - }; + let mut sink = pending.accept().unwrap(); std::thread::spawn(move || loop { count = count.wrapping_add(1); @@ -92,10 +87,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle module .register_subscription("subscribe_noop", "subscribe_noop", "unsubscribe_noop", |_, pending, _| { - let sink = match pending.accept() { - Some(sink) => sink, - _ => return, - }; + let sink = pending.accept().unwrap(); std::thread::spawn(move || { std::thread::sleep(Duration::from_secs(1)); let err = ErrorObject::owned( @@ -108,6 +100,73 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle }) .unwrap(); + module + .register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, pending, _| { + let mut sink = pending.accept().unwrap(); + + tokio::spawn(async move { + let interval = interval(Duration::from_millis(50)); + let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c); + + match sink.pipe_from_stream(stream).await { + SubscriptionClosed::Success => { + sink.close(SubscriptionClosed::Success); + } + _ => unreachable!(), + } + }); + }) + .unwrap(); + + module + .register_subscription("can_reuse_subscription", "n", "u_can_reuse_subscription", |_, pending, _| { + let mut sink = pending.accept().unwrap(); + + tokio::spawn(async move { + let stream1 = IntervalStream::new(interval(Duration::from_millis(50))) + .zip(futures::stream::iter(1..=5)) + .map(|(_, c)| c); + let stream2 = IntervalStream::new(interval(Duration::from_millis(50))) + .zip(futures::stream::iter(6..=10)) + .map(|(_, c)| c); + + let result = sink.pipe_from_stream(stream1).await; + assert!(matches!(result, SubscriptionClosed::Success)); + + match sink.pipe_from_stream(stream2).await { + SubscriptionClosed::Success => { + sink.close(SubscriptionClosed::Success); + } + _ => unreachable!(), + } + }); + }) + .unwrap(); + + module + .register_subscription( + "subscribe_with_err_on_stream", + "n", + "unsubscribe_with_err_on_stream", + move |_, pending, _| { + let mut sink = pending.accept().unwrap(); + + let err: &'static str = "error on the stream"; + + // create stream that produce an error which will cancel the subscription. + let stream = futures::stream::iter(vec![Ok(1_u32), Err(err), Ok(2), Ok(3)]); + tokio::spawn(async move { + match sink.pipe_from_try_stream(stream).await { + SubscriptionClosed::Failed(e) => { + sink.close(e); + } + _ => unreachable!(), + } + }); + }, + ) + .unwrap(); + let addr = server.local_addr().unwrap(); let server_handle = server.start(module).unwrap(); @@ -133,6 +192,31 @@ pub async fn websocket_server() -> SocketAddr { addr } +/// Yields one item then sleeps for an hour. +pub async fn websocket_server_with_sleeping_subscription(tx: futures::channel::mpsc::Sender<()>) -> SocketAddr { + let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); + let addr = server.local_addr().unwrap(); + + let mut module = RpcModule::new(tx); + + module + .register_subscription("subscribe_sleep", "n", "unsubscribe_sleep", |_, pending, mut tx| { + let mut sink = pending.accept().unwrap(); + + tokio::spawn(async move { + let interval = interval(Duration::from_secs(60 * 60)); + let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c); + + sink.pipe_from_stream(stream).await; + let send_back = std::sync::Arc::make_mut(&mut tx); + send_back.send(()).await.unwrap(); + }); + }) + .unwrap(); + server.start(module).unwrap(); + addr +} + pub async fn http_server() -> (SocketAddr, HttpServerHandle) { http_server_with_access_control(AccessControl::default()).await } diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 15a808971c..381b00cd50 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -30,17 +30,14 @@ use std::sync::Arc; use std::time::Duration; -use futures::TryStreamExt; +use futures::{channel::mpsc, StreamExt, TryStreamExt}; use helpers::{http_server, http_server_with_access_control, websocket_server, websocket_server_with_subscription}; use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT}; -use jsonrpsee::core::error::SubscriptionClosed; use jsonrpsee::core::{Error, JsonValue}; use jsonrpsee::http_client::HttpClientBuilder; use jsonrpsee::rpc_params; use jsonrpsee::types::error::ErrorObject; use jsonrpsee::ws_client::WsClientBuilder; -use tokio::time::interval; -use tokio_stream::wrappers::IntervalStream; mod helpers; @@ -386,41 +383,14 @@ async fn ws_server_should_stop_subscription_after_client_drop() { #[tokio::test] async fn ws_server_cancels_subscriptions_on_reset_conn() { - use futures::{channel::mpsc, SinkExt, StreamExt}; - use jsonrpsee::{ws_server::WsServerBuilder, RpcModule}; - - let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); - let server_url = format!("ws://{}", server.local_addr().unwrap()); - let (tx, rx) = mpsc::channel(1); - let mut module = RpcModule::new(tx); - - module - .register_subscription("subscribe_for_ever", "n", "unsubscribe_for_ever", |_, pending, mut tx| { - // Create stream that produce one item then sleeps for an hour. - let interval = interval(Duration::from_secs(60 * 60)); - let stream = IntervalStream::new(interval).map(move |_| 0_usize); - - let mut sink = match pending.accept() { - Some(sink) => sink, - _ => return, - }; - - tokio::spawn(async move { - sink.pipe_from_stream(stream).await; - let send_back = Arc::make_mut(&mut tx); - send_back.send(()).await.unwrap(); - }); - }) - .unwrap(); - - server.start(module).unwrap(); + let server_url = format!("ws://{}", helpers::websocket_server_with_sleeping_subscription(tx).await); let client = WsClientBuilder::default().build(&server_url).await.unwrap(); let mut subs = Vec::new(); for _ in 0..10 { - subs.push(client.subscribe::("subscribe_for_ever", None, "unsubscribe_for_ever").await.unwrap()); + subs.push(client.subscribe::("subscribe_sleep", None, "unsubscribe_sleep").await.unwrap()); } // terminate connection. @@ -433,38 +403,8 @@ async fn ws_server_cancels_subscriptions_on_reset_conn() { #[tokio::test] async fn ws_server_cancels_sub_stream_after_err() { - use jsonrpsee::{ws_server::WsServerBuilder, RpcModule}; - - let err: &'static str = "error on the stream"; - let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); - let server_url = format!("ws://{}", server.local_addr().unwrap()); - - let mut module = RpcModule::new(()); - - module - .register_subscription( - "subscribe_with_err_on_stream", - "n", - "unsubscribe_with_err_on_stream", - move |_, pending, _| { - let mut sink = match pending.accept() { - Some(sink) => sink, - _ => return, - }; - - // create stream that produce an error which will cancel the subscription. - let stream = futures::stream::iter(vec![Ok(1_u32), Err(err), Ok(2), Ok(3)]); - tokio::spawn(async move { - match sink.pipe_from_try_stream(stream).await { - SubscriptionClosed::Failed(e) => sink.close(e), - _ => unreachable!(), - }; - }); - }, - ) - .unwrap(); - - server.start(module).unwrap(); + let (addr, _handle) = websocket_server_with_subscription().await; + let server_url = format!("ws://{}", addr); let client = WsClientBuilder::default().build(&server_url).await.unwrap(); let mut sub: Subscription = @@ -477,35 +417,8 @@ async fn ws_server_cancels_sub_stream_after_err() { #[tokio::test] async fn ws_server_subscribe_with_stream() { - use futures::StreamExt; - use jsonrpsee::{ws_server::WsServerBuilder, RpcModule}; - - let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); - let server_url = format!("ws://{}", server.local_addr().unwrap()); - - let mut module = RpcModule::new(()); - - module - .register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, pending, _| { - let mut sink = match pending.accept() { - Some(sink) => sink, - _ => return, - }; - - tokio::spawn(async move { - let interval = interval(Duration::from_millis(50)); - let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c); - - match sink.pipe_from_stream(stream).await { - SubscriptionClosed::Success => { - sink.close(SubscriptionClosed::Success); - } - _ => unreachable!(), - }; - }); - }) - .unwrap(); - server.start(module).unwrap(); + let (addr, _handle) = websocket_server_with_subscription().await; + let server_url = format!("ws://{}", addr); let client = WsClientBuilder::default().build(&server_url).await.unwrap(); let mut sub1: Subscription = client.subscribe("subscribe_5_ints", None, "unsubscribe_5_ints").await.unwrap(); @@ -530,6 +443,37 @@ async fn ws_server_subscribe_with_stream() { assert!(sub1.next().await.is_none()); } +#[tokio::test] +async fn ws_server_pipe_from_stream_should_cancel_tasks_immediately() { + let (tx, rx) = mpsc::channel(1); + let server_url = format!("ws://{}", helpers::websocket_server_with_sleeping_subscription(tx).await); + + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + let mut subs = Vec::new(); + + for _ in 0..10 { + subs.push(client.subscribe::("subscribe_sleep", None, "unsubscribe_sleep").await.unwrap()) + } + + // This will call the `unsubscribe method`. + drop(subs); + + let rx_len = rx.take(10).fold(0, |acc, _| async move { acc + 1 }).await; + + assert_eq!(rx_len, 10); +} + +#[tokio::test] +async fn ws_server_pipe_from_stream_can_be_reused() { + let (addr, _handle) = websocket_server_with_subscription().await; + let client = WsClientBuilder::default().build(&format!("ws://{}", addr)).await.unwrap(); + let sub = client.subscribe::("can_reuse_subscription", None, "u_can_reuse_subscription").await.unwrap(); + + let items = sub.fold(0, |acc, _| async move { acc + 1 }).await; + + assert_eq!(items, 10); +} + #[tokio::test] async fn ws_batch_works() { let server_addr = websocket_server().await;