Skip to content

Commit 26157a0

Browse files
committed
Updated example notebook
1 parent 6816cde commit 26157a0

2 files changed

Lines changed: 218 additions & 327 deletions

File tree

datamint/utils/visualization.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import colorsys
88
from collections.abc import Sequence
99
from matplotlib.axes import Axes
10+
import logging
11+
12+
_LOGGER = logging.getLogger(__name__)
1013

1114

1215
def show(imgs: Sequence[Tensor | np.ndarray] | Tensor | np.ndarray,
@@ -94,7 +97,7 @@ def draw_masks(
9497
image: Tensor | np.ndarray,
9598
masks: Tensor | np.ndarray,
9699
alpha: float = 0.5,
97-
colors: list[str | tuple[int, int, int]] | str | tuple[int, int, int] | None = None,
100+
colors: Sequence[str | tuple[int, int, int]] | str | tuple[int, int, int] | None = None,
98101
) -> Tensor:
99102
"""
100103
Draws segmentation masks on given RGB image.
@@ -121,6 +124,10 @@ def draw_masks(
121124
if isinstance(masks, np.ndarray):
122125
masks = torch.from_numpy(masks)
123126

127+
if masks.ndim == 4:
128+
_LOGGER.warning(f"In draw_masks: Expected masks to have shape (num_masks, H, W) or (H, W), but got {masks.shape}."
129+
" It might produce unexpected results. Please check the shape of the masks.")
130+
124131
if image.ndim == 3 and image.shape[0] == 1:
125132
# convert to RGB
126133
image = image.expand(3, -1, -1)

0 commit comments

Comments
 (0)