Skip to content

Commit d32eb3f

Browse files
committed
Add function to skip enqueued work in thread_pool
1 parent a7bca33 commit d32eb3f

File tree

4 files changed

+67
-17
lines changed

4 files changed

+67
-17
lines changed

base/task.cpp

+25-9
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,33 @@ task_token& task::start(thread_pool& pool)
3333
m_state = state::ENQUEUED;
3434
m_token.reset();
3535

36-
pool.execute([this] { in_worker_thread(); });
36+
m_token.m_work = pool.execute([this] { in_worker_thread(); });
3737
return m_token;
3838
}
3939

40+
bool task::try_skip(thread_pool& pool)
41+
{
42+
bool skipped = pool.try_skip(m_token.m_work);
43+
if (skipped) {
44+
m_token.m_canceled = true;
45+
call_finished();
46+
}
47+
48+
return skipped;
49+
}
50+
51+
void task::call_finished()
52+
{
53+
if (m_finished) {
54+
try {
55+
m_finished(m_token);
56+
}
57+
catch (const std::exception& ex) {
58+
LOG(ERROR, "Exception executing 'finished' callback: %s\n", ex.what());
59+
}
60+
}
61+
}
62+
4063
void task::in_worker_thread()
4164
{
4265
m_state = state::RUNNING;
@@ -50,14 +73,7 @@ void task::in_worker_thread()
5073

5174
m_state = state::FINISHED;
5275

53-
if (m_finished) {
54-
try {
55-
m_finished(m_token);
56-
}
57-
catch (const std::exception& ex) {
58-
LOG(ERROR, "Exception executing 'finished' callback: %s\n", ex.what());
59-
}
60-
}
76+
call_finished();
6177
}
6278

6379
} // namespace base

base/task.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class task_token {
2727

2828
bool canceled() const { return m_canceled; }
2929
float progress() const { return m_progress; }
30-
bool finished() const { return m_canceled || m_progress == m_progress_max; }
3130

3231
void cancel() { m_canceled = true; }
3332
void set_progress(float p)
@@ -51,6 +50,7 @@ class task_token {
5150
std::atomic<bool> m_canceled;
5251
std::atomic<float> m_progress;
5352
float m_progress_min, m_progress_max;
53+
const thread_pool::work* m_work = nullptr;
5454
};
5555

5656
class task {
@@ -71,6 +71,7 @@ class task {
7171
void on_finished(func_t&& f) { m_finished = std::move(f); }
7272

7373
task_token& start(thread_pool& pool);
74+
bool try_skip(thread_pool& pool);
7475

7576
bool running() const { return m_state == state::RUNNING; }
7677

@@ -86,6 +87,7 @@ class task {
8687

8788
private:
8889
void in_worker_thread();
90+
void call_finished();
8991

9092
std::atomic<state> m_state;
9193
task_token m_token;

base/thread_pool.cpp

+19-4
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,27 @@ thread_pool::~thread_pool()
2626
join_all();
2727
}
2828

29-
void thread_pool::execute(std::function<void()>&& func)
29+
const thread_pool::work* thread_pool::execute(std::function<void()>&& func)
3030
{
31+
thread_pool::work_ptr work = std::make_unique<thread_pool::work>(std::move(func));
32+
const thread_pool::work* result = work.get();
3133
const std::unique_lock lock(m_mutex);
3234
ASSERT(m_running);
33-
m_work.push(std::move(func));
35+
m_work.push_back(std::move(work));
3436
m_cv.notify_one();
37+
return result;
38+
}
39+
40+
bool thread_pool::try_skip(const work* w)
41+
{
42+
std::unique_lock<std::mutex> lock(m_mutex);
43+
for (auto it = m_work.begin(); it != m_work.end(); ++it) {
44+
if (w == it->get()) {
45+
m_work.erase(it);
46+
return true;
47+
}
48+
}
49+
return false;
3550
}
3651

3752
void thread_pool::wait_all()
@@ -79,9 +94,9 @@ void thread_pool::worker()
7994
m_cv.wait(lock, [this]() -> bool { return !m_running || !m_work.empty(); });
8095
running = m_running;
8196
if (m_running && !m_work.empty()) {
82-
func = std::move(m_work.front());
97+
func = std::move(m_work.front()->m_func);
8398
++m_doingWork;
84-
m_work.pop();
99+
m_work.pop_front();
85100
}
86101
}
87102
try {

base/thread_pool.h

+20-3
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,37 @@
99
#pragma once
1010

1111
#include <condition_variable>
12+
#include <deque>
1213
#include <functional>
1314
#include <mutex>
14-
#include <queue>
1515
#include <thread>
1616
#include <vector>
1717

1818
namespace base {
1919

2020
class thread_pool {
2121
public:
22+
class work {
23+
friend class thread_pool;
24+
25+
public:
26+
work(std::function<void()>&& func) { m_func = std::move(func); }
27+
28+
private:
29+
std::function<void()> m_func = nullptr;
30+
};
31+
32+
typedef std::unique_ptr<work> work_ptr;
33+
2234
thread_pool(const size_t n);
2335
~thread_pool();
2436

25-
void execute(std::function<void()>&& func);
37+
const work* execute(std::function<void()>&& func);
38+
39+
// Tries to skip the work if it was not started yet, in other words, it
40+
// removes the specified work from the queue if possible. Returns true if it
41+
// was able to do so, or false otherwise.
42+
bool try_skip(const work* w);
2643

2744
// Waits until the queue is empty.
2845
void wait_all();
@@ -39,7 +56,7 @@ class thread_pool {
3956
std::mutex m_mutex;
4057
std::condition_variable m_cv;
4158
std::condition_variable m_cvWait;
42-
std::queue<std::function<void()>> m_work;
59+
std::deque<work_ptr> m_work;
4360
int m_doingWork;
4461
};
4562

0 commit comments

Comments
 (0)