From 60cd5ad302d9140650160a89d86288f145118fb1 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 27 Dec 2024 19:40:10 +0530 Subject: [PATCH] Introducing get_run_data_interval on LazyDeserializedDAG (#45211) --- airflow/serialization/serialized_objects.py | 18 +++++++++-- tests/dag_processing/test_collection.py | 33 ++++++++++++++++++++- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 1f43f7865d13d..3263f3dc5c320 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -41,7 +41,7 @@ from airflow.exceptions import AirflowException, SerializationError, TaskDeferred from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection -from airflow.models.dag import DAG +from airflow.models.dag import DAG, _get_model_data_interval from airflow.models.expandinput import ( EXPAND_INPUT_EMPTY, create_expand_input, @@ -95,13 +95,14 @@ if TYPE_CHECKING: from inspect import Parameter + from airflow.models import DagRun from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.expandinput import ExpandInput from airflow.models.operator import Operator from airflow.sdk.definitions.node import DAGNode from airflow.serialization.json_schema import Validator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep - from airflow.timetables.base import DagRunInfo, Timetable + from airflow.timetables.base import DagRunInfo, DataInterval, Timetable HAS_KUBERNETES: bool try: @@ -1960,6 +1961,19 @@ def get_task_assets( if isinstance(obj, of_type): yield task["task_id"], obj + def get_run_data_interval(self, run: DagRun) -> DataInterval: + """Get the data interval of this run.""" + if run.dag_id is not None and run.dag_id != self.dag_id: + raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}") + + data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end") + # the older implementation has call to infer_automated_data_interval if data_interval is None, do we want to keep that or raise + # an exception? + if data_interval is None: + raise ValueError(f"Cannot calculate data interval for run {run}") + + return data_interval + if TYPE_CHECKING: access_control: Mapping[str, Mapping[str, Collection[str]] | Collection[str]] | None = pydantic.Field( init=False, default=None diff --git a/tests/dag_processing/test_collection.py b/tests/dag_processing/test_collection.py index ca435cc1a4fae..a248904cbefcc 100644 --- a/tests/dag_processing/test_collection.py +++ b/tests/dag_processing/test_collection.py @@ -40,7 +40,7 @@ ) from airflow.exceptions import SerializationError from airflow.listeners.listener import get_listener_manager -from airflow.models import DagModel, Trigger +from airflow.models import DagModel, DagRun, Trigger from airflow.models.asset import ( AssetActive, asset_trigger_association_table, @@ -449,6 +449,27 @@ def _sync_perms(): {"owners": ["airflow"]}, id="default-owner", ), + pytest.param( + { + "_tasks_": [ + EmptyOperator(task_id="task", owner="owner1"), + EmptyOperator(task_id="task2", owner="owner2"), + EmptyOperator(task_id="task3"), + EmptyOperator(task_id="task4", owner="owner2"), + ], + "schedule": "0 0 * * *", + "catchup": False, + }, + { + "default_view": conf.get("webserver", "dag_default_view").lower(), + "owners": ["owner1", "owner2"], + "next_dagrun": tz.datetime(2020, 1, 5, 0, 0, 0), + "next_dagrun_data_interval_start": tz.datetime(2020, 1, 5, 0, 0, 0), + "next_dagrun_data_interval_end": tz.datetime(2020, 1, 6, 0, 0, 0), + "next_dagrun_create_after": tz.datetime(2020, 1, 6, 0, 0, 0), + }, + id="with-scheduled-dagruns", + ), ], ) @pytest.mark.usefixtures("clean_db") @@ -462,6 +483,16 @@ def test_dagmodel_properties(self, attrs, expected, session, time_machine): if tasks: dag.add_tasks(tasks) + if attrs.pop("schedule", None): + dr_kwargs = { + "dag_id": "dag", + "run_type": "scheduled", + "data_interval": (dt, dt + timedelta(minutes=5)), + } + dr1 = DagRun(logical_date=dt, run_id="test_run_id_1", **dr_kwargs, start_date=dt) + session.add(dr1) + session.commit() + update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session) orm_dag = session.get(DagModel, ("dag",))