From dd468c4e384e07dbac6779d7ea32ecad7dc0bad6 Mon Sep 17 00:00:00 2001 From: fjetter Date: Mon, 11 Nov 2024 17:54:02 +0100 Subject: [PATCH 1/3] Use Task.fuse --- dask_expr/_expr.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 739456a1..239159c5 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -3766,32 +3766,25 @@ def _broadcast_dep(self, dep: Expr): return dep.npartitions == 1 def _task(self, name: Key, index: int) -> Task: - internal_tasks = [] - seen_keys = set() - external_deps = set() + internal_tasks = {} for _expr in self.exprs: if self._broadcast_dep(_expr): subname = (_expr._name, 0) else: subname = (_expr._name, index) t = _expr._task(subname, subname[1]) + assert t.key == subname - internal_tasks.append(t) - seen_keys.add(subname) - external_deps.update(t.dependencies) - external_deps -= seen_keys - dependencies = {dep: TaskRef(dep) for dep in external_deps} - t = Task( - name, - Fused._execute_internal_graph, - # Wrap the actual subgraph as a data node such that the tasks are - # not erroneously parsed. The external task would otherwise carry - # the internal keys as dependencies which is not satisfiable - DataNode(None, internal_tasks), - dependencies, - (self.exprs[0]._name, index), - ) - return t + internal_tasks[t.key] = t + outkey = (self.exprs[0]._name, index) + work = [outkey] + internal_tasks_culled = [] + while work: + tkey = work.pop() + t = internal_tasks[tkey] + internal_tasks_culled.append(t) + work.extend(t.dependencies) + return Task.fuse(*internal_tasks_culled, key=name) @staticmethod def _execute_internal_graph(internal_tasks, dependencies, outkey): From 72868af1b6543e624a30ae0a2769b3627242a8a6 Mon Sep 17 00:00:00 2001 From: fjetter Date: Mon, 11 Nov 2024 17:58:20 +0100 Subject: [PATCH 2/3] add a comment --- dask_expr/_expr.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 239159c5..99502722 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -3776,11 +3776,15 @@ def _task(self, name: Key, index: int) -> Task: assert t.key == subname internal_tasks[t.key] = t + # The above is ambiguous and we have to cull to get an unambiguous graph outkey = (self.exprs[0]._name, index) work = [outkey] internal_tasks_culled = [] while work: tkey = work.pop() + if tkey not in internal_tasks: + # External dependency + continue t = internal_tasks[tkey] internal_tasks_culled.append(t) work.extend(t.dependencies) From 9c1303394e8baeb408ab48cc762ca5be89b256b9 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:42:20 +0100 Subject: [PATCH 3/3] Remove unnecessary culling --- dask_expr/_expr.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 0e996372..9f518911 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -3765,7 +3765,7 @@ def _broadcast_dep(self, dep: Expr): return dep.npartitions == 1 def _task(self, name: Key, index: int) -> Task: - internal_tasks = {} + internal_tasks = [] for _expr in self.exprs: if self._broadcast_dep(_expr): subname = (_expr._name, 0) @@ -3774,20 +3774,8 @@ def _task(self, name: Key, index: int) -> Task: t = _expr._task(subname, subname[1]) assert t.key == subname - internal_tasks[t.key] = t - # The above is ambiguous and we have to cull to get an unambiguous graph - outkey = (self.exprs[0]._name, index) - work = [outkey] - internal_tasks_culled = [] - while work: - tkey = work.pop() - if tkey not in internal_tasks: - # External dependency - continue - t = internal_tasks[tkey] - internal_tasks_culled.append(t) - work.extend(t.dependencies) - return Task.fuse(*internal_tasks_culled, key=name) + internal_tasks.append(t) + return Task.fuse(*internal_tasks, key=name) @staticmethod def _execute_internal_graph(internal_tasks, dependencies, outkey):