diff --git a/manim/camera/moving_camera.py b/manim/camera/moving_camera.py index af1fbb22a9..3bb61120d2 100644 --- a/manim/camera/moving_camera.py +++ b/manim/camera/moving_camera.py @@ -10,7 +10,7 @@ __all__ = ["MovingCamera"] from collections.abc import Iterable -from typing import Any +from typing import Any, Literal, overload from cairo import Context @@ -20,7 +20,7 @@ from ..camera.camera import Camera from ..constants import DOWN, LEFT, RIGHT, UP from ..mobject.frame import ScreenRectangle -from ..mobject.mobject import Mobject +from ..mobject.mobject import Mobject, _AnimationBuilder from ..utils.color import WHITE, ManimColor @@ -166,13 +166,31 @@ def get_mobjects_indicating_movement(self) -> list[Mobject]: """ return [self.frame] + @overload + def auto_zoom( + self, + mobjects: Iterable[Mobject], + margin: float, + only_mobjects_in_frame: bool, + animate: Literal[False], + ) -> Mobject: ... + + @overload + def auto_zoom( + self, + mobjects: Iterable[Mobject], + margin: float, + only_mobjects_in_frame: bool, + animate: Literal[True], + ) -> _AnimationBuilder: ... + def auto_zoom( self, mobjects: Iterable[Mobject], margin: float = 0, only_mobjects_in_frame: bool = False, animate: bool = True, - ) -> Mobject: + ) -> _AnimationBuilder | Mobject: """Zooms on to a given array of mobjects (or a singular mobject) and automatically resizes to frame all the mobjects. diff --git a/manim/mobject/geometry/line.py b/manim/mobject/geometry/line.py index c55fd94f6e..2cd7aff807 100644 --- a/manim/mobject/geometry/line.py +++ b/manim/mobject/geometry/line.py @@ -14,14 +14,14 @@ "RightAngle", ] -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np from manim import config from manim.constants import * from manim.mobject.geometry.arc import Arc, ArcBetweenPoints, Dot, TipableVMobject -from manim.mobject.geometry.tips import ArrowTriangleFilledTip +from manim.mobject.geometry.tips import ArrowTip, ArrowTriangleFilledTip from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.mobject.opengl.opengl_mobject import OpenGLMobject @@ -648,9 +648,11 @@ def scale(self, factor: float, scale_tips: bool = False, **kwargs: Any) -> Self: self._set_stroke_width_from_length() if has_tip: - self.add_tip(tip=old_tips[0]) + # error: Argument "tip" to "add_tip" of "TipableVMobject" has incompatible type "VMobject"; expected "ArrowTip | None" [arg-type] + self.add_tip(tip=cast(ArrowTip, old_tips[0])) if has_start_tip: - self.add_tip(tip=old_tips[1], at_start=True) + # error: Argument "tip" to "add_tip" of "TipableVMobject" has incompatible type "VMobject"; expected "ArrowTip | None" [arg-type] + self.add_tip(tip=cast(ArrowTip, old_tips[1]), at_start=True) return self def get_normal_vector(self) -> Vector3D: diff --git a/manim/mobject/graphing/probability.py b/manim/mobject/graphing/probability.py index 4fe0e1f393..4f5d882575 100644 --- a/manim/mobject/graphing/probability.py +++ b/manim/mobject/graphing/probability.py @@ -207,13 +207,11 @@ def add_braces_and_labels(self) -> None: if hasattr(parts, subattr): self.add(getattr(parts, subattr)) - def __getitem__(self, index: int) -> SampleSpace: + def __getitem__(self, index: int) -> VMobject: if hasattr(self, "horizontal_parts"): - val: SampleSpace = self.horizontal_parts[index] - return val + return self.horizontal_parts[index] elif hasattr(self, "vertical_parts"): - val = self.vertical_parts[index] - return val + return self.vertical_parts[index] return self.split()[index] diff --git a/manim/mobject/matrix.py b/manim/mobject/matrix.py index 5a965cc655..73fc7b4e53 100644 --- a/manim/mobject/matrix.py +++ b/manim/mobject/matrix.py @@ -40,15 +40,15 @@ def construct(self): import itertools as it -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable from typing import Any, Self import numpy as np -from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.mobject.text.numbers import DecimalNumber, Integer from manim.mobject.text.tex_mobject import MathTex, Tex +from manim.typing import Vector2DLike, Vector3DLike from ..constants import * from ..mobject.types.vectorized_mobject import VGroup, VMobject @@ -164,16 +164,16 @@ def construct(self): def __init__( self, - matrix: Iterable, + matrix: Iterable[Iterable[Any] | Vector2DLike], v_buff: float = 0.8, h_buff: float = 1.3, bracket_h_buff: float = MED_SMALL_BUFF, bracket_v_buff: float = MED_SMALL_BUFF, add_background_rectangles_to_entries: bool = False, include_background_rectangle: bool = False, - element_to_mobject: type[Mobject] | Callable[..., Mobject] = MathTex, - element_to_mobject_config: dict = {}, - element_alignment_corner: Sequence[float] = DR, + element_to_mobject: type[VMobject] | Callable[..., VMobject] = MathTex, + element_to_mobject_config: dict[str, Any] = {}, + element_alignment_corner: Vector3DLike = DR, left_bracket: str = "[", right_bracket: str = "]", stretch_brackets: bool = True, @@ -206,7 +206,9 @@ def __init__( if self.include_background_rectangle: self.add_background_rectangle() - def _matrix_to_mob_matrix(self, matrix: np.ndarray) -> list[list[Mobject]]: + def _matrix_to_mob_matrix( + self, matrix: Iterable[Iterable[Any]] + ) -> list[list[VMobject]]: return [ [ self.element_to_mobject(item, **self.element_to_mobject_config) @@ -215,7 +217,7 @@ def _matrix_to_mob_matrix(self, matrix: np.ndarray) -> list[list[Mobject]]: for row in matrix ] - def _organize_mob_matrix(self, matrix: list[list[Mobject]]) -> Self: + def _organize_mob_matrix(self, matrix: list[list[VMobject]]) -> Self: for i, row in enumerate(matrix): for j, _ in enumerate(row): mob = matrix[i][j] @@ -401,7 +403,7 @@ def add_background_to_entries(self) -> Self: mob.add_background_rectangle() return self - def get_mob_matrix(self) -> list[list[Mobject]]: + def get_mob_matrix(self) -> list[list[VMobject]]: """Return the underlying mob matrix mobjects. Returns @@ -483,8 +485,8 @@ def construct(self): def __init__( self, - matrix: Iterable, - element_to_mobject: type[Mobject] = DecimalNumber, + matrix: Iterable[Iterable[Any]], + element_to_mobject: type[VMobject] | Callable[..., VMobject] = DecimalNumber, element_to_mobject_config: dict[str, Any] = {"num_decimal_places": 1}, **kwargs: Any, ): @@ -528,8 +530,8 @@ def construct(self): def __init__( self, - matrix: Iterable, - element_to_mobject: type[Mobject] = Integer, + matrix: Iterable[Iterable[Any]], + element_to_mobject: type[VMobject] | Callable[..., VMobject] = Integer, **kwargs: Any, ): """ @@ -566,8 +568,8 @@ def construct(self): def __init__( self, - matrix: Iterable, - element_to_mobject: type[Mobject] | Callable[..., Mobject] = lambda m: m, + matrix: Iterable[Iterable[Any]], + element_to_mobject: type[VMobject] | Callable[..., VMobject] = lambda m: m, **kwargs: Any, ): super().__init__(matrix, element_to_mobject=element_to_mobject, **kwargs) diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index 14738d22a7..9f3818c188 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -14,10 +14,10 @@ import sys import types import warnings -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Iterator, Sequence from functools import partialmethod, reduce from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -43,24 +43,30 @@ if TYPE_CHECKING: from typing import Self, TypeAlias + from PIL import Image + + from manim.mobject.types.point_cloud_mobject import Point from manim.typing import ( FunctionOverride, MappingFunction, MatrixMN, MultiMappingFunction, PathFuncType, - PixelArray, Point3D, + Point3D_Array, Point3DLike, Point3DLike_Array, + Vector3D, Vector3DLike, ) from ..animation.animation import Animation + from ..camera.camera import Camera + - TimeBasedUpdater: TypeAlias = Callable[["Mobject", float], object] - NonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], object] - Updater: TypeAlias = NonTimeBasedUpdater | TimeBasedUpdater +_TimeBasedUpdater: TypeAlias = Callable[["Mobject", float], object] +_NonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], object] +_Updater: TypeAlias = _NonTimeBasedUpdater | _TimeBasedUpdater class Mobject: @@ -83,16 +89,18 @@ class Mobject: """ - animation_overrides = {} + original_id: str + _original__init__: Callable[..., None] + animation_overrides: dict[ + type[Animation], + FunctionOverride, + ] = {} @classmethod - def __init_subclass__(cls, **kwargs) -> None: + def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) - cls.animation_overrides: dict[ - type[Animation], - FunctionOverride, - ] = {} + cls.animation_overrides = {} cls._add_intrinsic_animation_overrides() cls._original__init__ = cls.__init__ @@ -101,16 +109,16 @@ def __init__( color: ParsableManimColor | list[ParsableManimColor] = WHITE, name: str | None = None, dim: int = 3, - target=None, + target: Mobject | None = None, z_index: float = 0, - ) -> None: + ): self.name = self.__class__.__name__ if name is None else name self.dim = dim self.target = target self.z_index = z_index self.point_hash = None - self.submobjects = [] - self.updaters: list[Updater] = [] + self.submobjects: list[Mobject] = [] + self.updaters: list[_Updater] = [] self.updating_suspended = False self.color = ManimColor.parse(color) @@ -151,7 +159,7 @@ def _assert_valid_submobjects(self, submobjects: Iterable[Mobject]) -> Self: return self._assert_valid_submobjects_internal(submobjects, Mobject) def _assert_valid_submobjects_internal( - self, submobjects: list[Mobject], mob_class: type[Mobject] + self, submobjects: Iterable[Mobject], mob_class: type[Mobject] ) -> Self: for i, submob in enumerate(submobjects): if not isinstance(submob, mob_class): @@ -247,7 +255,7 @@ def add_animation_override( ) @classmethod - def set_default(cls, **kwargs) -> None: + def set_default(cls, **kwargs: Any) -> None: """Sets the default values of keyword arguments. If this method is called without any additional keyword @@ -290,8 +298,11 @@ def construct(self): """ if kwargs: - cls.__init__ = partialmethod(cls.__init__, **kwargs) + # Apparently mypy does not correctly understand `partialmethod`: + # see https://github.com/python/mypy/issues/8619 + cls.__init__ = partialmethod(cls.__init__, **kwargs) # type: ignore[assignment] else: + # error: Cannot assign to a method [method-assign] cls.__init__ = cls._original__init__ @property @@ -430,7 +441,7 @@ def construct(self): # can't use typing.cast because Self is under TYPE_CHECKING return _UpdaterBuilder(self) # type: ignore[return-value] - def __deepcopy__(self, clone_from_id) -> Self: + def __deepcopy__(self, clone_from_id: dict[int, Mobject]) -> Self: cls = self.__class__ result = cls.__new__(cls) clone_from_id[id(self)] = result @@ -447,7 +458,7 @@ def reset_points(self) -> Self: self.points = np.zeros((0, self.dim)) return self - def init_colors(self) -> object: + def init_colors(self, propagate_colors: bool = True) -> object: """Initializes the colors. Gets called upon creation. This is an empty method that can be implemented by @@ -567,10 +578,10 @@ def insert(self, index: int, mobject: Mobject) -> None: self._assert_valid_submobjects([mobject]) self.submobjects.insert(index, mobject) - def __add__(self, mobject: Mobject): + def __add__(self, mobject: Mobject) -> Self: raise NotImplementedError - def __iadd__(self, mobject: Mobject): + def __iadd__(self, mobject: Mobject) -> Self: raise NotImplementedError def add_to_back(self, *mobjects: Mobject) -> Self: @@ -650,13 +661,13 @@ def remove(self, *mobjects: Mobject) -> Self: self.submobjects.remove(mobject) return self - def __sub__(self, other): + def __sub__(self, other: Mobject) -> Self: raise NotImplementedError - def __isub__(self, other): + def __isub__(self, other: Mobject) -> Self: raise NotImplementedError - def set(self, **kwargs) -> Self: + def set(self, **kwargs: Any) -> Self: """Sets attributes. I.e. ``my_mobject.set(foo=1)`` applies ``my_mobject.foo = 1``. @@ -724,7 +735,7 @@ def __getattr__(self, attr: str) -> types.MethodType: # Remove the "get_" prefix to_get = attr[4:] - def getter(self): + def getter(self: Mobject) -> Any: warnings.warn( "This method is not guaranteed to stay around. Please prefer " "getting the attribute normally.", @@ -741,7 +752,7 @@ def getter(self): # Remove the "set_" prefix to_set = attr[4:] - def setter(self, value): + def setter(self: Mobject, value: Any) -> Mobject: warnings.warn( "This method is not guaranteed to stay around. Please prefer " "setting the attribute normally or with Mobject.set().", @@ -792,7 +803,7 @@ def construct(self): return self.length_over_dim(0) @width.setter - def width(self, value: float): + def width(self, value: float) -> None: self.scale_to_fit_width(value) @property @@ -828,7 +839,7 @@ def construct(self): return self.length_over_dim(1) @height.setter - def height(self, value: float): + def height(self, value: float) -> None: self.scale_to_fit_height(value) @property @@ -848,7 +859,7 @@ def depth(self) -> float: return self.length_over_dim(2) @depth.setter - def depth(self, value: float): + def depth(self, value: float) -> None: self.scale_to_fit_depth(value) # Can't be staticmethod because of point_cloud_mobject.py @@ -861,16 +872,13 @@ def apply_over_attr_arrays(self, func: MultiMappingFunction) -> Self: return self # Displaying - - def get_image(self, camera=None) -> PixelArray: + def get_image(self, camera: Camera | None = None) -> Image.Image: if camera is None: - from ..camera.camera import Camera - camera = Camera() camera.capture_mobject(self) return camera.get_image() - def show(self, camera=None) -> None: + def show(self, camera: Camera | None = None) -> None: self.get_image(camera=camera).show() def save_image(self, name: str | None = None) -> None: @@ -930,18 +938,21 @@ def update(self, dt: float = 0, recursive: bool = True) -> Self: :meth:`get_updaters` """ - if not self.updating_suspended: - for updater in self.updaters: - if "dt" in inspect.signature(updater).parameters: - updater(self, dt) - else: - updater(self) + if self.updating_suspended: + return self + for updater in self.updaters: + if "dt" in inspect.signature(updater).parameters: + time_based_updater = cast(_TimeBasedUpdater, updater) + time_based_updater(self, dt) + else: + non_time_based_updater = cast(_NonTimeBasedUpdater, updater) + non_time_based_updater(self) if recursive: for submob in self.submobjects: submob.update(dt, recursive=recursive) return self - def get_time_based_updaters(self) -> list[TimeBasedUpdater]: + def get_time_based_updaters(self) -> list[_TimeBasedUpdater]: """Return all updaters using the ``dt`` parameter. The updaters use this parameter as the input for difference in time. @@ -957,11 +968,12 @@ def get_time_based_updaters(self) -> list[TimeBasedUpdater]: :meth:`has_time_based_updater` """ - return [ - updater - for updater in self.updaters - if "dt" in inspect.signature(updater).parameters - ] + rv: list[_TimeBasedUpdater] = [] + for updater in self.updaters: + if "dt" in inspect.signature(updater).parameters: + time_based_updater = cast(_TimeBasedUpdater, updater) + rv.append(time_based_updater) + return rv def has_time_based_updater(self) -> bool: """Test if ``self`` has a time based updater. @@ -981,7 +993,7 @@ def has_time_based_updater(self) -> bool: "dt" in inspect.signature(updater).parameters for updater in self.updaters ) - def get_updaters(self) -> list[Updater]: + def get_updaters(self) -> list[_Updater]: """Return all updaters. Returns @@ -997,12 +1009,12 @@ def get_updaters(self) -> list[Updater]: """ return self.updaters - def get_family_updaters(self) -> list[Updater]: + def get_family_updaters(self) -> list[_Updater]: return list(it.chain(*(sm.get_updaters() for sm in self.get_family()))) def add_updater( self, - update_function: Updater, + update_function: _Updater, index: int | None = None, call_updater: bool = False, ) -> Self: @@ -1076,12 +1088,15 @@ def construct(self): if call_updater: parameters = inspect.signature(update_function).parameters if "dt" in parameters: - update_function(self, 0) + time_based_updater = cast(_TimeBasedUpdater, update_function) + time_based_updater(self, 0) else: - update_function(self) + non_time_based_updater = cast(_NonTimeBasedUpdater, update_function) + non_time_based_updater(self) + return self - def remove_updater(self, update_function: Updater) -> Self: + def remove_updater(self, update_function: _Updater) -> Self: """Remove an updater. If the same updater is applied multiple times, every instance gets removed. @@ -1328,6 +1343,7 @@ def rotate( *, about_point: Point3DLike | None = None, about_edge: Vector3DLike | None = None, + **kwargs: Any, ) -> Self: """Rotates the :class:`~.Mobject` around a specified axis and point. @@ -1516,10 +1532,10 @@ def construct(self): self.play(t.animate.set_value(TAU), run_time=3) """ - def R3_func(point): + def R3_func(point: Point3D) -> Point3D: x, y, z = point xy_complex = function(complex(x, y)) - return [xy_complex.real, xy_complex.imag, z] + return np.array([xy_complex.real, xy_complex.imag, z]) return self.apply_function( R3_func, about_point=about_point, about_edge=about_edge @@ -1533,7 +1549,7 @@ def reverse_points(self) -> Self: def repeat(self, count: int) -> Self: """This can make transition animations nicer""" - def repeat_array(array): + def repeat_array(array: Point3D_Array) -> Point3D_Array: return reduce(lambda a1, a2: np.append(a1, a2, axis=0), [array] * count) for mob in self.family_members_with_points(): @@ -1563,7 +1579,7 @@ def apply_points_function_about_point( mob.points += about_point return self - def pose_at_angle(self, **kwargs): + def pose_at_angle(self, **kwargs: Any) -> Self: self.rotate(TAU / 14, RIGHT + UP, **kwargs) return self @@ -1709,7 +1725,7 @@ def construct(self): self.shift((target_point - point_to_align + buff * np_direction) * coor_mask) return self - def shift_onto_screen(self, **kwargs) -> Self: + def shift_onto_screen(self, **kwargs: Any) -> Self: space_lengths = [config["frame_x_radius"], config["frame_y_radius"]] for vect in UP, DOWN, LEFT, RIGHT: dim = np.argmax(np.abs(vect)) @@ -1720,20 +1736,21 @@ def shift_onto_screen(self, **kwargs) -> Self: self.to_edge(vect, **kwargs) return self - def is_off_screen(self): + def is_off_screen(self) -> bool: if self.get_left()[0] > config["frame_x_radius"]: return True if self.get_right()[0] < -config["frame_x_radius"]: return True if self.get_bottom()[1] > config["frame_y_radius"]: return True - return self.get_top()[1] < -config["frame_y_radius"] + rv: bool = self.get_top()[1] < -config["frame_y_radius"] + return rv def stretch_about_point(self, factor: float, dim: int, point: Point3DLike) -> Self: return self.stretch(factor, dim, about_point=point) def rescale_to_fit( - self, length: float, dim: int, stretch: bool = False, **kwargs + self, length: float, dim: int, stretch: bool = False, **kwargs: Any ) -> Self: old_length = self.length_over_dim(dim) if old_length == 0: @@ -1744,7 +1761,7 @@ def rescale_to_fit( self.scale(length / old_length, **kwargs) return self - def scale_to_fit_width(self, width: float, **kwargs) -> Self: + def scale_to_fit_width(self, width: float, **kwargs: Any) -> Self: """Scales the :class:`~.Mobject` to fit a width while keeping height/depth proportional. Returns @@ -1769,7 +1786,7 @@ def scale_to_fit_width(self, width: float, **kwargs) -> Self: """ return self.rescale_to_fit(width, 0, stretch=False, **kwargs) - def stretch_to_fit_width(self, width: float, **kwargs) -> Self: + def stretch_to_fit_width(self, width: float, **kwargs: Any) -> Self: """Stretches the :class:`~.Mobject` to fit a width, not keeping height/depth proportional. Returns @@ -1794,7 +1811,7 @@ def stretch_to_fit_width(self, width: float, **kwargs) -> Self: """ return self.rescale_to_fit(width, 0, stretch=True, **kwargs) - def scale_to_fit_height(self, height: float, **kwargs) -> Self: + def scale_to_fit_height(self, height: float, **kwargs: Any) -> Self: """Scales the :class:`~.Mobject` to fit a height while keeping width/depth proportional. Returns @@ -1819,7 +1836,7 @@ def scale_to_fit_height(self, height: float, **kwargs) -> Self: """ return self.rescale_to_fit(height, 1, stretch=False, **kwargs) - def stretch_to_fit_height(self, height: float, **kwargs) -> Self: + def stretch_to_fit_height(self, height: float, **kwargs: Any) -> Self: """Stretches the :class:`~.Mobject` to fit a height, not keeping width/depth proportional. Returns @@ -1844,15 +1861,17 @@ def stretch_to_fit_height(self, height: float, **kwargs) -> Self: """ return self.rescale_to_fit(height, 1, stretch=True, **kwargs) - def scale_to_fit_depth(self, depth: float, **kwargs) -> Self: + def scale_to_fit_depth(self, depth: float, **kwargs: Any) -> Self: """Scales the :class:`~.Mobject` to fit a depth while keeping width/height proportional.""" return self.rescale_to_fit(depth, 2, stretch=False, **kwargs) - def stretch_to_fit_depth(self, depth: float, **kwargs) -> Self: + def stretch_to_fit_depth(self, depth: float, **kwargs: Any) -> Self: """Stretches the :class:`~.Mobject` to fit a depth, not keeping width/height proportional.""" return self.rescale_to_fit(depth, 2, stretch=True, **kwargs) - def set_coord(self, value, dim: int, direction: Vector3DLike = ORIGIN) -> Self: + def set_coord( + self, value: float, dim: int, direction: Vector3DLike = ORIGIN + ) -> Self: curr = self.get_coord(dim, direction) shift_vect = np.zeros(self.dim) shift_vect[dim] = value - curr @@ -1871,7 +1890,7 @@ def set_z(self, z: float, direction: Vector3DLike = ORIGIN) -> Self: """Set z value of the center of the :class:`~.Mobject` (``int`` or ``float``)""" return self.set_coord(z, 2, direction) - def space_out_submobjects(self, factor: float = 1.5, **kwargs) -> Self: + def space_out_submobjects(self, factor: float = 1.5, **kwargs: Any) -> Self: self.scale(factor, **kwargs) for submob in self.submobjects: submob.scale(1.0 / factor) @@ -1950,7 +1969,10 @@ def put_start_and_end_on(self, start: Point3DLike, end: Point3DLike) -> Self: # Background rectangle def add_background_rectangle( - self, color: ParsableManimColor | None = None, opacity: float = 0.75, **kwargs + self, + color: ParsableManimColor | None = None, + opacity: float = 0.75, + **kwargs: Any, ) -> Self: """Add a BackgroundRectangle as submobject. @@ -1989,12 +2011,14 @@ def add_background_rectangle( self.add_to_back(self.background_rectangle) return self - def add_background_rectangle_to_submobjects(self, **kwargs) -> Self: + def add_background_rectangle_to_submobjects(self, **kwargs: Any) -> Self: for submobject in self.submobjects: submobject.add_background_rectangle(**kwargs) return self - def add_background_rectangle_to_family_members_with_points(self, **kwargs) -> Self: + def add_background_rectangle_to_family_members_with_points( + self, **kwargs: Any + ) -> Self: for mob in self.family_members_with_points(): mob.add_background_rectangle(**kwargs) return self @@ -2002,7 +2026,10 @@ def add_background_rectangle_to_family_members_with_points(self, **kwargs) -> Se # Color functions def set_color( - self, color: ParsableManimColor = PURE_YELLOW, family: bool = True + self, + color: ParsableManimColor = PURE_YELLOW, + alpha: Any = None, + family: bool = True, ) -> Self: """Condition is function which takes in one arguments, (x, y, z). Here it just recurses to submobjects, but in subclasses this @@ -2044,7 +2071,7 @@ def set_colors_by_radial_gradient( ) return self - def set_submobject_colors_by_gradient(self, *colors: Iterable[ParsableManimColor]): + def set_submobject_colors_by_gradient(self, *colors: ParsableManimColor) -> Self: if len(colors) == 0: raise ValueError("Need at least one color") elif len(colors) == 1: @@ -2070,7 +2097,9 @@ def set_submobject_colors_by_radial_gradient( for mob in self.family_members_with_points(): t = np.linalg.norm(mob.get_center() - center) / radius t = min(t, 1) - mob_color = interpolate_color(inner_color, outer_color, t) + mob_color = interpolate_color( + ManimColor(inner_color), ManimColor(outer_color), t + ) mob.set_color(mob_color, family=False) return self @@ -2083,7 +2112,7 @@ def fade_to( self, color: ParsableManimColor, alpha: float, family: bool = True ) -> Self: if self.get_num_points() > 0: - new_color = interpolate_color(self.get_color(), color, alpha) + new_color = interpolate_color(self.get_color(), ManimColor(color), alpha) self.set_color(new_color, family=False) if family: for submob in self.submobjects: @@ -2123,12 +2152,14 @@ def save_state(self) -> Self: def restore(self) -> Self: """Restores the state that was previously saved with :meth:`~.Mobject.save_state`.""" - if not hasattr(self, "saved_state") or self.save_state is None: + if not hasattr(self, "saved_state") or self.saved_state is None: raise Exception("Trying to restore without having saved") self.become(self.saved_state) return self - def reduce_across_dimension(self, reduce_func: Callable, dim: int): + def reduce_across_dimension( + self, reduce_func: Callable[[Iterable[float]], float], dim: int + ) -> float: """Find the min or max value from a dimension across all points in this and submobjects.""" assert dim >= 0 assert dim <= 2 @@ -2148,9 +2179,10 @@ def reduce_across_dimension(self, reduce_func: Callable, dim: int): for mobj in self.submobjects: value = mobj.reduce_across_dimension(reduce_func, dim) rv = value if rv is None else reduce_func([value, rv]) + assert rv is not None return rv - def nonempty_submobjects(self) -> list[Self]: + def nonempty_submobjects(self) -> Sequence[Mobject]: return [ submob for submob in self.submobjects @@ -2194,11 +2226,14 @@ def get_extremum_along_dim( ) values = np_points[:, dim] if key < 0: - return np.min(values) + rv: float = np.min(values) + return rv elif key == 0: - return (np.min(values) + np.max(values)) / 2 + rv = (np.min(values) + np.max(values)) / 2 + return rv else: - return np.max(values) + rv = np.max(values) + return rv def get_critical_point(self, direction: Vector3DLike) -> Point3D: """Picture a box bounding the :class:`~.Mobject`. Such a box has @@ -2223,7 +2258,7 @@ def get_critical_point(self, direction: Vector3DLike) -> Point3D: result[dim] = self.get_extremum_along_dim( all_points, dim=dim, - key=direction[dim], + key=np.array(direction[dim]), ) return result @@ -2298,14 +2333,16 @@ def get_nadir(self) -> Point3D: def length_over_dim(self, dim: int) -> float: """Measure the length of an :class:`~.Mobject` in a certain direction.""" - return self.reduce_across_dimension( + max_coord: float = self.reduce_across_dimension( max, dim, - ) - self.reduce_across_dimension(min, dim) + ) + min_coord: float = self.reduce_across_dimension(min, dim) + return max_coord - min_coord def get_coord(self, dim: int, direction: Vector3DLike = ORIGIN) -> float: """Meant to generalize ``get_x``, ``get_y`` and ``get_z``""" - return self.get_extremum_along_dim(dim=dim, key=direction[dim]) + return self.get_extremum_along_dim(dim=dim, key=np.array(direction)[dim]) def get_x(self, direction: Vector3DLike = ORIGIN) -> float: """Returns x Point3D of the center of the :class:`~.Mobject` as ``float``""" @@ -2369,19 +2406,19 @@ def match_color(self, mobject: Mobject) -> Self: """Match the color with the color of another :class:`~.Mobject`.""" return self.set_color(mobject.get_color()) - def match_dim_size(self, mobject: Mobject, dim: int, **kwargs) -> Self: + def match_dim_size(self, mobject: Mobject, dim: int, **kwargs: Any) -> Self: """Match the specified dimension with the dimension of another :class:`~.Mobject`.""" return self.rescale_to_fit(mobject.length_over_dim(dim), dim, **kwargs) - def match_width(self, mobject: Mobject, **kwargs) -> Self: + def match_width(self, mobject: Mobject, **kwargs: Any) -> Self: """Match the width with the width of another :class:`~.Mobject`.""" return self.match_dim_size(mobject, 0, **kwargs) - def match_height(self, mobject: Mobject, **kwargs) -> Self: + def match_height(self, mobject: Mobject, **kwargs: Any) -> Self: """Match the height with the height of another :class:`~.Mobject`.""" return self.match_dim_size(mobject, 1, **kwargs) - def match_depth(self, mobject: Mobject, **kwargs) -> Self: + def match_depth(self, mobject: Mobject, **kwargs: Any) -> Self: """Match the depth with the depth of another :class:`~.Mobject`.""" return self.match_dim_size(mobject, 2, **kwargs) @@ -2395,15 +2432,15 @@ def match_coord( direction=direction, ) - def match_x(self, mobject: Mobject, direction=ORIGIN) -> Self: + def match_x(self, mobject: Mobject, direction: Vector3DLike = ORIGIN) -> Self: """Match x coord. to the x coord. of another :class:`~.Mobject`.""" return self.match_coord(mobject, 0, direction) - def match_y(self, mobject: Mobject, direction=ORIGIN) -> Self: + def match_y(self, mobject: Mobject, direction: Vector3DLike = ORIGIN) -> Self: """Match y coord. to the x coord. of another :class:`~.Mobject`.""" return self.match_coord(mobject, 1, direction) - def match_z(self, mobject: Mobject, direction=ORIGIN) -> Self: + def match_z(self, mobject: Mobject, direction: Vector3DLike = ORIGIN) -> Self: """Match z coord. to the x coord. of another :class:`~.Mobject`.""" return self.match_coord(mobject, 2, direction) @@ -2430,14 +2467,15 @@ def align_to( # Family matters - def __getitem__(self, value): + def __getitem__(self, value: Any) -> Mobject | Group: self_list = self.split() if isinstance(value, slice): GroupClass = self.get_group_class() return GroupClass(*self_list.__getitem__(value)) - return self_list.__getitem__(value) + rv: Mobject | Group = self_list.__getitem__(value) + return rv - def __iter__(self): + def __iter__(self) -> Iterator[Mobject]: return iter(self.split()) def __len__(self) -> int: @@ -2451,11 +2489,11 @@ def get_mobject_type_class() -> type[Mobject]: """Return the base class of this mobject type.""" return Mobject - def split(self) -> list[Self]: - result = [self] if len(self.points) > 0 else [] + def split(self) -> list[Mobject]: + result: list[Mobject] = [self] if len(self.points) > 0 else [] return result + self.submobjects - def get_family(self, recurse: bool = True) -> list[Self]: + def get_family(self, recurse: bool = True) -> list[Mobject]: """Lists all mobjects in the hierarchy (family) of the given mobject, including the mobject itself and all its submobjects recursively. @@ -2489,7 +2527,7 @@ def get_family(self, recurse: bool = True) -> list[Self]: all_mobjects = [self] + list(it.chain(*sub_families)) return remove_list_redundancies(all_mobjects) - def family_members_with_points(self) -> list[Self]: + def family_members_with_points(self) -> list[Mobject]: """Filters the list of family members (generated by :meth:`.get_family`) to include only mobjects with points. Returns @@ -2520,7 +2558,7 @@ def arrange( direction: Vector3DLike = RIGHT, buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, center: bool = True, - **kwargs, + **kwargs: Any, ) -> Self: """Sorts :class:`~.Mobject` next to each other on screen. @@ -2556,7 +2594,7 @@ def arrange_in_grid( row_heights: Iterable[float | None] | None = None, col_widths: Iterable[float | None] | None = None, flow_order: str = "rd", - **kwargs, + **kwargs: Any, ) -> Self: """Arrange submobjects in a grid. @@ -2650,13 +2688,18 @@ def construct(self): start_pos = self.get_center() # get cols / rows values if given (implicitly) - def init_size(num, alignments, sizes): + def init_size( + num: int | None, + alignments: str | None, + sizes: Iterable[float | None] | None, + ) -> int | None: if num is not None: return num if alignments is not None: return len(alignments) if sizes is not None: - return len(sizes) + return len(list(sizes)) + return None cols = init_size(cols, col_alignments, col_widths) rows = init_size(rows, row_alignments, row_heights) @@ -2667,8 +2710,9 @@ def init_size(num, alignments, sizes): # make the grid as close to quadratic as possible. # choosing cols first can results in cols>rows. # This is favored over rows>cols since in general - # the sceene is wider than high. + # the scene is wider than high. if rows is None: + assert isinstance(cols, int) rows = math.ceil(len(mobs) / cols) if cols is None: cols = math.ceil(len(mobs) / rows) @@ -2683,25 +2727,29 @@ def init_size(num, alignments, sizes): buff_x = buff_y = buff # Initialize alignments correctly - def init_alignments(alignments, num, mapping, name, dir_): + def init_alignments( + alignments: str | None, + num: int, + char_to_direction: dict[str, Vector3D], + name: str, + dir_: Vector3D, + ) -> list[Vector3D]: if alignments is None: # Use cell_alignment as fallback return [cell_alignment * dir_] * num if len(alignments) != num: raise ValueError(f"{name}_alignments has a mismatching size.") - alignments = list(alignments) - for i in range(num): - alignments[i] = mapping[alignments[i]] - return alignments + alignment_directions = [char_to_direction[char] for char in alignments] + return alignment_directions - row_alignments = init_alignments( + row_alignment_directions = init_alignments( row_alignments, rows, {"u": UP, "c": ORIGIN, "d": DOWN}, "row", RIGHT, ) - col_alignments = init_alignments( + col_alignment_directions = init_alignments( col_alignments, cols, {"l": LEFT, "c": ORIGIN, "r": RIGHT}, @@ -2710,7 +2758,7 @@ def init_alignments(alignments, num, mapping, name, dir_): ) # Now row_alignment[r] + col_alignment[c] is the alignment in cell [r][c] - mapper = { + mapper: dict[str, Callable[[int, int], int]] = { "dr": lambda r, c: (rows - r - 1) + c * rows, "dl": lambda r, c: (rows - r - 1) + (cols - c - 1) * rows, "ur": lambda r, c: r + c * rows, @@ -2724,18 +2772,14 @@ def init_alignments(alignments, num, mapping, name, dir_): raise ValueError( 'flow_order must be one of the following values: "dr", "rd", "ld" "dl", "ru", "ur", "lu", "ul".', ) - flow_order = mapper[flow_order] + get_mob_index_by_position = mapper[flow_order] - # Reverse row_alignments and row_heights. Necessary since the + # Reverse row_alignment_directions and row_heights. Necessary since the # grid filling is handled bottom up for simplicity reasons. - def reverse(maybe_list): - if maybe_list is not None: - maybe_list = list(maybe_list) - maybe_list.reverse() - return maybe_list - - row_alignments = reverse(row_alignments) - row_heights = reverse(row_heights) + row_alignment_directions.reverse() + row_heights_list = list(row_heights) if row_heights is not None else [] + row_heights_list.reverse() + col_widths_list = list(col_widths) if col_widths is not None else [] placeholder = Mobject() # Used to fill up the grid temporarily, doesn't get added to the scene. @@ -2743,7 +2787,10 @@ def reverse(maybe_list): # properties of 0. mobs.extend([placeholder] * (rows * cols - len(mobs))) - grid = [[mobs[flow_order(r, c)] for c in range(cols)] for r in range(rows)] + grid = [ + [mobs[get_mob_index_by_position(r, c)] for c in range(cols)] + for r in range(rows) + ] measured_heigths = [ max(grid[r][c].height for c in range(cols)) for r in range(rows) @@ -2753,24 +2800,29 @@ def reverse(maybe_list): ] # Initialize row_heights / col_widths correctly using measurements as fallback - def init_sizes(sizes, num, measures, name): - if sizes is None: + def init_sizes( + sizes: list[float | None] | None, num: int, measures: list[float], name: str + ) -> list[float]: + if sizes is None or len(sizes) == 0: sizes = [None] * num if len(sizes) != num: raise ValueError(f"{name} has a mismatching size.") return [ - sizes[i] if sizes[i] is not None else measures[i] for i in range(num) + size if size is not None else measure + for size, measure in zip(sizes, measures, strict=False) ] - heights = init_sizes(row_heights, rows, measured_heigths, "row_heights") - widths = init_sizes(col_widths, cols, measured_widths, "col_widths") + heights = init_sizes(row_heights_list, rows, measured_heigths, "row_heights") + widths = init_sizes(col_widths_list, cols, measured_widths, "col_widths") - x, y = 0, 0 + x, y = 0.0, 0.0 for r in range(rows): x = 0 for c in range(cols): if grid[r][c] is not placeholder: - alignment = row_alignments[r] + col_alignments[c] + alignment = ( + row_alignment_directions[r] + col_alignment_directions[c] + ) line = Line( x * RIGHT + y * UP, (x + widths[c]) * RIGHT + (y + heights[r]) * UP, @@ -2834,7 +2886,7 @@ def construct(self): self.submobjects.reverse() # Just here to keep from breaking old scenes. - def arrange_submobjects(self, *args, **kwargs) -> Self: + def arrange_submobjects(self, *args: Any, **kwargs: Any) -> Self: """Arrange the position of :attr:`submobjects` with a small buffer. Examples @@ -2855,11 +2907,11 @@ def construct(self): """ return self.arrange(*args, **kwargs) - def sort_submobjects(self, *args, **kwargs) -> Self: + def sort_submobjects(self, *args: Any, **kwargs: Any) -> Self: """Sort the :attr:`submobjects`""" return self.sort(*args, **kwargs) - def shuffle_submobjects(self, *args, **kwargs) -> None: + def shuffle_submobjects(self, *args: Any, **kwargs: Any) -> None: """Shuffles the order of :attr:`submobjects` Examples @@ -2927,7 +2979,7 @@ def align_data(self, mobject: Mobject, skip_point_alignment: bool = False) -> No for m1, m2 in zip(self.submobjects, mobject.submobjects, strict=True): m1.align_data(m2) - def get_point_mobject(self, center=None): + def get_point_mobject(self, center: Point3DLike | None = None) -> Point: """The simplest :class:`~.Mobject` to be transformed to or from self. Should by a point of the appropriate type """ @@ -2943,7 +2995,7 @@ def align_points(self, mobject: Mobject) -> Self: mobject.align_points_with_larger(self) return self - def align_points_with_larger(self, larger_mobject: Mobject): + def align_points_with_larger(self, larger_mobject: Mobject) -> None: raise NotImplementedError("Please override in a child class.") def align_submobjects(self, mobject: Mobject) -> Self: @@ -2955,7 +3007,7 @@ def align_submobjects(self, mobject: Mobject) -> Self: mob2.add_n_more_submobjects(max(0, n1 - n2)) return self - def null_point_align(self, mobject: Mobject): + def null_point_align(self, mobject: Mobject) -> Self: """If a :class:`~.Mobject` with points is being aligned to one without, treat both as groups, and push the one with points into its own submobjects @@ -3000,7 +3052,7 @@ def add_n_more_submobjects(self, n: int) -> Self | None: self.submobjects = new_submobs return self - def repeat_submobject(self, submob: Mobject) -> Self: + def repeat_submobject(self, submob: Mobject) -> Mobject: return submob.copy() def interpolate( @@ -3076,7 +3128,9 @@ def construct(self): self.interpolate_color(mobject1, mobject2, alpha) return self - def interpolate_color(self, mobject1: Mobject, mobject2: Mobject, alpha: float): + def interpolate_color( + self, mobject1: Mobject, mobject2: Mobject, alpha: float + ) -> None: raise NotImplementedError("Please override in a child class.") def become( @@ -3307,25 +3361,25 @@ class Group(Mobject, metaclass=ConvertToOpenGL): be added to the group. """ - def __init__(self, *mobjects, **kwargs) -> None: + def __init__(self, *mobjects: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.add(*mobjects) class _AnimationBuilder: - def __init__(self, mobject) -> None: + def __init__(self, mobject: Mobject) -> None: self.mobject = mobject self.mobject.generate_target() - self.overridden_animation = None + self.overridden_animation: Animation | None = None self.is_chaining = False self.methods: list[MethodWithArgs] = [] # Whether animation args can be passed self.cannot_pass_args = False - self.anim_args = {} + self.anim_args: dict[str, Any] = {} - def __call__(self, **kwargs) -> Self: + def __call__(self, **kwargs: Any) -> Self: if self.cannot_pass_args: raise ValueError( "Animation arguments must be passed before accessing methods and can only be passed once", @@ -3336,7 +3390,7 @@ def __call__(self, **kwargs) -> Self: return self - def __getattr__(self, method_name) -> types.MethodType: + def __getattr__(self, method_name: str) -> Callable[..., _AnimationBuilder]: method = getattr(self.mobject.target, method_name) has_overridden_animation = hasattr(method, "_override_animate") @@ -3345,7 +3399,7 @@ def __getattr__(self, method_name) -> types.MethodType: "Method chaining is currently not supported for overridden animations", ) - def update_target(*method_args, **method_kwargs): + def update_target(*method_args: Any, **method_kwargs: Any) -> _AnimationBuilder: if has_overridden_animation: self.overridden_animation = method._override_animate( self.mobject, @@ -3364,9 +3418,7 @@ def update_target(*method_args, **method_kwargs): return update_target def build(self) -> Animation: - from ..animation.transform import ( # is this to prevent circular import? - _MethodAnimation, - ) + from ..animation.transform import _MethodAnimation anim = self.overridden_animation or _MethodAnimation(self.mobject, self.methods) @@ -3382,9 +3434,9 @@ class _UpdaterBuilder: def __init__(self, mobject: Mobject): self._mobject = mobject - def __getattr__(self, name: str, /) -> Callable[..., Self]: + def __getattr__(self, name: str, /) -> Callable[..., _UpdaterBuilder]: # just return a function that will add the updater - def add_updater(*method_args, **method_kwargs) -> Self: + def add_updater(*method_args: Any, **method_kwargs: Any) -> _UpdaterBuilder: self._mobject.add_updater( lambda m: getattr(m, name)(*method_args, **method_kwargs), call_updater=True, @@ -3394,7 +3446,9 @@ def add_updater(*method_args, **method_kwargs) -> Self: return add_updater -def override_animate(method) -> types.FunctionType: +def override_animate( + method: types.MethodType, +) -> Callable[[types.MethodType], types.MethodType]: r"""Decorator for overriding method animations. This allows to specify a method (returning an :class:`~.Animation`) @@ -3445,9 +3499,11 @@ def construct(self): self.wait() """ + temp_method = cast(_AnimationBuilder, method) - def decorator(animation_method): - method._override_animate = animation_method + def decorator(animation_method: types.MethodType) -> types.MethodType: + # error: "Callable[..., Animation]" has no attribute "_override_animate" [attr-defined] + temp_method._override_animate = animation_method # type: ignore[attr-defined] return animation_method return decorator diff --git a/manim/mobject/text/tex_mobject.py b/manim/mobject/text/tex_mobject.py index 8698e024b8..729fbb158b 100644 --- a/manim/mobject/text/tex_mobject.py +++ b/manim/mobject/text/tex_mobject.py @@ -587,7 +587,7 @@ def set_color_by_tex_to_color_map( self.id_to_vgroup_dict[match[1]].set_color(color) return self - def index_of_part(self, part: MathTex) -> int: + def index_of_part(self, part: VMobject) -> int: split_self = self.split() if part not in split_self: raise ValueError("Trying to get index of part not in MathTex") diff --git a/manim/mobject/text/text_mobject.py b/manim/mobject/text/text_mobject.py index 812a6b084f..d484420301 100644 --- a/manim/mobject/text/text_mobject.py +++ b/manim/mobject/text/text_mobject.py @@ -166,9 +166,12 @@ def __init__( lines_str_list = lines_str.split("\n") self.chars = self._gen_chars(lines_str_list) - self.lines = [list(self.chars), [self.alignment] * len(self.chars)] - self.lines_initial_positions = [line.get_center() for line in self.lines[0]] - self.add(*self.lines[0]) + # TODO: If possible get rid of self.lines_chars, as it seems to be a + # listified duplicate of self.chars. + self.lines_chars = list(self.chars) + self.lines_alignments = [self.alignment] * len(self.chars) + self.lines_initial_positions = [line.get_center() for line in self.lines_chars] + self.add(*self.lines_chars) self.move_to(np.array([0, 0, 0])) if self.alignment: self._set_all_lines_alignments(self.alignment) @@ -221,7 +224,7 @@ def _set_all_lines_alignments(self, alignment: str) -> Paragraph: alignment Defines the alignment of paragraph. Possible values are "left", "right", "center". """ - for line_no in range(len(self.lines[0])): + for line_no in range(len(self.lines_chars)): self._change_alignment_for_a_line(alignment, line_no) return self @@ -240,8 +243,8 @@ def _set_line_alignment(self, alignment: str, line_no: int) -> Paragraph: def _set_all_lines_to_initial_positions(self) -> Paragraph: """Set all lines to their initial positions.""" - self.lines[1] = [None] * len(self.lines[0]) - for line_no in range(len(self.lines[0])): + self.lines_alignments = [None] * len(self.lines_chars) + for line_no in range(len(self.lines_chars)): self[line_no].move_to( self.get_center() + self.lines_initial_positions[line_no], ) @@ -255,7 +258,7 @@ def _set_line_to_initial_position(self, line_no: int) -> Paragraph: line_no Defines the line number for which we want to set given alignment. """ - self.lines[1][line_no] = None + self.lines_alignments[line_no] = None self[line_no].move_to(self.get_center() + self.lines_initial_positions[line_no]) return self @@ -269,12 +272,12 @@ def _change_alignment_for_a_line(self, alignment: str, line_no: int) -> None: line_no Defines the line number for which we want to set given alignment. """ - self.lines[1][line_no] = alignment - if self.lines[1][line_no] == "center": + self.lines_alignments[line_no] = alignment + if self.lines_alignments[line_no] == "center": self[line_no].move_to( np.array([self.get_center()[0], self[line_no].get_center()[1], 0]), ) - elif self.lines[1][line_no] == "right": + elif self.lines_alignments[line_no] == "right": self[line_no].move_to( np.array( [ @@ -284,7 +287,7 @@ def _change_alignment_for_a_line(self, alignment: str, line_no: int) -> None: ], ), ) - elif self.lines[1][line_no] == "left": + elif self.lines_alignments[line_no] == "left": self[line_no].move_to( np.array( [ diff --git a/manim/mobject/three_d/three_dimensions.py b/manim/mobject/three_d/three_dimensions.py index c630c6464b..48297a67da 100644 --- a/manim/mobject/three_d/three_dimensions.py +++ b/manim/mobject/three_d/three_dimensions.py @@ -149,6 +149,7 @@ def __init__( self.pre_function_handle_to_anchor_scale_factor = ( pre_function_handle_to_anchor_scale_factor ) + self.list_of_faces: list[ThreeDVMobject] = [] self._func = func self._setup_in_uv_space() self.apply_function(lambda p: func(p[0], p[1])) @@ -172,6 +173,7 @@ def _get_u_values_and_v_values(self) -> tuple[np.ndarray, np.ndarray]: def _setup_in_uv_space(self) -> None: u_values, v_values = self._get_u_values_and_v_values() faces = VGroup() + self.list_of_faces = [] for i in range(len(u_values) - 1): for j in range(len(v_values) - 1): u1, u2 = u_values[i : i + 2] @@ -193,6 +195,7 @@ def _setup_in_uv_space(self) -> None: face.u2 = u2 face.v1 = v1 face.v2 = v2 + self.list_of_faces.append(face) faces.set_fill(color=self.fill_color, opacity=self.fill_opacity) faces.set_stroke( color=self.stroke_color, @@ -223,7 +226,7 @@ def set_fill_by_checkerboard( The parametric surface with an alternating pattern. """ n_colors = len(colors) - for face in self: + for face in self.list_of_faces: c_index = (face.u_index + face.v_index) % n_colors face.set_fill(colors[c_index], opacity=opacity) return self diff --git a/manim/mobject/types/image_mobject.py b/manim/mobject/types/image_mobject.py index f135f826d1..1e35c9c60a 100644 --- a/manim/mobject/types/image_mobject.py +++ b/manim/mobject/types/image_mobject.py @@ -68,7 +68,7 @@ def __init__( def get_pixel_array(self) -> PixelArray: raise NotImplementedError() - def set_color( # type: ignore[override] + def set_color( self, color: ParsableManimColor = YELLOW_C, alpha: Any = None, @@ -217,7 +217,7 @@ def get_pixel_array(self) -> PixelArray: """A simple getter method.""" return self.pixel_array - def set_color( # type: ignore[override] + def set_color( self, color: ParsableManimColor = YELLOW_C, alpha: Any = None, diff --git a/manim/mobject/types/vectorized_mobject.py b/manim/mobject/types/vectorized_mobject.py index f4d2c8b9dd..9aa02d05cd 100644 --- a/manim/mobject/types/vectorized_mobject.py +++ b/manim/mobject/types/vectorized_mobject.py @@ -47,6 +47,7 @@ from manim.utils.space_ops import rotate_vector, shoelace_direction if TYPE_CHECKING: + from collections.abc import Iterator from typing import Self import numpy.typing as npt @@ -103,6 +104,7 @@ class VMobject(Mobject): """ sheen_factor = 0.0 + target: VMobject def __init__( self, @@ -172,6 +174,9 @@ def __init__( def _assert_valid_submobjects(self, submobjects: Iterable[VMobject]) -> Self: return self._assert_valid_submobjects_internal(submobjects, VMobject) + def __iter__(self) -> Iterator[VMobject]: + return iter(self.split()) + # OpenGL compatibility @property def n_points_per_curve(self) -> int: @@ -630,6 +635,17 @@ def get_color(self) -> ManimColor: color: ManimColor = property(get_color, set_color) + def nonempty_submobjects(self) -> Sequence[VMobject]: + return [ + submob + for submob in self.submobjects + if len(submob.submobjects) != 0 or len(submob.points) != 0 + ] + + def split(self) -> list[VMobject]: + result: list[VMobject] = [self] if len(self.points) > 0 else [] + return result + self.submobjects + def set_sheen_direction(self, direction: Vector3DLike, family: bool = True) -> Self: """Sets the direction of the applied sheen. @@ -2303,6 +2319,11 @@ def __setitem__(self, key: int, value: VMobject | Sequence[VMobject]) -> None: self._assert_valid_submobjects(tuplify(value)) self.submobjects[key] = value + def __getitem__(self, key: int | slice) -> VMobject: + if isinstance(key, slice): + return VGroup(self.submobjects[key]) + return self.submobjects[key] + class VDict(VMobject, metaclass=ConvertToOpenGL): """A VGroup-like class, also offering submobject access by diff --git a/manim/scene/vector_space_scene.py b/manim/scene/vector_space_scene.py index ecb688ac2f..56afea2095 100644 --- a/manim/scene/vector_space_scene.py +++ b/manim/scene/vector_space_scene.py @@ -281,15 +281,19 @@ def get_basis_vector_labels(self, **kwargs: Any) -> VGroup: color (str), label_scale_factor=VECTOR_LABEL_SCALE_FACTOR (int, float), """ - i_hat, j_hat = self.get_basis_vectors() + i_hat = self.get_basis_vectors().submobjects[0] + j_hat = self.get_basis_vectors().submobjects[1] return VGroup( *( self.get_vector_label( vect, label, color=color, label_scale_factor=1, **kwargs ) for vect, label, color in [ - (i_hat, "\\hat{\\imath}", X_COLOR), - (j_hat, "\\hat{\\jmath}", Y_COLOR), + # Casting i_hat and j_hat to Vector, as the VGroup from + # self.get_basis_vectors() contains two vectors, but the + # type checker is currently not aware of that. + (cast(Vector, i_hat), "\\hat{\\imath}", X_COLOR), + (cast(Vector, j_hat), "\\hat{\\jmath}", Y_COLOR), ] ) ) @@ -517,7 +521,9 @@ def vector_to_coords( y_line = Line(x_line.get_end(), arrow.get_end()) x_line.set_color(X_COLOR) y_line.set_color(Y_COLOR) - x_coord, y_coord = cast(VGroup, array.get_entries()) + temp = array.get_entries() + x_coord = temp.submobjects[0] + y_coord = temp.submobjects[1] x_coord_start = self.position_x_coordinate(x_coord.copy(), x_line, vector) y_coord_start = self.position_y_coordinate(y_coord.copy(), y_line, vector) brackets = array.get_brackets() diff --git a/manim/utils/deprecation.py b/manim/utils/deprecation.py index 26f775c591..112c62fa5d 100644 --- a/manim/utils/deprecation.py +++ b/manim/utils/deprecation.py @@ -250,6 +250,9 @@ def deprecate(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: if type(func).__name__ != "function": deprecate_docs(func) + # The following line raises this mypy error: + # Accessing "__init__" on an instance is unsound, since instance.__init__ + # could be from an incompatible subclass [misc] func.__init__ = decorate(func.__init__, deprecate) return func diff --git a/mypy.ini b/mypy.ini index 4020e1f2e6..0ee928de4c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -85,9 +85,6 @@ ignore_errors = True [mypy-manim.mobject.logo] ignore_errors = True -[mypy-manim.mobject.mobject] -ignore_errors = True - [mypy-manim.mobject.opengl.opengl_point_cloud_mobject] ignore_errors = True @@ -100,6 +97,9 @@ ignore_errors = True [mypy-manim.mobject.table] ignore_errors = True +[mypy-manim.mobject.types.point_cloud_mobject] +ignore_errors = True + [mypy-manim.mobject.types.vectorized_mobject] ignore_errors = True