From 50a79af16d732516de12a74ce6211ae68d208b9d Mon Sep 17 00:00:00 2001 From: Darren Eberly Date: Thu, 12 Mar 2026 00:23:19 -0400 Subject: [PATCH 1/3] Rework HitBox to support multiple named collision regions Replaces the single-polygon HitBox with a multi-region system where each HitBox can contain named Point2List regions. This enables sprites to have distinct collision shapes (e.g. separate head/body/feet hitboxes) while maintaining backward compatibility with the single-region API. Key changes: - HitBox accepts either a Point2List (single "default" region) or a dict[str, Point2List] mapping region names to point lists - Merge RotatableHitBox into HitBox (removed RotatableHitBox) - Add serialization support (save/load to JSON, with gzip option) - Update collision detection, pymunk physics, sprite lists, and tilemap to work with multi-region hitboxes - Add sprite_multi_hitbox example - Fix type annotations for pyright/mypy compatibility Co-Authored-By: Claude Opus 4.6 --- arcade/examples/sprite_multi_hitbox.py | 200 +++++++++ arcade/gui/widgets/dropdown.py | 3 +- arcade/gui/widgets/layout.py | 2 +- arcade/hitbox/__init__.py | 6 +- arcade/hitbox/base.py | 396 +++++++++++------- arcade/pymunk_physics_engine.py | 88 ++-- arcade/sprite/base.py | 15 +- arcade/sprite/sprite.py | 26 +- arcade/sprite_list/collision.py | 16 +- arcade/sprite_list/sprite_list.py | 12 +- arcade/tilemap/tilemap.py | 6 +- .../unit/gui/test_layout_size_hint_warning.py | 1 + tests/unit/gui/test_uidropdown.py | 4 +- tests/unit/hitbox/test_hitbox.py | 177 +++++++- tests/unit/test_hexagon.py | 4 +- webplayground/local_scripts/example_test.py | 5 +- 16 files changed, 734 insertions(+), 227 deletions(-) create mode 100644 arcade/examples/sprite_multi_hitbox.py diff --git a/arcade/examples/sprite_multi_hitbox.py b/arcade/examples/sprite_multi_hitbox.py new file mode 100644 index 0000000000..47fa934d5c --- /dev/null +++ b/arcade/examples/sprite_multi_hitbox.py @@ -0,0 +1,200 @@ +""" +Sprite Multi-Region Hit Boxes + +Demonstrates sprites with multiple hit box regions. The player sprite +has a "body" region and a "shield" region extending to one side. +Coins that touch any region of the player are collected. + +Hit box outlines are drawn for visual debugging. Use A/D to rotate +the player and see both regions rotate together. + +If Python and Arcade are installed, this example can be run from the +command line with: +python -m arcade.examples.sprite_multi_hitbox +""" + +import random +import math +import arcade +from arcade.hitbox import HitBox + +WINDOW_WIDTH = 1280 +WINDOW_HEIGHT = 720 +WINDOW_TITLE = "Multi-Region Hit Box Example" + +PLAYER_SPEED = 5.0 +COIN_SPEED = 2.0 +COIN_COUNT = 20 + + +class GameView(arcade.View): + + def __init__(self): + super().__init__() + self.player_sprite = None + self.player_list = None + self.coin_list = None + self.score = 0 + self.score_display = None + self.background_color = arcade.csscolor.DARK_SLATE_GRAY + + self.left_pressed = False + self.right_pressed = False + self.up_pressed = False + self.down_pressed = False + + def setup(self): + self.player_list = arcade.SpriteList() + self.coin_list = arcade.SpriteList() + self.score = 0 + self.score_display = arcade.Text( + text="Score: 0", + x=10, y=WINDOW_HEIGHT - 30, + color=arcade.color.WHITE, font_size=16, + ) + + # Create the player sprite + img = ":resources:images/animated_characters/female_person/femalePerson_idle.png" + self.player_sprite = arcade.Sprite(img, scale=0.5) + self.player_sprite.position = WINDOW_WIDTH / 2, WINDOW_HEIGHT / 2 + + # Replace the default hitbox with a multi-region hitbox. + # "body" is a box around the torso, "shield" extends to the right. + # Collision detection automatically checks all regions. + self.player_sprite.hit_box = HitBox( + { + "body": [(-15, -48), (-15, 40), (15, 40), (15, -48)], + "shield": [(15, -30), (15, 30), (45, 30), (45, -30)], + }, + position=self.player_sprite.position, + scale=self.player_sprite.scale, + angle=self.player_sprite.angle, + ) + + self.player_list.append(self.player_sprite) + self._spawn_coins(COIN_COUNT) + + def _spawn_coins(self, count): + for _ in range(count): + coin = arcade.Sprite(":resources:images/items/coinGold.png", scale=0.4) + + # Spawn along a random edge + side = random.randint(0, 3) + if side == 0: + coin.center_x = random.randrange(WINDOW_WIDTH) + coin.center_y = WINDOW_HEIGHT + 20 + elif side == 1: + coin.center_x = random.randrange(WINDOW_WIDTH) + coin.center_y = -20 + elif side == 2: + coin.center_x = -20 + coin.center_y = random.randrange(WINDOW_HEIGHT) + else: + coin.center_x = WINDOW_WIDTH + 20 + coin.center_y = random.randrange(WINDOW_HEIGHT) + + # Aim toward the center with some randomness + target_x = WINDOW_WIDTH / 2 + random.randint(-200, 200) + target_y = WINDOW_HEIGHT / 2 + random.randint(-200, 200) + dx = target_x - coin.center_x + dy = target_y - coin.center_y + dist = math.hypot(dx, dy) + if dist > 0: + coin.change_x = (dx / dist) * COIN_SPEED + coin.change_y = (dy / dist) * COIN_SPEED + + self.coin_list.append(coin) + + def on_draw(self): + self.clear() + + self.coin_list.draw() + self.player_list.draw() + + # Debug: draw each hitbox region in a different color + player_hb = self.player_sprite.hit_box + for region_name in player_hb.region_names: + pts = player_hb.get_adjusted_points(region_name) + color = arcade.color.RED if region_name == "body" else arcade.color.CYAN + arcade.draw_line_strip(tuple(pts) + (pts[0],), color=color, line_width=2) + + self.coin_list.draw_hit_boxes(color=arcade.color.YELLOW, line_thickness=1) + self.score_display.draw() + + arcade.draw_text( + "Red = Body | Cyan = Shield | Arrow keys to move | A/D to rotate", + WINDOW_WIDTH / 2, 20, + arcade.color.WHITE, font_size=12, anchor_x="center", + ) + + def on_key_press(self, key, modifiers): + if key in (arcade.key.UP, arcade.key.W): + self.up_pressed = True + elif key in (arcade.key.DOWN, arcade.key.S): + self.down_pressed = True + elif key == arcade.key.LEFT: + self.left_pressed = True + elif key == arcade.key.RIGHT: + self.right_pressed = True + + def on_key_release(self, key, modifiers): + if key in (arcade.key.UP, arcade.key.W): + self.up_pressed = False + elif key in (arcade.key.DOWN, arcade.key.S): + self.down_pressed = False + elif key == arcade.key.LEFT: + self.left_pressed = False + elif key == arcade.key.RIGHT: + self.right_pressed = False + + def on_update(self, delta_time): + # Move the player + if self.up_pressed: + self.player_sprite.center_y += PLAYER_SPEED + if self.down_pressed: + self.player_sprite.center_y -= PLAYER_SPEED + if self.left_pressed: + self.player_sprite.center_x -= PLAYER_SPEED + if self.right_pressed: + self.player_sprite.center_x += PLAYER_SPEED + + # Rotate with A/D + keys = self.window.keyboard + if keys[arcade.key.A]: + self.player_sprite.angle -= 3.0 + if keys[arcade.key.D]: + self.player_sprite.angle += 3.0 + + # Move coins + self.coin_list.update() + + # Standard collision check — automatically tests all hitbox regions + hit_list = arcade.check_for_collision_with_list( + self.player_sprite, self.coin_list + ) + if hit_list: + for coin in hit_list: + coin.remove_from_sprite_lists() + self.score += 1 + self.score_display.text = f"Score: {self.score}" + + # Replace coins that left the screen + margin = 100 + for coin in list(self.coin_list): + if (coin.center_x < -margin or coin.center_x > WINDOW_WIDTH + margin + or coin.center_y < -margin or coin.center_y > WINDOW_HEIGHT + margin): + coin.remove_from_sprite_lists() + if len(self.coin_list) < COIN_COUNT: + self._spawn_coins(COIN_COUNT - len(self.coin_list)) + + +def main(): + window = arcade.Window(WINDOW_WIDTH, WINDOW_HEIGHT, WINDOW_TITLE) + game = GameView() + game.setup() + window.show_view(game) + arcade.run() + + +if __name__ == "__main__": + main() diff --git a/arcade/gui/widgets/dropdown.py b/arcade/gui/widgets/dropdown.py index 9cb59c2881..97802db1c6 100644 --- a/arcade/gui/widgets/dropdown.py +++ b/arcade/gui/widgets/dropdown.py @@ -294,8 +294,7 @@ def do_layout(self): overlay_w = self.width + scroll_bar_w overlay.rect = ( - overlay.rect - .resize(overlay_w, visible_h) + overlay.rect.resize(overlay_w, visible_h) .align_top(self.bottom - 2) .align_left(self._default_button.left) ) diff --git a/arcade/gui/widgets/layout.py b/arcade/gui/widgets/layout.py index ffd010689d..63f6396279 100644 --- a/arcade/gui/widgets/layout.py +++ b/arcade/gui/widgets/layout.py @@ -3,8 +3,8 @@ import warnings from collections.abc import Iterable from dataclasses import dataclass -from typing import Literal, TypeVar from types import EllipsisType +from typing import Literal, TypeVar from typing_extensions import override diff --git a/arcade/hitbox/__init__.py b/arcade/hitbox/__init__.py index 01086e7a9b..8bbf1c710b 100644 --- a/arcade/hitbox/__init__.py +++ b/arcade/hitbox/__init__.py @@ -2,7 +2,7 @@ from arcade.types import Point2List -from .base import HitBox, HitBoxAlgorithm, RotatableHitBox +from .base import HitBox, HitBoxAlgorithm from .bounding_box import BoundingHitBoxAlgorithm from .simple import SimpleHitBoxAlgorithm @@ -13,7 +13,8 @@ #: The detailed hit box algorithm. This depends on pymunk and will fallback to the simple algorithm. try: from .pymunk import PymunkHitBoxAlgorithm - algo_detailed = PymunkHitBoxAlgorithm() + + algo_detailed: HitBoxAlgorithm = PymunkHitBoxAlgorithm() except ImportError: print("WARNING: Running without PyMunk. The detailed hitbox algorithm will fallback to simple") algo_detailed = SimpleHitBoxAlgorithm() @@ -58,7 +59,6 @@ def calculate_hit_box_points_detailed( __all__ = [ "HitBoxAlgorithm", "HitBox", - "RotatableHitBox", "SimpleHitBoxAlgorithm", "PymunkHitBoxAlgorithm", "BoundingHitBoxAlgorithm", diff --git a/arcade/hitbox/base.py b/arcade/hitbox/base.py index 9225ebf5ab..c8b7c46ef5 100644 --- a/arcade/hitbox/base.py +++ b/arcade/hitbox/base.py @@ -1,6 +1,9 @@ from __future__ import annotations +import gzip +import json from math import cos, radians, sin +from pathlib import Path from typing import Any from PIL.Image import Image @@ -8,7 +11,7 @@ from arcade.types import EMPTY_POINT_LIST, Point2, Point2List -__all__ = ["HitBoxAlgorithm", "HitBox", "RotatableHitBox"] +__all__ = ["HitBoxAlgorithm", "HitBox"] class HitBoxAlgorithm: @@ -92,52 +95,134 @@ def create_bounding_box(self, image: Image) -> Point2List: class HitBox: """ - A basic hit box class supporting scaling. + A hit box with support for multiple named regions, scaling, and rotation. - It includes support for rescaling as well as shorthand properties - for boundary values along the X and Y axes. For rotation support, - use :py:meth:`.create_rotatable` to create an instance of - :py:class:`RotatableHitBox`. + Each region is a named polygon (sequence of points). A hitbox with a + single region can be constructed by passing a ``Point2List`` directly, + which creates a region named ``"default"``. For multiple regions, pass + a ``dict[str, Point2List]``. + + **Single-region construction** (backward compatible):: + + box = HitBox( + [(-10, -10), (10, -10), (10, 10), (-10, 10)] + ) + + **Multi-region construction** with a dict:: + + box = HitBox({ + "body": [(-10, -10), (10, -10), (10, 10), (-10, 10)], + "head": [(-5, 10), (5, 10), (5, 20), (-5, 20)], + }) + + **Rotation** (replaces the former ``RotatableHitBox`` class):: + + box = HitBox(points, angle=45.0) + # Angle can be updated later: + box.angle = 90.0 + + **Region management**:: + + box.add_region("shield", shield_points) + box.has_region("shield") # True + box.remove_region("shield") + + **Serialization** to and from JSON files:: + + box.save("hitbox.json") # plain JSON + box.save("hitbox.json.gz") # gzip-compressed + + loaded = HitBox.load("hitbox.json") + + # Dict round-trip + data = box.to_dict() + copy = HitBox.from_dict(data) Args: points: - The unmodified points bounding the hit box + Either a single ``Point2List`` (creates a ``"default"`` region) + or a ``dict[str, Point2List]`` mapping region names to point lists. position: - The center around which the points will be offset + The center around which the points will be offset. scale: - The X and Y scaling factors to use when offsetting the points + The X and Y scaling factors. + angle: + The rotation angle in degrees (clockwise). """ + DEFAULT_REGION = "default" + def __init__( self, - points: Point2List, + points: Point2List | dict[str, Point2List], position: Point2 = (0.0, 0.0), scale: Point2 = (1.0, 1.0), + angle: float = 0.0, ): - self._points = points + if isinstance(points, dict): + self._regions: dict[str, Point2List] = dict(points) + else: + self._regions = {self.DEFAULT_REGION: points} + self._position = position self._scale = scale + self._angle: float = angle + self._is_single_region: bool = len(self._regions) == 1 - # This empty tuple will be replaced the first time - # get_adjusted_points is called - self._adjusted_points: Point2List = EMPTY_POINT_LIST + # Cached adjusted points per region + self._adjusted_regions: dict[str, Point2List] = {} self._adjusted_cache_dirty = True @property def points(self) -> Point2List: """ - The raw, unadjusted points of this hit box. + The raw, unadjusted points of the default region. - These are the points as originally passed before offsetting, scaling, - and any operations subclasses may perform, such as rotation. + This is provided for backward compatibility. For multi-region + hitboxes, use :py:attr:`regions` instead. """ - return self._points + return self._regions.get(self.DEFAULT_REGION, EMPTY_POINT_LIST) @property - def position(self) -> Point2: + def regions(self) -> dict[str, Point2List]: + """All raw, unadjusted regions as a dict mapping names to point lists.""" + return self._regions + + @property + def region_names(self) -> tuple[str, ...]: + """The names of all regions in this hit box.""" + return tuple(self._regions.keys()) + + def has_region(self, name: str) -> bool: + """Check if a region with the given name exists.""" + return name in self._regions + + def add_region(self, name: str, points: Point2List) -> None: + """ + Add a named region to this hit box. + + Args: + name: The name for the new region. + points: The polygon points for the region. + """ + self._regions[name] = points + self._is_single_region = len(self._regions) == 1 + self._adjusted_cache_dirty = True + + def remove_region(self, name: str) -> None: """ - The center point used to offset the final adjusted positions. + Remove a named region from this hit box. + + Args: + name: The name of the region to remove. """ + del self._regions[name] + self._is_single_region = len(self._regions) == 1 + self._adjusted_cache_dirty = True + + @property + def position(self) -> Point2: + """The center point used to offset the final adjusted positions.""" return self._position @position.setter @@ -145,45 +230,59 @@ def position(self, position: Point2): self._position = position self._adjusted_cache_dirty = True - # Per Clepto's testing as of around May 2023, these are better - # left uncached because caching them is somehow slower than what - # we currently do. Any readers should feel free to retest / - # investigate further. + @property + def angle(self) -> float: + """The angle to rotate the raw points by in degrees.""" + return self._angle + + @angle.setter + def angle(self, angle: float): + self._angle = angle + self._adjusted_cache_dirty = True + @property def left(self) -> float: - """ - Calculates the leftmost adjusted x position of this hit box - """ - points = self.get_adjusted_points() - x_points = [point[0] for point in points] - return min(x_points) + """Calculates the leftmost adjusted x position across all regions.""" + self._recalculate_if_dirty() + min_x = float("inf") + for points in self._adjusted_regions.values(): + for point in points: + if point[0] < min_x: + min_x = point[0] + return min_x @property def right(self) -> float: - """ - Calculates the rightmost adjusted x position of this hit box - """ - points = self.get_adjusted_points() - x_points = [point[0] for point in points] - return max(x_points) + """Calculates the rightmost adjusted x position across all regions.""" + self._recalculate_if_dirty() + max_x = float("-inf") + for points in self._adjusted_regions.values(): + for point in points: + if point[0] > max_x: + max_x = point[0] + return max_x @property def top(self) -> float: - """ - Calculates the topmost adjusted y position of this hit box - """ - points = self.get_adjusted_points() - y_points = [point[1] for point in points] - return max(y_points) + """Calculates the topmost adjusted y position across all regions.""" + self._recalculate_if_dirty() + max_y = float("-inf") + for points in self._adjusted_regions.values(): + for point in points: + if point[1] > max_y: + max_y = point[1] + return max_y @property def bottom(self) -> float: - """ - Calculates the bottommost adjusted y position of this hit box - """ - points = self.get_adjusted_points() - y_points = [point[1] for point in points] - return min(y_points) + """Calculates the bottommost adjusted y position across all regions.""" + self._recalculate_if_dirty() + min_y = float("inf") + for points in self._adjusted_regions.values(): + for point in points: + if point[1] < min_y: + min_y = point[1] + return min_y @property def scale(self) -> tuple[float, float]: @@ -199,127 +298,138 @@ def scale(self, scale: tuple[float, float]): self._scale = scale self._adjusted_cache_dirty = True - def create_rotatable( - self, - angle: float = 0.0, - ) -> RotatableHitBox: - """ - Create a rotatable instance of this hit box. - - The internal ``PointList`` is transferred directly instead of - deep copied, so care should be taken if using a mutable internal - representation. - - Args: - angle: The angle to rotate points by (0 by default) - """ - return RotatableHitBox( - self._points, position=self._position, scale=self._scale, angle=angle - ) - - def get_adjusted_points(self) -> Point2List: - """ - Return the positions of points, scaled and offset from the center. - - Unlike the boundary helper properties (left, etc), this method will - only recalculate the values when necessary: - - * The first time this method is called - * After properties affecting adjusted position were changed - """ + def _recalculate_if_dirty(self) -> None: + """Recalculate all adjusted regions if the cache is dirty.""" if not self._adjusted_cache_dirty: - return self._adjusted_points # type: ignore + return - position_x, position_y = self._position + rad = radians(-self._angle) scale_x, scale_y = self._scale + position_x, position_y = self._position + rad_cos = cos(rad) + rad_sin = sin(rad) + do_rotate = bool(rad) - def _adjust_point(point) -> Point2: + def _adjust_point(point: Point2) -> Point2: x, y = point - x *= scale_x y *= scale_y + if do_rotate: + rot_x = x * rad_cos - y * rad_sin + rot_y = x * rad_sin + y * rad_cos + x = rot_x + y = rot_y + return (x + position_x, y + position_y) - self._adjusted_points = [_adjust_point(point) for point in self._points] + self._adjusted_regions = { + name: [_adjust_point(p) for p in pts] for name, pts in self._regions.items() + } self._adjusted_cache_dirty = False - return self._adjusted_points + def get_adjusted_points(self, region: str | None = None) -> Point2List: + """ + Return the positions of points, scaled, rotated, and offset. -class RotatableHitBox(HitBox): - """ - A hit box with support for rotation. + Args: + region: + The name of the region to get points for. If ``None``, + returns the default region's points (backward compatible). + """ + self._recalculate_if_dirty() + name = region if region is not None else self.DEFAULT_REGION + return self._adjusted_regions.get(name, EMPTY_POINT_LIST) - Rotation is separated from the basic hitbox because it is much - slower than offsetting and scaling. + def get_all_adjusted_polygons(self) -> list[Point2List]: + """ + Return adjusted points for all regions as a list of polygons. - Args: - points: - The unmodified points bounding the hit box - position: - The translation to apply to the points - angle: - The angle to rotate the points by - scale: - The X and Y scaling factors - """ + This is used by collision detection to check all regions. + """ + self._recalculate_if_dirty() + return list(self._adjusted_regions.values()) - def __init__( - self, - points: Point2List, - *, - position: tuple[float, float] = (0.0, 0.0), - angle: float = 0.0, - scale: Point2 = (1.0, 1.0), - ): - super().__init__(points, position=position, scale=scale) - self._angle: float = angle + # --- Serialization --- - @property - def angle(self) -> float: - """ - The angle to rotate the raw points by in degrees + def to_dict(self) -> dict: """ - return self._angle + Serialize the hitbox shape to a dictionary. - @angle.setter - def angle(self, angle: float): - self._angle = angle - self._adjusted_cache_dirty = True + Only the region definitions (point data) are serialized. + Position, scale, and angle are runtime state and are not included. + """ + return { + "version": 1, + "regions": {name: [list(p) for p in pts] for name, pts in self._regions.items()}, + } + + @classmethod + def from_dict( + cls, + data: dict, + position: Point2 = (0.0, 0.0), + scale: Point2 = (1.0, 1.0), + angle: float = 0.0, + ) -> HitBox: + """ + Create a HitBox from a serialized dictionary. - def get_adjusted_points(self) -> Point2List: + Args: + data: The dictionary to deserialize from. + position: The center offset. + scale: The scaling factors. + angle: The rotation angle in degrees. """ - Return the offset, scaled, & rotated points of this hitbox. + regions: dict[str, Point2List] = { + name: tuple(tuple(p) for p in pts) for name, pts in data["regions"].items() + } + return cls(points=regions, position=position, scale=scale, angle=angle) - As with :py:meth:`.HitBox.get_adjusted_points`, this method only - recalculates the adjusted values when necessary. + def save(self, path: str | Path) -> None: """ - if not self._adjusted_cache_dirty: - return self._adjusted_points + Save the hitbox shape definition to a JSON file. - rad = radians(-self._angle) - scale_x, scale_y = self._scale - position_x, position_y = self._position - rad_cos = cos(rad) - rad_sin = sin(rad) + If the path ends with ``.gz``, the file will be gzip-compressed. - def _adjust_point(point) -> Point2: - x, y = point + Args: + path: The file path to save to. + """ + path = Path(path) + data_str = json.dumps(self.to_dict()) + data_bytes = data_str.encode("utf-8") - x *= scale_x - y *= scale_y + if path.suffix == ".gz": + data_bytes = gzip.compress(data_bytes) - if rad: - rot_x = x * rad_cos - y * rad_sin - rot_y = x * rad_sin + y * rad_cos - x = rot_x - y = rot_y + with open(path, mode="wb") as fd: + fd.write(data_bytes) - return ( - x + position_x, - y + position_y, - ) + @classmethod + def load( + cls, + path: str | Path, + position: Point2 = (0.0, 0.0), + scale: Point2 = (1.0, 1.0), + angle: float = 0.0, + ) -> HitBox: + """ + Load a hitbox shape definition from a JSON file. - self._adjusted_points = [_adjust_point(point) for point in self._points] - self._adjusted_cache_dirty = False - return self._adjusted_points + If the path ends with ``.gz``, the file is assumed to be gzip-compressed. + + Args: + path: The file path to load from. + position: The center offset. + scale: The scaling factors. + angle: The rotation angle in degrees. + """ + path = Path(path) + if path.suffix == ".gz": + with gzip.open(path, mode="rb") as fd: + data = json.loads(fd.read()) + else: + with open(path) as fd: + data = json.loads(fd.read()) + + return cls.from_dict(data, position=position, scale=scale, angle=angle) diff --git a/arcade/pymunk_physics_engine.py b/arcade/pymunk_physics_engine.py index 7dde352585..1fa0282a66 100644 --- a/arcade/pymunk_physics_engine.py +++ b/arcade/pymunk_physics_engine.py @@ -22,9 +22,17 @@ class PymunkPhysicsObject: """Object that holds pymunk body/shape for a sprite.""" - def __init__(self, body: pymunk.Body | None = None, shape: pymunk.Shape | None = None): + def __init__( + self, + body: pymunk.Body | None = None, + shape: pymunk.Shape | None = None, + shapes: list[pymunk.Shape] | None = None, + ): self.body: pymunk.Body | None = body self.shape: pymunk.Shape | None = shape + self.shapes: list[pymunk.Shape] = ( + shapes if shapes is not None else ([shape] if shape is not None else []) + ) class PymunkException(Exception): @@ -318,30 +326,33 @@ def velocity_callback( if body_type == self.DYNAMIC: body.velocity_func = velocity_callback - # Set the physics shape to the sprite's hitbox - poly = sprite.hit_box.points - scaled_poly = [[x * sprite.scale_x for x in z] for z in poly] - shape = pymunk.Poly(body, scaled_poly, radius=radius) # type: ignore + # Set the physics shapes to the sprite's hitbox regions + shapes: list[pymunk.Shape] = [] + for region_points in sprite.hit_box.regions.values(): + scaled_poly = [[x * sprite.scale_x for x in z] for z in region_points] + shape = pymunk.Poly(body, scaled_poly, radius=radius) # type: ignore - # Set collision type, used in collision callbacks - if collision_type: - shape.collision_type = collision_type_id + # Set collision type, used in collision callbacks + if collision_type: + shape.collision_type = collision_type_id - # How bouncy is the shape? - if elasticity is not None: - shape.elasticity = elasticity + # How bouncy is the shape? + if elasticity is not None: + shape.elasticity = elasticity - # Set shapes friction - shape.friction = friction + # Set shapes friction + shape.friction = friction + + shapes.append(shape) # Create physics object and add to list - physics_object = PymunkPhysicsObject(body, shape) + physics_object = PymunkPhysicsObject(body, shape=shapes[0], shapes=shapes) self.sprites[sprite] = physics_object if body_type != self.STATIC: self.non_static_sprite_list.append(sprite) - # Add body and shape to pymunk engine - self.space.add(body, shape) + # Add body and shapes to pymunk engine + self.space.add(body, *shapes) # Register physics engine with sprite, so we can remove from physics engine # if we tell the sprite to go away. @@ -431,7 +442,8 @@ def remove_sprite(self, sprite: Sprite) -> None: """Remove a sprite from the physics engine.""" physics_object = self.sprites[sprite] self.space.remove(physics_object.body) # type: ignore - self.space.remove(physics_object.shape) # type: ignore + for s in physics_object.shapes: + self.space.remove(s) self.sprites.pop(sprite) if sprite in self.non_static_sprite_list: self.non_static_sprite_list.remove(sprite) @@ -454,7 +466,7 @@ def get_sprite_for_shape(self, shape: pymunk.Shape | None) -> Sprite | None: A sprite for the ``shape``; ``None`` if no sprite is known. """ for sprite in self.sprites: - if self.sprites[sprite].shape is shape: + if shape in self.sprites[sprite].shapes: return sprite return None @@ -601,23 +613,35 @@ def update_sprite(self, sprite: Sprite) -> None: sprite: The Sprite to update """ physics_object = self.sprites[sprite] - old_shape = physics_object.shape - assert old_shape is not None, """ - Tried to update the shape for a Sprite which does not currently have a shape + old_shapes = physics_object.shapes + assert old_shapes, """ + Tried to update the shape for a Sprite which does not currently have shapes """ - # Set the physics shape to the sprite's hitbox - poly = sprite.hit_box.points - scaled_poly = [[x * sprite.scale_x for x in z] for z in poly] - shape = pymunk.Poly(physics_object.body, scaled_poly, radius=old_shape.radius) # type: ignore - - shape.collision_type = old_shape.collision_type - shape.elasticity = old_shape.elasticity - shape.friction = old_shape.friction + # Preserve properties from the first old shape + old_shape = old_shapes[0] + collision_type = old_shape.collision_type + elasticity = old_shape.elasticity + friction = old_shape.friction + radius = old_shape.radius + + # Remove all old shapes + for s in old_shapes: + self.space.remove(s) + + # Create new shapes from all hitbox regions + new_shapes: list[pymunk.Shape] = [] + for region_points in sprite.hit_box.regions.values(): + scaled_poly = [[x * sprite.scale_x for x in z] for z in region_points] + shape = pymunk.Poly(physics_object.body, scaled_poly, radius=radius) # type: ignore + shape.collision_type = collision_type + shape.elasticity = elasticity + shape.friction = friction + new_shapes.append(shape) - self.space.remove(old_shape) - self.space.add(shape) - physics_object.shape = shape + self.space.add(*new_shapes) + physics_object.shape = new_shapes[0] + physics_object.shapes = new_shapes def resync_sprites(self) -> None: """ diff --git a/arcade/sprite/base.py b/arcade/sprite/base.py index 8ecec9657e..7ffaf8edc1 100644 --- a/arcade/sprite/base.py +++ b/arcade/sprite/base.py @@ -8,7 +8,7 @@ from arcade.exceptions import ReplacementWarning, warning from arcade.hitbox import HitBox from arcade.texture import Texture -from arcade.types import LRBT, AsFloat, Color, Point, Point2, Point2List, Rect, RGBOrA255 +from arcade.types import LRBT, AsFloat, Color, Point, Point2, Rect, RGBOrA255 from arcade.utils import copy_dunders_unimplemented if TYPE_CHECKING: @@ -787,10 +787,10 @@ def draw_hit_box(self, color: RGBOrA255 = BLACK, line_thickness: float = 2.0) -> How thick the box should be """ converted_color = Color.from_iterable(color) - points: Point2List = self.hit_box.get_adjusted_points() - # NOTE: This is a COPY operation. We don't want to modify the points. - points = tuple(points) + tuple(points[:-1]) - arcade.draw_line_strip(points, color=converted_color, line_width=line_thickness) + for polygon in self.hit_box.get_all_adjusted_polygons(): + # NOTE: This is a COPY operation. We don't want to modify the points. + points = tuple(polygon) + tuple(polygon[:-1]) + arcade.draw_line_strip(points, color=converted_color, line_width=line_thickness) # ---- Shortcut Methods ---- @@ -812,7 +812,10 @@ def collides_with_point(self, point: Point2) -> bool: from arcade.geometry import is_point_in_polygon x, y = point - return is_point_in_polygon(x, y, self.hit_box.get_adjusted_points()) + return any( + is_point_in_polygon(x, y, polygon) + for polygon in self.hit_box.get_all_adjusted_polygons() + ) def collides_with_sprite(self, other: BasicSprite) -> bool: """Will check if a sprite is overlapping (colliding) another Sprite. diff --git a/arcade/sprite/sprite.py b/arcade/sprite/sprite.py index 7f54dff850..a82a61879b 100644 --- a/arcade/sprite/sprite.py +++ b/arcade/sprite/sprite.py @@ -4,7 +4,7 @@ import arcade from arcade import Texture -from arcade.hitbox import HitBox, RotatableHitBox +from arcade.hitbox import HitBox from arcade.texture import get_default_texture from arcade.types import PathOrTexture, Point2 @@ -141,7 +141,12 @@ def __init__( self.guid: str | None = None """A unique id for debugging purposes.""" - self._hit_box: RotatableHitBox = self._hit_box.create_rotatable(angle=self._angle) + self._hit_box = HitBox( + self._texture.hit_box_points, + self._position, + self._scale, + angle=self._angle, + ) self._width = self._texture.width * self._scale[0] self._height = self._texture.height * self._scale[1] @@ -225,13 +230,8 @@ def hit_box(self) -> HitBox: return self._hit_box @hit_box.setter - def hit_box(self, hit_box: HitBox | RotatableHitBox) -> None: - if type(hit_box) is HitBox: - self._hit_box = hit_box.create_rotatable(self.angle) - else: - # Mypy doesn't seem to understand the type check above - # It still thinks hit_box can be a union here - self._hit_box = hit_box # type: ignore + def hit_box(self, hit_box: HitBox) -> None: + self._hit_box = hit_box @property def texture(self) -> Texture: @@ -251,11 +251,11 @@ def texture(self, texture: Texture) -> None: # If sprite is using default texture, update the hit box if self._texture is get_default_texture(): - self.hit_box = RotatableHitBox( + self.hit_box = HitBox( texture.hit_box_points, position=self._position, - angle=self.angle, scale=self._scale, + angle=self.angle, ) self._texture = texture @@ -425,9 +425,9 @@ def sync_hit_box_to_texture(self) -> None: """ Update the sprite's hit box to match the current texture's hit box. """ - self.hit_box = RotatableHitBox( + self.hit_box = HitBox( self.texture.hit_box_points, position=self._position, - angle=self.angle, scale=self._scale, + angle=self.angle, ) diff --git a/arcade/sprite_list/collision.py b/arcade/sprite_list/collision.py index 3734b51e3e..5a60f197b0 100644 --- a/arcade/sprite_list/collision.py +++ b/arcade/sprite_list/collision.py @@ -120,9 +120,9 @@ def _check_for_collision(sprite1: BasicSprite, sprite2: BasicSprite) -> bool: if distance > radius_sum_sq: return False - return are_polygons_intersecting( - sprite1.hit_box.get_adjusted_points(), sprite2.hit_box.get_adjusted_points() - ) + polys1 = sprite1.hit_box.get_all_adjusted_polygons() + polys2 = sprite2.hit_box.get_all_adjusted_polygons() + return any(are_polygons_intersecting(p1, p2) for p1 in polys1 for p2 in polys2) def _get_nearby_sprites( @@ -283,7 +283,10 @@ def get_sprites_at_point(point: Point, sprite_list: SpriteSequence[SpriteType]) return [ s for s in sprites_to_check - if is_point_in_polygon(point[0], point[1], s.hit_box.get_adjusted_points()) + if any( + is_point_in_polygon(point[0], point[1], polygon) + for polygon in s.hit_box.get_all_adjusted_polygons() + ) ] @@ -346,5 +349,8 @@ def get_sprites_in_rect(rect: Rect, sprite_list: SpriteSequence[SpriteType]) -> return [ s for s in sprites_to_check - if are_polygons_intersecting(rect_points, s.hit_box.get_adjusted_points()) + if any( + are_polygons_intersecting(rect_points, polygon) + for polygon in s.hit_box.get_all_adjusted_polygons() + ) ] diff --git a/arcade/sprite_list/sprite_list.py b/arcade/sprite_list/sprite_list.py index d24641dea6..c627e14a4e 100644 --- a/arcade/sprite_list/sprite_list.py +++ b/arcade/sprite_list/sprite_list.py @@ -989,12 +989,12 @@ def draw_hit_boxes( # TODO: Make this faster in the future # NOTE: This will be easier when/if we change to triangles for sprite in self.sprite_list: - adjusted_points = sprite.hit_box.get_adjusted_points() - for i in range(len(adjusted_points) - 1): - points.append(adjusted_points[i]) - points.append(adjusted_points[i + 1]) - points.append(adjusted_points[-1]) - points.append(adjusted_points[0]) + for adjusted_points in sprite.hit_box.get_all_adjusted_polygons(): + for i in range(len(adjusted_points) - 1): + points.append(adjusted_points[i]) + points.append(adjusted_points[i + 1]) + points.append(adjusted_points[-1]) + points.append(adjusted_points[0]) arcade.draw_lines(points, color=converted_color, line_width=line_thickness) diff --git a/arcade/tilemap/tilemap.py b/arcade/tilemap/tilemap.py index f1415ae86c..2939b25b7a 100644 --- a/arcade/tilemap/tilemap.py +++ b/arcade/tilemap/tilemap.py @@ -31,7 +31,7 @@ get_window, hexagon, ) -from arcade.hitbox import HitBoxAlgorithm, RotatableHitBox +from arcade.hitbox import HitBox, HitBoxAlgorithm from arcade.types import RGBA255 from arcade.types import Color as ArcadeColor @@ -617,11 +617,11 @@ def _create_sprite_from_tile( if tile.flipped_diagonally: points = [(point[1], point[0]) for point in points] - my_sprite.hit_box = RotatableHitBox( + my_sprite.hit_box = HitBox( cast(list[Point2], points), position=my_sprite.position, - angle=my_sprite.angle, scale=my_sprite.scale, + angle=my_sprite.angle, ) if tile.animation: diff --git a/tests/unit/gui/test_layout_size_hint_warning.py b/tests/unit/gui/test_layout_size_hint_warning.py index a0e5c29103..40020db966 100644 --- a/tests/unit/gui/test_layout_size_hint_warning.py +++ b/tests/unit/gui/test_layout_size_hint_warning.py @@ -1,4 +1,5 @@ """Tests that layouts warn when explicit width/height conflicts with active size_hint.""" + import warnings import pytest diff --git a/tests/unit/gui/test_uidropdown.py b/tests/unit/gui/test_uidropdown.py index 22af8dd79c..8741969207 100644 --- a/tests/unit/gui/test_uidropdown.py +++ b/tests/unit/gui/test_uidropdown.py @@ -164,7 +164,9 @@ def test_dropdown_value_setter_updates_button_text(ui): def test_dropdown_few_options_no_scrolling(ui): - dropdown = UIDropdown(options=["Apple", "Banana", "Cherry"], width=200, height=30, max_height=200) + dropdown = UIDropdown( + options=["Apple", "Banana", "Cherry"], width=200, height=30, max_height=200 + ) anchor = ui.add(UIAnchorLayout()) anchor.add(dropdown, anchor_x="center", anchor_y="center") ui.execute_layout() diff --git a/tests/unit/hitbox/test_hitbox.py b/tests/unit/hitbox/test_hitbox.py index ba601bd109..54c16990ed 100644 --- a/tests/unit/hitbox/test_hitbox.py +++ b/tests/unit/hitbox/test_hitbox.py @@ -1,3 +1,7 @@ +import json +import tempfile +from pathlib import Path + import pytest from arcade import hitbox @@ -19,6 +23,7 @@ def test_create(): assert hb.get_adjusted_points() == points assert hb.position == (0.0, 0.0) assert hb.scale == (1.0, 1.0) + assert hb.angle == 0.0 assert hb.bottom == 0.0 assert hb.top == 10.0 assert hb.left == 0.0 @@ -39,14 +44,170 @@ def test_position(): assert hb.get_adjusted_points() == [(10.0, 10.0), (10.0, 20.0), (20.0, 20.0), (20.0, 10.0)] -def test_create_rotatable(): - hb = hitbox.HitBox(points) - rot = hb.create_rotatable() - assert rot.angle == 0.0 - assert rot.position == (0.0, 0.0) - rot.angle = 90.0 - assert rot.angle == 90.0 +def test_rotation(): + hb = hitbox.HitBox(points, angle=0.0) + assert hb.angle == 0.0 + assert hb.position == (0.0, 0.0) + hb.angle = 90.0 + assert hb.angle == 90.0 + + rot_p = hb.get_adjusted_points() + for i, (a, b) in enumerate(zip(rot_90, rot_p)): + assert a == pytest.approx(b, abs=1e-6), f"[{i}] {a} != {b}" + - rot_p = rot.get_adjusted_points() +def test_angle_constructor(): + hb = hitbox.HitBox(points, angle=90.0) + rot_p = hb.get_adjusted_points() for i, (a, b) in enumerate(zip(rot_90, rot_p)): assert a == pytest.approx(b, abs=1e-6), f"[{i}] {a} != {b}" + + +# --- Multi-region tests --- + + +def test_multi_region_create(): + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + hb = hitbox.HitBox({"body": body_pts, "head": head_pts}) + assert hb.region_names == ("body", "head") + assert hb.has_region("body") + assert hb.has_region("head") + assert not hb.has_region("default") + assert hb.regions["body"] == body_pts + assert hb.regions["head"] == head_pts + + +def test_multi_region_adjusted(): + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + hb = hitbox.HitBox({"body": body_pts, "head": head_pts}, position=(5.0, 5.0)) + body_adj = hb.get_adjusted_points("body") + head_adj = hb.get_adjusted_points("head") + assert body_adj == [(5.0, 5.0), (5.0, 15.0), (15.0, 15.0), (15.0, 5.0)] + assert head_adj == [(7.0, 15.0), (7.0, 20.0), (13.0, 20.0), (13.0, 15.0)] + + +def test_multi_region_boundaries(): + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + hb = hitbox.HitBox({"body": body_pts, "head": head_pts}) + # Boundaries span all regions + assert hb.left == 0.0 + assert hb.right == 10.0 + assert hb.bottom == 0.0 + assert hb.top == 15.0 + + +def test_get_all_adjusted_polygons(): + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + hb = hitbox.HitBox({"body": body_pts, "head": head_pts}) + all_polys = hb.get_all_adjusted_polygons() + assert len(all_polys) == 2 + + +def test_add_remove_region(): + hb = hitbox.HitBox(points) + assert hb.has_region("default") + assert len(hb.region_names) == 1 + + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + hb.add_region("head", head_pts) + assert hb.has_region("head") + assert len(hb.region_names) == 2 + assert hb.top == 15.0 + + hb.remove_region("head") + assert not hb.has_region("head") + assert len(hb.region_names) == 1 + assert hb.top == 10.0 + + +def test_default_region_points(): + hb = hitbox.HitBox(points) + assert hb.points == points + assert hb.get_adjusted_points() == hb.get_adjusted_points("default") + + +def test_single_region_fast_path(): + hb = hitbox.HitBox(points) + polys = hb.get_all_adjusted_polygons() + assert len(polys) == 1 + assert polys[0] == list(points) + + +# --- Serialization tests --- + + +def test_to_dict_single_region(): + hb = hitbox.HitBox(points) + d = hb.to_dict() + assert d["version"] == 1 + assert "default" in d["regions"] + assert len(d["regions"]) == 1 + + +def test_to_dict_multi_region(): + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + hb = hitbox.HitBox({"body": body_pts, "head": head_pts}) + d = hb.to_dict() + assert d["version"] == 1 + assert "body" in d["regions"] + assert "head" in d["regions"] + + +def test_from_dict(): + d = { + "version": 1, + "regions": { + "default": [[0.0, 0.0], [0.0, 10.0], [10.0, 10.0], [10.0, 0.0]], + }, + } + hb = hitbox.HitBox.from_dict(d) + assert hb.points == ((0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)) + assert hb.get_adjusted_points() == [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + + +def test_roundtrip_dict(): + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + hb = hitbox.HitBox({"body": body_pts, "head": head_pts}) + d = hb.to_dict() + hb2 = hitbox.HitBox.from_dict(d, position=(5.0, 5.0)) + assert hb2.has_region("body") + assert hb2.has_region("head") + assert hb2.position == (5.0, 5.0) + + +def test_save_load_json(): + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + hb = hitbox.HitBox({"body": body_pts, "head": head_pts}) + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: + path = Path(f.name) + + try: + hb.save(path) + hb2 = hitbox.HitBox.load(path, position=(1.0, 2.0)) + assert hb2.has_region("body") + assert hb2.has_region("head") + assert hb2.position == (1.0, 2.0) + finally: + path.unlink(missing_ok=True) + + +def test_save_load_gzip(): + hb = hitbox.HitBox(points) + + with tempfile.NamedTemporaryFile(suffix=".gz", delete=False) as f: + path = Path(f.name) + + try: + hb.save(path) + hb2 = hitbox.HitBox.load(path) + assert hb2.points == ((0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)) + finally: + path.unlink(missing_ok=True) diff --git a/tests/unit/test_hexagon.py b/tests/unit/test_hexagon.py index ca19893b68..9ebe464103 100644 --- a/tests/unit/test_hexagon.py +++ b/tests/unit/test_hexagon.py @@ -92,7 +92,9 @@ def test_hex_round(): a = HexTile(0.0, 0.0, 0.0) b = HexTile(1.0, -1.0, 0.0) c = HexTile(0.0, -1.0, 1.0) - assert HexTile(5, -10, 5) == round(HexTile(0.0, 0.0, 0.0).lerp_between(HexTile(10.0, -20.0, 10.0), 0.5)) + assert HexTile(5, -10, 5) == round( + HexTile(0.0, 0.0, 0.0).lerp_between(HexTile(10.0, -20.0, 10.0), 0.5) + ) assert round(a) == round(a.lerp_between(b, 0.499)) assert round(b) == round(a.lerp_between(b, 0.501)) diff --git a/webplayground/local_scripts/example_test.py b/webplayground/local_scripts/example_test.py index a2a70650b9..dbe1784bbb 100644 --- a/webplayground/local_scripts/example_test.py +++ b/webplayground/local_scripts/example_test.py @@ -24,7 +24,7 @@ def on_draw(self): arcade.color.WHITE, font_size=30, anchor_x="center", - anchor_y="center" + anchor_y="center", ) arcade.draw_text( "Edit this file and refresh to see changes", @@ -33,11 +33,10 @@ def on_draw(self): arcade.color.WHITE, font_size=16, anchor_x="center", - anchor_y="center" + anchor_y="center", ) if __name__ == "__main__": window = MyWindow() arcade.run() - From 0f6d7cc643d7079f76c7ea3c949f6252b5b37873 Mon Sep 17 00:00:00 2001 From: Darren Eberly Date: Thu, 12 Mar 2026 11:26:07 -0400 Subject: [PATCH 2/3] Add RawHitBox TypedDict and normalize points to tuples Add a RawHitBox TypedDict for typed serialization in to_dict/from_dict. Normalize all points to tuples of tuples on construction, add_region, and in adjusted point calculation for immutability and consistency. Update tests to expect tuple-based comparisons. Co-Authored-By: Claude Opus 4.6 --- arcade/hitbox/__init__.py | 3 +- arcade/hitbox/base.py | 47 ++++++++++++++++--------- tests/unit/hitbox/test_hitbox.py | 20 +++++------ tests/unit/sprite/test_sprite_hitbox.py | 10 +++--- 4 files changed, 47 insertions(+), 33 deletions(-) diff --git a/arcade/hitbox/__init__.py b/arcade/hitbox/__init__.py index 8bbf1c710b..d7ebfd852c 100644 --- a/arcade/hitbox/__init__.py +++ b/arcade/hitbox/__init__.py @@ -2,7 +2,7 @@ from arcade.types import Point2List -from .base import HitBox, HitBoxAlgorithm +from .base import HitBox, HitBoxAlgorithm, RawHitBox from .bounding_box import BoundingHitBoxAlgorithm from .simple import SimpleHitBoxAlgorithm @@ -59,6 +59,7 @@ def calculate_hit_box_points_detailed( __all__ = [ "HitBoxAlgorithm", "HitBox", + "RawHitBox", "SimpleHitBoxAlgorithm", "PymunkHitBoxAlgorithm", "BoundingHitBoxAlgorithm", diff --git a/arcade/hitbox/base.py b/arcade/hitbox/base.py index c8b7c46ef5..787a1c0d4e 100644 --- a/arcade/hitbox/base.py +++ b/arcade/hitbox/base.py @@ -4,14 +4,21 @@ import json from math import cos, radians, sin from pathlib import Path -from typing import Any +from typing import Any, TypedDict from PIL.Image import Image from typing_extensions import Self from arcade.types import EMPTY_POINT_LIST, Point2, Point2List -__all__ = ["HitBoxAlgorithm", "HitBox"] +__all__ = ["HitBoxAlgorithm", "HitBox", "RawHitBox"] + + +class RawHitBox(TypedDict): + """Typed dictionary representing the serialized form of a :py:class:`HitBox`.""" + + version: int + regions: dict[str, Point2List] class HitBoxAlgorithm: @@ -134,14 +141,21 @@ class HitBox: loaded = HitBox.load("hitbox.json") - # Dict round-trip + # Dict round-trip (see RawHitBox for the schema) data = box.to_dict() copy = HitBox.from_dict(data) + .. note:: + + All points are normalized to tuples of tuples on construction. + Any sequence type is accepted as input, but regions will always + store tuples internally. + Args: points: Either a single ``Point2List`` (creates a ``"default"`` region) or a ``dict[str, Point2List]`` mapping region names to point lists. + Points are normalized to tuples on storage. position: The center around which the points will be offset. scale: @@ -160,9 +174,11 @@ def __init__( angle: float = 0.0, ): if isinstance(points, dict): - self._regions: dict[str, Point2List] = dict(points) + self._regions: dict[str, Point2List] = { + name: tuple(tuple(p) for p in pts) for name, pts in points.items() + } else: - self._regions = {self.DEFAULT_REGION: points} + self._regions = {self.DEFAULT_REGION: tuple(tuple(p) for p in points)} self._position = position self._scale = scale @@ -205,7 +221,7 @@ def add_region(self, name: str, points: Point2List) -> None: name: The name for the new region. points: The polygon points for the region. """ - self._regions[name] = points + self._regions[name] = tuple(tuple(p) for p in points) self._is_single_region = len(self._regions) == 1 self._adjusted_cache_dirty = True @@ -324,7 +340,7 @@ def _adjust_point(point: Point2) -> Point2: return (x + position_x, y + position_y) self._adjusted_regions = { - name: [_adjust_point(p) for p in pts] for name, pts in self._regions.items() + name: tuple(_adjust_point(p) for p in pts) for name, pts in self._regions.items() } self._adjusted_cache_dirty = False @@ -352,39 +368,36 @@ def get_all_adjusted_polygons(self) -> list[Point2List]: # --- Serialization --- - def to_dict(self) -> dict: + def to_dict(self) -> RawHitBox: """ - Serialize the hitbox shape to a dictionary. + Serialize the hitbox shape to a :py:class:`RawHitBox` dictionary. Only the region definitions (point data) are serialized. Position, scale, and angle are runtime state and are not included. """ return { "version": 1, - "regions": {name: [list(p) for p in pts] for name, pts in self._regions.items()}, + "regions": {name: pts for name, pts in self._regions.items()}, } @classmethod def from_dict( cls, - data: dict, + data: RawHitBox, position: Point2 = (0.0, 0.0), scale: Point2 = (1.0, 1.0), angle: float = 0.0, ) -> HitBox: """ - Create a HitBox from a serialized dictionary. + Create a HitBox from a :py:class:`RawHitBox` dictionary. Args: - data: The dictionary to deserialize from. + data: A :py:class:`RawHitBox` dictionary to deserialize from. position: The center offset. scale: The scaling factors. angle: The rotation angle in degrees. """ - regions: dict[str, Point2List] = { - name: tuple(tuple(p) for p in pts) for name, pts in data["regions"].items() - } - return cls(points=regions, position=position, scale=scale, angle=angle) + return cls(points=data["regions"], position=position, scale=scale, angle=angle) def save(self, path: str | Path) -> None: """ diff --git a/tests/unit/hitbox/test_hitbox.py b/tests/unit/hitbox/test_hitbox.py index 54c16990ed..5bdabd7bf8 100644 --- a/tests/unit/hitbox/test_hitbox.py +++ b/tests/unit/hitbox/test_hitbox.py @@ -5,8 +5,8 @@ import pytest from arcade import hitbox -points = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] -rot_90 = [(0.0, 0.0), (10.0, 0), (10.0, -10.0), (0.0, -10.0)] +points = ((0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)) +rot_90 = ((0.0, 0.0), (10.0, 0), (10.0, -10.0), (0.0, -10.0)) def test_module(): @@ -34,14 +34,14 @@ def test_scale(): hb = hitbox.HitBox(points) hb.scale = (2.0, 2.0) assert hb.scale == (2.0, 2.0) - assert hb.get_adjusted_points() == [(0.0, 0.0), (0.0, 20.0), (20.0, 20.0), (20.0, 0.0)] + assert hb.get_adjusted_points() == ((0.0, 0.0), (0.0, 20.0), (20.0, 20.0), (20.0, 0.0)) def test_position(): hb = hitbox.HitBox(points) hb.position = (10.0, 10.0) assert hb.position == (10.0, 10.0) - assert hb.get_adjusted_points() == [(10.0, 10.0), (10.0, 20.0), (20.0, 20.0), (20.0, 10.0)] + assert hb.get_adjusted_points() == ((10.0, 10.0), (10.0, 20.0), (20.0, 20.0), (20.0, 10.0)) def test_rotation(): @@ -74,8 +74,8 @@ def test_multi_region_create(): assert hb.has_region("body") assert hb.has_region("head") assert not hb.has_region("default") - assert hb.regions["body"] == body_pts - assert hb.regions["head"] == head_pts + assert hb.regions["body"] == tuple(tuple(p) for p in body_pts) + assert hb.regions["head"] == tuple(tuple(p) for p in head_pts) def test_multi_region_adjusted(): @@ -84,8 +84,8 @@ def test_multi_region_adjusted(): hb = hitbox.HitBox({"body": body_pts, "head": head_pts}, position=(5.0, 5.0)) body_adj = hb.get_adjusted_points("body") head_adj = hb.get_adjusted_points("head") - assert body_adj == [(5.0, 5.0), (5.0, 15.0), (15.0, 15.0), (15.0, 5.0)] - assert head_adj == [(7.0, 15.0), (7.0, 20.0), (13.0, 20.0), (13.0, 15.0)] + assert body_adj == ((5.0, 5.0), (5.0, 15.0), (15.0, 15.0), (15.0, 5.0)) + assert head_adj == ((7.0, 15.0), (7.0, 20.0), (13.0, 20.0), (13.0, 15.0)) def test_multi_region_boundaries(): @@ -134,7 +134,7 @@ def test_single_region_fast_path(): hb = hitbox.HitBox(points) polys = hb.get_all_adjusted_polygons() assert len(polys) == 1 - assert polys[0] == list(points) + assert polys[0] == points # --- Serialization tests --- @@ -167,7 +167,7 @@ def test_from_dict(): } hb = hitbox.HitBox.from_dict(d) assert hb.points == ((0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)) - assert hb.get_adjusted_points() == [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + assert hb.get_adjusted_points() == ((0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)) def test_roundtrip_dict(): diff --git a/tests/unit/sprite/test_sprite_hitbox.py b/tests/unit/sprite/test_sprite_hitbox.py index 13bbf1f407..aa79b4d6f0 100644 --- a/tests/unit/sprite/test_sprite_hitbox.py +++ b/tests/unit/sprite/test_sprite_hitbox.py @@ -15,27 +15,27 @@ def test_1(): print() hitbox = my_sprite.hit_box.get_adjusted_points() print(f"Hitbox: {my_sprite.scale} -> {my_sprite.hit_box.points} -> {hitbox}") - assert hitbox == [(90.0, 90.0), (90.0, 110.0), (110.0, 110.0), (110.0, 90.0)] + assert hitbox == ((90.0, 90.0), (90.0, 110.0), (110.0, 110.0), (110.0, 90.0)) my_sprite.scale = 0.5, 0.5 hitbox = my_sprite.hit_box.get_adjusted_points() print(f"Hitbox: {my_sprite.scale} -> {my_sprite.hit_box.points} -> {hitbox}") - assert hitbox == [(95.0, 95.0), (95.0, 105.0), (105.0, 105.0), (105.0, 95.0)] + assert hitbox == ((95.0, 95.0), (95.0, 105.0), (105.0, 105.0), (105.0, 95.0)) my_sprite.scale = 1.0 hitbox = my_sprite.hit_box.get_adjusted_points() print(f"Hitbox: {my_sprite.scale} -> {my_sprite.hit_box.points} -> {hitbox}") - assert hitbox == [(90.0, 90.0), (90.0, 110.0), (110.0, 110.0), (110.0, 90.0)] + assert hitbox == ((90.0, 90.0), (90.0, 110.0), (110.0, 110.0), (110.0, 90.0)) my_sprite.scale = 2.0 hitbox = my_sprite.hit_box.get_adjusted_points() print(f"Hitbox: {my_sprite.scale} -> {my_sprite.hit_box.points} -> {hitbox}") - assert hitbox == [(80.0, 80.0), (80.0, 120.0), (120.0, 120.0), (120.0, 80.0)] + assert hitbox == ((80.0, 80.0), (80.0, 120.0), (120.0, 120.0), (120.0, 80.0)) my_sprite.scale = 2.0 hitbox = my_sprite.hit_box.get_adjusted_points() print(f"Hitbox: {my_sprite.scale} -> {my_sprite.hit_box.points} -> {hitbox}") - assert hitbox == [(80.0, 80.0), (80.0, 120.0), (120.0, 120.0), (120.0, 80.0)] + assert hitbox == ((80.0, 80.0), (80.0, 120.0), (120.0, 120.0), (120.0, 80.0)) def test_2(): From 28f322a227904ca177e140e714df511149993927 Mon Sep 17 00:00:00 2001 From: Darren Eberly Date: Fri, 13 Mar 2026 00:56:00 -0400 Subject: [PATCH 3/3] Add per-region collision channel/mask filtering to HitBox Each hitbox region can now have a channel (what it is) and mask (what it collides with) bitmask. Collision only occurs when both sides agree: (A.channel & B.mask) != 0 AND (B.channel & A.mask) != 0. Adds channels() utility for converting 1-based channel numbers to bitmasks, RawHitBox v2 serialization format, and updates the multi-hitbox example to demonstrate channel-based filtering. Co-Authored-By: Claude Opus 4.6 --- arcade/examples/sprite_multi_hitbox.py | 142 ++++++++++++------ arcade/hitbox/__init__.py | 14 +- arcade/hitbox/base.py | 199 ++++++++++++++++++++++-- arcade/sprite_list/collision.py | 29 +++- tests/unit/hitbox/test_hitbox.py | 200 ++++++++++++++++++++++++- 5 files changed, 515 insertions(+), 69 deletions(-) diff --git a/arcade/examples/sprite_multi_hitbox.py b/arcade/examples/sprite_multi_hitbox.py index 47fa934d5c..c9714a4e9b 100644 --- a/arcade/examples/sprite_multi_hitbox.py +++ b/arcade/examples/sprite_multi_hitbox.py @@ -1,9 +1,13 @@ """ -Sprite Multi-Region Hit Boxes +Sprite Multi-Region Hit Boxes with Collision Channels -Demonstrates sprites with multiple hit box regions. The player sprite -has a "body" region and a "shield" region extending to one side. -Coins that touch any region of the player are collected. +Demonstrates sprites with multiple hit box regions and per-region +collision channels. The player sprite has a "body" region on channel 1 +and a "shield" region on channel 2. + +Gold coins are on channel 1 and can only be collected by the body. +Blue gems are on channel 2 and can only be collected by the shield. +Coins pass through the shield, and gems pass through the body. Hit box outlines are drawn for visual debugging. Use A/D to rotate the player and see both regions rotate together. @@ -16,15 +20,16 @@ import random import math import arcade -from arcade.hitbox import HitBox +from arcade.hitbox import HitBox, channels WINDOW_WIDTH = 1280 WINDOW_HEIGHT = 720 WINDOW_TITLE = "Multi-Region Hit Box Example" PLAYER_SPEED = 5.0 -COIN_SPEED = 2.0 -COIN_COUNT = 20 +ITEM_SPEED = 2.0 +COIN_COUNT = 15 +GEM_COUNT = 10 class GameView(arcade.View): @@ -34,7 +39,9 @@ def __init__(self): self.player_sprite = None self.player_list = None self.coin_list = None - self.score = 0 + self.gem_list = None + self.coin_score = 0 + self.gem_score = 0 self.score_display = None self.background_color = arcade.csscolor.DARK_SLATE_GRAY @@ -46,9 +53,11 @@ def __init__(self): def setup(self): self.player_list = arcade.SpriteList() self.coin_list = arcade.SpriteList() - self.score = 0 + self.gem_list = arcade.SpriteList() + self.coin_score = 0 + self.gem_score = 0 self.score_display = arcade.Text( - text="Score: 0", + text="Coins: 0 | Gems: 0", x=10, y=WINDOW_HEIGHT - 30, color=arcade.color.WHITE, font_size=16, ) @@ -58,13 +67,17 @@ def setup(self): self.player_sprite = arcade.Sprite(img, scale=0.5) self.player_sprite.position = WINDOW_WIDTH / 2, WINDOW_HEIGHT / 2 - # Replace the default hitbox with a multi-region hitbox. - # "body" is a box around the torso, "shield" extends to the right. - # Collision detection automatically checks all regions. + # Multi-region hitbox with collision channels: + # "body" on channel 1 — collides with coins (also on channel 1) + # "shield" on channel 2 — collides with gems (also on channel 2) self.player_sprite.hit_box = HitBox( { "body": [(-15, -48), (-15, 40), (15, 40), (15, -48)], - "shield": [(15, -30), (15, 30), (45, 30), (45, -30)], + "shield": { + "points": [(35, -30), (35, 30), (65, 30), (65, -30)], + "channel": channels(2), + "mask": channels(2), + }, }, position=self.player_sprite.position, scale=self.player_sprite.scale, @@ -72,43 +85,61 @@ def setup(self): ) self.player_list.append(self.player_sprite) - self._spawn_coins(COIN_COUNT) + self._spawn_items(self.coin_list, COIN_COUNT, is_gem=False) + self._spawn_items(self.gem_list, GEM_COUNT, is_gem=True) - def _spawn_coins(self, count): + def _spawn_items(self, sprite_list, count, is_gem): for _ in range(count): - coin = arcade.Sprite(":resources:images/items/coinGold.png", scale=0.4) + if is_gem: + item = arcade.Sprite( + ":resources:images/items/gemBlue.png", scale=0.4, + ) + # Gems are on channel 2 — only the shield can catch them + item.hit_box = HitBox( + item.hit_box.points, + position=item.position, + scale=item.scale, + channel=channels(2), + mask=channels(2), + ) + else: + item = arcade.Sprite( + ":resources:images/items/coinGold.png", scale=0.4, + ) + # Coins keep the default channel 1 — only the body can catch them # Spawn along a random edge side = random.randint(0, 3) if side == 0: - coin.center_x = random.randrange(WINDOW_WIDTH) - coin.center_y = WINDOW_HEIGHT + 20 + item.center_x = random.randrange(WINDOW_WIDTH) + item.center_y = WINDOW_HEIGHT + 20 elif side == 1: - coin.center_x = random.randrange(WINDOW_WIDTH) - coin.center_y = -20 + item.center_x = random.randrange(WINDOW_WIDTH) + item.center_y = -20 elif side == 2: - coin.center_x = -20 - coin.center_y = random.randrange(WINDOW_HEIGHT) + item.center_x = -20 + item.center_y = random.randrange(WINDOW_HEIGHT) else: - coin.center_x = WINDOW_WIDTH + 20 - coin.center_y = random.randrange(WINDOW_HEIGHT) + item.center_x = WINDOW_WIDTH + 20 + item.center_y = random.randrange(WINDOW_HEIGHT) # Aim toward the center with some randomness target_x = WINDOW_WIDTH / 2 + random.randint(-200, 200) target_y = WINDOW_HEIGHT / 2 + random.randint(-200, 200) - dx = target_x - coin.center_x - dy = target_y - coin.center_y + dx = target_x - item.center_x + dy = target_y - item.center_y dist = math.hypot(dx, dy) if dist > 0: - coin.change_x = (dx / dist) * COIN_SPEED - coin.change_y = (dy / dist) * COIN_SPEED + item.change_x = (dx / dist) * ITEM_SPEED + item.change_y = (dy / dist) * ITEM_SPEED - self.coin_list.append(coin) + sprite_list.append(item) def on_draw(self): self.clear() self.coin_list.draw() + self.gem_list.draw() self.player_list.draw() # Debug: draw each hitbox region in a different color @@ -119,10 +150,12 @@ def on_draw(self): arcade.draw_line_strip(tuple(pts) + (pts[0],), color=color, line_width=2) self.coin_list.draw_hit_boxes(color=arcade.color.YELLOW, line_thickness=1) + self.gem_list.draw_hit_boxes(color=arcade.color.BLUE, line_thickness=1) self.score_display.draw() arcade.draw_text( - "Red = Body | Cyan = Shield | Arrow keys to move | A/D to rotate", + "Red body (ch1) = coins | Cyan shield (ch2) = gems | " + "Arrows to move | A/D to rotate", WINDOW_WIDTH / 2, 20, arcade.color.WHITE, font_size=12, anchor_x="center", ) @@ -165,27 +198,46 @@ def on_update(self, delta_time): if keys[arcade.key.D]: self.player_sprite.angle += 3.0 - # Move coins + # Move items self.coin_list.update() + self.gem_list.update() - # Standard collision check — automatically tests all hitbox regions - hit_list = arcade.check_for_collision_with_list( + # Coins only collide with the body (channel 1) + coin_hits = arcade.check_for_collision_with_list( self.player_sprite, self.coin_list ) - if hit_list: - for coin in hit_list: - coin.remove_from_sprite_lists() - self.score += 1 - self.score_display.text = f"Score: {self.score}" + for coin in coin_hits: + coin.remove_from_sprite_lists() + self.coin_score += 1 + + # Gems only collide with the shield (channel 2) + gem_hits = arcade.check_for_collision_with_list( + self.player_sprite, self.gem_list + ) + for gem in gem_hits: + gem.remove_from_sprite_lists() + self.gem_score += 1 + + if coin_hits or gem_hits: + self.score_display.text = ( + f"Coins: {self.coin_score} | Gems: {self.gem_score}" + ) - # Replace coins that left the screen + # Replace items that left the screen margin = 100 - for coin in list(self.coin_list): - if (coin.center_x < -margin or coin.center_x > WINDOW_WIDTH + margin - or coin.center_y < -margin or coin.center_y > WINDOW_HEIGHT + margin): - coin.remove_from_sprite_lists() + for item in list(self.coin_list): + if (item.center_x < -margin or item.center_x > WINDOW_WIDTH + margin + or item.center_y < -margin or item.center_y > WINDOW_HEIGHT + margin): + item.remove_from_sprite_lists() + for item in list(self.gem_list): + if (item.center_x < -margin or item.center_x > WINDOW_WIDTH + margin + or item.center_y < -margin or item.center_y > WINDOW_HEIGHT + margin): + item.remove_from_sprite_lists() + if len(self.coin_list) < COIN_COUNT: - self._spawn_coins(COIN_COUNT - len(self.coin_list)) + self._spawn_items(self.coin_list, COIN_COUNT - len(self.coin_list), is_gem=False) + if len(self.gem_list) < GEM_COUNT: + self._spawn_items(self.gem_list, GEM_COUNT - len(self.gem_list), is_gem=True) def main(): diff --git a/arcade/hitbox/__init__.py b/arcade/hitbox/__init__.py index d7ebfd852c..1a01e3a734 100644 --- a/arcade/hitbox/__init__.py +++ b/arcade/hitbox/__init__.py @@ -2,7 +2,15 @@ from arcade.types import Point2List -from .base import HitBox, HitBoxAlgorithm, RawHitBox +from .base import ( + DEFAULT_CHANNEL, + DEFAULT_MASK, + HitBox, + HitBoxAlgorithm, + RawHitBox, + RawHitBoxRegion, + channels, +) from .bounding_box import BoundingHitBoxAlgorithm from .simple import SimpleHitBoxAlgorithm @@ -60,6 +68,10 @@ def calculate_hit_box_points_detailed( "HitBoxAlgorithm", "HitBox", "RawHitBox", + "RawHitBoxRegion", + "DEFAULT_CHANNEL", + "DEFAULT_MASK", + "channels", "SimpleHitBoxAlgorithm", "PymunkHitBoxAlgorithm", "BoundingHitBoxAlgorithm", diff --git a/arcade/hitbox/base.py b/arcade/hitbox/base.py index 787a1c0d4e..2dd0893ba0 100644 --- a/arcade/hitbox/base.py +++ b/arcade/hitbox/base.py @@ -7,18 +7,59 @@ from typing import Any, TypedDict from PIL.Image import Image -from typing_extensions import Self +from typing_extensions import NotRequired, Self from arcade.types import EMPTY_POINT_LIST, Point2, Point2List -__all__ = ["HitBoxAlgorithm", "HitBox", "RawHitBox"] +__all__ = [ + "HitBoxAlgorithm", + "HitBox", + "RawHitBox", + "RawHitBoxRegion", + "DEFAULT_CHANNEL", + "DEFAULT_MASK", + "channels", +] + +#: Default channel bitmask: layer 1 only. +DEFAULT_CHANNEL: int = 1 +#: Default mask bitmask: layer 1 only. +DEFAULT_MASK: int = 1 + + +def channels(*nums: int) -> int: + """Convert 1-based channel numbers to a bitmask. + + Example:: + + channels(1) # 0b01 == 1 + channels(1, 3) # 0b101 == 5 + channels(1, 3, 5) # 0b10101 == 21 + + Args: + nums: One or more 1-based channel numbers. + """ + mask = 0 + for n in nums: + if n < 1: + raise ValueError(f"Channel numbers must be >= 1, got {n}") + mask |= 1 << (n - 1) + return mask + + +class RawHitBoxRegion(TypedDict): + """Typed dictionary for a single region in a serialized :py:class:`HitBox`.""" + + points: Point2List + channel: NotRequired[int] + mask: NotRequired[int] class RawHitBox(TypedDict): """Typed dictionary representing the serialized form of a :py:class:`HitBox`.""" version: int - regions: dict[str, Point2List] + regions: dict[str, RawHitBoxRegion] class HitBoxAlgorithm: @@ -122,6 +163,19 @@ class HitBox: "head": [(-5, 10), (5, 10), (5, 20), (-5, 20)], }) + **Per-region collision channels**:: + + from arcade.hitbox import channels + + box = HitBox({ + "body": body_pts, + "shield": { + "points": shield_pts, + "channel": channels(2), + "mask": channels(2), + }, + }) + **Rotation** (replaces the former ``RotatableHitBox`` class):: box = HitBox(points, angle=45.0) @@ -154,31 +208,52 @@ class HitBox: Args: points: Either a single ``Point2List`` (creates a ``"default"`` region) - or a ``dict[str, Point2List]`` mapping region names to point lists. - Points are normalized to tuples on storage. + or a ``dict`` mapping region names to point lists or + ``RawHitBoxRegion`` dicts with ``"points"``, ``"channel"``, + and ``"mask"`` keys. position: The center around which the points will be offset. scale: The X and Y scaling factors. angle: The rotation angle in degrees (clockwise). + channel: + Default channel bitmask for all regions that don't specify + their own. Defaults to :py:data:`DEFAULT_CHANNEL`. + mask: + Default mask bitmask for all regions that don't specify + their own. Defaults to :py:data:`DEFAULT_MASK`. """ DEFAULT_REGION = "default" def __init__( self, - points: Point2List | dict[str, Point2List], + points: Point2List | dict[str, Point2List | RawHitBoxRegion], position: Point2 = (0.0, 0.0), scale: Point2 = (1.0, 1.0), angle: float = 0.0, + channel: int = DEFAULT_CHANNEL, + mask: int = DEFAULT_MASK, ): + self._channels: dict[str, int] = {} + self._masks: dict[str, int] = {} + if isinstance(points, dict): - self._regions: dict[str, Point2List] = { - name: tuple(tuple(p) for p in pts) for name, pts in points.items() - } + self._regions: dict[str, Point2List] = {} + for name, value in points.items(): + if isinstance(value, dict): + self._regions[name] = tuple(tuple(p) for p in value["points"]) + self._channels[name] = value.get("channel", channel) + self._masks[name] = value.get("mask", mask) + else: + self._regions[name] = tuple(tuple(p) for p in value) + self._channels[name] = channel + self._masks[name] = mask else: self._regions = {self.DEFAULT_REGION: tuple(tuple(p) for p in points)} + self._channels[self.DEFAULT_REGION] = channel + self._masks[self.DEFAULT_REGION] = mask self._position = position self._scale = scale @@ -213,15 +288,25 @@ def has_region(self, name: str) -> bool: """Check if a region with the given name exists.""" return name in self._regions - def add_region(self, name: str, points: Point2List) -> None: + def add_region( + self, + name: str, + points: Point2List, + channel: int = DEFAULT_CHANNEL, + mask: int = DEFAULT_MASK, + ) -> None: """ Add a named region to this hit box. Args: name: The name for the new region. points: The polygon points for the region. + channel: Channel bitmask for this region. + mask: Mask bitmask for this region. """ self._regions[name] = tuple(tuple(p) for p in points) + self._channels[name] = channel + self._masks[name] = mask self._is_single_region = len(self._regions) == 1 self._adjusted_cache_dirty = True @@ -233,9 +318,45 @@ def remove_region(self, name: str) -> None: name: The name of the region to remove. """ del self._regions[name] + self._channels.pop(name, None) + self._masks.pop(name, None) self._is_single_region = len(self._regions) == 1 self._adjusted_cache_dirty = True + def get_channel(self, name: str) -> int: + """Return the channel bitmask for a region. + + Args: + name: The region name. + """ + return self._channels.get(name, DEFAULT_CHANNEL) + + def set_channel(self, name: str, channel: int) -> None: + """Set the channel bitmask for a region. + + Args: + name: The region name. + channel: The channel bitmask. + """ + self._channels[name] = channel + + def get_mask(self, name: str) -> int: + """Return the mask bitmask for a region. + + Args: + name: The region name. + """ + return self._masks.get(name, DEFAULT_MASK) + + def set_mask(self, name: str, mask: int) -> None: + """Set the mask bitmask for a region. + + Args: + name: The region name. + mask: The mask bitmask. + """ + self._masks[name] = mask + @property def position(self) -> Point2: """The center point used to offset the final adjusted positions.""" @@ -361,23 +482,54 @@ def get_all_adjusted_polygons(self) -> list[Point2List]: """ Return adjusted points for all regions as a list of polygons. - This is used by collision detection to check all regions. + This returns only the geometry without channel/mask metadata. + It is used by point and rectangle query functions. For + channel-filtered collision detection, use + :py:meth:`get_collision_regions` instead. """ self._recalculate_if_dirty() return list(self._adjusted_regions.values()) + def get_collision_regions(self) -> list[tuple[Point2List, int, int]]: + """ + Return ``(adjusted_points, channel, mask)`` for each region. + + Used by collision detection to perform channel-filtered checks. + """ + self._recalculate_if_dirty() + return [ + ( + self._adjusted_regions[name], + self._channels.get(name, DEFAULT_CHANNEL), + self._masks.get(name, DEFAULT_MASK), + ) + for name in self._regions + ] + # --- Serialization --- def to_dict(self) -> RawHitBox: """ Serialize the hitbox shape to a :py:class:`RawHitBox` dictionary. - Only the region definitions (point data) are serialized. - Position, scale, and angle are runtime state and are not included. + Only the region definitions (point data, channel, mask) are + serialized. Position, scale, and angle are runtime state and + are not included. Channel and mask are omitted when they equal + the defaults. """ + regions: dict[str, RawHitBoxRegion] = {} + for name, pts in self._regions.items(): + region: RawHitBoxRegion = {"points": pts} + ch = self._channels.get(name, DEFAULT_CHANNEL) + mk = self._masks.get(name, DEFAULT_MASK) + if ch != DEFAULT_CHANNEL: + region["channel"] = ch + if mk != DEFAULT_MASK: + region["mask"] = mk + regions[name] = region return { - "version": 1, - "regions": {name: pts for name, pts in self._regions.items()}, + "version": 2, + "regions": regions, } @classmethod @@ -391,13 +543,28 @@ def from_dict( """ Create a HitBox from a :py:class:`RawHitBox` dictionary. + Handles both v1 (region values are point lists) and v2 (region + values are :py:class:`RawHitBoxRegion` dicts). + Args: data: A :py:class:`RawHitBox` dictionary to deserialize from. position: The center offset. scale: The scaling factors. angle: The rotation angle in degrees. """ - return cls(points=data["regions"], position=position, scale=scale, angle=angle) + version = data.get("version", 1) + raw_regions = data["regions"] + + if version <= 1: + # v1: region values are plain Point2Lists + points: dict[str, Point2List | RawHitBoxRegion] = { + name: pts for name, pts in raw_regions.items() + } + else: + # v2: region values are RawHitBoxRegion dicts + points = dict(raw_regions) + + return cls(points=points, position=position, scale=scale, angle=angle) def save(self, path: str | Path) -> None: """ diff --git a/arcade/sprite_list/collision.py b/arcade/sprite_list/collision.py index 5a60f197b0..def5431dbf 100644 --- a/arcade/sprite_list/collision.py +++ b/arcade/sprite_list/collision.py @@ -57,12 +57,18 @@ def check_for_collision(sprite1: BasicSprite, sprite2: BasicSprite) -> bool: """ Check for a collision between two sprites. + Two sprites collide when their hitbox regions overlap geometrically + **and** their collision channels agree. Specifically, for each pair + of regions, both ``(A.channel & B.mask) != 0`` and + ``(B.channel & A.mask) != 0`` must hold. See :py:class:`~arcade.hitbox.HitBox` + for details on configuring channels and masks. + Args: sprite1: First sprite sprite2: Second sprite Returns: - ``True`` or ``False`` depending if the sprites intersect. + ``True`` if the sprites collide, ``False`` otherwise. """ if __debug__: if not isinstance(sprite1, BasicSprite): @@ -83,11 +89,14 @@ def _check_for_collision(sprite1: BasicSprite, sprite2: BasicSprite) -> bool: """ Check for collision between two sprites. + Performs a fast distance pre-check, then tests each pair of hitbox + regions for both channel/mask compatibility and polygon intersection. + Args: sprite1: Sprite 1 sprite2: Sprite 2 Returns: - ``True`` if sprites overlap. + ``True`` if sprites collide (overlap and channels match). """ # NOTE: for speed because attribute look ups are slow. @@ -120,9 +129,13 @@ def _check_for_collision(sprite1: BasicSprite, sprite2: BasicSprite) -> bool: if distance > radius_sum_sq: return False - polys1 = sprite1.hit_box.get_all_adjusted_polygons() - polys2 = sprite2.hit_box.get_all_adjusted_polygons() - return any(are_polygons_intersecting(p1, p2) for p1 in polys1 for p2 in polys2) + regions1 = sprite1.hit_box.get_collision_regions() + regions2 = sprite2.hit_box.get_collision_regions() + return any( + (ch1 & mask2) and (ch2 & mask1) and are_polygons_intersecting(p1, p2) + for p1, ch1, mask1 in regions1 + for p2, ch2, mask2 in regions2 + ) def _get_nearby_sprites( @@ -142,6 +155,9 @@ def check_for_collision_with_list( """ Check for a collision between a sprite, and a list of sprites. + Collision channel/mask filtering is applied per-region. + See :py:func:`check_for_collision` for details. + Args: sprite: Sprite to check @@ -206,6 +222,9 @@ def check_for_collision_with_lists( """ Check for a collision between a Sprite, and a list of SpriteLists. + Collision channel/mask filtering is applied per-region. + See :py:func:`check_for_collision` for details. + Args: sprite: Sprite to check diff --git a/tests/unit/hitbox/test_hitbox.py b/tests/unit/hitbox/test_hitbox.py index 5bdabd7bf8..50ef458f7a 100644 --- a/tests/unit/hitbox/test_hitbox.py +++ b/tests/unit/hitbox/test_hitbox.py @@ -143,9 +143,12 @@ def test_single_region_fast_path(): def test_to_dict_single_region(): hb = hitbox.HitBox(points) d = hb.to_dict() - assert d["version"] == 1 + assert d["version"] == 2 assert "default" in d["regions"] assert len(d["regions"]) == 1 + # Default channel/mask should be omitted + assert "channel" not in d["regions"]["default"] + assert "mask" not in d["regions"]["default"] def test_to_dict_multi_region(): @@ -153,7 +156,7 @@ def test_to_dict_multi_region(): head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] hb = hitbox.HitBox({"body": body_pts, "head": head_pts}) d = hb.to_dict() - assert d["version"] == 1 + assert d["version"] == 2 assert "body" in d["regions"] assert "head" in d["regions"] @@ -211,3 +214,196 @@ def test_save_load_gzip(): assert hb2.points == ((0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)) finally: path.unlink(missing_ok=True) + + +# --- Channel/mask tests --- + + +def test_channels_utility(): + assert hitbox.channels(1) == 1 + assert hitbox.channels(2) == 2 + assert hitbox.channels(3) == 4 + assert hitbox.channels(1, 2) == 3 + assert hitbox.channels(1, 3) == 5 + assert hitbox.channels(1, 3, 5) == 21 # 0b10101 + + +def test_channels_invalid(): + with pytest.raises(ValueError): + hitbox.channels(0) + with pytest.raises(ValueError): + hitbox.channels(-1) + + +def test_default_channel_mask(): + hb = hitbox.HitBox(points) + assert hb.get_channel("default") == hitbox.DEFAULT_CHANNEL + assert hb.get_mask("default") == hitbox.DEFAULT_MASK + + +def test_constructor_channel_mask_single(): + ch = hitbox.channels(2) + mk = hitbox.channels(1, 2) + hb = hitbox.HitBox(points, channel=ch, mask=mk) + assert hb.get_channel("default") == ch + assert hb.get_mask("default") == mk + + +def test_constructor_channel_mask_multi(): + ch = hitbox.channels(3) + mk = hitbox.channels(1, 3) + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + hb = hitbox.HitBox({"body": body_pts, "head": head_pts}, channel=ch, mask=mk) + assert hb.get_channel("body") == ch + assert hb.get_channel("head") == ch + assert hb.get_mask("body") == mk + assert hb.get_mask("head") == mk + + +def test_mixed_format_dict(): + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + shield_pts = [(15.0, -30.0), (15.0, 30.0), (45.0, 30.0), (45.0, -30.0)] + ch2 = hitbox.channels(2) + hb = hitbox.HitBox({ + "body": body_pts, + "shield": {"points": shield_pts, "channel": ch2, "mask": ch2}, + }) + # body gets defaults + assert hb.get_channel("body") == hitbox.DEFAULT_CHANNEL + assert hb.get_mask("body") == hitbox.DEFAULT_MASK + # shield gets explicit values + assert hb.get_channel("shield") == ch2 + assert hb.get_mask("shield") == ch2 + # points are stored correctly + assert hb.regions["body"] == tuple(tuple(p) for p in body_pts) + assert hb.regions["shield"] == tuple(tuple(p) for p in shield_pts) + + +def test_mixed_format_inherits_constructor_defaults(): + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + shield_pts = [(15.0, -30.0), (15.0, 30.0), (45.0, 30.0), (45.0, -30.0)] + ch3 = hitbox.channels(3) + hb = hitbox.HitBox( + {"body": body_pts, "shield": {"points": shield_pts, "channel": ch3}}, + channel=hitbox.channels(1, 2), + mask=hitbox.channels(1, 2, 3), + ) + # body inherits constructor defaults + assert hb.get_channel("body") == hitbox.channels(1, 2) + assert hb.get_mask("body") == hitbox.channels(1, 2, 3) + # shield overrides channel but inherits mask from constructor default + assert hb.get_channel("shield") == ch3 + assert hb.get_mask("shield") == hitbox.channels(1, 2, 3) + + +def test_set_get_channel_mask(): + hb = hitbox.HitBox(points) + hb.set_channel("default", hitbox.channels(2)) + hb.set_mask("default", hitbox.channels(1, 2)) + assert hb.get_channel("default") == hitbox.channels(2) + assert hb.get_mask("default") == hitbox.channels(1, 2) + + +def test_add_region_with_channel_mask(): + hb = hitbox.HitBox(points) + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + ch = hitbox.channels(3) + mk = hitbox.channels(1, 3) + hb.add_region("head", head_pts, channel=ch, mask=mk) + assert hb.get_channel("head") == ch + assert hb.get_mask("head") == mk + + +def test_remove_region_cleans_channel_mask(): + hb = hitbox.HitBox(points) + head_pts = [(2.0, 10.0), (2.0, 15.0), (8.0, 15.0), (8.0, 10.0)] + hb.add_region("head", head_pts, channel=hitbox.channels(2)) + hb.remove_region("head") + # After removal, get_channel returns the default + assert hb.get_channel("head") == hitbox.DEFAULT_CHANNEL + + +def test_get_collision_regions(): + body_pts = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)] + shield_pts = [(15.0, -5.0), (15.0, 5.0), (25.0, 5.0), (25.0, -5.0)] + ch2 = hitbox.channels(2) + hb = hitbox.HitBox({ + "body": body_pts, + "shield": {"points": shield_pts, "channel": ch2, "mask": ch2}, + }) + regions = hb.get_collision_regions() + assert len(regions) == 2 + _, ch_body, mk_body = regions[0] + _, ch_shield, mk_shield = regions[1] + assert ch_body == hitbox.DEFAULT_CHANNEL + assert mk_body == hitbox.DEFAULT_MASK + assert ch_shield == ch2 + assert mk_shield == ch2 + + +def test_to_dict_with_channel_mask(): + ch2 = hitbox.channels(2) + mk = hitbox.channels(1, 2) + hb = hitbox.HitBox({ + "body": points, + "shield": {"points": points, "channel": ch2, "mask": mk}, + }) + d = hb.to_dict() + assert d["version"] == 2 + # body has defaults, so channel/mask omitted + assert "channel" not in d["regions"]["body"] + assert "mask" not in d["regions"]["body"] + # shield has non-default values + assert d["regions"]["shield"]["channel"] == ch2 + assert d["regions"]["shield"]["mask"] == mk + + +def test_from_dict_v1(): + """v1 format (point lists as region values) still loads correctly.""" + d = { + "version": 1, + "regions": { + "default": [[0.0, 0.0], [0.0, 10.0], [10.0, 10.0], [10.0, 0.0]], + }, + } + hb = hitbox.HitBox.from_dict(d) + assert hb.points == ((0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)) + assert hb.get_channel("default") == hitbox.DEFAULT_CHANNEL + assert hb.get_mask("default") == hitbox.DEFAULT_MASK + + +def test_roundtrip_dict_with_channel_mask(): + ch2 = hitbox.channels(2) + mk = hitbox.channels(1, 2) + hb = hitbox.HitBox({ + "body": points, + "shield": {"points": points, "channel": ch2, "mask": mk}, + }) + d = hb.to_dict() + hb2 = hitbox.HitBox.from_dict(d) + assert hb2.get_channel("body") == hitbox.DEFAULT_CHANNEL + assert hb2.get_mask("body") == hitbox.DEFAULT_MASK + assert hb2.get_channel("shield") == ch2 + assert hb2.get_mask("shield") == mk + + +def test_save_load_with_channel_mask(): + ch2 = hitbox.channels(2) + mk = hitbox.channels(1, 2) + hb = hitbox.HitBox({ + "body": points, + "shield": {"points": points, "channel": ch2, "mask": mk}, + }) + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: + path = Path(f.name) + + try: + hb.save(path) + hb2 = hitbox.HitBox.load(path) + assert hb2.get_channel("shield") == ch2 + assert hb2.get_mask("shield") == mk + assert hb2.get_channel("body") == hitbox.DEFAULT_CHANNEL + finally: + path.unlink(missing_ok=True)