From 3e87e708ce91ea90eb5be776ed0a668ca5ca2c15 Mon Sep 17 00:00:00 2001 From: Cadu Date: Sun, 25 Jan 2026 20:00:01 -0300 Subject: [PATCH 1/2] Improve type hint docs, refactor rewriter, restructure tests (#31) - Update README type hint recommendation from `deleted_at: datetime` to `deleted_at: Mapped[datetime | None]` for proper SQLAlchemy support - Remove global_rewriter singleton, attach rewriter to generated mixin class as `_sqlalchemy_easy_softdelete_rewriter` for better isolation - Restructure tests into isolated "worlds" to enable testing different mixin configurations: - tests/default_config/ - default mixin configuration - tests/custom_deleted_field_name/ - custom field name (removed_at) - tests/custom_method_names/ - custom method names (soft_delete/restore) - tests/disabled_methods/ - disabled delete/undelete methods - tests/custom_default_value/ - custom default value function - tests/integer_field_type/ - integer field type instead of DateTime - Add comprehensive tests for all mixin configuration options --- README.md | 8 +- .../handler/sqlalchemy_easy_softdelete.py | 17 +- sqlalchemy_easy_softdelete/mixin.py | 6 +- tests/conftest.py | 33 +--- tests/custom_default_value/__init__.py | 0 tests/custom_default_value/conftest.py | 15 ++ tests/custom_default_value/model.py | 31 +++ .../test_custom_default_value.py | 31 +++ tests/custom_deleted_field_name/__init__.py | 0 tests/custom_deleted_field_name/conftest.py | 15 ++ tests/custom_deleted_field_name/model.py | 31 +++ .../test_custom_field_name.py | 57 ++++++ tests/custom_method_names/__init__.py | 0 tests/custom_method_names/conftest.py | 15 ++ tests/custom_method_names/model.py | 38 ++++ .../test_custom_method_names.py | 37 ++++ tests/default_config/__init__.py | 0 .../__snapshots__/test_queries.ambr | 0 .../__snapshots__/test_seed_data.ambr | 0 tests/default_config/conftest.py | 30 +++ tests/{ => default_config}/model.py | 16 +- .../seed_data/__init__.py | 2 +- .../seed_data/parent_child_childchild.py | 2 +- tests/{ => default_config}/test_queries.py | 2 +- tests/{ => default_config}/test_seed_data.py | 2 +- tests/default_config/test_type_hints.py | 178 ++++++++++++++++++ tests/disabled_methods/__init__.py | 0 tests/disabled_methods/conftest.py | 15 ++ tests/disabled_methods/model.py | 32 ++++ .../disabled_methods/test_disabled_methods.py | 50 +++++ tests/integer_field_type/__init__.py | 0 tests/integer_field_type/conftest.py | 15 ++ tests/integer_field_type/model.py | 32 ++++ .../test_integer_field_type.py | 54 ++++++ 34 files changed, 713 insertions(+), 51 deletions(-) create mode 100644 tests/custom_default_value/__init__.py create mode 100644 tests/custom_default_value/conftest.py create mode 100644 tests/custom_default_value/model.py create mode 100644 tests/custom_default_value/test_custom_default_value.py create mode 100644 tests/custom_deleted_field_name/__init__.py create mode 100644 tests/custom_deleted_field_name/conftest.py create mode 100644 tests/custom_deleted_field_name/model.py create mode 100644 tests/custom_deleted_field_name/test_custom_field_name.py create mode 100644 tests/custom_method_names/__init__.py create mode 100644 tests/custom_method_names/conftest.py create mode 100644 tests/custom_method_names/model.py create mode 100644 tests/custom_method_names/test_custom_method_names.py create mode 100644 tests/default_config/__init__.py rename tests/{ => default_config}/__snapshots__/test_queries.ambr (100%) rename tests/{ => default_config}/__snapshots__/test_seed_data.ambr (100%) create mode 100644 tests/default_config/conftest.py rename tests/{ => default_config}/model.py (81%) rename tests/{ => default_config}/seed_data/__init__.py (85%) rename tests/{ => default_config}/seed_data/parent_child_childchild.py (96%) rename tests/{ => default_config}/test_queries.py (99%) rename tests/{ => default_config}/test_seed_data.py (89%) create mode 100644 tests/default_config/test_type_hints.py create mode 100644 tests/disabled_methods/__init__.py create mode 100644 tests/disabled_methods/conftest.py create mode 100644 tests/disabled_methods/model.py create mode 100644 tests/disabled_methods/test_disabled_methods.py create mode 100644 tests/integer_field_type/__init__.py create mode 100644 tests/integer_field_type/conftest.py create mode 100644 tests/integer_field_type/model.py create mode 100644 tests/integer_field_type/test_integer_field_type.py diff --git a/README.md b/README.md index 8a451a3..626f782 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ pip install sqlalchemy-easy-softdelete ```py from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class from sqlalchemy_easy_softdelete.hook import IgnoredTable -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import declarative_base, Mapped from sqlalchemy import Column, Integer from datetime import datetime @@ -37,8 +37,10 @@ class SoftDeleteMixin(generate_soft_delete_mixin_class( # even if the table has the soft-delete column ignored_tables=[IgnoredTable(table_schema="public", name="cars"),] )): - # type hint for autocomplete IDE support - deleted_at: datetime + # Type hint for IDE autocomplete and type checker support. + # Using Mapped[T | None] ensures type checkers understand this is a + # SQLAlchemy column that supports query operations like .where() + deleted_at: Mapped[datetime | None] # Apply the mixin to your Models Base = declarative_base() diff --git a/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py b/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py index a970b2f..ea5d22a 100644 --- a/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py +++ b/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py @@ -9,16 +9,15 @@ from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter from sqlalchemy_easy_softdelete.hook import IgnoredTable -global_rewriter: SoftDeleteQueryRewriter | None = None - def activate_soft_delete_hook( deleted_field_name: str, disable_soft_delete_option_name: str, ignored_tables: list[IgnoredTable] -): - """Activate an event hook to rewrite the queries.""" +) -> SoftDeleteQueryRewriter: + """Activate an event hook to rewrite the queries. - global global_rewriter - global_rewriter = SoftDeleteQueryRewriter( + Returns the SoftDeleteQueryRewriter instance for use by the mixin class. + """ + rewriter = SoftDeleteQueryRewriter( deleted_field_name=deleted_field_name, disable_soft_delete_option_name=disable_soft_delete_option_name, ignored_tables=ignored_tables, @@ -30,10 +29,12 @@ def soft_delete_execute(state: ORMExecuteState): if not state.is_select: return - # Rewrite the statement - adapted = global_rewriter.rewrite_statement(state.statement) + # Rewrite the statement (closure captures local `rewriter`) + adapted = rewriter.rewrite_statement(state.statement) # Replace the statement # Cast needed because Statement type includes LambdaElement which mypy # doesn't recognize as Executable (even though it is at runtime) state.statement = cast(Executable, adapted) + + return rewriter diff --git a/sqlalchemy_easy_softdelete/mixin.py b/sqlalchemy_easy_softdelete/mixin.py index 8546fdb..d34b606 100644 --- a/sqlalchemy_easy_softdelete/mixin.py +++ b/sqlalchemy_easy_softdelete/mixin.py @@ -45,7 +45,11 @@ def undelete_method(_self): class_attributes[undelete_method_name] = undelete_method - activate_soft_delete_hook(deleted_field_name, disable_soft_delete_filtering_option_name, ignored_tables) + # Activate the soft delete hook and get the rewriter instance + rewriter = activate_soft_delete_hook(deleted_field_name, disable_soft_delete_filtering_option_name, ignored_tables) + + # Store rewriter on the generated class for testing purposes + class_attributes["_sqlalchemy_easy_softdelete_rewriter"] = rewriter generated_class = type(class_name, tuple(), class_attributes) diff --git a/tests/conftest.py b/tests/conftest.py index 0a196cc..5f7b22c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,18 +3,12 @@ import pytest from sqlalchemy import create_engine from sqlalchemy.engine import Connection, Engine -from sqlalchemy.orm import Session, sessionmaker - -from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter -from tests.model import TestModelBase -from tests.seed_data import generate_table_with_inheritance_obj -from tests.seed_data.parent_child_childchild import generate_parent_child_object_hierarchy env_connection_string = os.environ.get("TEST_CONNECTION_STRING", None) @pytest.fixture -def sqla2_warnings() -> Engine: +def sqla2_warnings() -> None: # Enable SQLAlchemy 2.0 Warnings mode to help with 2.0 support os.environ["SQLALCHEMY_WARN_20"] = "1" @@ -38,28 +32,3 @@ def db_connection(db_engine) -> Connection: finally: transaction.rollback() connection.close() - - -@pytest.fixture -def db_session(db_connection) -> Session: - TestModelBase.metadata.create_all(db_connection) - return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() - - -@pytest.fixture -def seeded_session(db_session) -> Session: - generate_parent_child_object_hierarchy(db_session, 1000) - generate_parent_child_object_hierarchy(db_session, 1001) - generate_parent_child_object_hierarchy(db_session, 1002, parent_deleted=True) - - generate_table_with_inheritance_obj(db_session, 1000, deleted=False) - generate_table_with_inheritance_obj(db_session, 1001, deleted=False) - generate_table_with_inheritance_obj(db_session, 1002, deleted=True) - return db_session - - -@pytest.fixture -def rewriter() -> SoftDeleteQueryRewriter: - from sqlalchemy_easy_softdelete.handler.sqlalchemy_easy_softdelete import global_rewriter - - return global_rewriter diff --git a/tests/custom_default_value/__init__.py b/tests/custom_default_value/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/custom_default_value/conftest.py b/tests/custom_default_value/conftest.py new file mode 100644 index 0000000..3bc4169 --- /dev/null +++ b/tests/custom_default_value/conftest.py @@ -0,0 +1,15 @@ +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from tests.custom_default_value.model import CDVModelBase, CDVSoftDeleteMixin + + +@pytest.fixture +def db_session(db_connection) -> Session: + CDVModelBase.metadata.create_all(db_connection) + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def rewriter(): + return CDVSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter diff --git a/tests/custom_default_value/model.py b/tests/custom_default_value/model.py new file mode 100644 index 0000000..4dfa7e6 --- /dev/null +++ b/tests/custom_default_value/model.py @@ -0,0 +1,31 @@ +"""Models for testing custom default value option.""" + +from datetime import datetime, timezone + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import Mapped, as_declarative, declared_attr + +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class CDVModelBase: + """CDV = Custom Default Value""" + + @declared_attr + def __tablename__(cls) -> str: + return cls.__name__.lower() + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class CDVSoftDeleteMixin( + generate_soft_delete_mixin_class( + delete_method_default_value=lambda: datetime(2000, 1, 1, tzinfo=timezone.utc), + ) +): + deleted_at: Mapped[datetime | None] + + +class CDVTable(CDVModelBase, CDVSoftDeleteMixin): + value = Column(Integer) diff --git a/tests/custom_default_value/test_custom_default_value.py b/tests/custom_default_value/test_custom_default_value.py new file mode 100644 index 0000000..d1b1562 --- /dev/null +++ b/tests/custom_default_value/test_custom_default_value.py @@ -0,0 +1,31 @@ +"""Tests for custom default value option.""" + +from datetime import datetime, timezone + +from tests.custom_default_value.model import CDVTable + + +def test_delete_uses_custom_default_value(db_session): + """Verify delete() uses the custom default value function.""" + obj = CDVTable(value=1) + db_session.add(obj) + db_session.commit() + + obj.delete() + + # Should use our custom date (2000-01-01) + # SQLite doesn't preserve timezone, so compare without it + assert obj.deleted_at.replace(tzinfo=None) == datetime(2000, 1, 1) + + +def test_delete_with_explicit_value_overrides_default(db_session): + """Verify delete(value) uses the passed value instead of default.""" + obj = CDVTable(value=1) + db_session.add(obj) + db_session.commit() + + custom_date = datetime(2020, 6, 15, 12, 30, tzinfo=timezone.utc) + obj.delete(custom_date) + + # SQLite doesn't preserve timezone, so compare without it + assert obj.deleted_at.replace(tzinfo=None) == custom_date.replace(tzinfo=None) diff --git a/tests/custom_deleted_field_name/__init__.py b/tests/custom_deleted_field_name/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/custom_deleted_field_name/conftest.py b/tests/custom_deleted_field_name/conftest.py new file mode 100644 index 0000000..6b99ac9 --- /dev/null +++ b/tests/custom_deleted_field_name/conftest.py @@ -0,0 +1,15 @@ +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from tests.custom_deleted_field_name.model import CFNModelBase, CFNSoftDeleteMixin + + +@pytest.fixture +def db_session(db_connection) -> Session: + CFNModelBase.metadata.create_all(db_connection) + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def rewriter(): + return CFNSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter diff --git a/tests/custom_deleted_field_name/model.py b/tests/custom_deleted_field_name/model.py new file mode 100644 index 0000000..3686ebf --- /dev/null +++ b/tests/custom_deleted_field_name/model.py @@ -0,0 +1,31 @@ +"""Models for testing custom deleted_field_name option.""" + +from datetime import datetime + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import Mapped, as_declarative, declared_attr + +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class CFNModelBase: + """CFN = Custom Field Name""" + + @declared_attr + def __tablename__(cls) -> str: + return cls.__name__.lower() + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class CFNSoftDeleteMixin( + generate_soft_delete_mixin_class( + deleted_field_name="removed_at", + ) +): + removed_at: Mapped[datetime | None] + + +class CFNTable(CFNModelBase, CFNSoftDeleteMixin): + value = Column(Integer) diff --git a/tests/custom_deleted_field_name/test_custom_field_name.py b/tests/custom_deleted_field_name/test_custom_field_name.py new file mode 100644 index 0000000..af53043 --- /dev/null +++ b/tests/custom_deleted_field_name/test_custom_field_name.py @@ -0,0 +1,57 @@ +"""Tests for custom deleted_field_name option.""" + +from datetime import datetime, timezone + +from tests.custom_deleted_field_name.model import CFNTable + + +def test_custom_field_name_column_exists(): + """Verify the column uses the custom field name.""" + assert "removed_at" in CFNTable.__table__.columns + assert "deleted_at" not in CFNTable.__table__.columns + + +def test_rewriter_has_correct_field_name(rewriter): + """Verify the rewriter is configured with the custom field name.""" + assert rewriter.deleted_field_name == "removed_at" + + +def test_delete_sets_custom_field(db_session): + """Verify delete() sets the custom field.""" + obj = CFNTable(value=1) + db_session.add(obj) + db_session.commit() + + assert obj.removed_at is None + obj.delete() + assert obj.removed_at is not None + + +def test_undelete_clears_custom_field(db_session): + """Verify undelete() clears the custom field.""" + obj = CFNTable(value=1) + db_session.add(obj) + db_session.commit() + + obj.delete() + assert obj.removed_at is not None + + obj.undelete() + assert obj.removed_at is None + + +def test_soft_delete_filtering_uses_custom_field(db_session): + """Verify soft-delete filtering works with custom field name.""" + active = CFNTable(value=1) + deleted = CFNTable(value=2) + deleted.removed_at = datetime.now(timezone.utc) + + db_session.add_all([active, deleted]) + db_session.commit() + + results = db_session.query(CFNTable).all() + assert len(results) == 1 + assert results[0].value == 1 + + all_results = db_session.query(CFNTable).execution_options(include_deleted=True).all() + assert len(all_results) == 2 diff --git a/tests/custom_method_names/__init__.py b/tests/custom_method_names/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/custom_method_names/conftest.py b/tests/custom_method_names/conftest.py new file mode 100644 index 0000000..37ad9ef --- /dev/null +++ b/tests/custom_method_names/conftest.py @@ -0,0 +1,15 @@ +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from tests.custom_method_names.model import CMNModelBase, CMNSoftDeleteMixin + + +@pytest.fixture +def db_session(db_connection) -> Session: + CMNModelBase.metadata.create_all(db_connection) + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def rewriter(): + return CMNSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter diff --git a/tests/custom_method_names/model.py b/tests/custom_method_names/model.py new file mode 100644 index 0000000..8de788b --- /dev/null +++ b/tests/custom_method_names/model.py @@ -0,0 +1,38 @@ +"""Models for testing custom method names option.""" + +from datetime import datetime + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import Mapped, as_declarative, declared_attr + +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class CMNModelBase: + """CMN = Custom Method Names""" + + @declared_attr + def __tablename__(cls) -> str: + return cls.__name__.lower() + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class CMNSoftDeleteMixin( + generate_soft_delete_mixin_class( + delete_method_name="soft_delete", + undelete_method_name="restore", + ) +): + deleted_at: Mapped[datetime | None] + + def soft_delete(self) -> None: + super().soft_delete() # type: ignore[misc] + + def restore(self) -> None: + super().restore() # type: ignore[misc] + + +class CMNTable(CMNModelBase, CMNSoftDeleteMixin): + value = Column(Integer) diff --git a/tests/custom_method_names/test_custom_method_names.py b/tests/custom_method_names/test_custom_method_names.py new file mode 100644 index 0000000..e1872eb --- /dev/null +++ b/tests/custom_method_names/test_custom_method_names.py @@ -0,0 +1,37 @@ +"""Tests for custom method names option.""" + +from tests.custom_method_names.model import CMNSoftDeleteMixin, CMNTable + + +def test_custom_method_names_exist(): + """Verify custom method names are used.""" + assert hasattr(CMNSoftDeleteMixin, "soft_delete") + assert hasattr(CMNSoftDeleteMixin, "restore") + # Original names should not exist on the generated parent class + generated_class = CMNSoftDeleteMixin.__bases__[0] + assert not hasattr(generated_class, "delete") + assert not hasattr(generated_class, "undelete") + + +def test_soft_delete_method_sets_deleted_at(db_session): + """Verify soft_delete() method works.""" + obj = CMNTable(value=1) + db_session.add(obj) + db_session.commit() + + assert obj.deleted_at is None + obj.soft_delete() + assert obj.deleted_at is not None + + +def test_restore_method_clears_deleted_at(db_session): + """Verify restore() method works.""" + obj = CMNTable(value=1) + db_session.add(obj) + db_session.commit() + + obj.soft_delete() + assert obj.deleted_at is not None + + obj.restore() + assert obj.deleted_at is None diff --git a/tests/default_config/__init__.py b/tests/default_config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__snapshots__/test_queries.ambr b/tests/default_config/__snapshots__/test_queries.ambr similarity index 100% rename from tests/__snapshots__/test_queries.ambr rename to tests/default_config/__snapshots__/test_queries.ambr diff --git a/tests/__snapshots__/test_seed_data.ambr b/tests/default_config/__snapshots__/test_seed_data.ambr similarity index 100% rename from tests/__snapshots__/test_seed_data.ambr rename to tests/default_config/__snapshots__/test_seed_data.ambr diff --git a/tests/default_config/conftest.py b/tests/default_config/conftest.py new file mode 100644 index 0000000..1e0ab5d --- /dev/null +++ b/tests/default_config/conftest.py @@ -0,0 +1,30 @@ +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from tests.default_config.model import SoftDeleteMixin, TestModelBase +from tests.default_config.seed_data import generate_table_with_inheritance_obj +from tests.default_config.seed_data.parent_child_childchild import generate_parent_child_object_hierarchy + + +@pytest.fixture +def db_session(db_connection) -> Session: + TestModelBase.metadata.create_all(db_connection) + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def seeded_session(db_session) -> Session: + generate_parent_child_object_hierarchy(db_session, 1000) + generate_parent_child_object_hierarchy(db_session, 1001) + generate_parent_child_object_hierarchy(db_session, 1002, parent_deleted=True) + + generate_table_with_inheritance_obj(db_session, 1000, deleted=False) + generate_table_with_inheritance_obj(db_session, 1001, deleted=False) + generate_table_with_inheritance_obj(db_session, 1002, deleted=True) + return db_session + + +@pytest.fixture +def rewriter(): + # Access world-specific rewriter from mixin class + return SoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter diff --git a/tests/model.py b/tests/default_config/model.py similarity index 81% rename from tests/model.py rename to tests/default_config/model.py index bd7a951..02fee69 100644 --- a/tests/model.py +++ b/tests/default_config/model.py @@ -1,7 +1,7 @@ from datetime import datetime from sqlalchemy import Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.orm import as_declarative, declared_attr, relationship +from sqlalchemy.orm import Mapped, as_declarative, declared_attr, relationship from sqlalchemy_easy_softdelete.hook import IgnoredTable from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class @@ -26,8 +26,18 @@ class SoftDeleteMixin( ], ) ): - # for autocomplete - deleted_at: datetime + # Type hint for IDE autocomplete and type checker support. + # Using Mapped[T | None] ensures type checkers understand this is a + # SQLAlchemy column that supports query operations like .where() + deleted_at: Mapped[datetime | None] + + # Optional: Add method stubs for delete/undelete if you want type hints for these. + # The actual implementations are provided by the generated mixin class. + def delete(self, v: datetime | None = None) -> None: + super().delete(v) # type: ignore[misc] + + def undelete(self) -> None: + super().undelete() # type: ignore[misc] class SDSimpleTable(TestModelBase, SoftDeleteMixin): diff --git a/tests/seed_data/__init__.py b/tests/default_config/seed_data/__init__.py similarity index 85% rename from tests/seed_data/__init__.py rename to tests/default_config/seed_data/__init__.py index 4709082..50f2bf2 100644 --- a/tests/seed_data/__init__.py +++ b/tests/default_config/seed_data/__init__.py @@ -2,7 +2,7 @@ from sqlalchemy.orm import Session -from tests.model import SDDerivedRequest +from tests.default_config.model import SDDerivedRequest def generate_table_with_inheritance_obj(s: Session, obj_id: int, deleted: bool = False): diff --git a/tests/seed_data/parent_child_childchild.py b/tests/default_config/seed_data/parent_child_childchild.py similarity index 96% rename from tests/seed_data/parent_child_childchild.py rename to tests/default_config/seed_data/parent_child_childchild.py index f402710..4235c54 100644 --- a/tests/seed_data/parent_child_childchild.py +++ b/tests/default_config/seed_data/parent_child_childchild.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import Session -from tests.model import SDChild, SDChildChild, SDParent +from tests.default_config.model import SDChild, SDChildChild, SDParent TEST_EPOCH = datetime.datetime(year=1985, month=8, day=4) diff --git a/tests/test_queries.py b/tests/default_config/test_queries.py similarity index 99% rename from tests/test_queries.py rename to tests/default_config/test_queries.py index 5b622ad..ebc751b 100644 --- a/tests/test_queries.py +++ b/tests/default_config/test_queries.py @@ -7,7 +7,7 @@ from sqlalchemy.sql import CompoundSelect, Select from sqlalchemy.sql.lambdas import LambdaElement, LinkedLambdaElement, StatementLambdaElement -from tests.model import ( +from tests.default_config.model import ( SDBaseRequest, SDChild, SDChildChild, diff --git a/tests/test_seed_data.py b/tests/default_config/test_seed_data.py similarity index 89% rename from tests/test_seed_data.py rename to tests/default_config/test_seed_data.py index e79399a..14cbb64 100644 --- a/tests/test_seed_data.py +++ b/tests/default_config/test_seed_data.py @@ -1,6 +1,6 @@ """Tests for `sqlalchemy_easy_softdelete` package.""" -from tests.model import SDChild, SDChildChild, SDParent +from tests.default_config.model import SDChild, SDChildChild, SDParent def test_ensure_stable_seed_data(snapshot, seeded_session): diff --git a/tests/default_config/test_type_hints.py b/tests/default_config/test_type_hints.py new file mode 100644 index 0000000..0275d4f --- /dev/null +++ b/tests/default_config/test_type_hints.py @@ -0,0 +1,178 @@ +"""Tests for type hint compatibility with SQLAlchemy operations. + +These tests verify that the Mapped[T | None] type hint recommendation works +correctly with SQLAlchemy query operations. See GitHub issue #31. + +The key insight is that using `deleted_at: datetime` as a type hint causes +type checkers to treat the attribute as a plain datetime, which breaks +type checking for expressions like `.where(Model.deleted_at < value)`. + +Using `deleted_at: Mapped[datetime | None]` tells the type checker this is +a SQLAlchemy mapped column that supports comparison operations. +""" + +from datetime import datetime, timezone + +from sqlalchemy import select +from sqlalchemy.sql.elements import BinaryExpression + +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter +from tests.default_config.model import SDChild, SDParent, SoftDeleteMixin + + +def test_deleted_at_column_supports_comparison_operators(): + """Verify that deleted_at can be used with comparison operators. + + This tests that our type hints allow using the column in expressions. + With the old `deleted_at: datetime` hint, type checkers would complain + that datetime doesn't support these operations in a SQLAlchemy context. + """ + # These should all work without type errors when using Mapped[datetime | None] + now = datetime.now(timezone.utc) + + # Less than + expr_lt = SDChild.deleted_at < now + assert isinstance(expr_lt, BinaryExpression) + + # Greater than + expr_gt = SDChild.deleted_at > now + assert isinstance(expr_gt, BinaryExpression) + + # Equals + expr_eq = SDChild.deleted_at == now + assert isinstance(expr_eq, BinaryExpression) + + # Not equals + expr_ne = SDChild.deleted_at != now + assert isinstance(expr_ne, BinaryExpression) + + # IS NULL + expr_is_none = SDChild.deleted_at.is_(None) + assert isinstance(expr_is_none, BinaryExpression) + + # IS NOT NULL + expr_is_not_none = SDChild.deleted_at.isnot(None) + assert isinstance(expr_is_not_none, BinaryExpression) + + +def test_deleted_at_column_works_in_where_clause(): + """Verify that deleted_at can be used in .where() clauses. + + This is the primary use case from issue #31 - the user was getting + type errors when using .where(Model.deleted_at < value). + """ + now = datetime.now(timezone.utc) + + # Build a select statement with a where clause using deleted_at + stmt = select(SDChild).where(SDChild.deleted_at < now) + + # The statement should compile without errors + assert stmt is not None + assert "deleted_at" in str(stmt) + + +def test_deleted_at_column_works_in_filter(): + """Verify that deleted_at works with the legacy .filter() method.""" + now = datetime.now(timezone.utc) + + # Using filter (ORM Query style) + stmt = select(SDParent).filter(SDParent.deleted_at > now) + + assert stmt is not None + assert "deleted_at" in str(stmt) + + +def test_deleted_at_column_works_with_between(): + """Verify that deleted_at works with .between().""" + now = datetime.now(timezone.utc) + earlier = datetime(2020, 1, 1, tzinfo=timezone.utc) + + expr = SDChild.deleted_at.between(earlier, now) + assert isinstance(expr, BinaryExpression) + + +def test_delete_method_is_callable(seeded_session): + """Verify that the delete() method stub works correctly. + + The SoftDeleteMixin provides method stubs for delete() and undelete() + so that type checkers know these methods exist. + """ + # Get an instance + child = seeded_session.query(SDChild).first() + assert child is not None + + # The delete method should be callable + assert hasattr(child, "delete") + assert callable(child.delete) + + # Call delete and verify it sets deleted_at + child.delete() + assert child.deleted_at is not None + + +def test_delete_without_value_uses_default(seeded_session): + """Verify delete() without value uses the default function (current time).""" + child = seeded_session.query(SDChild).first() + assert child is not None + + before = datetime.now(timezone.utc) + child.delete() + after = datetime.now(timezone.utc) + + # Should be between before and after (current time) + # SQLite doesn't preserve timezone, so compare without it + deleted_at = child.deleted_at.replace(tzinfo=timezone.utc) if child.deleted_at.tzinfo is None else child.deleted_at + assert before <= deleted_at <= after + + +def test_delete_with_custom_value(seeded_session): + """Verify delete(value) uses the passed value instead of default.""" + child = seeded_session.query(SDChild).first() + assert child is not None + + custom_date = datetime(2020, 6, 15, 12, 30, tzinfo=timezone.utc) + child.delete(custom_date) + + # SQLite doesn't preserve timezone, so compare without it + assert child.deleted_at.replace(tzinfo=None) == custom_date.replace(tzinfo=None) + + +def test_undelete_method_is_callable(seeded_session): + """Verify that the undelete() method stub works correctly.""" + # Get a deleted instance + child = ( + seeded_session.query(SDChild) + .execution_options(include_deleted=True) + .filter(SDChild.deleted_at.isnot(None)) + .first() + ) + assert child is not None + assert child.deleted_at is not None + + # The undelete method should be callable + assert hasattr(child, "undelete") + assert callable(child.undelete) + + # Call undelete and verify it clears deleted_at + child.undelete() + assert child.deleted_at is None + + +def test_mixin_class_has_correct_type_annotations(): + """Verify that the SoftDeleteMixin has the expected type annotations.""" + annotations = SoftDeleteMixin.__annotations__ + + # Should have deleted_at annotation + assert "deleted_at" in annotations + + # The annotation should include Mapped + annotation_str = str(annotations["deleted_at"]) + assert "Mapped" in annotation_str or "datetime" in annotation_str + + +def test_rewriter_is_attached_to_mixin(): + """Verify the rewriter is attached to the mixin class.""" + assert hasattr(SoftDeleteMixin, "_sqlalchemy_easy_softdelete_rewriter") + rewriter = SoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter + assert isinstance(rewriter, SoftDeleteQueryRewriter) + assert rewriter.deleted_field_name == "deleted_at" diff --git a/tests/disabled_methods/__init__.py b/tests/disabled_methods/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/disabled_methods/conftest.py b/tests/disabled_methods/conftest.py new file mode 100644 index 0000000..c8d93d4 --- /dev/null +++ b/tests/disabled_methods/conftest.py @@ -0,0 +1,15 @@ +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from tests.disabled_methods.model import DMModelBase, DMSoftDeleteMixin + + +@pytest.fixture +def db_session(db_connection) -> Session: + DMModelBase.metadata.create_all(db_connection) + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def rewriter(): + return DMSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter diff --git a/tests/disabled_methods/model.py b/tests/disabled_methods/model.py new file mode 100644 index 0000000..3f1b699 --- /dev/null +++ b/tests/disabled_methods/model.py @@ -0,0 +1,32 @@ +"""Models for testing disabled methods option.""" + +from datetime import datetime + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import Mapped, as_declarative, declared_attr + +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class DMModelBase: + """DM = Disabled Methods""" + + @declared_attr + def __tablename__(cls) -> str: + return cls.__name__.lower() + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class DMSoftDeleteMixin( + generate_soft_delete_mixin_class( + generate_delete_method=False, + generate_undelete_method=False, + ) +): + deleted_at: Mapped[datetime | None] + + +class DMTable(DMModelBase, DMSoftDeleteMixin): + value = Column(Integer) diff --git a/tests/disabled_methods/test_disabled_methods.py b/tests/disabled_methods/test_disabled_methods.py new file mode 100644 index 0000000..d9fc7cb --- /dev/null +++ b/tests/disabled_methods/test_disabled_methods.py @@ -0,0 +1,50 @@ +"""Tests for disabled methods option.""" + +from datetime import datetime, timezone + +from tests.disabled_methods.model import DMSoftDeleteMixin, DMTable + + +def test_no_delete_method_on_generated_class(): + """Verify delete/undelete methods are not generated.""" + generated_class = DMSoftDeleteMixin.__bases__[0] + assert not hasattr(generated_class, "delete") + assert not hasattr(generated_class, "undelete") + + +def test_can_manually_set_deleted_at(db_session): + """Verify we can still manually set deleted_at.""" + obj = DMTable(value=1) + db_session.add(obj) + db_session.commit() + + obj_id = obj.id + assert obj.deleted_at is None + + deleted_time = datetime.now(timezone.utc) + obj.deleted_at = deleted_time + db_session.commit() + + # After commit, the object is soft-deleted so query with include_deleted + result = ( + db_session.query(DMTable) + .execution_options(include_deleted=True) + .filter(DMTable.id == obj_id) + .first() + ) + assert result is not None + assert result.deleted_at is not None + + +def test_soft_delete_filtering_still_works(db_session): + """Verify soft-delete filtering works even without methods.""" + active = DMTable(value=1) + deleted = DMTable(value=2) + deleted.deleted_at = datetime.now(timezone.utc) + + db_session.add_all([active, deleted]) + db_session.commit() + + results = db_session.query(DMTable).all() + assert len(results) == 1 + assert results[0].value == 1 diff --git a/tests/integer_field_type/__init__.py b/tests/integer_field_type/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integer_field_type/conftest.py b/tests/integer_field_type/conftest.py new file mode 100644 index 0000000..6232310 --- /dev/null +++ b/tests/integer_field_type/conftest.py @@ -0,0 +1,15 @@ +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from tests.integer_field_type.model import IFTModelBase, IFTSoftDeleteMixin + + +@pytest.fixture +def db_session(db_connection) -> Session: + IFTModelBase.metadata.create_all(db_connection) + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() + + +@pytest.fixture +def rewriter(): + return IFTSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter diff --git a/tests/integer_field_type/model.py b/tests/integer_field_type/model.py new file mode 100644 index 0000000..db22ef1 --- /dev/null +++ b/tests/integer_field_type/model.py @@ -0,0 +1,32 @@ +"""Models for testing integer field type option.""" + +from datetime import datetime, timezone + +from sqlalchemy import Column, Integer +from sqlalchemy.orm import Mapped, as_declarative, declared_attr + +from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class + + +@as_declarative() +class IFTModelBase: + """IFT = Integer Field Type""" + + @declared_attr + def __tablename__(cls) -> str: + return cls.__name__.lower() + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class IFTSoftDeleteMixin( + generate_soft_delete_mixin_class( + deleted_field_type=Integer(), + delete_method_default_value=lambda: int(datetime.now(timezone.utc).timestamp()), + ) +): + deleted_at: Mapped[int | None] + + +class IFTTable(IFTModelBase, IFTSoftDeleteMixin): + value = Column(Integer) diff --git a/tests/integer_field_type/test_integer_field_type.py b/tests/integer_field_type/test_integer_field_type.py new file mode 100644 index 0000000..cfe3ea9 --- /dev/null +++ b/tests/integer_field_type/test_integer_field_type.py @@ -0,0 +1,54 @@ +"""Tests for integer field type option.""" + +from datetime import datetime, timezone + +from sqlalchemy import Integer + +from tests.integer_field_type.model import IFTTable + + +def test_field_is_integer_type(): + """Verify the field uses Integer type.""" + column = IFTTable.__table__.columns["deleted_at"] + assert isinstance(column.type, Integer) + + +def test_delete_sets_integer_timestamp(db_session): + """Verify delete() sets an integer timestamp.""" + obj = IFTTable(value=1) + db_session.add(obj) + db_session.commit() + + before = int(datetime.now(timezone.utc).timestamp()) + obj.delete() + after = int(datetime.now(timezone.utc).timestamp()) + + assert isinstance(obj.deleted_at, int) + assert before <= obj.deleted_at <= after + + +def test_undelete_clears_integer_field(db_session): + """Verify undelete() clears the integer field.""" + obj = IFTTable(value=1) + db_session.add(obj) + db_session.commit() + + obj.delete() + assert obj.deleted_at is not None + + obj.undelete() + assert obj.deleted_at is None + + +def test_integer_field_soft_delete_filtering(db_session): + """Verify soft-delete filtering works with integer field.""" + active = IFTTable(value=1) + deleted = IFTTable(value=2) + deleted.deleted_at = int(datetime.now(timezone.utc).timestamp()) + + db_session.add_all([active, deleted]) + db_session.commit() + + results = db_session.query(IFTTable).all() + assert len(results) == 1 + assert results[0].value == 1 From b0f2ad6fa506ed3bf0cdcac03e77e9394e402656 Mon Sep 17 00:00:00 2001 From: Cadu Date: Sun, 25 Jan 2026 20:38:16 -0300 Subject: [PATCH 2/2] Add mypy type checking for type hints test - Add mypy configuration to pyproject.toml - Add `make typecheck-hints` target to validate type hints test file - This ensures the Mapped[datetime | None] type hint is verified at compile time --- Makefile | 4 +- README.md | 22 +++++-- mypy.ini | 10 +-- tests/conftest.py | 5 +- tests/custom_default_value/conftest.py | 12 ++-- tests/custom_default_value/model.py | 11 +--- tests/custom_deleted_field_name/conftest.py | 12 ++-- tests/custom_deleted_field_name/model.py | 11 +--- tests/custom_method_names/conftest.py | 12 ++-- tests/custom_method_names/model.py | 11 +--- tests/default_config/conftest.py | 15 +++-- tests/default_config/model.py | 63 +++++++++---------- .../seed_data/parent_child_childchild.py | 8 +-- tests/default_config/test_type_hints.py | 3 +- tests/disabled_methods/conftest.py | 12 ++-- tests/disabled_methods/model.py | 11 +--- .../disabled_methods/test_disabled_methods.py | 7 +-- tests/integer_field_type/conftest.py | 12 ++-- tests/integer_field_type/model.py | 11 +--- tests/utils/__init__.py | 20 +++--- tests/utils/simple_select_extractor.py | 16 ++--- 21 files changed, 147 insertions(+), 141 deletions(-) diff --git a/Makefile b/Makefile index 7ccd039..95a7cf8 100644 --- a/Makefile +++ b/Makefile @@ -6,9 +6,9 @@ sources = sqlalchemy_easy_softdelete lint: uv run pre-commit run --all-files -# Run type checking (mypy) +# Run type checking (mypy) on source code and tests typecheck: - uv run mypy $(sources) + uv run mypy $(sources) tests/ # Quick test with SQLite (no docker needed) test: diff --git a/README.md b/README.md index 626f782..036cdfd 100644 --- a/README.md +++ b/README.md @@ -32,16 +32,28 @@ from sqlalchemy import Column, Integer from datetime import datetime # Create a Class that inherits from our class builder -class SoftDeleteMixin(generate_soft_delete_mixin_class( - # This table will be ignored by the hook - # even if the table has the soft-delete column - ignored_tables=[IgnoredTable(table_schema="public", name="cars"),] -)): +class SoftDeleteMixin( + generate_soft_delete_mixin_class( # type: ignore[misc] + # This table will be ignored by the hook + # even if the table has the soft-delete column + ignored_tables=[IgnoredTable(table_schema="public", name="cars"),] + ) +): + # type: ignore[misc] is required because the mixin is dynamically generated + # Type hint for IDE autocomplete and type checker support. # Using Mapped[T | None] ensures type checkers understand this is a # SQLAlchemy column that supports query operations like .where() deleted_at: Mapped[datetime | None] + # Optional: Add method stubs for delete/undelete for type checker support. + # The actual implementations are provided by the generated mixin class. + def delete(self, v: datetime | None = None) -> None: + super().delete(v) # type: ignore[misc] + + def undelete(self) -> None: + super().undelete() # type: ignore[misc] + # Apply the mixin to your Models Base = declarative_base() diff --git a/mypy.ini b/mypy.ini index 053bccf..cc989f1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,12 +1,12 @@ [mypy] -packages = sqlalchemy_easy_softdelete +packages = sqlalchemy_easy_softdelete,tests python_version = 3.10 -# Lenient settings - this codebase wasn't originally typed -disallow_untyped_calls = false +# Strictness settings +disallow_untyped_calls = true disallow_untyped_defs = false -disallow_untyped_decorators = false +disallow_untyped_decorators = true check_untyped_defs = false -ignore_missing_imports = true +ignore_missing_imports = false allow_redefinition = true warn_unused_configs = true diff --git a/tests/conftest.py b/tests/conftest.py index 5f7b22c..981afb3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +from collections.abc import Generator import pytest from sqlalchemy import create_engine @@ -14,14 +15,14 @@ def sqla2_warnings() -> None: @pytest.fixture -def db_engine(sqla2_warnings) -> Engine: +def db_engine(sqla2_warnings: None) -> Engine: test_db_url = env_connection_string or "sqlite://" print(f"connection_string={test_db_url}") return create_engine(test_db_url, future=True) @pytest.fixture -def db_connection(db_engine) -> Connection: +def db_connection(db_engine: Engine) -> Generator[Connection, None, None]: connection = db_engine.connect() # start a transaction diff --git a/tests/custom_default_value/conftest.py b/tests/custom_default_value/conftest.py index 3bc4169..1f8121a 100644 --- a/tests/custom_default_value/conftest.py +++ b/tests/custom_default_value/conftest.py @@ -1,15 +1,19 @@ +from typing import cast + import pytest +from sqlalchemy.engine import Connection from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter from tests.custom_default_value.model import CDVModelBase, CDVSoftDeleteMixin @pytest.fixture -def db_session(db_connection) -> Session: - CDVModelBase.metadata.create_all(db_connection) +def db_session(db_connection: Connection) -> Session: + CDVModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() @pytest.fixture -def rewriter(): - return CDVSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, CDVSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/custom_default_value/model.py b/tests/custom_default_value/model.py index 4dfa7e6..37811cb 100644 --- a/tests/custom_default_value/model.py +++ b/tests/custom_default_value/model.py @@ -1,9 +1,7 @@ -"""Models for testing custom default value option.""" - from datetime import datetime, timezone from sqlalchemy import Column, Integer -from sqlalchemy.orm import Mapped, as_declarative, declared_attr +from sqlalchemy.orm import Mapped, as_declarative from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class @@ -12,15 +10,11 @@ class CDVModelBase: """CDV = Custom Default Value""" - @declared_attr - def __tablename__(cls) -> str: - return cls.__name__.lower() - id = Column(Integer, primary_key=True, autoincrement=True) class CDVSoftDeleteMixin( - generate_soft_delete_mixin_class( + generate_soft_delete_mixin_class( # type: ignore[misc] delete_method_default_value=lambda: datetime(2000, 1, 1, tzinfo=timezone.utc), ) ): @@ -28,4 +22,5 @@ class CDVSoftDeleteMixin( class CDVTable(CDVModelBase, CDVSoftDeleteMixin): + __tablename__ = "cdvtable" value = Column(Integer) diff --git a/tests/custom_deleted_field_name/conftest.py b/tests/custom_deleted_field_name/conftest.py index 6b99ac9..640d014 100644 --- a/tests/custom_deleted_field_name/conftest.py +++ b/tests/custom_deleted_field_name/conftest.py @@ -1,15 +1,19 @@ +from typing import cast + import pytest +from sqlalchemy.engine import Connection from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter from tests.custom_deleted_field_name.model import CFNModelBase, CFNSoftDeleteMixin @pytest.fixture -def db_session(db_connection) -> Session: - CFNModelBase.metadata.create_all(db_connection) +def db_session(db_connection: Connection) -> Session: + CFNModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() @pytest.fixture -def rewriter(): - return CFNSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, CFNSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/custom_deleted_field_name/model.py b/tests/custom_deleted_field_name/model.py index 3686ebf..e88d675 100644 --- a/tests/custom_deleted_field_name/model.py +++ b/tests/custom_deleted_field_name/model.py @@ -1,9 +1,7 @@ -"""Models for testing custom deleted_field_name option.""" - from datetime import datetime from sqlalchemy import Column, Integer -from sqlalchemy.orm import Mapped, as_declarative, declared_attr +from sqlalchemy.orm import Mapped, as_declarative from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class @@ -12,15 +10,11 @@ class CFNModelBase: """CFN = Custom Field Name""" - @declared_attr - def __tablename__(cls) -> str: - return cls.__name__.lower() - id = Column(Integer, primary_key=True, autoincrement=True) class CFNSoftDeleteMixin( - generate_soft_delete_mixin_class( + generate_soft_delete_mixin_class( # type: ignore[misc] deleted_field_name="removed_at", ) ): @@ -28,4 +22,5 @@ class CFNSoftDeleteMixin( class CFNTable(CFNModelBase, CFNSoftDeleteMixin): + __tablename__ = "cfntable" value = Column(Integer) diff --git a/tests/custom_method_names/conftest.py b/tests/custom_method_names/conftest.py index 37ad9ef..785e8cb 100644 --- a/tests/custom_method_names/conftest.py +++ b/tests/custom_method_names/conftest.py @@ -1,15 +1,19 @@ +from typing import cast + import pytest +from sqlalchemy.engine import Connection from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter from tests.custom_method_names.model import CMNModelBase, CMNSoftDeleteMixin @pytest.fixture -def db_session(db_connection) -> Session: - CMNModelBase.metadata.create_all(db_connection) +def db_session(db_connection: Connection) -> Session: + CMNModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() @pytest.fixture -def rewriter(): - return CMNSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, CMNSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/custom_method_names/model.py b/tests/custom_method_names/model.py index 8de788b..2a57d36 100644 --- a/tests/custom_method_names/model.py +++ b/tests/custom_method_names/model.py @@ -1,9 +1,7 @@ -"""Models for testing custom method names option.""" - from datetime import datetime from sqlalchemy import Column, Integer -from sqlalchemy.orm import Mapped, as_declarative, declared_attr +from sqlalchemy.orm import Mapped, as_declarative from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class @@ -12,15 +10,11 @@ class CMNModelBase: """CMN = Custom Method Names""" - @declared_attr - def __tablename__(cls) -> str: - return cls.__name__.lower() - id = Column(Integer, primary_key=True, autoincrement=True) class CMNSoftDeleteMixin( - generate_soft_delete_mixin_class( + generate_soft_delete_mixin_class( # type: ignore[misc] delete_method_name="soft_delete", undelete_method_name="restore", ) @@ -35,4 +29,5 @@ def restore(self) -> None: class CMNTable(CMNModelBase, CMNSoftDeleteMixin): + __tablename__ = "cmntable" value = Column(Integer) diff --git a/tests/default_config/conftest.py b/tests/default_config/conftest.py index 1e0ab5d..87a0831 100644 --- a/tests/default_config/conftest.py +++ b/tests/default_config/conftest.py @@ -1,19 +1,23 @@ +from typing import cast + import pytest +from sqlalchemy.engine import Connection from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter from tests.default_config.model import SoftDeleteMixin, TestModelBase from tests.default_config.seed_data import generate_table_with_inheritance_obj from tests.default_config.seed_data.parent_child_childchild import generate_parent_child_object_hierarchy @pytest.fixture -def db_session(db_connection) -> Session: - TestModelBase.metadata.create_all(db_connection) +def db_session(db_connection: Connection) -> Session: + TestModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() @pytest.fixture -def seeded_session(db_session) -> Session: +def seeded_session(db_session: Session) -> Session: generate_parent_child_object_hierarchy(db_session, 1000) generate_parent_child_object_hierarchy(db_session, 1001) generate_parent_child_object_hierarchy(db_session, 1002, parent_deleted=True) @@ -25,6 +29,5 @@ def seeded_session(db_session) -> Session: @pytest.fixture -def rewriter(): - # Access world-specific rewriter from mixin class - return SoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, SoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/default_config/model.py b/tests/default_config/model.py index 02fee69..c36c6a5 100644 --- a/tests/default_config/model.py +++ b/tests/default_config/model.py @@ -1,7 +1,7 @@ from datetime import datetime from sqlalchemy import Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.orm import Mapped, as_declarative, declared_attr, relationship +from sqlalchemy.orm import Mapped, as_declarative, relationship from sqlalchemy_easy_softdelete.hook import IgnoredTable from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class @@ -9,18 +9,14 @@ @as_declarative() class TestModelBase: - @declared_attr - def __tablename__(cls) -> str: - return cls.__name__.lower() - id = Column(Integer, primary_key=True, autoincrement=True) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} id={self.id}>" class SoftDeleteMixin( - generate_soft_delete_mixin_class( + generate_soft_delete_mixin_class( # type: ignore[misc] ignored_tables=[ IgnoredTable(table_schema=None, name="sdtablethatshouldnotbesoftdeleted"), ], @@ -31,7 +27,7 @@ class SoftDeleteMixin( # SQLAlchemy column that supports query operations like .where() deleted_at: Mapped[datetime | None] - # Optional: Add method stubs for delete/undelete if you want type hints for these. + # Optional: Add method stubs for delete/undelete for type checker support. # The actual implementations are provided by the generated mixin class. def delete(self, v: datetime | None = None) -> None: super().delete(v) # type: ignore[misc] @@ -41,62 +37,58 @@ def undelete(self) -> None: class SDSimpleTable(TestModelBase, SoftDeleteMixin): + __tablename__ = "sdsimpletable" int_field = Column(Integer) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}>" class SDParent(TestModelBase, SoftDeleteMixin): - __allow_unmapped__ = True - children = relationship("SDChild", back_populates="parent") + __tablename__ = "sdparent" + children: Mapped[list["SDChild"]] = relationship("SDChild", back_populates="parent") # type: ignore[assignment] - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}>" class SDChild(TestModelBase, SoftDeleteMixin): - __allow_unmapped__ = True - parent_id = Column(Integer, ForeignKey(f"{SDParent.__tablename__}.id"), nullable=False) - parent = relationship("SDParent", back_populates="children") - - child_children = relationship("SDChildChild", back_populates="child") + __tablename__ = "sdchild" + parent_id = Column(Integer, ForeignKey("sdparent.id"), nullable=False) + parent: Mapped["SDParent"] = relationship("SDParent", back_populates="children") # type: ignore[assignment] + child_children: Mapped[list["SDChildChild"]] = relationship("SDChildChild", back_populates="child") # type: ignore[assignment] - def __repr__(self): + def __repr__(self) -> str: pid = f"(parent_id={self.parent_id})" left = f"{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}" return f"<{left:30} {pid:>15}>" class SDChildChild(TestModelBase, SoftDeleteMixin): - __allow_unmapped__ = True + __tablename__ = "sdchildchild" + child_id = Column(Integer, ForeignKey("sdchild.id"), nullable=False) + child: Mapped["SDChild"] = relationship("SDChild", back_populates="child_children") # type: ignore[assignment] - child_id = Column(Integer, ForeignKey(f"{SDChild.__tablename__}.id"), nullable=False) - child: SDChild = relationship("SDChild", back_populates="child_children") - - def __repr__(self): + def __repr__(self) -> str: pid = f"(child_id={self.child_id})" left = f"{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}" return f"<{left:30} {pid:>15}>" -class SDBaseRequest( - TestModelBase, - SoftDeleteMixin, -): +class SDBaseRequest(TestModelBase, SoftDeleteMixin): + __tablename__ = "sdbaserequest" request_type = Column(String(50)) - base_field = Column(Integer) __mapper_args__ = { "polymorphic_identity": "sdbaserequest", - "polymorphic_on": request_type, + "polymorphic_on": "request_type", } class SDDerivedRequest(SDBaseRequest): - id: Integer = Column(Integer, ForeignKey("sdbaserequest.id"), primary_key=True) - + __tablename__ = "sdderivedrequest" + id = Column(Integer, ForeignKey("sdbaserequest.id"), primary_key=True) derived_field = Column(Integer) __mapper_args__ = { @@ -105,8 +97,9 @@ class SDDerivedRequest(SDBaseRequest): class SDTableThatShouldNotBeSoftDeleted(TestModelBase): - id: Integer = Column(Integer, primary_key=True) - deleted_at: datetime = Column(DateTime(timezone=True)) + __tablename__ = "sdtablethatshouldnotbesoftdeleted" + id = Column(Integer, primary_key=True) + deleted_at = Column(DateTime(timezone=True)) - def __repr__(self): - return f"<{self.__class__.__name__} id={self.id} name={self.name}>" + def __repr__(self) -> str: + return f"<{self.__class__.__name__} id={self.id}>" diff --git a/tests/default_config/seed_data/parent_child_childchild.py b/tests/default_config/seed_data/parent_child_childchild.py index 4235c54..247d727 100644 --- a/tests/default_config/seed_data/parent_child_childchild.py +++ b/tests/default_config/seed_data/parent_child_childchild.py @@ -1,20 +1,20 @@ -import datetime import random +from datetime import datetime, timedelta from sqlalchemy.orm import Session from tests.default_config.model import SDChild, SDChildChild, SDParent -TEST_EPOCH = datetime.datetime(year=1985, month=8, day=4) +TEST_EPOCH = datetime(year=1985, month=8, day=4) def pseudorandom_date(max_days: int = 3650) -> datetime: - return TEST_EPOCH + datetime.timedelta(days=random.randint(0, max_days)) + return TEST_EPOCH + timedelta(days=random.randint(0, max_days)) def generate_parent_child_object_hierarchy( s: Session, parent_id: int, min_children: int = 1, max_children: int = 5, parent_deleted: bool = False -): +) -> None: # Fix a seed in the RNG for deterministic outputs random.seed(parent_id) diff --git a/tests/default_config/test_type_hints.py b/tests/default_config/test_type_hints.py index 0275d4f..4c33ef8 100644 --- a/tests/default_config/test_type_hints.py +++ b/tests/default_config/test_type_hints.py @@ -141,7 +141,8 @@ def test_undelete_method_is_callable(seeded_session): """Verify that the undelete() method stub works correctly.""" # Get a deleted instance child = ( - seeded_session.query(SDChild) + seeded_session + .query(SDChild) .execution_options(include_deleted=True) .filter(SDChild.deleted_at.isnot(None)) .first() diff --git a/tests/disabled_methods/conftest.py b/tests/disabled_methods/conftest.py index c8d93d4..a59d520 100644 --- a/tests/disabled_methods/conftest.py +++ b/tests/disabled_methods/conftest.py @@ -1,15 +1,19 @@ +from typing import cast + import pytest +from sqlalchemy.engine import Connection from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter from tests.disabled_methods.model import DMModelBase, DMSoftDeleteMixin @pytest.fixture -def db_session(db_connection) -> Session: - DMModelBase.metadata.create_all(db_connection) +def db_session(db_connection: Connection) -> Session: + DMModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() @pytest.fixture -def rewriter(): - return DMSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, DMSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/disabled_methods/model.py b/tests/disabled_methods/model.py index 3f1b699..92ee1c8 100644 --- a/tests/disabled_methods/model.py +++ b/tests/disabled_methods/model.py @@ -1,9 +1,7 @@ -"""Models for testing disabled methods option.""" - from datetime import datetime from sqlalchemy import Column, Integer -from sqlalchemy.orm import Mapped, as_declarative, declared_attr +from sqlalchemy.orm import Mapped, as_declarative from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class @@ -12,15 +10,11 @@ class DMModelBase: """DM = Disabled Methods""" - @declared_attr - def __tablename__(cls) -> str: - return cls.__name__.lower() - id = Column(Integer, primary_key=True, autoincrement=True) class DMSoftDeleteMixin( - generate_soft_delete_mixin_class( + generate_soft_delete_mixin_class( # type: ignore[misc] generate_delete_method=False, generate_undelete_method=False, ) @@ -29,4 +23,5 @@ class DMSoftDeleteMixin( class DMTable(DMModelBase, DMSoftDeleteMixin): + __tablename__ = "dmtable" value = Column(Integer) diff --git a/tests/disabled_methods/test_disabled_methods.py b/tests/disabled_methods/test_disabled_methods.py index d9fc7cb..81c9fe0 100644 --- a/tests/disabled_methods/test_disabled_methods.py +++ b/tests/disabled_methods/test_disabled_methods.py @@ -26,12 +26,7 @@ def test_can_manually_set_deleted_at(db_session): db_session.commit() # After commit, the object is soft-deleted so query with include_deleted - result = ( - db_session.query(DMTable) - .execution_options(include_deleted=True) - .filter(DMTable.id == obj_id) - .first() - ) + result = db_session.query(DMTable).execution_options(include_deleted=True).filter(DMTable.id == obj_id).first() assert result is not None assert result.deleted_at is not None diff --git a/tests/integer_field_type/conftest.py b/tests/integer_field_type/conftest.py index 6232310..289b349 100644 --- a/tests/integer_field_type/conftest.py +++ b/tests/integer_field_type/conftest.py @@ -1,15 +1,19 @@ +from typing import cast + import pytest +from sqlalchemy.engine import Connection from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter from tests.integer_field_type.model import IFTModelBase, IFTSoftDeleteMixin @pytest.fixture -def db_session(db_connection) -> Session: - IFTModelBase.metadata.create_all(db_connection) +def db_session(db_connection: Connection) -> Session: + IFTModelBase.metadata.create_all(db_connection) # type: ignore[attr-defined] return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() @pytest.fixture -def rewriter(): - return IFTSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter +def rewriter() -> SoftDeleteQueryRewriter: + return cast(SoftDeleteQueryRewriter, IFTSoftDeleteMixin._sqlalchemy_easy_softdelete_rewriter) diff --git a/tests/integer_field_type/model.py b/tests/integer_field_type/model.py index db22ef1..cec1b74 100644 --- a/tests/integer_field_type/model.py +++ b/tests/integer_field_type/model.py @@ -1,9 +1,7 @@ -"""Models for testing integer field type option.""" - from datetime import datetime, timezone from sqlalchemy import Column, Integer -from sqlalchemy.orm import Mapped, as_declarative, declared_attr +from sqlalchemy.orm import Mapped, as_declarative from sqlalchemy_easy_softdelete.mixin import generate_soft_delete_mixin_class @@ -12,15 +10,11 @@ class IFTModelBase: """IFT = Integer Field Type""" - @declared_attr - def __tablename__(cls) -> str: - return cls.__name__.lower() - id = Column(Integer, primary_key=True, autoincrement=True) class IFTSoftDeleteMixin( - generate_soft_delete_mixin_class( + generate_soft_delete_mixin_class( # type: ignore[misc] deleted_field_type=Integer(), delete_method_default_value=lambda: int(datetime.now(timezone.utc).timestamp()), ) @@ -29,4 +23,5 @@ class IFTSoftDeleteMixin( class IFTTable(IFTModelBase, IFTSoftDeleteMixin): + __tablename__ = "ifttable" value = Column(Integer) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index de13e63..5b18f73 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -1,11 +1,13 @@ -from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList, Null +from typing import Any + +from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList, ColumnElement, Null from sqlalchemy.sql.schema import Table from sqlalchemy.sql.selectable import Select from tests.utils.simple_select_extractor import extract_simple_selects -def extract_binary_expressions_from_where(whereclause) -> tuple[BinaryExpression]: +def extract_binary_expressions_from_where(whereclause: Any) -> tuple[ColumnElement[Any], ...]: if isinstance(whereclause, BinaryExpression): return (whereclause,) @@ -14,16 +16,16 @@ def extract_binary_expressions_from_where(whereclause) -> tuple[BinaryExpression # Make sure we only have BinaryExpressions assert all(isinstance(c, BinaryExpression) for c in clauses) - return tuple(whereclause.clauses) + return clauses raise NotImplementedError(f'Unsupported whereclause type "{(type(whereclause))}"!') -def is_soft_delete_filter(b: BinaryExpression, tables: list[Table], deleted_field: str): +def is_soft_delete_filter(b: BinaryExpression[Any], tables: set[Table], deleted_field: str) -> bool: return b.left.table in tables and b.left.name == deleted_field and isinstance(b.right, Null) -def is_simple_select_doing_soft_delete_filtering(stmt: Select, tables: set[Table], deleted_field: str) -> bool: +def is_simple_select_doing_soft_delete_filtering(stmt: Select[Any], tables: set[Table], deleted_field: str) -> bool: # Check if query is disabled for soft-deletion opts = stmt.get_execution_options() if opts and opts.get("include_deleted"): @@ -37,9 +39,11 @@ def is_simple_select_doing_soft_delete_filtering(stmt: Select, tables: set[Table binary_expressions = extract_binary_expressions_from_where(stmt.whereclause) - found_tables = set() + found_tables: set[Table] = set() for binary_expression in binary_expressions: - if is_soft_delete_filter(binary_expression, tables, deleted_field): + if isinstance(binary_expression, BinaryExpression) and is_soft_delete_filter( + binary_expression, tables, deleted_field + ): found_tables.add(binary_expression.left.table) if found_tables == tables: @@ -48,7 +52,7 @@ def is_simple_select_doing_soft_delete_filtering(stmt: Select, tables: set[Table return False -def is_filtering_for_softdeleted(statement: Select, tables: set[Table], deleted_field: str = "deleted_at") -> bool: +def is_filtering_for_softdeleted(statement: Select[Any], tables: set[Table], deleted_field: str = "deleted_at") -> bool: selects = extract_simple_selects(statement) # Make sure all extracted selects are doing soft-delete filtering diff --git a/tests/utils/simple_select_extractor.py b/tests/utils/simple_select_extractor.py index a397f5a..2663f4a 100644 --- a/tests/utils/simple_select_extractor.py +++ b/tests/utils/simple_select_extractor.py @@ -6,9 +6,11 @@ from __future__ import annotations +from typing import Any + from sqlalchemy.orm.util import _ORMJoin from sqlalchemy.sql.schema import Table -from sqlalchemy.sql.selectable import CompoundSelect, Join, Select, SelectBase, Subquery +from sqlalchemy.sql.selectable import CompoundSelect, Join, Select, Subquery def is_simple_join(j: Join | _ORMJoin) -> bool: @@ -27,7 +29,7 @@ def is_simple_join(j: Join | _ORMJoin) -> bool: return left_simple and right_simple -def is_simple_select(s: Select | Subquery | CompoundSelect) -> bool: +def is_simple_select(s: Select[Any] | Subquery | CompoundSelect[Any]) -> bool: if isinstance(s, CompoundSelect): return False @@ -53,20 +55,20 @@ def is_simple_select(s: Select | Subquery | CompoundSelect) -> bool: return True -def extract_simple_selects(statement: Select | CompoundSelect | SelectBase) -> list[SelectBase]: +def extract_simple_selects(statement: Select[Any] | CompoundSelect[Any]) -> list[Select[Any]]: if is_simple_select(statement): - return [statement] + return [statement] # type: ignore[list-item] # We know it's a Select here if isinstance(statement, CompoundSelect): - extracted_elements = [] + extracted_elements: list[Select[Any]] = [] for select in statement.selects: - extracted_elements.extend(extract_simple_selects(select)) + extracted_elements.extend(extract_simple_selects(select)) # type: ignore[arg-type] return extracted_elements for from_obj in statement.get_final_froms(): if isinstance(from_obj, Table): continue elif isinstance(from_obj, Subquery): - return extract_simple_selects(from_obj.element) + return extract_simple_selects(from_obj.element) # type: ignore[arg-type] # element is Select raise NotImplementedError(f'Should not reach this point! statement.froms -> "{statement.froms}"!')