Skip to content

Commit b2b211f

Browse files
committed
feat: add single flight and apply to QueryRegionsProvider
1 parent b0c2a7f commit b2b211f

File tree

3 files changed

+122
-1
lines changed

3 files changed

+122
-1
lines changed

qiniu/http/regions_provider.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .region import Region, ServiceName
1616
from .default_client import qn_http_client
1717
from .middleware import RetryDomainsMiddleware
18+
from .single_flight import SingleFlight
1819

1920

2021
class RegionsProvider:
@@ -70,6 +71,9 @@ def _get_region_from_query(data, **kwargs):
7071
)
7172

7273

74+
_query_regions_single_flight = SingleFlight()
75+
76+
7377
class QueryRegionsProvider(RegionsProvider):
7478
def __init__(
7579
self,
@@ -95,7 +99,15 @@ def __init__(
9599
self.max_retry_times_per_endpoint = max_retry_times_per_endpoint
96100

97101
def __iter__(self):
98-
regions = self.__fetch_regions()
102+
endpoints_md5 = io_md5([
103+
to_bytes(e.host) for e in self.endpoints_provider
104+
])
105+
flight_key = ':'.join([
106+
endpoints_md5,
107+
self.access_key,
108+
self.bucket_name
109+
])
110+
regions = _query_regions_single_flight.do(flight_key, self.__fetch_regions)
99111
# change to `yield from` when min version of python update to >= 3.3
100112
for r in regions:
101113
yield r

qiniu/http/single_flight.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import threading
2+
3+
4+
class _FlightLock:
5+
"""
6+
Do not use dataclass which caused the event created only once
7+
"""
8+
def __init__(self):
9+
self.event = threading.Event()
10+
self.result = None
11+
self.error = None
12+
13+
14+
class SingleFlight:
15+
def __init__(self):
16+
self._locks = {}
17+
self._lock = threading.Lock()
18+
19+
def do(self, key, fn, *args, **kwargs):
20+
# here does not use `with` statement
21+
# because need to wait by another object if it exists,
22+
# and reduce the `acquire` times if it not exists
23+
self._lock.acquire()
24+
if key in self._locks:
25+
flight_lock = self._locks[key]
26+
27+
self._lock.release()
28+
flight_lock.event.wait()
29+
30+
if flight_lock.error:
31+
raise flight_lock.error
32+
return flight_lock.result
33+
34+
flight_lock = _FlightLock()
35+
self._locks[key] = flight_lock
36+
self._lock.release()
37+
38+
try:
39+
flight_lock.result = fn(*args, **kwargs)
40+
except Exception as e:
41+
flight_lock.error = e
42+
finally:
43+
flight_lock.event.set()
44+
45+
with self._lock:
46+
del self._locks[key]
47+
48+
if flight_lock.error:
49+
raise flight_lock.error
50+
return flight_lock.result
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
import time
3+
from multiprocessing.pool import ThreadPool
4+
5+
from qiniu.http.single_flight import SingleFlight
6+
7+
class TestSingleFlight:
8+
def test_single_flight_success(self):
9+
sf = SingleFlight()
10+
11+
def fn():
12+
return "result"
13+
14+
result = sf.do("key1", fn)
15+
assert result == "result"
16+
17+
def test_single_flight_exception(self):
18+
sf = SingleFlight()
19+
20+
def fn():
21+
raise ValueError("error")
22+
23+
with pytest.raises(ValueError, match="error"):
24+
sf.do("key2", fn)
25+
26+
def test_single_flight_concurrent(self):
27+
sf = SingleFlight()
28+
share_state = []
29+
results = []
30+
31+
def fn():
32+
time.sleep(1)
33+
share_state.append('share_state')
34+
return "result"
35+
36+
def worker(_n):
37+
result = sf.do("key3", fn)
38+
results.append(result)
39+
40+
ThreadPool(2).map(worker, range(5))
41+
42+
assert len(share_state) == 3
43+
assert all(result == "result" for result in results)
44+
45+
def test_single_flight_different_keys(self):
46+
sf = SingleFlight()
47+
results = []
48+
49+
def fn():
50+
time.sleep(1)
51+
return "result"
52+
53+
def worker(n):
54+
result = sf.do("key{}".format(n), fn)
55+
results.append(result)
56+
57+
ThreadPool(2).map(worker, range(2))
58+
assert len(results) == 2
59+
assert all(result == "result" for result in results)

0 commit comments

Comments
 (0)