diff --git a/generated/provider_dependencies.json.sha256sum b/generated/provider_dependencies.json.sha256sum index 943fd0fc93e4c..b110f434d3945 100644 --- a/generated/provider_dependencies.json.sha256sum +++ b/generated/provider_dependencies.json.sha256sum @@ -1 +1,5 @@ +<<<<<<< Updated upstream 93831555f2a141e481c81c147142aeb860c34ea860163ca130d045e5ecd0a83b +======= +e14a01deac3579ec86383046f5e8fee680fae21d446e6d81cbf10395a9837cad +>>>>>>> Stashed changes diff --git a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py index f0b36e7b41f9b..0d01d709d6ddd 100644 --- a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py @@ -23,7 +23,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.common.compat.sdk import task diff --git a/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py index 062b3112b39de..38fcdffe13fc1 100644 --- a/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py @@ -23,7 +23,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.snowflake.operators.snowpark import SnowparkOperator from airflow.utils import timezone @@ -69,9 +69,10 @@ def func2(): ] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() == number + assert mock_snowflake_hook.call_count == 2 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2 @@ -109,8 +110,8 @@ def func3(number: int): ] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 3 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3 @@ -135,8 +136,8 @@ def func(session: Session): ) dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() is None mock_snowflake_hook.assert_called_once() mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() @@ -170,8 +171,8 @@ def func(session: Session): ) dr = dag_maker.create_dagrun() - ti = dr.get_task_instances()[0] - ti.run() + task = dag.tasks[0] + ti = dag_maker.run_ti(task.task_id, dr) query_tag = ti.xcom_pull() assert query_tag == { "dag_id": TEST_DAG_ID, @@ -179,3 +180,36 @@ def func(session: Session): "task_id": TASK_ID, "operator": "SnowparkOperator", } + + @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook") + def test_snowpark_operator_closes_session_on_exception( + self, + mock_snowflake_hook, + dag_maker, + ): + mock_session = mock_snowflake_hook.return_value.get_snowpark_session.return_value + + with dag_maker(dag_id=TEST_DAG_ID) as dag: + + def func(session: Session): + raise ValueError("boom") + + SnowparkOperator( + task_id=TASK_ID, + snowflake_conn_id=CONN_ID, + python_callable=func, + warehouse="test_warehouse", + database="test_database", + schema="test_schema", + role="test_role", + authenticator="externalbrowser", + dag=dag, + ) + + dr = dag_maker.create_dagrun() + task = dag.tasks[0] + + with pytest.raises(ValueError, match="boom"): + dag_maker.run_ti(task.task_id, dr) + + mock_session.close.assert_called_once() diff --git a/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py index 181d8ed8f3aa2..c8c9c7945ea39 100644 --- a/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py @@ -18,7 +18,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.snowflake.utils.snowpark import inject_session_into_op_kwargs