Skip to content

Commit

Permalink
refactor: implement graceful shutdown for IndexerExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiupopescu199 committed Dec 31, 2024
1 parent 5dd444a commit 11bba52
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 60 deletions.
54 changes: 35 additions & 19 deletions crates/iota-data-ingestion-core/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use iota_types::{
full_checkpoint_content::CheckpointData, messages_checkpoint::CheckpointSequenceNumber,
};
use prometheus::Registry;
use tokio::sync::mpsc;
use tokio::{sync::mpsc, task::JoinHandle};
use tokio_util::sync::CancellationToken;

use crate::{
Expand All @@ -30,10 +30,16 @@ pub struct IndexerExecutor<P> {
pool_progress_sender: mpsc::Sender<(String, CheckpointSequenceNumber)>,
pool_progress_receiver: mpsc::Receiver<(String, CheckpointSequenceNumber)>,
metrics: DataIngestionMetrics,
token: CancellationToken,
}

impl<P: ProgressStore> IndexerExecutor<P> {
pub fn new(progress_store: P, number_of_jobs: usize, metrics: DataIngestionMetrics) -> Self {
pub fn new(
progress_store: P,
number_of_jobs: usize,
metrics: DataIngestionMetrics,
token: CancellationToken,
) -> Self {
let (pool_progress_sender, pool_progress_receiver) =
mpsc::channel(number_of_jobs * MAX_CHECKPOINTS_IN_PROGRESS);
Self {
Expand All @@ -43,6 +49,7 @@ impl<P: ProgressStore> IndexerExecutor<P> {
pool_progress_sender,
pool_progress_receiver,
metrics,
token,
}
}

Expand All @@ -54,6 +61,7 @@ impl<P: ProgressStore> IndexerExecutor<P> {
checkpoint_number,
receiver,
self.pool_progress_sender.clone(),
self.token.child_token(),
)));
self.pool_senders.push(sender);
Ok(())
Expand All @@ -66,25 +74,33 @@ impl<P: ProgressStore> IndexerExecutor<P> {
remote_store_url: Option<String>,
remote_store_options: Vec<(String, String)>,
reader_options: ReaderOptions,
token: CancellationToken,
) -> Result<ExecutorProgress> {
let mut reader_checkpoint_number = self.progress_store.min_watermark()?;
let (checkpoint_reader, mut checkpoint_recv, gc_sender, _exit_sender) =
CheckpointReader::initialize(
path,
reader_checkpoint_number,
remote_store_url,
remote_store_options,
reader_options,
);
spawn_monitored_task!(checkpoint_reader.run());
let (checkpoint_reader, mut checkpoint_recv, gc_sender) = CheckpointReader::initialize(
path,
reader_checkpoint_number,
remote_store_url,
remote_store_options,
reader_options,
self.token.child_token(),
);

let checkpoint_reader_handle = spawn_monitored_task!(checkpoint_reader.run());

let worker_pools = std::mem::take(&mut self.pools)
.into_iter()
.map(|pool| spawn_monitored_task!(pool))
.collect::<Vec<JoinHandle<()>>>();

for pool in std::mem::take(&mut self.pools) {
spawn_monitored_task!(pool);
}
loop {
tokio::select! {
_ = token.cancelled() => break,
_ = self.token.cancelled() => {
for worker in worker_pools {
worker.await?;
}
checkpoint_reader_handle.await??;
break;
}
Some((task_name, sequence_number)) = self.pool_progress_receiver.recv() => {
self.progress_store.save(task_name.clone(), sequence_number).await?;
let seq_number = self.progress_store.min_watermark()?;
Expand All @@ -101,6 +117,7 @@ impl<P: ProgressStore> IndexerExecutor<P> {
}
}
}

Ok(self.progress_store.stats())
}
}
Expand All @@ -115,10 +132,10 @@ pub async fn setup_single_workflow<W: Worker + 'static>(
impl Future<Output = Result<ExecutorProgress>>,
CancellationToken,
)> {
let token = CancellationToken::new();
let metrics = DataIngestionMetrics::new(&Registry::new());
let progress_store = ShimProgressStore(initial_checkpoint_number);
let mut executor = IndexerExecutor::new(progress_store, 1, metrics);
let token = CancellationToken::new();
let mut executor = IndexerExecutor::new(progress_store, 1, metrics, token.child_token());
let worker_pool = WorkerPool::new(worker, "workflow".to_string(), concurrency);
executor.register(worker_pool).await?;
Ok((
Expand All @@ -127,7 +144,6 @@ pub async fn setup_single_workflow<W: Worker + 'static>(
Some(remote_store_url),
vec![],
reader_options.unwrap_or_default(),
token.child_token(),
),
token,
))
Expand Down
14 changes: 7 additions & 7 deletions crates/iota-data-ingestion-core/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ use notify::{RecursiveMode, Watcher};
use object_store::{ObjectStore, path::Path};
use tap::pipe::Pipe;
use tokio::{
sync::{mpsc, mpsc::error::TryRecvError, oneshot},
sync::{mpsc, mpsc::error::TryRecvError},
time::timeout,
};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info};

