diff --git a/docs/_static/docstring_previews/de_disp_ests.png b/docs/_static/docstring_previews/de_disp_ests.png new file mode 100644 index 00000000..fbd26aa6 Binary files /dev/null and b/docs/_static/docstring_previews/de_disp_ests.png differ diff --git a/pertpy/tools/_differential_gene_expression/_pydeseq2.py b/pertpy/tools/_differential_gene_expression/_pydeseq2.py index 600ca627..5e5d3c40 100644 --- a/pertpy/tools/_differential_gene_expression/_pydeseq2.py +++ b/pertpy/tools/_differential_gene_expression/_pydeseq2.py @@ -1,15 +1,20 @@ import os import warnings +import matplotlib.pyplot as plt import numpy as np import pandas as pd from anndata import AnnData +from matplotlib.lines import Line2D +from matplotlib.pyplot import Figure from numpy import ndarray from pydeseq2.dds import DeseqDataSet from pydeseq2.default_inference import DefaultInference from pydeseq2.ds import DeseqStats from scipy.sparse import issparse +from pertpy._doc import _doc_params, doc_common_plot_args + from ._base import LinearModelBase from ._checks import check_is_integer_matrix @@ -66,6 +71,162 @@ def fit(self, **kwargs) -> pd.DataFrame: dds.deseq2() self.dds = dds + @_doc_params(common_plot_args=doc_common_plot_args) + def plot_disp_ests( # pragma: no cover # noqa: D417 + self, + *, + ymin: float | None = None, + cv: bool = False, + gene_col: str = "black", + fit_col: str = "red", + final_col: str = "dodgerblue", + legend: bool = True, + xlabel: str | None = None, + ylabel: str | None = None, + log: str = "xy", + point_size: float = 0.45, + return_fig: bool = False, + **kwargs, + ) -> Figure | None: + """Plots per-gene dispersion estimates together with the fitted mean–dispersion relationship. + + Args: + ymin: Lower bound for plotted values. Points below this threshold are drawn at ymin using triangle markers. + cv: If True, plot the square root of dispersion (coefficient of variation) instead of dispersion. + gene_col: Color for gene-wise dispersion estimates. + fit_col: Color for fitted dispersion trend. + final_col: Color for final dispersion estimates used for testing. + legend: Whether to draw a legend. + xlabel: Label for the x-axis (default: "mean of normalized counts"). + ylabel: Label for the y-axis (default: "dispersion" or "coefficient of variation"). + log: Axis scaling. "x", "y", or "xy" for log scaling. + point_size: Scaling factor for point sizes. + {common_plot_args} + **kwargs: Additional arguments for ax.scatter. + + Returns: + If `return_fig` is `True`, returns the figure, otherwise `None`. + + Examples: + >>> import pertpy as pt + >>> import decoupler as dc + >>> adata = pt.dt.zhang_2021() + >>> adata = adata[adata.obs["Origin"] == "t", :].copy() + >>> adata.layers["counts"] = adata.X.copy() + >>> pdata = dc.pp.pseudobulk(adata, sample_col="Patient", groups_col="Cluster", layer="counts", mode="sum") + >>> dc.pp.filter_samples(pdata, inplace=True) + >>> pds2 = pt.tl.PyDESeq2(pdata, design="~Efficacy+Treatment") + >>> pds2.fit() + >>> pds2.plot_disp_ests(point_size=0.1) + + Preview: + .. image:: /_static/docstring_previews/de_disp_ests.png + """ + if not hasattr(self, "dds"): + raise ValueError("Model not fitted yet. Call .fit() first.") + + dds = self.dds + + if xlabel is None: + xlabel = "mean of normalized counts" + if ylabel is None: + ylabel = "coefficient of variation" if cv else "dispersion" + + px = np.asarray(dds.var["_normed_means"]) + sel = px > 0 + px = px[sel] + + py = np.asarray(dds.var["genewise_dispersions"])[sel] + if cv: + py = np.sqrt(py) + + if ymin is None: + positive = py[(py > 0) & np.isfinite(py)] + ymin = 10 ** np.floor(np.log10(np.min(positive)) - 0.1) + + py_plot = np.maximum(py, ymin) + + fig, ax = plt.subplots(dpi=300) + + below = py < ymin + above = ~below + + if above.any(): + ax.scatter( + px[above], + py_plot[above], + facecolor=gene_col, + edgecolors="none", + s=point_size * 20, + marker="o", + **kwargs, + ) + + if below.any(): + ax.scatter( + px[below], + py_plot[below], + facecolor=gene_col, + edgecolors="none", + s=point_size * 20, + marker="v", + **kwargs, + ) + + outliers = np.asarray( + dds.var.get( + "_outlier_genes", + pd.Series(False, index=dds.var_names), + ) + )[sel] + + final_disp = np.asarray(dds.var["dispersions"])[sel] + final_y = np.sqrt(final_disp) if cv else final_disp + + ax.scatter( + px, + final_y, + s=point_size * (20 + 20 * outliers.astype(int)), + facecolor=np.where(outliers, "none", final_col), + edgecolors=np.where(outliers, final_col, "none"), + ) + + fitted_disp = np.asarray(dds.var["fitted_dispersions"])[sel] + fitted_y = np.sqrt(fitted_disp) if cv else fitted_disp + + ax.scatter( + px, + fitted_y, + facecolor=fit_col, + edgecolors="none", + marker="o", + s=point_size * 20, + ) + + if "x" in log: + ax.set_xscale("log") + if "y" in log: + ax.set_yscale("log") + + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + if legend: + handles = [ + Line2D([0], [0], marker="o", linestyle="", color=gene_col, label="gene-est"), + Line2D([0], [0], marker="o", linestyle="", color=fit_col, label="fitted"), + Line2D([0], [0], marker="o", linestyle="", color=final_col, label="final"), + ] + ax.legend(handles=handles, loc="lower right", frameon=True) + + plt.tight_layout(pad=2.0) + + if return_fig: + return plt.gcf() + + plt.show() + return None + def _test_single_contrast(self, contrast, alpha=0.05, *, lfc_shrink=None, **kwargs) -> pd.DataFrame: """Conduct a specific test and returns a Pandas DataFrame.