Skip to content

Commit a951814

Browse files
committed
Implement SemLock with eventfd where available (3.10+/Linux)
Otherwise SemLock is semaphore-based and blocks when RLock eventfd-based solution doesn't work with `forkserver` fixes #16
1 parent 02fe299 commit a951814

File tree

6 files changed

+352
-8
lines changed

6 files changed

+352
-8
lines changed

.github/workflows/build.yml

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
- '3.6'
2626
- 'pypy-3.7'
2727
- 'pypy-3.8'
28+
# - 'pypy-3.9'
2829
with-venv:
2930
- 'true'
3031
- 'false'

.idea/geventmp.iml

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/integrationtest/python/_mp_test_gevent.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818

1919
from gevent import spawn, monkey
2020
from gevent.util import assert_switches
21-
from geventmp.monkey import GEVENT_SAVED_MODULE_SETTINGS
2221
from multiprocessing.util import get_logger
2322

23+
from geventmp.monkey import GEVENT_SAVED_MODULE_SETTINGS
24+
2425
logger = get_logger()
2526

2627

@@ -72,3 +73,34 @@ def count():
7273
task.kill()
7374
logger.info("exiting")
7475
sys.exit(10)
76+
77+
78+
def test_joinable_queues(r_q, w_q):
79+
def count():
80+
while True:
81+
sleep(0.01)
82+
83+
task = spawn(count)
84+
task.start()
85+
86+
with assert_switches():
87+
sleep(1)
88+
89+
with assert_switches():
90+
logger.info(r_q.get(timeout=5))
91+
92+
with assert_switches():
93+
r_q.task_done()
94+
95+
with assert_switches():
96+
sleep(1)
97+
98+
with assert_switches():
99+
w_q.put(test_queues.__name__, timeout=5)
100+
101+
with assert_switches():
102+
sleep(1)
103+
104+
task.kill()
105+
logger.info("exiting")
106+
sys.exit(10)

src/integrationtest/python/monkey_mp_tests.py

+53-4
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
if not getattr(current_process(), "_inheriting", False):
2121
monkey.patch_all()
2222

23-
# from multiprocessing.util import log_to_stderr
24-
# log_to_stderr(1)
23+
# from multiprocessing.util import log_to_stderr
24+
# log_to_stderr(1)
2525

2626
from unittest import TestCase, main
2727
import trace
@@ -59,8 +59,11 @@ def test_mp_queues_fork(self):
5959
def test_mp_queues_spawn(self):
6060
self.run_test_mp_queues("spawn", _mp_test_gevent.test_queues)
6161

62-
def test_mp_queues_forkserver(self):
63-
self.run_test_mp_queues("forkserver", _mp_test.test_queues)
62+
def test_mp_jqueues_fork(self):
63+
self.run_test_mp_jqueues("fork", _mp_test_gevent.test_queues)
64+
65+
def test_mp_jqueues_spawn(self):
66+
self.run_test_mp_jqueues("spawn", _mp_test_gevent.test_queues)
6467

6568
def test_mp_no_args_fork(self):
6669
self.run_test_mp_no_args("fork", _mp_test_gevent.test_no_args)
@@ -118,6 +121,16 @@ def run_test_mp_queues(self, context, func, do_trace=False):
118121
else:
119122
self._test_mp_queues(p, r_q, w_q)
120123

124+
def run_test_mp_jqueues(self, context, func, do_trace=False):
125+
ctx = mp.get_context(context)
126+
r_q = ctx.JoinableQueue()
127+
w_q = ctx.JoinableQueue()
128+
p = ctx.Process(target=func, args=(w_q, r_q))
129+
if do_trace:
130+
trace.Trace(count=0).runfunc(self._test_mp_jqueues, p, r_q, w_q)
131+
else:
132+
self._test_mp_jqueues(p, r_q, w_q)
133+
121134
def _test_mp_queues(self, p, r_q, w_q):
122135
async_counter = [0]
123136

@@ -151,6 +164,42 @@ def count():
151164
logger.info("Async counter counted to %d" % async_counter[0])
152165
self.assertGreater(async_counter[0], 0)
153166

167+
def _test_mp_jqueues(self, p, r_q, w_q):
168+
async_counter = [0]
169+
170+
def count():
171+
while True:
172+
idle()
173+
async_counter[0] += 1
174+
sleep(0.001)
175+
176+
task = spawn(count)
177+
178+
p.start()
179+
self.assertTrue(p.pid > 0)
180+
181+
with assert_switches():
182+
w_q.put("master", timeout=5)
183+
with assert_switches():
184+
self.assertEqual(r_q.get(timeout=5), "test_queues")
185+
r_q.task_done()
186+
187+
with assert_switches():
188+
start = clock()
189+
p.join(15)
190+
end = clock()
191+
logger.info("Waited for child to die for %f" % (end - start))
192+
task.kill()
193+
194+
# This is to ensure a greenlet flip
195+
sleep(0.001)
196+
197+
logger.info(f"checking {p} is alive")
198+
self.assertFalse(p.is_alive())
199+
self.assertEqual(p.exitcode, 10)
200+
logger.info("Async counter counted to %d" % async_counter[0])
201+
self.assertGreater(async_counter[0], 0)
202+
154203

