77import colorsys
88from collections .abc import Sequence
99from matplotlib .axes import Axes
10+ import logging
11+
12+ _LOGGER = logging .getLogger (__name__ )
1013
1114
1215def show (imgs : Sequence [Tensor | np .ndarray ] | Tensor | np .ndarray ,
@@ -94,7 +97,7 @@ def draw_masks(
9497 image : Tensor | np .ndarray ,
9598 masks : Tensor | np .ndarray ,
9699 alpha : float = 0.5 ,
97- colors : list [str | tuple [int , int , int ]] | str | tuple [int , int , int ] | None = None ,
100+ colors : Sequence [str | tuple [int , int , int ]] | str | tuple [int , int , int ] | None = None ,
98101) -> Tensor :
99102 """
100103 Draws segmentation masks on given RGB image.
@@ -121,6 +124,10 @@ def draw_masks(
121124 if isinstance (masks , np .ndarray ):
122125 masks = torch .from_numpy (masks )
123126
127+ if masks .ndim == 4 :
128+ _LOGGER .warning (f"In draw_masks: Expected masks to have shape (num_masks, H, W) or (H, W), but got { masks .shape } ."
129+ " It might produce unexpected results. Please check the shape of the masks." )
130+
124131 if image .ndim == 3 and image .shape [0 ] == 1 :
125132 # convert to RGB
126133 image = image .expand (3 , - 1 , - 1 )
0 commit comments