Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/_static/docstring_previews/de_disp_ests.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
161 changes: 161 additions & 0 deletions pertpy/tools/_differential_gene_expression/_pydeseq2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from anndata import AnnData
from matplotlib.lines import Line2D
from matplotlib.pyplot import Figure
from numpy import ndarray
from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats
from scipy.sparse import issparse

from pertpy._doc import _doc_params, doc_common_plot_args

from ._base import LinearModelBase
from ._checks import check_is_integer_matrix

Expand Down Expand Up @@ -66,6 +71,162 @@ def fit(self, **kwargs) -> pd.DataFrame:
dds.deseq2()
self.dds = dds

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_disp_ests( # pragma: no cover # noqa: D417
self,
*,
ymin: float | None = None,
cv: bool = False,
gene_col: str = "black",
fit_col: str = "red",
final_col: str = "dodgerblue",
legend: bool = True,
xlabel: str | None = None,
ylabel: str | None = None,
log: str = "xy",
point_size: float = 0.45,
return_fig: bool = False,
**kwargs,
) -> Figure | None:
"""Plots per-gene dispersion estimates together with the fitted mean–dispersion relationship.

Args:
Comment thread
LuisHeinzlmeier marked this conversation as resolved.
ymin: Lower bound for plotted values. Points below this threshold are drawn at ymin using triangle markers.
cv: If True, plot the square root of dispersion (coefficient of variation) instead of dispersion.
gene_col: Color for gene-wise dispersion estimates.
fit_col: Color for fitted dispersion trend.
final_col: Color for final dispersion estimates used for testing.
legend: Whether to draw a legend.
xlabel: Label for the x-axis (default: "mean of normalized counts").
ylabel: Label for the y-axis (default: "dispersion" or "coefficient of variation").
log: Axis scaling. "x", "y", or "xy" for log scaling.
point_size: Scaling factor for point sizes.
{common_plot_args}
**kwargs: Additional arguments for ax.scatter.

Returns:
If `return_fig` is `True`, returns the figure, otherwise `None`.

Examples:
>>> import pertpy as pt
>>> import decoupler as dc
>>> adata = pt.dt.zhang_2021()
>>> adata = adata[adata.obs["Origin"] == "t", :].copy()
>>> adata.layers["counts"] = adata.X.copy()
>>> pdata = dc.pp.pseudobulk(adata, sample_col="Patient", groups_col="Cluster", layer="counts", mode="sum")
>>> dc.pp.filter_samples(pdata, inplace=True)
>>> pds2 = pt.tl.PyDESeq2(pdata, design="~Efficacy+Treatment")
>>> pds2.fit()
>>> pds2.plot_disp_ests(point_size=0.1)

Preview:
Comment thread
LuisHeinzlmeier marked this conversation as resolved.
.. image:: /_static/docstring_previews/de_disp_ests.png
"""
if not hasattr(self, "dds"):
raise ValueError("Model not fitted yet. Call .fit() first.")

dds = self.dds

if xlabel is None:
xlabel = "mean of normalized counts"
if ylabel is None:
ylabel = "coefficient of variation" if cv else "dispersion"

px = np.asarray(dds.var["_normed_means"])
sel = px > 0
px = px[sel]

py = np.asarray(dds.var["genewise_dispersions"])[sel]
if cv:
py = np.sqrt(py)

if ymin is None:
positive = py[(py > 0) & np.isfinite(py)]
ymin = 10 ** np.floor(np.log10(np.min(positive)) - 0.1)

py_plot = np.maximum(py, ymin)

fig, ax = plt.subplots(dpi=300)

below = py < ymin
above = ~below

if above.any():
ax.scatter(
px[above],
py_plot[above],
facecolor=gene_col,
edgecolors="none",
s=point_size * 20,
marker="o",
**kwargs,
)

if below.any():
ax.scatter(
px[below],
py_plot[below],
facecolor=gene_col,
edgecolors="none",
s=point_size * 20,
marker="v",
**kwargs,
)

outliers = np.asarray(
dds.var.get(
"_outlier_genes",
pd.Series(False, index=dds.var_names),
)
)[sel]

final_disp = np.asarray(dds.var["dispersions"])[sel]
final_y = np.sqrt(final_disp) if cv else final_disp

ax.scatter(
px,
final_y,
s=point_size * (20 + 20 * outliers.astype(int)),
facecolor=np.where(outliers, "none", final_col),
edgecolors=np.where(outliers, final_col, "none"),
)

fitted_disp = np.asarray(dds.var["fitted_dispersions"])[sel]
fitted_y = np.sqrt(fitted_disp) if cv else fitted_disp

ax.scatter(
px,
fitted_y,
facecolor=fit_col,
edgecolors="none",
marker="o",
s=point_size * 20,
)

if "x" in log:
ax.set_xscale("log")
if "y" in log:
ax.set_yscale("log")

ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

if legend:
handles = [
Line2D([0], [0], marker="o", linestyle="", color=gene_col, label="gene-est"),
Line2D([0], [0], marker="o", linestyle="", color=fit_col, label="fitted"),
Line2D([0], [0], marker="o", linestyle="", color=final_col, label="final"),
]
ax.legend(handles=handles, loc="lower right", frameon=True)

plt.tight_layout(pad=2.0)

if return_fig:
return plt.gcf()

plt.show()
return None

def _test_single_contrast(self, contrast, alpha=0.05, *, lfc_shrink=None, **kwargs) -> pd.DataFrame:
"""Conduct a specific test and returns a Pandas DataFrame.

Expand Down
Loading