From b523dd0965388eead2baec998e8b609c5d960ebf Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 11 Mar 2025 15:05:26 +0800 Subject: [PATCH] Add name and dag_id to asset decorators This allows the user to provide a custom asset name or dag_id to the underlying asset or DAG, instead of always bounded to the function name. This provides feature parity to the dag decorator. --- .../sdk/definitions/asset/decorators.py | 16 +++++++----- .../definitions/test_asset_decorators.py | 25 +++++++++++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index 77ab57074bbfe..e6ce95f986410 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -97,7 +97,7 @@ class AssetDefinition(Asset): _source: asset def __attrs_post_init__(self) -> None: - with self._source.create_dag(dag_id=self.name): + with self._source.create_dag(default_dag_id=self.name): _AssetMainOperator.from_definition(self) @@ -117,7 +117,7 @@ class MultiAssetDefinition(BaseAsset): _source: asset.multi def __attrs_post_init__(self) -> None: - with self._source.create_dag(dag_id=self._function.__name__): + with self._source.create_dag(default_dag_id=self._function.__name__): _AssetMainOperator.from_definition(self) def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: @@ -153,7 +153,8 @@ class _DAGFactory: schedule: ScheduleArg is_paused_upon_creation: bool | None = None - display_name: str | None = None + dag_id: str | None = None + dag_display_name: str | None = None description: str | None = None params: ParamsDict | None = None @@ -163,15 +164,16 @@ class _DAGFactory: access_control: dict[str, dict[str, Collection[str]]] | None = None owner_links: dict[str, str] | None = None - def create_dag(self, *, dag_id: str) -> DAG: + def create_dag(self, *, default_dag_id: str) -> DAG: from airflow.models.dag import DAG # TODO: Use the SDK DAG when it works. + dag_id = self.dag_id or default_dag_id return DAG( dag_id=dag_id, schedule=self.schedule, is_paused_upon_creation=self.is_paused_upon_creation, catchup=False, - dag_display_name=self.display_name or dag_id, + dag_display_name=self.dag_display_name or dag_id, description=self.description, params=self.params, on_success_callback=self.on_success_callback, @@ -184,6 +186,7 @@ def create_dag(self, *, dag_id: str) -> DAG: class asset(_DAGFactory): """Create an asset by decorating a materialization function.""" + name: str | None = None uri: str | ObjectStoragePath | None = None group: str = Asset.asset_type extra: dict[str, Any] = attrs.field(factory=dict) @@ -203,8 +206,9 @@ def __call__(self, f: Callable) -> MultiAssetDefinition: return MultiAssetDefinition(function=f, source=self) def __call__(self, f: Callable) -> AssetDefinition: - if (name := f.__name__) != f.__qualname__: + if f.__name__ != f.__qualname__: raise ValueError("nested function not supported") + name = self.name or f.__name__ return AssetDefinition( name=name, uri=name if self.uri is None else str(self.uri), diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py index 4467397f07244..b29c050bc98b2 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py @@ -106,6 +106,23 @@ def test_with_invalid_asset_name(self, example_asset_func): assert err.value.args[0].startswith("prohibited name for asset: ") + @pytest.mark.parametrize( + "provided_uri, expected_uri", + [ + pytest.param(None, "custom", id="default-uri"), + pytest.param("s3://bucket/object", "s3://bucket/object", id="custom-uri"), + ], + ) + def test_custom_name(self, example_asset_func, provided_uri, expected_uri): + asset_definition = asset(name="custom", uri=provided_uri, schedule=None)(example_asset_func) + assert asset_definition.name == "custom" + assert asset_definition.uri == expected_uri + + def test_custom_dag_id(self, example_asset_func): + asset_definition = asset(name="asset", dag_id="dag", schedule=None)(example_asset_func) + assert asset_definition.name == "asset" + assert asset_definition._source.dag_id == "dag" + class TestAssetMultiDecorator: def test_multi_asset(self, example_asset_func): @@ -118,6 +135,14 @@ def test_multi_asset(self, example_asset_func): assert definition._source.schedule is None assert definition._source.outlets == [Asset(name="a"), Asset(name="b")] + def test_multi_custom_dag_id(self, example_asset_func): + definition = asset.multi( + dag_id="custom", + schedule=None, + outlets=[Asset(name="a"), Asset(name="b")], + )(example_asset_func) + assert definition._source.dag_id == "custom" + class TestAssetDefinition: @mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition")