Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
976c683
updated config files, added array size parameter for cluster execution
tpall Nov 20, 2025
ef13eed
updated nextflow.config
tpall Nov 20, 2025
d9a3bc6
Swap output assignments for rRNA and tRNA collections
tpall Nov 21, 2025
a087cf4
Merge branch 'dev' of https://github.com/WrightonLabCSU/DRAM into dev
tpall Nov 21, 2025
f5697b2
Merge branch 'dev' of https://github.com/tpall/DRAM into dev
tpall Nov 21, 2025
63ea268
Refactor distill script and configuration for improved clarity and fu…
tpall Nov 25, 2025
7414979
Refactor input and output path definitions for consistency in the SUM…
tpall Nov 26, 2025
a77c29e
Fix conditional check for gene columns in genome summary export to pr…
tpall Nov 26, 2025
e8b0e95
Refactor channel usage for consistency across workflows and improve r…
tpall Nov 26, 2025
4418739
Update SUMMARIZE module to use parameterized fasta column for grouping
tpall Nov 27, 2025
cd3b7ac
Fix closure in QC workflow
tpall Nov 28, 2025
ed054bd
Fix closure in DB_SEARCH workflow
tpall Nov 28, 2025
d39ff14
Updated combine_annotations.py to fix binwise summary. TODO: getting…
tpall Dec 1, 2025
64ab39e
Add QC:COLLECT_RNA to array pattern
tpall Dec 18, 2025
1ee2aea
Merge branch 'dev' of https://github.com/WrightonLabCSU/DRAM into dev
tpall Dec 23, 2025
2c93e2a
Merge branch 'dev' of https://github.com/WrightonLabCSU/DRAM into dev
tpall Dec 23, 2025
702666f
Merge branch 'dev' of https://github.com/tpall/DRAM into dev
tpall Dec 23, 2025
69b23f2
Merge branch 'dev' of https://github.com/WrightonLabCSU/DRAM into dev
tpall Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 6 additions & 21 deletions bin/combine_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
logger = get_logger(filename=Path(__file__).stem)

def read_and_preprocess(path: Path):
# We design input fastas from intermediate steps to be named like: "input_fasta___some_information_annotation_file.tsv"
input_fasta = input_fasta_from_filepath(path)
try:
df = pd.read_csv(path)
df[FASTA_COLUMN] = input_fasta # Add input_fasta column
df[FASTA_COLUMN] = input_fasta
return df
except Exception as e:
logger.error(f"Error loading DataFrame for input_fasta {input_fasta}: {str(e)}")
return pd.DataFrame() # Return an empty DataFrame in case of error
return pd.DataFrame()

def input_fasta_from_filepath(file_path: Path):
return file_path.stem.split("___")[0]
Expand Down Expand Up @@ -51,7 +50,6 @@ def count_motifs(gene_faa, motif="(C..CH)", genes_faa_dict=None):
for seq in read_sequence(gene_faa, format="fasta"):
if seq.metadata["id"] not in genes_faa_dict:
genes_faa_dict[seq.metadata["id"]] = {}

genes_faa_dict[seq.metadata["id"]]["heme_regulatory_motif_count"] = len(list(seq.find_with_regex(motif)))
return genes_faa_dict

Expand All @@ -61,7 +59,6 @@ def set_gene_data(gene_faa, genes_faa_dict=None):
for seq in read_sequence(gene_faa, format="fasta"):
if seq.metadata["id"] not in genes_faa_dict:
genes_faa_dict[seq.metadata["id"]] = {}

split_label = seq.metadata["id"].split("_")
gene_position = split_label[-1]
start_position, end_position, strandedness = seq.metadata["description"].split("#")[1:4]
Expand All @@ -84,16 +81,13 @@ def organize_columns(df, special_columns=None):
special_columns = []
base_columns = ['query_id', FASTA_COLUMN, "scaffold", 'gene_number', 'start_position', 'stop_position', 'strandedness', 'rank']
base_columns = [col for col in base_columns if col in df.columns]

