diff --git a/doc/release_notes.rst b/doc/release_notes.rst index bae61140..b2ff60de 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -79,6 +79,7 @@ Most users should keep calling ``model.solve(...)``. If you want more control, y * New ``ConstraintLabelIndex`` cached on ``Model.constraints`` (mirrors the existing ``Variables.label_index``); ``ConstraintBase`` gains ``active_labels()`` and a ``range`` property; ``CSRConstraint`` exposes ``coords``. * ``linopy.common`` gains ``values_to_lookup_array``; the legacy pandas-based helpers ``series_to_lookup_array`` and ``lookup_vals`` are removed. ``model.to_gurobipy()`` / ``model.to_highspy()`` / ``to_cupdlpx(model)`` (and similar) all return the underlying solver model as before; internally they now go through ``Solver.from_model(model, io_api="direct")``. No user-visible change. +* Adopt Python 3.11 type-syntax: the status enums (``ModelStatus``, ``SolverStatus``, ``TerminationCondition``) are now ``StrEnum``, and classmethods plus the expression base class use ``Self`` instead of string forward-references and a self-typed ``TypeVar``. No user-visible change — ``Model.solve()`` still returns ``(status, termination_condition)`` as plain strings. Version 0.7.0 ------------- diff --git a/linopy/constants.py b/linopy/constants.py index a1f4fb76..0e971827 100644 --- a/linopy/constants.py +++ b/linopy/constants.py @@ -5,8 +5,8 @@ import logging from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Literal, TypeAlias, Union, get_args +from enum import StrEnum +from typing import Any, Literal, Self, TypeAlias, get_args import numpy as np @@ -119,7 +119,7 @@ class EvolvingAPIWarning(FutureWarning): """ -class ModelStatus(Enum): +class ModelStatus(StrEnum): """ Model status. @@ -135,7 +135,7 @@ class ModelStatus(Enum): initialized = "initialized" -class SolverStatus(Enum): +class SolverStatus(StrEnum): """ Solver status. """ @@ -147,7 +147,7 @@ class SolverStatus(Enum): unknown = "unknown" @classmethod - def process(cls, status: str) -> "SolverStatus": + def process(cls, status: str) -> Self: try: return cls(status) except ValueError: @@ -163,7 +163,7 @@ def from_termination_condition( return cls("unknown") -class TerminationCondition(Enum): +class TerminationCondition(StrEnum): """ Termination condition of the solver. """ @@ -195,9 +195,7 @@ class TerminationCondition(Enum): licensing_problems = "licensing_problems" @classmethod - def process( - cls, termination_condition: Union["TerminationCondition", str] - ) -> "TerminationCondition": + def process(cls, termination_condition: Self | str) -> Self: if isinstance(termination_condition, TerminationCondition): termination_condition = termination_condition.value try: @@ -245,7 +243,7 @@ class Status: legacy_status: tuple[str, str] | str = "" @classmethod - def process(cls, status: str, termination_condition: str) -> "Status": + def process(cls, status: str, termination_condition: str) -> Self: return cls( status=SolverStatus.process(status), termination_condition=TerminationCondition.process(termination_condition), @@ -254,8 +252,8 @@ def process(cls, status: str, termination_condition: str) -> "Status": @classmethod def from_termination_condition( - cls, termination_condition: Union["TerminationCondition", str, None] - ) -> "Status": + cls, termination_condition: TerminationCondition | str | None + ) -> Self: termination_condition = TerminationCondition.process( termination_condition if termination_condition is not None else "unknown" ) diff --git a/linopy/expressions.py b/linopy/expressions.py index 96bc1adc..b0515ea2 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -14,7 +14,7 @@ from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field from itertools import product, zip_longest -from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Self, TypeAlias, TypeVar, cast, overload from warnings import warn import numpy as np @@ -462,40 +462,30 @@ def print(self, display_max_rows: int = 20, display_max_terms: int = 20) -> None print(self) @abstractmethod - def __add__( - self: GenericExpression, other: SideLike - ) -> GenericExpression | QuadraticExpression: ... + def __add__(self, other: SideLike) -> Self | QuadraticExpression: ... @abstractmethod - def __radd__(self: GenericExpression, other: SideLike) -> GenericExpression: ... + def __radd__(self, other: SideLike) -> Self: ... @abstractmethod - def __sub__( - self: GenericExpression, other: SideLike - ) -> GenericExpression | QuadraticExpression: ... + def __sub__(self, other: SideLike) -> Self | QuadraticExpression: ... @abstractmethod - def __rsub__(self: GenericExpression, other: SideLike) -> GenericExpression: ... + def __rsub__(self, other: SideLike) -> Self: ... @abstractmethod - def __mul__( - self: GenericExpression, other: SideLike - ) -> GenericExpression | QuadraticExpression: ... + def __mul__(self, other: SideLike) -> Self | QuadraticExpression: ... @abstractmethod - def __rmul__( - self: GenericExpression, other: SideLike - ) -> GenericExpression | QuadraticExpression: ... + def __rmul__(self, other: SideLike) -> Self | QuadraticExpression: ... @abstractmethod - def __matmul__( - self: GenericExpression, other: SideLike - ) -> GenericExpression | QuadraticExpression: ... + def __matmul__(self, other: SideLike) -> Self | QuadraticExpression: ... @abstractmethod def __pow__(self, other: int) -> QuadraticExpression: ... - def __neg__(self: GenericExpression) -> GenericExpression: + def __neg__(self) -> Self: """ Get the negative of the expression. """ @@ -529,7 +519,7 @@ def _multiply_by_linear_expression( return cast(QuadraticExpression, res) def _align_constant( - self: GenericExpression, + self, other: DataArray, fill_value: float = 0, join: JoinOptions | None = None, @@ -575,8 +565,8 @@ def _align_constant( return self_const, aligned, True def _add_constant( - self: GenericExpression, other: ConstantLike, join: JoinOptions | None = None - ) -> GenericExpression: + self, other: ConstantLike, join: JoinOptions | None = None + ) -> Self: # NaN values in self.const or other are filled with 0 (additive identity) # so that missing data does not silently propagate through arithmetic. if np.isscalar(other) and join is None: @@ -599,12 +589,12 @@ def _add_constant( return self.assign(const=self_const + da) def _apply_constant_op( - self: GenericExpression, + self, other: ConstantLike, op: Callable[[DataArray, DataArray], DataArray], fill_value: float, join: JoinOptions | None = None, - ) -> GenericExpression: + ) -> Self: """ Apply a constant operation (mul, div, etc.) to this expression with a scalar or array. @@ -633,16 +623,16 @@ def _apply_constant_op( return self.assign(coeffs=op(coeffs, factor), const=op(self_const, factor)) def _multiply_by_constant( - self: GenericExpression, other: ConstantLike, join: JoinOptions | None = None - ) -> GenericExpression: + self, other: ConstantLike, join: JoinOptions | None = None + ) -> Self: return self._apply_constant_op(other, operator.mul, fill_value=0, join=join) def _divide_by_constant( - self: GenericExpression, other: ConstantLike, join: JoinOptions | None = None - ) -> GenericExpression: + self, other: ConstantLike, join: JoinOptions | None = None + ) -> Self: return self._apply_constant_op(other, operator.truediv, fill_value=1, join=join) - def __div__(self: GenericExpression, other: SideLike) -> GenericExpression: + def __div__(self, other: SideLike) -> Self: try: if isinstance(other, SUPPORTED_EXPRESSION_TYPES): raise TypeError( @@ -654,7 +644,7 @@ def __div__(self: GenericExpression, other: SideLike) -> GenericExpression: except TypeError: return NotImplemented - def __truediv__(self: GenericExpression, other: SideLike) -> GenericExpression: + def __truediv__(self, other: SideLike) -> Self: return self.__div__(other) def __le__(self, rhs: SideLike) -> Constraint: @@ -677,10 +667,10 @@ def __lt__(self, other: Any) -> NotImplementedType: ) def add( - self: GenericExpression, + self, other: SideLike, join: JoinOptions | None = None, - ) -> GenericExpression | QuadraticExpression: + ) -> Self | QuadraticExpression: """ Add an expression to others. @@ -705,10 +695,10 @@ def add( return merge([self, other], cls=self.__class__, join=join) def sub( - self: GenericExpression, + self, other: SideLike, join: JoinOptions | None = None, - ) -> GenericExpression | QuadraticExpression: + ) -> Self | QuadraticExpression: """ Subtract others from expression. @@ -724,10 +714,10 @@ def sub( return self.add(-other, join=join) def mul( - self: GenericExpression, + self, other: SideLike, join: JoinOptions | None = None, - ) -> GenericExpression | QuadraticExpression: + ) -> Self | QuadraticExpression: """ Multiply the expr by a factor. @@ -749,10 +739,10 @@ def mul( return self._multiply_by_constant(other, join=join) def div( - self: GenericExpression, + self, other: VariableLike | ConstantLike, join: JoinOptions | None = None, - ) -> GenericExpression | QuadraticExpression: + ) -> Self | QuadraticExpression: """ Divide the expr by a factor. @@ -776,7 +766,7 @@ def div( return self._divide_by_constant(other, join=join) def le( - self: GenericExpression, + self, rhs: SideLike, join: JoinOptions | None = None, ) -> Constraint: @@ -838,17 +828,13 @@ def pow(self, other: int) -> QuadraticExpression: """ return self.__pow__(other) - def dot( - self: GenericExpression, other: ndarray - ) -> GenericExpression | QuadraticExpression: + def dot(self, other: ndarray) -> Self | QuadraticExpression: """ Matrix multiplication with other, similar to xarray dot. """ return self.__matmul__(other) - def __getitem__( - self: GenericExpression, selector: int | tuple[slice, list[int]] | slice - ) -> GenericExpression: + def __getitem__(self, selector: int | tuple[slice, list[int]] | slice) -> Self: """ Get selection from the expression. This is a wrapper around the xarray __getitem__ method. It returns a @@ -987,11 +973,11 @@ def solution(self) -> DataArray: return sol.rename("solution") def sum( - self: GenericExpression, + self, dim: DimsLike | None = None, drop_zeros: bool = False, **kwargs: Any, - ) -> GenericExpression: + ) -> Self: """ Sum the expression over all or a subset of dimensions. @@ -1141,7 +1127,7 @@ def to_constraint( ) return constraints.Constraint(data, model=self.model) - def reset_const(self: GenericExpression) -> GenericExpression: + def reset_const(self) -> Self: """ Reset the constant of the linear expression to zero. """ @@ -1159,7 +1145,7 @@ def isnull(self) -> DataArray: return (self.vars == -1).all(helper_dims) & self.const.isnull() def where( - self: GenericExpression, + self, cond: DataArray, other: LinearExpression | int @@ -1167,7 +1153,7 @@ def where( | dict[str, float | int | DataArray] | None = None, **kwargs: Any, - ) -> GenericExpression: + ) -> Self: """ Filter variables based on a condition. @@ -1211,14 +1197,14 @@ def where( return self.__class__(self.data.where(cond, other=_other, **kwargs), self.model) def fillna( - self: GenericExpression, + self, value: int | float | DataArray | Dataset | LinearExpression | dict[str, float | int | DataArray], - ) -> GenericExpression: + ) -> Self: """ Fill missing values with a given value. @@ -1242,7 +1228,7 @@ def fillna( value = {"const": value} return self.__class__(self.data.fillna(value), self.model) - def diff(self: GenericExpression, dim: str, n: int = 1) -> GenericExpression: + def diff(self, dim: str, n: int = 1) -> Self: """ Calculate the n-th order discrete difference along given axis. @@ -1414,7 +1400,7 @@ def empty(self) -> EmptyDeprecationWrapper: """ return EmptyDeprecationWrapper(not self.size) - def densify_terms(self: GenericExpression) -> GenericExpression: + def densify_terms(self) -> Self: """ Move all non-zero term entries to the front and cut off all-zero entries in the term-axis. @@ -1445,7 +1431,7 @@ def densify_terms(self: GenericExpression) -> GenericExpression: return self.__class__(data.sel({TERM_DIM: slice(0, nterm)}), self.model) - def sanitize(self: GenericExpression) -> GenericExpression: + def sanitize(self) -> Self: """ Sanitize LinearExpression by ensuring int dtype for variables.