Skip to content

Commit

Permalink
fix async py workflow (#2046)
Browse files Browse the repository at this point in the history
  • Loading branch information
DatGuyJonathan authored Feb 18, 2025
1 parent d09e930 commit 44b7425
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from temporalio import activity
from dataclasses import dataclass
from typing import Optional, Any, Callable
import asyncio
import os
import sys
import json
Expand Down Expand Up @@ -58,7 +59,10 @@ async def dynamic_activity(execution_input: ScriptExecutionInput) -> WorkflowSte
# Pass the input data directly if it exists
input_data = execution_input.input_data if execution_input.input_data else {}
log.info(f"Processed input_data for task: {input_data}")
result = task_func(data=input_data)
if asyncio.iscoroutinefunction(task_func):
result = await task_func(data=input_data)
else:
result = task_func(data=input_data)

# Validate and encode result
if not isinstance(result, dict):
Expand Down
36 changes: 22 additions & 14 deletions packages/py-moose-lib/moose_lib/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps
from typing import TypeVar, Callable, Any
from typing import TypeVar, Callable, Any, Awaitable
from .commons import Logger
import asyncio

T = TypeVar('T')

Expand All @@ -11,20 +12,27 @@ def task(func: Callable[..., T] = None, *, retries: int = 3) -> Callable[..., T]
func: The function to decorate
retries: Number of times to retry the task if it fails
"""
def validate_result(result: Any) -> None:
"""Ensure proper return format"""
if not isinstance(result, dict):
raise ValueError("Task must return a dictionary with 'step' and 'data' keys")
if "step" not in result or "data" not in result:
raise ValueError("Task result must contain 'step' and 'data' keys")

def decorator(f: Callable[..., T]) -> Callable[..., T]:
@wraps(f)
def wrapper(*args, **kwargs) -> T:
result = f(*args, **kwargs)

# Ensure proper return format
if not isinstance(result, dict):
raise ValueError("Task must return a dictionary with 'step' and 'data' keys")

if "step" not in result or "data" not in result:
raise ValueError("Task result must contain 'step' and 'data' keys")

return result
if asyncio.iscoroutinefunction(f):
@wraps(f)
async def wrapper(*args, **kwargs) -> T:
result = await f(*args, **kwargs)
validate_result(result)
return result
else:
@wraps(f)
def wrapper(*args, **kwargs) -> T:
result = f(*args, **kwargs)
validate_result(result)
return result

# Add the markers to the wrapper
wrapper._is_moose_task = True
wrapper._retries = retries
Expand Down

0 comments on commit 44b7425

Please sign in to comment.