Skip to content

Commit

Permalink
Refactoring agent.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Default2882 committed Nov 21, 2024
1 parent 6e6f558 commit 11f50f5
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 170 deletions.
282 changes: 134 additions & 148 deletions python-sdk/indexify/executor/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from . import image_dependency_installer
from .api_objects import ExecutorMetadata, Task
from .downloader import DownloadedInputs, Downloader
from .executor_tasks import DownloadTask, RunFunctionTask, TaskEnum
from .runtime_probes import ProbeInfo, RuntimeProbes
from .task_store import CompletedTask, TaskStore

Expand Down Expand Up @@ -64,7 +65,8 @@ def __init__(
image_version: Optional[int] = None,
):
event_loop = asyncio.get_event_loop()
event_loop.set_default_executor(ThreadPoolExecutor(max_workers=num_workers))
self._thread_pool = ThreadPoolExecutor(max_workers=num_workers)
event_loop.set_default_executor(self._thread_pool)
self.name_alias = name_alias
self.image_version = image_version
self._config_path = config_path
Expand Down Expand Up @@ -117,12 +119,13 @@ def __init__(
)

async def task_launcher(self):
async_tasks: List[asyncio.Task | asyncio.Future] = []
async_tasks: List[asyncio.Task] = []
fn_queue: List[FunctionInput] = []

async_tasks.append(
asyncio.create_task(
self._task_store.get_runnable_tasks(), name="get_runnable_tasks"
self._task_store.get_runnable_tasks(),
name=TaskEnum.GET_RUNNABLE_TASK.value,
)
)

Expand All @@ -132,54 +135,19 @@ async def task_launcher(self):
task: Task = self._task_store.get_task(fn.task_id)

if self._executor_bootstrap_failed:
completed_task = CompletedTask(
task=task,
outputs=[],
task_outcome="failure",
)
self._task_store.complete(outcome=completed_task)

self.mark_task_as_failed(task)
continue

# Bootstrap this executor. Fail the task if we can't.
if self._require_image_bootstrap:
try:
image_info = await _get_image_info_for_compute_graph(
task, self._protocol, self._server_addr, self._config_path
)
image_dependency_installer.executor_image_builder(
image_info, self.name_alias, self.image_version
)
self._require_image_bootstrap = False
except Exception as e:
console.print(
Text("Failed to bootstrap the executor ", style="red bold")
+ Text(f"Exception: {traceback.format_exc()}", style="red")
)

self._executor_bootstrap_failed = True

completed_task = CompletedTask(
task=task,
outputs=[],
task_outcome="failure",
)
self._task_store.complete(outcome=completed_task)

if not self._try_bootstrap(task):
continue

code_path = f"{self._code_path}/{task.namespace}/{task.compute_graph}.{task.graph_version}"
async_tasks.append(
self._function_worker.run_function(
task, fn.input, fn.init_value, code_path
)
# ExtractTask(
# function_worker=self._function_worker,
# task=task,
# input=fn.input,
# code_path=f"{self._code_path}/{task.namespace}/{task.compute_graph}.{task.graph_version}",
# init_value=fn.init_value,
# )
)

fn_queue = []
Expand All @@ -189,121 +157,139 @@ async def task_launcher(self):

