Skip to content

Commit 34340cd

Browse files
committed
Merge branch 'feature/mfr-worker' into feature/buff-worms
Closes #390
2 parents 1b3510e + ecfe0e9 commit 34340cd

File tree

10 files changed

+263
-2
lines changed

10 files changed

+263
-2
lines changed

mfr/__init__.py

Whitespace-only changes.

mfr/tasks/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from mfr.tasks.app import app
2+
from mfr.tasks.render import render
3+
from mfr.tasks.core import celery_task
4+
from mfr.tasks.core import backgrounded
5+
from mfr.tasks.core import wait_on_celery
6+
from mfr.tasks.exceptions import WaitTimeOutError
7+
8+
__all__ = [
9+
'app',
10+
'render',
11+
'celery_task',
12+
'backgrounded',
13+
'wait_on_celery',
14+
'WaitTimeOutError',
15+
]

mfr/tasks/app.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
3+
from celery import Celery
4+
from celery.signals import task_failure
5+
6+
import sentry_sdk
7+
from sentry_sdk.integrations.celery import CeleryIntegration
8+
from sentry_sdk.integrations.logging import LoggingIntegration
9+
10+
from mfr.settings import config
11+
from mfr.version import __version__
12+
from mfr.tasks import settings as tasks_settings
13+
14+
logger = logging.getLogger(__name__)
15+
16+
app = Celery()
17+
app.config_from_object(tasks_settings)
18+
19+
20+
def register_signal():
21+
"""Adapted from `raven.contrib.celery.register_signal`. Remove args and
22+
kwargs from logs so that keys aren't leaked to Sentry.
23+
"""
24+
def process_failure_signal(sender, task_id, *args, **kwargs):
25+
scope = sentry_sdk.get_current_scope()
26+
scope.set_tag('task_id', task_id)
27+
scope.set_tag('task', sender)
28+
sentry_sdk.capture_exception()
29+
30+
task_failure.connect(process_failure_signal, weak=False)
31+
32+
33+
sentry_dsn = config.get_nullable('SENTRY_DSN', None)
34+
if sentry_dsn:
35+
sentry_logging = LoggingIntegration(
36+
level=logging.INFO, # Capture INFO level and above as breadcrumbs
37+
event_level=None, # Do not send logs of any level as events
38+
)
39+
sentry_sdk.init(
40+
sentry_dsn,
41+
release=__version__,
42+
integrations=[CeleryIntegration(), sentry_logging]
43+
)
44+
register_signal()

mfr/tasks/core.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import os
2+
import pickle
3+
import asyncio
4+
import logging
5+
import functools
6+
7+
from celery.backends.base import DisabledBackend
8+
9+
from mfr.tasks.app import app
10+
from mfr.tasks import settings
11+
from mfr.tasks import exceptions
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
def ensure_event_loop():
17+
"""Ensure the existance of an eventloop
18+
Useful for contexts where get_event_loop() may
19+
raise an exception.
20+
:returns: The new event loop
21+
:rtype: BaseEventLoop
22+
"""
23+
try:
24+
return asyncio.get_event_loop()
25+
except (AssertionError, RuntimeError):
26+
asyncio.set_event_loop(asyncio.new_event_loop())
27+
28+
# Note: No clever tricks are used here to dry up code
29+
# This avoids an infinite loop if settings the event loop ever fails
30+
return asyncio.get_event_loop()
31+
32+
33+
def __coroutine_unwrapper(func):
34+
@functools.wraps(func)
35+
def wrapped(*args, **kwargs):
36+
return ensure_event_loop().run_until_complete(func(*args, **kwargs))
37+
wrapped.as_async = func
38+
return wrapped
39+
40+
41+
async def backgrounded(func, *args, **kwargs):
42+
"""Runs the given function with the given arguments in
43+
a background thread
44+
"""
45+
loop = asyncio.get_event_loop()
46+
if asyncio.iscoroutinefunction(func):
47+
func = __coroutine_unwrapper(func)
48+
49+
return (await loop.run_in_executor(
50+
None, # None uses the default executer, ThreadPoolExecuter
51+
functools.partial(func, *args, **kwargs)
52+
))
53+
54+
55+
def backgroundify(func):
56+
@functools.wraps(func)
57+
async def wrapped(*args, **kwargs):
58+
return await backgrounded(func, *args, **kwargs)
59+
return wrapped
60+
61+
62+
def adhoc_file_backend(func, was_bound=False, basepath=None):
63+
basepath = basepath or settings.ADHOC_BACKEND_PATH
64+
65+
@functools.wraps(func)
66+
def wrapped(task, *args, **kwargs):
67+
if was_bound:
68+
args = (task,) + args
69+
70+
try:
71+
result = func(*args, **kwargs)
72+
except Exception as e:
73+
result = e
74+
75+
with open(os.path.join(basepath, task.request.id), 'wb') as result_file:
76+
pickle.dump(result, result_file)
77+
78+
if isinstance(result, Exception):
79+
raise result
80+
return result
81+
return wrapped
82+
83+
84+
def celery_task(func, *args, **kwargs):
85+
"""A wrapper around Celery.task. When the wrapped method is called it will be called using
86+
Celery's Task.delay function and run in a background thread.
87+
88+
If the celery backend is disabled, the task will be wrapped in a function that will write the
89+
result to disk using the pickle serialization protocol.
90+
"""
91+
task_func = __coroutine_unwrapper(func)
92+
93+
if isinstance(app.backend, DisabledBackend):
94+
task_func = adhoc_file_backend(
95+
task_func,
96+
was_bound=kwargs.pop('bind', False)
97+
)
98+
kwargs['bind'] = True
99+
100+
logger.debug(f'celery_task: task_func:({task_func})')
101+
102+
task = app.task(task_func, **kwargs)
103+
task.adelay = backgroundify(task.delay)
104+
105+
return task
106+
107+
108+
@backgroundify
109+
async def wait_on_celery(result, interval=None, timeout=None, basepath=None):
110+
timeout = timeout or settings.WAIT_TIMEOUT
111+
interval = interval or settings.WAIT_INTERVAL
112+
basepath = basepath or settings.ADHOC_BACKEND_PATH
113+
114+
waited = 0
115+
116+
while True:
117+
if isinstance(app.backend, DisabledBackend):
118+
try:
119+
with open(os.path.join(basepath, result.id), 'rb') as result_file:
120+
data = pickle.load(result_file)
121+
if isinstance(data, Exception):
122+
raise data
123+
return data
124+
except FileNotFoundError:
125+
pass
126+
else:
127+
if result.ready():
128+
if result.failed():
129+
raise result.result
130+
return result.result
131+
132+
if waited > timeout:
133+
raise exceptions.WaitTimeOutError
134+
await asyncio.sleep(interval)
135+
waited += interval

