Skip to content

Commit

Permalink
feat(node)!: Implement graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
oblique committed Aug 13, 2024
1 parent 7ba4d0b commit 5a7d254
Show file tree
Hide file tree
Showing 13 changed files with 671 additions and 218 deletions.
18 changes: 13 additions & 5 deletions node/src/daser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use tracing::{debug, error, warn};
use web_time::{Duration, Instant};

use crate::events::{EventPublisher, NodeEvent};
use crate::executor::spawn;
use crate::executor::{spawn, JoinHandle};
use crate::p2p::shwap::sample_cid;
use crate::p2p::{P2p, P2pError};
use crate::store::{BlockRanges, SamplingStatus, Store, StoreError};
Expand Down Expand Up @@ -71,6 +71,7 @@ pub enum DaserError {
/// Component responsible for data availability sampling of blocks from the network.
pub(crate) struct Daser {
cancellation_token: CancellationToken,
join_handle: JoinHandle,
}

/// Arguments used to configure the [`Daser`].
Expand All @@ -96,7 +97,7 @@ impl Daser {
let event_pub = args.event_pub.clone();
let mut worker = Worker::new(args, cancellation_token.child_token())?;

spawn(async move {
let join_handle = spawn(async move {
if let Err(e) = worker.run().await {
error!("Daser stopped because of a fatal error: {e}");

Expand All @@ -106,20 +107,27 @@ impl Daser {
}
});

Ok(Daser { cancellation_token })
Ok(Daser {
cancellation_token,
join_handle,
})
}

/// Stop the worker.
pub(crate) fn stop(&self) {
// Singal the Worker to stop.
// TODO: Should we wait for the Worker to stop?
self.cancellation_token.cancel();
}

/// Wait until worker is completely stopped.
pub(crate) async fn join(&self) {
self.join_handle.join().await;
}
}

impl Drop for Daser {
fn drop(&mut self) {
self.cancellation_token.cancel();
self.stop();
}
}

Expand Down
5 changes: 0 additions & 5 deletions node/src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,6 @@ impl EventPublisher {
file_line: location.line(),
});
}

/// Returns if there are any active subscribers or not.
pub(crate) fn has_subscribers(&self) -> bool {
self.tx.receiver_count() > 0
}
}

impl EventSubscriber {
Expand Down
93 changes: 87 additions & 6 deletions node/src/executor.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
use std::fmt::{self, Debug};
use std::future::Future;

use tokio::select;
use tokio_util::sync::CancellationToken;

use crate::utils::Token;

#[allow(unused_imports)]
pub(crate) use self::imp::{
sleep, spawn, spawn_cancellable, timeout, yield_now, Elapsed, Interval,
};

/// Naive `JoinHandle` implementation.
pub(crate) struct JoinHandle(Token);

impl JoinHandle {
pub(crate) async fn join(&self) {
self.0.triggered().await;
}
}

impl Debug for JoinHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("JoinHandle { .. }")
}
}

