Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion flexkv/kvtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class KVTask:

# batch: points to the batch task id if this task was merged into a batch
batch_task_id: Optional[int] = None
# ref count: number of sub-tasks referencing this batch task
pending_sub_count: int = 0

def is_completed(self) -> bool:
return self.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED]
Expand Down Expand Up @@ -354,6 +356,8 @@ def _update_tasks(self, timeout: float = 0.001) -> None:
task.op_callback_dict[completed_op.op_id]()

def _cancel_task(self, task_id: int) -> None:
if task_id not in self.tasks:
return
task = self.tasks[task_id]
if task.is_completed():
flexkv_logger.warning(f"Task {task_id} is already completed, cannot cancel")
Expand All @@ -365,7 +369,7 @@ def _cancel_task(self, task_id: int) -> None:
flexkv_logger.warning(f"Task {task_id} is already cancelled, cannot cancel")
return
task.status = TaskStatus.CANCELLED
self.graph_to_task.pop(task.graph.graph_id, None)
self._release_task(task_id)

def check_completed(self, task_id: int, completely: bool = False) -> bool:
task = self.tasks[task_id]
Expand Down Expand Up @@ -411,6 +415,19 @@ def check_task_ready(self, task_id: int) -> TransferOpGraph:
task.status = TaskStatus.RUNNING
return task.graph

def _release_task(self, task_id: int) -> None:
"""clean up task resources."""
if task_id not in self.tasks:
return
task = self.tasks[task_id]
batch_id = task.batch_task_id
self.graph_to_task.pop(task.graph.graph_id, None)
self.tasks.pop(task_id, None)
if batch_id is not None and batch_id in self.tasks:
self.tasks[batch_id].pending_sub_count -= 1
if self.tasks[batch_id].pending_sub_count <= 0:
self._release_task(batch_id)

def _mark_completed(self, task_id: int) -> None:
task = self.tasks[task_id]
if task.is_completed():
Expand Down Expand Up @@ -601,6 +618,8 @@ def _wait_impl(self,
task_id=task_id,
return_mask=self.tasks[task_id].return_mask
)
if self.tasks[effective_id].is_completed():
self._release_task(task_id)
break
elif only_return_finished:
break
Expand Down Expand Up @@ -835,6 +854,7 @@ def merge_to_batch_kvtask(self,
op_callback_dict=op_callback_dict,
)
self.graph_to_task[batch_task_graph.graph_id] = batch_id
self.tasks[batch_id].pending_sub_count = len(task_ids)
for task_id in task_ids:
self.graph_to_task.pop(self.tasks[task_id].graph.graph_id, None)
self.tasks[task_id].batch_task_id = batch_id
Expand Down
Loading