mfr/tasks/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
class MfrTaskError(Exception):
2+
pass
3+
4+
5+
class WaitTimeOutError(MfrTaskError):
6+
pass

mfr/tasks/render.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import logging
2+
3+
from mfr.tasks import core
4+
logger = logging.getLogger(__name__)
5+
6+
7+
@core.celery_task
8+
async def render(*args, **kwargs):
9+
logger.critical(f'Received task with {args=} and {kwargs=}')

mfr/tasks/settings.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
3+
from kombu import Queue, Exchange
4+
5+
from mfr import settings
6+
7+
8+
config = settings.child('TASKS_CONFIG')
9+
10+
WAIT_TIMEOUT = int(config.get('WAIT_TIMEOUT', 20))
11+
WAIT_INTERVAL = float(config.get('WAIT_INTERVAL', 0.5))
12+
ADHOC_BACKEND_PATH = config.get('ADHOC_BACKEND_PATH', '/tmp')
13+
14+
broker_url = config.get(
15+
'BROKER_URL',
16+
'amqp://{}:{}//'.format(
17+
os.environ.get('RABBITMQ_PORT_5672_TCP_ADDR', ''),
18+
os.environ.get('RABBITMQ_PORT_5672_TCP_PORT', ''),
19+
)
20+
)
21+
22+
task_default_queue = config.get('CELERY_DEFAULT_QUEUE', 'mfr')
23+
task_queues = (
24+
Queue('mfr', Exchange('mfr'), routing_key='mfr'),
25+
)
26+
27+
task_always_eager = config.get_bool('CELERY_ALWAYS_EAGER', False)
28+
result_backend = config.get_nullable('CELERY_RESULT_BACKEND', 'rpc://')
29+
result_persistent = config.get_bool('CELERY_RESULT_PERSISTENT', True)
30+
worker_disable_rate_limits = config.get_bool('CELERY_DISABLE_RATE_LIMITS', True)
31+
result_expires = int(config.get('CELERY_TASK_RESULT_EXPIRES', 60))
32+
task_create_missing_queues = config.get_bool('CELERY_CREATE_MISSING_QUEUES', False)
33+
task_acks_late = True
34+
worker_hijack_root_logger = False
35+
task_eager_propagates = True
36+
37+
imports = [
38+
'mfr.tasks.render',
39+
]

poetry.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ openpyxl = "^3.1"
4545

4646
waterbutler = { git = "https://github.com/CenterForOpenScience/waterbutler.git", branch = "feature/buff-worms" }
4747
markupsafe = "2.0.1"
48+
celery = "5.5.0"
4849

4950
[tool.poetry.group.dev]
5051
optional = true

tasks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,15 @@ def server(ctx):
5858

5959
from mfr.server.app import serve
6060
serve()
61+
62+
@task
63+
def celery(ctx, loglevel='INFO', hostname='%h', concurrency=None):
64+
from mfr.tasks.app import app
65+
command = ['worker']
66+
if loglevel:
67+
command.extend(['--loglevel', loglevel])
68+
if hostname:
69+
command.extend(['--hostname', hostname])
70+
if concurrency:
71+
command.extend(['--concurrency', concurrency])
72+
app.worker_main(command)

0 commit comments

Comments
 (0)