#[cfg(not(target_arch = "wasm32"))]
mod imp {
use super::*;
Expand All @@ -17,28 +35,48 @@ mod imp {
pub(crate) use tokio::time::{sleep, timeout};

#[track_caller]
pub(crate) fn spawn<F>(future: F)
pub(crate) fn spawn<F>(future: F) -> JoinHandle
where
F: Future<Output = ()> + Send + 'static,
{
tokio::spawn(future);
let token = Token::new();
let guard = token.trigger_drop_guard();

tokio::spawn(async move {
let _guard = guard;
future.await;
});

JoinHandle(token)
}

/// Spawn a cancellable task.
///
/// This will cancel the task in the highest layer and should not be used
/// if cancellation must happen in a point.
#[track_caller]
pub(crate) fn spawn_cancellable<F>(cancelation_token: CancellationToken, future: F)
pub(crate) fn spawn_cancellable<F>(
cancelation_token: CancellationToken,
future: F,
) -> JoinHandle
where
F: Future<Output = ()> + Send + 'static,
{
let token = Token::new();
let guard = token.trigger_drop_guard();

tokio::spawn(async move {
let _guard = guard;
select! {
// Run branches in order.
biased;

_ = cancelation_token.cancelled() => {}
_ = future => {}
}
});

JoinHandle(token)
}

pub(crate) struct Interval(tokio::time::Interval);
Expand Down Expand Up @@ -80,27 +118,47 @@ mod imp {

use super::*;

pub(crate) fn spawn<F>(future: F)
pub(crate) fn spawn<F>(future: F) -> JoinHandle
where
F: Future<Output = ()> + 'static,
{
wasm_bindgen_futures::spawn_local(future);
let token = Token::new();
let guard = token.trigger_drop_guard();

wasm_bindgen_futures::spawn_local(async move {
let _guard = guard;
future.await;
});

JoinHandle(token)
}

/// Spawn a cancellable task.
///
/// This will cancel the task in the highest layer and should not be used
/// if cancellation must happen in a point.
pub(crate) fn spawn_cancellable<F>(cancelation_token: CancellationToken, future: F)
pub(crate) fn spawn_cancellable<F>(
cancelation_token: CancellationToken,
future: F,
) -> JoinHandle
where
F: Future<Output = ()> + 'static,
{
let token = Token::new();
let guard = token.trigger_drop_guard();

wasm_bindgen_futures::spawn_local(async move {
let _guard = guard;
select! {
// Run branches in order.
biased;

_ = cancelation_token.cancelled() => {}
_ = future => {}
}
});

JoinHandle(token)
}

pub(crate) struct Interval(SendWrapper<IntervalStream>);
Expand Down Expand Up @@ -217,3 +275,26 @@ mod imp {
.await;
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::async_test;
use std::time::Duration;
use web_time::Instant;

#[async_test]
async fn join_handle() {
let now = Instant::now();

let join_handle = spawn(async {
sleep(Duration::from_millis(10)).await;
});

join_handle.join().await;
assert!(now.elapsed() >= Duration::from_millis(10));

// This must return immediately.
join_handle.join().await;
}
}
61 changes: 37 additions & 24 deletions node/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ use celestia_types::ExtendedHeader;
use libp2p::identity::Keypair;
use libp2p::swarm::NetworkInfo;
use libp2p::{Multiaddr, PeerId};
use tokio::select;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use tracing::warn;

use crate::daser::{Daser, DaserArgs};
use crate::events::{EventChannel, EventSubscriber, NodeEvent};
use crate::executor::spawn;
use crate::executor::spawn_cancellable;
use crate::p2p::{P2p, P2pArgs};
use crate::store::{SamplingMetadata, Store, StoreError};
use crate::syncer::{Syncer, SyncerArgs};
Expand Down Expand Up @@ -89,7 +88,7 @@ where
p2p: Arc<P2p>,
store: Arc<S>,
syncer: Arc<Syncer<S>>,
_daser: Arc<Daser>,
daser: Arc<Daser>,
tasks_cancellation_token: CancellationToken,
}

Expand Down Expand Up @@ -144,32 +143,26 @@ where
event_pub: event_channel.publisher(),
})?);

// spawn the task that will stop the services when the fraud is detected
let network_compromised_token = p2p.get_network_compromised_token().await?;
let tasks_cancellation_token = CancellationToken::new();

spawn({
// spawn the task that will stop the services when the fraud is detected
spawn_cancellable(tasks_cancellation_token.child_token(), {
let network_compromised_token = p2p.get_network_compromised_token().await?;
let syncer = syncer.clone();
let daser = daser.clone();
let tasks_cancellation_token = tasks_cancellation_token.child_token();
let event_pub = event_channel.publisher();

async move {
select! {
_ = tasks_cancellation_token.cancelled() => (),
_ = network_compromised_token.cancelled() => {
syncer.stop();
daser.stop();

if event_pub.has_subscribers() {
event_pub.send(NodeEvent::NetworkCompromised);
} else {
// This is a very important message and we want to log it if user
// does not consume our events.
warn!("{}", NodeEvent::NetworkCompromised);
}
}
}
network_compromised_token.triggered().await;

// Network compromised! Stop Syncer and Daser.
syncer.stop();
daser.stop();

event_pub.send(NodeEvent::NetworkCompromised);
// This is a very important message and we want to log it even
// if user consumes our events.
warn!("{}", NodeEvent::NetworkCompromised);
}
});

Expand All @@ -178,13 +171,30 @@ where
p2p,
store,
syncer,
_daser: daser,
daser,
tasks_cancellation_token,
};

Ok((node, event_sub))
}

/// Stop the node.
pub async fn stop(&self) {
// Cancel Node's tasks
self.tasks_cancellation_token.cancel();

// Stop all components that use P2p.
self.daser.stop();
self.syncer.stop();

self.daser.join().await;
self.syncer.join().await;

// Now stop P2p component.
self.p2p.stop();
self.p2p.join().await;
}

/// Returns a new `EventSubscriber`.
pub fn event_subscriber(&self) -> EventSubscriber {
self.event_channel.subscribe()
Expand Down Expand Up @@ -366,7 +376,10 @@ where
S: Store,
{
fn drop(&mut self) {
// we have to cancel the task to drop the Arc's passed to it
// Stop everything, but don't join them.
self.tasks_cancellation_token.cancel();
self.daser.stop();
self.syncer.stop();
self.p2p.stop();
}
}
Loading

0 comments on commit 5a7d254

Please sign in to comment.