diff --git a/bin/annotate_omega_failing.py b/bin/annotate_omega_failing.py index 9ef3666f..9d860ed6 100755 --- a/bin/annotate_omega_failing.py +++ b/bin/annotate_omega_failing.py @@ -41,12 +41,31 @@ def load_flagged_tables(paths: List[Path]) -> Tuple[pd.DataFrame, pd.DataFrame]: "cohort": "sample", "ID": "gene", }) - syn_df = all_df[all_df["regions"] == "synonymous"][["sample", "gene", "reason_exclusion"]].reset_index(drop=True) - npa_df = all_df[all_df["regions"] == "non_protein_affecting"][["sample", "gene", "reason_exclusion"]].reset_index(drop=True) - return syn_df, npa_df - - -def plot_flagged_summary(syn_flagged: pd.DataFrame, npa_flagged: pd.DataFrame, output_prefix: str = 'flagged_gene', top_n: int = 50) -> None: + syn_sample_df = all_df[(all_df["regions"] == "synonymous") + & (all_df["criteria"] == "per_sample") + ][["sample", "gene", "reason_exclusion"]].reset_index(drop=True) + syn_sample_df.rename(columns={"sample": "cohort", "gene": "sample"}, inplace=True) + syn_gene_df = all_df[(all_df["regions"] == "synonymous") + & (all_df["criteria"] == "per_gene") + ][["sample", "gene", "reason_exclusion"]].reset_index(drop=True) + syn_gene_df.rename(columns={"sample": "cohort"}, inplace=True) + + npa_sample_df = all_df[(all_df["regions"] == "non_protein_affecting") + & (all_df["criteria"] == "per_sample") + ][["sample", "gene", "reason_exclusion"]].reset_index(drop=True) + npa_sample_df.rename(columns={"sample": "cohort", "gene": "sample"}, inplace=True) + + npa_gene_df = all_df[(all_df["regions"] == "non_protein_affecting") + & (all_df["criteria"] == "per_gene") + ][["sample", "gene", "reason_exclusion"]].reset_index(drop=True) + npa_gene_df.rename(columns={"sample": "cohort"}, inplace=True) + + return syn_sample_df, syn_gene_df, npa_sample_df, npa_gene_df + + +def plot_flagged_summary(syn_flagged: pd.DataFrame, npa_flagged: pd.DataFrame, + output_prefix: str = 'flagged_gene', top_n: int = 50, + variable : str = 'gene') -> None: """Plot stacked bar summaries of flagged genes for synonymous and non-protein-affecting. Uses the gene order from `syn_flagged` (most frequently flagged first) to @@ -55,22 +74,22 @@ def plot_flagged_summary(syn_flagged: pd.DataFrame, npa_flagged: pd.DataFrame, o will get colors from the seaborn color palette. """ # Ensure dataframes exist - syn = syn_flagged.copy() if syn_flagged is not None else pd.DataFrame(columns=['sample', 'gene', 'reason_exclusion']) - npa = npa_flagged.copy() if npa_flagged is not None else pd.DataFrame(columns=['sample', 'gene', 'reason_exclusion']) + syn = syn_flagged.copy() + npa = npa_flagged.copy() # Determine gene order from synonymously flagged genes (descending frequency) - syn_order = syn['gene'].value_counts().index.tolist() + syn_order = syn[variable].value_counts().index.tolist() # Include any genes present in npa but not in syn at the end - npa_genes = [g for g in npa['gene'].unique() if g not in syn_order] + npa_genes = [g for g in npa[variable].unique() if g not in syn_order] order = syn_order + npa_genes # Count occurrences per gene x reason - syn_counts = syn.groupby(['gene', 'reason_exclusion']).size().unstack(fill_value=0) - npa_counts = npa.groupby(['gene', 'reason_exclusion']).size().unstack(fill_value=0) + syn_counts = syn.groupby([variable, 'reason_exclusion']).size().unstack(fill_value=0) + npa_counts = npa.groupby([variable, 'reason_exclusion']).size().unstack(fill_value=0) # Reindex to the chosen order; keep only top_n genes by total syn+npa counts - total_counts = (syn.groupby('gene').size().add(npa.groupby('gene').size(), fill_value=0)).sort_values(ascending=False) + total_counts = (syn.groupby(variable).size().add(npa.groupby(variable).size(), fill_value=0)).sort_values(ascending=False) top_genes = total_counts.head(top_n).index.tolist() order = [g for g in order if g in top_genes] @@ -85,30 +104,24 @@ def plot_flagged_summary(syn_flagged: pd.DataFrame, npa_flagged: pd.DataFrame, o } # Collect all reasons in consistent order (predefined first) - reasons = [] - for r in predefined.keys(): - if r in set(list(syn_counts.columns) + list(npa_counts.columns)): - reasons.append(r) - # add any other reasons found - other_reasons = [r for r in sorted(set(list(syn_counts.columns) + list(npa_counts.columns))) if r not in reasons] - reasons.extend(other_reasons) + syn_reasons = list(syn_counts.columns) + npa_reasons = list(npa_counts.columns) # Build color list: use predefined colors then palette for others - colors = [] - palette = sns.color_palette('tab10', n_colors=max(3, len(other_reasons))) - other_iter = iter(palette) - for r in reasons: - if r in predefined: - colors.append(predefined[r]) - else: - colors.append(next(other_iter)) + syn_colors = [] + for r in syn_reasons: + syn_colors.append(predefined.get(r, 'grey')) + + npa_colors = [] + for r in npa_reasons: + npa_colors.append(predefined.get(r, 'grey')) # Plotting stacked horizontal barplots (two subplots stacked vertically) fig, axes = plt.subplots(2, 1, figsize=(10, max(6, 0.25 * len(order) * 2)), sharex=True) # Synonymous stacked bar if not syn_counts.empty: - syn_counts[reasons].plot(kind='barh', stacked=True, color=colors, ax=axes[0]) + syn_counts[syn_reasons].plot(kind='barh', stacked=True, color=syn_colors, ax=axes[0]) axes[0].set_title('Synonymous flagged cases (by gene)') axes[0].set_xlabel('Count') else: @@ -116,130 +129,82 @@ def plot_flagged_summary(syn_flagged: pd.DataFrame, npa_flagged: pd.DataFrame, o # NPA stacked bar if not npa_counts.empty: - npa_counts[reasons].plot(kind='barh', stacked=True, color=colors, ax=axes[1]) + npa_counts[npa_reasons].plot(kind='barh', stacked=True, color=npa_colors, ax=axes[1]) axes[1].set_title('Non-protein-affecting flagged cases (by gene)') axes[1].set_xlabel('Count') else: axes[1].text(0.5, 0.5, 'No non-protein-affecting flagged cases', ha='center') plt.tight_layout() - out_png = f"{output_prefix}_cases_summary.png" - fig.savefig(out_png, dpi=300) + fig.savefig(f"{output_prefix}.{variable}.cases_summary.png", dpi=300) plt.close(fig) - # Also write TSV summary - summary = pd.DataFrame({'gene': order}) - for r in reasons: - summary[r] = syn_counts.get(r, 0).reindex(order, fill_value=0).values + npa_counts.get(r, 0).reindex(order, fill_value=0).values - summary.to_csv(f"{output_prefix}_cases_counts.flagged.tsv", sep='\t', index=False) - print(f"Wrote flagged summary plot {out_png} and table {output_prefix}_cases_counts.flagged.tsv") + # # Also write TSV summary + # summary = pd.DataFrame({variable: order}) + # for r in reasons: + # summary[r] = syn_counts.get(r, 0).reindex(order, fill_value=0).values + npa_counts.get(r, 0).reindex(order, fill_value=0).values + # summary.to_csv(f"{output_prefix}_cases_counts.flagged.tsv", sep='\t', index=False) + # print(f"Wrote flagged summary plot {out_png} and table {output_prefix}_cases_counts.flagged.tsv") -def plot_flagged_sample_summary(syn_flagged: pd.DataFrame, npa_flagged: pd.DataFrame, output_prefix: str = 'flagged_sample', top_n: int = 50) -> None: - """Plot stacked bar summaries of flagged genes for synonymous and non-protein-affecting. - Uses the gene order from `syn_flagged` (most frequently flagged first) to - order the bars. Bars are stacked by `reason_exclusion` and colors for a - small set of known reasons are defined a priori; any additional reasons - will get colors from the seaborn color palette. - """ - # Ensure dataframes exist - syn = syn_flagged.copy() if syn_flagged is not None else pd.DataFrame(columns=['sample', 'gene', 'reason_exclusion']) - npa = npa_flagged.copy() if npa_flagged is not None else pd.DataFrame(columns=['sample', 'gene', 'reason_exclusion']) - - # Determine sample order from synonymously flagged genes (descending frequency) - syn_order = syn['sample'].value_counts().index.tolist() - - # Include any samples present in npa but not in syn at the end - npa_samples = [g for g in npa['sample'].unique() if g not in syn_order] - order = syn_order + npa_samples - - # Count occurrences per sample x reason - syn_counts = syn.groupby(['sample', 'reason_exclusion']).size().unstack(fill_value=0) - npa_counts = npa.groupby(['sample', 'reason_exclusion']).size().unstack(fill_value=0) - - # Reindex to the chosen order; keep only top_n samples by total syn+npa counts - total_counts = (syn.groupby('sample').size().add(npa.groupby('sample').size(), fill_value=0)).sort_values(ascending=False) - top_samples = total_counts.head(top_n).index.tolist() - order = [g for g in order if g in top_samples] - - syn_counts = syn_counts.reindex(order, fill_value=0) - npa_counts = npa_counts.reindex(order, fill_value=0) - - # Colors: predefined mapping for known reasons - predefined = { - 'mutdensity = 0': '#8c8c8c', - 'high_mutdensity - zscore > 2': '#d62728', - 'low_mutdensity - zscore < -2': '#1f77b4', - } - - # Collect all reasons in consistent order (predefined first) - reasons = [] - for r in predefined.keys(): - if r in set(list(syn_counts.columns) + list(npa_counts.columns)): - reasons.append(r) - # add any other reasons found - other_reasons = [r for r in sorted(set(list(syn_counts.columns) + list(npa_counts.columns))) if r not in reasons] - reasons.extend(other_reasons) - - # Build color list: use predefined colors then palette for others - colors = [] - palette = sns.color_palette('tab10', n_colors=max(3, len(other_reasons))) - other_iter = iter(palette) - for r in reasons: - if r in predefined: - colors.append(predefined[r]) - else: - colors.append(next(other_iter)) +def annotate(omegas: pd.DataFrame, gene_flagged: pd.DataFrame, sample_flagged: pd.DataFrame ) -> pd.DataFrame: + """Annotate omegas DataFrame with flagged info. - # Plotting stacked horizontal barplots (two subplots stacked vertically) - fig, axes = plt.subplots(2, 1, figsize=(10, max(6, 0.25 * len(order) * 2)), sharex=True) + The returned DataFrame includes two new columns: `flagged` (bool) and + `flag_reason` (string). + """ + # omegas has these columns: + # gene sample impact mutations dnds pvalue lower upper - # Synonymous stacked bar - if not syn_counts.empty: - syn_counts[reasons].plot(kind='barh', stacked=True, color=colors, ax=axes[0]) - axes[0].set_title('Synonymous flagged cases (by sample)') - axes[0].set_xlabel('Count') - else: - axes[0].text(0.5, 0.5, 'No synonymous flagged cases', ha='center') + # and flagged has: + # sample, gene, reason_exclusion - # NPA stacked bar - if not npa_counts.empty: - npa_counts[reasons].plot(kind='barh', stacked=True, color=colors, ax=axes[1]) - axes[1].set_title('Non-protein-affecting flagged cases (by sample)') - axes[1].set_xlabel('Count') - else: - axes[1].text(0.5, 0.5, 'No non-protein-affecting flagged cases', ha='center') + omegas["original_gene"] = omegas["gene"].copy() + omegas["gene"] = omegas["gene"].str.split("--").str[0] + annotated_omegas_gene = omegas.merge(gene_flagged, how="left", + left_on=["sample", "gene"], + right_on=["cohort", "gene"], + suffixes = ("", "_gene") + ).drop(columns=["cohort"]) + annotated_omegas = annotated_omegas_gene.merge(sample_flagged.drop(columns=["cohort"]), + how="left", + on=["sample"], + suffixes = ("", "_sample") + ) + + annotated_omegas["gene"] = annotated_omegas["original_gene"] + annotated_omegas["flag_reason"] = annotated_omegas["reason_exclusion"].fillna("") + annotated_omegas["reason_exclusion_sample"].fillna("") + annotated_omegas["flagged"] = annotated_omegas["flag_reason"] != "" + annotated_omegas = annotated_omegas.drop(columns=["original_gene", "reason_exclusion", "reason_exclusion_sample"]) - plt.tight_layout() - out_png = f"{output_prefix}_cases_summary.png" - fig.savefig(out_png, dpi=300) - plt.close(fig) + return annotated_omegas[['gene', 'sample', 'impact', 'mutations', + 'dnds', 'pvalue', 'lower', 'upper', + 'flagged', 'flag_reason']] - # Also write TSV summary - summary = pd.DataFrame({'sample': order}) - for r in reasons: - summary[r] = syn_counts.get(r, 0).reindex(order, fill_value=0).values + npa_counts.get(r, 0).reindex(order, fill_value=0).values - summary.to_csv(f"{output_prefix}_cases_counts.flagged.tsv", sep='\t', index=False) - print(f"Wrote flagged summary plot {out_png} and table {output_prefix}_cases_counts.flagged.tsv") +def plot_flagged_heatmap(syn_flagged: pd.DataFrame, npa_flagged: pd.DataFrame, + output_prefix: str = 'flagged_heatmap', top_n_genes: int = 200, + top_n_samples: int = 200, mode: str = 'combined', + variable: str = 'gene' + ) -> None: -def plot_flagged_heatmap(syn_flagged: pd.DataFrame, npa_flagged: pd.DataFrame, output_prefix: str = 'flagged_heatmap', top_n_genes: int = 200, top_n_samples: int = 200, mode: str = 'combined') -> None: """Create a genes x samples heatmap showing failure reasons per cell. mode: 'combined' (default) aggregates syn+npa into one heatmap. mode: 'separate' will produce two heatmaps (suffixes '_syn' and '_npa'). """ - def _heatmap_one(df: pd.DataFrame, prefix: str) -> None: + def _heatmap_one(df: pd.DataFrame, prefix: str, variable: str) -> None: syn = df.copy() if syn.empty: print(f'No flagged entries to create heatmap for {prefix}') return - gene_order = syn['gene'].value_counts().head(top_n_genes).index.tolist() - sample_order = syn['sample'].value_counts().head(top_n_samples).index.tolist() + gene_order = syn[variable].value_counts().head(top_n_genes).index.tolist() + sample_order = syn['cohort'].value_counts().head(top_n_samples).index.tolist() - pivot = syn.groupby(['gene', 'sample'])['reason_exclusion'].apply(lambda s: ';'.join(sorted(set([x for x in s if x])))).unstack(fill_value='') + pivot = syn.groupby([variable, 'cohort'])['reason_exclusion'].apply(lambda s: ';'.join(sorted(set([x for x in s if x])))).unstack(fill_value='') pivot = pivot.reindex(index=gene_order, columns=sample_order, fill_value='') + print(pivot) cell_reason = pivot.fillna('').astype(str) cell_reason = cell_reason.apply(lambda col: col.map(lambda v: v.split(';')[0] if v else '')) @@ -250,12 +215,7 @@ def _heatmap_one(df: pd.DataFrame, prefix: str) -> None: 'low_mutdensity - zscore < -2': '#1f77b4', } unique_reasons = sorted(set([r for r in cell_reason.values.flatten() if r])) - other_reasons = [r for r in unique_reasons if r not in predefined] - from matplotlib.colors import to_hex - palette = [to_hex(c) for c in sns.color_palette('tab20', n_colors=max(1, len(other_reasons)))] reason_colors = {r: predefined[r] for r in predefined if r in unique_reasons} - for i, r in enumerate(other_reasons): - reason_colors[r] = palette[i] reason_to_int = {r: idx+1 for idx, r in enumerate(sorted(reason_colors.keys()))} int_matrix = cell_reason.replace('', pd.NA).apply(lambda col: col.map(lambda v: reason_to_int.get(v, pd.NA))) @@ -268,56 +228,29 @@ def _heatmap_one(df: pd.DataFrame, prefix: str) -> None: ax = sns.heatmap(int_matrix.fillna(0).astype(int), cmap=cmap, cbar=False, linewidths=0.2) ax.set_yticklabels(int_matrix.index, rotation=0) ax.set_xticklabels(int_matrix.columns, rotation=90) - plt.title('Flagged genes (rows) x samples (columns) - colored by failure reason') + plt.title('Flagged genes (rows) x cohorts (columns) - colored by failure reason') plt.tight_layout() - out_png = f"{prefix}.png" + out_png = f"{prefix}.{variable}.png" plt.savefig(out_png, dpi=200) plt.close() - legend_df = pd.DataFrame([{'reason': r, 'color': reason_colors[r]} for r in sorted(reason_colors.keys())]) - legend_df.to_csv(f"{prefix}_legend.tsv", sep='\t', index=False) - print(f"Wrote heatmap {out_png} and legend {prefix}_legend.tsv") - # combined df - all_df = pd.concat((syn_flagged.copy() if syn_flagged is not None else pd.DataFrame(columns=['sample','gene','reason_exclusion']), - npa_flagged.copy() if npa_flagged is not None else pd.DataFrame(columns=['sample','gene','reason_exclusion'])), + all_df = pd.concat((syn_flagged.copy() if syn_flagged is not None else pd.DataFrame(columns=['cohort',variable,'reason_exclusion']), + npa_flagged.copy() if npa_flagged is not None else pd.DataFrame(columns=['cohort',variable,'reason_exclusion'])), ignore_index=True, sort=False) if mode == 'combined': if all_df.empty: print('No flagged entries to create heatmap') return - _heatmap_one(all_df, output_prefix) + _heatmap_one(all_df, output_prefix, variable) elif mode == 'separate': - _heatmap_one(syn_flagged.copy() if syn_flagged is not None else pd.DataFrame(columns=['sample','gene','reason_exclusion']), f"{output_prefix}_syn") - _heatmap_one(npa_flagged.copy() if npa_flagged is not None else pd.DataFrame(columns=['sample','gene','reason_exclusion']), f"{output_prefix}_npa") + _heatmap_one(syn_flagged.copy() if syn_flagged is not None else pd.DataFrame(columns=['cohort',variable,'reason_exclusion']), f"{output_prefix}_syn", variable) + _heatmap_one(npa_flagged.copy() if npa_flagged is not None else pd.DataFrame(columns=['cohort',variable,'reason_exclusion']), f"{output_prefix}_npa", variable) else: raise ValueError(f"Unknown mode for plot_flagged_heatmap: {mode}") -def annotate(omegas: pd.DataFrame, flagged: pd.DataFrame) -> pd.DataFrame: - """Annotate omegas DataFrame with flagged info. - - The returned DataFrame includes two new columns: `flagged` (bool) and - `flag_reason` (string). - """ - # omegas has these columns: - # gene sample impact mutations dnds pvalue lower upper - - # and flagged has: - # sample, gene, reason_exclusion - - omegas["original_gene"] = omegas["gene"].copy() - omegas["gene"] = omegas["gene"].str.split("-").str[0] - annotated_omegas = omegas.merge(flagged, how="left", - on=["sample", "gene"]) - annotated_omegas["gene"] = annotated_omegas["original_gene"] - annotated_omegas["flagged"] = annotated_omegas["reason_exclusion"].notnull() - annotated_omegas["flag_reason"] = annotated_omegas["reason_exclusion"].fillna("") - - return annotated_omegas.drop(columns=["original_gene", "reason_exclusion"]) - - @@ -341,33 +274,38 @@ def main(omegas_file: str, compiled_flagged_files: str, output: str) -> None: # Read omegas omegas = pd.read_csv(omegas_path, sep="\t", header=0, dtype=str).fillna("") - syn_flagged, npa_flagged = load_flagged_tables(flagged_paths) + syn_flagged_sample, syn_flagged_gene, npa_flagged_sample, npa_flagged_gene = load_flagged_tables(flagged_paths) # keep debug outputs for inspection - syn_flagged.to_csv("debug.syn_flagged.tsv", sep="\t", index=False) - npa_flagged.to_csv("debug.npa_flagged.tsv", sep="\t", index=False) + syn_flagged_sample.to_csv("debug.syn_flagged_sample.tsv", sep="\t", index=False) + syn_flagged_gene.to_csv("debug.syn_flagged_gene.tsv", sep="\t", index=False) + npa_flagged_sample.to_csv("debug.npa_flagged_sample.tsv", sep="\t", index=False) + npa_flagged_gene.to_csv("debug.npa_flagged_gene.tsv", sep="\t", index=False) - if syn_flagged.empty and npa_flagged.empty: + if syn_flagged_sample.empty and npa_flagged_sample.empty: print('No flagged entries found; skipping plots and annotating with no flags.') - + else: - # Gene-based summary try: - plot_flagged_summary(syn_flagged, npa_flagged) + # Gene-based summary + plot_flagged_summary(syn_flagged_gene, npa_flagged_gene, variable = 'gene') except Exception as e: print(f"Warning: plot_flagged_summary failed: {e}") - # Sample-based summaries (separate syn/npa) + try: - plot_flagged_sample_summary(syn_flagged, npa_flagged) + # Sample-based summaries (separate syn/npa) + plot_flagged_summary(syn_flagged_sample, npa_flagged_sample, variable = 'sample') except Exception as e: - print(f"Warning: plot_flagged_sample_summary failed: {e}") - # Heatmaps (separate syn/npa) + print(f"Warning: plot_flagged_summary failed: {e}") + try: - plot_flagged_heatmap(syn_flagged, npa_flagged, mode='separate') + # Heatmaps (separate syn/npa) + plot_flagged_heatmap(syn_flagged_gene, npa_flagged_gene , mode='separate', variable='gene' ) + plot_flagged_heatmap(syn_flagged_sample, npa_flagged_sample , mode='separate', variable='sample') except Exception as e: - print(f"Warning: plot_flagged_heatmap failed: {e}") + print(f"Warning: plot_flagged_summary failed: {e}") - annotated = annotate(omegas, syn_flagged) + annotated = annotate(omegas, syn_flagged_gene, syn_flagged_sample) annotated.to_csv(output, sep="\t", index=False) print(f"Wrote annotated omegas to {output}") diff --git a/bin/check_contamination.py b/bin/check_contamination.py index c6505fe3..bd2f83f4 100755 --- a/bin/check_contamination.py +++ b/bin/check_contamination.py @@ -41,20 +41,7 @@ def compute_shared_variants(somatic_variants, germline_variants): -def contamination_detection(maf_file, somatic_maf_file): - - maf_df = pd.read_table(maf_file, na_values=custom_na_values) - print(maf_df.shape) - maf_df = maf_df[~(maf_df["FILTER.not_covered"]) - & (maf_df["TYPE"] == 'SNV') - ].reset_index() - print(maf_df.shape) - - somatic_maf_df = pd.read_table(somatic_maf_file, na_values=custom_na_values) - print(somatic_maf_df.shape) - somatic_maf_df = somatic_maf_df[(somatic_maf_df["TYPE"] == 'SNV')] - print(somatic_maf_df.shape) - +def contamination_detection_between_samples(maf_df, somatic_maf_df): # this is if we were to consider both unique and no-unique variants vaf_threshold = 0.2 @@ -372,6 +359,54 @@ def contamination_detection(maf_file, somatic_maf_file): +def data_loading(maf_path, somatic_maf_path): + maf_df = pd.read_table(maf_path, na_values=custom_na_values) + print(maf_df.shape) + maf_df = maf_df[~(maf_df["FILTER.not_covered"]) + & (maf_df["TYPE"] == 'SNV') + ].reset_index() + print(maf_df.shape) + + somatic_maf_df = pd.read_table(somatic_maf_path, na_values=custom_na_values) + print(somatic_maf_df.shape) + somatic_maf_df = somatic_maf_df[(somatic_maf_df["TYPE"] == 'SNV')] + print(somatic_maf_df.shape) + return maf_df, somatic_maf_df + + +def contamination_detection_in_snps(maf): + + snp_positions_maf = maf[maf["FILTER.gnomAD_SNP"]][ + ["SAMPLE_ID", "MUT_ID", "VAF"] + ].reset_index(drop = True) + + # being very restrictive in the VAF to count the occurrences of potentially contaminated mutations + somatic_snp_positions_maf = snp_positions_maf[snp_positions_maf["VAF"] < 0.05].reset_index(drop = True) + germline_snp_positions_maf = snp_positions_maf[snp_positions_maf["VAF"] >= 0.05].reset_index(drop = True) + + unique_SNP_positions = snp_positions_maf["MUT_ID"].unique() + number_unique_SNP_positions = len(unique_SNP_positions) + + sample_SNP_mutation_freq = [] + for sample in snp_positions_maf["SAMPLE_ID"].unique(): + germline_count = len(germline_snp_positions_maf[germline_snp_positions_maf["SAMPLE_ID"] == sample]) + somatic_count = len(somatic_snp_positions_maf[somatic_snp_positions_maf["SAMPLE_ID"] == sample]) + remaining_germline = number_unique_SNP_positions-germline_count + sample_SNP_mutation_freq.append([sample, + germline_count, + remaining_germline, + somatic_count, + somatic_count / remaining_germline if remaining_germline > 0 else 1 + ]) + sample_SNP_mutation_freq_df = pd.DataFrame(sample_SNP_mutation_freq) + sample_SNP_mutation_freq_df.columns = ["SAMPLE_ID", "germline_count", "remaining_germline", "somatic_count", "prop_somatic_SNPs"] + + # identify outliers in the "prop_somatic_SNPs" column + sample_SNP_mutation_freq_df = sample_SNP_mutation_freq_df.sort_values(by = "prop_somatic_SNPs", ascending = False) + sample_SNP_mutation_freq_df.to_csv("sample_SNP_mutation_freq.tsv", header = True, sep = '\t', index = False) + + + @click.command() @click.option('--maf_path', type=click.Path(exists=True), required=True, help='Path to the MAF file.') @click.option('--somatic_maf', type=click.Path(exists=True), required=True, help='Path to the filtered somatic mutations file.') @@ -379,7 +414,14 @@ def main(maf_path, somatic_maf): """ CLI entry point for assessing contamination between samples using germline and somatic mutations. """ - contamination_detection(maf_path, somatic_maf) + + maf_df, somatic_maf_df = data_loading(maf_path, somatic_maf) + + print("Running contamination analysis between samples") + contamination_detection_between_samples(maf_df, somatic_maf_df) + + print("Running general contamination analysis") + contamination_detection_in_snps(maf_df) diff --git a/bin/mut_profile.py b/bin/mut_profile.py index 7f7da278..60c834fb 100755 --- a/bin/mut_profile.py +++ b/bin/mut_profile.py @@ -10,6 +10,17 @@ from utils_plot import plot_profile from read_utils import custom_na_values +def bayesian_update(observed_counts, prior_profile, limit=200): + + N = np.sum(observed_counts).item() + + if N >= limit: + return observed_counts + else: + prior_profile = prior_profile.set_index("CONTEXT_MUT") + posterior_matrix = ((limit - N) * prior_profile).values + observed_counts + return posterior_matrix + def compute_mutation_matrix(sample_name, mutations_file, mutation_matrix, method, pseudocount, sigprofiler, per_sample): @@ -98,7 +109,6 @@ def compute_mutation_matrix(sample_name, mutations_file, mutation_matrix, method sep = "\t") - def profile_stability(counts, denominator): """ Compute stability score of a probability profile by @@ -164,23 +174,28 @@ def profile_stability(counts, denominator): } -def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_counts_file, plot, - wgs = False, wgs_trinucleotide_counts = False, sigprofiler = False): +def compute_mutation_profile(sample_name, mutation_matrix, trinucleotide_counts_file, plot, + wgs = False, wgs_trinucleotide_counts = False, sigprofiler = False, + smoothed = False, prior_profile_file = None, + minimum_mutations = 200 + ): """ Compute mutational profile from the input data Required information: Mutation matrix Trinucleotide content of the sequenced region (depth-aware or non-depth-aware) + + Minimum number of mutations to apply Bayesian smoothing (if smoothed is active) + This has been set to 200 after proper testing of the minimum required entropy of mutational profiles. Output: Mutational profile per sample """ # Load your mutation matrix - mutation_matrix = pd.read_csv(mutation_matrix_file, sep = "\t", header = 0) mutation_matrix = mutation_matrix.set_index("CONTEXT_MUT") total_mutations = np.sum(mutation_matrix[sample_name]) - + # proportion of SBS mutations per trinucleotide in panel mutation_matrix_proportions = mutation_matrix.copy() mutation_matrix_proportions[sample_name] = mutation_matrix_proportions[sample_name] / total_mutations @@ -214,17 +229,6 @@ def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_co # normalize mut_probability = mut_probability / mut_probability.sum() - # if there is any channel with 0 probability we need to add a pseudocount - if not all(mut_probability[sample_name].values > 0): - # find the minimum value greater than 0 - min_value_non_zero = mut_probability[mut_probability > 0].min() - # print(min_value_non_zero) - - # add a dynamic pseudocount of one third of the minimum number greater than 0 - mut_probability = mut_probability + (min_value_non_zero / 3) - - mut_probability = mut_probability / mut_probability.sum() - # reindex to ensure the right order mut_probability = mut_probability.reindex(contexts_formatted) @@ -271,6 +275,32 @@ def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_co header = True, index = True, sep = "\t") + + if total_mutations < minimum_mutations and smoothed: + print(f"Using a prior profile to smooth the profile, since the mutation count is < {minimum_mutations}") + prior_profile = pd.read_table(prior_profile_file) + + upd_mutation_matrix_wgs = bayesian_update(profile_trinuc_clean, prior_profile, minimum_mutations).reset_index() + total_mutations = max(minimum_mutations, total_mutations) + + # we have updated the mutation counts at the WGS level, we should now revert this update + # for the specific sample's trinucleotide content and depth + + upd_mutation_matrix_wgs["CONTEXT"] = upd_mutation_matrix_wgs["CONTEXT_MUT"].apply( lambda x : x[:3]) + profile_trinuc_merge = upd_mutation_matrix_wgs.merge(ref_trinuc32, on = "CONTEXT") + + # relative mutability per trinucleotide change in the panel after smoothing + profile_trinuc_merge["RELATIVE_MUTABILITY_PER_CHANNEL"] = profile_trinuc_merge[sample_name] / profile_trinuc_merge["COUNT"] + profile_trinuc_merge["RELATIVE_MUTABILITY_PER_CHANNEL"] = profile_trinuc_merge["RELATIVE_MUTABILITY_PER_CHANNEL"] / profile_trinuc_merge["RELATIVE_MUTABILITY_PER_CHANNEL"].sum() + + profile_n_sample_trinuc = profile_trinuc_merge.merge(trinucleotide_counts.reset_index(), on = "CONTEXT", suffixes = ("", "_PANEL")) + profile_n_sample_trinuc["MUTS_PANEL"] = profile_n_sample_trinuc[sample_name] * profile_n_sample_trinuc[f"{sample_name}_PANEL"] + profile_n_sample_trinuc["MUTS_PANEL_FINAL"] = profile_n_sample_trinuc["MUTS_PANEL"] / profile_n_sample_trinuc["MUTS_PANEL"].sum() * total_mutations + smoothed_mutation_matrix_panel_counts = profile_n_sample_trinuc[["CONTEXT_MUT", "MUTS_PANEL_FINAL"]] + + return smoothed_mutation_matrix_panel_counts.rename({"MUTS_PANEL_FINAL": sample_name}, axis = 1) + else: + print(f"No need to smooth the profile, the mutation count is already >= {minimum_mutations}") profile_trinuc_clean_proportion = profile_trinuc_clean.copy() profile_trinuc_clean_proportion[sample_name] = profile_trinuc_clean_proportion[sample_name] / profile_trinuc_clean_proportion[sample_name].sum() @@ -278,7 +308,7 @@ def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_co header = True, index = True, sep = "\t") - + # plot the profile as a percentage of SBS mutations seen after sequencing one WGS # if mutations were occuring with the same probabilities as they occur in our sequenced panel @@ -295,7 +325,7 @@ def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_co index = False, sep = "\t") - + return None @@ -313,12 +343,14 @@ def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_co @click.option('--plot', is_flag=True, help='Generate plot and save as PDF') @click.option('--wgs', is_flag=True, help='Store matrix of mutation counts at WGS level') @click.option('--wgs_trinucleotide_counts', type=click.Path(exists=True), help='Trinucleotide counts file of the WGS (for profile mode if WGS active)') +@click.option('--smoothed', is_flag=True, help='Apply Bayesian smoothing to the mutation counts using a prior profile') +@click.option('--prior_profile', type=click.Path(exists=True), help='Prior profile file to use for Bayesian smoothing (required if --smoothed is set)') @click.option('--sigprofiler', is_flag=True, help='Store the index column using the SigProfiler format') def main(mode, sample_name, mut_file, out_matrix, method, pseud, sigprofiler, per_sample, mutation_matrix, - trinucleotide_counts, plot, wgs, wgs_trinucleotide_counts): + trinucleotide_counts, plot, wgs, wgs_trinucleotide_counts, smoothed, prior_profile): if mode == 'matrix': click.echo(f"Running in matrix mode...") @@ -328,7 +360,11 @@ def main(mode, sample_name, mut_file, out_matrix, method, pseud, sigprofiler, pe elif mode == 'profile': click.echo(f"Running in profile mode...") - compute_mutation_profile(sample_name, mutation_matrix, trinucleotide_counts, plot, wgs, wgs_trinucleotide_counts, sigprofiler) + mutation_matrix_loaded = pd.read_csv(mutation_matrix, sep = "\t", header = 0) + smoothed_mutation_matrix = compute_mutation_profile(sample_name, mutation_matrix_loaded, trinucleotide_counts, plot, wgs, wgs_trinucleotide_counts, sigprofiler, smoothed, prior_profile) + if smoothed_mutation_matrix is not None: + compute_mutation_profile(sample_name, smoothed_mutation_matrix, trinucleotide_counts, plot, wgs, wgs_trinucleotide_counts, sigprofiler) + click.echo("Profile computation completed.") else: diff --git a/conf/modules.config b/conf/modules.config index 05366308..7a084208 100644 --- a/conf/modules.config +++ b/conf/modules.config @@ -164,6 +164,10 @@ process { ext.args = "--level ${params.confidence_level}" } + withName: COMPUTEPROFILE { + ext.smoothing = params.profile_smoothing + } + withName: MAF2VCF { ext.args = "--output-dir . --maf-from-deepcsa --sample-name-column SAMPLE_ID" publishDir = [ diff --git a/modules/local/compute_profile/main.nf b/modules/local/compute_profile/main.nf index 568e1957..89039194 100644 --- a/modules/local/compute_profile/main.nf +++ b/modules/local/compute_profile/main.nf @@ -7,8 +7,9 @@ process COMPUTE_PROFILE { label 'deepcsa_core' input: - tuple val(meta), path(matrix), path(trinucleotide) + tuple val(meta) , path(matrix), path(trinucleotide) path( wgs_trinucleotides ) + tuple val(meta2), path( cohort_profile , stageAs: 'global_mutprofile.tsv') output: tuple val(meta), path("*.profile.tsv") , emit: profile @@ -28,12 +29,14 @@ process COMPUTE_PROFILE { def prefix = task.ext.prefix ?: "" prefix = "${meta.id}${prefix}" def wgs_trinuc = wgs_trinucleotides ? "--wgs --wgs_trinucleotide_counts ${wgs_trinucleotides}" : "" + def smoothing_args = task.ext.smoothing ? "--smoothed --prior_profile ${cohort_profile}" : "" """ mut_profile.py profile \\ --sample_name ${prefix} \\ --mutation_matrix ${matrix} \\ --trinucleotide_counts ${trinucleotide} \\ ${wgs_trinuc} \\ + ${smoothing_args} \\ ${args} cat <<-END_VERSIONS > versions.yml "${task.process}": diff --git a/nextflow.config b/nextflow.config index 8ebb82b2..6b3ec4f7 100644 --- a/nextflow.config +++ b/nextflow.config @@ -40,6 +40,7 @@ params { profilenonprot = false profileexons = false profileintrons = false + profile_smoothing = false positive_selection_non_protein_affecting = false oncodrivefml = false diff --git a/nextflow_schema.json b/nextflow_schema.json index e283e8eb..4be35947 100644 --- a/nextflow_schema.json +++ b/nextflow_schema.json @@ -580,6 +580,11 @@ "type": "boolean", "description": "Do you want to run the profile for intron mutations only?", "fa_icon": "fas fa-book" + }, + "profile_smoothing": { + "type": "boolean", + "description": "Do you want to apply smoothing to the mutational profile when having low numbers of mutations?", + "fa_icon": "fas fa-book" } } }, diff --git a/subworkflows/local/mutationprofile/main.nf b/subworkflows/local/mutationprofile/main.nf index 173a1d46..877fabb9 100644 --- a/subworkflows/local/mutationprofile/main.nf +++ b/subworkflows/local/mutationprofile/main.nf @@ -7,6 +7,7 @@ include { COMPUTE_MATRIX as COMPUTEMATRIX } from '../../.. include { COMPUTE_TRINUCLEOTIDE as COMPUTETRINUC } from '../../../modules/local/compute_trinucleotide/main' include { COMPUTE_PROFILE as COMPUTEPROFILE } from '../../../modules/local/compute_profile/main' +include { COMPUTE_PROFILE as COMPUTEPROFILECOHORT } from '../../../modules/local/compute_profile/main' include { CONCAT_PROFILES as CONCATPROFILES } from '../../../modules/local/concatprofiles/main' @@ -40,7 +41,19 @@ workflow MUTATIONAL_PROFILE { .join(COMPUTETRINUC.out.trinucleotides) .set{ matrix_n_trinucleotide } - COMPUTEPROFILE(matrix_n_trinucleotide, wgs_trinuc) + dummy_file = wgs_trinuc.map{ it -> [ [ id: "dummy_file" ], it ] } + + if (params.profile_smoothing) { + matrix_n_trinucleotide + .join( channel.of([ [ id: "all_samples" ] ]) ) + .set{ matrix_n_trinucleotide_all } + + COMPUTEPROFILECOHORT(matrix_n_trinucleotide_all, wgs_trinuc, dummy_file) + cohort_profile = COMPUTEPROFILECOHORT.out.wgs_proportions.first() + } else { + cohort_profile = dummy_file + } + COMPUTEPROFILE(matrix_n_trinucleotide, wgs_trinuc, cohort_profile) sigprofiler_empty = channel.of([]) sigprofiler_empty