diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py index ee1007fba6124..c339d629d0e23 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py @@ -74,11 +74,37 @@ def _executor_initializer(): This function must be picklable, so it cannot be defined as an inner method or local function. Reconfigures the ORM engine to prevent issues that arise when multiple processes interact with - the Airflow database. + the Airflow database, and re-initializes ``Stats`` so that metrics emitted from worker + processes (e.g. ``ol.event.size.*`` from ``_emit_manual_state_change_event``) are routed to + the configured statsd backend instead of being silently dropped by ``NoStatsLogger`` — the + parent's ``Stats.initialize(...)`` call from scheduler startup does not propagate across the + spawn boundary. """ # This initializer is used only on the scheduler # We can configure_orm regardless of the Airflow version, as DB access is always allowed from scheduler. settings.configure_orm() + try: + from airflow.observability.metrics import stats_utils + + Stats.initialize(factory=stats_utils.get_stats_factory(Stats)) + except ImportError: + # ``stats_utils`` lives under ``airflow.observability.metrics`` in current Airflow; if the + # import path changes or is unavailable, fall through silently — gauge calls will simply + # land on ``NoStatsLogger`` as before, which is no worse than current behavior. + pass + + +def _emit_manual_state_change_event(adapter_method, stats_key, **kwargs): + """ + Emit an OL event via the given adapter method and record its serialized size. + + Module-level so it is picklable across the ProcessPoolExecutor boundary used by + `_on_task_instance_manual_state_change` for scheduler-side "task state changed + externally" emissions. + """ + event = adapter_method(**kwargs) + Stats.gauge(stats_key, len(Serde.to_json(event).encode("utf-8"))) + return event class OpenLineageListener: @@ -653,6 +679,17 @@ def _on_task_instance_manual_state_change( ti_state: TaskInstanceState, error: None | str | BaseException = None, ) -> None: + """ + Emit an OL event from the scheduler when a TI transitions externally. + + This path is only reached on the scheduler (``process_executor_events -> + handle_failure``, or manual UI/API state changes). Emission is routed through + the same ``ProcessPoolExecutor`` the DAG-run listeners use rather than through + ``_fork_execute``: the pool's ``_executor_initializer`` rebuilds the ORM once + per worker, so the child never shares a pooled Postgres SSL connection with + the scheduler, and bursts of external-state-change events no longer produce a + fork-per-event. + """ self.log.debug("`_on_task_instance_manual_state_change` was called with state: `%s`.", ti_state) end_date = timezone.utcnow() @@ -674,15 +711,25 @@ def _on_task_instance_manual_state_change( ) return - @print_warning(self.log) - def on_state_change(): - date = dagrun.logical_date or dagrun.run_after - parent_run_id = self.adapter.build_dag_run_id( - dag_id=ti.dag_id, - logical_date=date, - clear_number=dagrun.clear_number, - ) + try: + if self._executor is None: + self.log.debug("Executor has not started before `_on_task_instance_manual_state_change`") + return + + if ti_state == TaskInstanceState.FAILED: + adapter_method = self.adapter.fail_task + event_type = RunState.FAIL.value.lower() + elif ti_state in (TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED): + adapter_method = self.adapter.complete_task + event_type = RunState.COMPLETE.value.lower() + else: + raise ValueError(f"Unsupported ti_state: `{ti_state}`.") + # Extract primitives from live ORM objects in the parent (scheduler) + # before crossing the pool boundary. Passing ORM objects through the pool + # pickler loses TaskGroup attributes and crashes event emission -- see + # the equivalent note in `on_dag_run_running` (listener.py ~868). + date = dagrun.logical_date or dagrun.run_after task_uuid = self.adapter.build_task_instance_run_id( dag_id=ti.dag_id, task_id=ti.task_id, @@ -690,29 +737,38 @@ def on_state_change(): logical_date=date, map_index=ti.map_index, ) + parent_run_id = self.adapter.build_dag_run_id( + dag_id=ti.dag_id, + logical_date=date, + clear_number=dagrun.clear_number, + ) - data_interval_start = dagrun.data_interval_start + # Mirror the pattern used in the other listener call sites: convert + # `datetime` to ISO-8601 string, but preserve any non-`datetime` + # value as-is in case a duck-typed caller already passed a string. + data_interval_start: str | datetime | None = dagrun.data_interval_start if isinstance(data_interval_start, datetime): data_interval_start = data_interval_start.isoformat() - data_interval_end = dagrun.data_interval_end + data_interval_end: str | datetime | None = dagrun.data_interval_end if isinstance(data_interval_end, datetime): data_interval_end = data_interval_end.isoformat() - dag_tags, owners, doc, doc_type = None, None, None, None - airflow_run_facet = {} + dag_tags: list | None = None + owners: list[str] | None = None + doc: str | None = None + doc_type: str | None = None + airflow_run_facet: dict = {} if task: # on scheduler, we should have access to task doc, doc_type = get_task_documentation(task) dag = getattr(task, "dag") if dag: if not doc: doc, doc_type = get_dag_documentation(dag) - dag_tags = dag.tags owners = [x.strip() for x in (task if task.owner != "airflow" else dag).owner.split(",")] - airflow_run_facet = get_airflow_run_facet(dagrun, dag, ti, task, task_uuid) - adapter_kwargs = { + adapter_kwargs: dict = { "run_id": task_uuid, "job_name": get_job_name(ti), "end_time": end_date.isoformat(), @@ -733,23 +789,21 @@ def on_state_change(): **get_airflow_debug_facet(), }, } - if ti_state == TaskInstanceState.FAILED: - event_type = RunState.FAIL.value.lower() - redacted_event = self.adapter.fail_task(**adapter_kwargs, error=error) - elif ti_state in (TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED): - event_type = RunState.COMPLETE.value.lower() - redacted_event = self.adapter.complete_task(**adapter_kwargs) - else: - raise ValueError(f"Unsupported ti_state: `{ti_state}`.") + adapter_kwargs["error"] = error - operator_name = ti.operator.lower() - Stats.gauge( + operator_name = (ti.operator or "unknown").lower() + self.submit_callable( + _emit_manual_state_change_event, + adapter_method, f"ol.event.size.{event_type}.{operator_name}", - len(Serde.to_json(redacted_event).encode("utf-8")), + **adapter_kwargs, + ) + except BaseException as e: + self.log.warning( + "OpenLineage received exception in method `_on_task_instance_manual_state_change`", + exc_info=e, ) - - self._execute(on_state_change, "on_state_change", use_fork=True) def _execute(self, callable, callable_name: str, use_fork: bool = False): if use_fork: @@ -825,7 +879,7 @@ def on_dag_run_running(self, dag_run: DagRun, msg: str) -> None: ) return - if not self.executor: + if self._executor is None: self.log.debug("Executor have not started before `on_dag_run_running`") return @@ -877,7 +931,7 @@ def on_dag_run_success(self, dag_run: DagRun, msg: str) -> None: ) return - if not self.executor: + if self._executor is None: self.log.debug("Executor have not started before `on_dag_run_success`") return @@ -929,7 +983,7 @@ def on_dag_run_failed(self, dag_run: DagRun, msg: str) -> None: ) return - if not self.executor: + if self._executor is None: self.log.debug("Executor have not started before `on_dag_run_failed`") return diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index 608e789436e96..e62d49c5634f9 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -85,6 +85,23 @@ def regular_call(self, callable, callable_name, use_fork): callable() +def direct_submit_call(self, callable, *args, **kwargs): + """Synchronous stand-in for ``OpenLineageListener.submit_callable``. + + Bypasses the ``ProcessPoolExecutor`` so tests can assert against mocked + adapter methods without hitting pickling of ``unittest.mock.Mock``. + When the submitted callable is ``_emit_manual_state_change_event``, skip + its ``Stats.gauge`` side effect (which would try to ``Serde.to_json`` a + ``MagicMock`` return value) and invoke the adapter method directly. + """ + from airflow.providers.openlineage.plugins.listener import _emit_manual_state_change_event + + if callable is _emit_manual_state_change_event: + adapter_method, _stats_key, *_ = args + return adapter_method(**kwargs) + return callable(*args, **kwargs) + + class MockExecutor: def __init__(self, *args, **kwargs): self.submitted = False @@ -1457,13 +1474,15 @@ def test_adapter_fail_task_is_called_with_dag_description_when_task_doc_is_empty assert listener.adapter.fail_task.call_args.kwargs["job_description"] == "Test DAG Description" assert listener.adapter.fail_task.call_args.kwargs["job_description_type"] == "text/plain" + @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._fork_execute") @mock.patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit") @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet") @mock.patch( - "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + "airflow.providers.openlineage.plugins.listener.OpenLineageListener.submit_callable", + new=direct_submit_call, ) def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_model( self, @@ -1472,6 +1491,7 @@ def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_ mock_debug_facet, mock_debug_mode, mock_emit, + mock_fork_execute, time_machine, ): """Tests that the 'fail_task' method of the OpenLineageAdapter is invoked with the correct arguments. @@ -1482,6 +1502,7 @@ def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_ time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False) listener, task_instance = self._create_listener_and_task_instance(runtime_ti=False) + listener._executor = MagicMock() # satisfy `if self._executor is None` guard mock_get_airflow_run_facet.return_value = {"airflow": 3} mock_get_task_parent_run_facet.return_value = {"parent": 4} mock_debug_facet.return_value = {"debug": "packages"} @@ -1513,6 +1534,8 @@ def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_ error=err, ) listener.adapter.fail_task.assert_called_once_with(**expected_args) + # Regression guard: manual state-change emission must not go through _fork_execute. + mock_fork_execute.assert_not_called() expected_args["run_id"] = "9d3b14f7-de91-40b6-aeef-e887e2c7673e" adapter = OpenLineageAdapter() @@ -1644,15 +1667,23 @@ def test_adapter_complete_task_is_called_with_dag_description_when_task_doc_is_e assert listener.adapter.complete_task.call_args.kwargs["job_description"] == "Test DAG Description" assert listener.adapter.complete_task.call_args.kwargs["job_description_type"] == "text/plain" + @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._fork_execute") @mock.patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit") @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet") @mock.patch( - "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + "airflow.providers.openlineage.plugins.listener.OpenLineageListener.submit_callable", + new=direct_submit_call, ) def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_instance_model( - self, mock_get_task_parent_run_facet, mock_debug_facet, mock_debug_mode, mock_emit, time_machine + self, + mock_get_task_parent_run_facet, + mock_debug_facet, + mock_debug_mode, + mock_emit, + mock_fork_execute, + time_machine, ): """Tests that the 'complete_task' method of the OpenLineageAdapter is called with the correct arguments. @@ -1662,6 +1693,7 @@ def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_insta time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False) listener, task_instance = self._create_listener_and_task_instance(runtime_ti=False) + listener._executor = MagicMock() # satisfy `if self._executor is None` guard delattr(task_instance, "task") # Test api server path, where task is not available mock_get_task_parent_run_facet.return_value = {"parent": 4} mock_debug_facet.return_value = {"debug": "packages"} @@ -1691,6 +1723,8 @@ def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_insta }, ) assert calls[0][1] == expected_args + # Regression guard: manual state-change emission must not go through _fork_execute. + mock_fork_execute.assert_not_called() expected_args["run_id"] = "9d3b14f7-de91-40b6-aeef-e887e2c7673e" adapter = OpenLineageAdapter() @@ -1851,15 +1885,23 @@ def test_listener_on_task_instance_skipped_do_not_call_adapter_when_disabled_ope listener.extractor_manager.extract_metadata.assert_not_called() listener.adapter.complete_task.assert_not_called() + @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._fork_execute") @mock.patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit") @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet") @mock.patch( - "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + "airflow.providers.openlineage.plugins.listener.OpenLineageListener.submit_callable", + new=direct_submit_call, ) def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_instance_model_on_skip( - self, mock_get_task_parent_run_facet, mock_debug_facet, mock_debug_mode, mock_emit, time_machine + self, + mock_get_task_parent_run_facet, + mock_debug_facet, + mock_debug_mode, + mock_emit, + mock_fork_execute, + time_machine, ): """Tests that the 'complete_task' method of the OpenLineageAdapter is called with the correct arguments. @@ -1869,6 +1911,7 @@ def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_insta time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False) listener, task_instance = self._create_listener_and_task_instance(runtime_ti=False) + listener._executor = MagicMock() # satisfy `if self._executor is None` guard delattr(task_instance, "task") # Test api server path, where task is not available mock_get_task_parent_run_facet.return_value = {"parent": 4} mock_debug_facet.return_value = {"debug": "packages"} @@ -1898,6 +1941,8 @@ def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_insta }, ) assert calls[0][1] == expected_args + # Regression guard: manual state-change emission must not go through _fork_execute. + mock_fork_execute.assert_not_called() expected_args["run_id"] = "9d3b14f7-de91-40b6-aeef-e887e2c7673e" adapter = OpenLineageAdapter()