155204
if __name__ == '__main__':
156205
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2022 Karellen, Inc. and contributors
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from multiprocessing.process import current_process
17+
18+
from gevent import monkey
19+
20+
if not getattr(current_process(), "_inheriting", False):
21+
monkey.patch_all()
22+
23+
# from multiprocessing.util import log_to_stderr
24+
# log_to_stderr(1)
25+
26+
from unittest import TestCase, main
27+
28+
from geventmp.monkey import GEVENT_SAVED_MODULE_SETTINGS
29+
30+
import multiprocessing as mp
31+
from multiprocessing.util import get_logger
32+
import sys
33+
from time import monotonic
34+
35+
logger = get_logger()
36+
37+
38+
class TestSynchronizers(TestCase):
39+
def setUp(self):
40+
self.assertTrue(monkey.saved[GEVENT_SAVED_MODULE_SETTINGS].get("geventmp"),
41+
"GeventMP patch has not run!")
42+
self.tearDown()
43+
44+
def tearDown(self):
45+
sys.stdout.flush()
46+
sys.stderr.flush()
47+
logger.info("=====================")
48+
sys.stdout.flush()
49+
50+
def test_semaphore_fork(self):
51+
self._test_semaphore("fork")
52+
53+
def test_semaphore_spawn(self):
54+
self._test_semaphore("spawn")
55+
56+
def test_semaphore_forkserver(self):
57+
self._test_semaphore("forkserver")
58+
59+
def test_bounded_semaphore_fork(self):
60+
self._test_bounded_semaphore("fork")
61+
62+
def test_bounded_semaphore_spawn(self):
63+
self._test_bounded_semaphore("spawn")
64+
65+
def test_bounded_semaphore_forkserver(self):
66+
self._test_bounded_semaphore("forkserver")
67+
68+
def _test_semaphore(self, ctx_name):
69+
ctx = mp.get_context(ctx_name)
70+
s = ctx.Semaphore()
71+
self.assertEqual(s.get_value(), 1)
72+
self.assertEqual(s._semlock._count(), 0)
73+
repr(s)
74+
75+
s.release()
76+
self.assertEqual(s.get_value(), 2)
77+
self.assertEqual(s._semlock._count(), -1)
78+
repr(s)
79+
80+
self.assertTrue(s.acquire())
81+
self.assertEqual(s.get_value(), 1)
82+
self.assertEqual(s._semlock._count(), 0)
83+
repr(s)
84+
85+
self.assertTrue(s.acquire())
86+
self.assertEqual(s.get_value(), 0)
87+
self.assertEqual(s._semlock._count(), 1)
88+
repr(s)
89+
90+
self.assertFalse(s.acquire(block=False))
91+
self.assertEqual(s.get_value(), 0)
92+
self.assertEqual(s._semlock._count(), 1)
93+
repr(s)
94+
95+
start = monotonic()
96+
self.assertFalse(s.acquire(timeout=1))
97+
duration = monotonic() - start
98+
self.assertGreater(duration, 1)
99+
self.assertLess(duration, 3)
100+
s.release()
101+
self.assertEqual(s.get_value(), 1)
102+
self.assertEqual(s._semlock._count(), 0)
103+
104+
def _test_bounded_semaphore(self, ctx_name):
105+
ctx = mp.get_context(ctx_name)
106+
s = ctx.BoundedSemaphore(2)
107+
repr(s)
108+
109+
self.assertEqual(s.get_value(), 2)
110+
self.assertEqual(s._semlock._count(), 0)
111+
112+
self.assertTrue(s.acquire())
113+
self.assertEqual(s.get_value(), 1)
114+
self.assertEqual(s._semlock._count(), 1)
115+
repr(s)
116+
117+
self.assertTrue(s.acquire())
118+
self.assertEqual(s.get_value(), 0)
119+
self.assertEqual(s._semlock._count(), 2)
120+
repr(s)
121+
122+
self.assertFalse(s.acquire(block=False))
123+
self.assertEqual(s.get_value(), 0)
124+
self.assertEqual(s._semlock._count(), 2)
125+
repr(s)
126+
127+
start = monotonic()
128+
self.assertFalse(s.acquire(timeout=1))
129+
duration = monotonic() - start
130+
self.assertGreater(duration, 1)
131+
self.assertLess(duration, 3)
132+
s.release()
133+
self.assertEqual(s.get_value(), 1)
134+
self.assertEqual(s._semlock._count(), 1)
135+
s.release()
136+
self.assertEqual(s.get_value(), 2)
137+
self.assertEqual(s._semlock._count(), 0)
138+
with self.assertRaises(ValueError):
139+
s.release()
140+
141+
142+
if __name__ == '__main__':
143+
main()

0 commit comments

Comments
 (0)