diff --git a/pubchempy.py b/pubchempy.py index 72c053b..33012ae 100644 --- a/pubchempy.py +++ b/pubchempy.py @@ -773,7 +773,7 @@ def get_all_sources(domain: str = "substance") -> list[str]: def download( outformat: str, - path: str, + path: str | os.PathLike, identifier: str | int | list[str | int], namespace: str = "cid", domain: str = "compound", diff --git a/tests/test_download.py b/tests/test_download.py index 89e9605..0e4c949 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -2,38 +2,23 @@ import csv import os -import shutil -import tempfile import pytest from pubchempy import download -@pytest.fixture(scope="module") -def tmp_dir(): - dir = tempfile.mkdtemp() - yield dir - shutil.rmtree(dir) +def test_image_download(tmp_path): + download("PNG", tmp_path / "aspirin.png", "Aspirin", "name") + with pytest.raises(OSError): + download("PNG", tmp_path / "aspirin.png", "Aspirin", "name") + download("PNG", tmp_path / "aspirin.png", "Aspirin", "name", overwrite=True) -def test_image_download(tmp_dir): - download("PNG", os.path.join(tmp_dir, "aspirin.png"), "Aspirin", "name") - with pytest.raises(IOError): - download("PNG", os.path.join(tmp_dir, "aspirin.png"), "Aspirin", "name") - download( - "PNG", os.path.join(tmp_dir, "aspirin.png"), "Aspirin", "name", overwrite=True - ) - - -def test_csv_download(tmp_dir): - download( - "CSV", - os.path.join(tmp_dir, "s.csv"), - [1, 2, 3], - operation="property/ConnectivitySMILES,SMILES", - ) - with open(os.path.join(tmp_dir, "s.csv")) as f: +def test_csv_download(tmp_path): + props = "property/ConnectivitySMILES,SMILES" + download("CSV", tmp_path / "s.csv", [1, 2, 3], operation=props) + with open(tmp_path / "s.csv") as f: rows = list(csv.reader(f)) assert rows[0] == ["CID", "ConnectivitySMILES", "SMILES"] assert rows[1][0] == "1"