Skip to content

Commit

Permalink
Add support to stop SpinGroup and show interrupt debrief (#569)
Browse files Browse the repository at this point in the history
  • Loading branch information
tjwp authored Nov 25, 2024
1 parent 6eb513c commit 4d3fcf8
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 3 deletions.
52 changes: 50 additions & 2 deletions lib/cli/ui/spinner/spin_group.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def pause_spinners(&block)
# ==== Options
#
# * +:auto_debrief+ - Automatically debrief exceptions or through success_debrief? Default to true
# * +:interrupt_debrief+ - Automatically debrief on interrupt. Default to false
# * +:max_concurrent+ - Maximum number of concurrent tasks. Default is 0 (effectively unlimited)
# * +:work_queue+ - Custom WorkQueue instance. If not provided, a new one will be created
#
Expand All @@ -67,15 +68,18 @@ def pause_spinners(&block)
sig do
params(
auto_debrief: T::Boolean,
interrupt_debrief: T::Boolean,
max_concurrent: Integer,
work_queue: T.nilable(WorkQueue),
).void
end
def initialize(auto_debrief: true, max_concurrent: 0, work_queue: nil)
def initialize(auto_debrief: true, interrupt_debrief: false, max_concurrent: 0, work_queue: nil)
@m = Mutex.new
@tasks = []
@auto_debrief = auto_debrief
@interrupt_debrief = interrupt_debrief
@start = Time.new
@stopped = false
@internal_work_queue = work_queue.nil?
@work_queue = T.let(
work_queue || WorkQueue.new(max_concurrent.zero? ? 1024 : max_concurrent),
Expand All @@ -96,6 +100,9 @@ class Task
sig { returns(T::Boolean) }
attr_reader :success

sig { returns(T::Boolean) }
attr_reader :done

sig { returns(T.nilable(Exception)) }
attr_reader :exception

Expand Down Expand Up @@ -140,6 +147,11 @@ def initialize(title, final_glyph:, merged_output:, duplicate_output_to:, work_q
@success = false
end

sig { params(block: T.proc.params(task: Task).void).void }
def on_done(&block)
@on_done = block
end

# Checks if a task is finished
#
sig { returns(T::Boolean) }
Expand All @@ -157,6 +169,8 @@ def check
@success = false
end

@on_done&.call(self)

@done
end

Expand Down Expand Up @@ -305,6 +319,34 @@ def add(
end
end

sig { void }
def stop
# If we already own the mutex (called from within another synchronized block),
# set stopped directly to avoid deadlock
if @m.owned?
return if @stopped

@stopped = true
else
@m.synchronize do
return if @stopped

@stopped = true
end
end
# Interrupt is thread-safe on its own, so we can call it outside the mutex
@work_queue.interrupt
end

sig { returns(T::Boolean) }
def stopped?
if @m.owned?
@stopped
else
@m.synchronize { @stopped }
end
end

# Tells the group you're done adding tasks and to wait for all of them to finish
#
# ==== Example Usage:
Expand All @@ -324,6 +366,8 @@ def wait
tasks_seen_done = @tasks.map { false }

loop do
break if stopped?

done_count = 0

width = CLI::UI::Terminal.width
Expand Down Expand Up @@ -374,7 +418,8 @@ def wait
end
rescue Interrupt
@work_queue.interrupt
raise
debrief if @interrupt_debrief
stopped? ? false : raise
end

# Provide an alternative debriefing for failed tasks
Expand Down Expand Up @@ -410,6 +455,8 @@ def all_succeeded?
def debrief
@m.synchronize do
@tasks.each do |task|
next unless task.done

title = task.title
out = task.stdout
err = task.stderr
Expand All @@ -418,6 +465,7 @@ def debrief
next @success_debrief&.call(title, out, err)
end

# exception will not be set if the wait loop is stopped before the task is checked
e = task.exception
next @failure_debrief.call(title, e, out, err) if @failure_debrief

Expand Down
152 changes: 151 additions & 1 deletion test/cli/ui/spinner/spin_group_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,15 @@ def test_spin_group_interrupt

sg.add('Interruptible task') do
started_queue.push(true)
10.times { sleep(0.1) }
sleep(1)
task_completed = true
rescue Interrupt
task_interrupted = true
raise
end

t = Thread.new { sg.wait }
t.report_on_exception = false

# Wait for task to start
started_queue.pop
Expand All @@ -155,6 +156,155 @@ def test_spin_group_interrupt
assert(task_interrupted, 'Task should have been interrupted')
end
end

def test_spin_group_stop
capture_io do
CLI::UI::StdoutRouter.ensure_activated
sg = SpinGroup.new

task_started = false
task_completed = false

sg.add('Stoppable task') do
task_started = true
sleep(1)
task_completed = true
end

t = Thread.new { sg.wait }

# Wait for task to start
sleep(0.1) until task_started

# Stop the spin group
sg.stop

t.join

refute(task_completed, 'Task should not complete after stop')
assert(sg.stopped?, 'SpinGroup should be marked as stopped')
refute(sg.all_succeeded?, 'Tasks should not be marked as succeeded after stop')
end
end

def test_spin_group_nested_stop
capture_io do
CLI::UI::StdoutRouter.ensure_activated
sg = SpinGroup.new

sg.add('Outer task') do
sg.stop
true
end

refute(sg.wait, 'SpinGroup#wait should return false when stopped')
assert(sg.stopped?, 'SpinGroup should be marked as stopped')
end
end

def test_spin_group_interrupt_with_debrief
capture_io do
CLI::UI::StdoutRouter.ensure_activated
sg = SpinGroup.new(interrupt_debrief: true)
task_interrupted = false
debrief_called = false

# Use Queue for thread-safe signaling
started_queue = Queue.new

sg.failure_debrief do |title, _exception, _out, _err|
assert_equal('Failed task', title)
debrief_called = true
end

sg.add('Failed task') do
TASK_FAILED
end

sg.add('Interruptible task') do
started_queue.push(true)
sleep(1)
rescue Interrupt
task_interrupted = true
raise
end

t = Thread.new { sg.wait }
t.report_on_exception = false

# Wait for task to start
started_queue.pop
sleep(0.1) # Small delay to ensure we're in sleep
t.raise(Interrupt)

# The interrupt should propagate since we didn't stop
assert_raises(Interrupt) { t.join }
assert(task_interrupted, 'Task should have been interrupted')
assert(debrief_called, 'Debrief should have been called')
end
end

def test_spin_group_interrupt_without_debrief
capture_io do
CLI::UI::StdoutRouter.ensure_activated
sg = SpinGroup.new(interrupt_debrief: false)

# Use Queue for thread-safe signaling
started_queue = Queue.new

debrief_called = false
sg.failure_debrief do
debrief_called = true
end

sg.add('Failed task') do
TASK_FAILED
end
sg.add('Interruptible task') do
started_queue.push(true)
sleep(1)
false
end

t = Thread.new { sg.wait }
t.report_on_exception = false

# Wait for task to actually start
started_queue.pop
sleep(0.1) # Small delay to ensure we're in sleep

# Interrupt should be raised through
t.raise(Interrupt)
assert_raises(Interrupt) { t.join }

refute(debrief_called, 'failure_debrief should not be called when interrupt_debrief is false')
end
end

def test_task_on_done_callback
capture_io do
CLI::UI::StdoutRouter.ensure_activated
sg = SpinGroup.new

callback_executed = false
task_completed = false

sg.add('Task with callback') do |task|
task.on_done do |completed_task|
callback_executed = true
assert_equal('Task with callback', completed_task.title)
assert(completed_task.done)
assert(completed_task.success)
end
task_completed = true
true
end

assert(sg.wait)
assert(task_completed, 'Task should have completed')
assert(callback_executed, 'on_done callback should have been executed')
end
end
end
end
end
Expand Down

0 comments on commit 4d3fcf8

Please sign in to comment.