kegg_columns = sorted([col for col in df.columns if col.startswith('kegg_')], key=lambda x: (x != 'kegg_id', x))
other_columns = [col for col in df.columns if col not in base_columns + kegg_columns + special_columns]

db_prefixes = set(col.split('_')[0] for col in other_columns)
sorted_other_columns = []
for prefix in db_prefixes:
prefixed_columns = sorted([col for col in other_columns if col.startswith(prefix + '_')], key=lambda x: (x != f"{prefix}_id", x))
sorted_other_columns.extend(prefixed_columns)

final_columns_order = base_columns + kegg_columns + sorted_other_columns + special_columns
return df[final_columns_order]

Expand All @@ -107,29 +101,20 @@ def combine_annotations(annotations_dir, genes_dir, output, threads):
annotations = Path(annotations_dir).glob("*")
genes_faa = Path(genes_dir).glob("*")
with ThreadPoolExecutor(max_workers=threads) as executor:
# futures = [executor.submit(read_and_preprocess, input_fasta, path) for input_fasta, path in input_fastas_and_paths]
futures = [executor.submit(read_and_preprocess, Path(path)) for path in annotations]
data_frames = [future.result() for future in as_completed(futures)]

combined_data = pd.concat(data_frames, ignore_index=True)
combined_data = pd.concat([df for df in data_frames if not df.empty], ignore_index=True)
if genes_faa:
genes_faa_dict = dict()
for gene_path in genes_faa:
gene_path = str(gene_path)
genes_faa_dict
count_motifs(gene_path, "(C..CH)", genes_faa_dict=genes_faa_dict)
set_gene_data(gene_path, genes_faa_dict)
df = pd.DataFrame.from_dict(genes_faa_dict, orient='index')
columns = [col for col in df.columns.tolist() if col != FASTA_COLUMN]
combined_data = combined_data.drop(columns=columns, errors='ignore')
df.index.name = 'query_id'
df = df.rename(columns={FASTA_COLUMN: FASTA_COLUMN+"2"})

# we use outer to get any genes that don't have hits
combined_data = pd.merge(combined_data, df, how="outer", on="query_id")
combined_data[FASTA_COLUMN] = combined_data[FASTA_COLUMN].fillna("")
mask = combined_data[FASTA_COLUMN] != ""
combined_data[FASTA_COLUMN] = combined_data[FASTA_COLUMN].where(mask, other=combined_data[FASTA_COLUMN+"2"])
df = pd.DataFrame.from_dict(genes_faa_dict, orient='index').reset_index().rename(columns={'index': 'query_id'})
combined_data = combined_data.drop(columns=df.columns.difference(["query_id", "scaffold", FASTA_COLUMN]), errors='ignore')
combined_data = pd.merge(combined_data, df, how="outer", on=["query_id", FASTA_COLUMN])

combined_data = convert_bit_scores_to_numeric(combined_data)

Expand Down
205 changes: 147 additions & 58 deletions bin/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
DISTILATE_SORT_ORDER_COLUMNS = [COL_HEADER, COL_SUBHEADER, COL_MODULE, COL_GENE_ID]
EXCEL_MAX_CELL_SIZE = 32767

FASTA_COLUMN = os.getenv('FASTA_COLUMN', 'input_fasta')
DISTILL_DIR = Path(__file__).parent / "assets/forms/distill_sheets"


Expand All @@ -34,21 +33,107 @@ def check_columns(data, logger):
missing = [i for i in ID_EXPR_DICT if i not in data.columns]
logger.info("Note: the following id fields "
f"were not in the annotations file and are not being used: {missing},"
f" but these are {functions}")
f" but these are {list(functions.keys())}")

def make_genome_summary(annotations, genome_summary_frame: pl.LazyFrame, logger, groupby_column=FASTA_COLUMN):
rules_col = "rules"
if rules_col not in genome_summary_frame.collect_schema().names():
genome_summary_frame = genome_summary_frame.with_columns(
pl.lit(None).cast(pl.String).alias(rules_col)
)

