Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions airflow-core/docs/administration-and-deployment/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ looks like:
# A list of timetable classes to register so they can be used in Dags.
timetables = []

# A list of deadline reference classes that can be used as custom deadlines in Dags.
# Custom deadline reference classes must be registered here in order to be
# resolvable at scheduler-side deserialization time; classes that are not
# registered will raise ``DeadlineReferenceNotRegistered`` when a Dag attempts
# to use them.
deadline_references = []

# A list of Listeners that plugin provides. Listeners can register to
# listen to particular events that happen in Airflow, like
# TaskInstance state changes. Listeners are python modules.
Expand Down
1 change: 1 addition & 0 deletions airflow-core/newsfragments/66737.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Custom deadline reference classes must now be registered via the new ``deadline_references`` attribute on ``AirflowPlugin``, matching the existing pattern for custom timetables and custom partition mappers. To use a custom ``DeadlineReference`` subclass, register it in a plugin's ``deadline_references`` list. Custom references that are not registered will raise ``DeadlineReferenceNotRegistered`` at deserialization.
13 changes: 13 additions & 0 deletions airflow-core/src/airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

if TYPE_CHECKING:
from airflow.listeners.listener import ListenerManager
from airflow.models.deadline import DeadlineReferenceType
from airflow.partition_mappers.base import PartitionMapper
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.timetables.base import Timetable
Expand Down Expand Up @@ -286,6 +287,18 @@ def get_partition_mapper_plugins() -> dict[str, type[PartitionMapper]]:
}


@cache
def get_deadline_references_plugins() -> dict[str, type[DeadlineReferenceType]]:
"""Collect and get deadline reference classes registered by plugins."""
log.debug("Initialize extra deadline reference plugins")

return {
qualname(deadline_ref_cls): deadline_ref_cls
for plugin in _get_plugins()[0]
for deadline_ref_cls in plugin.deadline_references
}


@cache
def integrate_macros_plugins() -> None:
"""Integrates macro plugins."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ def serialize_reference(self) -> dict:

@classmethod
def deserialize_reference(cls, reference_data: dict):
from airflow._shared.module_loading import import_string
from airflow.serialization.helpers import find_registered_custom_deadline_reference

custom_class = import_string(reference_data["__class_path"])
custom_class = find_registered_custom_deadline_reference(reference_data["__class_path"])
inner_ref = custom_class.deserialize_reference(reference_data)
return cls(inner_ref)

Expand Down
27 changes: 27 additions & 0 deletions airflow-core/src/airflow/serialization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.configuration import conf

if TYPE_CHECKING:
from airflow.models.deadline import DeadlineReferenceType
from airflow.partition_mappers.base import PartitionMapper
from airflow.timetables.base import Timetable as CoreTimetable

Expand Down Expand Up @@ -145,6 +146,32 @@ def find_registered_custom_partition_mapper(importable_string: str) -> type[Part
raise PartitionMapperNotFound(importable_string)


class DeadlineReferenceNotRegistered(ValueError):
"""When an unregistered custom deadline reference is being accessed."""

def __init__(self, type_string: str) -> None:
self.type_string = type_string

def __str__(self) -> str:
return (
f"Custom deadline reference class {self.type_string!r} is not "
"registered. Custom deadline references must be registered via the "
"`deadline_references` attribute on an AirflowPlugin."
)


def find_registered_custom_deadline_reference(
importable_string: str,
) -> type[DeadlineReferenceType]:
"""Find a user-defined custom deadline reference class registered via a plugin."""
from airflow import plugins_manager

deadline_ref_classes = plugins_manager.get_deadline_references_plugins()
with contextlib.suppress(KeyError):
return deadline_ref_classes[importable_string]
raise DeadlineReferenceNotRegistered(importable_string)


def is_core_timetable_import_path(importable_string: str) -> bool:
"""Whether an importable string points to a core timetable class."""
return importable_string.startswith("airflow.timetables.")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest

from airflow import plugins_manager
from airflow.models.deadline import ReferenceModels
from airflow.serialization.definitions.deadline import SerializedReferenceModels
from airflow.serialization.helpers import (
DeadlineReferenceNotRegistered,
find_registered_custom_deadline_reference,
)


class _RegisteredCustomReference(ReferenceModels.BaseDeadlineReference):
"""Fake deadline reference registered through a plugin in these tests."""

required_kwargs: set[str] = set()

@classmethod
def deserialize_reference(cls, reference_data: dict):
return cls()

def serialize_reference(self) -> dict:
return {}

def _evaluate_with(self, *, session, **kwargs):
return None


_IMPORTABLE = f"{_RegisteredCustomReference.__module__}._RegisteredCustomReference"


@pytest.fixture
def fake_plugin_registry(monkeypatch):
"""Stub `get_deadline_references_plugins` to advertise a single registered class."""
registered = {_IMPORTABLE: _RegisteredCustomReference}
monkeypatch.setattr(
plugins_manager,
"get_deadline_references_plugins",
lambda: registered,
)
return registered


def test_find_registered_returns_class(fake_plugin_registry):
assert find_registered_custom_deadline_reference(_IMPORTABLE) is _RegisteredCustomReference


def test_find_registered_raises_for_unknown(fake_plugin_registry):
with pytest.raises(DeadlineReferenceNotRegistered) as exc_info:
find_registered_custom_deadline_reference("not.registered.SomeReference")
assert exc_info.value.type_string == "not.registered.SomeReference"
assert "not.registered.SomeReference" in str(exc_info.value)


def test_find_registered_raises_when_registry_empty(monkeypatch):
monkeypatch.setattr(
plugins_manager,
"get_deadline_references_plugins",
lambda: {},
)
with pytest.raises(DeadlineReferenceNotRegistered):
find_registered_custom_deadline_reference("anything.at.all.MyReference")


def test_serialized_custom_reference_uses_registry(fake_plugin_registry):
result = SerializedReferenceModels.SerializedCustomReference.deserialize_reference(
{"__class_path": _IMPORTABLE}
)

assert isinstance(result, SerializedReferenceModels.SerializedCustomReference)
assert isinstance(result.inner_ref, _RegisteredCustomReference)


def test_serialized_custom_reference_rejects_unregistered(monkeypatch):
monkeypatch.setattr(
plugins_manager,
"get_deadline_references_plugins",
lambda: {},
)
with pytest.raises(DeadlineReferenceNotRegistered):
SerializedReferenceModels.SerializedCustomReference.deserialize_reference(
{"__class_path": "some.other.module.UnregisteredReference"}
)
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ class AirflowPlugin:
# A list of timetable classes that can be used for Dag scheduling.
partition_mappers: list[Any] = []

# A list of deadline reference classes that can be used as custom deadlines in Dags.
deadline_references: list[Any] = []

# A list of listeners that can be used for tracking task and Dag states.
listeners: list[ModuleType | object] = []

Expand Down
Loading