diff --git a/ad/__init__.py b/ad/__init__.py index dfbda4c..11ff70a 100644 --- a/ad/__init__.py +++ b/ad/__init__.py @@ -782,6 +782,118 @@ def __complex__(self) -> complex: def __bool__(self) -> bool: return bool(self.x) + def __array_ufunc__( # noqa: PLW3201 + self, ufunc: object, method: str, *inputs: object, **kwargs: object + ) -> object: + """Support NumPy ufunc dispatch (e.g. ``numpy.sin(x)``) on ADF objects. + + Returns + ------- + object + Result of the corresponding ``admath`` function, element-wise + result for arithmetic ufuncs, or ``NotImplemented`` if the + ufunc has no mapping. + """ + if method != "__call__": + return NotImplemented + + import operator # noqa: PLC0415 + + import numpy as np # noqa: PLC0415 + + import ad.admath as adm # noqa: PLC0415 + + math_ufunc_map = { + np.sin: adm.sin, + np.cos: adm.cos, + np.tan: adm.tan, + np.arcsin: adm.asin, + np.arccos: adm.acos, + np.arctan: adm.atan, + np.arctan2: adm.atan2, + np.sinh: adm.sinh, + np.cosh: adm.cosh, + np.tanh: adm.tanh, + np.arcsinh: adm.asinh, + np.arccosh: adm.acosh, + np.arctanh: adm.atanh, + np.exp: adm.exp, + np.expm1: adm.expm1, + np.log: adm.log, + np.log10: adm.log10, + np.log1p: adm.log1p, + np.sqrt: adm.sqrt, + np.ceil: adm.ceil, + np.floor: adm.floor, + np.trunc: adm.trunc, + np.hypot: adm.hypot, + np.degrees: adm.degrees, + np.radians: adm.radians, + np.power: adm.power, + } + + adm_func = math_ufunc_map.get(ufunc) # type: ignore[arg-type] + if adm_func is not None: + result = adm_func(*inputs) + out = kwargs.get("out") + if out is not None: + out[0][...] = result + return result + + arith_ufunc_map = { + np.multiply: operator.mul, + np.add: operator.add, + np.subtract: operator.sub, + np.true_divide: operator.truediv, + np.floor_divide: operator.floordiv, + np.remainder: operator.mod, + np.negative: operator.neg, + np.positive: operator.pos, + np.absolute: abs, + np.fabs: abs, + } + + arith_op = arith_ufunc_map.get(ufunc) # type: ignore[arg-type] + if arith_op is None: + return NotImplemented + + def _to_py(v: object) -> object: + return v.item() if isinstance(v, np.generic) else v # type: ignore[union-attr] + + def _apply(op: object, a: object, b: object = None) -> object: # type: ignore[assignment] + if b is None: + if isinstance(a, np.ndarray): + return np.array( # type: ignore[call-overload] + [op(_to_py(xi)) for xi in a.flat], # type: ignore[operator] + dtype=object, + ).reshape(a.shape) + return op(_to_py(a)) # type: ignore[operator] + if isinstance(a, np.ndarray) and not isinstance(b, np.ndarray): + return np.array( # type: ignore[call-overload] + [op(_to_py(xi), b) for xi in a.flat], # type: ignore[operator] + dtype=object, + ).reshape(a.shape) + if not isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return np.array( # type: ignore[call-overload] + [op(a, _to_py(xi)) for xi in b.flat], # type: ignore[operator] + dtype=object, + ).reshape(b.shape) + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return np.array( # type: ignore[call-overload] + [ + op(_to_py(ai), _to_py(bi)) # type: ignore[operator] + for ai, bi in zip(a.flat, b.flat, strict=False) + ], + dtype=object, + ).reshape(a.shape) + return op(_to_py(a), _to_py(b)) # type: ignore[operator] + + result = _apply(arith_op, *inputs) + out = kwargs.get("out") + if out is not None: + out[0][...] = result + return result + class ADV(ADF): """ diff --git a/pyproject.toml b/pyproject.toml index c565c42..e21d36b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,6 +121,9 @@ select = [ ignore = [] +[tool.ruff.lint.pylint] +max-public-methods = 31 + [tool.ruff.lint.per-file-ignores] "**/__init__.py" = ["RUF067"] "tests/*.py" = []