Skip to content

Commit

Permalink
fix(rpc module): close subscription task when a subscription is `unsu…
Browse files Browse the repository at this point in the history
…bscribed` via the `unsubscribe call` (#743)

* refactor: remove SubscriptionSink::inner_send

* fix: close running task if unsubscribed

* Update core/src/server/rpc_module.rs

* Update core/src/server/rpc_module.rs

* fix nits

* Update core/src/server/rpc_module.rs

* add test for canceling subscriptions

* print subscription info; once per minute

* revert closure stuff

* Revert "print subscription info; once per minute"

This reverts commit 366176a.

* use tokio::sync::watch instead of oneshot

The receiver is clonable and it's possible to check whether the sender is still alive

* Update tests/tests/helpers.rs

Co-authored-by: David <dvdplm@gmail.com>

* Update core/src/server/rpc_module.rs

Co-authored-by: David <dvdplm@gmail.com>

* grumbles: use unwrap in tests

* add test for reuse pipe_from_stream

Co-authored-by: David <dvdplm@gmail.com>
  • Loading branch information
niklasad1 and dvdplm authored Apr 29, 2022
1 parent 9decd23 commit 8e945de
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 134 deletions.
50 changes: 26 additions & 24 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use jsonrpsee_types::{
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::Notify;
use tokio::sync::{watch, Notify};

/// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request,
/// implemented as a function pointer to a `Fn` function taking four arguments:
Expand Down Expand Up @@ -98,7 +98,7 @@ impl<'a> std::fmt::Debug for ConnState<'a> {
}
}

type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, Arc<()>)>>>;
type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, watch::Sender<()>)>>>;

