From 9424e61967679b0cf75124185bdd593fc16508bf Mon Sep 17 00:00:00 2001 From: Sameer Mesiah Date: Fri, 19 Jun 2026 21:04:24 +0100 Subject: [PATCH] Add a test verifying that SnowparkOperator closes the Snowpark session when the user callable raises an exception. As part of enabling the new coverage, migrate the Snowpark tests away from the removed TaskInstance.run() API and update Snowpark test dependency detection to use the importable Snowpark module. --- .../provider_dependencies.json.sha256sum | 4 ++ .../snowflake/decorators/test_snowpark.py | 2 +- .../unit/snowflake/operators/test_snowpark.py | 52 +++++++++++++++---- .../unit/snowflake/utils/test_snowpark.py | 2 +- 4 files changed, 49 insertions(+), 11 deletions(-) diff --git a/generated/provider_dependencies.json.sha256sum b/generated/provider_dependencies.json.sha256sum index 943fd0fc93e4c..b110f434d3945 100644 --- a/generated/provider_dependencies.json.sha256sum +++ b/generated/provider_dependencies.json.sha256sum @@ -1 +1,5 @@ +<<<<<<< Updated upstream 93831555f2a141e481c81c147142aeb860c34ea860163ca130d045e5ecd0a83b +======= +e14a01deac3579ec86383046f5e8fee680fae21d446e6d81cbf10395a9837cad +>>>>>>> Stashed changes diff --git a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py index f0b36e7b41f9b..0d01d709d6ddd 100644 --- a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py @@ -23,7 +23,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.common.compat.sdk import task diff --git a/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py index 062b3112b39de..38fcdffe13fc1 100644 --- a/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py @@ -23,7 +23,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.snowflake.operators.snowpark import SnowparkOperator from airflow.utils import timezone @@ -69,9 +69,10 @@ def func2(): ] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() == number + assert mock_snowflake_hook.call_count == 2 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2 @@ -109,8 +110,8 @@ def func3(number: int): ] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 3 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3 @@ -135,8 +136,8 @@ def func(session: Session): ) dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() is None mock_snowflake_hook.assert_called_once() mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() @@ -170,8 +171,8 @@ def func(session: Session): ) dr = dag_maker.create_dagrun() - ti = dr.get_task_instances()[0] - ti.run() + task = dag.tasks[0] + ti = dag_maker.run_ti(task.task_id, dr) query_tag = ti.xcom_pull() assert query_tag == { "dag_id": TEST_DAG_ID, @@ -179,3 +180,36 @@ def func(session: Session): "task_id": TASK_ID, "operator": "SnowparkOperator", } + + @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook") + def test_snowpark_operator_closes_session_on_exception( + self, + mock_snowflake_hook, + dag_maker, + ): + mock_session = mock_snowflake_hook.return_value.get_snowpark_session.return_value + + with dag_maker(dag_id=TEST_DAG_ID) as dag: + + def func(session: Session): + raise ValueError("boom") + + SnowparkOperator( + task_id=TASK_ID, + snowflake_conn_id=CONN_ID, + python_callable=func, + warehouse="test_warehouse", + database="test_database", + schema="test_schema", + role="test_role", + authenticator="externalbrowser", + dag=dag, + ) + + dr = dag_maker.create_dagrun() + task = dag.tasks[0] + + with pytest.raises(ValueError, match="boom"): + dag_maker.run_ti(task.task_id, dr) + + mock_session.close.assert_called_once() diff --git a/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py index 181d8ed8f3aa2..c8c9c7945ea39 100644 --- a/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py @@ -18,7 +18,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.snowflake.utils.snowpark import inject_session_into_op_kwargs