diff --git a/countNUM.py b/countNUM.py new file mode 100644 index 000000000..58f8477d4 --- /dev/null +++ b/countNUM.py @@ -0,0 +1,123 @@ +#Counting the unique target ID from result2.m8 & reclass_result.txt +from collections import defaultdict +import pdb + +PATH1 = '/home/yakim/benchmark/result_mmseq.m8' # mmseq +PATH2 = '/home/yakim/benchmark/reclass_bit.m8' # after reclassification - bit score +PATH3 = '/home/yakim/benchmark/reclass_seqid.m8' # after reclassification - seq id +PATH4 = '/home/yakim/benchmark/reclass0922.m8' # after reclassification - C++ (seqid) + + +# MMseq2 result count +qandt = [] +hit_num = 0 +redund1 = [] +with open(PATH1, 'r') as f: + for line in f: + a = line.split() + qandt.append([a[0],a[1]]) + +for q, t in qandt: + query = q.split('_')[-1] #query + qnum = q.split('_')[0] + if qnum not in redund1: + redund1.append(qnum) + if query == t: + hit_num += 1 +print('Hit # of mmseqs result:', hit_num, '&& Total #:', len(redund1)) + + +# Mmseq2 result count (count also for every top 10) +''' +hit_num1 = 0 +grouped = defaultdict(list) + +with open(PATH1, 'r') as f: + for line in f: + a = line.split() + query_full = a[0] # query705_from_UniRef90_A0AA46S9G1 + target = a[1] # A0AA46S9G1 + seq_identity = round(float(a[2]), 3) + query_num = query_full.split('_')[0] # query705 + query_last_id = query_full.split('_')[-1] # A0AA46S9G1 + + # query 번호별로 그룹핑 + grouped[query_num].append({ + 'query_last_id': query_last_id, + 'target': target, + 'seq_identity': seq_identity + }) # row[0] row[1] row[2] + #{query1: (A0AA46S9G1, A0AA46S9R1, 0.9), (A0AA46S9G1, A0AA46S9G1, 0.9), (A0AA46S9G1, A0AA34S9T1, 0.85)} + +for query_num, rows in grouped.items(): + if not rows: + continue + top_seq_identity = rows[0]['seq_identity'] # 0.9 + top_query_last_id = rows[0]['query_last_id'] # A0AA46S9G1 + + count = sum( + (row['query_last_id'] == row['target']) and + (row['seq_identity'] == top_seq_identity) + for row in rows + ) + #print(f'{query_num}: {count}') # 각 그룹별로 개수 출력 + hit_num1 += count + +print('Hit # of mmseqs result(top 10):', hit_num1, '&& Total #:', len(redund)) +''' + +# After reclassification result (bit score) +qandt2 = [] +hit_num2 = 0 +redund2 = [] +with open(PATH2, 'r') as f: + for line in f: + a = line.split() + qandt2.append([a[0],a[1]]) + +for q, t in qandt2: + query = q.split('_')[-1] #query + qnum = q.split('_')[0] + if qnum not in redund2: + redund2.append(qnum) + if query == t: + hit_num2 += 1 +print('Hit # of reclassification result(bit score):', hit_num2, '&& Total #:', len(redund2)) + + +# After reclassification result (seq id) +qandt3 = [] +hit_num3 = 0 +redund3 = [] +with open(PATH3, 'r') as f: + for line in f: + a = line.split() + qandt3.append([a[0],a[1]]) + +for q, t in qandt3: + query = q.split('_')[-1] #query + qnum = q.split('_')[0] + if qnum not in redund3: + redund3.append(qnum) + if query == t: + hit_num3 += 1 +print('Hit # of reclassification result(seqid):', hit_num3, '&& Total #:', len(redund3)) + +# after reclass result count (C++) +qandt4 = [] +hit_num4 = 0 +redund4 = [] +with open(PATH4, 'r') as f: + for line in f: + a = line.split() + qandt4.append([a[0],a[1]]) + +for q, t in qandt4: + query = q.split('_')[-1] #query + qnum = q.split('_')[0] + if qnum not in redund4: + redund4.append(qnum) + if query == t: + hit_num4 += 1 +print('Hit # of reclassification result(C++):', hit_num4, '&& Total #:', len(redund4)) + diff --git a/src/CommandDeclarations.h b/src/CommandDeclarations.h index 42a1e3d3b..1edbfd451 100644 --- a/src/CommandDeclarations.h +++ b/src/CommandDeclarations.h @@ -101,6 +101,8 @@ extern int gappedprefilter(int argc, const char **argv, const Command& command); extern int unpackdb(int argc, const char **argv, const Command& command); extern int rbh(int argc, const char **argv, const Command& command); extern int recoverlongestorf(int argc, const char **argv, const Command& command); +extern int emreclassify(int argc, const char **argv, const Command& command); +extern int emabundance(int argc, const char **argv, const Command& command); extern int result2flat(int argc, const char **argv, const Command& command); extern int result2msa(int argc, const char **argv, const Command& command); extern int result2dnamsa(int argc, const char **argv, const Command& command); diff --git a/src/MMseqsBase.cpp b/src/MMseqsBase.cpp index 94c876133..3a00fa00d 100644 --- a/src/MMseqsBase.cpp +++ b/src/MMseqsBase.cpp @@ -1072,6 +1072,26 @@ std::vector baseCommands = { " ", CITATION_MMSEQS2, {{"resultDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::resultDb }, {"resultDB", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::resultDb }}}, + {"reclassify", emreclassify, &par.reclassify, COMMAND_RESULT | COMMAND_FORMAT_CONVERSION, + "Reclassify alignments and write a new alignment DB with posterior probabilities(convertalis will convert it into column 3(instead of seqId))", + "mmseqs reclassify queryDB targetDB alignmentDB newDB\n", + "Yaeji Kim", + " ", + CITATION_MMSEQS2, {{"queryDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::sequenceDb }, + {"targetDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::sequenceDb }, + {"alignmentDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }, + {"newDB", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }}}, + {"abundance", emabundance, &par.abundance, COMMAND_RESULT | COMMAND_FORMAT_CONVERSION, + "Summarize abundance from reclassified alignment DB", + "mmseqs abundance queryDB targetDB newDB abundance.tsv\n" + "# targetDB_mapping and targetDB_taxonomy are only required with --taxonomy 1\n" + "mmseqs abundance queryDB targetDB newDB abundance.report --taxonomy 1\n", + "Yaeji Kim", + " ", + CITATION_MMSEQS2 | CITATION_TAXONOMY, {{"queryDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::sequenceDb }, + {"targetDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::taxSequenceDb }, + {"alignmentDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }, + {"abundanceFile", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::flatfile }}}, {"summarizealis", summarizealis, &par.threadsandcompression, COMMAND_RESULT, "Summarize alignment result to one row (uniq. cov., cov., avg. seq. id.)", NULL, diff --git a/src/commons/Parameters.cpp b/src/commons/Parameters.cpp index a3669d99a..49c4fa162 100644 --- a/src/commons/Parameters.cpp +++ b/src/commons/Parameters.cpp @@ -308,6 +308,14 @@ Parameters::Parameters(): // unpackdb PARAM_UNPACK_SUFFIX(PARAM_UNPACK_SUFFIX_ID, "--unpack-suffix", "Unpack suffix", "File suffix for unpacked files.\nAdd .gz suffix to write compressed files.", typeid(std::string), (void *) &unpackSuffix, "^.*$"), PARAM_UNPACK_NAME_MODE(PARAM_UNPACK_NAME_MODE_ID, "--unpack-name-mode", "Unpack name mode", "Name unpacked files by 0: DB key, 1: accession (through .lookup)", typeid(int), (void *) &unpackNameMode, "^[0-1]{1}$"), + // reclassify + PARAM_RECLASSIFY_LAMBDA(PARAM_RECLASSIFY_LAMBDA_ID, "--lambda", "Reclassify lambda", "Lambda scaling factor for the reclassification bit score term", typeid(double), (void *) &reclassifyLambda, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_ALPHA(PARAM_RECLASSIFY_ALPHA_ID, "--alpha", "Reclassify alpha", "Exponent applied to abundance during reclassification", typeid(double), (void *) &reclassifyAlpha, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_GAMMA(PARAM_RECLASSIFY_GAMMA_ID, "--gamma", "Reclassify gamma", "Exponent applied to coverage confidence during reclassification", typeid(double), (void *) &reclassifyGamma, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_MAX_ITER(PARAM_RECLASSIFY_MAX_ITER_ID, "--max-iter", "Reclassify max iterations", "Maximum number of SQUAREM iterations for reclassification", typeid(int), (void *) &reclassifyMaxIterations, "^[1-9]{1}[0-9]*$"), + PARAM_RECLASSIFY_TOL(PARAM_RECLASSIFY_TOL_ID, "--tol", "Reclassify tolerance", "Convergence tolerance for reclassification", typeid(double), (void *) &reclassifyTolerance, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_TAXONOMY(PARAM_RECLASSIFY_TAXONOMY_ID, "--taxonomy", "Abundance taxonomy output", "0: write alignment and protein abundance only; taxonomy files are not required. 1: also write taxonomy_abundance.tsv and taxonomic columns in protein_abundance.tsv; requires targetDB_mapping and targetDB_taxonomy", typeid(int), (void *) &reclassifyTaxonomy, "^[0-1]{1}$"), + PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE(PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE_ID, "--drop-percentage", "Max drop percentage", "Maximum percentage of cumulative low-tail abundance mass that the automatic jump-based filter may drop (range 0.0-100.0, default 10.0)", typeid(double), (void *) &reclassifyMaxDropPercentage, "^100(\\.0+)?$|^([0-9]|[1-9][0-9])(\\.[0-9]+)?$"), // for modules that should handle -h themselves PARAM_HELP(PARAM_HELP_ID, "-h", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN), PARAM_HELP_LONG(PARAM_HELP_LONG_ID, "--help", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN) @@ -335,6 +343,29 @@ Parameters::Parameters(): threadsandcompression.push_back(&PARAM_COMPRESSED); threadsandcompression.push_back(&PARAM_V); + // reclassify + reclassify.push_back(&PARAM_RECLASSIFY_LAMBDA); + reclassify.push_back(&PARAM_RECLASSIFY_ALPHA); + reclassify.push_back(&PARAM_RECLASSIFY_GAMMA); + reclassify.push_back(&PARAM_RECLASSIFY_MAX_ITER); + reclassify.push_back(&PARAM_RECLASSIFY_TOL); + reclassify.push_back(&PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE); + reclassify.push_back(&PARAM_THREADS); + reclassify.push_back(&PARAM_COMPRESSED); + reclassify.push_back(&PARAM_V); + + // abundance + abundance.push_back(&PARAM_RECLASSIFY_LAMBDA); + abundance.push_back(&PARAM_RECLASSIFY_ALPHA); + abundance.push_back(&PARAM_RECLASSIFY_GAMMA); + abundance.push_back(&PARAM_RECLASSIFY_MAX_ITER); + abundance.push_back(&PARAM_RECLASSIFY_TOL); + abundance.push_back(&PARAM_RECLASSIFY_TAXONOMY); + abundance.push_back(&PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE); + abundance.push_back(&PARAM_THREADS); + abundance.push_back(&PARAM_COMPRESSED); + abundance.push_back(&PARAM_V); + // createclusearchdb createclusearchdb.push_back(&PARAM_THREADS); createclusearchdb.push_back(&PARAM_COMPRESSED); @@ -2637,6 +2668,15 @@ void Parameters::setDefaults() { unpackSuffix = ""; unpackNameMode = Parameters::UNPACK_NAME_ACCESSION; + // reclassify + reclassifyLambda = 1.0; + reclassifyAlpha = 1.0; + reclassifyGamma = 1.0; + reclassifyMaxIterations = 100; + reclassifyTolerance = 1e-5; + reclassifyTaxonomy = 0; + reclassifyMaxDropPercentage = 10.0; + lcaRanks = ""; showTaxLineage = 0; // bin for all unclassified sequences diff --git a/src/commons/Parameters.h b/src/commons/Parameters.h index caa0e4735..c882b8bac 100644 --- a/src/commons/Parameters.h +++ b/src/commons/Parameters.h @@ -724,6 +724,15 @@ class Parameters { std::string unpackSuffix; int unpackNameMode; + // reclassify + double reclassifyLambda; + double reclassifyAlpha; + double reclassifyGamma; + int reclassifyMaxIterations; + double reclassifyTolerance; + int reclassifyTaxonomy; + double reclassifyMaxDropPercentage; + // for modules that should handle -h themselves bool help; @@ -1081,6 +1090,15 @@ class Parameters { PARAMETER(PARAM_UNPACK_SUFFIX) PARAMETER(PARAM_UNPACK_NAME_MODE) + // reclassify + PARAMETER(PARAM_RECLASSIFY_LAMBDA) + PARAMETER(PARAM_RECLASSIFY_ALPHA) + PARAMETER(PARAM_RECLASSIFY_GAMMA) + PARAMETER(PARAM_RECLASSIFY_MAX_ITER) + PARAMETER(PARAM_RECLASSIFY_TOL) + PARAMETER(PARAM_RECLASSIFY_TAXONOMY) + PARAMETER(PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE) + // for modules that should handle -h themselves PARAMETER(PARAM_HELP) PARAMETER(PARAM_HELP_LONG) @@ -1207,6 +1225,8 @@ class Parameters { std::vector touchdb; std::vector gpuserver; std::vector tsv2exprofiledb; + std::vector reclassify; + std::vector abundance; std::vector combineList(const std::vector &par1, const std::vector &par2); diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index 4c7e19137..656c04941 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -48,6 +48,8 @@ set(util_source_files util/profile2pssm.cpp util/profile2neff.cpp util/profile2seq.cpp + util/EM_reclassify.cpp + util/EM_abundnace.cpp util/recoverlongestorf.cpp util/result2dnamsa.cpp util/result2flat.cpp diff --git a/src/util/EM_abundnace.cpp b/src/util/EM_abundnace.cpp new file mode 100644 index 000000000..1294743ee --- /dev/null +++ b/src/util/EM_abundnace.cpp @@ -0,0 +1,738 @@ +#include "Parameters.h" +#include "DBReader.h" +#include "Debug.h" +#include "Util.h" +#include "Matcher.h" +#include "FastSort.h" +#include "FileUtil.h" +#include "NcbiTaxonomy.h" +#include "MappingReader.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef OPENMP +#include +#endif + +namespace { +struct ReclassTaxEntry { + Matcher::result_t result; + double abundance; + double posterior; + double coverageConfidence; +}; + +typedef std::unordered_map > MappingTable; + +struct Interval { + int start; + int end; +}; + +struct TargetStats { + unsigned int key; + unsigned int targetLength; + double abundance; + double coverageConfidence; + bool dropped; + std::vector intervals; +}; + +struct ReclassTaxContext { + MappingTable mappingTable; + std::vector queryOrder; + std::unordered_set targetSet; + size_t queryCount; + bool hasBacktrace; + bool hasOrfPosition; + + ReclassTaxContext() : queryCount(0), hasBacktrace(false), hasOrfPosition(false) {} +}; + +static const double EPS = 1e-12; +static const size_t MIN_FILTER_TARGETS = 20; +static const size_t MIN_TAIL_TARGETS = 2; + +static double clamp01(double value); + +static std::vector targetListFromSet(const std::unordered_set &targets) { + std::vector out(targets.begin(), targets.end()); + SORT_SERIAL(out.begin(), out.end()); + return out; +} + +static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &ctx) { + Debug::Progress progress(reader.getSize()); + const char *entry[255]; + + for (size_t i = 0; i < reader.getSize(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = reader.getDbKey(i); + char *data = reader.getData(i, 0); + + if (reader.getEntryLen(i) <= 1) { + continue; + } + + std::vector &records = ctx.mappingTable[queryKey]; + if (records.empty()) { + ctx.queryOrder.push_back(queryKey); + } + while (*data != '\0') { + const size_t columns = Util::getWordsOfLine(data, entry, 255); + if (columns < Matcher::ALN_RES_WITHOUT_BT_COL_CNT) { + Debug(Debug::ERROR) << "Invalid alignment result record in query " << queryKey << ".\n"; + EXIT(EXIT_FAILURE); + } + + if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { + ctx.hasBacktrace = true; + } + if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { + ctx.hasOrfPosition = true; + } + + Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); + + // Always use the 3rd alignment column (seqId) as posterior. + // Do not consume optional trailing columns as posterior. + double posterior = static_cast(result.seqId); + + records.push_back(ReclassTaxEntry{result, 0.0, posterior, 0.0}); + ctx.targetSet.insert(result.dbKey); + data = Util::skipLine(data); + } + } + + ctx.queryCount = ctx.mappingTable.size(); +} + +struct TargetHitRef { + const ReclassTaxEntry *entry; + double score; + double weight; +}; + +static void initCoverageConfidence(MappingTable &mappingTable, + const std::unordered_set &targetSet, + int threads) { + (void)threads; + std::unordered_map targetMin; + std::unordered_map targetMax; + std::unordered_map > hitsByTarget; + + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + targetMin[*it] = std::numeric_limits::max(); + targetMax[*it] = std::numeric_limits::min(); + hitsByTarget.emplace(*it, std::vector()); + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSum = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSum += static_cast(it->second[j].result.score); + } + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + if (it->second[j].result.dbStartPos < targetMin[target]) { + targetMin[target] = it->second[j].result.dbStartPos; + } + if (it->second[j].result.dbEndPos > targetMax[target]) { + targetMax[target] = it->second[j].result.dbEndPos; + } + const double score = static_cast(it->second[j].result.score); + const double weight = (scoreSum > 0.0) ? (score / scoreSum) : 0.0; + hitsByTarget[target].push_back(TargetHitRef{&it->second[j], score, weight}); + } + } + + std::unordered_map coverageFraction; + coverageFraction.reserve(targetSet.size()); + const std::vector targetList = targetListFromSet(targetSet); + std::vector coverageFractionByIndex(targetList.size(), 0.0); + +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) + for (size_t i = 0; i < targetList.size(); ++i) { + const unsigned int target = targetList[i]; + const int startPos = targetMin[target]; + const int endPos = targetMax[target]; + const int len = (endPos >= startPos) ? (endPos - startPos + 1) : 1; + std::vector cov(static_cast(len), 0.0); + std::vector covConf(static_cast(len), 0.0); + + std::unordered_map >::const_iterator hitIt = hitsByTarget.find(target); + if (hitIt != hitsByTarget.end()) { + const std::vector &hits = hitIt->second; + for (size_t h = 0; h < hits.size(); ++h) { + const Matcher::result_t &result = hits[h].entry->result; + const int targetLen = result.dbEndPos - result.dbStartPos + 1; + if (targetLen <= 0) { + continue; + } + + const double mq = hits[h].score / static_cast(targetLen); + const int start = std::max(0, result.dbStartPos - startPos); + const int end = std::min(len - 1, result.dbEndPos - startPos); + for (int pos = start; pos <= end; ++pos) { + cov[static_cast(pos)] += mq; + covConf[static_cast(pos)] += hits[h].weight; + } + } + } + + double covered = 0.0; + double squaredCovered = 0.0; + for (size_t pos = 0; pos < covConf.size(); ++pos) { + const double clipped = std::min(1.0, covConf[pos]); + covered += clipped; + squaredCovered += clipped * clipped; + } + const double fraction = covered / static_cast(len); + const double hhi = (covered > 0.0) ? (squaredCovered / (covered * covered)) : 1.0; + const double concentrationPenalty = 1.0 - hhi; + coverageFractionByIndex[i] = clamp01(fraction * concentrationPenalty); + } + + for (size_t i = 0; i < targetList.size(); ++i) { + coverageFraction[targetList[i]] = coverageFractionByIndex[i]; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + std::unordered_map::const_iterator cf = coverageFraction.find(target); + it->second[j].coverageConfidence = (cf != coverageFraction.end()) ? cf->second : 0.0; + } + } +} + +static void computeAbundanceFromPosterior(MappingTable &mappingTable, + const std::unordered_set &targetSet, + size_t queryCount) { + std::unordered_map abundance; + abundance.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + abundance[*it] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + abundance[it->second[j].result.dbKey] += it->second[j].posterior; + } + } + + if (queryCount > 0) { + const double denom = static_cast(queryCount); + for (std::unordered_map::iterator it = abundance.begin(); it != abundance.end(); ++it) { + it->second /= denom; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = abundance[it->second[j].result.dbKey]; + } + } +} + +static void addInterval(std::vector &intervals, int start, int end) { + Interval interval; + interval.start = std::min(start, end); + interval.end = std::max(start, end); + intervals.push_back(interval); +} + +static std::vector mergeIntervals(std::vector intervals) { + if (intervals.empty()) { + return intervals; + } + + std::sort(intervals.begin(), intervals.end(), [](const Interval &lhs, const Interval &rhs) { + if (lhs.start != rhs.start) { + return lhs.start < rhs.start; + } + return lhs.end < rhs.end; + }); + + std::vector merged; + merged.push_back(intervals[0]); + for (size_t i = 1; i < intervals.size(); ++i) { + if (intervals[i].start <= merged.back().end + 1) { + merged.back().end = std::max(merged.back().end, intervals[i].end); + } else { + merged.push_back(intervals[i]); + } + } + return merged; +} + +static unsigned int intervalCoverage(const std::vector &intervals) { + unsigned int covered = 0; + for (size_t i = 0; i < intervals.size(); ++i) { + covered += static_cast(intervals[i].end - intervals[i].start + 1); + } + return covered; +} + +static std::vector collectTargetStats(const ReclassTaxContext &ctx) { + std::unordered_map statsByTarget; + + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const ReclassTaxEntry &entry = it->second[j]; + TargetStats &stats = statsByTarget[entry.result.dbKey]; + stats.key = entry.result.dbKey; + stats.targetLength = entry.result.dbLen; + stats.abundance = entry.abundance; + stats.coverageConfidence = entry.coverageConfidence; + stats.dropped = false; + addInterval(stats.intervals, entry.result.dbStartPos, entry.result.dbEndPos); + } + } + + std::vector out; + out.reserve(statsByTarget.size()); + for (std::unordered_map::iterator it = statsByTarget.begin(); it != statsByTarget.end(); ++it) { + it->second.intervals = mergeIntervals(it->second.intervals); + out.push_back(it->second); + } + + std::sort(out.begin(), out.end(), [](const TargetStats &lhs, const TargetStats &rhs) { + if (lhs.abundance != rhs.abundance) { + return lhs.abundance > rhs.abundance; + } + return lhs.key < rhs.key; + }); + return out; +} + +static double clamp01(double value) { + return std::max(0.0, std::min(1.0, value)); +} + +static bool largestJumpCutoff(std::vector values, + bool useLowTail, + double maxTailFraction, + double &cutoff, + size_t &tailCount) { + cutoff = 0.0; + tailCount = 0; + if (values.size() < MIN_FILTER_TARGETS) { + return false; + } + + std::sort(values.begin(), values.end()); + maxTailFraction = clamp01(maxTailFraction); + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + if (totalMass <= EPS || maxTailFraction <= 0.0) { + return false; + } + const double maxTailMass = maxTailFraction * totalMass; + + double bestGap = 0.0; + size_t bestIdx = 0; + double lowTailMass = 0.0; + for (size_t i = 0; i + 1 < values.size(); ++i) { + const size_t lowTailCount = i + 1; + const size_t highTailCount = values.size() - lowTailCount; + const size_t candidateTailCount = useLowTail ? lowTailCount : highTailCount; + lowTailMass += values[i]; + const double highTailMass = totalMass - lowTailMass; + const double candidateTailMass = useLowTail ? lowTailMass : highTailMass; + if (candidateTailCount < MIN_TAIL_TARGETS || candidateTailMass > (maxTailMass + EPS)) { + continue; + } + + const double gap = values[i + 1] - values[i]; + if (gap > bestGap) { + bestGap = gap; + bestIdx = i; + tailCount = candidateTailCount; + } + } + + if (bestGap <= EPS) { + return false; + } + + cutoff = 0.5 * (values[bestIdx] + values[bestIdx + 1]); + return true; +} + +static bool tailQuantileCutoff(std::vector values, + bool useLowTail, + double maxTailFraction, + double &cutoff, + size_t &tailCount) { + cutoff = 0.0; + tailCount = 0; + if (values.size() < MIN_FILTER_TARGETS) { + return false; + } + + std::sort(values.begin(), values.end()); + maxTailFraction = clamp01(maxTailFraction); + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + if (totalMass <= EPS || maxTailFraction <= 0.0) { + return false; + } + const double maxTailMass = maxTailFraction * totalMass; + + double accumulatedMass = 0.0; + size_t maxTailCount = 0; + for (size_t i = 0; i < values.size(); ++i) { + const double candidate = accumulatedMass + values[i]; + if (candidate > (maxTailMass + EPS)) { + break; + } + accumulatedMass = candidate; + ++maxTailCount; + } + if (maxTailCount < MIN_TAIL_TARGETS || maxTailCount >= values.size()) { + return false; + } + + tailCount = maxTailCount; + if (useLowTail) { + cutoff = values[tailCount - 1]; + } else { + cutoff = values[values.size() - tailCount]; + } + return true; +} + +static std::unordered_set selectTailTargets(const std::vector &stats, + bool useLowTail, + size_t tailCount, + double maxTailFraction) { + std::vector ordered; + ordered.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + ordered.push_back(&stats[i]); + } + + std::sort(ordered.begin(), ordered.end(), [useLowTail](const TargetStats *lhs, const TargetStats *rhs) { + const double lhsValue = useLowTail ? lhs->abundance : lhs->coverageConfidence; + const double rhsValue = useLowTail ? rhs->abundance : rhs->coverageConfidence; + if (lhsValue != rhsValue) { + return useLowTail ? (lhsValue < rhsValue) : (lhsValue > rhsValue); + } + return lhs->key < rhs->key; + }); + + double totalMass = 0.0; + for (size_t i = 0; i < ordered.size(); ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + totalMass += value; + } + const double maxTailMass = clamp01(maxTailFraction) * totalMass; + + std::unordered_set selected; + const size_t limit = std::min(tailCount, ordered.size()); + double selectedMass = 0.0; + selected.reserve(limit); + for (size_t i = 0; i < limit; ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + if (selected.size() >= MIN_TAIL_TARGETS && (selectedMass + value) > (maxTailMass + EPS)) { + break; + } + selectedMass += value; + selected.insert(ordered[i]->key); + } + return selected; +} + +static std::unordered_set selectDroppedTargets(const std::vector &stats, + double maxDropPercentage, + double &abundanceCutoff) { + std::unordered_set dropped; + if (stats.empty()) { + abundanceCutoff = 0.0; + return dropped; + } + if (stats.size() < MIN_FILTER_TARGETS) { + abundanceCutoff = 0.0; + return dropped; + } + + std::vector abundances; + abundances.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + abundances.push_back(stats[i].abundance); + } + + const double maxTailFraction = clamp01(maxDropPercentage / 100.0); + size_t abundanceTailCount = 0; + bool hasAbundanceCutoff = largestJumpCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); + if (hasAbundanceCutoff == false) { + hasAbundanceCutoff = tailQuantileCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); + } + if (hasAbundanceCutoff == false) { + abundanceCutoff = 0.0; + return dropped; + } + + const std::unordered_set lowAbundanceTargets = selectTailTargets(stats, true, abundanceTailCount, maxTailFraction); + for (std::unordered_set::const_iterator it = lowAbundanceTargets.begin(); it != lowAbundanceTargets.end(); ++it) { + dropped.insert(*it); + } + if (dropped.size() == stats.size()) { + dropped.clear(); + } + return dropped; +} + +static void markDroppedTargets(std::vector &stats, const std::unordered_set &dropped) { + for (size_t i = 0; i < stats.size(); ++i) { + stats[i].dropped = (dropped.find(stats[i].key) != dropped.end()); + } +} + +static void convertAbundanceToPercent(std::vector &stats) { + double total = 0.0; + for (size_t i = 0; i < stats.size(); ++i) { + total += stats[i].abundance; + } + + if (total <= 0.0) { + for (size_t i = 0; i < stats.size(); ++i) { + stats[i].abundance = 0.0; + } + return; + } + + for (size_t i = 0; i < stats.size(); ++i) { + stats[i].abundance = 100.0 * (stats[i].abundance / total); + } +} + +static const char *headerForKey(DBReader &headerReader, unsigned int key, unsigned int threadIdx) { + size_t id = headerReader.getId(key); + if (id == UINT_MAX) { + return NULL; + } + return headerReader.getData(id, threadIdx); +} + +static std::string identifierForKey(DBReader &headerReader, unsigned int key, unsigned int threadIdx) { + const char *header = headerForKey(headerReader, key, threadIdx); + if (header == NULL) { + return SSTR(key); + } + std::string parsed = Util::parseFastaHeader(header); + return parsed.empty() ? SSTR(key) : parsed; +} + +static void writeProteinStats(const std::vector &stats, + DBReader &targetHeaderReader, + const std::string &path) { + FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); + fputs("target_key\ttarget_id\tabundance_pct\tcoverage_confidence\tDrop(y/n)\tmapped_length\ttarget_length\n", handle); + + for (size_t i = 0; i < stats.size(); ++i) { + const unsigned int key = stats[i].key; + const std::string targetId = identifierForKey(targetHeaderReader, key, 0); + const unsigned int mappedLength = intervalCoverage(stats[i].intervals); + + fprintf(handle, "%u\t%s\t%.12g\t%.12g\t%s\t%u\t%u\n", + key, + targetId.c_str(), + stats[i].abundance, + stats[i].coverageConfidence, + stats[i].dropped ? "y" : "n", + mappedLength, + stats[i].targetLength); + } + + fclose(handle); +} + +static void writeKrakenReport(const std::vector &stats, + MappingReader &mapping, + NcbiTaxonomy *taxonomy, + size_t queryCount, + const std::string &path) { + std::unordered_map directCounts; + directCounts.reserve(stats.size()); + + for (size_t i = 0; i < stats.size(); ++i) { + const TaxID taxId = mapping.lookup(stats[i].key); + if (taxId == 0) { + continue; + } + const double expectedReads = stats[i].abundance * static_cast(queryCount) / 100.0; + directCounts[taxId] += static_cast(std::floor(expectedReads + 0.5)); + } + + const std::unordered_map > parentToChildren = taxonomy->getParentToChildren(); + const std::unordered_map cladeCounts = taxonomy->getCladeCounts(directCounts, parentToChildren); + + FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); + const double totalReads = (queryCount > 0) ? static_cast(queryCount) : 1.0; + + std::vector stack; + std::vector depthStack; + stack.push_back(1); + depthStack.push_back(0); + + while (!stack.empty()) { + TaxID taxId = stack.back(); + stack.pop_back(); + int depth = depthStack.back(); + depthStack.pop_back(); + + unsigned int cladeCount = 0; + unsigned int directCount = 0; + std::unordered_map::const_iterator it = cladeCounts.find(taxId); + if (it != cladeCounts.end()) { + cladeCount = it->second.cladeCount; + directCount = it->second.taxCount; + } + + if (cladeCount > 0) { + const TaxonNode *node = taxonomy->taxonNode(taxId, false); + const char *rankStr = (node != NULL) ? taxonomy->getString(node->rankIdx) : NULL; + char rankCode = '-'; + if (rankStr != NULL) { + std::map::const_iterator rankIt = NcbiShortRanks.find(std::string(rankStr)); + if (rankIt != NcbiShortRanks.end()) { + rankCode = rankIt->second; + } + } + const char *name = (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified"; + const double pct = 100.0 * static_cast(cladeCount) / totalReads; + + for (int i = 0; i < depth; ++i) { + fputs(" ", handle); + } + fprintf(handle, "%.4f\t%u\t%u\t%c\t%u\t%s\n", + pct, cladeCount, directCount, rankCode, static_cast(taxId), name); + } + + std::unordered_map::const_iterator childIt = cladeCounts.find(taxId); + if (childIt != cladeCounts.end()) { + const std::vector &children = childIt->second.children; + for (size_t i = 0; i < children.size(); ++i) { + stack.push_back(children[i]); + depthStack.push_back(depth + 1); + } + } + } + + fclose(handle); +} + +static void writeBrackenReport(const std::vector &stats, + MappingReader &mapping, + NcbiTaxonomy *taxonomy, + size_t queryCount, + const std::string &path) { + std::unordered_map directCounts; + directCounts.reserve(stats.size()); + + for (size_t i = 0; i < stats.size(); ++i) { + const TaxID taxId = mapping.lookup(stats[i].key); + if (taxId == 0) { + continue; + } + const double expectedReads = stats[i].abundance * static_cast(queryCount) / 100.0; + directCounts[taxId] += static_cast(std::floor(expectedReads + 0.5)); + } + + const std::unordered_map > parentToChildren = taxonomy->getParentToChildren(); + const std::unordered_map cladeCounts = taxonomy->getCladeCounts(directCounts, parentToChildren); + + FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); + fputs("name\ttaxonomy_id\ttaxonomy_lvl\tkraken_assigned_reads\tadded_reads\tnew_est_reads\tfraction_total_reads\n", handle); + + const double totalReads = (queryCount > 0) ? static_cast(queryCount) : 1.0; + for (std::unordered_map::const_iterator it = cladeCounts.begin(); it != cladeCounts.end(); ++it) { + const TaxID taxId = it->first; + const TaxonCounts &counts = it->second; + if (counts.cladeCount == 0) { + continue; + } + + const TaxonNode *node = taxonomy->taxonNode(taxId, false); + const char *rankStr = (node != NULL) ? taxonomy->getString(node->rankIdx) : "-"; + const char *name = (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified"; + const unsigned int cladeCount = counts.cladeCount; + const unsigned int directCount = counts.taxCount; + const unsigned int addedReads = (cladeCount >= directCount) ? (cladeCount - directCount) : 0; + const double fraction = static_cast(cladeCount) / totalReads; + + fprintf(handle, "%s\t%u\t%s\t%u\t%u\t%u\t%.12g\n", + name, + static_cast(taxId), + rankStr, + directCount, + addedReads, + cladeCount, + fraction); + } + + fclose(handle); +} +} + +int emabundance(int argc, const char **argv, const Command &command) { + Parameters &par = Parameters::getInstance(); + par.parseParameters(argc, argv, command, true, 0, 0); + + DBReader reader(par.db3.c_str(), par.db3Index.c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + reader.open(DBReader::LINEAR_ACCCESS); + + DBReader targetHeaderReader((par.db2 + "_h").c_str(), (par.db2 + "_h.index").c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + targetHeaderReader.open(DBReader::NOSORT); + + const bool withTaxonomy = (par.reclassifyTaxonomy == 1); + NcbiTaxonomy *taxonomy = NULL; + MappingReader *mapping = NULL; + if (withTaxonomy) { + taxonomy = NcbiTaxonomy::openTaxonomy(par.db2); + mapping = new MappingReader(par.db2); + } + + ReclassTaxContext ctx; + loadAlignmentDb(reader, ctx); + Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; + + initCoverageConfidence(ctx.mappingTable, ctx.targetSet, par.threads); + computeAbundanceFromPosterior(ctx.mappingTable, ctx.targetSet, ctx.queryCount); + + std::vector allTargetStats = collectTargetStats(ctx); + double abundanceCutoff = 0.0; + const std::unordered_set dropped = selectDroppedTargets(allTargetStats, + par.reclassifyMaxDropPercentage, + abundanceCutoff); + markDroppedTargets(allTargetStats, dropped); + convertAbundanceToPercent(allTargetStats); + + if (withTaxonomy) { + std::vector targetStats = allTargetStats; + targetStats.erase(std::remove_if(targetStats.begin(), targetStats.end(), [](const TargetStats &entry) { + return entry.dropped; + }), targetStats.end()); + writeKrakenReport(targetStats, *mapping, taxonomy, ctx.queryCount, par.db4); + writeBrackenReport(targetStats, *mapping, taxonomy, ctx.queryCount, par.db4 + ".bracken"); + } else { + writeProteinStats(allTargetStats, targetHeaderReader, par.db4); + } + + delete mapping; + delete taxonomy; + targetHeaderReader.close(); + reader.close(); + return EXIT_SUCCESS; +} diff --git a/src/util/EM_reclassify.cpp b/src/util/EM_reclassify.cpp new file mode 100644 index 000000000..49294c5f5 --- /dev/null +++ b/src/util/EM_reclassify.cpp @@ -0,0 +1,931 @@ +#include "Parameters.h" +#include "DBReader.h" +#include "DBWriter.h" +#include "Debug.h" +#include "Util.h" +#include "Matcher.h" +#include "FastSort.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef OPENMP +#include +#endif + +namespace { +struct ReclassTaxEntry { + Matcher::result_t result; + double abundance; + double posterior; + double coverageConfidence; +}; + +typedef std::unordered_map > MappingTable; + +struct Interval { + int start; + int end; +}; + +struct TargetStats { + unsigned int key; + unsigned int targetLength; + double abundance; + double coverageConfidence; + bool dropped; + std::vector intervals; +}; + +struct ReclassTaxContext { + MappingTable mappingTable; + std::vector queryOrder; + std::unordered_set targetSet; + size_t queryCount; + bool hasBacktrace; + bool hasOrfPosition; + + ReclassTaxContext() : queryCount(0), hasBacktrace(false), hasOrfPosition(false) {} +}; + +static const double STEP_MIN = -1.0; +static const double STEP_MAX = 1.0; +static const double EPS = 1e-12; +static const double LOG_COMPATIBILITY_MIN = -60.0; +static const double LOG_COMPATIBILITY_MAX = 60.0; +static const double ABUNDANCE_EXP_TAU = 3.0; +static const double ABUNDANCE_SMOOTH_EPS = 1e-8; +static const size_t MIN_FILTER_TARGETS = 20; +static const size_t MIN_TAIL_TARGETS = 2; + +static double clamp01(double value); + +static std::vector targetListFromSet(const std::unordered_set &targets) { + std::vector out(targets.begin(), targets.end()); + SORT_SERIAL(out.begin(), out.end()); + return out; +} + +static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &ctx) { + Debug::Progress progress(reader.getSize()); + const char *entry[255]; + + for (size_t i = 0; i < reader.getSize(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = reader.getDbKey(i); + char *data = reader.getData(i, 0); + + if (reader.getEntryLen(i) <= 1) { + continue; + } + + std::vector &records = ctx.mappingTable[queryKey]; + if (records.empty()) { + ctx.queryOrder.push_back(queryKey); + } + while (*data != '\0') { + const size_t columns = Util::getWordsOfLine(data, entry, 255); + if (columns < Matcher::ALN_RES_WITHOUT_BT_COL_CNT) { + Debug(Debug::ERROR) << "Invalid alignment result record in query " << queryKey << ".\n"; + EXIT(EXIT_FAILURE); + } + + if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { + ctx.hasBacktrace = true; + } + if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { + ctx.hasOrfPosition = true; + } + + Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); + records.push_back(ReclassTaxEntry{result, 0.0, static_cast(result.seqId), 0.0}); + ctx.targetSet.insert(result.dbKey); + data = Util::skipLine(data); + } + } + + ctx.queryCount = ctx.mappingTable.size(); +} + +static void initAbundance(MappingTable &mappingTable, const std::unordered_set &targetSet, size_t queryCount) { + std::unordered_map initAbundance; + initAbundance.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + initAbundance[*it] = 0.0; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSum = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSum += std::max(static_cast(it->second[j].result.score), 0.0); + } + if (scoreSum <= 0.0) { + continue; + } + for (size_t j = 0; j < it->second.size(); ++j) { + const double nonNegativeScore = std::max(static_cast(it->second[j].result.score), 0.0); + initAbundance[it->second[j].result.dbKey] += nonNegativeScore / scoreSum; + } + } + + if (queryCount > 0) { + const double denom = static_cast(queryCount); + for (std::unordered_map::iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { + it->second /= denom; + } + } + double totalAbundance = 0.0; + for (std::unordered_map::const_iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { + totalAbundance += it->second; + } + if (totalAbundance > 0.0) { + for (std::unordered_map::iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { + it->second /= totalAbundance; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = initAbundance[it->second[j].result.dbKey]; + } + } +} + +struct TargetHitRef { + const ReclassTaxEntry *entry; + double score; + double weight; +}; + +static void initCoverageConfidence(MappingTable &mappingTable, + const std::unordered_set &targetSet, + int threads) { + (void)threads; + std::unordered_map targetMin; + std::unordered_map targetMax; + std::unordered_map > hitsByTarget; + + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + targetMin[*it] = std::numeric_limits::max(); + targetMax[*it] = std::numeric_limits::min(); + hitsByTarget.emplace(*it, std::vector()); + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSum = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSum += static_cast(it->second[j].result.score); + } + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + if (it->second[j].result.dbStartPos < targetMin[target]) { + targetMin[target] = it->second[j].result.dbStartPos; + } + if (it->second[j].result.dbEndPos > targetMax[target]) { + targetMax[target] = it->second[j].result.dbEndPos; + } + const double score = static_cast(it->second[j].result.score); + const double weight = (scoreSum > 0.0) ? (score / scoreSum) : 0.0; + hitsByTarget[target].push_back(TargetHitRef{&it->second[j], score, weight}); + } + } + + std::unordered_map coverageFraction; + coverageFraction.reserve(targetSet.size()); + const std::vector targetList = targetListFromSet(targetSet); + std::vector coverageFractionByIndex(targetList.size(), 0.0); + +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) + for (size_t i = 0; i < targetList.size(); ++i) { + const unsigned int target = targetList[i]; + const int startPos = targetMin[target]; + const int endPos = targetMax[target]; + const int len = (endPos >= startPos) ? (endPos - startPos + 1) : 1; + std::vector cov(static_cast(len), 0.0); + std::vector covConf(static_cast(len), 0.0); + + std::unordered_map >::const_iterator hitIt = hitsByTarget.find(target); + if (hitIt != hitsByTarget.end()) { + const std::vector &hits = hitIt->second; + for (size_t h = 0; h < hits.size(); ++h) { + const Matcher::result_t &result = hits[h].entry->result; + const int targetLen = result.dbEndPos - result.dbStartPos + 1; + if (targetLen <= 0) { + continue; + } + + const double mq = hits[h].score / static_cast(targetLen); + const int start = std::max(0, result.dbStartPos - startPos); + const int end = std::min(len - 1, result.dbEndPos - startPos); + for (int pos = start; pos <= end; ++pos) { + cov[static_cast(pos)] += mq; + covConf[static_cast(pos)] += hits[h].weight; + } + } + } + + double covered = 0.0; + double squaredCovered = 0.0; + for (size_t pos = 0; pos < covConf.size(); ++pos) { + const double clipped = std::min(1.0, covConf[pos]); + covered += clipped; + squaredCovered += clipped * clipped; + } + const double fraction = covered / static_cast(len); + const double hhi = (covered > 0.0) ? (squaredCovered / (covered * covered)) : 1.0; + const double concentrationPenalty = 1.0 - hhi; + coverageFractionByIndex[i] = clamp01(fraction * concentrationPenalty); + } + + for (size_t i = 0; i < targetList.size(); ++i) { + coverageFraction[targetList[i]] = coverageFractionByIndex[i]; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + std::unordered_map::const_iterator cf = coverageFraction.find(target); + it->second[j].coverageConfidence = (cf != coverageFraction.end()) ? cf->second : 0.0; + } + } +} + +static double maxQueryBitScore(const std::vector &entries) { + double maxScore = std::numeric_limits::lowest(); + for (size_t i = 0; i < entries.size(); ++i) { + maxScore = std::max(maxScore, static_cast(entries[i].result.score)); + } + return maxScore; +} + +static double hitCoverage(const Matcher::result_t &result) { + return clamp01(std::min(static_cast(result.qcov), static_cast(result.dbcov))); +} + +static double compatibilityLogTerm(const ReclassTaxEntry &entry, + double queryMaxScore, + double betaBit, + double betaSeqId, + double betaCov) { + const double deltaBit = static_cast(entry.result.score) - queryMaxScore; + const double seqId = static_cast(entry.result.seqId); + const double cov = hitCoverage(entry.result); + const double r = (betaBit * deltaBit) + (betaSeqId * seqId) + (betaCov * cov); + return std::max(LOG_COMPATIBILITY_MIN, std::min(LOG_COMPATIBILITY_MAX, r)); +} + +static void computePosterior(MappingTable &mappingTable, + double betaBit, + double betaSeqId, + double betaCov, + double abundanceExponent) { + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + const double queryMaxScore = maxQueryBitScore(it->second); + std::vector numerators(it->second.size(), 0.0); + double denom = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + const double phi = std::exp(compatibilityLogTerm(it->second[j], queryMaxScore, betaBit, betaSeqId, betaCov)); + const double abundance = std::max(it->second[j].abundance, ABUNDANCE_SMOOTH_EPS); + const double weighted = phi * std::pow(abundance, abundanceExponent); + numerators[j] = weighted; + denom += weighted; + } + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].posterior = (denom > 0.0) ? (numerators[j] / denom) : 0.0; + } + } +} + +static double logLikelihood(const MappingTable &mappingTable, + double betaBit, + double betaSeqId, + double betaCov, + double abundanceExponent, + size_t queryCount) { + if (queryCount == 0) { + return 0.0; + } + + double ll = 0.0; + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + const double queryMaxScore = maxQueryBitScore(it->second); + double mixture = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + const double phi = std::exp(compatibilityLogTerm(it->second[j], queryMaxScore, betaBit, betaSeqId, betaCov)); + const double abundance = std::max(it->second[j].abundance, ABUNDANCE_SMOOTH_EPS); + mixture += phi * std::pow(abundance, abundanceExponent); + } + ll += std::log(mixture > 0.0 ? mixture : 1e-300); + } + return ll / static_cast(queryCount); +} + +static std::vector abundanceVectorFromTable(const MappingTable &mappingTable, const std::vector &targetList) { + std::unordered_map abundance; + abundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + abundance[targetList[i]] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + abundance[it->second[j].result.dbKey] = it->second[j].abundance; + } + } + + std::vector out; + out.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + out.push_back(abundance[targetList[i]]); + } + return out; +} + +static void setAbundance(MappingTable &mappingTable, const std::vector &targetList, const std::vector &abundanceVector) { + std::unordered_map abundance; + abundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + abundance[targetList[i]] = abundanceVector[i]; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = abundance[it->second[j].result.dbKey]; + } + } +} + +static std::vector emUpdate(MappingTable &mappingTable, + double betaBit, + double betaSeqId, + double betaCov, + const std::unordered_map &fixedCoverageConfidence, + const std::vector &targetList, + size_t queryCount, + double abundanceExponent, + double coveragePriorWeight) { + computePosterior(mappingTable, betaBit, betaSeqId, betaCov, abundanceExponent); + + std::unordered_map nextAbundance; + nextAbundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + nextAbundance[targetList[i]] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + nextAbundance[it->second[j].result.dbKey] += it->second[j].posterior; + } + } + + (void)queryCount; + double denom = 0.0; + for (std::unordered_map::iterator it = nextAbundance.begin(); it != nextAbundance.end(); ++it) { + const std::unordered_map::const_iterator fixed = fixedCoverageConfidence.find(it->first); + const double confidence = (fixed != fixedCoverageConfidence.end()) ? fixed->second : 0.0; + it->second = it->second + (coveragePriorWeight * confidence) + ABUNDANCE_SMOOTH_EPS; + denom += it->second; + } + if (denom > 0.0) { + for (std::unordered_map::iterator it = nextAbundance.begin(); it != nextAbundance.end(); ++it) { + it->second /= denom; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + it->second[j].abundance = nextAbundance[target]; + std::unordered_map::const_iterator fixed = fixedCoverageConfidence.find(target); + if (fixed != fixedCoverageConfidence.end()) { + it->second[j].coverageConfidence = fixed->second; + } else { + it->second[j].coverageConfidence = 0.0; + } + } + } + + std::vector out; + out.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + out.push_back(nextAbundance[targetList[i]]); + } + return out; +} + +static std::vector projectSimplex(const std::vector &x) { + std::vector projected = x; + for (size_t i = 0; i < projected.size(); ++i) { + if (projected[i] < 0.0) { + projected[i] = 0.0; + } + } + + const double sum = std::accumulate(projected.begin(), projected.end(), 0.0); + if (sum > 0.0) { + for (size_t i = 0; i < projected.size(); ++i) { + projected[i] /= sum; + } + } + return projected; +} + +static void squarem(ReclassTaxContext &ctx, + double betaBit, + int maxIter, + double tol, + double alphaMax, + double coveragePriorWeight, + int threads) { + if (ctx.queryCount == 0 || ctx.targetSet.empty()) { + return; + } + + initAbundance(ctx.mappingTable, ctx.targetSet, ctx.queryCount); + initCoverageConfidence(ctx.mappingTable, ctx.targetSet, threads); + Debug(Debug::INFO) << "Reclassify initialized coverage confidence." << "\n"; + + std::unordered_map fixedCoverageConfidence; + fixedCoverageConfidence.reserve(ctx.targetSet.size()); + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + fixedCoverageConfidence[it->second[j].result.dbKey] = it->second[j].coverageConfidence; + } + } + + const std::vector targetList = targetListFromSet(ctx.targetSet); + std::vector x0 = abundanceVectorFromTable(ctx.mappingTable, targetList); + std::vector logLikelihoods; + + for (int iter = 0; iter < maxIter; ++iter) { + const double abundanceExponent = alphaMax * (1.0 - std::exp(-static_cast(iter + 1) / ABUNDANCE_EXP_TAU)); + const std::vector x1 = emUpdate(ctx.mappingTable, + betaBit, + 1.0, + 0.25, + fixedCoverageConfidence, + targetList, + ctx.queryCount, + abundanceExponent, + coveragePriorWeight); + const std::vector x2 = emUpdate(ctx.mappingTable, + betaBit, + 1.0, + 0.25, + fixedCoverageConfidence, + targetList, + ctx.queryCount, + abundanceExponent, + coveragePriorWeight); + + std::vector r(x0.size(), 0.0); + std::vector v(x0.size(), 0.0); + for (size_t i = 0; i < x0.size(); ++i) { + r[i] = x1[i] - x0[i]; + v[i] = x2[i] - x1[i] - r[i]; + } + + double normR = 0.0; + double normV = 0.0; + for (size_t i = 0; i < x0.size(); ++i) { + normR += r[i] * r[i]; + normV += v[i] * v[i]; + } + normR = std::sqrt(normR); + normV = std::sqrt(normV); + + double accel = (normV == 0.0) ? -1.0 : -(normR / normV); + accel = std::max(STEP_MIN, std::min(STEP_MAX, accel)); + + std::vector xNew(x0.size(), 0.0); + for (size_t i = 0; i < x0.size(); ++i) { + xNew[i] = x0[i] - 2.0 * accel * r[i] + accel * accel * v[i]; + } + xNew = projectSimplex(xNew); + + setAbundance(ctx.mappingTable, targetList, xNew); + computePosterior(ctx.mappingTable, betaBit, 1.0, 0.25, abundanceExponent); + double currentLl = logLikelihood(ctx.mappingTable, betaBit, 1.0, 0.25, abundanceExponent, ctx.queryCount); + + if (!logLikelihoods.empty() && currentLl < logLikelihoods.back() - 1e-9) { + setAbundance(ctx.mappingTable, targetList, x2); + computePosterior(ctx.mappingTable, betaBit, 1.0, 0.25, abundanceExponent); + currentLl = logLikelihood(ctx.mappingTable, betaBit, 1.0, 0.25, abundanceExponent, ctx.queryCount); + xNew = x2; + } + + logLikelihoods.push_back(currentLl); + + double parameterChange = 0.0; + for (size_t i = 0; i < x0.size(); ++i) { + parameterChange = std::max(parameterChange, std::fabs(xNew[i] - x0[i])); + } + + Debug(Debug::INFO) << "Reclassify iteration " << iter << ": LL=" << currentLl << " delta=" << parameterChange << "\n"; + x0 = xNew; + if (parameterChange < tol && iter > 5) { + Debug(Debug::INFO) << "Reclassify converged after " << (iter + 1) << " iterations." << "\n"; + break; + } + } +} + +static void addInterval(std::vector &intervals, int start, int end) { + Interval interval; + interval.start = std::min(start, end); + interval.end = std::max(start, end); + intervals.push_back(interval); +} + +static std::vector mergeIntervals(std::vector intervals) { + if (intervals.empty()) { + return intervals; + } + + std::sort(intervals.begin(), intervals.end(), [](const Interval &lhs, const Interval &rhs) { + if (lhs.start != rhs.start) { + return lhs.start < rhs.start; + } + return lhs.end < rhs.end; + }); + + std::vector merged; + merged.push_back(intervals[0]); + for (size_t i = 1; i < intervals.size(); ++i) { + if (intervals[i].start <= merged.back().end + 1) { + merged.back().end = std::max(merged.back().end, intervals[i].end); + } else { + merged.push_back(intervals[i]); + } + } + return merged; +} + +static std::vector collectTargetStats(const ReclassTaxContext &ctx) { + std::unordered_map statsByTarget; + + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const ReclassTaxEntry &entry = it->second[j]; + TargetStats &stats = statsByTarget[entry.result.dbKey]; + stats.key = entry.result.dbKey; + stats.targetLength = entry.result.dbLen; + stats.abundance = entry.abundance; + stats.coverageConfidence = entry.coverageConfidence; + stats.dropped = false; + addInterval(stats.intervals, entry.result.dbStartPos, entry.result.dbEndPos); + } + } + + std::vector out; + out.reserve(statsByTarget.size()); + for (std::unordered_map::iterator it = statsByTarget.begin(); it != statsByTarget.end(); ++it) { + it->second.intervals = mergeIntervals(it->second.intervals); + out.push_back(it->second); + } + + std::sort(out.begin(), out.end(), [](const TargetStats &lhs, const TargetStats &rhs) { + if (lhs.abundance != rhs.abundance) { + return lhs.abundance > rhs.abundance; + } + return lhs.key < rhs.key; + }); + return out; +} + +static void printAbundanceDistribution(const std::vector &stats) { + if (stats.empty()) { + Debug(Debug::INFO) << "Abundance distribution: no targets.\n"; + return; + } + + std::vector values; + values.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + values.push_back(stats[i].abundance); + } + std::sort(values.begin(), values.end()); + + const auto quantile = [&values](double q) -> double { + const size_t idx = static_cast(q * static_cast(values.size() - 1)); + return values[idx]; + }; + + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + const auto cumulativeCountAtMassFraction = [&values, totalMass](double massFraction, size_t &count, double &countFrac) { + count = 0; + countFrac = 0.0; + if (values.empty() || totalMass <= 0.0 || massFraction <= 0.0) { + return; + } + + const double targetMass = massFraction * totalMass; + double cumulativeMass = 0.0; + for (size_t i = 0; i < values.size(); ++i) { + if ((cumulativeMass + values[i]) <= targetMass) { + cumulativeMass += values[i]; + ++count; + } else { + break; + } + } + countFrac = static_cast(count) / static_cast(values.size()); + }; + + std::ostringstream oss; + oss << std::fixed << std::setprecision(8); + oss << "Abundance distribution (targets=" << values.size() << "):" + << " min=" << values.front() + << " p25=" << quantile(0.25) + << " p50=" << quantile(0.50) + << " p75=" << quantile(0.75) + << " p90=" << quantile(0.90) + << " p95=" << quantile(0.95) + << " p99=" << quantile(0.99) + << " max=" << values.back(); + Debug(Debug::INFO) << oss.str() << "\n"; + + const double cutoffs[] = {0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99}; + std::ostringstream cumulativeOss; + cumulativeOss << std::fixed << std::setprecision(8); + cumulativeOss << "Abundance cumulative:"; + for (size_t i = 0; i < sizeof(cutoffs) / sizeof(cutoffs[0]); ++i) { + size_t count = 0; + double countFrac = 0.0; + cumulativeCountAtMassFraction(cutoffs[i], count, countFrac); + cumulativeOss << " <=" << cutoffs[i] + << "[count=" << count + << ",countFrac=" << countFrac << "]"; + } + Debug(Debug::INFO) << cumulativeOss.str() << "\n"; +} + +static double clamp01(double value) { + return std::max(0.0, std::min(1.0, value)); +} + +static bool tailQuantileCutoff(std::vector values, + bool useLowTail, + double maxTailFraction, + double &cutoff, + size_t &tailCount) { + cutoff = 0.0; + tailCount = 0; + if (values.size() < MIN_FILTER_TARGETS) { + return false; + } + + std::sort(values.begin(), values.end()); + maxTailFraction = clamp01(maxTailFraction); + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + if (totalMass <= EPS || maxTailFraction <= 0.0) { + return false; + } + const double maxTailMass = maxTailFraction * totalMass; + + double accumulatedMass = 0.0; + size_t maxTailCount = 0; + for (size_t i = 0; i < values.size(); ++i) { + const double candidate = accumulatedMass + values[i]; + if (candidate > (maxTailMass + EPS)) { + break; + } + accumulatedMass = candidate; + ++maxTailCount; + } + if (maxTailCount < MIN_TAIL_TARGETS || maxTailCount >= values.size()) { + return false; + } + + tailCount = maxTailCount; + if (useLowTail) { + cutoff = values[tailCount - 1]; + } else { + cutoff = values[values.size() - tailCount]; + } + return true; +} + +static std::unordered_set selectTailTargets(const std::vector &stats, + bool useLowTail, + size_t tailCount, + double maxTailFraction) { + std::vector ordered; + ordered.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + ordered.push_back(&stats[i]); + } + + std::sort(ordered.begin(), ordered.end(), [useLowTail](const TargetStats *lhs, const TargetStats *rhs) { + const double lhsValue = useLowTail ? lhs->abundance : lhs->coverageConfidence; + const double rhsValue = useLowTail ? rhs->abundance : rhs->coverageConfidence; + if (lhsValue != rhsValue) { + return useLowTail ? (lhsValue < rhsValue) : (lhsValue > rhsValue); + } + return lhs->key < rhs->key; + }); + + double totalMass = 0.0; + for (size_t i = 0; i < ordered.size(); ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + totalMass += value; + } + const double maxTailMass = clamp01(maxTailFraction) * totalMass; + + std::unordered_set selected; + const size_t limit = std::min(tailCount, ordered.size()); + double selectedMass = 0.0; + selected.reserve(limit); + for (size_t i = 0; i < limit; ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + if (selected.size() >= MIN_TAIL_TARGETS && (selectedMass + value) > (maxTailMass + EPS)) { + break; + } + selectedMass += value; + selected.insert(ordered[i]->key); + } + return selected; +} + +static std::unordered_set selectDroppedTargets(const std::vector &stats, + double maxDropPercentage, + double &abundanceCutoff) { + std::unordered_set dropped; + if (stats.empty()) { + abundanceCutoff = 0.0; + return dropped; + } + if (stats.size() < MIN_FILTER_TARGETS) { + abundanceCutoff = 0.0; + return dropped; + } + + std::vector abundances; + abundances.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + abundances.push_back(stats[i].abundance); + } + + const double maxTailFraction = clamp01(maxDropPercentage / 100.0); + size_t abundanceTailCount = 0; + const bool hasAbundanceCutoff = tailQuantileCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); + if (hasAbundanceCutoff == false) { + abundanceCutoff = 0.0; + return dropped; + } + + const std::unordered_set lowAbundanceTargets = selectTailTargets(stats, true, abundanceTailCount, maxTailFraction); + for (std::unordered_set::const_iterator it = lowAbundanceTargets.begin(); it != lowAbundanceTargets.end(); ++it) { + dropped.insert(*it); + } + if (dropped.size() == stats.size()) { + dropped.clear(); + } + return dropped; +} + +static void applyDroppedTargets(ReclassTaxContext &ctx, + const std::unordered_set &dropped, + size_t totalTargets, + double abundanceCutoff) { + if (dropped.empty()) { + Debug(Debug::INFO) << "Reclassify target filter kept all targets. abundance cutoff=" + << abundanceCutoff << "\n"; + return; + } + + for (MappingTable::iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end();) { + std::vector &records = it->second; + records.erase(std::remove_if(records.begin(), records.end(), [&dropped](const ReclassTaxEntry &entry) { + return dropped.find(entry.result.dbKey) != dropped.end(); + }), records.end()); + + if (records.empty()) { + it = ctx.mappingTable.erase(it); + } else { + ++it; + } + } + + ctx.queryOrder.erase(std::remove_if(ctx.queryOrder.begin(), ctx.queryOrder.end(), [&ctx](unsigned int queryKey) { + return ctx.mappingTable.find(queryKey) == ctx.mappingTable.end(); + }), ctx.queryOrder.end()); + + for (std::unordered_set::const_iterator it = dropped.begin(); it != dropped.end(); ++it) { + ctx.targetSet.erase(*it); + } + + const double removedPct = (totalTargets > 0) + ? (100.0 * static_cast(dropped.size()) / static_cast(totalTargets)) + : 0.0; + Debug(Debug::INFO) << "Reclassify dropped " << dropped.size() + << " of " << totalTargets + << " targets (" << removedPct << "%)" + << " using abundance <= " << abundanceCutoff << ".\n"; +} + +static bool compareByPosteriorThenBitScore(const ReclassTaxEntry &a, const ReclassTaxEntry &b) { + if (a.posterior != b.posterior) { + return a.posterior > b.posterior; + } + if (a.result.score != b.result.score) { + return a.result.score > b.result.score; + } + return Matcher::compareHits(a.result, b.result); +} + +static void writeReclassifiedDb(const ReclassTaxContext &ctx, + int dbType, + const std::string &outDb, + const std::string &outIndex, + int threads, + bool compress) { + DBWriter writer(outDb.c_str(), outIndex.c_str(), threads, compress, dbType); + writer.open(); + + Debug::Progress progress(ctx.queryOrder.size()); +#pragma omp parallel + { + unsigned int thread_idx = 0; +#ifdef OPENMP + thread_idx = static_cast(omp_get_thread_num()); +#endif + char buffer[1024 + 32768 * 4]; + +#pragma omp for schedule(dynamic, 5) + for (size_t i = 0; i < ctx.queryOrder.size(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = ctx.queryOrder[i]; + MappingTable::const_iterator recordsIt = ctx.mappingTable.find(queryKey); + if (recordsIt == ctx.mappingTable.end()) { + continue; + } + + std::vector records = recordsIt->second; + SORT_SERIAL(records.begin(), records.end(), compareByPosteriorThenBitScore); + + writer.writeStart(thread_idx); + for (size_t j = 0; j < records.size(); ++j) { + Matcher::result_t res = records[j].result; + res.seqId = static_cast(records[j].posterior); + size_t len = Matcher::resultToBuffer(buffer, res, ctx.hasBacktrace, ctx.hasOrfPosition); + writer.writeAdd(buffer, len, thread_idx); + } + writer.writeEnd(queryKey, thread_idx); + } + } + + writer.close(); +} +} + +int emreclassify(int argc, const char **argv, const Command &command) { + Parameters &par = Parameters::getInstance(); + par.parseParameters(argc, argv, command, true, 0, 0); + + DBReader reader(par.db3.c_str(), par.db3Index.c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + reader.open(DBReader::LINEAR_ACCCESS); + + ReclassTaxContext ctx; + loadAlignmentDb(reader, ctx); + Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; + + squarem(ctx, + par.reclassifyLambda, + par.reclassifyMaxIterations, + par.reclassifyTolerance, + par.reclassifyAlpha, + par.reclassifyGamma, + par.threads); + + std::vector allTargetStats = collectTargetStats(ctx); + printAbundanceDistribution(allTargetStats); + double abundanceCutoff = 0.0; + const std::unordered_set dropped = selectDroppedTargets(allTargetStats, + par.reclassifyMaxDropPercentage, + abundanceCutoff); + applyDroppedTargets(ctx, dropped, allTargetStats.size(), abundanceCutoff); + const size_t totalTargets = allTargetStats.size(); + const size_t removedTargets = dropped.size(); + const double removedPct = (totalTargets > 0) + ? (100.0 * static_cast(removedTargets) / static_cast(totalTargets)) + : 0.0; + Debug(Debug::INFO) << "Reclassify-drop summary: removed " << removedTargets + << " / " << totalTargets + << " targets (" << removedPct << "%).\n"; + + writeReclassifiedDb(ctx, reader.getDbtype(), par.db4, par.db4Index, par.threads, par.compressed); + + reader.close(); + return EXIT_SUCCESS; +} diff --git a/src/util/EM_reclassify2.cpp b/src/util/EM_reclassify2.cpp new file mode 100644 index 000000000..b3a4e1d56 --- /dev/null +++ b/src/util/EM_reclassify2.cpp @@ -0,0 +1,849 @@ +#include "Parameters.h" +#include "DBReader.h" +#include "DBWriter.h" +#include "Debug.h" +#include "Util.h" +#include "Matcher.h" +#include "FastSort.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef OPENMP +#include +#endif + +namespace { +struct ReclassTaxEntry { + Matcher::result_t result; + double abundance; + double posterior; + double coverageConfidence; +}; + +typedef std::unordered_map > MappingTable; + +struct Interval { + int start; + int end; +}; + +struct TargetStats { + unsigned int key; + unsigned int targetLength; + double abundance; + double coverageConfidence; + bool dropped; + std::vector intervals; +}; + +struct ReclassTaxContext { + MappingTable mappingTable; + std::vector queryOrder; + std::unordered_set targetSet; + size_t queryCount; + bool hasBacktrace; + bool hasOrfPosition; + + ReclassTaxContext() : queryCount(0), hasBacktrace(false), hasOrfPosition(false) {} +}; + +static const double STEP_MIN = -1.0; +static const double STEP_MAX = 1.0; +static const double EPS = 1e-12; +static const size_t MIN_FILTER_TARGETS = 20; +static const size_t MIN_TAIL_TARGETS = 2; + +static double clamp01(double value); + +static std::vector targetListFromSet(const std::unordered_set &targets) { + std::vector out(targets.begin(), targets.end()); + SORT_SERIAL(out.begin(), out.end()); + return out; +} + +static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &ctx) { + Debug::Progress progress(reader.getSize()); + const char *entry[255]; + + for (size_t i = 0; i < reader.getSize(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = reader.getDbKey(i); + char *data = reader.getData(i, 0); + + if (reader.getEntryLen(i) <= 1) { + continue; + } + + std::vector &records = ctx.mappingTable[queryKey]; + if (records.empty()) { + ctx.queryOrder.push_back(queryKey); + } + while (*data != '\0') { + const size_t columns = Util::getWordsOfLine(data, entry, 255); + if (columns < Matcher::ALN_RES_WITHOUT_BT_COL_CNT) { + Debug(Debug::ERROR) << "Invalid alignment result record in query " << queryKey << ".\n"; + EXIT(EXIT_FAILURE); + } + + if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { + ctx.hasBacktrace = true; + } + if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { + ctx.hasOrfPosition = true; + } + + Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); + records.push_back(ReclassTaxEntry{result, 0.0, 0.0, 0.0}); + ctx.targetSet.insert(result.dbKey); + data = Util::skipLine(data); + } + } + + ctx.queryCount = ctx.mappingTable.size(); +} + +static void initAbundance(MappingTable &mappingTable, const std::unordered_set &targetSet, size_t queryCount) { + std::unordered_map initAbundance; + initAbundance.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + initAbundance[*it] = 0.0; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSum = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSum += it->second[j].result.score; + } + if (scoreSum <= 0.0) { + continue; + } + for (size_t j = 0; j < it->second.size(); ++j) { + initAbundance[it->second[j].result.dbKey] += it->second[j].result.score / scoreSum; + } + } + + if (queryCount > 0) { + const double denom = static_cast(queryCount); + for (std::unordered_map::iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { + it->second /= denom; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = initAbundance[it->second[j].result.dbKey]; + } + } +} + +struct TargetHitRef { + const ReclassTaxEntry *entry; + double score; + double weight; +}; + +static void initCoverageConfidence(MappingTable &mappingTable, + const std::unordered_set &targetSet, + int threads) { + (void)threads; + std::unordered_map targetMin; + std::unordered_map targetMax; + std::unordered_map > hitsByTarget; + + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + targetMin[*it] = std::numeric_limits::max(); + targetMax[*it] = std::numeric_limits::min(); + hitsByTarget.emplace(*it, std::vector()); + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSum = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSum += static_cast(it->second[j].result.score); + } + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + if (it->second[j].result.dbStartPos < targetMin[target]) { + targetMin[target] = it->second[j].result.dbStartPos; + } + if (it->second[j].result.dbEndPos > targetMax[target]) { + targetMax[target] = it->second[j].result.dbEndPos; + } + const double score = static_cast(it->second[j].result.score); + const double weight = (scoreSum > 0.0) ? (score / scoreSum) : 0.0; + hitsByTarget[target].push_back(TargetHitRef{&it->second[j], score, weight}); + } + } + + std::unordered_map coverageFraction; + coverageFraction.reserve(targetSet.size()); + const std::vector targetList = targetListFromSet(targetSet); + std::vector coverageFractionByIndex(targetList.size(), 0.0); + +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) + for (size_t i = 0; i < targetList.size(); ++i) { + const unsigned int target = targetList[i]; + const int startPos = targetMin[target]; + const int endPos = targetMax[target]; + const int len = (endPos >= startPos) ? (endPos - startPos + 1) : 1; + std::vector cov(static_cast(len), 0.0); + std::vector covConf(static_cast(len), 0.0); + + std::unordered_map >::const_iterator hitIt = hitsByTarget.find(target); + if (hitIt != hitsByTarget.end()) { + const std::vector &hits = hitIt->second; + for (size_t h = 0; h < hits.size(); ++h) { + const Matcher::result_t &result = hits[h].entry->result; + const int targetLen = result.dbEndPos - result.dbStartPos + 1; + if (targetLen <= 0) { + continue; + } + + const double mq = hits[h].score / static_cast(targetLen); + const int start = std::max(0, result.dbStartPos - startPos); + const int end = std::min(len - 1, result.dbEndPos - startPos); + for (int pos = start; pos <= end; ++pos) { + cov[static_cast(pos)] += mq; + covConf[static_cast(pos)] += hits[h].weight; + } + } + } + + double covered = 0.0; + double squaredCovered = 0.0; + for (size_t pos = 0; pos < covConf.size(); ++pos) { + const double clipped = std::min(1.0, covConf[pos]); + covered += clipped; + squaredCovered += clipped * clipped; + } + const double fraction = covered / static_cast(len); + const double hhi = (covered > 0.0) ? (squaredCovered / (covered * covered)) : 1.0; + const double concentrationPenalty = 1.0 - hhi; + coverageFractionByIndex[i] = clamp01(fraction * concentrationPenalty); + } + + for (size_t i = 0; i < targetList.size(); ++i) { + coverageFraction[targetList[i]] = coverageFractionByIndex[i]; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + std::unordered_map::const_iterator cf = coverageFraction.find(target); + it->second[j].coverageConfidence = (cf != coverageFraction.end()) ? cf->second : 0.0; + } + } +} + +static double maxQueryBitScore(const std::vector &entries) { + double maxScore = 0.0; + for (size_t i = 0; i < entries.size(); ++i) { + maxScore = std::max(maxScore, static_cast(entries[i].result.score)); + } + return maxScore; +} + +static double scoreTerm(const ReclassTaxEntry &entry, double queryMaxScore, double lambda, double alpha, double gamma) { + if (queryMaxScore <= 0.0) { + return 0.0; + } + + const double normalizedScore = static_cast(entry.result.score) / queryMaxScore; + const double abundance = std::max(entry.abundance, EPS); + const double coverageConfidence = std::max(entry.coverageConfidence, EPS); + return std::exp(lambda * normalizedScore) * std::pow(abundance, alpha) * std::pow(coverageConfidence, gamma); +} + +static void computePosterior(MappingTable &mappingTable, double lambda, double alpha, double gamma) { + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + const double queryMaxScore = maxQueryBitScore(it->second); + double denom = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + denom += scoreTerm(it->second[j], queryMaxScore, lambda, alpha, gamma); + } + for (size_t j = 0; j < it->second.size(); ++j) { + const double value = scoreTerm(it->second[j], queryMaxScore, lambda, alpha, gamma); + it->second[j].posterior = (denom > 0.0) ? (value / denom) : 0.0; + } + } +} + +static double logLikelihood(const MappingTable &mappingTable, double lambda, double alpha, double gamma, size_t queryCount) { + if (queryCount == 0) { + return 0.0; + } + + double ll = 0.0; + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + const double queryMaxScore = maxQueryBitScore(it->second); + double denom = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + const double value = scoreTerm(it->second[j], queryMaxScore, lambda, alpha, gamma); + denom += value; + if (it->second[j].posterior > 0.0) { + ll += it->second[j].posterior * std::log(value > 0.0 ? value : 1e-300); + } + } + ll -= std::log(denom > 0.0 ? denom : 1e-300); + } + return ll / static_cast(queryCount); +} + +static std::vector abundanceVectorFromTable(const MappingTable &mappingTable, const std::vector &targetList) { + std::unordered_map abundance; + abundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + abundance[targetList[i]] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + abundance[it->second[j].result.dbKey] = it->second[j].abundance; + } + } + + std::vector out; + out.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + out.push_back(abundance[targetList[i]]); + } + return out; +} + +static void setAbundance(MappingTable &mappingTable, const std::vector &targetList, const std::vector &abundanceVector) { + std::unordered_map abundance; + abundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + abundance[targetList[i]] = abundanceVector[i]; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = abundance[it->second[j].result.dbKey]; + } + } +} + +static std::vector emUpdate(MappingTable &mappingTable, + double lambda, + const std::unordered_map &fixedCoverageConfidence, + const std::vector &targetList, + size_t queryCount, + double alpha, + double gamma) { + computePosterior(mappingTable, lambda, alpha, gamma); + + std::unordered_map nextAbundance; + nextAbundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + nextAbundance[targetList[i]] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + nextAbundance[it->second[j].result.dbKey] += it->second[j].posterior; + } + } + + if (queryCount > 0) { + const double denom = static_cast(queryCount); + for (std::unordered_map::iterator it = nextAbundance.begin(); it != nextAbundance.end(); ++it) { + it->second /= denom; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + it->second[j].abundance = nextAbundance[target]; + std::unordered_map::const_iterator fixed = fixedCoverageConfidence.find(target); + if (fixed != fixedCoverageConfidence.end()) { + it->second[j].coverageConfidence = fixed->second; + } else { + it->second[j].coverageConfidence = 0.0; + } + } + } + + std::vector out; + out.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + out.push_back(nextAbundance[targetList[i]]); + } + return out; +} + +static std::vector projectSimplex(const std::vector &x) { + std::vector projected = x; + for (size_t i = 0; i < projected.size(); ++i) { + if (projected[i] < 0.0) { + projected[i] = 0.0; + } + } + + const double sum = std::accumulate(projected.begin(), projected.end(), 0.0); + if (sum > 0.0) { + for (size_t i = 0; i < projected.size(); ++i) { + projected[i] /= sum; + } + } + return projected; +} + +static void squarem(ReclassTaxContext &ctx, + double lambda, + int maxIter, + double tol, + double alpha, + double gamma, + int threads) { + if (ctx.queryCount == 0 || ctx.targetSet.empty()) { + return; + } + + initAbundance(ctx.mappingTable, ctx.targetSet, ctx.queryCount); + initCoverageConfidence(ctx.mappingTable, ctx.targetSet, threads); + Debug(Debug::INFO) << "Reclassify initialized coverage confidence." << "\n"; + + std::unordered_map fixedCoverageConfidence; + fixedCoverageConfidence.reserve(ctx.targetSet.size()); + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + fixedCoverageConfidence[it->second[j].result.dbKey] = it->second[j].coverageConfidence; + } + } + + const std::vector targetList = targetListFromSet(ctx.targetSet); + std::vector x0 = abundanceVectorFromTable(ctx.mappingTable, targetList); + std::vector logLikelihoods; + + for (int iter = 0; iter < maxIter; ++iter) { + const std::vector x1 = emUpdate(ctx.mappingTable, lambda, fixedCoverageConfidence, targetList, ctx.queryCount, alpha, gamma); + const std::vector x2 = emUpdate(ctx.mappingTable, lambda, fixedCoverageConfidence, targetList, ctx.queryCount, alpha, gamma); + + std::vector r(x0.size(), 0.0); + std::vector v(x0.size(), 0.0); + for (size_t i = 0; i < x0.size(); ++i) { + r[i] = x1[i] - x0[i]; + v[i] = x2[i] - x1[i] - r[i]; + } + + double normR = 0.0; + double normV = 0.0; + for (size_t i = 0; i < x0.size(); ++i) { + normR += r[i] * r[i]; + normV += v[i] * v[i]; + } + normR = std::sqrt(normR); + normV = std::sqrt(normV); + + double accel = (normV == 0.0) ? -1.0 : -(normR / normV); + accel = std::max(STEP_MIN, std::min(STEP_MAX, accel)); + + std::vector xNew(x0.size(), 0.0); + for (size_t i = 0; i < x0.size(); ++i) { + xNew[i] = x0[i] - 2.0 * accel * r[i] + accel * accel * v[i]; + } + xNew = projectSimplex(xNew); + + setAbundance(ctx.mappingTable, targetList, xNew); + computePosterior(ctx.mappingTable, lambda, alpha, gamma); + double currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, gamma, ctx.queryCount); + + if (!logLikelihoods.empty() && currentLl < logLikelihoods.back() - 1e-9) { + setAbundance(ctx.mappingTable, targetList, x2); + computePosterior(ctx.mappingTable, lambda, alpha, gamma); + currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, gamma, ctx.queryCount); + xNew = x2; + } + + logLikelihoods.push_back(currentLl); + + double parameterChange = 0.0; + for (size_t i = 0; i < x0.size(); ++i) { + parameterChange = std::max(parameterChange, std::fabs(xNew[i] - x0[i])); + } + + Debug(Debug::INFO) << "Reclassify iteration " << iter << ": LL=" << currentLl << " delta=" << parameterChange << "\n"; + x0 = xNew; + if (parameterChange < tol && iter > 5) { + Debug(Debug::INFO) << "Reclassify converged after " << (iter + 1) << " iterations." << "\n"; + break; + } + } +} + +static void addInterval(std::vector &intervals, int start, int end) { + Interval interval; + interval.start = std::min(start, end); + interval.end = std::max(start, end); + intervals.push_back(interval); +} + +static std::vector mergeIntervals(std::vector intervals) { + if (intervals.empty()) { + return intervals; + } + + std::sort(intervals.begin(), intervals.end(), [](const Interval &lhs, const Interval &rhs) { + if (lhs.start != rhs.start) { + return lhs.start < rhs.start; + } + return lhs.end < rhs.end; + }); + + std::vector merged; + merged.push_back(intervals[0]); + for (size_t i = 1; i < intervals.size(); ++i) { + if (intervals[i].start <= merged.back().end + 1) { + merged.back().end = std::max(merged.back().end, intervals[i].end); + } else { + merged.push_back(intervals[i]); + } + } + return merged; +} + +static std::vector collectTargetStats(const ReclassTaxContext &ctx) { + std::unordered_map statsByTarget; + + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const ReclassTaxEntry &entry = it->second[j]; + TargetStats &stats = statsByTarget[entry.result.dbKey]; + stats.key = entry.result.dbKey; + stats.targetLength = entry.result.dbLen; + stats.abundance = entry.abundance; + stats.coverageConfidence = entry.coverageConfidence; + stats.dropped = false; + addInterval(stats.intervals, entry.result.dbStartPos, entry.result.dbEndPos); + } + } + + std::vector out; + out.reserve(statsByTarget.size()); + for (std::unordered_map::iterator it = statsByTarget.begin(); it != statsByTarget.end(); ++it) { + it->second.intervals = mergeIntervals(it->second.intervals); + out.push_back(it->second); + } + + std::sort(out.begin(), out.end(), [](const TargetStats &lhs, const TargetStats &rhs) { + if (lhs.abundance != rhs.abundance) { + return lhs.abundance > rhs.abundance; + } + return lhs.key < rhs.key; + }); + return out; +} + +static double clamp01(double value) { + return std::max(0.0, std::min(1.0, value)); +} + +static bool largestJumpCutoff(std::vector values, + bool useLowTail, + double maxTailFraction, + double &cutoff, + size_t &tailCount) { + cutoff = 0.0; + tailCount = 0; + if (values.size() < MIN_FILTER_TARGETS) { + return false; + } + + std::sort(values.begin(), values.end()); + maxTailFraction = clamp01(maxTailFraction); + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + if (totalMass <= EPS || maxTailFraction <= 0.0) { + return false; + } + const double maxTailMass = maxTailFraction * totalMass; + + double bestGap = 0.0; + size_t bestIdx = 0; + double lowTailMass = 0.0; + for (size_t i = 0; i + 1 < values.size(); ++i) { + const size_t lowTailCount = i + 1; + const size_t highTailCount = values.size() - lowTailCount; + const size_t candidateTailCount = useLowTail ? lowTailCount : highTailCount; + lowTailMass += values[i]; + const double highTailMass = totalMass - lowTailMass; + const double candidateTailMass = useLowTail ? lowTailMass : highTailMass; + if (candidateTailCount < MIN_TAIL_TARGETS || candidateTailMass > (maxTailMass + EPS)) { + continue; + } + + const double gap = values[i + 1] - values[i]; + if (gap > bestGap) { + bestGap = gap; + bestIdx = i; + tailCount = candidateTailCount; + } + } + + if (bestGap <= EPS) { + return false; + } + + cutoff = 0.5 * (values[bestIdx] + values[bestIdx + 1]); + return true; +} + +static bool tailQuantileCutoff(std::vector values, + bool useLowTail, + double maxTailFraction, + double &cutoff, + size_t &tailCount) { + cutoff = 0.0; + tailCount = 0; + if (values.size() < MIN_FILTER_TARGETS) { + return false; + } + + std::sort(values.begin(), values.end()); + maxTailFraction = clamp01(maxTailFraction); + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + if (totalMass <= EPS || maxTailFraction <= 0.0) { + return false; + } + const double maxTailMass = maxTailFraction * totalMass; + + double accumulatedMass = 0.0; + size_t maxTailCount = 0; + for (size_t i = 0; i < values.size(); ++i) { + const double candidate = accumulatedMass + values[i]; + if (candidate > (maxTailMass + EPS)) { + break; + } + accumulatedMass = candidate; + ++maxTailCount; + } + if (maxTailCount < MIN_TAIL_TARGETS || maxTailCount >= values.size()) { + return false; + } + + tailCount = maxTailCount; + if (useLowTail) { + cutoff = values[tailCount - 1]; + } else { + cutoff = values[values.size() - tailCount]; + } + return true; +} + +static std::unordered_set selectTailTargets(const std::vector &stats, + bool useLowTail, + size_t tailCount, + double maxTailFraction) { + std::vector ordered; + ordered.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + ordered.push_back(&stats[i]); + } + + std::sort(ordered.begin(), ordered.end(), [useLowTail](const TargetStats *lhs, const TargetStats *rhs) { + const double lhsValue = useLowTail ? lhs->abundance : lhs->coverageConfidence; + const double rhsValue = useLowTail ? rhs->abundance : rhs->coverageConfidence; + if (lhsValue != rhsValue) { + return useLowTail ? (lhsValue < rhsValue) : (lhsValue > rhsValue); + } + return lhs->key < rhs->key; + }); + + double totalMass = 0.0; + for (size_t i = 0; i < ordered.size(); ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + totalMass += value; + } + const double maxTailMass = clamp01(maxTailFraction) * totalMass; + + std::unordered_set selected; + const size_t limit = std::min(tailCount, ordered.size()); + double selectedMass = 0.0; + selected.reserve(limit); + for (size_t i = 0; i < limit; ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + if (selected.size() >= MIN_TAIL_TARGETS && (selectedMass + value) > (maxTailMass + EPS)) { + break; + } + selectedMass += value; + selected.insert(ordered[i]->key); + } + return selected; +} + +static std::unordered_set selectDroppedTargets(const std::vector &stats, + double maxDropPercentage, + double &abundanceCutoff) { + std::unordered_set dropped; + if (stats.empty()) { + abundanceCutoff = 0.0; + return dropped; + } + if (stats.size() < MIN_FILTER_TARGETS) { + abundanceCutoff = 0.0; + return dropped; + } + + std::vector abundances; + abundances.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + abundances.push_back(stats[i].abundance); + } + + const double maxTailFraction = clamp01(maxDropPercentage / 100.0); + size_t abundanceTailCount = 0; + bool hasAbundanceCutoff = largestJumpCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); + if (hasAbundanceCutoff == false) { + hasAbundanceCutoff = tailQuantileCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); + } + if (hasAbundanceCutoff == false) { + abundanceCutoff = 0.0; + return dropped; + } + + const std::unordered_set lowAbundanceTargets = selectTailTargets(stats, true, abundanceTailCount, maxTailFraction); + for (std::unordered_set::const_iterator it = lowAbundanceTargets.begin(); it != lowAbundanceTargets.end(); ++it) { + dropped.insert(*it); + } + if (dropped.size() == stats.size()) { + dropped.clear(); + } + return dropped; +} + +static void applyDroppedTargets(ReclassTaxContext &ctx, + const std::unordered_set &dropped, + size_t totalTargets, + double abundanceCutoff) { + if (dropped.empty()) { + Debug(Debug::INFO) << "Reclassify target filter kept all targets. abundance cutoff=" + << abundanceCutoff << "\n"; + return; + } + + for (MappingTable::iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end();) { + std::vector &records = it->second; + records.erase(std::remove_if(records.begin(), records.end(), [&dropped](const ReclassTaxEntry &entry) { + return dropped.find(entry.result.dbKey) != dropped.end(); + }), records.end()); + + if (records.empty()) { + it = ctx.mappingTable.erase(it); + } else { + ++it; + } + } + + ctx.queryOrder.erase(std::remove_if(ctx.queryOrder.begin(), ctx.queryOrder.end(), [&ctx](unsigned int queryKey) { + return ctx.mappingTable.find(queryKey) == ctx.mappingTable.end(); + }), ctx.queryOrder.end()); + + for (std::unordered_set::const_iterator it = dropped.begin(); it != dropped.end(); ++it) { + ctx.targetSet.erase(*it); + } + + const double removedPct = (totalTargets > 0) + ? (100.0 * static_cast(dropped.size()) / static_cast(totalTargets)) + : 0.0; + Debug(Debug::INFO) << "Reclassify dropped " << dropped.size() + << " of " << totalTargets + << " targets (" << removedPct << "%)" + << " using abundance <= " << abundanceCutoff << ".\n"; +} + +static bool compareByPosteriorThenBitScore(const ReclassTaxEntry &a, const ReclassTaxEntry &b) { + if (a.posterior != b.posterior) { + return a.posterior > b.posterior; + } + if (a.result.score != b.result.score) { + return a.result.score > b.result.score; + } + return Matcher::compareHits(a.result, b.result); +} + +static void writeReclassifiedDb(const ReclassTaxContext &ctx, + int dbType, + const std::string &outDb, + const std::string &outIndex, + int threads, + bool compress) { + DBWriter writer(outDb.c_str(), outIndex.c_str(), threads, compress, dbType); + writer.open(); + + Debug::Progress progress(ctx.queryOrder.size()); +#pragma omp parallel + { + unsigned int thread_idx = 0; +#ifdef OPENMP + thread_idx = static_cast(omp_get_thread_num()); +#endif + char buffer[1024 + 32768 * 4]; + +#pragma omp for schedule(dynamic, 5) + for (size_t i = 0; i < ctx.queryOrder.size(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = ctx.queryOrder[i]; + MappingTable::const_iterator recordsIt = ctx.mappingTable.find(queryKey); + if (recordsIt == ctx.mappingTable.end()) { + continue; + } + + std::vector records = recordsIt->second; + SORT_SERIAL(records.begin(), records.end(), compareByPosteriorThenBitScore); + + writer.writeStart(thread_idx); + for (size_t j = 0; j < records.size(); ++j) { + Matcher::result_t res = records[j].result; + res.seqId = static_cast(records[j].posterior); + size_t len = Matcher::resultToBuffer(buffer, res, ctx.hasBacktrace, ctx.hasOrfPosition); + writer.writeAdd(buffer, len, thread_idx); + } + writer.writeEnd(queryKey, thread_idx); + } + } + + writer.close(); +} +} + +int emreclassify(int argc, const char **argv, const Command &command) { + Parameters &par = Parameters::getInstance(); + par.parseParameters(argc, argv, command, true, 0, 0); + + DBReader reader(par.db3.c_str(), par.db3Index.c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + reader.open(DBReader::LINEAR_ACCCESS); + + ReclassTaxContext ctx; + loadAlignmentDb(reader, ctx); + Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; + + squarem(ctx, + par.reclassifyLambda, + par.reclassifyMaxIterations, + par.reclassifyTolerance, + par.reclassifyAlpha, + par.reclassifyGamma, + par.threads); + + std::vector allTargetStats = collectTargetStats(ctx); + double abundanceCutoff = 0.0; + const std::unordered_set dropped = selectDroppedTargets(allTargetStats, + par.reclassifyMaxDropPercentage, + abundanceCutoff); + applyDroppedTargets(ctx, dropped, allTargetStats.size(), abundanceCutoff); + + writeReclassifiedDb(ctx, reader.getDbtype(), par.db4, par.db4Index, par.threads, par.compressed); + + reader.close(); + return EXIT_SUCCESS; +} diff --git a/src/util/EM_reclassify3.cpp b/src/util/EM_reclassify3.cpp new file mode 100644 index 000000000..a6f47a3eb --- /dev/null +++ b/src/util/EM_reclassify3.cpp @@ -0,0 +1,835 @@ +// #include "Parameters.h" +// #include "DBReader.h" +// #include "DBWriter.h" +// #include "Debug.h" +// #include "Util.h" +// #include "Matcher.h" +// #include "FastSort.h" + +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #ifdef OPENMP +// #include +// #endif + +// namespace { +// struct ReclassTaxEntry { +// Matcher::result_t result; +// double abundance; +// double posterior; +// double coverageConfidence; +// }; + +// typedef std::unordered_map > MappingTable; + +// struct Interval { +// int start; +// int end; +// }; + +// struct TargetStats { +// unsigned int key; +// unsigned int targetLength; +// double abundance; +// double coverageConfidence; +// bool dropped; +// std::vector intervals; +// }; + +// struct ReclassTaxContext { +// MappingTable mappingTable; +// std::vector queryOrder; +// std::unordered_set targetSet; +// size_t queryCount; +// bool hasBacktrace; +// bool hasOrfPosition; + +// ReclassTaxContext() : queryCount(0), hasBacktrace(false), hasOrfPosition(false) {} +// }; + +// static const double STEP_MIN = -1.0; +// static const double STEP_MAX = 1.0; +// static const double EPS = 1e-12; +// static const size_t MIN_FILTER_TARGETS = 20; +// static const size_t MIN_TAIL_TARGETS = 2; + +// static double clamp01(double value); + +// static std::vector targetListFromSet(const std::unordered_set &targets) { +// std::vector out(targets.begin(), targets.end()); +// SORT_SERIAL(out.begin(), out.end()); +// return out; +// } + +// static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &ctx) { +// Debug::Progress progress(reader.getSize()); +// const char *entry[255]; + +// for (size_t i = 0; i < reader.getSize(); ++i) { +// progress.updateProgress(); +// const unsigned int queryKey = reader.getDbKey(i); +// char *data = reader.getData(i, 0); + +// if (reader.getEntryLen(i) <= 1) { +// continue; +// } + +// std::vector &records = ctx.mappingTable[queryKey]; +// if (records.empty()) { +// ctx.queryOrder.push_back(queryKey); +// } +// while (*data != '\0') { +// const size_t columns = Util::getWordsOfLine(data, entry, 255); +// if (columns < Matcher::ALN_RES_WITHOUT_BT_COL_CNT) { +// Debug(Debug::ERROR) << "Invalid alignment result record in query " << queryKey << ".\n"; +// EXIT(EXIT_FAILURE); +// } + +// if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT +// || columns == Matcher::ALN_RES_WITH_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { +// ctx.hasBacktrace = true; +// } +// if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT +// || columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { +// ctx.hasOrfPosition = true; +// } + +// Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); +// records.push_back(ReclassTaxEntry{result, 0.0, 0.0, 0.0}); +// ctx.targetSet.insert(result.dbKey); +// data = Util::skipLine(data); +// } +// } + +// ctx.queryCount = ctx.mappingTable.size(); +// } + +// static void initAbundance(MappingTable &mappingTable, const std::unordered_set &targetSet, size_t queryCount) { +// std::unordered_map initAbundance; +// initAbundance.reserve(targetSet.size()); +// for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { +// initAbundance[*it] = 0.0; +// } + +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// double scoreSum = 0.0; +// for (size_t j = 0; j < it->second.size(); ++j) { +// scoreSum += it->second[j].result.score; +// } +// if (scoreSum <= 0.0) { +// continue; +// } +// for (size_t j = 0; j < it->second.size(); ++j) { +// initAbundance[it->second[j].result.dbKey] += it->second[j].result.score / scoreSum; +// } +// } + +// if (queryCount > 0) { +// const double denom = static_cast(queryCount); +// for (std::unordered_map::iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { +// it->second /= denom; +// } +// } + +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// it->second[j].abundance = initAbundance[it->second[j].result.dbKey]; +// } +// } +// } + +// struct TargetHitRef { +// const ReclassTaxEntry *entry; +// double score; +// double weight; +// }; + +// static void initCoverageConfidence(MappingTable &mappingTable, +// const std::unordered_set &targetSet, +// int threads) { +// (void)threads; +// std::unordered_map targetMin; +// std::unordered_map targetMax; +// std::unordered_map > hitsByTarget; + +// for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { +// targetMin[*it] = std::numeric_limits::max(); +// targetMax[*it] = std::numeric_limits::min(); +// hitsByTarget.emplace(*it, std::vector()); +// } + +// for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// double scoreSum = 0.0; +// for (size_t j = 0; j < it->second.size(); ++j) { +// scoreSum += static_cast(it->second[j].result.score); +// } +// for (size_t j = 0; j < it->second.size(); ++j) { +// const unsigned int target = it->second[j].result.dbKey; +// if (it->second[j].result.dbStartPos < targetMin[target]) { +// targetMin[target] = it->second[j].result.dbStartPos; +// } +// if (it->second[j].result.dbEndPos > targetMax[target]) { +// targetMax[target] = it->second[j].result.dbEndPos; +// } +// const double score = static_cast(it->second[j].result.score); +// const double weight = (scoreSum > 0.0) ? (score / scoreSum) : 0.0; +// hitsByTarget[target].push_back(TargetHitRef{&it->second[j], score, weight}); +// } +// } + +// std::unordered_map coverageFraction; +// coverageFraction.reserve(targetSet.size()); +// const std::vector targetList = targetListFromSet(targetSet); +// std::vector coverageFractionByIndex(targetList.size(), 0.0); + +// #pragma omp parallel for num_threads(threads) schedule(dynamic, 1) +// for (size_t i = 0; i < targetList.size(); ++i) { +// const unsigned int target = targetList[i]; +// const int startPos = targetMin[target]; +// const int endPos = targetMax[target]; +// const int len = (endPos >= startPos) ? (endPos - startPos + 1) : 1; +// std::vector cov(static_cast(len), 0.0); +// std::vector covConf(static_cast(len), 0.0); + +// std::unordered_map >::const_iterator hitIt = hitsByTarget.find(target); +// if (hitIt != hitsByTarget.end()) { +// const std::vector &hits = hitIt->second; +// for (size_t h = 0; h < hits.size(); ++h) { +// const Matcher::result_t &result = hits[h].entry->result; +// const int targetLen = result.dbEndPos - result.dbStartPos + 1; +// if (targetLen <= 0) { +// continue; +// } + +// const double mq = hits[h].score / static_cast(targetLen); +// const int start = std::max(0, result.dbStartPos - startPos); +// const int end = std::min(len - 1, result.dbEndPos - startPos); +// for (int pos = start; pos <= end; ++pos) { +// cov[static_cast(pos)] += mq; +// covConf[static_cast(pos)] += hits[h].weight; +// } +// } +// } + +// double covered = 0.0; +// double squaredCovered = 0.0; +// for (size_t pos = 0; pos < covConf.size(); ++pos) { +// const double clipped = std::min(1.0, covConf[pos]); +// covered += clipped; +// squaredCovered += clipped * clipped; +// } +// const double fraction = covered / static_cast(len); +// const double hhi = (covered > 0.0) ? (squaredCovered / (covered * covered)) : 1.0; +// const double concentrationPenalty = 1.0 - hhi; +// coverageFractionByIndex[i] = clamp01(fraction * concentrationPenalty); +// } + +// for (size_t i = 0; i < targetList.size(); ++i) { +// coverageFraction[targetList[i]] = coverageFractionByIndex[i]; +// } + +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// const unsigned int target = it->second[j].result.dbKey; +// std::unordered_map::const_iterator cf = coverageFraction.find(target); +// it->second[j].coverageConfidence = (cf != coverageFraction.end()) ? cf->second : 0.0; +// } +// } +// } + +// static double scoreTerm(const ReclassTaxEntry &entry, double lambda, double alpha, double gamma) { +// const double bitScore = static_cast(entry.result.score); +// const double abundance = std::max(entry.abundance, EPS); +// const double coverageConfidence = std::max(entry.coverageConfidence, EPS); +// return std::exp(lambda * bitScore) * std::pow(abundance, alpha) * std::pow(coverageConfidence, gamma); +// } + +// static void computePosterior(MappingTable &mappingTable, double lambda, double alpha, double gamma) { +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// double denom = 0.0; +// for (size_t j = 0; j < it->second.size(); ++j) { +// denom += scoreTerm(it->second[j], lambda, alpha, gamma); +// } +// for (size_t j = 0; j < it->second.size(); ++j) { +// const double value = scoreTerm(it->second[j], lambda, alpha, gamma); +// it->second[j].posterior = (denom > 0.0) ? (value / denom) : 0.0; +// } +// } +// } + +// static double logLikelihood(const MappingTable &mappingTable, double lambda, double alpha, double gamma, size_t queryCount) { +// if (queryCount == 0) { +// return 0.0; +// } + +// double ll = 0.0; +// for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// double denom = 0.0; +// for (size_t j = 0; j < it->second.size(); ++j) { +// const double value = scoreTerm(it->second[j], lambda, alpha, gamma); +// denom += value; +// if (it->second[j].posterior > 0.0) { +// ll += it->second[j].posterior * std::log(value > 0.0 ? value : 1e-300); +// } +// } +// ll -= std::log(denom > 0.0 ? denom : 1e-300); +// } +// return ll / static_cast(queryCount); +// } + +// static std::vector abundanceVectorFromTable(const MappingTable &mappingTable, const std::vector &targetList) { +// std::unordered_map abundance; +// abundance.reserve(targetList.size()); +// for (size_t i = 0; i < targetList.size(); ++i) { +// abundance[targetList[i]] = 0.0; +// } + +// for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// abundance[it->second[j].result.dbKey] = it->second[j].abundance; +// } +// } + +// std::vector out; +// out.reserve(targetList.size()); +// for (size_t i = 0; i < targetList.size(); ++i) { +// out.push_back(abundance[targetList[i]]); +// } +// return out; +// } + +// static void setAbundance(MappingTable &mappingTable, const std::vector &targetList, const std::vector &abundanceVector) { +// std::unordered_map abundance; +// abundance.reserve(targetList.size()); +// for (size_t i = 0; i < targetList.size(); ++i) { +// abundance[targetList[i]] = abundanceVector[i]; +// } + +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// it->second[j].abundance = abundance[it->second[j].result.dbKey]; +// } +// } +// } + +// static std::vector emUpdate(MappingTable &mappingTable, +// double lambda, +// const std::unordered_map &fixedCoverageConfidence, +// const std::vector &targetList, +// size_t queryCount, +// double alpha, +// double gamma) { +// computePosterior(mappingTable, lambda, alpha, gamma); + +// std::unordered_map nextAbundance; +// nextAbundance.reserve(targetList.size()); +// for (size_t i = 0; i < targetList.size(); ++i) { +// nextAbundance[targetList[i]] = 0.0; +// } + +// for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// nextAbundance[it->second[j].result.dbKey] += it->second[j].posterior; +// } +// } + +// if (queryCount > 0) { +// const double denom = static_cast(queryCount); +// for (std::unordered_map::iterator it = nextAbundance.begin(); it != nextAbundance.end(); ++it) { +// it->second /= denom; +// } +// } + +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// const unsigned int target = it->second[j].result.dbKey; +// it->second[j].abundance = nextAbundance[target]; +// std::unordered_map::const_iterator fixed = fixedCoverageConfidence.find(target); +// if (fixed != fixedCoverageConfidence.end()) { +// it->second[j].coverageConfidence = fixed->second; +// } else { +// it->second[j].coverageConfidence = 0.0; +// } +// } +// } + +// std::vector out; +// out.reserve(targetList.size()); +// for (size_t i = 0; i < targetList.size(); ++i) { +// out.push_back(nextAbundance[targetList[i]]); +// } +// return out; +// } + +// static std::vector projectSimplex(const std::vector &x) { +// std::vector projected = x; +// for (size_t i = 0; i < projected.size(); ++i) { +// if (projected[i] < 0.0) { +// projected[i] = 0.0; +// } +// } + +// const double sum = std::accumulate(projected.begin(), projected.end(), 0.0); +// if (sum > 0.0) { +// for (size_t i = 0; i < projected.size(); ++i) { +// projected[i] /= sum; +// } +// } +// return projected; +// } + +// static void squarem(ReclassTaxContext &ctx, +// double lambda, +// int maxIter, +// double tol, +// double alpha, +// double gamma, +// int threads) { +// if (ctx.queryCount == 0 || ctx.targetSet.empty()) { +// return; +// } + +// initAbundance(ctx.mappingTable, ctx.targetSet, ctx.queryCount); +// initCoverageConfidence(ctx.mappingTable, ctx.targetSet, threads); +// Debug(Debug::INFO) << "Reclassify initialized coverage confidence." << "\n"; + +// std::unordered_map fixedCoverageConfidence; +// fixedCoverageConfidence.reserve(ctx.targetSet.size()); +// for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// fixedCoverageConfidence[it->second[j].result.dbKey] = it->second[j].coverageConfidence; +// } +// } + +// const std::vector targetList = targetListFromSet(ctx.targetSet); +// std::vector x0 = abundanceVectorFromTable(ctx.mappingTable, targetList); +// std::vector logLikelihoods; + +// for (int iter = 0; iter < maxIter; ++iter) { +// const std::vector x1 = emUpdate(ctx.mappingTable, lambda, fixedCoverageConfidence, targetList, ctx.queryCount, alpha, gamma); +// const std::vector x2 = emUpdate(ctx.mappingTable, lambda, fixedCoverageConfidence, targetList, ctx.queryCount, alpha, gamma); + +// std::vector r(x0.size(), 0.0); +// std::vector v(x0.size(), 0.0); +// for (size_t i = 0; i < x0.size(); ++i) { +// r[i] = x1[i] - x0[i]; +// v[i] = x2[i] - x1[i] - r[i]; +// } + +// double normR = 0.0; +// double normV = 0.0; +// for (size_t i = 0; i < x0.size(); ++i) { +// normR += r[i] * r[i]; +// normV += v[i] * v[i]; +// } +// normR = std::sqrt(normR); +// normV = std::sqrt(normV); + +// double accel = (normV == 0.0) ? -1.0 : -(normR / normV); +// accel = std::max(STEP_MIN, std::min(STEP_MAX, accel)); + +// std::vector xNew(x0.size(), 0.0); +// for (size_t i = 0; i < x0.size(); ++i) { +// xNew[i] = x0[i] - 2.0 * accel * r[i] + accel * accel * v[i]; +// } +// xNew = projectSimplex(xNew); + +// setAbundance(ctx.mappingTable, targetList, xNew); +// computePosterior(ctx.mappingTable, lambda, alpha, gamma); +// double currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, gamma, ctx.queryCount); + +// if (!logLikelihoods.empty() && currentLl < logLikelihoods.back() - 1e-9) { +// setAbundance(ctx.mappingTable, targetList, x2); +// computePosterior(ctx.mappingTable, lambda, alpha, gamma); +// currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, gamma, ctx.queryCount); +// xNew = x2; +// } + +// logLikelihoods.push_back(currentLl); + +// double parameterChange = 0.0; +// for (size_t i = 0; i < x0.size(); ++i) { +// parameterChange = std::max(parameterChange, std::fabs(xNew[i] - x0[i])); +// } + +// Debug(Debug::INFO) << "Reclassify iteration " << iter << ": LL=" << currentLl << " delta=" << parameterChange << "\n"; +// x0 = xNew; +// if (parameterChange < tol && iter > 5) { +// Debug(Debug::INFO) << "Reclassify converged after " << (iter + 1) << " iterations." << "\n"; +// break; +// } +// } +// } + +// static void addInterval(std::vector &intervals, int start, int end) { +// Interval interval; +// interval.start = std::min(start, end); +// interval.end = std::max(start, end); +// intervals.push_back(interval); +// } + +// static std::vector mergeIntervals(std::vector intervals) { +// if (intervals.empty()) { +// return intervals; +// } + +// std::sort(intervals.begin(), intervals.end(), [](const Interval &lhs, const Interval &rhs) { +// if (lhs.start != rhs.start) { +// return lhs.start < rhs.start; +// } +// return lhs.end < rhs.end; +// }); + +// std::vector merged; +// merged.push_back(intervals[0]); +// for (size_t i = 1; i < intervals.size(); ++i) { +// if (intervals[i].start <= merged.back().end + 1) { +// merged.back().end = std::max(merged.back().end, intervals[i].end); +// } else { +// merged.push_back(intervals[i]); +// } +// } +// return merged; +// } + +// static std::vector collectTargetStats(const ReclassTaxContext &ctx) { +// std::unordered_map statsByTarget; + +// for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// const ReclassTaxEntry &entry = it->second[j]; +// TargetStats &stats = statsByTarget[entry.result.dbKey]; +// stats.key = entry.result.dbKey; +// stats.targetLength = entry.result.dbLen; +// stats.abundance = entry.abundance; +// stats.coverageConfidence = entry.coverageConfidence; +// stats.dropped = false; +// addInterval(stats.intervals, entry.result.dbStartPos, entry.result.dbEndPos); +// } +// } + +// std::vector out; +// out.reserve(statsByTarget.size()); +// for (std::unordered_map::iterator it = statsByTarget.begin(); it != statsByTarget.end(); ++it) { +// it->second.intervals = mergeIntervals(it->second.intervals); +// out.push_back(it->second); +// } + +// std::sort(out.begin(), out.end(), [](const TargetStats &lhs, const TargetStats &rhs) { +// if (lhs.abundance != rhs.abundance) { +// return lhs.abundance > rhs.abundance; +// } +// return lhs.key < rhs.key; +// }); +// return out; +// } + +// static double clamp01(double value) { +// return std::max(0.0, std::min(1.0, value)); +// } + +// static bool largestJumpCutoff(std::vector values, +// bool useLowTail, +// double maxTailFraction, +// double &cutoff, +// size_t &tailCount) { +// cutoff = 0.0; +// tailCount = 0; +// if (values.size() < MIN_FILTER_TARGETS) { +// return false; +// } + +// std::sort(values.begin(), values.end()); +// maxTailFraction = clamp01(maxTailFraction); +// const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); +// if (totalMass <= EPS || maxTailFraction <= 0.0) { +// return false; +// } +// const double maxTailMass = maxTailFraction * totalMass; + +// double bestGap = 0.0; +// size_t bestIdx = 0; +// double lowTailMass = 0.0; +// for (size_t i = 0; i + 1 < values.size(); ++i) { +// const size_t lowTailCount = i + 1; +// const size_t highTailCount = values.size() - lowTailCount; +// const size_t candidateTailCount = useLowTail ? lowTailCount : highTailCount; +// lowTailMass += values[i]; +// const double highTailMass = totalMass - lowTailMass; +// const double candidateTailMass = useLowTail ? lowTailMass : highTailMass; +// if (candidateTailCount < MIN_TAIL_TARGETS || candidateTailMass > (maxTailMass + EPS)) { +// continue; +// } + +// const double gap = values[i + 1] - values[i]; +// if (gap > bestGap) { +// bestGap = gap; +// bestIdx = i; +// tailCount = candidateTailCount; +// } +// } + +// if (bestGap <= EPS) { +// return false; +// } + +// cutoff = 0.5 * (values[bestIdx] + values[bestIdx + 1]); +// return true; +// } + +// static bool tailQuantileCutoff(std::vector values, +// bool useLowTail, +// double maxTailFraction, +// double &cutoff, +// size_t &tailCount) { +// cutoff = 0.0; +// tailCount = 0; +// if (values.size() < MIN_FILTER_TARGETS) { +// return false; +// } + +// std::sort(values.begin(), values.end()); +// maxTailFraction = clamp01(maxTailFraction); +// const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); +// if (totalMass <= EPS || maxTailFraction <= 0.0) { +// return false; +// } +// const double maxTailMass = maxTailFraction * totalMass; + +// double accumulatedMass = 0.0; +// size_t maxTailCount = 0; +// for (size_t i = 0; i < values.size(); ++i) { +// const double candidate = accumulatedMass + values[i]; +// if (candidate > (maxTailMass + EPS)) { +// break; +// } +// accumulatedMass = candidate; +// ++maxTailCount; +// } +// if (maxTailCount < MIN_TAIL_TARGETS || maxTailCount >= values.size()) { +// return false; +// } + +// tailCount = maxTailCount; +// if (useLowTail) { +// cutoff = values[tailCount - 1]; +// } else { +// cutoff = values[values.size() - tailCount]; +// } +// return true; +// } + +// static std::unordered_set selectTailTargets(const std::vector &stats, +// bool useLowTail, +// size_t tailCount, +// double maxTailFraction) { +// std::vector ordered; +// ordered.reserve(stats.size()); +// for (size_t i = 0; i < stats.size(); ++i) { +// ordered.push_back(&stats[i]); +// } + +// std::sort(ordered.begin(), ordered.end(), [useLowTail](const TargetStats *lhs, const TargetStats *rhs) { +// const double lhsValue = useLowTail ? lhs->abundance : lhs->coverageConfidence; +// const double rhsValue = useLowTail ? rhs->abundance : rhs->coverageConfidence; +// if (lhsValue != rhsValue) { +// return useLowTail ? (lhsValue < rhsValue) : (lhsValue > rhsValue); +// } +// return lhs->key < rhs->key; +// }); + +// double totalMass = 0.0; +// for (size_t i = 0; i < ordered.size(); ++i) { +// const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; +// totalMass += value; +// } +// const double maxTailMass = clamp01(maxTailFraction) * totalMass; + +// std::unordered_set selected; +// const size_t limit = std::min(tailCount, ordered.size()); +// double selectedMass = 0.0; +// selected.reserve(limit); +// for (size_t i = 0; i < limit; ++i) { +// const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; +// if (selected.size() >= MIN_TAIL_TARGETS && (selectedMass + value) > (maxTailMass + EPS)) { +// break; +// } +// selectedMass += value; +// selected.insert(ordered[i]->key); +// } +// return selected; +// } + +// static std::unordered_set selectDroppedTargets(const std::vector &stats, +// double maxDropPercentage, +// double &abundanceCutoff) { +// std::unordered_set dropped; +// if (stats.empty()) { +// abundanceCutoff = 0.0; +// return dropped; +// } +// if (stats.size() < MIN_FILTER_TARGETS) { +// abundanceCutoff = 0.0; +// return dropped; +// } + +// std::vector abundances; +// abundances.reserve(stats.size()); +// for (size_t i = 0; i < stats.size(); ++i) { +// abundances.push_back(stats[i].abundance); +// } + +// const double maxTailFraction = clamp01(maxDropPercentage / 100.0); +// size_t abundanceTailCount = 0; +// bool hasAbundanceCutoff = largestJumpCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); +// if (hasAbundanceCutoff == false) { +// hasAbundanceCutoff = tailQuantileCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); +// } +// if (hasAbundanceCutoff == false) { +// abundanceCutoff = 0.0; +// return dropped; +// } + +// const std::unordered_set lowAbundanceTargets = selectTailTargets(stats, true, abundanceTailCount, maxTailFraction); +// for (std::unordered_set::const_iterator it = lowAbundanceTargets.begin(); it != lowAbundanceTargets.end(); ++it) { +// dropped.insert(*it); +// } +// if (dropped.size() == stats.size()) { +// dropped.clear(); +// } +// return dropped; +// } + +// static void applyDroppedTargets(ReclassTaxContext &ctx, +// const std::unordered_set &dropped, +// size_t totalTargets, +// double abundanceCutoff) { +// if (dropped.empty()) { +// Debug(Debug::INFO) << "Reclassify target filter kept all targets. abundance cutoff=" +// << abundanceCutoff << "\n"; +// return; +// } + +// for (MappingTable::iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end();) { +// std::vector &records = it->second; +// records.erase(std::remove_if(records.begin(), records.end(), [&dropped](const ReclassTaxEntry &entry) { +// return dropped.find(entry.result.dbKey) != dropped.end(); +// }), records.end()); + +// if (records.empty()) { +// it = ctx.mappingTable.erase(it); +// } else { +// ++it; +// } +// } + +// ctx.queryOrder.erase(std::remove_if(ctx.queryOrder.begin(), ctx.queryOrder.end(), [&ctx](unsigned int queryKey) { +// return ctx.mappingTable.find(queryKey) == ctx.mappingTable.end(); +// }), ctx.queryOrder.end()); + +// for (std::unordered_set::const_iterator it = dropped.begin(); it != dropped.end(); ++it) { +// ctx.targetSet.erase(*it); +// } + +// const double removedPct = (totalTargets > 0) +// ? (100.0 * static_cast(dropped.size()) / static_cast(totalTargets)) +// : 0.0; +// Debug(Debug::INFO) << "Reclassify dropped " << dropped.size() +// << " of " << totalTargets +// << " targets (" << removedPct << "%)" +// << " using abundance <= " << abundanceCutoff << ".\n"; +// } + +// static bool compareByPosteriorThenBitScore(const ReclassTaxEntry &a, const ReclassTaxEntry &b) { +// if (a.posterior != b.posterior) { +// return a.posterior > b.posterior; +// } +// if (a.result.score != b.result.score) { +// return a.result.score > b.result.score; +// } +// return Matcher::compareHits(a.result, b.result); +// } + +// static void writeReclassifiedDb(const ReclassTaxContext &ctx, +// int dbType, +// const std::string &outDb, +// const std::string &outIndex, +// int threads, +// bool compress) { +// DBWriter writer(outDb.c_str(), outIndex.c_str(), threads, compress, dbType); +// writer.open(); + +// Debug::Progress progress(ctx.queryOrder.size()); +// #pragma omp parallel +// { +// unsigned int thread_idx = 0; +// #ifdef OPENMP +// thread_idx = static_cast(omp_get_thread_num()); +// #endif +// char buffer[1024 + 32768 * 4]; + +// #pragma omp for schedule(dynamic, 5) +// for (size_t i = 0; i < ctx.queryOrder.size(); ++i) { +// progress.updateProgress(); +// const unsigned int queryKey = ctx.queryOrder[i]; +// MappingTable::const_iterator recordsIt = ctx.mappingTable.find(queryKey); +// if (recordsIt == ctx.mappingTable.end()) { +// continue; +// } + +// std::vector records = recordsIt->second; +// SORT_SERIAL(records.begin(), records.end(), compareByPosteriorThenBitScore); + +// writer.writeStart(thread_idx); +// for (size_t j = 0; j < records.size(); ++j) { +// Matcher::result_t res = records[j].result; +// res.seqId = static_cast(records[j].posterior); +// size_t len = Matcher::resultToBuffer(buffer, res, ctx.hasBacktrace, ctx.hasOrfPosition); +// writer.writeAdd(buffer, len, thread_idx); +// } +// writer.writeEnd(queryKey, thread_idx); +// } +// } + +// writer.close(); +// } +// } + +// int emreclassify(int argc, const char **argv, const Command &command) { +// Parameters &par = Parameters::getInstance(); +// par.parseParameters(argc, argv, command, true, 0, 0); + +// DBReader reader(par.db3.c_str(), par.db3Index.c_str(), par.threads, +// DBReader::USE_INDEX | DBReader::USE_DATA); +// reader.open(DBReader::LINEAR_ACCCESS); + +// ReclassTaxContext ctx; +// loadAlignmentDb(reader, ctx); +// Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; + +// squarem(ctx, +// par.reclassifyLambda, +// par.reclassifyMaxIterations, +// par.reclassifyTolerance, +// par.reclassifyAlpha, +// par.reclassifyGamma, +// par.threads); + +// std::vector allTargetStats = collectTargetStats(ctx); +// double abundanceCutoff = 0.0; +// const std::unordered_set dropped = selectDroppedTargets(allTargetStats, +// par.reclassifyMaxDropPercentage, +// abundanceCutoff); +// applyDroppedTargets(ctx, dropped, allTargetStats.size(), abundanceCutoff); + +// writeReclassifiedDb(ctx, reader.getDbtype(), par.db4, par.db4Index, par.threads, par.compressed); + +// reader.close(); +// return EXIT_SUCCESS; +// }