Skip to content

Commit

Permalink
Introducing get_run_data_interval on LazyDeserializedDAG (apache#45211)
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh authored Dec 27, 2024
1 parent c9badd7 commit 60cd5ad
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
18 changes: 16 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from airflow.exceptions import AirflowException, SerializationError, TaskDeferred
from airflow.models.baseoperator import BaseOperator
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.models.dag import DAG, _get_model_data_interval
from airflow.models.expandinput import (
EXPAND_INPUT_EMPTY,
create_expand_input,
Expand Down Expand Up @@ -95,13 +95,14 @@
if TYPE_CHECKING:
from inspect import Parameter

from airflow.models import DagRun
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.expandinput import ExpandInput
from airflow.models.operator import Operator
from airflow.sdk.definitions.node import DAGNode
from airflow.serialization.json_schema import Validator
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.timetables.base import DagRunInfo, Timetable
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable

HAS_KUBERNETES: bool
try:
Expand Down Expand Up @@ -1960,6 +1961,19 @@ def get_task_assets(
if isinstance(obj, of_type):
yield task["task_id"], obj

def get_run_data_interval(self, run: DagRun) -> DataInterval:
"""Get the data interval of this run."""
if run.dag_id is not None and run.dag_id != self.dag_id:
raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}")

data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end")
# the older implementation has call to infer_automated_data_interval if data_interval is None, do we want to keep that or raise
# an exception?
if data_interval is None:
raise ValueError(f"Cannot calculate data interval for run {run}")

return data_interval

if TYPE_CHECKING:
access_control: Mapping[str, Mapping[str, Collection[str]] | Collection[str]] | None = pydantic.Field(
init=False, default=None
Expand Down
33 changes: 32 additions & 1 deletion tests/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
from airflow.exceptions import SerializationError
from airflow.listeners.listener import get_listener_manager
from airflow.models import DagModel, Trigger
from airflow.models import DagModel, DagRun, Trigger
from airflow.models.asset import (
AssetActive,
asset_trigger_association_table,
Expand Down Expand Up @@ -449,6 +449,27 @@ def _sync_perms():
{"owners": ["airflow"]},
id="default-owner",
),
pytest.param(
{
"_tasks_": [
EmptyOperator(task_id="task", owner="owner1"),
EmptyOperator(task_id="task2", owner="owner2"),
EmptyOperator(task_id="task3"),
EmptyOperator(task_id="task4", owner="owner2"),
],
"schedule": "0 0 * * *",
"catchup": False,
},
{
"default_view": conf.get("webserver", "dag_default_view").lower(),
"owners": ["owner1", "owner2"],
"next_dagrun": tz.datetime(2020, 1, 5, 0, 0, 0),
"next_dagrun_data_interval_start": tz.datetime(2020, 1, 5, 0, 0, 0),
"next_dagrun_data_interval_end": tz.datetime(2020, 1, 6, 0, 0, 0),
"next_dagrun_create_after": tz.datetime(2020, 1, 6, 0, 0, 0),
},
id="with-scheduled-dagruns",
),
],
)
@pytest.mark.usefixtures("clean_db")
Expand All @@ -462,6 +483,16 @@ def test_dagmodel_properties(self, attrs, expected, session, time_machine):
if tasks:
dag.add_tasks(tasks)

if attrs.pop("schedule", None):
dr_kwargs = {
"dag_id": "dag",
"run_type": "scheduled",
"data_interval": (dt, dt + timedelta(minutes=5)),
}
dr1 = DagRun(logical_date=dt, run_id="test_run_id_1", **dr_kwargs, start_date=dt)
session.add(dr1)
session.commit()

update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session)

orm_dag = session.get(DagModel, ("dag",))
Expand Down

0 comments on commit 60cd5ad

Please sign in to comment.