/// Represent a unique subscription entry based on [`RpcSubscriptionId`] and [`ConnectionId`].
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -794,9 +794,9 @@ impl PendingSubscription {
let InnerPendingSubscription { sink, close_notify, method, uniq_sub, subscribers, id } = inner;

if sink.send_response(id, &uniq_sub.sub_id) {
let active_sub = Arc::new(());
subscribers.lock().insert(uniq_sub.clone(), (sink.clone(), active_sub.clone()));
Some(SubscriptionSink { inner: sink, close_notify, method, uniq_sub, subscribers, active_sub })
let (tx, rx) = watch::channel(());
subscribers.lock().insert(uniq_sub.clone(), (sink.clone(), tx));
Some(SubscriptionSink { inner: sink, close_notify, method, uniq_sub, subscribers, unsubscribe: rx })
} else {
None
}
Expand Down Expand Up @@ -826,7 +826,8 @@ pub struct SubscriptionSink {
uniq_sub: SubscriptionKey,
/// Shared Mutex of subscriptions for this method.
subscribers: Subscribers,
active_sub: Arc<()>,
/// Future that returns when the unsubscribe method has been called.
unsubscribe: watch::Receiver<()>,
}

impl SubscriptionSink {
Expand All @@ -843,7 +844,7 @@ impl SubscriptionSink {
}

let msg = self.build_message(result)?;
Ok(self.inner_send(msg))
Ok(self.inner.send_raw(msg).is_ok())
}

/// Reads data from the `stream` and sends back data on the subscription
Expand Down Expand Up @@ -881,7 +882,7 @@ impl SubscriptionSink {
/// SubscriptionClosed::Failed(e) => {
/// sink.close(e);
/// }
/// };
/// }
/// });
/// });
/// ```
Expand All @@ -891,14 +892,23 @@ impl SubscriptionSink {
T: Serialize,
E: std::fmt::Display,
{
let close_notify = match self.close_notify.clone() {
let conn_closed = match self.close_notify.clone() {
Some(close_notify) => close_notify,
None => return SubscriptionClosed::RemotePeerAborted,
None => {
return SubscriptionClosed::RemotePeerAborted;
}
};

let mut sub_closed = self.unsubscribe.clone();
let sub_closed_fut = sub_closed.changed();

let conn_closed_fut = conn_closed.notified();
pin_mut!(conn_closed_fut);
pin_mut!(sub_closed_fut);

let mut stream_item = stream.try_next();
let closed_fut = close_notify.notified();
pin_mut!(closed_fut);
let mut closed_fut = futures_util::future::select(conn_closed_fut, sub_closed_fut);

loop {
match futures_util::future::select(stream_item, closed_fut).await {
// The app sent us a value to send back to the subscribers
Expand All @@ -922,7 +932,7 @@ impl SubscriptionSink {
break SubscriptionClosed::Failed(err);
}
Either::Left((Ok(None), _)) => break SubscriptionClosed::Success,
Either::Right(((), _)) => {
Either::Right((_, _)) => {
break SubscriptionClosed::RemotePeerAborted;
}
}
Expand Down Expand Up @@ -956,13 +966,13 @@ impl SubscriptionSink {
self.pipe_from_try_stream::<_, _, Error>(stream.map(|item| Ok(item))).await
}

/// Returns whether this channel is closed without needing a context.
/// Returns whether the subscription is closed.
pub fn is_closed(&self) -> bool {
self.inner.is_closed() || self.close_notify.is_none()
self.inner.is_closed() || self.close_notify.is_none() || !self.is_active_subscription()
}

fn is_active_subscription(&self) -> bool {
Arc::strong_count(&self.active_sub) > 1
!self.unsubscribe.has_changed().is_err()
}

fn build_message<T: Serialize>(&self, result: &T) -> Result<String, serde_json::Error> {
Expand All @@ -981,14 +991,6 @@ impl SubscriptionSink {
.map_err(Into::into)
}

fn inner_send(&mut self, msg: String) -> bool {
if self.is_active_subscription() {
self.inner.send_raw(msg).is_ok()
} else {
false
}
}

/// Close the subscription, sending a notification with a special `error` field containing the provided error.
///
/// This can be used to signal an actual error, or just to signal that the subscription has been closed,
Expand Down
116 changes: 100 additions & 16 deletions tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@
use std::net::SocketAddr;
use std::time::Duration;

use futures::{SinkExt, StreamExt};
use jsonrpsee::core::error::SubscriptionClosed;
use jsonrpsee::http_server::{AccessControl, HttpServerBuilder, HttpServerHandle};
use jsonrpsee::types::error::{ErrorObject, SUBSCRIPTION_CLOSED_WITH_ERROR};
use jsonrpsee::ws_server::{WsServerBuilder, WsServerHandle};
use jsonrpsee::RpcModule;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;

pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle) {
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
Expand All @@ -40,10 +44,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle

module
.register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, pending, _| {
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let mut sink = pending.accept().unwrap();
std::thread::spawn(move || loop {
if let Ok(false) = sink.send(&"hello from subscription") {
break;
Expand All @@ -55,10 +56,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle

module
.register_subscription("subscribe_foo", "subscribe_foo", "unsubscribe_foo", |_, pending, _| {
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let mut sink = pending.accept().unwrap();
std::thread::spawn(move || loop {
if let Ok(false) = sink.send(&1337_usize) {
break;
Expand All @@ -75,10 +73,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle
_ => return,
};

let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let mut sink = pending.accept().unwrap();

std::thread::spawn(move || loop {
count = count.wrapping_add(1);
Expand All @@ -92,10 +87,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle

module
.register_subscription("subscribe_noop", "subscribe_noop", "unsubscribe_noop", |_, pending, _| {
let sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let sink = pending.accept().unwrap();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_secs(1));
let err = ErrorObject::owned(
Expand All @@ -108,6 +100,73 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle
})
.unwrap();

module
.register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, pending, _| {
let mut sink = pending.accept().unwrap();

tokio::spawn(async move {
let interval = interval(Duration::from_millis(50));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);

match sink.pipe_from_stream(stream).await {
SubscriptionClosed::Success => {
sink.close(SubscriptionClosed::Success);
}
_ => unreachable!(),
}
});
})
.unwrap();

module
.register_subscription("can_reuse_subscription", "n", "u_can_reuse_subscription", |_, pending, _| {
let mut sink = pending.accept().unwrap();

tokio::spawn(async move {
let stream1 = IntervalStream::new(interval(Duration::from_millis(50)))
.zip(futures::stream::iter(1..=5))
.map(|(_, c)| c);
let stream2 = IntervalStream::new(interval(Duration::from_millis(50)))
.zip(futures::stream::iter(6..=10))
.map(|(_, c)| c);

let result = sink.pipe_from_stream(stream1).await;
assert!(matches!(result, SubscriptionClosed::Success));

match sink.pipe_from_stream(stream2).await {
SubscriptionClosed::Success => {
sink.close(SubscriptionClosed::Success);
}
_ => unreachable!(),
}
});
})
.unwrap();

module
.register_subscription(
"subscribe_with_err_on_stream",
"n",
"unsubscribe_with_err_on_stream",
move |_, pending, _| {
let mut sink = pending.accept().unwrap();

let err: &'static str = "error on the stream";

// create stream that produce an error which will cancel the subscription.
let stream = futures::stream::iter(vec![Ok(1_u32), Err(err), Ok(2), Ok(3)]);
tokio::spawn(async move {
match sink.pipe_from_try_stream(stream).await {
SubscriptionClosed::Failed(e) => {
sink.close(e);
}
_ => unreachable!(),
}
});
},
)
.unwrap();

let addr = server.local_addr().unwrap();
let server_handle = server.start(module).unwrap();

Expand All @@ -133,6 +192,31 @@ pub async fn websocket_server() -> SocketAddr {
addr
}

/// Yields one item then sleeps for an hour.
pub async fn websocket_server_with_sleeping_subscription(tx: futures::channel::mpsc::Sender<()>) -> SocketAddr {
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let addr = server.local_addr().unwrap();

let mut module = RpcModule::new(tx);

module
.register_subscription("subscribe_sleep", "n", "unsubscribe_sleep", |_, pending, mut tx| {
let mut sink = pending.accept().unwrap();

tokio::spawn(async move {
let interval = interval(Duration::from_secs(60 * 60));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);

sink.pipe_from_stream(stream).await;
let send_back = std::sync::Arc::make_mut(&mut tx);
send_back.send(()).await.unwrap();
});
})
.unwrap();
server.start(module).unwrap();
addr
}

pub async fn http_server() -> (SocketAddr, HttpServerHandle) {
http_server_with_access_control(AccessControl::default()).await
}
Expand Down
Loading

0 comments on commit 8e945de

Please sign in to comment.