Skip to content

Commit

Permalink
Add name and dag_id to asset decorators
Browse files Browse the repository at this point in the history
This allows the user to provide a custom asset name or dag_id to the
underlying asset or DAG, instead of always bounded to the function name.
This provides feature parity to the dag decorator.
  • Loading branch information
uranusjr committed Mar 11, 2025
1 parent 637525c commit b523dd0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
16 changes: 10 additions & 6 deletions task-sdk/src/airflow/sdk/definitions/asset/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class AssetDefinition(Asset):
_source: asset

def __attrs_post_init__(self) -> None:
with self._source.create_dag(dag_id=self.name):
with self._source.create_dag(default_dag_id=self.name):
_AssetMainOperator.from_definition(self)


Expand All @@ -117,7 +117,7 @@ class MultiAssetDefinition(BaseAsset):
_source: asset.multi

def __attrs_post_init__(self) -> None:
with self._source.create_dag(dag_id=self._function.__name__):
with self._source.create_dag(default_dag_id=self._function.__name__):
_AssetMainOperator.from_definition(self)

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
Expand Down Expand Up @@ -153,7 +153,8 @@ class _DAGFactory:
schedule: ScheduleArg
is_paused_upon_creation: bool | None = None

display_name: str | None = None
dag_id: str | None = None
dag_display_name: str | None = None
description: str | None = None

params: ParamsDict | None = None
Expand All @@ -163,15 +164,16 @@ class _DAGFactory:
access_control: dict[str, dict[str, Collection[str]]] | None = None
owner_links: dict[str, str] | None = None

def create_dag(self, *, dag_id: str) -> DAG:
def create_dag(self, *, default_dag_id: str) -> DAG:
from airflow.models.dag import DAG # TODO: Use the SDK DAG when it works.

dag_id = self.dag_id or default_dag_id
return DAG(
dag_id=dag_id,
schedule=self.schedule,
is_paused_upon_creation=self.is_paused_upon_creation,
catchup=False,
dag_display_name=self.display_name or dag_id,
dag_display_name=self.dag_display_name or dag_id,
description=self.description,
params=self.params,
on_success_callback=self.on_success_callback,
Expand All @@ -184,6 +186,7 @@ def create_dag(self, *, dag_id: str) -> DAG:
class asset(_DAGFactory):
"""Create an asset by decorating a materialization function."""

name: str | None = None
uri: str | ObjectStoragePath | None = None
group: str = Asset.asset_type
extra: dict[str, Any] = attrs.field(factory=dict)
Expand All @@ -203,8 +206,9 @@ def __call__(self, f: Callable) -> MultiAssetDefinition:
return MultiAssetDefinition(function=f, source=self)

def __call__(self, f: Callable) -> AssetDefinition:
if (name := f.__name__) != f.__qualname__:
if f.__name__ != f.__qualname__:
raise ValueError("nested function not supported")
name = self.name or f.__name__
return AssetDefinition(
name=name,
uri=name if self.uri is None else str(self.uri),
Expand Down
25 changes: 25 additions & 0 deletions task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,23 @@ def test_with_invalid_asset_name(self, example_asset_func):

assert err.value.args[0].startswith("prohibited name for asset: ")

@pytest.mark.parametrize(
"provided_uri, expected_uri",
[
pytest.param(None, "custom", id="default-uri"),
pytest.param("s3://bucket/object", "s3://bucket/object", id="custom-uri"),
],
)
def test_custom_name(self, example_asset_func, provided_uri, expected_uri):
asset_definition = asset(name="custom", uri=provided_uri, schedule=None)(example_asset_func)
assert asset_definition.name == "custom"
assert asset_definition.uri == expected_uri

def test_custom_dag_id(self, example_asset_func):
asset_definition = asset(name="asset", dag_id="dag", schedule=None)(example_asset_func)
assert asset_definition.name == "asset"
assert asset_definition._source.dag_id == "dag"


class TestAssetMultiDecorator:
def test_multi_asset(self, example_asset_func):
Expand All @@ -118,6 +135,14 @@ def test_multi_asset(self, example_asset_func):
assert definition._source.schedule is None
assert definition._source.outlets == [Asset(name="a"), Asset(name="b")]

def test_multi_custom_dag_id(self, example_asset_func):
definition = asset.multi(
dag_id="custom",
schedule=None,
outlets=[Asset(name="a"), Asset(name="b")],
)(example_asset_func)
assert definition._source.dag_id == "custom"


class TestAssetDefinition:
@mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition")
Expand Down

0 comments on commit b523dd0

Please sign in to comment.