async_tasks: List[asyncio.Task] = list(pending)
for async_task in done:
if async_task.get_name() == "get_runnable_tasks":
if async_task.exception():
console.print(
Text("Task Launcher Error: ", style="red bold")
+ Text(
f"Failed to get runnable tasks: {async_task.exception()}",
style="red",
task_name = TaskEnum.from_value(async_task.get_name())
match task_name:
case TaskEnum.GET_RUNNABLE_TASK:
if async_task.exception():
self._console_log_exception(async_task)
continue
result: Dict[str, Task] = await async_task
task: Task
for _, task in result.items():
async_tasks.append(
self._downloader.download(
task, TaskEnum.DOWNLOAD_GRAPH_TASK
)
)
)
continue
result: Dict[str, Task] = await async_task
task: Task
for _, task in result.items():
async_tasks.append(
self._downloader.download(task, "download_graph")
)
async_tasks.append(
asyncio.create_task(
self._task_store.get_runnable_tasks(),
name="get_runnable_tasks",
)
)
elif async_task.get_name() == "download_graph":
if async_task.exception():
console.print(
Text(
f"Failed to download graph for task {async_task.task.id}\n",
style="red bold",
asyncio.create_task(
self._task_store.get_runnable_tasks(),
name=TaskEnum.GET_RUNNABLE_TASK.value,
)
+ Text(f"Exception: {async_task.exception()}", style="red")
)
completed_task = CompletedTask(
task=async_task.task,
outputs=[],
task_outcome="failure",
)
self._task_store.complete(outcome=completed_task)
continue
async_tasks.append(
self._downloader.download(async_task.task, "download_input")
)
elif async_task.get_name() == "download_input":
if async_task.exception():
console.print(
Text(
f"Failed to download input for task {async_task.task.id}\n",
style="red bold",
case TaskEnum.DOWNLOAD_GRAPH_TASK:
async_task: DownloadTask
if async_task.exception():
self._console_log_exception(async_task)
self.mark_task_as_failed(async_task.task)
continue
async_tasks.append(
self._downloader.download(
async_task.task, TaskEnum.DOWNLOAD_INPUT_TASK
)
+ Text(f"Exception: {async_task.exception()}", style="red")
)
completed_task = CompletedTask(
task=async_task.task,
outputs=[],
task_outcome="failure",
)
self._task_store.complete(outcome=completed_task)
continue
downloaded_inputs: DownloadedInputs = await async_task
task: Task = async_task.task
fn_queue.append(
FunctionInput(
task_id=task.id,
namespace=task.namespace,
compute_graph=task.compute_graph,
function=task.compute_fn,
input=downloaded_inputs.input,
init_value=downloaded_inputs.init_value,
)
)
elif async_task.get_name() == "run_function":
if async_task.exception():
completed_task = CompletedTask(
task=async_task.task,
task_outcome="failure",
outputs=[],
stderr=str(async_task.exception()),
)
self._task_store.complete(outcome=completed_task)
continue
async_task: ExtractTask
try:
outputs: FunctionWorkerOutput = await async_task
if not outputs.success:
task_outcome = "failure"
else:
task_outcome = "success"

completed_task = CompletedTask(
task=async_task.task,
task_outcome=task_outcome,
outputs=outputs.fn_outputs,
router_output=outputs.router_output,
stdout=outputs.stdout,
stderr=outputs.stderr,
reducer=outputs.reducer,
)
self._task_store.complete(outcome=completed_task)
except BrokenProcessPool:
self._task_store.retriable_failure(async_task.task.id)
continue
except Exception as e:
console.print(
Text(
f"Failed to execute task {async_task.task.id}\n",
style="red bold",
case TaskEnum.DOWNLOAD_INPUT_TASK:
async_task: DownloadTask
if async_task.exception():
self._console_log_exception(async_task)
self.mark_task_as_failed(async_task.task)
continue
downloaded_inputs: DownloadedInputs = await async_task
task: Task = async_task.task
fn_queue.append(
FunctionInput(
task_id=task.id,
namespace=task.namespace,
compute_graph=task.compute_graph,
function=task.compute_fn,
input=downloaded_inputs.input,
init_value=downloaded_inputs.init_value,
)
+ Text(f"Exception: {e}", style="red")
)
completed_task = CompletedTask(
task=async_task.task,
task_outcome="failure",
outputs=[],
case TaskEnum.RUN_FUNCTION_TASK:
async_task: RunFunctionTask
if async_task.exception():
self.mark_task_as_failed(
async_task.task, str(async_task.exception())
)
continue
try:
outputs: FunctionWorkerOutput = await async_task
if not outputs.success:
task_outcome = "failure"
else:
task_outcome = "success"

completed_task = CompletedTask(
task=async_task.task,
task_outcome=task_outcome,
outputs=outputs.fn_outputs,
router_output=outputs.router_output,
stdout=outputs.stdout,
stderr=outputs.stderr,
reducer=outputs.reducer,
)
self._task_store.complete(outcome=completed_task)
except BrokenProcessPool:
self._task_store.retriable_failure(async_task.task.id)
continue
except Exception as e:
console.print(
Text(
f"Failed to execute task {async_task.task.id}\n",
style="red bold",
)
+ Text(f"Exception: {e}", style="red")
)
completed_task = CompletedTask(
task=async_task.task,
task_outcome="failure",
outputs=[],
)
self._task_store.complete(outcome=completed_task)
continue
case _:
raise ValueError(
f"'{async_task.get_name()}' is not a valid task name."
)
self._task_store.complete(outcome=completed_task)
continue

def _console_log_exception(self, async_task: asyncio.Task):
console.print(
Text("Task Launcher Error: ", style="red bold")
+ Text(
f"Failed to get runnable tasks: {async_task.exception()}",
style="red",
)
)

def mark_task_as_failed(self, task: Task, stderr: str = None):
completed_task = CompletedTask(
task=task,
outputs=[],
task_outcome="failure",
stderr=stderr,
)
self._task_store.complete(outcome=completed_task)

def _try_bootstrap(self, task: Task) -> bool:
try:
image_info = _get_image_info_for_compute_graph(
task, self._protocol, self._server_addr, self._config_path
)
image_dependency_installer.executor_image_builder(
image_info, self.name_alias, self.image_version
)
self._require_image_bootstrap = False
return True
except Exception as e:
console.print(
Text("Failed to bootstrap the executor ", style="red bold")
+ Text(f"Exception: {traceback.format_exc()}", style="red")
)

self._executor_bootstrap_failed = True
self.mark_task_as_failed(task)
return False

async def run(self):
console.print("Starting Extractor Agent...", style="green")
Expand Down Expand Up @@ -395,15 +381,15 @@ def to_sentence_case(snake_str):
async def _shutdown(self, loop):
console.print(Text("shutting down agent...", style="bold yellow"))
self._should_run = False
self._thread_pool.shutdown(cancel_futures=True)
for task in asyncio.all_tasks(loop):
task.cancel()

def shutdown(self, loop):
self._function_worker.shutdown()
loop.create_task(self._shutdown(loop))


async def _get_image_info_for_compute_graph(
def _get_image_info_for_compute_graph(
task: Task, protocol, server_addr, config_path: str
) -> ImageInformation:
namespace = task.namespace
Expand Down
Loading

0 comments on commit 11f50f5

Please sign in to comment.