Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ect"
version = "1.2.0"
version = "1.2.1"
authors = [
{ name="Liz Munch", email="muncheli@msu.edu" },
]
Expand Down
165 changes: 163 additions & 2 deletions src/ect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@ def __array_finalize__(self, obj):
self.directions = getattr(obj, "directions", None)
self.thresholds = getattr(obj, "thresholds", None)

def plot(self, ax=None):
"""Plot ECT matrix with proper handling for both 2D and 3D"""
def plot(self, ax=None, *, radial=False, **kwargs):
"""Plot ECT matrix with proper handling for both 2D and 3D.

Set radial=True to render a polar visualization (2D only). Any extra
keyword arguments are forwarded to the radial renderer.
"""
if radial:
return self._plot_radial(ax=ax, **kwargs)

ax = ax or plt.gca()

if self.thresholds is None:
Expand Down Expand Up @@ -107,6 +114,46 @@ def smooth(self):
# create new ECTResult with float type
return ECTResult(sect.astype(np.float64), self.directions, self.thresholds)

# Internal plotting utilities
def _ensure_2d(self):
if self.directions is None or self.directions.dim != 2:
raise ValueError("This visualization is only supported for 2D ECT results")

def _theta_threshold_mesh(self):
thetas = self.directions.thetas
thresholds = self.thresholds
THETA, R = np.meshgrid(thetas, thresholds)
return THETA, R

def _configure_polar_axes(
self, ax, rmin=0.0, rmax=None, theta_zero="N", theta_dir=-1
):
ax.set_theta_zero_location(theta_zero)
ax.set_theta_direction(theta_dir)
if rmax is None:
rmax = float(np.max(self.thresholds))
ax.set_ylim(float(rmin), float(rmax))
return ax

def _scale_overlay_radii(self, points, rmin=0.0, rmax=None, fit_to_thresholds=True):
x = points[:, 0]
y = points[:, 1]
r = np.sqrt(x**2 + y**2)
theta = np.arctan2(y, x)

if rmax is None:
rmax = float(np.max(self.thresholds))

if not fit_to_thresholds:
return theta, r

max_r_points = float(np.max(r)) if r.size else 0.0
if max_r_points > 0.0:
scaled_r = (r / max_r_points) * (rmax - float(rmin)) + float(rmin)
else:
scaled_r = r
return theta, scaled_r

def _plot_ecc(self, theta):
"""Plot the Euler Characteristic Curve for a specific direction"""
plt.step(self.thresholds, self.T, label="ECC")
Expand All @@ -115,6 +162,120 @@ def _plot_ecc(self, theta):
plt.xlabel("$a$")
plt.ylabel(r"$\chi(K_a)$")

def _plot_radial(
self,
ax=None,
title=None,
cmap="viridis",
*,
rmin=0.0,
rmax=None,
colorbar=True,
overlay=None,
overlay_kwargs=None,
**kwargs,
):
"""
Plot ECT matrix in polar coordinates (radial plot).

Args:
ax: matplotlib axes object. If None, creates a new polar subplot
title: optional string for plot title
cmap: colormap for the plot (default: 'viridis')
rmin: minimum radius for the plot (default: 0.0)
rmax: maximum radius for the plot (default: None)
colorbar: whether to show the colorbar (default: True)
overlay: points to overlay on the plot (default: None)

**kwargs: additional keyword arguments passed to pcolormesh

Returns:
matplotlib.axes.Axes: The axes object used for plotting
"""
self._ensure_2d()

if ax is None:
fig, ax = plt.subplots(
subplot_kw=dict(projection="polar"), figsize=(10, 10)
)

THETA, R = self._theta_threshold_mesh()

im = ax.pcolormesh(THETA, R, self.T, cmap=cmap, **kwargs)

self._configure_polar_axes(ax, rmin=rmin, rmax=rmax)

if title:
ax.set_title(title)

if colorbar:
plt.colorbar(im, ax=ax, label="ECT Value")

if overlay is not None:
overlay_kwargs = overlay_kwargs or {}
theta, scaled_r = self._scale_overlay_radii(
overlay, rmin=rmin, rmax=rmax, fit_to_thresholds=True
)
ax.plot(
theta,
scaled_r,
"-",
color=overlay_kwargs.get("color", "black"),
linewidth=overlay_kwargs.get("linewidth", 2),
alpha=overlay_kwargs.get("alpha", 0.5),
)

return ax

def _overlay_points(
self,
points,
ax=None,
color="black",
linewidth=2,
alpha=0.5,
*,
rmin=0.0,
rmax=None,
fit_to_thresholds=True,
**kwargs,
):
"""
Overlay original points on a radial ECT plot.

Args:
points: numpy array of shape (N, 2) containing the original points
ax: matplotlib polar axes object. If None, uses current axes
color: color for the overlay line (default: 'white')
linewidth: line width for the overlay (default: 2)
alpha: transparency for the overlay (default: 0.5)
**kwargs: additional keyword arguments passed to plot

Returns:
matplotlib.axes.Axes: The axes object used for plotting
"""
if ax is None:
ax = plt.gca()

if not hasattr(ax, "name") or ax.name != "polar":
raise ValueError("overlay_points requires a polar axes object")

theta, scaled_r = self._scale_overlay_radii(
points, rmin=rmin, rmax=rmax, fit_to_thresholds=fit_to_thresholds
)

ax.plot(
theta,
scaled_r,
"-",
color=color,
linewidth=linewidth,
alpha=alpha,
**kwargs,
)

return ax

def dist(
self,
other: Union["ECTResult", List["ECTResult"]],
Expand Down
Loading