genome_summary_frame = genome_summary_frame.with_columns(
pl.when(pl.col(rules_col).is_not_null())
.then(pl.col(rules_col))
.otherwise(pl.col("gene_id"))
.alias(rules_col)
)
def get_ids_from_annotations_by_row(data):
functions = {i:j for i,j in FUNCTION_DICT.items() if i in data.columns}
out = data.apply(lambda x: {i for k, v in functions.items() if not pd.isna(x[k])
for i in v(str(x[k])) if not pd.isna(i)}, axis=1)
return out


def get_ids_from_annotations_all(data):
data = get_ids_from_annotations_by_row(data)
data.apply(list)
out = Counter(chain(*data.values))
return out


def fill_genome_summary_frame(annotations, genome_summary_frame, groupby_column, logger):
genome_summary_id_sets = [set([str(k).strip() for k in j.split(',')]) for j in genome_summary_frame[COL_GENE_ID]]
logger.info(f"Genome summary ID sets: {genome_summary_id_sets}")

def fill_a_frame(frame: pd.DataFrame):
id_dict = get_ids_from_annotations_all(frame)
logger.info(f"ID dictionary for {frame.name}: {id_dict}")

counts = list()
for set_ in genome_summary_id_sets:
identifier_count = 0
for gene_id in set_:
# Try matching with and without '.hmm'
matching_keys = [key for key in id_dict.keys() if gene_id == key or key.startswith(gene_id + ".")]
for key in matching_keys:
identifier_count += id_dict[key]
counts.append(identifier_count)
# logger.info(f"Counts for {frame.name}: {counts}")

return pd.Series(counts, index=genome_summary_frame.index)

counts = annotations.groupby(groupby_column, sort=False)[annotations.columns].apply(fill_a_frame)
genome_summary_frame = pd.concat([genome_summary_frame, counts.T], axis=1)

return genome_summary_frame


def fill_genome_summary_frame_gene_names(annotations, genome_summary_frame, groupby_column, logger):
genome_summary_id_sets = [set([k.strip() for k in j.split(',')]) for j in genome_summary_frame[COL_GENE_ID]]
for genome, frame in annotations.groupby(groupby_column, sort=False):
# make dict of identifiers to gene names
id_gene_dict = defaultdict(list)
for gene, ids in get_ids_from_annotations_by_row(frame).items():
for id_ in ids:
id_gene_dict[id_].append(gene)
# fill in genome summary_frame
values = list()
for id_set in genome_summary_id_sets:
this_value = list()
for id_ in id_set:
this_value += id_gene_dict[id_]
values.append(','.join(this_value))
genome_summary_frame[genome] = values
return genome_summary_frame


def summarize_rrnas(rrnas_df, groupby_column="input_fasta"):
genome_rrna_dict = dict()
for genome, frame in rrnas_df.groupby(groupby_column):
genome_rrna_dict[genome] = Counter(frame['type'])
row_list = list()
for rna_type in RRNA_TYPES:
row = [rna_type, '%s ribosomal RNA gene' % rna_type.split()[0], 'rRNA', 'rRNA', '', '']
for genome, rrna_dict in genome_rrna_dict.items():
row.append(genome_rrna_dict[genome].get(rna_type, 0))
row_list.append(row)
rrna_frame = pd.DataFrame(row_list, columns=FRAME_COLUMNS + list(genome_rrna_dict.keys()))
return rrna_frame


def make_genome_summary(annotations, genome_summary_frame, logger, groupby_column="input_fasta"):

summary_frames = list()
# get ko summaries
summary_frames.append(fill_genome_summary_frame(annotations, genome_summary_frame.copy(), groupby_column, logger))

# merge summary frames
summarized_genomes = pd.concat(summary_frames, sort=False)
return summarized_genomes


def split_column_str(names):
if len(names) < EXCEL_MAX_CELL_SIZE:
return [names]
out = ['']
name_list = names.split(',')
j = 0
for i in name_list:
if len(out[j]) + len(i) + 1 < EXCEL_MAX_CELL_SIZE:
out[j] = ','.join([out[j], i])
else:
j += 1
out += ['']
return out

