From b13ed892b1fcc98b1bc9896dd4b2273a21121705 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 25 Dec 2024 16:52:31 +0900 Subject: [PATCH] Fail a task if an inlet or outlet asset is inactive or an inactive asset is added to an asset alias (#44831) * feat(taskinstance): fail a task if its outlets contain inactive asset * test(taskinstance): active assets in test cases * test(taskinstance): add test cases test_run_with_inactive_assets_in_the_same_dag and test_run_with_inactive_assets_in_different_dags * feat(taskinstance): fail a task if asset is not active in inlets * refactor(taskinstance): rework warning message * feat(asset_alias): block adding asset events to assets that can not be active * feat(asset-alias): handle the case that asset is not active but might be able to be activated when adding * feat(taskinstance): refactor name_uri as asset key * refactor(exceptions): move error msg logic to customized exception * test(taskinstance): add test case test_outlet_asset_alias_asset_inactive * refactor(taskinstance): remove AssetUniqueKey.to_tuple and use attrs.astuple instead --- airflow/exceptions.py | 31 ++++ airflow/models/taskinstance.py | 45 ++++- tests/models/test_taskinstance.py | 274 ++++++++++++++++++++++++++---- tests/utils/test_context.py | 17 +- 4 files changed, 328 insertions(+), 39 deletions(-) diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 4035488cf87e1..5e32c00c7d5da 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -22,6 +22,7 @@ from __future__ import annotations import warnings +from collections.abc import Collection from datetime import timedelta from http import HTTPStatus from typing import TYPE_CHECKING, Any, NamedTuple @@ -32,6 +33,7 @@ from collections.abc import Sized from airflow.models import DagRun + from airflow.sdk.definitions.asset import AssetUniqueKey class AirflowException(Exception): @@ -111,6 +113,35 @@ class AirflowFailException(AirflowException): """Raise when the task should be failed without retrying.""" +class AirflowExecuteWithInactiveAssetExecption(AirflowFailException): + """Raise when the task is executed with inactive assets.""" + + def __init__(self, inactive_asset_unikeys: Collection[AssetUniqueKey]) -> None: + self.inactive_asset_unique_keys = inactive_asset_unikeys + + @property + def inactive_assets_error_msg(self): + return ", ".join( + f'Asset(name="{key.name}", uri="{key.uri}")' for key in self.inactive_asset_unique_keys + ) + + +class AirflowInactiveAssetInInletOrOutletException(AirflowExecuteWithInactiveAssetExecption): + """Raise when the task is executed with inactive assets in its inlet or outlet.""" + + def __str__(self) -> str: + return f"Task has the following inactive assets in its inlets or outlets: {self.inactive_assets_error_msg}" + + +class AirflowInactiveAssetAddedToAssetAliasException(AirflowExecuteWithInactiveAssetExecption): + """Raise when inactive assets are added to an asset alias.""" + + def __str__(self) -> str: + return ( + f"The following assets accessed by an AssetAlias are inactive: {self.inactive_assets_error_msg}" + ) + + class AirflowOptionalProviderFeatureException(AirflowException): """Raise by providers when imports are missing for optional provider features.""" diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 9bd8e8da12ac6..6ef4452834fe3 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -35,6 +35,7 @@ from typing import TYPE_CHECKING, Any, Callable from urllib.parse import quote +import attrs import dill import jinja2 import lazy_object_proxy @@ -79,6 +80,8 @@ from airflow.exceptions import ( AirflowException, AirflowFailException, + AirflowInactiveAssetAddedToAssetAliasException, + AirflowInactiveAssetInInletOrOutletException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, @@ -91,7 +94,7 @@ XComForMappingNotPushed, ) from airflow.listeners.listener import get_listener_manager -from airflow.models.asset import AssetEvent, AssetModel +from airflow.models.asset import AssetActive, AssetEvent, AssetModel from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel from airflow.models.dagbag import DagBag from airflow.models.log import Log @@ -263,6 +266,7 @@ def _run_raw_task( context = ti.get_template_context(ignore_param_exceptions=False, session=session) try: + ti._validate_inlet_outlet_assets_activeness(session=session) if not mark_success: TaskInstance._execute_task_with_callbacks( self=ti, # type: ignore[arg-type] @@ -2749,16 +2753,24 @@ def _register_asset_changes_int( frozen_extra = frozenset(asset_alias_event.extra.items()) asset_alias_names[(asset_unique_key, frozen_extra)].add(asset_alias_name) + asset_unique_keys = {key for key, _ in asset_alias_names} asset_models: dict[AssetUniqueKey, AssetModel] = { AssetUniqueKey.from_asset(asset_obj): asset_obj for asset_obj in session.scalars( select(AssetModel).where( tuple_(AssetModel.name, AssetModel.uri).in_( - (key.name, key.uri) for key, _ in asset_alias_names + attrs.astuple(key) for key in asset_unique_keys ) ) ) } + inactive_asset_unique_keys = TaskInstance._get_inactive_asset_unique_keys( + asset_unique_keys={key for key in asset_unique_keys if key in asset_models}, + session=session, + ) + if inactive_asset_unique_keys: + raise AirflowInactiveAssetAddedToAssetAliasException(inactive_asset_unique_keys) + if missing_assets := [ asset_unique_key.to_asset() for asset_unique_key, _ in asset_alias_names @@ -3642,6 +3654,35 @@ def duration_expression_update( } ) + def _validate_inlet_outlet_assets_activeness(self, session: Session) -> None: + if not self.task or not (self.task.outlets or self.task.inlets): + return + + all_asset_unique_keys = { + AssetUniqueKey.from_asset(inlet_or_outlet) + for inlet_or_outlet in itertools.chain(self.task.inlets, self.task.outlets) + if isinstance(inlet_or_outlet, Asset) + } + inactive_asset_unique_keys = self._get_inactive_asset_unique_keys(all_asset_unique_keys, session) + if inactive_asset_unique_keys: + raise AirflowInactiveAssetInInletOrOutletException(inactive_asset_unique_keys) + + @staticmethod + def _get_inactive_asset_unique_keys( + asset_unique_keys: set[AssetUniqueKey], session: Session + ) -> set[AssetUniqueKey]: + active_asset_unique_keys = { + AssetUniqueKey(name, uri) + for name, uri in session.execute( + select(AssetActive.name, AssetActive.uri).where( + tuple_in_condition( + (AssetActive.name, AssetActive.uri), [attrs.astuple(key) for key in asset_unique_keys] + ) + ) + ) + } + return asset_unique_keys - active_asset_unique_keys + def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None: """Given two operators, find their innermost common mapped task group.""" diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 244fa2ed5bf3a..73f5908b707cf 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -43,6 +43,8 @@ from airflow.exceptions import ( AirflowException, AirflowFailException, + AirflowInactiveAssetAddedToAssetAliasException, + AirflowInactiveAssetInInletOrOutletException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, @@ -51,7 +53,8 @@ UnmappableXComTypePushed, XComForMappingNotPushed, ) -from airflow.models.asset import AssetAliasModel, AssetDagRunQueue, AssetEvent, AssetModel +from airflow.jobs.scheduler_job_runner import SchedulerJobRunner +from airflow.models.asset import AssetActive, AssetAliasModel, AssetDagRunQueue, AssetEvent, AssetModel from airflow.models.connection import Connection from airflow.models.dag import DAG from airflow.models.dagbag import DagBag @@ -2289,6 +2292,11 @@ def test_outlet_assets(self, create_task_instance): dagbag = DagBag(dag_folder=example_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) dagbag.sync_to_db(session=session) + + asset_models = session.scalars(select(AssetModel)).all() + SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) + session.flush() + run_id = str(uuid4()) dr = DagRun(dag1.dag_id, run_id=run_id, run_type="anything") session.merge(dr) @@ -2402,6 +2410,11 @@ def test_outlet_assets_skipped(self): dagbag = DagBag(dag_folder=test_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) dagbag.sync_to_db(session=session) + + asset_models = session.scalars(select(AssetModel)).all() + SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) + session.flush() + run_id = str(uuid4()) dr = DagRun(dag_with_skip_task.dag_id, run_id=run_id, run_type="anything") session.merge(dr) @@ -2419,10 +2432,11 @@ def test_outlet_assets_skipped(self): # check that no asset events were generated assert session.query(AssetEvent).count() == 0 + @pytest.mark.want_activate_assets(True) def test_outlet_asset_extra(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset - with dag_maker(schedule=None, session=session) as dag: + with dag_maker(schedule=None, serialized=True, session=session): @task(outlets=Asset("test_outlet_asset_extra_1")) def write1(*, outlet_events): @@ -2442,7 +2456,6 @@ def _write2_post_execute(context, _): dr: DagRun = dag_maker.create_dagrun() for ti in dr.get_task_instances(session=session): - ti.refresh_from_task(dag.get_task(ti.task_id)) ti.run(session=session) events = dict(iter(session.execute(select(AssetEvent.source_task_id, AssetEvent)))) @@ -2460,10 +2473,11 @@ def _write2_post_execute(context, _): assert events["write2"].asset.uri == "test_outlet_asset_extra_2" assert events["write2"].extra == {"x": 1} + @pytest.mark.want_activate_assets(True) def test_outlet_asset_extra_ignore_different(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset - with dag_maker(schedule=None, session=session): + with dag_maker(schedule=None, serialized=True, session=session): @task(outlets=Asset("test_outlet_asset_extra")) def write(*, outlet_events): @@ -2481,11 +2495,12 @@ def write(*, outlet_events): assert event.source_task_id == "write" assert event.extra == {"one": 1} + @pytest.mark.want_activate_assets(True) def test_outlet_asset_extra_yield(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.asset.metadata import Metadata - with dag_maker(schedule=None, session=session) as dag: + with dag_maker(schedule=None, serialized=True, session=session): @task(outlets=Asset("test_outlet_asset_extra_1")) def write1(): @@ -2507,7 +2522,6 @@ def _write2_post_execute(context, result): dr: DagRun = dag_maker.create_dagrun() for ti in dr.get_task_instances(session=session): - ti.refresh_from_task(dag.get_task(ti.task_id)) ti.run(session=session) xcom = session.scalars( @@ -2532,17 +2546,18 @@ def _write2_post_execute(context, result): assert events["write2"].asset.name == "test_outlet_asset_extra_2" assert events["write2"].extra == {"x": 1} + @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset, AssetAlias asset_uri = "test_outlet_asset_alias_test_case_ds" alias_name_1 = "test_outlet_asset_alias_test_case_asset_alias_1" - ds1 = AssetModel(id=1, uri=asset_uri) - session.add(ds1) + asm = AssetModel(id=1, uri=asset_uri) + session.add_all([asm, AssetActive.for_asset(asm.to_public())]) session.commit() - with dag_maker(dag_id="producer_dag", schedule=None, session=session) as dag: + with dag_maker(dag_id="producer_dag", schedule=None, serialized=True, session=session): @task(outlets=AssetAlias(alias_name_1)) def producer(*, outlet_events): @@ -2553,7 +2568,6 @@ def producer(*, outlet_events): dr: DagRun = dag_maker.create_dagrun() for ti in dr.get_task_instances(session=session): - ti.refresh_from_task(dag.get_task(ti.task_id)) ti.run(session=session) producer_events = session.execute( @@ -2580,6 +2594,7 @@ def producer(*, outlet_events): assert len(asset_alias_obj.assets) == 1 assert asset_alias_obj.assets[0].uri == asset_uri + @pytest.mark.want_activate_assets(True) def test_outlet_multiple_asset_alias(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset, AssetAlias @@ -2588,11 +2603,11 @@ def test_outlet_multiple_asset_alias(self, dag_maker, session): asset_alias_name_2 = "test_outlet_maa_asset_alias_2" asset_alias_name_3 = "test_outlet_maa_asset_alias_3" - ds1 = AssetModel(id=1, uri=asset_uri) - session.add(ds1) + asm = AssetModel(id=1, uri=asset_uri) + session.add_all([asm, AssetActive.for_asset(asm.to_public())]) session.commit() - with dag_maker(dag_id="producer_dag", schedule=None, session=session) as dag: + with dag_maker(dag_id="producer_dag", schedule=None, serialized=True, session=session): @task( outlets=[ @@ -2611,7 +2626,6 @@ def producer(*, outlet_events): dr: DagRun = dag_maker.create_dagrun() for ti in dr.get_task_instances(session=session): - ti.refresh_from_task(dag.get_task(ti.task_id)) ti.run(session=session) producer_events = session.execute( @@ -2653,6 +2667,7 @@ def producer(*, outlet_events): assert len(asset_alias_obj.assets) == 1 assert asset_alias_obj.assets[0].uri == asset_uri + @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias_through_metadata(self, dag_maker, session): from airflow.sdk.definitions.asset import AssetAlias from airflow.sdk.definitions.asset.metadata import Metadata @@ -2660,11 +2675,11 @@ def test_outlet_asset_alias_through_metadata(self, dag_maker, session): asset_uri = "test_outlet_asset_alias_through_metadata_ds" asset_alias_name = "test_outlet_asset_alias_through_metadata_asset_alias" - ds1 = AssetModel(id=1, uri="test_outlet_asset_alias_through_metadata_ds") - session.add(ds1) + asm = AssetModel(id=1, uri="test_outlet_asset_alias_through_metadata_ds") + session.add_all([asm, AssetActive.for_asset(asm)]) session.commit() - with dag_maker(dag_id="producer_dag", schedule=None, session=session) as dag: + with dag_maker(dag_id="producer_dag", schedule=None, serialized=True, session=session): @task(outlets=AssetAlias(asset_alias_name)) def producer(*, outlet_events): @@ -2675,7 +2690,6 @@ def producer(*, outlet_events): dr: DagRun = dag_maker.create_dagrun() for ti in dr.get_task_instances(session=session): - ti.refresh_from_task(dag.get_task(ti.task_id)) ti.run(session=session) producer_event = session.scalar(select(AssetEvent).where(AssetEvent.source_task_id == "producer")) @@ -2697,13 +2711,14 @@ def producer(*, outlet_events): assert len(asset_alias_obj.assets) == 1 assert asset_alias_obj.assets[0].uri == asset_uri + @pytest.mark.want_activate_assets(True) def test_outlet_asset_alias_asset_not_exists(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset, AssetAlias asset_alias_name = "test_outlet_asset_alias_asset_not_exists_asset_alias" asset_uri = "did_not_exists" - with dag_maker(dag_id="producer_dag", schedule=None, session=session) as dag: + with dag_maker(dag_id="producer_dag", schedule=None, serialized=True, session=session): @task(outlets=AssetAlias(asset_alias_name)) def producer(*, outlet_events): @@ -2714,7 +2729,6 @@ def producer(*, outlet_events): dr: DagRun = dag_maker.create_dagrun() for ti in dr.get_task_instances(session=session): - ti.refresh_from_task(dag.get_task(ti.task_id)) ti.run(session=session) producer_event = session.scalar(select(AssetEvent).where(AssetEvent.source_task_id == "producer")) @@ -2736,12 +2750,65 @@ def producer(*, outlet_events): assert len(asset_alias_obj.assets) == 1 assert asset_alias_obj.assets[0].uri == asset_uri + def test_outlet_asset_alias_asset_inactive(self, dag_maker, session): + from airflow.sdk.definitions.asset import Asset, AssetAlias + + asset_name = "did_not_exists" + asset = Asset(asset_name) + asset2 = Asset(asset_name, uri="test://asset") + asm = AssetModel.from_public(asset) + asm2 = AssetModel.from_public(asset2) + session.add_all([asm, asm2, AssetActive.for_asset(asm)]) + + asset_alias_name = "alias_with_inactive_asset" + + with dag_maker(dag_id="producer_dag", schedule=None, session=session): + + @task(outlets=AssetAlias(asset_alias_name)) + def producer_without_inactive(*, outlet_events): + outlet_events[AssetAlias(asset_alias_name)].add(asset, extra={"key": "value"}) + + @task(outlets=AssetAlias(asset_alias_name)) + def producer_with_inactive(*, outlet_events): + outlet_events[AssetAlias(asset_alias_name)].add(asset2, extra={"key": "value"}) + + producer_without_inactive() >> producer_with_inactive() + + tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} + tis["producer_without_inactive"].run(session=session) + with pytest.raises(AirflowInactiveAssetAddedToAssetAliasException) as exc: + tis["producer_with_inactive"].run(session=session) + + assert 'Asset(name="did_not_exists", uri="test://asset/")' in str(exc.value) + + producer_event = session.scalar( + select(AssetEvent).where(AssetEvent.source_task_id == "producer_without_inactive") + ) + + assert producer_event.source_task_id == "producer_without_inactive" + assert producer_event.source_dag_id == "producer_dag" + assert producer_event.source_run_id == "test" + assert producer_event.source_map_index == -1 + assert producer_event.asset.uri == asset_name + assert producer_event.extra == {"key": "value"} + assert len(producer_event.source_aliases) == 1 + assert producer_event.source_aliases[0].name == asset_alias_name + + asset_obj = session.scalar(select(AssetModel).where(AssetModel.uri == asset_name)) + assert len(asset_obj.aliases) == 1 + assert asset_obj.aliases[0].name == asset_alias_name + + asset_alias_obj = session.scalar(select(AssetAliasModel)) + assert len(asset_alias_obj.assets) == 1 + assert asset_alias_obj.assets[0].uri == asset_name + + @pytest.mark.want_activate_assets(True) def test_inlet_asset_extra(self, dag_maker, session): from airflow.sdk.definitions.asset import Asset read_task_evaluated = False - with dag_maker(schedule=None, session=session): + with dag_maker(schedule=None, serialized=True, session=session): @task(outlets=Asset("test_inlet_asset_extra")) def write(*, ti, outlet_events): @@ -2791,21 +2858,22 @@ def read(*, inlet_events): assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis assert read_task_evaluated + @pytest.mark.want_activate_assets(True) def test_inlet_asset_alias_extra(self, dag_maker, session): + from airflow.sdk.definitions.asset import Asset, AssetAlias + asset_uri = "test_inlet_asset_extra_ds" asset_alias_name = "test_inlet_asset_extra_asset_alias" asset_model = AssetModel(id=1, uri=asset_uri, group="asset") asset_alias_model = AssetAliasModel(name=asset_alias_name) asset_alias_model.assets.append(asset_model) - session.add_all([asset_model, asset_alias_model]) + session.add_all([asset_model, asset_alias_model, AssetActive.for_asset(Asset(asset_uri))]) session.commit() - from airflow.sdk.definitions.asset import Asset, AssetAlias - read_task_evaluated = False - with dag_maker(schedule=None, session=session): + with dag_maker(schedule=None, serialized=True, session=session): @task(outlets=AssetAlias(asset_alias_name)) def write(*, ti, outlet_events): @@ -2878,6 +2946,7 @@ def read(*, inlet_events): # Should be done. assert not dr.task_instance_scheduling_decisions(session=session).schedulable_tis + @pytest.mark.want_activate_assets(True) @pytest.mark.parametrize( "slicer, expected", [ @@ -2894,7 +2963,7 @@ def test_inlet_asset_extra_slice(self, dag_maker, session, slicer, expected): asset_uri = "test_inlet_asset_extra_slice" - with dag_maker(dag_id="write", schedule="@daily", params={"i": -1}, session=session): + with dag_maker(dag_id="write", serialized=True, schedule="@daily", params={"i": -1}, session=session): @task(outlets=Asset(asset_uri)) def write(*, params, outlet_events): @@ -2942,19 +3011,20 @@ def read(*, inlet_events): (lambda x: x[-5:5], []), ], ) + @pytest.mark.want_activate_assets(True) def test_inlet_asset_alias_extra_slice(self, dag_maker, session, slicer, expected): + from airflow.sdk.definitions.asset import Asset + asset_uri = "test_inlet_asset_alias_extra_slice_ds" asset_alias_name = "test_inlet_asset_alias_extra_slice_asset_alias" asset_model = AssetModel(id=1, uri=asset_uri) asset_alias_model = AssetAliasModel(name=asset_alias_name) asset_alias_model.assets.append(asset_model) - session.add_all([asset_model, asset_alias_model]) + session.add_all([asset_model, asset_alias_model, AssetActive.for_asset(Asset(asset_uri))]) session.commit() - from airflow.sdk.definitions.asset import Asset - - with dag_maker(dag_id="write", schedule="@daily", params={"i": -1}, session=session): + with dag_maker(dag_id="write", schedule="@daily", params={"i": -1}, serialized=True, session=session): @task(outlets=AssetAlias(asset_alias_name)) def write(*, params, outlet_events): @@ -2973,7 +3043,7 @@ def write(*, params, outlet_events): result = "the task does not run" - with dag_maker(dag_id="read", schedule=None, session=session): + with dag_maker(dag_id="read", schedule=None, serialized=True, session=session): @task(inlets=AssetAlias(asset_alias_name)) def read(*, inlet_events): @@ -4055,6 +4125,148 @@ def test_task_instance_history_is_created_when_ti_goes_for_retry(self, dag_maker assert session.query(TaskInstance).count() == 1 assert session.query(TaskInstanceHistory).count() == 1 + @pytest.mark.want_activate_assets(True) + def test_run_with_inactive_assets(self, dag_maker, session): + from airflow.sdk.definitions.asset import Asset + + with dag_maker(schedule=None, serialized=True, session=session): + + @task(outlets=Asset("asset_first")) + def first_asset_task(*, outlet_events): + outlet_events[Asset("asset_first")].extra = {"foo": "bar"} + + first_asset_task() + + with dag_maker(schedule=None, serialized=True, session=session): + + @task(inlets=Asset("asset_second")) + def asset_task_in_inlet(): + pass + + @task(outlets=Asset(name="asset_first", uri="test://asset"), inlets=Asset("asset_second")) + def duplicate_asset_task_in_outlet(*, outlet_events): + outlet_events[Asset(name="asset_first", uri="test://asset")].extra = {"foo": "bar"} + + duplicate_asset_task_in_outlet() >> asset_task_in_inlet() + + tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} + + tis["asset_task_in_inlet"].run(session=session) + with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: + tis["duplicate_asset_task_in_outlet"].run(session=session) + + assert 'Asset(name="asset_second", uri="asset_second")' in str(exc.value) + assert 'Asset(name="asset_first", uri="test://asset/")' in str(exc.value) + + @pytest.mark.want_activate_assets(True) + def test_run_with_inactive_assets_in_outlets_within_the_same_dag(self, dag_maker, session): + from airflow.sdk.definitions.asset import Asset + + with dag_maker(schedule=None, serialized=True, session=session): + + @task(outlets=Asset("asset_first")) + def first_asset_task(*, outlet_events): + outlet_events[Asset("asset_first")].extra = {"foo": "bar"} + + @task(outlets=Asset(name="asset_first", uri="test://asset")) + def duplicate_asset_task(*, outlet_events): + outlet_events[Asset(name="asset_first", uri="test://asset")].extra = {"foo": "bar"} + + first_asset_task() >> duplicate_asset_task() + + tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} + tis["first_asset_task"].run(session=session) + with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: + tis["duplicate_asset_task"].run(session=session) + + assert str(exc.value) == ( + "Task has the following inactive assets in its inlets or outlets: " + 'Asset(name="asset_first", uri="test://asset/")' + ) + + @pytest.mark.want_activate_assets(True) + def test_run_with_inactive_assets_in_outlets_in_different_dag(self, dag_maker, session): + from airflow.sdk.definitions.asset import Asset + + with dag_maker(schedule=None, serialized=True, session=session): + + @task(outlets=Asset("asset_first")) + def first_asset_task(*, outlet_events): + outlet_events[Asset("asset_first")].extra = {"foo": "bar"} + + first_asset_task() + + with dag_maker(schedule=None, serialized=True, session=session): + + @task(outlets=Asset(name="asset_first", uri="test://asset")) + def duplicate_asset_task(*, outlet_events): + outlet_events[Asset(name="asset_first", uri="test://asset")].extra = {"foo": "bar"} + + duplicate_asset_task() + + tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} + with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: + tis["duplicate_asset_task"].run(session=session) + + assert str(exc.value) == ( + "Task has the following inactive assets in its inlets or outlets: " + 'Asset(name="asset_first", uri="test://asset/")' + ) + + @pytest.mark.want_activate_assets(True) + def test_run_with_inactive_assets_in_inlets_within_the_same_dag(self, dag_maker, session): + from airflow.sdk.definitions.asset import Asset + + with dag_maker(schedule=None, serialized=True, session=session): + + @task(inlets=Asset("asset_first")) + def first_asset_task(): + pass + + @task(inlets=Asset(name="asset_first", uri="test://asset")) + def duplicate_asset_task(): + pass + + first_asset_task() >> duplicate_asset_task() + + tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} + with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: + tis["first_asset_task"].run(session=session) + + assert str(exc.value) == ( + "Task has the following inactive assets in its inlets or outlets: " + 'Asset(name="asset_first", uri="asset_first")' + ) + + @pytest.mark.want_activate_assets(True) + def test_run_with_inactive_assets_in_inlets_in_different_dag(self, dag_maker, session): + from airflow.sdk.definitions.asset import Asset + + with dag_maker(schedule=None, serialized=True, session=session): + + @task(inlets=Asset("asset_first")) + def first_asset_task(*, outlet_events): + pass + + first_asset_task() + + with dag_maker(schedule=None, serialized=True, session=session): + + @task(inlets=Asset(name="asset_first", uri="test://asset")) + def duplicate_asset_task(*, outlet_events): + pass + + duplicate_asset_task() + + tis = {ti.task_id: ti for ti in dag_maker.create_dagrun().task_instances} + with pytest.raises(AirflowInactiveAssetInInletOrOutletException) as exc: + tis["duplicate_asset_task"].run(session=session) + + assert str(exc.value) == ( + "Task has the following inactive assets in its inlets or outlets: " + 'Asset(name="asset_first", uri="test://asset/")' + ) + @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) @pytest.mark.parametrize("queue_by_policy", [None, "forced_queue"]) diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py index 3388a33845a4b..0046ca33cc4da 100644 --- a/tests/utils/test_context.py +++ b/tests/utils/test_context.py @@ -20,7 +20,7 @@ import pytest -from airflow.models.asset import AssetAliasModel, AssetModel +from airflow.models.asset import AssetActive, AssetAliasModel, AssetModel from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors @@ -42,9 +42,14 @@ class TestOutletEventAccessor: ), ), ) - def test_add(self, key, asset_alias_events): + @pytest.mark.db_test + def test_add(self, key, asset_alias_events, session): + asset = Asset("test_uri") + session.add_all([AssetModel.from_public(asset), AssetActive.for_asset(asset)]) + session.flush() + outlet_event_accessor = OutletEventAccessor(key=key, extra={}) - outlet_event_accessor.add(Asset("test_uri")) + outlet_event_accessor.add(asset) assert outlet_event_accessor.asset_alias_events == asset_alias_events @pytest.mark.db_test @@ -65,11 +70,11 @@ def test_add(self, key, asset_alias_events): ), ) def test_add_with_db(self, key, asset_alias_events, session): - asm = AssetModel(uri="test://asset-uri", name="test-asset", group="asset") + asset = Asset(uri="test://asset-uri", name="test-asset") + asm = AssetModel.from_public(asset) aam = AssetAliasModel(name="test_alias") - session.add_all([asm, aam]) + session.add_all([asm, aam, AssetActive.for_asset(asset)]) session.flush() - asset = Asset(uri="test://asset-uri", name="test-asset") outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) outlet_event_accessor.add(asset, extra={})