diff --git a/docs/source/changelog.md b/docs/source/changelog.md index 21dd0bfc..f4133e04 100644 --- a/docs/source/changelog.md +++ b/docs/source/changelog.md @@ -1,5 +1,9 @@ # Changelog +## 0.9.9 (XXXX-XX-XX) +- Support using subclasses of lists, tuple, and dict via deep_map. This allows usage of NamedTuple for more + semantically obvious return types. + ## 0.9.8 (2024-09-06) - Bugfix for `inputs` argument for `flow.run()`. diff --git a/src/pydiverse/pipedag/util/deep_map.py b/src/pydiverse/pipedag/util/deep_map.py index 5c4aef9d..8770ad4c 100644 --- a/src/pydiverse/pipedag/util/deep_map.py +++ b/src/pydiverse/pipedag/util/deep_map.py @@ -3,6 +3,7 @@ Heavily inspired by the builtin copy module of python: https://github.com/python/cpython/blob/main/Lib/copy.py """ + from __future__ import annotations from typing import Callable @@ -19,13 +20,11 @@ def deep_map(x, fn: Callable, memo=None): if y is not _nil: return y - cls = type(x) - - if cls == list: + if isinstance(x, list): y = _deep_map_list(x, fn, memo) - elif cls == tuple: + elif isinstance(x, tuple): y = _deep_map_tuple(x, fn, memo) - elif cls == dict: + elif isinstance(x, dict): y = _deep_map_dict(x, fn, memo) else: y = fn(x) diff --git a/tests/test_flows/test_flow.py b/tests/test_flows/test_flow.py index 3921c840..34bfe7aa 100644 --- a/tests/test_flows/test_flow.py +++ b/tests/test_flows/test_flow.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import NamedTuple import pandas as pd import sqlalchemy as sa @@ -46,9 +47,14 @@ def list_arg(x: list[pd.DataFrame]): return Blob(x) +class Blobs(NamedTuple): + a: Blob + b: Blob + + @materialize def blob_task(x, y): - return Blob(x), Blob(y) + return Blobs(Blob(x), Blob(y)) def test_simple_flow(with_blob=True):