Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rpc module): close subscription task when a subscription is unsubscribed via the unsubscribe call #743

Merged
merged 16 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use jsonrpsee_types::{
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::Notify;
use tokio::sync::{watch, Notify};

/// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request,
/// implemented as a function pointer to a `Fn` function taking four arguments:
Expand Down Expand Up @@ -98,7 +98,7 @@ impl<'a> std::fmt::Debug for ConnState<'a> {
}
}

type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, Arc<()>)>>>;
type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, watch::Sender<()>)>>>;

/// Represent a unique subscription entry based on [`RpcSubscriptionId`] and [`ConnectionId`].
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -794,9 +794,9 @@ impl PendingSubscription {
let InnerPendingSubscription { sink, close_notify, method, uniq_sub, subscribers, id } = inner;

if sink.send_response(id, &uniq_sub.sub_id) {
let active_sub = Arc::new(());
subscribers.lock().insert(uniq_sub.clone(), (sink.clone(), active_sub.clone()));
Some(SubscriptionSink { inner: sink, close_notify, method, uniq_sub, subscribers, active_sub })
let (tx, rx) = watch::channel(());
subscribers.lock().insert(uniq_sub.clone(), (sink.clone(), tx));
Some(SubscriptionSink { inner: sink, close_notify, method, uniq_sub, subscribers, unsubscribe: rx })
} else {
None
}
Expand Down Expand Up @@ -826,7 +826,8 @@ pub struct SubscriptionSink {
uniq_sub: SubscriptionKey,
/// Shared Mutex of subscriptions for this method.
subscribers: Subscribers,
active_sub: Arc<()>,
/// Future that returns when the unsubscribe method has been called.
unsubscribe: watch::Receiver<()>,
}

impl SubscriptionSink {
Expand All @@ -843,7 +844,7 @@ impl SubscriptionSink {
}

let msg = self.build_message(result)?;
Ok(self.inner_send(msg))
Ok(self.inner.send_raw(msg).is_ok())
}

/// Reads data from the `stream` and sends back data on the subscription
Expand Down Expand Up @@ -881,7 +882,7 @@ impl SubscriptionSink {
/// SubscriptionClosed::Failed(e) => {
/// sink.close(e);
/// }
/// };
/// }
/// });
/// });
/// ```
Expand All @@ -891,14 +892,23 @@ impl SubscriptionSink {
T: Serialize,
E: std::fmt::Display,
{
let close_notify = match self.close_notify.clone() {
let conn_closed = match self.close_notify.clone() {
Some(close_notify) => close_notify,
None => return SubscriptionClosed::RemotePeerAborted,
None => {
return SubscriptionClosed::RemotePeerAborted;
}
};

let mut sub_closed = self.unsubscribe.clone();
let sub_closed_fut = sub_closed.changed();

let conn_closed_fut = conn_closed.notified();
pin_mut!(conn_closed_fut);
pin_mut!(sub_closed_fut);

let mut stream_item = stream.try_next();
let closed_fut = close_notify.notified();
pin_mut!(closed_fut);
let mut closed_fut = futures_util::future::select(conn_closed_fut, sub_closed_fut);

loop {
match futures_util::future::select(stream_item, closed_fut).await {
// The app sent us a value to send back to the subscribers
Expand All @@ -922,7 +932,7 @@ impl SubscriptionSink {
break SubscriptionClosed::Failed(err);
}
Either::Left((Ok(None), _)) => break SubscriptionClosed::Success,
Either::Right(((), _)) => {
Either::Right((_, _)) => {
break SubscriptionClosed::RemotePeerAborted;
}
}
Expand Down Expand Up @@ -956,13 +966,13 @@ impl SubscriptionSink {
self.pipe_from_try_stream::<_, _, Error>(stream.map(|item| Ok(item))).await
}

/// Returns whether this channel is closed without needing a context.
/// Returns whether the subscription is closed.
pub fn is_closed(&self) -> bool {
self.inner.is_closed() || self.close_notify.is_none()
self.inner.is_closed() || self.close_notify.is_none() || !self.is_active_subscription()
}

fn is_active_subscription(&self) -> bool {
Arc::strong_count(&self.active_sub) > 1
!self.unsubscribe.has_changed().is_err()
}

fn build_message<T: Serialize>(&self, result: &T) -> Result<String, serde_json::Error> {
Expand All @@ -981,14 +991,6 @@ impl SubscriptionSink {
.map_err(Into::into)
}

fn inner_send(&mut self, msg: String) -> bool {
if self.is_active_subscription() {
self.inner.send_raw(msg).is_ok()
} else {
false
}
}

/// Close the subscription, sending a notification with a special `error` field containing the provided error.
///
/// This can be used to signal an actual error, or just to signal that the subscription has been closed,
Expand Down
116 changes: 100 additions & 16 deletions tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@
use std::net::SocketAddr;
use std::time::Duration;

use futures::{SinkExt, StreamExt};
use jsonrpsee::core::error::SubscriptionClosed;
use jsonrpsee::http_server::{AccessControl, HttpServerBuilder, HttpServerHandle};
use jsonrpsee::types::error::{ErrorObject, SUBSCRIPTION_CLOSED_WITH_ERROR};
use jsonrpsee::ws_server::{WsServerBuilder, WsServerHandle};
use jsonrpsee::RpcModule;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;

pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle) {
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
Expand All @@ -40,10 +44,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle

module
.register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, pending, _| {
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let mut sink = pending.accept().unwrap();
std::thread::spawn(move || loop {
if let Ok(false) = sink.send(&"hello from subscription") {
break;
Expand All @@ -55,10 +56,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle

module
.register_subscription("subscribe_foo", "subscribe_foo", "unsubscribe_foo", |_, pending, _| {
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let mut sink = pending.accept().unwrap();
std::thread::spawn(move || loop {
if let Ok(false) = sink.send(&1337_usize) {
break;
Expand All @@ -75,10 +73,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle
_ => return,
};

let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let mut sink = pending.accept().unwrap();

std::thread::spawn(move || loop {
count = count.wrapping_add(1);
Expand All @@ -92,10 +87,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle

module
.register_subscription("subscribe_noop", "subscribe_noop", "unsubscribe_noop", |_, pending, _| {
let sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let sink = pending.accept().unwrap();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_secs(1));
let err = ErrorObject::owned(
Expand All @@ -108,6 +100,73 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle
})
.unwrap();

module
.register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, pending, _| {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we can call pipe_from_stream more than once, should we also have a test to make sure that we can?

let mut sink = pending.accept().unwrap();

tokio::spawn(async move {
let interval = interval(Duration::from_millis(50));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);

match sink.pipe_from_stream(stream).await {
SubscriptionClosed::Success => {
sink.close(SubscriptionClosed::Success);
}
_ => unreachable!(),
}
});
})
.unwrap();

module
.register_subscription("can_reuse_subscription", "n", "u_can_reuse_subscription", |_, pending, _| {
let mut sink = pending.accept().unwrap();

tokio::spawn(async move {
let stream1 = IntervalStream::new(interval(Duration::from_millis(50)))
.zip(futures::stream::iter(1..=5))
.map(|(_, c)| c);
let stream2 = IntervalStream::new(interval(Duration::from_millis(50)))
.zip(futures::stream::iter(6..=10))
.map(|(_, c)| c);

let result = sink.pipe_from_stream(stream1).await;
assert!(matches!(result, SubscriptionClosed::Success));

match sink.pipe_from_stream(stream2).await {
SubscriptionClosed::Success => {
sink.close(SubscriptionClosed::Success);
}
_ => unreachable!(),
}
});
})
.unwrap();

module
.register_subscription(
"subscribe_with_err_on_stream",
"n",
"unsubscribe_with_err_on_stream",
move |_, pending, _| {
let mut sink = pending.accept().unwrap();

let err: &'static str = "error on the stream";

// create stream that produce an error which will cancel the subscription.
let stream = futures::stream::iter(vec![Ok(1_u32), Err(err), Ok(2), Ok(3)]);
tokio::spawn(async move {
match sink.pipe_from_try_stream(stream).await {
SubscriptionClosed::Failed(e) => {
sink.close(e);
}
_ => unreachable!(),
}
});
},
)
.unwrap();

let addr = server.local_addr().unwrap();
let server_handle = server.start(module).unwrap();

Expand All @@ -133,6 +192,31 @@ pub async fn websocket_server() -> SocketAddr {
addr
}

/// Yields one item then sleeps for an hour.
pub async fn websocket_server_with_sleeping_subscription(tx: futures::channel::mpsc::Sender<()>) -> SocketAddr {
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let addr = server.local_addr().unwrap();

let mut module = RpcModule::new(tx);

module
.register_subscription("subscribe_sleep", "n", "unsubscribe_sleep", |_, pending, mut tx| {
let mut sink = pending.accept().unwrap();

tokio::spawn(async move {
let interval = interval(Duration::from_secs(60 * 60));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);

sink.pipe_from_stream(stream).await;
let send_back = std::sync::Arc::make_mut(&mut tx);
send_back.send(()).await.unwrap();
});
})
.unwrap();
server.start(module).unwrap();
addr
}

pub async fn http_server() -> (SocketAddr, HttpServerHandle) {
http_server_with_access_control(AccessControl::default()).await
}
Expand Down
Loading