df = evaluate_rules_on_anno(
rules=genome_summary_frame,
Expand All @@ -66,12 +151,25 @@ def make_genome_summary(annotations, genome_summary_frame: pl.LazyFrame, logger,

df = genome_summary_frame.collect().join(df, on="gene_id", how="left")

df = df.drop(OPTIONAL_COLUMNS, strict=False)
def write_summarized_genomes_to_xlsx(summarized_genomes, output_file, extra_frames=tuple()):
# turn all this into an xlsx
with pd.ExcelWriter(output_file) as writer:
for sheet, frame in summarized_genomes.groupby(COL_SHEET, sort=False):
frame = frame.sort_values(DISTILATE_SORT_ORDER_COLUMNS)
frame = frame.drop([COL_SHEET], axis=1)
gene_columns = list(set(frame.columns) - set(CONSTANT_DISTILLATE_COLUMNS))
if gene_columns:
split_genes = pd.concat([split_names_to_long(frame[i].astype(str)) for i in gene_columns], axis=1)
frame = pd.concat([frame[CONSTANT_DISTILLATE_COLUMNS], split_genes], axis=1)
frame.to_excel(writer, sheet_name=sheet, index=False)
for extra_frame in extra_frames:
if extra_frame is not None and not extra_frame.empty:
extra_frame.to_excel(writer, sheet_name=extra_frame[COL_HEADER].iloc[0], index=False)

return df

# TODO: add assembly stats like N50, longest contig, total assembled length etc
def make_genome_stats(annotations: pl.DataFrame, rrna_frame: pl.DataFrame = None, trna_frame: pl.DataFrame = None, quast_frame: pl.DataFrame = None, groupby_column: str = FASTA_COLUMN):
def make_genome_stats(annotations, rrna_frame=None, trna_frame=None, quast_frame=None, groupby_column="input_fasta"):
rows = list()
columns = ['genome']
if 'scaffold' in annotations.columns:
Expand Down Expand Up @@ -99,23 +197,15 @@ def make_genome_stats(annotations: pl.DataFrame, rrna_frame: pl.DataFrame = None
meta_cols = RRNA_COLUMNS
sample_cols = [c for c in rrna_frame.columns if c not in meta_cols]

# group_by gene_id, sum sample columns
df_rrna = (
rrna_frame
.group_by("gene_id")
.agg([pl.col(c).sum().alias(c) for c in sample_cols])
)
df_rrna = rrna_frame.groupby("gene_id")[sample_cols].sum()

# transpose: rows -> genomes (samples), columns -> gene_id
# This creates a "genome" column from original sample column names
df_rrna = df_rrna.transpose(
include_header=True,
header_name="genome", # new first column name
column_names="gene_id", # column headers come from gene_id values
)
# Transpose so samples become rows and genes become columns
df_rrna = df_rrna.T.reset_index()

genome_stats = genome_stats.join(df_rrna, on="genome", how="inner")
assert genome_stats.shape[0] == df_rrna.shape[0], "genomes from annotation file don't map to rrna file"
# Rename the index column to input_fasta (or whatever you want)
df_rrna = df_rrna.rename(columns={"index": "genome"})
df_rrna.columns.name = None
genome_stats = pd.merge(genome_stats, df_rrna, how="outer", on="genome")
if trna_frame is not None:
meta_cols = TRNA_COLUMNS

Expand Down Expand Up @@ -147,16 +237,19 @@ def make_genome_stats(annotations: pl.DataFrame, rrna_frame: pl.DataFrame = None
@click.command()
@click.option("-i", "--input_file", required=True, help="Annotations path")
# @click.option("-o", "--output_dir", required=True, help="Directory to write summarized genomes")
@click.option("--rrna_path", help="rRNA output from annotation")
@click.option("--trna_path", help="tRNA output from annotation")
@click.option("--quast_path", help="Quast summary TSV from the quast step")
@click.option("--rrna_path", help="rRNA output from annotation", default=None, type=click.Path(exists=True))
@click.option("--trna_path", help="tRNA output from annotation", default=None, type=click.Path(exists=True))
@click.option("--quast_path", help="Quast summary TSV from the quast step", default=None, type=click.Path(exists=True))
@click.option("--groupby_column", help="Column from annotations to group as organism units",
default=FASTA_COLUMN)
default="input_fasta", type = click.STRING)
@click.option("--distil_topics", default="default", help="Default distillates topics to run.")
@click.option("--distil_ecosystem", default="eng_sys,ag", help="Default distillates ecosystems to run.")
@click.option("--custom_distillate", default=[], callback=validate_comma_separated, help="Custom distillate forms to add your own modules, comma separated. ")
def distill(input_file, rrna_path=None, trna_path=None, quast_path=None, groupby_column=FASTA_COLUMN, distil_topics=None, distil_ecosystem=None,
custom_distillate=None):
@click.option("--custom_distillate", default="", callback=validate_comma_separated, help="Custom distillate forms to add your own modules, comma separated. ")
@click.option("--distillate_gene_names", is_flag=True,
show_default=True, default=False,
help="Give names of genes instead of counts in genome metabolism summary")
def distill(input_file, rrna_path, trna_path, quast_path, groupby_column, distil_topics, distil_ecosystem,
custom_distillate, distillate_gene_names):
"""Summarize metabolic content of annotated genomes"""

# read in data
Expand All @@ -170,24 +263,20 @@ def distill(input_file, rrna_path=None, trna_path=None, quast_path=None, groupby
# Check the columns are present
check_columns(annotations, logger)

if trna_path is None:
trna_frame = None
else:
trna_frame = pl.read_csv(trna_path, separator='\t')
if rrna_path is None:
rrna_frame = None
else:
rrna_frame = pl.read_csv(rrna_path, separator='\t')
# Check NF DRAM didn't pass an empty sheet to signal no tRNAs or rRNAs
if rrna_frame.is_empty():
rrna_frame = None
if trna_frame.is_empty():
trna_frame = None

if quast_path is None:
quast_frame = None
else:
quast_frame = pl.read_csv(quast_path, separator='\t')
trna_frame = None
rrna_frame = None
if all([v is not None for v in [trna_path, rrna_path]]):
trna_frame = pd.read_csv(trna_path, sep='\t')
rrna_frame = pd.read_csv(rrna_path, sep='\t')
if any(v.dropna(how="all").empty for v in [trna_frame, rrna_frame]):
trna_frame = None
rrna_frame = None

quast_frame = None
if quast_path is not None:
quast_frame = pd.read_csv(quast_path, sep='\t')
if quast_frame.dropna(how="all").empty:
quast_frame = None

distil_sheets_names = []
if "default" in distil_topics:
Expand Down Expand Up @@ -240,8 +329,8 @@ def distill(input_file, rrna_path=None, trna_path=None, quast_path=None, groupby
logger.info('Retrieved distillate genome summary form')

# make genome stats
genome_stats = make_genome_stats(annotations, rrna_frame, trna_frame, quast_frame=quast_frame, groupby_column=groupby_column)
genome_stats.write_csv('genome_stats.tsv', separator='\t')
genome_stats = make_genome_stats(annotations, rrna_frame, trna_frame, quast_frame, groupby_column=groupby_column)
genome_stats.to_csv('genome_stats.tsv', sep='\t', index=None)
logger.info('Calculated genome statistics')

# make genome metabolism summary
Expand Down
6 changes: 6 additions & 0 deletions conf/base.config
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,14 @@ process {
withLabel:error_ignore {
errorStrategy = 'ignore'
}

withLabel:error_retry {
errorStrategy = 'retry'
maxRetries = 2
}

withName: 'DRAM:ANNOTATE:CALL:.*|DRAM:ANNOTATE:DB_SEARCH:.*|DRAM:ANNOTATE:QC:COLLECT_RNA:.*' {
array = params.array_size
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to support people running DRAM2 with local executor (such as on their own computer if they want), which doesn't support array. So the array should only be used with executors that support it.

}

}
1 change: 1 addition & 0 deletions conf/modules.config
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ process {
]
}
withName: SUMMARIZE {
ext.args = { "--groupby_column ${params.CONSTANTS.FASTA_COLUMN}" }
publishDir = [
[
path: "${params.outdir}/SUMMARIZE",
Expand Down
Loading