diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 010637c693f..5e54cc91674 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -19,11 +19,13 @@ from poetry.utils._compat import WINDOWS from poetry.utils.helpers import Downloader from poetry.utils.helpers import HTTPRangeRequestSupportedError +from poetry.utils.helpers import directory from poetry.utils.helpers import download_file from poetry.utils.helpers import ensure_path from poetry.utils.helpers import extractall from poetry.utils.helpers import get_file_hash from poetry.utils.helpers import get_highest_priority_hash_type +from poetry.utils.helpers import merge_dicts if TYPE_CHECKING: @@ -157,6 +159,20 @@ def test_download_file( assert http.calls[-1].request.headers["Accept-Encoding"] == "Identity" +def test_downloader_with_invalid_content_length( + http: responses.RequestsMock, tmp_path: Path +) -> None: + url = "https://foo.com/demo.txt" + http.get(url, body=b"demo", headers={"Content-Length": "invalid"}) + dest = tmp_path / "demo.txt" + + downloader = Downloader(url, dest) + + assert downloader.total_size == 0 + assert list(downloader.download_with_progress(chunk_size=2)) == [2, 4] + assert dest.read_bytes() == b"demo" + + def test_download_file_recover_from_error( http: responses.RequestsMock, fixture_dir: FixtureDirGetter, tmp_path: Path ) -> None: @@ -348,6 +364,37 @@ def test_ensure_path_directory(tmp_path: Path) -> None: assert ensure_path(path=path, is_directory=True) is path +def test_directory_restores_working_directory_after_error(tmp_path: Path) -> None: + cwd = Path.cwd() + + with pytest.raises(RuntimeError), directory(tmp_path): + assert Path.cwd() == tmp_path + raise RuntimeError("expected failure") + + assert Path.cwd() == cwd + + +def test_merge_dicts_merges_nested_mappings() -> None: + config = { + "installer": {"parallel": True, "max-workers": 4}, + "virtualenvs": {"create": True}, + } + + merge_dicts( + config, + { + "installer": {"max-workers": 8}, + "repositories": {"foo": {"url": "https://foo.example/simple/"}}, + }, + ) + + assert config == { + "installer": {"parallel": True, "max-workers": 8}, + "virtualenvs": {"create": True}, + "repositories": {"foo": {"url": "https://foo.example/simple/"}}, + } + + @pytest.mark.parametrize("relative", [False, True]) @pytest.mark.parametrize("existing", [False, True]) def test_extractall_sdist_no_path_traversal(