1
1
#pragma once
2
2
3
3
/* *
4
- * @file RouterThreadPool .h
4
+ * @file vtr_thread_pool .h
5
5
* @brief A generic thread pool for parallel task execution
6
6
*/
7
7
@@ -27,89 +27,70 @@ class thread_pool {
27
27
std::mutex queue_mutex;
28
28
std::condition_variable cv;
29
29
bool stop = false ;
30
- size_t thread_id; // For debugging
31
- size_t tasks_completed{0 }; // Track number of completed tasks
30
+ size_t thread_id;
32
31
33
32
ThreadData (size_t id) : thread_id(id) {}
34
33
};
35
34
36
35
std::vector<std::unique_ptr<ThreadData>> threads;
37
- std::atomic<size_t > next_thread{0 }; // For round-robin assignment
36
+ std::atomic<size_t > next_thread{0 };
38
37
std::atomic<size_t > total_tasks_queued{0 };
39
- vtr::Timer pool_timer; // Track pool lifetime
40
38
std::atomic<size_t > active_tasks{0 };
39
+
40
+ /* Condition variable for wait_for_all*/
41
41
std::mutex completion_mutex;
42
42
std::condition_variable completion_cv;
43
43
44
44
public:
45
45
thread_pool (size_t thread_count) {
46
- // VTR_LOG("Creating thread pool with %zu threads\n", thread_count);
47
46
threads.reserve (thread_count);
48
47
49
48
for (size_t i = 0 ; i < thread_count; i++) {
50
49
auto thread_data = std::make_unique<ThreadData>(i);
51
50
52
- thread_data->thread = std::thread ([td = thread_data. get () ]() {
53
- // VTR_LOG("Thread %zu started\n", td->thread_id );
51
+ thread_data->thread = std::thread ([& ]() {
52
+ ThreadData* td = thread_data. get ( );
54
53
55
54
while (true ) {
56
55
std::function<void ()> task;
57
56
{
58
57
std::unique_lock<std::mutex> lock (td->queue_mutex );
59
- // if (!td->task_queue.empty()) {
60
- // VTR_LOG("Thread %zu has %zu tasks queued\n",
61
- // td->thread_id, td->task_queue.size());
62
- // }
63
-
58
+
64
59
td->cv .wait (lock, [td]() {
65
60
return td->stop || !td->task_queue .empty ();
66
61
});
67
-
62
+
68
63
if (td->stop && td->task_queue .empty ()) {
69
- // VTR_LOG("Thread %zu stopping after completing %zu tasks\n",
70
- // td->thread_id, td->tasks_completed);
71
64
return ;
72
65
}
73
-
66
+
74
67
task = std::move (td->task_queue .front ());
75
68
td->task_queue .pop ();
76
69
}
77
-
70
+
78
71
vtr::Timer task_timer;
79
72
task ();
80
- td->tasks_completed ++;
81
- // VTR_LOG("Thread %zu completed task %zu in %.3f seconds\n",
82
- // td->thread_id, td->tasks_completed, task_timer.elapsed_sec());
83
73
}
84
74
});
85
-
75
+
86
76
threads.push_back (std::move (thread_data));
87
77
}
88
- // VTR_LOG("Thread pool initialization completed in %.3f seconds\n",
89
- // pool_timer.elapsed_sec());
90
78
}
91
79
92
- // Schedule work and get future for result
93
80
template <typename F>
94
81
void schedule_work (F&& f) {
95
82
active_tasks++;
96
83
97
- // Round-robin thread assignment
84
+ /* Round-robin thread assignment */
98
85
size_t thread_idx = (next_thread++) % threads.size ();
99
86
auto thread_data = threads[thread_idx].get ();
100
87
size_t task_id = ++total_tasks_queued;
101
88
102
- // VTR_LOG("Scheduling task %zu to thread %zu\n", task_id, thread_data->thread_id);
103
-
104
- // Wrap the work with task completion tracking
105
89
auto task = [this , f = std::forward<F>(f), thread_id = thread_data->thread_id , task_id]() {
106
90
vtr::Timer task_timer;
107
- // VTR_LOG("Thread %zu starting task %zu\n", thread_id, task_id);
108
-
91
+
109
92
try {
110
93
f ();
111
- // VTR_LOG("Thread %zu completed task %zu successfully in %.3f seconds\n",
112
- // thread_id, task_id, task_timer.elapsed_sec());
113
94
} catch (const std::exception & e) {
114
95
VTR_LOG_ERROR (" Thread %zu failed task %zu with error: %s\n " ,
115
96
thread_id, task_id, e.what ());
@@ -120,19 +101,16 @@ class thread_pool {
120
101
throw ;
121
102
}
122
103
123
- // Track task completion
124
104
size_t remaining = --active_tasks;
125
105
if (remaining == 0 ) {
126
106
completion_cv.notify_all ();
127
107
}
128
108
};
129
109
130
- // Queue the task
110
+ /* Queue new task */
131
111
{
132
112
std::lock_guard<std::mutex> lock (thread_data->queue_mutex );
133
113
thread_data->task_queue .push (std::move (task));
134
- // VTR_LOG("Task %zu queued to thread %zu (queue size: %zu)\n",
135
- // task_id, thread_data->thread_id, thread_data->task_queue.size());
136
114
}
137
115
thread_data->cv .notify_one ();
138
116
}
@@ -143,26 +121,18 @@ class thread_pool {
143
121
}
144
122
145
123
~thread_pool () {
146
- // VTR_LOG("Shutting down thread pool after %.3f seconds, processed %zu total tasks\n",
147
- // pool_timer.elapsed_sec(), total_tasks_queued.load());
148
-
149
- // Signal all threads to stop
124
+ /* Stop all threads */
150
125
for (auto & thread_data : threads) {
151
126
{
152
127
std::lock_guard<std::mutex> lock (thread_data->queue_mutex );
153
128
thread_data->stop = true ;
154
- // VTR_LOG("Signaling thread %zu to stop (remaining tasks: %zu)\n",
155
- // thread_data->thread_id, thread_data->task_queue.size());
156
129
}
157
130
thread_data->cv .notify_one ();
158
131
}
159
132
160
- // Join all threads
161
133
for (auto & thread_data : threads) {
162
134
if (thread_data->thread .joinable ()) {
163
135
thread_data->thread .join ();
164
- // VTR_LOG("Thread %zu joined after completing %zu tasks\n",
165
- // thread_data->thread_id, thread_data->tasks_completed);
166
136
}
167
137
}
168
138
}
0 commit comments