use crate::{create_remote_store_client, executor::MAX_CHECKPOINTS_IN_PROGRESS};
Expand All @@ -36,7 +37,7 @@ pub struct CheckpointReader {
checkpoint_sender: mpsc::Sender<CheckpointData>,
processed_receiver: mpsc::Receiver<CheckpointSequenceNumber>,
remote_fetcher_receiver: Option<mpsc::Receiver<Result<(CheckpointData, usize)>>>,
exit_receiver: oneshot::Receiver<()>,
token: CancellationToken,
options: ReaderOptions,
data_limiter: DataLimiter,
}
Expand Down Expand Up @@ -316,15 +317,14 @@ impl CheckpointReader {
remote_store_url: Option<String>,
remote_store_options: Vec<(String, String)>,
options: ReaderOptions,
token: CancellationToken,
) -> (
Self,
mpsc::Receiver<CheckpointData>,
mpsc::Sender<CheckpointSequenceNumber>,
oneshot::Sender<()>,
) {
let (checkpoint_sender, checkpoint_recv) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
let (processed_sender, processed_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
let (exit_sender, exit_receiver) = oneshot::channel();
let reader = Self {
path,
remote_store_url,
Expand All @@ -334,11 +334,11 @@ impl CheckpointReader {
checkpoint_sender,
processed_receiver,
remote_fetcher_receiver: None,
exit_receiver,
token,
data_limiter: DataLimiter::new(options.data_limit),
options,
};
(reader, checkpoint_recv, processed_sender, exit_sender)
(reader, checkpoint_recv, processed_sender)
}

pub async fn run(mut self) -> Result<()> {
Expand All @@ -362,7 +362,7 @@ impl CheckpointReader {

loop {
tokio::select! {
_ = &mut self.exit_receiver => break,
_ = self.token.cancelled() => break,
Some(gc_checkpoint_number) = self.processed_receiver.recv() => {
self.gc_processed_files(gc_checkpoint_number).expect("Failed to clean the directory");
}
Expand Down
27 changes: 15 additions & 12 deletions crates/iota-data-ingestion-core/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async fn run(
indexer: IndexerExecutor<FileProgressStore>,
path: Option<PathBuf>,
duration: Option<Duration>,
token: CancellationToken,
) -> Result<ExecutorProgress> {
let options = ReaderOptions {
tick_interval_ms: 10,
Expand All @@ -52,24 +53,15 @@ async fn run(
match duration {
None => {
indexer
.run(
path.unwrap_or_else(temp_dir),
None,
vec![],
options,
CancellationToken::new(),
)
.run(path.unwrap_or_else(temp_dir), None, vec![], options)
.await
}
Some(duration) => {
let token = CancellationToken::new();
let token_child = token.child_token();
let handle = tokio::task::spawn(indexer.run(
path.unwrap_or_else(temp_dir),
None,
vec![],
options,
token_child,
));
tokio::time::sleep(duration).await;
token.cancel();
Expand All @@ -81,6 +73,7 @@ async fn run(
struct ExecutorBundle {
executor: IndexerExecutor<FileProgressStore>,
_progress_file: NamedTempFile,
token: CancellationToken,
}

#[derive(Clone)]
Expand All @@ -96,7 +89,7 @@ impl Worker for TestWorker {
#[tokio::test]
async fn empty_pools() {
let bundle = create_executor_bundle();
let result = run(bundle.executor, None, None).await;
let result = run(bundle.executor, None, None, bundle.token).await;
assert!(result.is_err());
if let Err(err) = result {
assert!(err.to_string().contains("pools can't be empty"));
Expand All @@ -114,7 +107,13 @@ async fn basic_flow() {
let bytes = mock_checkpoint_data_bytes(checkpoint_number);
std::fs::write(path.join(format!("{}.chk", checkpoint_number)), bytes).unwrap();
}
let result = run(bundle.executor, Some(path), Some(Duration::from_secs(1))).await;
let result = run(
bundle.executor,
Some(path),
Some(Duration::from_secs(1)),
bundle.token,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().get("test"), Some(&20));
}
Expand All @@ -130,14 +129,18 @@ fn create_executor_bundle() -> ExecutorBundle {
let path = progress_file.path().to_path_buf();
std::fs::write(path.clone(), "{}").unwrap();
let progress_store = FileProgressStore::new(path);
let token = CancellationToken::new();
let child_token = token.child_token();
let executor = IndexerExecutor::new(
progress_store,
1,
DataIngestionMetrics::new(&Registry::new()),
child_token,
);
ExecutorBundle {
executor,
_progress_file: progress_file,
token,
}
}

Expand Down
28 changes: 13 additions & 15 deletions crates/iota-data-ingestion-core/src/worker_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use iota_metrics::spawn_monitored_task;
use iota_types::{
full_checkpoint_content::CheckpointData, messages_checkpoint::CheckpointSequenceNumber,
};
use tokio::sync::{mpsc, oneshot};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::info;

use crate::{Worker, executor::MAX_CHECKPOINTS_IN_PROGRESS};
Expand All @@ -36,6 +37,7 @@ impl<W: Worker + 'static> WorkerPool<W> {
mut current_checkpoint_number: CheckpointSequenceNumber,
mut checkpoint_receiver: mpsc::Receiver<CheckpointData>,
executor_progress_sender: mpsc::Sender<(String, CheckpointSequenceNumber)>,
token: CancellationToken,
) {
info!(
"Starting indexing pipeline {} with concurrency {}. Current watermark is {}.",
Expand All @@ -54,16 +56,17 @@ impl<W: Worker + 'static> WorkerPool<W> {
for worker_id in 0..self.concurrency {
let (worker_sender, mut worker_recv) =
mpsc::channel::<CheckpointData>(MAX_CHECKPOINTS_IN_PROGRESS);
let (term_sender, mut term_receiver) = oneshot::channel::<()>();
let cloned_progress_sender = progress_sender.clone();
let task_name = self.task_name.clone();
workers.push((worker_sender, term_sender));
workers.push(worker_sender);

let token = token.child_token();

let worker = self.worker.clone();
let join_handle = spawn_monitored_task!(async move {
loop {
tokio::select! {
_ = &mut term_receiver => break,
_ = token.cancelled() => break,
Some(checkpoint) = worker_recv.recv() => {
let sequence_number = checkpoint.checkpoint_summary.sequence_number;
info!("received checkpoint for processing {} for workflow {}", sequence_number, task_name);
Expand Down Expand Up @@ -97,6 +100,7 @@ impl<W: Worker + 'static> WorkerPool<W> {
// main worker pool loop
loop {
tokio::select! {
_ = token.cancelled() => break,
Some((worker_id, status_update, progress_watermark)) = progress_receiver.recv() => {
idle.insert(worker_id);
updates.insert(status_update, progress_watermark);
Expand All @@ -121,7 +125,7 @@ impl<W: Worker + 'static> WorkerPool<W> {
while !checkpoints.is_empty() && !idle.is_empty() {
let checkpoint = checkpoints.pop_front().unwrap();
let worker_id = idle.pop_first().unwrap();
if workers[worker_id].0.send(checkpoint).await.is_err() {
if workers[worker_id].send(checkpoint).await.is_err() {
// The worker channel closing is a sign we need to exit this loop.
break;
}
Expand All @@ -137,23 +141,17 @@ impl<W: Worker + 'static> WorkerPool<W> {
checkpoints.push_back(checkpoint);
} else {
let worker_id = idle.pop_first().unwrap();
if workers[worker_id].0.send(checkpoint).await.is_err() {
if workers[worker_id].send(checkpoint).await.is_err() {
// The worker channel closing is a sign we need to exit this loop.
break;
};
}
}
}
}

// Clean up code for graceful termination

// Notify the exit handles of all workers to terminate
drop(workers);

// Wait for all workers to finish
for join_handle in join_handles {
join_handle.await.expect("worker thread panicked");
}
}

tracing::info!("Worker pool `{}` terminated gracefully", self.task_name);
}
}
6 changes: 3 additions & 3 deletions crates/iota-data-ingestion/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ fn setup_env(token: CancellationToken) {
#[tokio::main]
async fn main() -> Result<()> {
let token = CancellationToken::new();
let token_child = token.child_token();
let child_token = token.child_token();
setup_env(token);

let args: Vec<String> = env::args().collect();
Expand All @@ -113,7 +113,8 @@ async fn main() -> Result<()> {
config.progress_store.table_name,
)
.await;
let mut executor = IndexerExecutor::new(progress_store, config.tasks.len(), metrics);
let mut executor =
IndexerExecutor::new(progress_store, config.tasks.len(), metrics, child_token);
for task_config in config.tasks {
match task_config.task {
Task::Archival(archival_config) => {
Expand Down Expand Up @@ -152,7 +153,6 @@ async fn main() -> Result<()> {
config.remote_store_url,
config.remote_store_options,
reader_options,
token_child,
)
.await?;
Ok(())
Expand Down
Loading

0 comments on commit 11bba52

Please sign in to comment.