From c6139cdbf7398faa905e97c23c5c2ecda31a446e Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Tue, 24 Mar 2026 16:45:49 +0900 Subject: [PATCH 01/12] Initial EM reclassify integration --- src/CommandDeclarations.h | 1 + src/MMseqsBase.cpp | 7 + src/util/CMakeLists.txt | 1 + src/util/reclassify.cpp | 473 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 482 insertions(+) create mode 100644 src/util/reclassify.cpp diff --git a/src/CommandDeclarations.h b/src/CommandDeclarations.h index 42a1e3d3b..01364629c 100644 --- a/src/CommandDeclarations.h +++ b/src/CommandDeclarations.h @@ -101,6 +101,7 @@ extern int gappedprefilter(int argc, const char **argv, const Command& command); extern int unpackdb(int argc, const char **argv, const Command& command); extern int rbh(int argc, const char **argv, const Command& command); extern int recoverlongestorf(int argc, const char **argv, const Command& command); +extern int reclassify(int argc, const char **argv, const Command& command); extern int result2flat(int argc, const char **argv, const Command& command); extern int result2msa(int argc, const char **argv, const Command& command); extern int result2dnamsa(int argc, const char **argv, const Command& command); diff --git a/src/MMseqsBase.cpp b/src/MMseqsBase.cpp index 94c876133..bdaee9ff4 100644 --- a/src/MMseqsBase.cpp +++ b/src/MMseqsBase.cpp @@ -1072,6 +1072,13 @@ std::vector baseCommands = { " ", CITATION_MMSEQS2, {{"resultDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::resultDb }, {"resultDB", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::resultDb }}}, + {"reclassify", reclassify, &par.sortresult, COMMAND_RESULT, + "Reorder alignment hits per query using EM-based reclassification", + NULL, + "Yeji Kim", + " ", + CITATION_MMSEQS2, {{"alignmentDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }, + {"alignmentDB", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }}}, {"summarizealis", summarizealis, &par.threadsandcompression, COMMAND_RESULT, "Summarize alignment result to one row (uniq. cov., cov., avg. seq. id.)", NULL, diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index 4c7e19137..a37ce238c 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -48,6 +48,7 @@ set(util_source_files util/profile2pssm.cpp util/profile2neff.cpp util/profile2seq.cpp + util/reclassify.cpp util/recoverlongestorf.cpp util/result2dnamsa.cpp util/result2flat.cpp diff --git a/src/util/reclassify.cpp b/src/util/reclassify.cpp new file mode 100644 index 000000000..0637fc349 --- /dev/null +++ b/src/util/reclassify.cpp @@ -0,0 +1,473 @@ +#include "Parameters.h" +#include "DBReader.h" +#include "DBWriter.h" +#include "Debug.h" +#include "Util.h" +#include "Matcher.h" +#include "FastSort.h" + +#include +#include +#include +#include +#include +#include + +namespace { +struct ReclassEntry { + Matcher::result_t result; + double abundance; + double posterior; + double entropyPenalty; +}; + +typedef std::unordered_map > MappingTable; + +struct ReclassContext { + MappingTable mappingTable; + std::unordered_set targetSet; + size_t queryCount; + bool hasBacktrace; + bool hasOrfPosition; + + ReclassContext() : queryCount(0), hasBacktrace(false), hasOrfPosition(false) {} +}; + +static const double DEFAULT_LAMBDA = 0.02; +static const double DEFAULT_ALPHA = 1.0; +static const double DEFAULT_BETA = 1.0; +static const double DEFAULT_GAMMA = 1.0; +static const int DEFAULT_MAX_ITER = 100; +static const double DEFAULT_TOL = 1e-5; +static const double STEP_MIN = -1.0; +static const double STEP_MAX = 1.0; +static const double EPS = 1e-12; + +static std::vector targetListFromSet(const std::unordered_set &targets) { + std::vector out(targets.begin(), targets.end()); + SORT_SERIAL(out.begin(), out.end()); + return out; +} + +static void loadAlignmentDb(DBReader &reader, ReclassContext &ctx) { + Debug::Progress progress(reader.getSize()); + const char *entry[255]; + + for (size_t i = 0; i < reader.getSize(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = reader.getDbKey(i); + char *data = reader.getData(i, 0); + + if (reader.getEntryLen(i) <= 1) { + continue; + } + + std::vector &records = ctx.mappingTable[queryKey]; + while (*data != '\0') { + const size_t columns = Util::getWordsOfLine(data, entry, 255); + if (columns < Matcher::ALN_RES_WITHOUT_BT_COL_CNT) { + Debug(Debug::ERROR) << "Invalid alignment result record in query " << queryKey << ".\n"; + EXIT(EXIT_FAILURE); + } + + if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT) { + ctx.hasBacktrace = true; + } + if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT) { + ctx.hasOrfPosition = true; + } + + Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); + records.push_back(ReclassEntry{result, 0.0, 0.0, 0.0}); + ctx.targetSet.insert(result.dbKey); + data = Util::skipLine(data); + } + } + + ctx.queryCount = ctx.mappingTable.size(); +} + +static void initAbundance(MappingTable &mappingTable, const std::unordered_set &targetSet, size_t queryCount) { + std::unordered_map initAbundance; + initAbundance.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + initAbundance[*it] = 0.0; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSum = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSum += it->second[j].result.score; + } + if (scoreSum <= 0.0) { + continue; + } + for (size_t j = 0; j < it->second.size(); ++j) { + initAbundance[it->second[j].result.dbKey] += it->second[j].result.score / scoreSum; + } + } + + if (queryCount > 0) { + const double denom = static_cast(queryCount); + for (std::unordered_map::iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { + it->second /= denom; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = initAbundance[it->second[j].result.dbKey]; + } + } +} + +static void initEntropy(MappingTable &mappingTable, const std::unordered_set &targetSet, double lambda) { + std::unordered_map targetMin; + std::unordered_map targetMax; + + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + targetMin[*it] = std::numeric_limits::max(); + targetMax[*it] = std::numeric_limits::min(); + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + if (it->second[j].result.dbStartPos < targetMin[target]) { + targetMin[target] = it->second[j].result.dbStartPos; + } + if (it->second[j].result.dbEndPos > targetMax[target]) { + targetMax[target] = it->second[j].result.dbEndPos; + } + } + } + + std::unordered_map > coverage; + coverage.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + const int start = targetMin[*it]; + const int end = targetMax[*it]; + const int len = (end >= start) ? (end - start + 1) : 1; + coverage[*it] = std::vector(len, 0.0); + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const Matcher::result_t &result = it->second[j].result; + const int targetLen = result.dbEndPos - result.dbStartPos + 1; + if (targetLen <= 0) { + continue; + } + + const double mq = std::exp(lambda * static_cast(result.score)) / static_cast(targetLen); + std::vector &cov = coverage[result.dbKey]; + const int start = std::max(0, result.dbStartPos - targetMin[result.dbKey]); + const int end = std::min(static_cast(cov.size()) - 1, result.dbEndPos - targetMin[result.dbKey]); + for (int pos = start; pos <= end; ++pos) { + cov[pos] += mq; + } + } + } + + std::unordered_map entropy; + entropy.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + const std::vector &cov = coverage[*it]; + const double covSum = std::accumulate(cov.begin(), cov.end(), 0.0); + if (covSum <= 0.0) { + entropy[*it] = 0.0; + continue; + } + + double ent = 0.0; + for (size_t pos = 0; pos < cov.size(); ++pos) { + if (cov[pos] <= 0.0) { + continue; + } + const double p = cov[pos] / covSum; + ent -= p * std::log2(p); + } + entropy[*it] = ent; + } + + double entropySum = 0.0; + for (std::unordered_map::const_iterator it = entropy.begin(); it != entropy.end(); ++it) { + entropySum += it->second; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + it->second[j].entropyPenalty = (entropySum > 0.0) ? (1.0 - (entropy[target] / entropySum)) : 0.0; + } + } +} + +static double scoreTerm(const ReclassEntry &entry, double lambda, double alpha, double beta, double gamma) { + if (entry.result.alnLength == 0) { + return 0.0; + } + + const double bitPerLen = static_cast(entry.result.score) / static_cast(entry.result.alnLength); + const double seqId = std::max(static_cast(entry.result.seqId) / 100.0, EPS); + const double abundance = std::max(entry.abundance, EPS); + const double entropyPenalty = std::max(entry.entropyPenalty, EPS); + return std::exp(lambda * bitPerLen) * std::pow(seqId, beta) * std::pow(abundance, alpha) * std::pow(entropyPenalty, gamma); +} + +static void computePosterior(MappingTable &mappingTable, double lambda, double alpha, double beta, double gamma) { + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double denom = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + denom += scoreTerm(it->second[j], lambda, alpha, beta, gamma); + } + for (size_t j = 0; j < it->second.size(); ++j) { + const double value = scoreTerm(it->second[j], lambda, alpha, beta, gamma); + it->second[j].posterior = (denom > 0.0) ? (value / denom) : 0.0; + } + } +} + +static double logLikelihood(const MappingTable &mappingTable, double lambda, double alpha, double beta, double gamma, size_t queryCount) { + if (queryCount == 0) { + return 0.0; + } + + double ll = 0.0; + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double denom = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + const double value = scoreTerm(it->second[j], lambda, alpha, beta, gamma); + denom += value; + if (it->second[j].posterior > 0.0) { + ll += it->second[j].posterior * std::log(value > 0.0 ? value : 1e-300); + } + } + ll -= std::log(denom > 0.0 ? denom : 1e-300); + } + return ll / static_cast(queryCount); +} + +static std::vector abundanceVectorFromTable(const MappingTable &mappingTable, const std::vector &targetList) { + std::unordered_map abundance; + abundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + abundance[targetList[i]] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + abundance[it->second[j].result.dbKey] = it->second[j].abundance; + } + } + + std::vector out; + out.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + out.push_back(abundance[targetList[i]]); + } + return out; +} + +static void setAbundance(MappingTable &mappingTable, const std::vector &targetList, const std::vector &abundanceVector) { + std::unordered_map abundance; + abundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + abundance[targetList[i]] = abundanceVector[i]; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = abundance[it->second[j].result.dbKey]; + } + } +} + +static std::vector emUpdate(MappingTable &mappingTable, + double lambda, + const std::unordered_map &fixedEntropy, + const std::vector &targetList, + size_t queryCount, + double alpha, + double beta, + double gamma) { + computePosterior(mappingTable, lambda, alpha, beta, gamma); + + std::unordered_map nextAbundance; + nextAbundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + nextAbundance[targetList[i]] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + nextAbundance[it->second[j].result.dbKey] += it->second[j].posterior; + } + } + + if (queryCount > 0) { + const double denom = static_cast(queryCount); + for (std::unordered_map::iterator it = nextAbundance.begin(); it != nextAbundance.end(); ++it) { + it->second /= denom; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + it->second[j].abundance = nextAbundance[target]; + std::unordered_map::const_iterator fixed = fixedEntropy.find(target); + it->second[j].entropyPenalty = (fixed != fixedEntropy.end()) ? fixed->second : 0.0; + } + } + + std::vector out; + out.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + out.push_back(nextAbundance[targetList[i]]); + } + return out; +} + +static std::vector projectSimplex(const std::vector &x) { + std::vector projected = x; + for (size_t i = 0; i < projected.size(); ++i) { + if (projected[i] < 0.0) { + projected[i] = 0.0; + } + } + + const double sum = std::accumulate(projected.begin(), projected.end(), 0.0); + if (sum > 0.0) { + for (size_t i = 0; i < projected.size(); ++i) { + projected[i] /= sum; + } + } + return projected; +} + +static void squarem(ReclassContext &ctx, double lambda, int maxIter, double tol, double alpha, double beta, double gamma) { + if (ctx.queryCount == 0 || ctx.targetSet.empty()) { + return; + } + + initAbundance(ctx.mappingTable, ctx.targetSet, ctx.queryCount); + initEntropy(ctx.mappingTable, ctx.targetSet, lambda); + + std::unordered_map fixedEntropy; + fixedEntropy.reserve(ctx.targetSet.size()); + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + fixedEntropy[it->second[j].result.dbKey] = it->second[j].entropyPenalty; + } + } + + const std::vector targetList = targetListFromSet(ctx.targetSet); + std::vector x0 = abundanceVectorFromTable(ctx.mappingTable, targetList); + std::vector logLikelihoods; + + for (int iter = 0; iter < maxIter; ++iter) { + const std::vector x1 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, beta, gamma); + const std::vector x2 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, beta, gamma); + + std::vector r(x0.size(), 0.0); + std::vector v(x0.size(), 0.0); + for (size_t i = 0; i < x0.size(); ++i) { + r[i] = x1[i] - x0[i]; + v[i] = x2[i] - x1[i] - r[i]; + } + + double normR = 0.0; + double normV = 0.0; + for (size_t i = 0; i < x0.size(); ++i) { + normR += r[i] * r[i]; + normV += v[i] * v[i]; + } + normR = std::sqrt(normR); + normV = std::sqrt(normV); + + double accel = (normV == 0.0) ? -1.0 : -(normR / normV); + accel = std::max(STEP_MIN, std::min(STEP_MAX, accel)); + + std::vector xNew(x0.size(), 0.0); + for (size_t i = 0; i < x0.size(); ++i) { + xNew[i] = x0[i] - 2.0 * accel * r[i] + accel * accel * v[i]; + } + xNew = projectSimplex(xNew); + + setAbundance(ctx.mappingTable, targetList, xNew); + computePosterior(ctx.mappingTable, lambda, alpha, beta, gamma); + double currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, beta, gamma, ctx.queryCount); + + if (!logLikelihoods.empty() && currentLl < logLikelihoods.back() - 1e-9) { + setAbundance(ctx.mappingTable, targetList, x2); + computePosterior(ctx.mappingTable, lambda, alpha, beta, gamma); + currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, beta, gamma, ctx.queryCount); + xNew = x2; + } + + logLikelihoods.push_back(currentLl); + + double parameterChange = 0.0; + for (size_t i = 0; i < x0.size(); ++i) { + parameterChange = std::max(parameterChange, std::fabs(xNew[i] - x0[i])); + } + + Debug(Debug::INFO) << "Reclassify iteration " << iter << ": LL=" << currentLl << " delta=" << parameterChange << "\n"; + x0 = xNew; + if (parameterChange < tol && iter > 5) { + Debug(Debug::INFO) << "Reclassify converged after " << (iter + 1) << " iterations.\n"; + break; + } + } +} + +static bool compareByPosterior(const ReclassEntry &a, const ReclassEntry &b) { + if (a.posterior != b.posterior) { + return a.posterior > b.posterior; + } + return Matcher::compareHits(a.result, b.result); +} +} + +int reclassify(int argc, const char **argv, const Command &command) { + Parameters &par = Parameters::getInstance(); + par.parseParameters(argc, argv, command, true, 0, 0); + + DBReader reader(par.db1.c_str(), par.db1Index.c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + reader.open(DBReader::LINEAR_ACCCESS); + + ReclassContext ctx; + loadAlignmentDb(reader, ctx); + Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; + + squarem(ctx, DEFAULT_LAMBDA, DEFAULT_MAX_ITER, DEFAULT_TOL, DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_GAMMA); + + DBWriter writer(par.db2.c_str(), par.db2Index.c_str(), par.threads, par.compressed, reader.getDbtype()); + writer.open(); + + Debug::Progress progress(reader.getSize()); + char buffer[1024 + 32768 * 4]; + for (size_t i = 0; i < reader.getSize(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = reader.getDbKey(i); + MappingTable::iterator it = ctx.mappingTable.find(queryKey); + if (it == ctx.mappingTable.end() || it->second.empty()) { + writer.writeData("", 0, queryKey, 0); + continue; + } + + SORT_SERIAL(it->second.begin(), it->second.end(), compareByPosterior); + writer.writeStart(0); + for (size_t j = 0; j < it->second.size(); ++j) { + const size_t len = Matcher::resultToBuffer(buffer, it->second[j].result, ctx.hasBacktrace, false, ctx.hasOrfPosition); + writer.writeAdd(buffer, len, 0); + } + writer.writeEnd(queryKey, 0); + } + + writer.close(); + reader.close(); + return EXIT_SUCCESS; +} From fe424a3cb02014beec58df48286a233a2f894da4 Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Wed, 25 Mar 2026 18:06:29 +0900 Subject: [PATCH 02/12] Add reclassify flat-file and taxonomy outputs --- src/CommandDeclarations.h | 1 + src/MMseqsBase.cpp | 17 +- src/commons/Parameters.cpp | 29 ++ src/commons/Parameters.h | 19 + src/util/CMakeLists.txt | 1 + src/util/reclassify_taxonomy.cpp | 783 +++++++++++++++++++++++++++++++ 6 files changed, 843 insertions(+), 7 deletions(-) create mode 100644 src/util/reclassify_taxonomy.cpp diff --git a/src/CommandDeclarations.h b/src/CommandDeclarations.h index 01364629c..fa530a95e 100644 --- a/src/CommandDeclarations.h +++ b/src/CommandDeclarations.h @@ -102,6 +102,7 @@ extern int unpackdb(int argc, const char **argv, const Command& command); extern int rbh(int argc, const char **argv, const Command& command); extern int recoverlongestorf(int argc, const char **argv, const Command& command); extern int reclassify(int argc, const char **argv, const Command& command); +extern int reclassifytaxonomy(int argc, const char **argv, const Command& command); extern int result2flat(int argc, const char **argv, const Command& command); extern int result2msa(int argc, const char **argv, const Command& command); extern int result2dnamsa(int argc, const char **argv, const Command& command); diff --git a/src/MMseqsBase.cpp b/src/MMseqsBase.cpp index bdaee9ff4..e8c04ae14 100644 --- a/src/MMseqsBase.cpp +++ b/src/MMseqsBase.cpp @@ -1072,13 +1072,16 @@ std::vector baseCommands = { " ", CITATION_MMSEQS2, {{"resultDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::resultDb }, {"resultDB", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::resultDb }}}, - {"reclassify", reclassify, &par.sortresult, COMMAND_RESULT, - "Reorder alignment hits per query using EM-based reclassification", - NULL, - "Yeji Kim", - " ", - CITATION_MMSEQS2, {{"alignmentDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }, - {"alignmentDB", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }}}, + {"reclassify", reclassifytaxonomy, &par.reclassify, COMMAND_RESULT | COMMAND_FORMAT_CONVERSION, + "Reclassify alignments and export default flat-file summaries", + "mmseqs reclassify queryDB targetDB alignmentDB outDir\n" + "mmseqs reclassify queryDB targetDB alignmentDB outDir --taxonomy 1\n", + "Yaeji Kim", + " ", + CITATION_MMSEQS2 | CITATION_TAXONOMY, {{"queryDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::sequenceDb }, + {"targetDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER|DbType::NEED_TAXONOMY, &DbValidator::taxSequenceDb }, + {"alignmentDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }, + {"outDir", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::directory }}}, {"summarizealis", summarizealis, &par.threadsandcompression, COMMAND_RESULT, "Summarize alignment result to one row (uniq. cov., cov., avg. seq. id.)", NULL, diff --git a/src/commons/Parameters.cpp b/src/commons/Parameters.cpp index a3669d99a..085d50018 100644 --- a/src/commons/Parameters.cpp +++ b/src/commons/Parameters.cpp @@ -308,6 +308,14 @@ Parameters::Parameters(): // unpackdb PARAM_UNPACK_SUFFIX(PARAM_UNPACK_SUFFIX_ID, "--unpack-suffix", "Unpack suffix", "File suffix for unpacked files.\nAdd .gz suffix to write compressed files.", typeid(std::string), (void *) &unpackSuffix, "^.*$"), PARAM_UNPACK_NAME_MODE(PARAM_UNPACK_NAME_MODE_ID, "--unpack-name-mode", "Unpack name mode", "Name unpacked files by 0: DB key, 1: accession (through .lookup)", typeid(int), (void *) &unpackNameMode, "^[0-1]{1}$"), + // reclassify + PARAM_RECLASSIFY_LAMBDA(PARAM_RECLASSIFY_LAMBDA_ID, "--reclassify-lambda", "Reclassify lambda", "Lambda scaling factor for the reclassification score term", typeid(double), (void *) &reclassifyLambda, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_ALPHA(PARAM_RECLASSIFY_ALPHA_ID, "--reclassify-alpha", "Reclassify alpha", "Exponent applied to abundance during reclassification", typeid(double), (void *) &reclassifyAlpha, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_BETA(PARAM_RECLASSIFY_BETA_ID, "--reclassify-beta", "Reclassify beta", "Exponent applied to sequence identity during reclassification", typeid(double), (void *) &reclassifyBeta, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_GAMMA(PARAM_RECLASSIFY_GAMMA_ID, "--reclassify-gamma", "Reclassify gamma", "Exponent applied to entropy penalty during reclassification", typeid(double), (void *) &reclassifyGamma, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_MAX_ITER(PARAM_RECLASSIFY_MAX_ITER_ID, "--reclassify-max-iter", "Reclassify max iterations", "Maximum number of SQUAREM iterations for reclassification", typeid(int), (void *) &reclassifyMaxIterations, "^[1-9]{1}[0-9]*$"), + PARAM_RECLASSIFY_TOL(PARAM_RECLASSIFY_TOL_ID, "--reclassify-tol", "Reclassify tolerance", "Convergence tolerance for reclassification", typeid(double), (void *) &reclassifyTolerance, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_TAXONOMY(PARAM_RECLASSIFY_TAXONOMY_ID, "--taxonomy", "Reclassify taxonomy output", "0: write alignment and protein abundance only, 1: also write taxonomy_abundance.tsv", typeid(int), (void *) &reclassifyTaxonomy, "^[0-1]{1}$"), // for modules that should handle -h themselves PARAM_HELP(PARAM_HELP_ID, "-h", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN), PARAM_HELP_LONG(PARAM_HELP_LONG_ID, "--help", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN) @@ -335,6 +343,18 @@ Parameters::Parameters(): threadsandcompression.push_back(&PARAM_COMPRESSED); threadsandcompression.push_back(&PARAM_V); + // reclassify + reclassify.push_back(&PARAM_RECLASSIFY_LAMBDA); + reclassify.push_back(&PARAM_RECLASSIFY_ALPHA); + reclassify.push_back(&PARAM_RECLASSIFY_BETA); + reclassify.push_back(&PARAM_RECLASSIFY_GAMMA); + reclassify.push_back(&PARAM_RECLASSIFY_MAX_ITER); + reclassify.push_back(&PARAM_RECLASSIFY_TOL); + reclassify.push_back(&PARAM_RECLASSIFY_TAXONOMY); + reclassify.push_back(&PARAM_THREADS); + reclassify.push_back(&PARAM_COMPRESSED); + reclassify.push_back(&PARAM_V); + // createclusearchdb createclusearchdb.push_back(&PARAM_THREADS); createclusearchdb.push_back(&PARAM_COMPRESSED); @@ -2637,6 +2657,15 @@ void Parameters::setDefaults() { unpackSuffix = ""; unpackNameMode = Parameters::UNPACK_NAME_ACCESSION; + // reclassify + reclassifyLambda = 0.02; + reclassifyAlpha = 1.0; + reclassifyBeta = 1.0; + reclassifyGamma = 1.0; + reclassifyMaxIterations = 100; + reclassifyTolerance = 1e-5; + reclassifyTaxonomy = 0; + lcaRanks = ""; showTaxLineage = 0; // bin for all unclassified sequences diff --git a/src/commons/Parameters.h b/src/commons/Parameters.h index caa0e4735..f1231a90f 100644 --- a/src/commons/Parameters.h +++ b/src/commons/Parameters.h @@ -724,6 +724,15 @@ class Parameters { std::string unpackSuffix; int unpackNameMode; + // reclassify + double reclassifyLambda; + double reclassifyAlpha; + double reclassifyBeta; + double reclassifyGamma; + int reclassifyMaxIterations; + double reclassifyTolerance; + int reclassifyTaxonomy; + // for modules that should handle -h themselves bool help; @@ -1081,6 +1090,15 @@ class Parameters { PARAMETER(PARAM_UNPACK_SUFFIX) PARAMETER(PARAM_UNPACK_NAME_MODE) + // reclassify + PARAMETER(PARAM_RECLASSIFY_LAMBDA) + PARAMETER(PARAM_RECLASSIFY_ALPHA) + PARAMETER(PARAM_RECLASSIFY_BETA) + PARAMETER(PARAM_RECLASSIFY_GAMMA) + PARAMETER(PARAM_RECLASSIFY_MAX_ITER) + PARAMETER(PARAM_RECLASSIFY_TOL) + PARAMETER(PARAM_RECLASSIFY_TAXONOMY) + // for modules that should handle -h themselves PARAMETER(PARAM_HELP) PARAMETER(PARAM_HELP_LONG) @@ -1207,6 +1225,7 @@ class Parameters { std::vector touchdb; std::vector gpuserver; std::vector tsv2exprofiledb; + std::vector reclassify; std::vector combineList(const std::vector &par1, const std::vector &par2); diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index a37ce238c..616d97bd2 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -49,6 +49,7 @@ set(util_source_files util/profile2neff.cpp util/profile2seq.cpp util/reclassify.cpp + util/reclassify_taxonomy.cpp util/recoverlongestorf.cpp util/result2dnamsa.cpp util/result2flat.cpp diff --git a/src/util/reclassify_taxonomy.cpp b/src/util/reclassify_taxonomy.cpp new file mode 100644 index 000000000..5d71056f0 --- /dev/null +++ b/src/util/reclassify_taxonomy.cpp @@ -0,0 +1,783 @@ +#include "Parameters.h" +#include "DBReader.h" +#include "Debug.h" +#include "Util.h" +#include "Matcher.h" +#include "FastSort.h" +#include "FileUtil.h" +#include "NcbiTaxonomy.h" +#include "MappingReader.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { +struct ReclassTaxEntry { + Matcher::result_t result; + double abundance; + double posterior; + double entropyValue; + double entropyPenalty; +}; + +typedef std::unordered_map > MappingTable; + +struct Interval { + int start; + int end; +}; + +struct TargetStats { + unsigned int key; + unsigned int targetLength; + double abundance; + double entropyValue; + double entropyPenalty; + std::vector intervals; +}; + +struct TaxonomyStats { + unsigned int taxId; + double abundance; + double entropySum; + double entropyPenaltySum; + size_t proteinCount; + + TaxonomyStats() : taxId(0), abundance(0.0), entropySum(0.0), entropyPenaltySum(0.0), proteinCount(0) {} +}; + +struct ReclassTaxContext { + MappingTable mappingTable; + std::vector queryOrder; + std::unordered_set targetSet; + size_t queryCount; + bool hasBacktrace; + bool hasOrfPosition; + + ReclassTaxContext() : queryCount(0), hasBacktrace(false), hasOrfPosition(false) {} +}; + +static const double STEP_MIN = -1.0; +static const double STEP_MAX = 1.0; +static const double EPS = 1e-12; + +static std::vector targetListFromSet(const std::unordered_set &targets) { + std::vector out(targets.begin(), targets.end()); + SORT_SERIAL(out.begin(), out.end()); + return out; +} + +static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &ctx) { + Debug::Progress progress(reader.getSize()); + const char *entry[255]; + + for (size_t i = 0; i < reader.getSize(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = reader.getDbKey(i); + char *data = reader.getData(i, 0); + + if (reader.getEntryLen(i) <= 1) { + continue; + } + + std::vector &records = ctx.mappingTable[queryKey]; + if (records.empty()) { + ctx.queryOrder.push_back(queryKey); + } + while (*data != '\0') { + const size_t columns = Util::getWordsOfLine(data, entry, 255); + if (columns < Matcher::ALN_RES_WITHOUT_BT_COL_CNT) { + Debug(Debug::ERROR) << "Invalid alignment result record in query " << queryKey << ".\n"; + EXIT(EXIT_FAILURE); + } + + if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT) { + ctx.hasBacktrace = true; + } + if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT) { + ctx.hasOrfPosition = true; + } + + Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); + records.push_back(ReclassTaxEntry{result, 0.0, 0.0, 0.0, 0.0}); + ctx.targetSet.insert(result.dbKey); + data = Util::skipLine(data); + } + } + + ctx.queryCount = ctx.mappingTable.size(); +} + +static void initAbundance(MappingTable &mappingTable, const std::unordered_set &targetSet, size_t queryCount) { + std::unordered_map initAbundance; + initAbundance.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + initAbundance[*it] = 0.0; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSum = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSum += it->second[j].result.score; + } + if (scoreSum <= 0.0) { + continue; + } + for (size_t j = 0; j < it->second.size(); ++j) { + initAbundance[it->second[j].result.dbKey] += it->second[j].result.score / scoreSum; + } + } + + if (queryCount > 0) { + const double denom = static_cast(queryCount); + for (std::unordered_map::iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { + it->second /= denom; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = initAbundance[it->second[j].result.dbKey]; + } + } +} + +static void initEntropy(MappingTable &mappingTable, const std::unordered_set &targetSet, double lambda) { + std::unordered_map targetMin; + std::unordered_map targetMax; + + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + targetMin[*it] = std::numeric_limits::max(); + targetMax[*it] = std::numeric_limits::min(); + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + if (it->second[j].result.dbStartPos < targetMin[target]) { + targetMin[target] = it->second[j].result.dbStartPos; + } + if (it->second[j].result.dbEndPos > targetMax[target]) { + targetMax[target] = it->second[j].result.dbEndPos; + } + } + } + + std::unordered_map > coverage; + coverage.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + const int start = targetMin[*it]; + const int end = targetMax[*it]; + const int len = (end >= start) ? (end - start + 1) : 1; + coverage[*it] = std::vector(len, 0.0); + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const Matcher::result_t &result = it->second[j].result; + const int targetLen = result.dbEndPos - result.dbStartPos + 1; + if (targetLen <= 0) { + continue; + } + + const double mq = std::exp(lambda * static_cast(result.score)) / static_cast(targetLen); + std::vector &cov = coverage[result.dbKey]; + const int start = std::max(0, result.dbStartPos - targetMin[result.dbKey]); + const int end = std::min(static_cast(cov.size()) - 1, result.dbEndPos - targetMin[result.dbKey]); + for (int pos = start; pos <= end; ++pos) { + cov[pos] += mq; + } + } + } + + std::unordered_map entropy; + entropy.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + const std::vector &cov = coverage[*it]; + const double covSum = std::accumulate(cov.begin(), cov.end(), 0.0); + if (covSum <= 0.0) { + entropy[*it] = 0.0; + continue; + } + + double ent = 0.0; + for (size_t pos = 0; pos < cov.size(); ++pos) { + if (cov[pos] <= 0.0) { + continue; + } + const double p = cov[pos] / covSum; + ent -= p * std::log2(p); + } + entropy[*it] = ent; + } + + double entropySum = 0.0; + for (std::unordered_map::const_iterator it = entropy.begin(); it != entropy.end(); ++it) { + entropySum += it->second; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + it->second[j].entropyValue = entropy[target]; + it->second[j].entropyPenalty = (entropySum > 0.0) ? (1.0 - (entropy[target] / entropySum)) : 0.0; + } + } +} + +static double scoreTerm(const ReclassTaxEntry &entry, double lambda, double alpha, double beta, double gamma) { + if (entry.result.alnLength == 0) { + return 0.0; + } + + const double bitPerLen = static_cast(entry.result.score) / static_cast(entry.result.alnLength); + const double seqId = std::max(static_cast(entry.result.seqId) / 100.0, EPS); + const double abundance = std::max(entry.abundance, EPS); + const double entropyPenalty = std::max(entry.entropyPenalty, EPS); + return std::exp(lambda * bitPerLen) * std::pow(seqId, beta) * std::pow(abundance, alpha) * std::pow(entropyPenalty, gamma); +} + +static void computePosterior(MappingTable &mappingTable, double lambda, double alpha, double beta, double gamma) { + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double denom = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + denom += scoreTerm(it->second[j], lambda, alpha, beta, gamma); + } + for (size_t j = 0; j < it->second.size(); ++j) { + const double value = scoreTerm(it->second[j], lambda, alpha, beta, gamma); + it->second[j].posterior = (denom > 0.0) ? (value / denom) : 0.0; + } + } +} + +static double logLikelihood(const MappingTable &mappingTable, double lambda, double alpha, double beta, double gamma, size_t queryCount) { + if (queryCount == 0) { + return 0.0; + } + + double ll = 0.0; + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double denom = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + const double value = scoreTerm(it->second[j], lambda, alpha, beta, gamma); + denom += value; + if (it->second[j].posterior > 0.0) { + ll += it->second[j].posterior * std::log(value > 0.0 ? value : 1e-300); + } + } + ll -= std::log(denom > 0.0 ? denom : 1e-300); + } + return ll / static_cast(queryCount); +} + +static std::vector abundanceVectorFromTable(const MappingTable &mappingTable, const std::vector &targetList) { + std::unordered_map abundance; + abundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + abundance[targetList[i]] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + abundance[it->second[j].result.dbKey] = it->second[j].abundance; + } + } + + std::vector out; + out.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + out.push_back(abundance[targetList[i]]); + } + return out; +} + +static void setAbundance(MappingTable &mappingTable, const std::vector &targetList, const std::vector &abundanceVector) { + std::unordered_map abundance; + abundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + abundance[targetList[i]] = abundanceVector[i]; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = abundance[it->second[j].result.dbKey]; + } + } +} + +static std::vector emUpdate(MappingTable &mappingTable, + double lambda, + const std::unordered_map > &fixedEntropy, + const std::vector &targetList, + size_t queryCount, + double alpha, + double beta, + double gamma) { + computePosterior(mappingTable, lambda, alpha, beta, gamma); + + std::unordered_map nextAbundance; + nextAbundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + nextAbundance[targetList[i]] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + nextAbundance[it->second[j].result.dbKey] += it->second[j].posterior; + } + } + + if (queryCount > 0) { + const double denom = static_cast(queryCount); + for (std::unordered_map::iterator it = nextAbundance.begin(); it != nextAbundance.end(); ++it) { + it->second /= denom; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + it->second[j].abundance = nextAbundance[target]; + std::unordered_map >::const_iterator fixed = fixedEntropy.find(target); + if (fixed != fixedEntropy.end()) { + it->second[j].entropyValue = fixed->second.first; + it->second[j].entropyPenalty = fixed->second.second; + } else { + it->second[j].entropyValue = 0.0; + it->second[j].entropyPenalty = 0.0; + } + } + } + + std::vector out; + out.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + out.push_back(nextAbundance[targetList[i]]); + } + return out; +} + +static std::vector projectSimplex(const std::vector &x) { + std::vector projected = x; + for (size_t i = 0; i < projected.size(); ++i) { + if (projected[i] < 0.0) { + projected[i] = 0.0; + } + } + + const double sum = std::accumulate(projected.begin(), projected.end(), 0.0); + if (sum > 0.0) { + for (size_t i = 0; i < projected.size(); ++i) { + projected[i] /= sum; + } + } + return projected; +} + +static void squarem(ReclassTaxContext &ctx, double lambda, int maxIter, double tol, double alpha, double beta, double gamma) { + if (ctx.queryCount == 0 || ctx.targetSet.empty()) { + return; + } + + initAbundance(ctx.mappingTable, ctx.targetSet, ctx.queryCount); + initEntropy(ctx.mappingTable, ctx.targetSet, lambda); + + std::unordered_map > fixedEntropy; + fixedEntropy.reserve(ctx.targetSet.size()); + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + fixedEntropy[it->second[j].result.dbKey] = std::make_pair(it->second[j].entropyValue, it->second[j].entropyPenalty); + } + } + + const std::vector targetList = targetListFromSet(ctx.targetSet); + std::vector x0 = abundanceVectorFromTable(ctx.mappingTable, targetList); + std::vector logLikelihoods; + + for (int iter = 0; iter < maxIter; ++iter) { + const std::vector x1 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, beta, gamma); + const std::vector x2 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, beta, gamma); + + std::vector r(x0.size(), 0.0); + std::vector v(x0.size(), 0.0); + for (size_t i = 0; i < x0.size(); ++i) { + r[i] = x1[i] - x0[i]; + v[i] = x2[i] - x1[i] - r[i]; + } + + double normR = 0.0; + double normV = 0.0; + for (size_t i = 0; i < x0.size(); ++i) { + normR += r[i] * r[i]; + normV += v[i] * v[i]; + } + normR = std::sqrt(normR); + normV = std::sqrt(normV); + + double accel = (normV == 0.0) ? -1.0 : -(normR / normV); + accel = std::max(STEP_MIN, std::min(STEP_MAX, accel)); + + std::vector xNew(x0.size(), 0.0); + for (size_t i = 0; i < x0.size(); ++i) { + xNew[i] = x0[i] - 2.0 * accel * r[i] + accel * accel * v[i]; + } + xNew = projectSimplex(xNew); + + setAbundance(ctx.mappingTable, targetList, xNew); + computePosterior(ctx.mappingTable, lambda, alpha, beta, gamma); + double currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, beta, gamma, ctx.queryCount); + + if (!logLikelihoods.empty() && currentLl < logLikelihoods.back() - 1e-9) { + setAbundance(ctx.mappingTable, targetList, x2); + computePosterior(ctx.mappingTable, lambda, alpha, beta, gamma); + currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, beta, gamma, ctx.queryCount); + xNew = x2; + } + + logLikelihoods.push_back(currentLl); + + double parameterChange = 0.0; + for (size_t i = 0; i < x0.size(); ++i) { + parameterChange = std::max(parameterChange, std::fabs(xNew[i] - x0[i])); + } + + Debug(Debug::INFO) << "Reclassify-taxonomy iteration " << iter << ": LL=" << currentLl << " delta=" << parameterChange << "\n"; + x0 = xNew; + if (parameterChange < tol && iter > 5) { + Debug(Debug::INFO) << "Reclassify-taxonomy converged after " << (iter + 1) << " iterations.\n"; + break; + } + } +} + +static bool compareByPosterior(const ReclassTaxEntry &a, const ReclassTaxEntry &b) { + if (a.posterior != b.posterior) { + return a.posterior > b.posterior; + } + return Matcher::compareHits(a.result, b.result); +} + +static const char *headerForKey(DBReader &headerReader, unsigned int key, unsigned int threadIdx) { + size_t id = headerReader.getId(key); + if (id == UINT_MAX) { + return NULL; + } + return headerReader.getData(id, threadIdx); +} + +static std::string identifierForKey(DBReader &headerReader, unsigned int key, unsigned int threadIdx) { + const char *header = headerForKey(headerReader, key, threadIdx); + if (header == NULL) { + return SSTR(key); + } + std::string parsed = Util::parseFastaHeader(header); + return parsed.empty() ? SSTR(key) : parsed; +} + +static void computeAlignmentCounts(const Matcher::result_t &res, unsigned int &alnLen, unsigned int &mismatchCount, unsigned int &gapOpenCount) { + gapOpenCount = 0; + alnLen = res.alnLength; + mismatchCount = 0; + + if (!res.backtrace.empty()) { + size_t matchCount = 0; + alnLen = 0; + for (size_t pos = 0; pos < res.backtrace.size(); ++pos) { + int cnt = 0; + if (std::isdigit(static_cast(res.backtrace[pos]))) { + cnt += Util::fast_atoi(res.backtrace.c_str() + pos); + while (std::isdigit(static_cast(res.backtrace[pos]))) { + pos++; + } + } + alnLen += cnt; + + switch (res.backtrace[pos]) { + case 'M': + matchCount += cnt; + break; + case 'D': + case 'I': + gapOpenCount += 1; + break; + } + } + const unsigned int identical = static_cast(res.seqId * static_cast(alnLen) + 0.5f); + mismatchCount = static_cast(matchCount - identical); + } else { + const int adjustQstart = (res.qStartPos == -1) ? 0 : res.qStartPos; + const int adjustDBstart = (res.dbStartPos == -1) ? 0 : res.dbStartPos; + const float bestMatchEstimate = static_cast(std::min(abs(res.qEndPos - adjustQstart), abs(res.dbEndPos - adjustDBstart))); + mismatchCount = static_cast(bestMatchEstimate * (1.0f - res.seqId) + 0.5f); + } +} + +static void addInterval(std::vector &intervals, int start, int end) { + Interval interval; + interval.start = std::min(start, end); + interval.end = std::max(start, end); + intervals.push_back(interval); +} + +static std::vector mergeIntervals(std::vector intervals) { + if (intervals.empty()) { + return intervals; + } + + std::sort(intervals.begin(), intervals.end(), [](const Interval &lhs, const Interval &rhs) { + if (lhs.start != rhs.start) { + return lhs.start < rhs.start; + } + return lhs.end < rhs.end; + }); + + std::vector merged; + merged.push_back(intervals[0]); + for (size_t i = 1; i < intervals.size(); ++i) { + if (intervals[i].start <= merged.back().end + 1) { + merged.back().end = std::max(merged.back().end, intervals[i].end); + } else { + merged.push_back(intervals[i]); + } + } + return merged; +} + +static std::string intervalsToString(const std::vector &intervals) { + std::string out; + for (size_t i = 0; i < intervals.size(); ++i) { + if (i > 0) { + out.append(","); + } + out.append(SSTR(intervals[i].start + 1)); + out.append(":"); + out.append(SSTR(intervals[i].end + 1)); + } + return out; +} + +static unsigned int intervalCoverage(const std::vector &intervals) { + unsigned int covered = 0; + for (size_t i = 0; i < intervals.size(); ++i) { + covered += static_cast(intervals[i].end - intervals[i].start + 1); + } + return covered; +} + +static std::vector collectTargetStats(const ReclassTaxContext &ctx) { + std::unordered_map statsByTarget; + + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const ReclassTaxEntry &entry = it->second[j]; + TargetStats &stats = statsByTarget[entry.result.dbKey]; + stats.key = entry.result.dbKey; + stats.targetLength = entry.result.dbLen; + stats.abundance = entry.abundance; + stats.entropyValue = entry.entropyValue; + stats.entropyPenalty = entry.entropyPenalty; + addInterval(stats.intervals, entry.result.dbStartPos, entry.result.dbEndPos); + } + } + + std::vector out; + out.reserve(statsByTarget.size()); + for (std::unordered_map::iterator it = statsByTarget.begin(); it != statsByTarget.end(); ++it) { + it->second.intervals = mergeIntervals(it->second.intervals); + out.push_back(it->second); + } + + std::sort(out.begin(), out.end(), [](const TargetStats &lhs, const TargetStats &rhs) { + if (lhs.abundance != rhs.abundance) { + return lhs.abundance > rhs.abundance; + } + return lhs.key < rhs.key; + }); + return out; +} + +static void writeReclassifiedM8(const ReclassTaxContext &ctx, + DBReader &queryHeaderReader, + DBReader &targetHeaderReader, + const std::string &path) { + FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); + char line[4096]; + + for (size_t i = 0; i < ctx.queryOrder.size(); ++i) { + const unsigned int queryKey = ctx.queryOrder[i]; + MappingTable::const_iterator recordsIt = ctx.mappingTable.find(queryKey); + if (recordsIt == ctx.mappingTable.end()) { + continue; + } + + std::string queryId = identifierForKey(queryHeaderReader, queryKey, 0); + std::vector records = recordsIt->second; + SORT_SERIAL(records.begin(), records.end(), compareByPosterior); + + for (size_t j = 0; j < records.size(); ++j) { + const Matcher::result_t &res = records[j].result; + const std::string targetId = identifierForKey(targetHeaderReader, res.dbKey, 0); + + unsigned int alnLen = 0; + unsigned int mismatchCount = 0; + unsigned int gapOpenCount = 0; + computeAlignmentCounts(res, alnLen, mismatchCount, gapOpenCount); + + const int written = snprintf(line, sizeof(line), + "%s\t%s\t%1.3f\t%u\t%u\t%u\t%d\t%d\t%d\t%d\t%.2E\t%d\n", + queryId.c_str(), targetId.c_str(), res.seqId, alnLen, + mismatchCount, gapOpenCount, + res.qStartPos + 1, res.qEndPos + 1, + res.dbStartPos + 1, res.dbEndPos + 1, + res.eval, res.score); + if (written < 0 || static_cast(written) >= sizeof(line)) { + Debug(Debug::WARNING) << "Truncated M8 line for query " << queryKey << " and target " << res.dbKey << ".\n"; + continue; + } + fputs(line, handle); + } + } + + fclose(handle); +} + +static void writeProteinStats(const std::vector &stats, + DBReader &targetHeaderReader, + MappingReader &mapping, + NcbiTaxonomy *taxonomy, + const std::string &path) { + FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); + fputs("target_key\ttarget_id\tabundance\tentropy\tmapping_parts\tmapped_length\ttarget_length\ttaxid\trank\ttaxname\ttaxlineage\n", handle); + + for (size_t i = 0; i < stats.size(); ++i) { + const unsigned int key = stats[i].key; + const std::string targetId = identifierForKey(targetHeaderReader, key, 0); + const unsigned int taxId = mapping.lookup(key); + const TaxonNode *node = (taxId != 0) ? taxonomy->taxonNode(taxId, false) : NULL; + const std::string lineage = (node != NULL) ? taxonomy->taxLineage(node, true) : "unclassified"; + const std::string parts = intervalsToString(stats[i].intervals); + const unsigned int mappedLength = intervalCoverage(stats[i].intervals); + + fprintf(handle, "%u\t%s\t%.12g\t%.12g\t%s\t%u\t%u\t%u\t%s\t%s\t%s\n", + key, + targetId.c_str(), + stats[i].abundance, + stats[i].entropyValue, + parts.c_str(), + mappedLength, + stats[i].targetLength, + taxId, + (node != NULL) ? taxonomy->getString(node->rankIdx) : "unclassified", + (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified", + lineage.c_str()); + } + + fclose(handle); +} + +static void writeTaxonomyStats(const std::vector &stats, + MappingReader &mapping, + NcbiTaxonomy *taxonomy, + const std::string &path) { + std::unordered_map aggregated; + for (size_t i = 0; i < stats.size(); ++i) { + const unsigned int taxId = mapping.lookup(stats[i].key); + TaxonomyStats &entry = aggregated[taxId]; + entry.taxId = taxId; + entry.abundance += stats[i].abundance; + entry.entropySum += stats[i].entropyValue; + entry.entropyPenaltySum += stats[i].entropyPenalty; + entry.proteinCount += 1; + } + + std::vector rows; + rows.reserve(aggregated.size()); + for (std::unordered_map::const_iterator it = aggregated.begin(); it != aggregated.end(); ++it) { + rows.push_back(it->second); + } + std::sort(rows.begin(), rows.end(), [](const TaxonomyStats &lhs, const TaxonomyStats &rhs) { + if (lhs.abundance != rhs.abundance) { + return lhs.abundance > rhs.abundance; + } + return lhs.taxId < rhs.taxId; + }); + + FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); + fputs("taxid\trank\ttaxname\ttaxlineage\tprotein_abundance\tprotein_count\tmean_entropy\tmean_entropy_penalty\n", handle); + + for (size_t i = 0; i < rows.size(); ++i) { + const TaxonNode *node = (rows[i].taxId != 0) ? taxonomy->taxonNode(rows[i].taxId, false) : NULL; + const std::string lineage = (node != NULL) ? taxonomy->taxLineage(node, true) : "unclassified"; + const double denom = (rows[i].proteinCount > 0) ? static_cast(rows[i].proteinCount) : 1.0; + fprintf(handle, "%u\t%s\t%s\t%s\t%.12g\t%zu\t%.12g\t%.12g\n", + rows[i].taxId, + (node != NULL) ? taxonomy->getString(node->rankIdx) : "unclassified", + (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified", + lineage.c_str(), + rows[i].abundance, + rows[i].proteinCount, + rows[i].entropySum / denom, + rows[i].entropyPenaltySum / denom); + } + + fclose(handle); +} +} + +int reclassifytaxonomy(int argc, const char **argv, const Command &command) { + Parameters &par = Parameters::getInstance(); + par.parseParameters(argc, argv, command, true, 0, 0); + + DBReader reader(par.db3.c_str(), par.db3Index.c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + reader.open(DBReader::LINEAR_ACCCESS); + + DBReader queryHeaderReader((par.db1 + "_h").c_str(), (par.db1 + "_h.index").c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + queryHeaderReader.open(DBReader::NOSORT); + + DBReader targetHeaderReader((par.db2 + "_h").c_str(), (par.db2 + "_h.index").c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + targetHeaderReader.open(DBReader::NOSORT); + + NcbiTaxonomy *taxonomy = NcbiTaxonomy::openTaxonomy(par.db2); + MappingReader mapping(par.db2); + + ReclassTaxContext ctx; + loadAlignmentDb(reader, ctx); + Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; + + squarem(ctx, + par.reclassifyLambda, + par.reclassifyMaxIterations, + par.reclassifyTolerance, + par.reclassifyAlpha, + par.reclassifyBeta, + par.reclassifyGamma); + + const std::string outDir = par.db4; + const std::string m8Path = outDir + "/new_alignment_result.m8"; + const std::string proteinPath = outDir + "/protein_abundance.tsv"; + const std::string taxonomyPath = outDir + "/taxonomy_abundance.tsv"; + + const std::vector targetStats = collectTargetStats(ctx); + + writeReclassifiedM8(ctx, queryHeaderReader, targetHeaderReader, m8Path); + writeProteinStats(targetStats, targetHeaderReader, mapping, taxonomy, proteinPath); + if (par.reclassifyTaxonomy == 1) { + writeTaxonomyStats(targetStats, mapping, taxonomy, taxonomyPath); + } + + delete taxonomy; + targetHeaderReader.close(); + queryHeaderReader.close(); + reader.close(); + return EXIT_SUCCESS; +} From bec1f51796c30c69e3b8b97e23ad49888cf58c1c Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Wed, 25 Mar 2026 23:25:59 +0900 Subject: [PATCH 03/12] Update taxonomy reclassification scoring --- EM.md | 329 +++++++++++++++++++++ EM_paper.md | 317 +++++++++++++++++++++ src/CommandDeclarations.h | 1 - src/commons/Parameters.cpp | 3 - src/commons/Parameters.h | 2 - src/util/CMakeLists.txt | 1 - src/util/reclassify.cpp | 473 ------------------------------- src/util/reclassify_taxonomy.cpp | 47 +-- 8 files changed, 673 insertions(+), 500 deletions(-) create mode 100644 EM.md create mode 100644 EM_paper.md delete mode 100644 src/util/reclassify.cpp diff --git a/EM.md b/EM.md new file mode 100644 index 000000000..2648dbdaa --- /dev/null +++ b/EM.md @@ -0,0 +1,329 @@ +# EM Formulation in `reclassify_taxonomy.cpp` + +## Overview + +`reclassify_taxonomy.cpp` reorders MMseqs2 alignment hits by estimating posterior support for each target using: + +- alignment score +- target abundance +- entropy-based coverage penalty + +The method is an EM-like iterative update accelerated with SQUAREM. + +## Notation + +Let: + +- \( q \): a query +- \( H_q \): the set of hits for query \(q\) +- \( h \in H_q \): one hit +- \( t(h) \): the target of hit \(h\) +- \( Q \): total number of queries +- \( s_h \): alignment score of hit \(h\) +- \( A_t \): abundance of target \(t\) +- \( E_t \): entropy penalty of target \(t\) +- \( \lambda, \alpha, \gamma \): model parameters + +## 1. Initial Abundance + +For each query \(q\), compute the sum of scores over all hits: + +\[ +S_q = \sum_{h \in H_q} s_h +\] + +Each hit contributes a normalized score fraction to its target: + +\[ +c_h = \frac{s_h}{S_q} +\] + +Initial abundance of target \(t\) is: + +\[ +A_t^{(0)} = \frac{1}{Q} \sum_q \sum_{h \in H_q,\ t(h)=t} \frac{s_h}{S_q} +\] + +## 2. Coverage Weight for Entropy + +For each hit \(h\), define target-covered length: + +\[ +\ell_h = dbEndPos_h - dbStartPos_h + 1 +\] + +The code assigns each hit a coverage weight: + +\[ +m_h = \frac{e^{\lambda s_h}}{\ell_h} +\] + +For each covered target position \(p\), accumulate: + +\[ +C_t(p) = \sum_{h \text{ covering } p} m_h +\] + +## 3. Position Probability on a Target + +Let total coverage mass on target \(t\) be: + +\[ +Z_t = \sum_p C_t(p) +\] + +Then the normalized positional probability is: + +\[ +p_t(p) = \frac{C_t(p)}{Z_t} +\] + +## 4. Target Entropy + +Coverage entropy for target \(t\) is: + +\[ +H_t = - \sum_p p_t(p)\log_2 p_t(p) +\] + +If \(Z_t = 0\), then: + +\[ +H_t = 0 +\] + +## 5. Entropy Penalty + +Let the total entropy over all targets be: + +\[ +H_{\mathrm{sum}} = \sum_t H_t +\] + +Then the penalty used in scoring is: + +\[ +E_t = +\begin{cases} +1 - \frac{H_t}{H_{\mathrm{sum}}}, & H_{\mathrm{sum}} > 0 \\ +0, & H_{\mathrm{sum}} = 0 +\end{cases} +\] + +## 6. Score Term for One Hit + +For each query \(q\), define: + +\[ +S_{\max,q} = \max_{h \in H_q} s_h +\] + +In the current implementation, this is the maximum observed bit score among the query's candidate hits. + +With + +\[ +\varepsilon = 10^{-12} +\] + +the abundance and entropy penalty are clipped from below: + +\[ +A_h = \max(A_{t(h)}, \varepsilon) +\] + +\[ +E_h = \max(E_{t(h)}, \varepsilon) +\] + +The score term is: + +\[ +f_h = +\exp\left(\lambda \frac{s_h}{S_{\max,q}}\right) +\cdot A_h^{\alpha} +\cdot E_h^{\gamma} +\] + +## 7. E-step: Posterior Probability + +For each query \(q\), normalize the score terms over all its hits: + +\[ +Z_q = \sum_{h \in H_q} f_h +\] + +Then the posterior probability of hit \(h\) is: + +\[ +P(h \mid q) = +\begin{cases} +\frac{f_h}{Z_q}, & Z_q > 0 \\ +0, & Z_q = 0 +\end{cases} +\] + +## 8. M-step: Abundance Update + +The updated abundance of target \(t\) is the average posterior mass assigned to it: + +\[ +A_t^{\mathrm{new}} = +\frac{1}{Q} +\sum_q +\sum_{h \in H_q,\ t(h)=t} +P(h \mid q) +\] + +This is the EM abundance update used in the code. + +## 9. Log-Likelihood + +For each query \(q\), the code evaluates: + +\[ +\ell_q = +\sum_{h \in H_q} P(h \mid q)\log f_h +- +\log\left(\sum_{h \in H_q} f_h\right) +\] + +The average log-likelihood is: + +\[ +LL = \frac{1}{Q} \sum_q \ell_q +\] + +This value is used to monitor whether the accelerated update is acceptable. + +## 10. Simplex Projection + +After acceleration, abundance values may become negative or fail to sum to 1. + +The code projects back to the simplex by: + +1. clipping negatives: + +\[ +x_i' = \max(x_i, 0) +\] + +2. renormalizing if the total is positive: + +\[ +x_i'' = \frac{x_i'}{\sum_j x_j'} +\] + +## 11. SQUAREM Acceleration + +Let: + +\[ +x_1 = EM(x_0), \qquad x_2 = EM(x_1) +\] + +Define: + +\[ +r = x_1 - x_0 +\] + +\[ +v = x_2 - x_1 - r +\] + +Their Euclidean norms are: + +\[ +\|r\| = \sqrt{\sum_i r_i^2}, \qquad +\|v\| = \sqrt{\sum_i v_i^2} +\] + +The acceleration factor is: + +\[ +a = +\begin{cases} +-1, & \|v\| = 0 \\ +-\frac{\|r\|}{\|v\|}, & \|v\| > 0 +\end{cases} +\] + +Then it is clipped to: + +\[ +a \in [-1, 1] +\] + +The accelerated update is: + +\[ +x_{\mathrm{new}} = x_0 - 2ar + a^2 v +\] + +Finally, `projectSimplex(x_new)` is applied. + +## 12. Likelihood Safeguard + +If the accelerated update lowers the log-likelihood, the code falls back to the second EM step: + +\[ +x_{\mathrm{new}} \leftarrow x_2 +\] + +This keeps the iteration stable. + +## 13. Convergence Criterion + +The stopping criterion is the maximum absolute parameter change: + +\[ +\Delta = \max_i |x_{\mathrm{new}, i} - x_{0,i}| +\] + +Convergence is declared when: + +\[ +\Delta < \mathrm{tol} +\] + +after at least 6 iterations. + +## 14. Final Ranking + +After convergence, hits for each query are sorted by posterior probability: + +\[ +h_i \prec h_j \quad \text{if} \quad P(h_i \mid q) > P(h_j \mid q) +\] + +If two hits have equal posterior, MMseqs2's default hit comparison is used as a tie-breaker. + +## Summary + +The model implemented in `reclassify_taxonomy.cpp` is: + +\[ +f_h = +\exp\left(\lambda \frac{s_h}{S_{\max,q}}\right) +\cdot +A_{t(h)}^{\alpha} +\cdot +E_{t(h)}^{\gamma} +\] + +with EM updates: + +\[ +P(h \mid q) = \frac{f_h}{\sum_{h' \in H_q} f_{h'}} +\] + +\[ +A_t^{\mathrm{new}} = +\frac{1}{Q} +\sum_q +\sum_{h \in H_q,\ t(h)=t} +P(h \mid q) +\] + +and SQUAREM acceleration applied to the abundance vector. diff --git a/EM_paper.md b/EM_paper.md new file mode 100644 index 000000000..f16156bed --- /dev/null +++ b/EM_paper.md @@ -0,0 +1,317 @@ +# A Paper-Style Description of the EM Reclassification Model + +## Abstract + +The reclassification procedure implemented in `reclassify_taxonomy.cpp` refines target ranking for each query by combining local alignment evidence with global target-level support. The model assigns each hit a latent posterior probability and iteratively updates target abundances using an expectation-maximization framework. To improve convergence speed, the abundance update is accelerated using SQUAREM. + +## Model Setup + +Let \(q \in \{1, \dots, Q\}\) index queries, and let \(H_q\) denote the set of candidate hits for query \(q\). For each hit \(h \in H_q\), define: + +- \(t(h)\): target assigned to hit \(h\) +- \(s_h\): alignment score + +For each target \(t\), the model estimates: + +- abundance \(A_t\) +- entropy-derived penalty \(E_t\) + +The free model parameters are: + +\[ +\lambda, \alpha, \gamma +\] + +which control the relative effect of alignment score, abundance, and entropy penalty. + +## Hit Likelihood Term + +For each query \(q\), define the query-specific normalization term: + +\[ +S_{\max,q} = \max_{h \in H_q} s_h +\] + +In the current implementation, this is the maximum observed bit score among the candidate hits of query \(q\). + +For each hit \(h\), abundance and entropy penalty are clipped from below: + +\[ +\tilde{A}_h = \max(A_{t(h)}, \varepsilon) +\] + +\[ +\tilde{E}_h = \max(E_{t(h)}, \varepsilon) +\] + +with + +\[ +\varepsilon = 10^{-12} +\] + +The unnormalized compatibility score of hit \(h\) is: + +\[ +f_h = +\exp\left(\lambda \frac{s_h}{S_{\max,q}}\right) +\cdot \tilde{A}_h^{\alpha} +\cdot \tilde{E}_h^{\gamma} +\] + +This quantity defines the relative support for assigning query \(q\) to target \(t(h)\). + +## Entropy-Based Penalty + +For a hit \(h\), let its target-covered length be: + +\[ +\ell_h = dbEndPos_h - dbStartPos_h + 1 +\] + +The hit contributes position-wise target coverage mass: + +\[ +m_h = \frac{\exp(\lambda s_h)}{\ell_h} +\] + +For target \(t\), the accumulated coverage at target position \(p\) is: + +\[ +C_t(p) = \sum_{h \text{ covering } p,\ t(h)=t} m_h +\] + +The normalized positional distribution is: + +\[ +p_t(p) = \frac{C_t(p)}{\sum_{p'} C_t(p')} +\] + +The target entropy is then: + +\[ +H_t = - \sum_p p_t(p)\log_2 p_t(p) +\] + +Let + +\[ +H_{\mathrm{sum}} = \sum_t H_t +\] + +The penalty assigned to target \(t\) is: + +\[ +E_t = +\begin{cases} +1 - \frac{H_t}{H_{\mathrm{sum}}}, & H_{\mathrm{sum}} > 0 \\ +0, & H_{\mathrm{sum}} = 0 +\end{cases} +\] + +This term modulates target support according to the spatial distribution of its matched regions. + +## Initialization + +The abundance is initialized by normalized query-local alignment support. For each query \(q\), define: + +\[ +S_q = \sum_{h \in H_q} s_h +\] + +Then the initial abundance of target \(t\) is: + +\[ +A_t^{(0)} = +\frac{1}{Q} +\sum_q +\sum_{h \in H_q,\ t(h)=t} +\frac{s_h}{S_q} +\] + +This initialization gives each query unit mass, distributed proportionally to hit score. + +## E-step + +Given the current abundances \(A_t\), posterior probabilities are assigned to hits within each query: + +\[ +P(h \mid q) = +\frac{f_h}{\sum_{h' \in H_q} f_{h'}} +\] + +whenever the denominator is positive; otherwise the posterior is set to zero. + +Thus, the E-step computes a soft assignment of each query to its candidate targets. + +## M-step + +The target abundance is updated by averaging posterior mass across all queries: + +\[ +A_t^{\mathrm{new}} = +\frac{1}{Q} +\sum_q +\sum_{h \in H_q,\ t(h)=t} +P(h \mid q) +\] + +The entropy penalty is kept fixed during the EM iteration after its initial computation. + +## Objective Function + +The code evaluates an average query-normalized log-likelihood-like objective: + +\[ +\ell_q = +\sum_{h \in H_q} P(h \mid q)\log f_h +- +\log\left(\sum_{h \in H_q} f_h\right) +\] + +and + +\[ +LL = \frac{1}{Q}\sum_q \ell_q +\] + +This objective is used to reject unstable accelerated updates. + +## SQUAREM Acceleration + +Let \(EM(x)\) denote one EM abundance update applied to abundance vector \(x\). Starting from \(x_0\), compute: + +\[ +x_1 = EM(x_0), \qquad x_2 = EM(x_1) +\] + +Define: + +\[ +r = x_1 - x_0 +\] + +\[ +v = x_2 - x_1 - r +\] + +with Euclidean norms: + +\[ +\|r\| = \sqrt{\sum_i r_i^2}, \qquad +\|v\| = \sqrt{\sum_i v_i^2} +\] + +The acceleration parameter is: + +\[ +a = +\begin{cases} +-1, & \|v\| = 0 \\ +-\frac{\|r\|}{\|v\|}, & \|v\| > 0 +\end{cases} +\] + +and is clipped into the interval: + +\[ +a \in [-1,1] +\] + +The accelerated proposal is: + +\[ +x_{\mathrm{new}} = x_0 - 2ar + a^2 v +\] + +Since \(x_{\mathrm{new}}\) must remain a valid abundance vector, it is projected back onto the simplex. + +## Simplex Projection + +Given a proposed abundance vector \(x\), projection is implemented as: + +1. clip negative entries: + +\[ +x_i' = \max(x_i, 0) +\] + +2. normalize: + +\[ +x_i'' = \frac{x_i'}{\sum_j x_j'} +\] + +if the denominator is positive. + +This ensures abundances are nonnegative and sum to 1. + +## Safeguard Step + +If the accelerated proposal decreases the objective, the algorithm falls back to the second ordinary EM iterate: + +\[ +x_{\mathrm{new}} \leftarrow x_2 +\] + +This provides a monotonicity safeguard against unstable extrapolation. + +## Convergence Criterion + +The iteration stops when the maximum coordinate-wise parameter change becomes sufficiently small: + +\[ +\Delta = \max_i |x_{\mathrm{new},i} - x_{0,i}| +\] + +Convergence is declared when: + +\[ +\Delta < \mathrm{tol} +\] + +after an initial burn-in of at least six iterations. + +## Final Ranking + +After convergence, hits are sorted by posterior probability: + +\[ +h_i \prec h_j +\quad \Longleftrightarrow \quad +P(h_i \mid q) > P(h_j \mid q) +\] + +with MMseqs2's default hit comparison used to break ties. + +## Final Model Summary + +The reclassification model can be summarized as: + +\[ +f_h = +\exp\left(\lambda \frac{s_h}{S_{\max,q}}\right) +\cdot +A_{t(h)}^{\alpha} +\cdot +E_{t(h)}^{\gamma} +\] + +with posterior assignment: + +\[ +P(h \mid q) = +\frac{f_h}{\sum_{h' \in H_q} f_{h'}} +\] + +and abundance update: + +\[ +A_t^{\mathrm{new}} = +\frac{1}{Q} +\sum_q +\sum_{h \in H_q,\ t(h)=t} +P(h \mid q) +\] + +This framework combines local alignment quality with global target prevalence to produce a query-specific posterior ranking of targets. diff --git a/src/CommandDeclarations.h b/src/CommandDeclarations.h index fa530a95e..6f4d75293 100644 --- a/src/CommandDeclarations.h +++ b/src/CommandDeclarations.h @@ -101,7 +101,6 @@ extern int gappedprefilter(int argc, const char **argv, const Command& command); extern int unpackdb(int argc, const char **argv, const Command& command); extern int rbh(int argc, const char **argv, const Command& command); extern int recoverlongestorf(int argc, const char **argv, const Command& command); -extern int reclassify(int argc, const char **argv, const Command& command); extern int reclassifytaxonomy(int argc, const char **argv, const Command& command); extern int result2flat(int argc, const char **argv, const Command& command); extern int result2msa(int argc, const char **argv, const Command& command); diff --git a/src/commons/Parameters.cpp b/src/commons/Parameters.cpp index 085d50018..9be8bb260 100644 --- a/src/commons/Parameters.cpp +++ b/src/commons/Parameters.cpp @@ -311,7 +311,6 @@ Parameters::Parameters(): // reclassify PARAM_RECLASSIFY_LAMBDA(PARAM_RECLASSIFY_LAMBDA_ID, "--reclassify-lambda", "Reclassify lambda", "Lambda scaling factor for the reclassification score term", typeid(double), (void *) &reclassifyLambda, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), PARAM_RECLASSIFY_ALPHA(PARAM_RECLASSIFY_ALPHA_ID, "--reclassify-alpha", "Reclassify alpha", "Exponent applied to abundance during reclassification", typeid(double), (void *) &reclassifyAlpha, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), - PARAM_RECLASSIFY_BETA(PARAM_RECLASSIFY_BETA_ID, "--reclassify-beta", "Reclassify beta", "Exponent applied to sequence identity during reclassification", typeid(double), (void *) &reclassifyBeta, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), PARAM_RECLASSIFY_GAMMA(PARAM_RECLASSIFY_GAMMA_ID, "--reclassify-gamma", "Reclassify gamma", "Exponent applied to entropy penalty during reclassification", typeid(double), (void *) &reclassifyGamma, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), PARAM_RECLASSIFY_MAX_ITER(PARAM_RECLASSIFY_MAX_ITER_ID, "--reclassify-max-iter", "Reclassify max iterations", "Maximum number of SQUAREM iterations for reclassification", typeid(int), (void *) &reclassifyMaxIterations, "^[1-9]{1}[0-9]*$"), PARAM_RECLASSIFY_TOL(PARAM_RECLASSIFY_TOL_ID, "--reclassify-tol", "Reclassify tolerance", "Convergence tolerance for reclassification", typeid(double), (void *) &reclassifyTolerance, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), @@ -346,7 +345,6 @@ Parameters::Parameters(): // reclassify reclassify.push_back(&PARAM_RECLASSIFY_LAMBDA); reclassify.push_back(&PARAM_RECLASSIFY_ALPHA); - reclassify.push_back(&PARAM_RECLASSIFY_BETA); reclassify.push_back(&PARAM_RECLASSIFY_GAMMA); reclassify.push_back(&PARAM_RECLASSIFY_MAX_ITER); reclassify.push_back(&PARAM_RECLASSIFY_TOL); @@ -2660,7 +2658,6 @@ void Parameters::setDefaults() { // reclassify reclassifyLambda = 0.02; reclassifyAlpha = 1.0; - reclassifyBeta = 1.0; reclassifyGamma = 1.0; reclassifyMaxIterations = 100; reclassifyTolerance = 1e-5; diff --git a/src/commons/Parameters.h b/src/commons/Parameters.h index f1231a90f..6f7a62ddd 100644 --- a/src/commons/Parameters.h +++ b/src/commons/Parameters.h @@ -727,7 +727,6 @@ class Parameters { // reclassify double reclassifyLambda; double reclassifyAlpha; - double reclassifyBeta; double reclassifyGamma; int reclassifyMaxIterations; double reclassifyTolerance; @@ -1093,7 +1092,6 @@ class Parameters { // reclassify PARAMETER(PARAM_RECLASSIFY_LAMBDA) PARAMETER(PARAM_RECLASSIFY_ALPHA) - PARAMETER(PARAM_RECLASSIFY_BETA) PARAMETER(PARAM_RECLASSIFY_GAMMA) PARAMETER(PARAM_RECLASSIFY_MAX_ITER) PARAMETER(PARAM_RECLASSIFY_TOL) diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index 616d97bd2..65562c7f4 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -48,7 +48,6 @@ set(util_source_files util/profile2pssm.cpp util/profile2neff.cpp util/profile2seq.cpp - util/reclassify.cpp util/reclassify_taxonomy.cpp util/recoverlongestorf.cpp util/result2dnamsa.cpp diff --git a/src/util/reclassify.cpp b/src/util/reclassify.cpp deleted file mode 100644 index 0637fc349..000000000 --- a/src/util/reclassify.cpp +++ /dev/null @@ -1,473 +0,0 @@ -#include "Parameters.h" -#include "DBReader.h" -#include "DBWriter.h" -#include "Debug.h" -#include "Util.h" -#include "Matcher.h" -#include "FastSort.h" - -#include -#include -#include -#include -#include -#include - -namespace { -struct ReclassEntry { - Matcher::result_t result; - double abundance; - double posterior; - double entropyPenalty; -}; - -typedef std::unordered_map > MappingTable; - -struct ReclassContext { - MappingTable mappingTable; - std::unordered_set targetSet; - size_t queryCount; - bool hasBacktrace; - bool hasOrfPosition; - - ReclassContext() : queryCount(0), hasBacktrace(false), hasOrfPosition(false) {} -}; - -static const double DEFAULT_LAMBDA = 0.02; -static const double DEFAULT_ALPHA = 1.0; -static const double DEFAULT_BETA = 1.0; -static const double DEFAULT_GAMMA = 1.0; -static const int DEFAULT_MAX_ITER = 100; -static const double DEFAULT_TOL = 1e-5; -static const double STEP_MIN = -1.0; -static const double STEP_MAX = 1.0; -static const double EPS = 1e-12; - -static std::vector targetListFromSet(const std::unordered_set &targets) { - std::vector out(targets.begin(), targets.end()); - SORT_SERIAL(out.begin(), out.end()); - return out; -} - -static void loadAlignmentDb(DBReader &reader, ReclassContext &ctx) { - Debug::Progress progress(reader.getSize()); - const char *entry[255]; - - for (size_t i = 0; i < reader.getSize(); ++i) { - progress.updateProgress(); - const unsigned int queryKey = reader.getDbKey(i); - char *data = reader.getData(i, 0); - - if (reader.getEntryLen(i) <= 1) { - continue; - } - - std::vector &records = ctx.mappingTable[queryKey]; - while (*data != '\0') { - const size_t columns = Util::getWordsOfLine(data, entry, 255); - if (columns < Matcher::ALN_RES_WITHOUT_BT_COL_CNT) { - Debug(Debug::ERROR) << "Invalid alignment result record in query " << queryKey << ".\n"; - EXIT(EXIT_FAILURE); - } - - if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT) { - ctx.hasBacktrace = true; - } - if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT) { - ctx.hasOrfPosition = true; - } - - Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); - records.push_back(ReclassEntry{result, 0.0, 0.0, 0.0}); - ctx.targetSet.insert(result.dbKey); - data = Util::skipLine(data); - } - } - - ctx.queryCount = ctx.mappingTable.size(); -} - -static void initAbundance(MappingTable &mappingTable, const std::unordered_set &targetSet, size_t queryCount) { - std::unordered_map initAbundance; - initAbundance.reserve(targetSet.size()); - for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { - initAbundance[*it] = 0.0; - } - - for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - double scoreSum = 0.0; - for (size_t j = 0; j < it->second.size(); ++j) { - scoreSum += it->second[j].result.score; - } - if (scoreSum <= 0.0) { - continue; - } - for (size_t j = 0; j < it->second.size(); ++j) { - initAbundance[it->second[j].result.dbKey] += it->second[j].result.score / scoreSum; - } - } - - if (queryCount > 0) { - const double denom = static_cast(queryCount); - for (std::unordered_map::iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { - it->second /= denom; - } - } - - for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - for (size_t j = 0; j < it->second.size(); ++j) { - it->second[j].abundance = initAbundance[it->second[j].result.dbKey]; - } - } -} - -static void initEntropy(MappingTable &mappingTable, const std::unordered_set &targetSet, double lambda) { - std::unordered_map targetMin; - std::unordered_map targetMax; - - for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { - targetMin[*it] = std::numeric_limits::max(); - targetMax[*it] = std::numeric_limits::min(); - } - - for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - for (size_t j = 0; j < it->second.size(); ++j) { - const unsigned int target = it->second[j].result.dbKey; - if (it->second[j].result.dbStartPos < targetMin[target]) { - targetMin[target] = it->second[j].result.dbStartPos; - } - if (it->second[j].result.dbEndPos > targetMax[target]) { - targetMax[target] = it->second[j].result.dbEndPos; - } - } - } - - std::unordered_map > coverage; - coverage.reserve(targetSet.size()); - for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { - const int start = targetMin[*it]; - const int end = targetMax[*it]; - const int len = (end >= start) ? (end - start + 1) : 1; - coverage[*it] = std::vector(len, 0.0); - } - - for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - for (size_t j = 0; j < it->second.size(); ++j) { - const Matcher::result_t &result = it->second[j].result; - const int targetLen = result.dbEndPos - result.dbStartPos + 1; - if (targetLen <= 0) { - continue; - } - - const double mq = std::exp(lambda * static_cast(result.score)) / static_cast(targetLen); - std::vector &cov = coverage[result.dbKey]; - const int start = std::max(0, result.dbStartPos - targetMin[result.dbKey]); - const int end = std::min(static_cast(cov.size()) - 1, result.dbEndPos - targetMin[result.dbKey]); - for (int pos = start; pos <= end; ++pos) { - cov[pos] += mq; - } - } - } - - std::unordered_map entropy; - entropy.reserve(targetSet.size()); - for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { - const std::vector &cov = coverage[*it]; - const double covSum = std::accumulate(cov.begin(), cov.end(), 0.0); - if (covSum <= 0.0) { - entropy[*it] = 0.0; - continue; - } - - double ent = 0.0; - for (size_t pos = 0; pos < cov.size(); ++pos) { - if (cov[pos] <= 0.0) { - continue; - } - const double p = cov[pos] / covSum; - ent -= p * std::log2(p); - } - entropy[*it] = ent; - } - - double entropySum = 0.0; - for (std::unordered_map::const_iterator it = entropy.begin(); it != entropy.end(); ++it) { - entropySum += it->second; - } - - for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - for (size_t j = 0; j < it->second.size(); ++j) { - const unsigned int target = it->second[j].result.dbKey; - it->second[j].entropyPenalty = (entropySum > 0.0) ? (1.0 - (entropy[target] / entropySum)) : 0.0; - } - } -} - -static double scoreTerm(const ReclassEntry &entry, double lambda, double alpha, double beta, double gamma) { - if (entry.result.alnLength == 0) { - return 0.0; - } - - const double bitPerLen = static_cast(entry.result.score) / static_cast(entry.result.alnLength); - const double seqId = std::max(static_cast(entry.result.seqId) / 100.0, EPS); - const double abundance = std::max(entry.abundance, EPS); - const double entropyPenalty = std::max(entry.entropyPenalty, EPS); - return std::exp(lambda * bitPerLen) * std::pow(seqId, beta) * std::pow(abundance, alpha) * std::pow(entropyPenalty, gamma); -} - -static void computePosterior(MappingTable &mappingTable, double lambda, double alpha, double beta, double gamma) { - for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - double denom = 0.0; - for (size_t j = 0; j < it->second.size(); ++j) { - denom += scoreTerm(it->second[j], lambda, alpha, beta, gamma); - } - for (size_t j = 0; j < it->second.size(); ++j) { - const double value = scoreTerm(it->second[j], lambda, alpha, beta, gamma); - it->second[j].posterior = (denom > 0.0) ? (value / denom) : 0.0; - } - } -} - -static double logLikelihood(const MappingTable &mappingTable, double lambda, double alpha, double beta, double gamma, size_t queryCount) { - if (queryCount == 0) { - return 0.0; - } - - double ll = 0.0; - for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - double denom = 0.0; - for (size_t j = 0; j < it->second.size(); ++j) { - const double value = scoreTerm(it->second[j], lambda, alpha, beta, gamma); - denom += value; - if (it->second[j].posterior > 0.0) { - ll += it->second[j].posterior * std::log(value > 0.0 ? value : 1e-300); - } - } - ll -= std::log(denom > 0.0 ? denom : 1e-300); - } - return ll / static_cast(queryCount); -} - -static std::vector abundanceVectorFromTable(const MappingTable &mappingTable, const std::vector &targetList) { - std::unordered_map abundance; - abundance.reserve(targetList.size()); - for (size_t i = 0; i < targetList.size(); ++i) { - abundance[targetList[i]] = 0.0; - } - - for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - for (size_t j = 0; j < it->second.size(); ++j) { - abundance[it->second[j].result.dbKey] = it->second[j].abundance; - } - } - - std::vector out; - out.reserve(targetList.size()); - for (size_t i = 0; i < targetList.size(); ++i) { - out.push_back(abundance[targetList[i]]); - } - return out; -} - -static void setAbundance(MappingTable &mappingTable, const std::vector &targetList, const std::vector &abundanceVector) { - std::unordered_map abundance; - abundance.reserve(targetList.size()); - for (size_t i = 0; i < targetList.size(); ++i) { - abundance[targetList[i]] = abundanceVector[i]; - } - - for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - for (size_t j = 0; j < it->second.size(); ++j) { - it->second[j].abundance = abundance[it->second[j].result.dbKey]; - } - } -} - -static std::vector emUpdate(MappingTable &mappingTable, - double lambda, - const std::unordered_map &fixedEntropy, - const std::vector &targetList, - size_t queryCount, - double alpha, - double beta, - double gamma) { - computePosterior(mappingTable, lambda, alpha, beta, gamma); - - std::unordered_map nextAbundance; - nextAbundance.reserve(targetList.size()); - for (size_t i = 0; i < targetList.size(); ++i) { - nextAbundance[targetList[i]] = 0.0; - } - - for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - for (size_t j = 0; j < it->second.size(); ++j) { - nextAbundance[it->second[j].result.dbKey] += it->second[j].posterior; - } - } - - if (queryCount > 0) { - const double denom = static_cast(queryCount); - for (std::unordered_map::iterator it = nextAbundance.begin(); it != nextAbundance.end(); ++it) { - it->second /= denom; - } - } - - for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - for (size_t j = 0; j < it->second.size(); ++j) { - const unsigned int target = it->second[j].result.dbKey; - it->second[j].abundance = nextAbundance[target]; - std::unordered_map::const_iterator fixed = fixedEntropy.find(target); - it->second[j].entropyPenalty = (fixed != fixedEntropy.end()) ? fixed->second : 0.0; - } - } - - std::vector out; - out.reserve(targetList.size()); - for (size_t i = 0; i < targetList.size(); ++i) { - out.push_back(nextAbundance[targetList[i]]); - } - return out; -} - -static std::vector projectSimplex(const std::vector &x) { - std::vector projected = x; - for (size_t i = 0; i < projected.size(); ++i) { - if (projected[i] < 0.0) { - projected[i] = 0.0; - } - } - - const double sum = std::accumulate(projected.begin(), projected.end(), 0.0); - if (sum > 0.0) { - for (size_t i = 0; i < projected.size(); ++i) { - projected[i] /= sum; - } - } - return projected; -} - -static void squarem(ReclassContext &ctx, double lambda, int maxIter, double tol, double alpha, double beta, double gamma) { - if (ctx.queryCount == 0 || ctx.targetSet.empty()) { - return; - } - - initAbundance(ctx.mappingTable, ctx.targetSet, ctx.queryCount); - initEntropy(ctx.mappingTable, ctx.targetSet, lambda); - - std::unordered_map fixedEntropy; - fixedEntropy.reserve(ctx.targetSet.size()); - for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { - for (size_t j = 0; j < it->second.size(); ++j) { - fixedEntropy[it->second[j].result.dbKey] = it->second[j].entropyPenalty; - } - } - - const std::vector targetList = targetListFromSet(ctx.targetSet); - std::vector x0 = abundanceVectorFromTable(ctx.mappingTable, targetList); - std::vector logLikelihoods; - - for (int iter = 0; iter < maxIter; ++iter) { - const std::vector x1 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, beta, gamma); - const std::vector x2 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, beta, gamma); - - std::vector r(x0.size(), 0.0); - std::vector v(x0.size(), 0.0); - for (size_t i = 0; i < x0.size(); ++i) { - r[i] = x1[i] - x0[i]; - v[i] = x2[i] - x1[i] - r[i]; - } - - double normR = 0.0; - double normV = 0.0; - for (size_t i = 0; i < x0.size(); ++i) { - normR += r[i] * r[i]; - normV += v[i] * v[i]; - } - normR = std::sqrt(normR); - normV = std::sqrt(normV); - - double accel = (normV == 0.0) ? -1.0 : -(normR / normV); - accel = std::max(STEP_MIN, std::min(STEP_MAX, accel)); - - std::vector xNew(x0.size(), 0.0); - for (size_t i = 0; i < x0.size(); ++i) { - xNew[i] = x0[i] - 2.0 * accel * r[i] + accel * accel * v[i]; - } - xNew = projectSimplex(xNew); - - setAbundance(ctx.mappingTable, targetList, xNew); - computePosterior(ctx.mappingTable, lambda, alpha, beta, gamma); - double currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, beta, gamma, ctx.queryCount); - - if (!logLikelihoods.empty() && currentLl < logLikelihoods.back() - 1e-9) { - setAbundance(ctx.mappingTable, targetList, x2); - computePosterior(ctx.mappingTable, lambda, alpha, beta, gamma); - currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, beta, gamma, ctx.queryCount); - xNew = x2; - } - - logLikelihoods.push_back(currentLl); - - double parameterChange = 0.0; - for (size_t i = 0; i < x0.size(); ++i) { - parameterChange = std::max(parameterChange, std::fabs(xNew[i] - x0[i])); - } - - Debug(Debug::INFO) << "Reclassify iteration " << iter << ": LL=" << currentLl << " delta=" << parameterChange << "\n"; - x0 = xNew; - if (parameterChange < tol && iter > 5) { - Debug(Debug::INFO) << "Reclassify converged after " << (iter + 1) << " iterations.\n"; - break; - } - } -} - -static bool compareByPosterior(const ReclassEntry &a, const ReclassEntry &b) { - if (a.posterior != b.posterior) { - return a.posterior > b.posterior; - } - return Matcher::compareHits(a.result, b.result); -} -} - -int reclassify(int argc, const char **argv, const Command &command) { - Parameters &par = Parameters::getInstance(); - par.parseParameters(argc, argv, command, true, 0, 0); - - DBReader reader(par.db1.c_str(), par.db1Index.c_str(), par.threads, - DBReader::USE_INDEX | DBReader::USE_DATA); - reader.open(DBReader::LINEAR_ACCCESS); - - ReclassContext ctx; - loadAlignmentDb(reader, ctx); - Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; - - squarem(ctx, DEFAULT_LAMBDA, DEFAULT_MAX_ITER, DEFAULT_TOL, DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_GAMMA); - - DBWriter writer(par.db2.c_str(), par.db2Index.c_str(), par.threads, par.compressed, reader.getDbtype()); - writer.open(); - - Debug::Progress progress(reader.getSize()); - char buffer[1024 + 32768 * 4]; - for (size_t i = 0; i < reader.getSize(); ++i) { - progress.updateProgress(); - const unsigned int queryKey = reader.getDbKey(i); - MappingTable::iterator it = ctx.mappingTable.find(queryKey); - if (it == ctx.mappingTable.end() || it->second.empty()) { - writer.writeData("", 0, queryKey, 0); - continue; - } - - SORT_SERIAL(it->second.begin(), it->second.end(), compareByPosterior); - writer.writeStart(0); - for (size_t j = 0; j < it->second.size(); ++j) { - const size_t len = Matcher::resultToBuffer(buffer, it->second[j].result, ctx.hasBacktrace, false, ctx.hasOrfPosition); - writer.writeAdd(buffer, len, 0); - } - writer.writeEnd(queryKey, 0); - } - - writer.close(); - reader.close(); - return EXIT_SUCCESS; -} diff --git a/src/util/reclassify_taxonomy.cpp b/src/util/reclassify_taxonomy.cpp index 5d71056f0..d6ebc4dae 100644 --- a/src/util/reclassify_taxonomy.cpp +++ b/src/util/reclassify_taxonomy.cpp @@ -232,41 +232,50 @@ static void initEntropy(MappingTable &mappingTable, const std::unordered_set &entries) { + double maxScore = 0.0; + for (size_t i = 0; i < entries.size(); ++i) { + maxScore = std::max(maxScore, static_cast(entries[i].result.score)); + } + return maxScore; +} + +static double scoreTerm(const ReclassTaxEntry &entry, double queryMaxScore, double lambda, double alpha, double gamma) { + if (queryMaxScore <= 0.0) { return 0.0; } - const double bitPerLen = static_cast(entry.result.score) / static_cast(entry.result.alnLength); - const double seqId = std::max(static_cast(entry.result.seqId) / 100.0, EPS); + const double normalizedScore = static_cast(entry.result.score) / queryMaxScore; const double abundance = std::max(entry.abundance, EPS); const double entropyPenalty = std::max(entry.entropyPenalty, EPS); - return std::exp(lambda * bitPerLen) * std::pow(seqId, beta) * std::pow(abundance, alpha) * std::pow(entropyPenalty, gamma); + return std::exp(lambda * normalizedScore) * std::pow(abundance, alpha) * std::pow(entropyPenalty, gamma); } -static void computePosterior(MappingTable &mappingTable, double lambda, double alpha, double beta, double gamma) { +static void computePosterior(MappingTable &mappingTable, double lambda, double alpha, double gamma) { for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + const double queryMaxScore = maxQueryBitScore(it->second); double denom = 0.0; for (size_t j = 0; j < it->second.size(); ++j) { - denom += scoreTerm(it->second[j], lambda, alpha, beta, gamma); + denom += scoreTerm(it->second[j], queryMaxScore, lambda, alpha, gamma); } for (size_t j = 0; j < it->second.size(); ++j) { - const double value = scoreTerm(it->second[j], lambda, alpha, beta, gamma); + const double value = scoreTerm(it->second[j], queryMaxScore, lambda, alpha, gamma); it->second[j].posterior = (denom > 0.0) ? (value / denom) : 0.0; } } } -static double logLikelihood(const MappingTable &mappingTable, double lambda, double alpha, double beta, double gamma, size_t queryCount) { +static double logLikelihood(const MappingTable &mappingTable, double lambda, double alpha, double gamma, size_t queryCount) { if (queryCount == 0) { return 0.0; } double ll = 0.0; for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + const double queryMaxScore = maxQueryBitScore(it->second); double denom = 0.0; for (size_t j = 0; j < it->second.size(); ++j) { - const double value = scoreTerm(it->second[j], lambda, alpha, beta, gamma); + const double value = scoreTerm(it->second[j], queryMaxScore, lambda, alpha, gamma); denom += value; if (it->second[j].posterior > 0.0) { ll += it->second[j].posterior * std::log(value > 0.0 ? value : 1e-300); @@ -318,9 +327,8 @@ static std::vector emUpdate(MappingTable &mappingTable, const std::vector &targetList, size_t queryCount, double alpha, - double beta, double gamma) { - computePosterior(mappingTable, lambda, alpha, beta, gamma); + computePosterior(mappingTable, lambda, alpha, gamma); std::unordered_map nextAbundance; nextAbundance.reserve(targetList.size()); @@ -381,7 +389,7 @@ static std::vector projectSimplex(const std::vector &x) { return projected; } -static void squarem(ReclassTaxContext &ctx, double lambda, int maxIter, double tol, double alpha, double beta, double gamma) { +static void squarem(ReclassTaxContext &ctx, double lambda, int maxIter, double tol, double alpha, double gamma) { if (ctx.queryCount == 0 || ctx.targetSet.empty()) { return; } @@ -402,8 +410,8 @@ static void squarem(ReclassTaxContext &ctx, double lambda, int maxIter, double t std::vector logLikelihoods; for (int iter = 0; iter < maxIter; ++iter) { - const std::vector x1 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, beta, gamma); - const std::vector x2 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, beta, gamma); + const std::vector x1 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, gamma); + const std::vector x2 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, gamma); std::vector r(x0.size(), 0.0); std::vector v(x0.size(), 0.0); @@ -431,13 +439,13 @@ static void squarem(ReclassTaxContext &ctx, double lambda, int maxIter, double t xNew = projectSimplex(xNew); setAbundance(ctx.mappingTable, targetList, xNew); - computePosterior(ctx.mappingTable, lambda, alpha, beta, gamma); - double currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, beta, gamma, ctx.queryCount); + computePosterior(ctx.mappingTable, lambda, alpha, gamma); + double currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, gamma, ctx.queryCount); if (!logLikelihoods.empty() && currentLl < logLikelihoods.back() - 1e-9) { setAbundance(ctx.mappingTable, targetList, x2); - computePosterior(ctx.mappingTable, lambda, alpha, beta, gamma); - currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, beta, gamma, ctx.queryCount); + computePosterior(ctx.mappingTable, lambda, alpha, gamma); + currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, gamma, ctx.queryCount); xNew = x2; } @@ -759,7 +767,6 @@ int reclassifytaxonomy(int argc, const char **argv, const Command &command) { par.reclassifyMaxIterations, par.reclassifyTolerance, par.reclassifyAlpha, - par.reclassifyBeta, par.reclassifyGamma); const std::string outDir = par.db4; From 1d1f3d8380229ab6ec0ad0601a4e0a1d017ffa6c Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Thu, 26 Mar 2026 13:28:50 +0900 Subject: [PATCH 04/12] Rename reclassify CLI parameters --- src/commons/Parameters.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/commons/Parameters.cpp b/src/commons/Parameters.cpp index 9be8bb260..ffa52c439 100644 --- a/src/commons/Parameters.cpp +++ b/src/commons/Parameters.cpp @@ -309,11 +309,11 @@ Parameters::Parameters(): PARAM_UNPACK_SUFFIX(PARAM_UNPACK_SUFFIX_ID, "--unpack-suffix", "Unpack suffix", "File suffix for unpacked files.\nAdd .gz suffix to write compressed files.", typeid(std::string), (void *) &unpackSuffix, "^.*$"), PARAM_UNPACK_NAME_MODE(PARAM_UNPACK_NAME_MODE_ID, "--unpack-name-mode", "Unpack name mode", "Name unpacked files by 0: DB key, 1: accession (through .lookup)", typeid(int), (void *) &unpackNameMode, "^[0-1]{1}$"), // reclassify - PARAM_RECLASSIFY_LAMBDA(PARAM_RECLASSIFY_LAMBDA_ID, "--reclassify-lambda", "Reclassify lambda", "Lambda scaling factor for the reclassification score term", typeid(double), (void *) &reclassifyLambda, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), - PARAM_RECLASSIFY_ALPHA(PARAM_RECLASSIFY_ALPHA_ID, "--reclassify-alpha", "Reclassify alpha", "Exponent applied to abundance during reclassification", typeid(double), (void *) &reclassifyAlpha, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), - PARAM_RECLASSIFY_GAMMA(PARAM_RECLASSIFY_GAMMA_ID, "--reclassify-gamma", "Reclassify gamma", "Exponent applied to entropy penalty during reclassification", typeid(double), (void *) &reclassifyGamma, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), - PARAM_RECLASSIFY_MAX_ITER(PARAM_RECLASSIFY_MAX_ITER_ID, "--reclassify-max-iter", "Reclassify max iterations", "Maximum number of SQUAREM iterations for reclassification", typeid(int), (void *) &reclassifyMaxIterations, "^[1-9]{1}[0-9]*$"), - PARAM_RECLASSIFY_TOL(PARAM_RECLASSIFY_TOL_ID, "--reclassify-tol", "Reclassify tolerance", "Convergence tolerance for reclassification", typeid(double), (void *) &reclassifyTolerance, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_LAMBDA(PARAM_RECLASSIFY_LAMBDA_ID, "--lambda", "Reclassify lambda", "Lambda scaling factor for the reclassification score term", typeid(double), (void *) &reclassifyLambda, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_ALPHA(PARAM_RECLASSIFY_ALPHA_ID, "--alpha", "Reclassify alpha", "Exponent applied to abundance during reclassification", typeid(double), (void *) &reclassifyAlpha, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_GAMMA(PARAM_RECLASSIFY_GAMMA_ID, "--gamma", "Reclassify gamma", "Exponent applied to entropy penalty during reclassification", typeid(double), (void *) &reclassifyGamma, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_MAX_ITER(PARAM_RECLASSIFY_MAX_ITER_ID, "--max-iter", "Reclassify max iterations", "Maximum number of SQUAREM iterations for reclassification", typeid(int), (void *) &reclassifyMaxIterations, "^[1-9]{1}[0-9]*$"), + PARAM_RECLASSIFY_TOL(PARAM_RECLASSIFY_TOL_ID, "--tol", "Reclassify tolerance", "Convergence tolerance for reclassification", typeid(double), (void *) &reclassifyTolerance, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), PARAM_RECLASSIFY_TAXONOMY(PARAM_RECLASSIFY_TAXONOMY_ID, "--taxonomy", "Reclassify taxonomy output", "0: write alignment and protein abundance only, 1: also write taxonomy_abundance.tsv", typeid(int), (void *) &reclassifyTaxonomy, "^[0-1]{1}$"), // for modules that should handle -h themselves PARAM_HELP(PARAM_HELP_ID, "-h", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN), From ac5ef7a5298700693166cfa936fe2b9943039eec Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Fri, 27 Mar 2026 15:09:27 +0900 Subject: [PATCH 05/12] Update reclassify taxonomy target filtering --- src/util/reclassify_taxonomy.cpp | 161 ++++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 5 deletions(-) diff --git a/src/util/reclassify_taxonomy.cpp b/src/util/reclassify_taxonomy.cpp index d6ebc4dae..8e2e01bf8 100644 --- a/src/util/reclassify_taxonomy.cpp +++ b/src/util/reclassify_taxonomy.cpp @@ -40,6 +40,7 @@ struct TargetStats { double abundance; double entropyValue; double entropyPenalty; + bool dropped; std::vector intervals; }; @@ -591,6 +592,7 @@ static std::vector collectTargetStats(const ReclassTaxContext &ctx) stats.abundance = entry.abundance; stats.entropyValue = entry.entropyValue; stats.entropyPenalty = entry.entropyPenalty; + stats.dropped = false; addInterval(stats.intervals, entry.result.dbStartPos, entry.result.dbEndPos); } } @@ -611,6 +613,140 @@ static std::vector collectTargetStats(const ReclassTaxContext &ctx) return out; } +static bool largestJumpCutoff(std::vector values, double &cutoff) { + cutoff = 0.0; + if (values.size() < 4) { + return false; + } + + std::sort(values.begin(), values.end()); + double bestGap = 0.0; + size_t bestIdx = 0; + for (size_t i = 0; i + 1 < values.size(); ++i) { + const double gap = values[i + 1] - values[i]; + if (gap > bestGap) { + bestGap = gap; + bestIdx = i; + } + } + + if (bestGap <= EPS) { + return false; + } + + cutoff = 0.5 * (values[bestIdx] + values[bestIdx + 1]); + return true; +} + +static std::unordered_set selectDroppedTargets(const std::vector &stats, + double &abundanceCutoff, + double &entropyCutoff) { + std::unordered_set dropped; + if (stats.empty()) { + abundanceCutoff = 0.0; + entropyCutoff = 0.0; + return dropped; + } + if (stats.size() < 4) { + abundanceCutoff = 0.0; + entropyCutoff = 0.0; + return dropped; + } + + std::vector abundances; + std::vector entropies; + abundances.reserve(stats.size()); + entropies.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + abundances.push_back(stats[i].abundance); + entropies.push_back(stats[i].entropyValue); + } + + const bool hasAbundanceCutoff = largestJumpCutoff(abundances, abundanceCutoff); + const bool hasEntropyCutoff = largestJumpCutoff(entropies, entropyCutoff); + if (hasAbundanceCutoff == false || hasEntropyCutoff == false) { + abundanceCutoff = 0.0; + entropyCutoff = 0.0; + return dropped; + } + + for (size_t i = 0; i < stats.size(); ++i) { + if (stats[i].abundance <= abundanceCutoff && stats[i].entropyValue >= entropyCutoff) { + dropped.insert(stats[i].key); + } + } + if (dropped.size() == stats.size()) { + dropped.clear(); + } + return dropped; +} + +static void applyDroppedTargets(ReclassTaxContext &ctx, + const std::unordered_set &dropped, + size_t totalTargets, + double abundanceCutoff, + double entropyCutoff) { + if (dropped.empty()) { + Debug(Debug::INFO) << "Reclassify-taxonomy target filter kept all targets. abundance cutoff=" + << abundanceCutoff << " entropy cutoff=" << entropyCutoff << "\n"; + return; + } + + for (MappingTable::iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end();) { + std::vector &records = it->second; + records.erase(std::remove_if(records.begin(), records.end(), [&dropped](const ReclassTaxEntry &entry) { + return dropped.find(entry.result.dbKey) != dropped.end(); + }), records.end()); + + if (records.empty()) { + it = ctx.mappingTable.erase(it); + } else { + ++it; + } + } + + ctx.queryOrder.erase(std::remove_if(ctx.queryOrder.begin(), ctx.queryOrder.end(), [&ctx](unsigned int queryKey) { + return ctx.mappingTable.find(queryKey) == ctx.mappingTable.end(); + }), ctx.queryOrder.end()); + + for (std::unordered_set::const_iterator it = dropped.begin(); it != dropped.end(); ++it) { + ctx.targetSet.erase(*it); + } + + const double removedPct = (totalTargets > 0) + ? (100.0 * static_cast(dropped.size()) / static_cast(totalTargets)) + : 0.0; + Debug(Debug::INFO) << "Reclassify-taxonomy dropped " << dropped.size() + << " of " << totalTargets + << " targets (" << removedPct << "%)" + << " using abundance <= " << abundanceCutoff + << " and entropy >= " << entropyCutoff << ".\n"; +} + +static void markDroppedTargets(std::vector &stats, const std::unordered_set &dropped) { + for (size_t i = 0; i < stats.size(); ++i) { + stats[i].dropped = (dropped.find(stats[i].key) != dropped.end()); + } +} + +static void convertAbundanceToPercent(std::vector &stats) { + double total = 0.0; + for (size_t i = 0; i < stats.size(); ++i) { + total += stats[i].abundance; + } + + if (total <= 0.0) { + for (size_t i = 0; i < stats.size(); ++i) { + stats[i].abundance = 0.0; + } + return; + } + + for (size_t i = 0; i < stats.size(); ++i) { + stats[i].abundance = 100.0 * (stats[i].abundance / total); + } +} + static void writeReclassifiedM8(const ReclassTaxContext &ctx, DBReader &queryHeaderReader, DBReader &targetHeaderReader, @@ -662,7 +798,7 @@ static void writeProteinStats(const std::vector &stats, NcbiTaxonomy *taxonomy, const std::string &path) { FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); - fputs("target_key\ttarget_id\tabundance\tentropy\tmapping_parts\tmapped_length\ttarget_length\ttaxid\trank\ttaxname\ttaxlineage\n", handle); + fputs("target_key\ttarget_id\tabundance_pct\tentropy\tDrop(y/n)\tmapping_parts\tmapped_length\ttarget_length\ttaxid\trank\ttaxname\ttaxlineage\n", handle); for (size_t i = 0; i < stats.size(); ++i) { const unsigned int key = stats[i].key; @@ -673,11 +809,12 @@ static void writeProteinStats(const std::vector &stats, const std::string parts = intervalsToString(stats[i].intervals); const unsigned int mappedLength = intervalCoverage(stats[i].intervals); - fprintf(handle, "%u\t%s\t%.12g\t%.12g\t%s\t%u\t%u\t%u\t%s\t%s\t%s\n", + fprintf(handle, "%u\t%s\t%.12g\t%.12g\t%s\t%s\t%u\t%u\t%u\t%s\t%s\t%s\n", key, targetId.c_str(), stats[i].abundance, stats[i].entropyValue, + stats[i].dropped ? "y" : "n", parts.c_str(), mappedLength, stats[i].targetLength, @@ -718,7 +855,7 @@ static void writeTaxonomyStats(const std::vector &stats, }); FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); - fputs("taxid\trank\ttaxname\ttaxlineage\tprotein_abundance\tprotein_count\tmean_entropy\tmean_entropy_penalty\n", handle); + fputs("taxid\trank\ttaxname\ttaxlineage\tprotein_abundance_pct\tprotein_count\tmean_entropy\tmean_entropy_penalty\n", handle); for (size_t i = 0; i < rows.size(); ++i) { const TaxonNode *node = (rows[i].taxId != 0) ? taxonomy->taxonNode(rows[i].taxId, false) : NULL; @@ -774,10 +911,24 @@ int reclassifytaxonomy(int argc, const char **argv, const Command &command) { const std::string proteinPath = outDir + "/protein_abundance.tsv"; const std::string taxonomyPath = outDir + "/taxonomy_abundance.tsv"; - const std::vector targetStats = collectTargetStats(ctx); + std::vector allTargetStats = collectTargetStats(ctx); + double abundanceCutoff = 0.0; + double entropyCutoff = 0.0; + const std::unordered_set dropped = selectDroppedTargets(allTargetStats, + abundanceCutoff, + entropyCutoff); + markDroppedTargets(allTargetStats, dropped); + convertAbundanceToPercent(allTargetStats); + + std::vector targetStats = allTargetStats; + targetStats.erase(std::remove_if(targetStats.begin(), targetStats.end(), [](const TargetStats &entry) { + return entry.dropped; + }), targetStats.end()); + applyDroppedTargets(ctx, dropped, allTargetStats.size(), abundanceCutoff, entropyCutoff); + convertAbundanceToPercent(targetStats); writeReclassifiedM8(ctx, queryHeaderReader, targetHeaderReader, m8Path); - writeProteinStats(targetStats, targetHeaderReader, mapping, taxonomy, proteinPath); + writeProteinStats(allTargetStats, targetHeaderReader, mapping, taxonomy, proteinPath); if (par.reclassifyTaxonomy == 1) { writeTaxonomyStats(targetStats, mapping, taxonomy, taxonomyPath); } From c16b63c263fd761543d60994bbacd89a9c7d5505 Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Fri, 27 Mar 2026 17:38:51 +0900 Subject: [PATCH 06/12] Refine reclassify abundance filtering --- src/commons/Parameters.cpp | 3 + src/commons/Parameters.h | 2 + src/util/reclassify_taxonomy.cpp | 132 +++++++++++++++++++++++-------- 3 files changed, 106 insertions(+), 31 deletions(-) diff --git a/src/commons/Parameters.cpp b/src/commons/Parameters.cpp index ffa52c439..8f51ea5f6 100644 --- a/src/commons/Parameters.cpp +++ b/src/commons/Parameters.cpp @@ -315,6 +315,7 @@ Parameters::Parameters(): PARAM_RECLASSIFY_MAX_ITER(PARAM_RECLASSIFY_MAX_ITER_ID, "--max-iter", "Reclassify max iterations", "Maximum number of SQUAREM iterations for reclassification", typeid(int), (void *) &reclassifyMaxIterations, "^[1-9]{1}[0-9]*$"), PARAM_RECLASSIFY_TOL(PARAM_RECLASSIFY_TOL_ID, "--tol", "Reclassify tolerance", "Convergence tolerance for reclassification", typeid(double), (void *) &reclassifyTolerance, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), PARAM_RECLASSIFY_TAXONOMY(PARAM_RECLASSIFY_TAXONOMY_ID, "--taxonomy", "Reclassify taxonomy output", "0: write alignment and protein abundance only, 1: also write taxonomy_abundance.tsv", typeid(int), (void *) &reclassifyTaxonomy, "^[0-1]{1}$"), + PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE(PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE_ID, "--max-drop-percentage", "Max drop percentage", "Maximum percentage of targets that the automatic jump-based filter may classify as a tail for dropping (range 0.0-100.0)", typeid(double), (void *) &reclassifyMaxDropPercentage, "^100(\\.0+)?$|^([0-9]|[1-9][0-9])(\\.[0-9]+)?$"), // for modules that should handle -h themselves PARAM_HELP(PARAM_HELP_ID, "-h", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN), PARAM_HELP_LONG(PARAM_HELP_LONG_ID, "--help", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN) @@ -349,6 +350,7 @@ Parameters::Parameters(): reclassify.push_back(&PARAM_RECLASSIFY_MAX_ITER); reclassify.push_back(&PARAM_RECLASSIFY_TOL); reclassify.push_back(&PARAM_RECLASSIFY_TAXONOMY); + reclassify.push_back(&PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE); reclassify.push_back(&PARAM_THREADS); reclassify.push_back(&PARAM_COMPRESSED); reclassify.push_back(&PARAM_V); @@ -2662,6 +2664,7 @@ void Parameters::setDefaults() { reclassifyMaxIterations = 100; reclassifyTolerance = 1e-5; reclassifyTaxonomy = 0; + reclassifyMaxDropPercentage = 20.0; lcaRanks = ""; showTaxLineage = 0; diff --git a/src/commons/Parameters.h b/src/commons/Parameters.h index 6f7a62ddd..41c67ffd7 100644 --- a/src/commons/Parameters.h +++ b/src/commons/Parameters.h @@ -731,6 +731,7 @@ class Parameters { int reclassifyMaxIterations; double reclassifyTolerance; int reclassifyTaxonomy; + double reclassifyMaxDropPercentage; // for modules that should handle -h themselves bool help; @@ -1096,6 +1097,7 @@ class Parameters { PARAMETER(PARAM_RECLASSIFY_MAX_ITER) PARAMETER(PARAM_RECLASSIFY_TOL) PARAMETER(PARAM_RECLASSIFY_TAXONOMY) + PARAMETER(PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE) // for modules that should handle -h themselves PARAMETER(PARAM_HELP) diff --git a/src/util/reclassify_taxonomy.cpp b/src/util/reclassify_taxonomy.cpp index 8e2e01bf8..6e30198c2 100644 --- a/src/util/reclassify_taxonomy.cpp +++ b/src/util/reclassify_taxonomy.cpp @@ -68,6 +68,8 @@ struct ReclassTaxContext { static const double STEP_MIN = -1.0; static const double STEP_MAX = 1.0; static const double EPS = 1e-12; +static const size_t MIN_FILTER_TARGETS = 20; +static const size_t MIN_TAIL_TARGETS = 2; static std::vector targetListFromSet(const std::unordered_set &targets) { std::vector out(targets.begin(), targets.end()); @@ -613,20 +615,40 @@ static std::vector collectTargetStats(const ReclassTaxContext &ctx) return out; } -static bool largestJumpCutoff(std::vector values, double &cutoff) { +static double clamp01(double value) { + return std::max(0.0, std::min(1.0, value)); +} + +static bool largestJumpCutoff(std::vector values, + bool useLowTail, + double maxTailFraction, + double &cutoff, + size_t &tailCount) { cutoff = 0.0; - if (values.size() < 4) { + tailCount = 0; + if (values.size() < MIN_FILTER_TARGETS) { return false; } std::sort(values.begin(), values.end()); + maxTailFraction = clamp01(maxTailFraction); + const size_t maxTailCount = std::max(MIN_TAIL_TARGETS, + static_cast(std::floor(maxTailFraction * static_cast(values.size())))); double bestGap = 0.0; size_t bestIdx = 0; for (size_t i = 0; i + 1 < values.size(); ++i) { + const size_t lowTailCount = i + 1; + const size_t highTailCount = values.size() - lowTailCount; + const size_t candidateTailCount = useLowTail ? lowTailCount : highTailCount; + if (candidateTailCount < MIN_TAIL_TARGETS || candidateTailCount > maxTailCount) { + continue; + } + const double gap = values[i + 1] - values[i]; if (gap > bestGap) { bestGap = gap; bestIdx = i; + tailCount = candidateTailCount; } } @@ -638,42 +660,94 @@ static bool largestJumpCutoff(std::vector values, double &cutoff) { return true; } +static bool tailQuantileCutoff(std::vector values, + bool useLowTail, + double maxTailFraction, + double &cutoff, + size_t &tailCount) { + cutoff = 0.0; + tailCount = 0; + if (values.size() < MIN_FILTER_TARGETS) { + return false; + } + + std::sort(values.begin(), values.end()); + maxTailFraction = clamp01(maxTailFraction); + const size_t maxTailCount = std::max(MIN_TAIL_TARGETS, + static_cast(std::floor(maxTailFraction * static_cast(values.size())))); + if (maxTailCount < MIN_TAIL_TARGETS || maxTailCount >= values.size()) { + return false; + } + + tailCount = maxTailCount; + if (useLowTail) { + cutoff = values[tailCount - 1]; + } else { + cutoff = values[values.size() - tailCount]; + } + return true; +} + +static std::unordered_set selectTailTargets(const std::vector &stats, + bool useLowTail, + size_t tailCount) { + std::vector ordered; + ordered.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + ordered.push_back(&stats[i]); + } + + std::sort(ordered.begin(), ordered.end(), [useLowTail](const TargetStats *lhs, const TargetStats *rhs) { + const double lhsValue = useLowTail ? lhs->abundance : lhs->entropyValue; + const double rhsValue = useLowTail ? rhs->abundance : rhs->entropyValue; + if (lhsValue != rhsValue) { + return useLowTail ? (lhsValue < rhsValue) : (lhsValue > rhsValue); + } + return lhs->key < rhs->key; + }); + + std::unordered_set selected; + const size_t limit = std::min(tailCount, ordered.size()); + selected.reserve(limit); + for (size_t i = 0; i < limit; ++i) { + selected.insert(ordered[i]->key); + } + return selected; +} + static std::unordered_set selectDroppedTargets(const std::vector &stats, - double &abundanceCutoff, - double &entropyCutoff) { + double maxDropPercentage, + double &abundanceCutoff) { std::unordered_set dropped; if (stats.empty()) { abundanceCutoff = 0.0; - entropyCutoff = 0.0; return dropped; } - if (stats.size() < 4) { + if (stats.size() < MIN_FILTER_TARGETS) { abundanceCutoff = 0.0; - entropyCutoff = 0.0; return dropped; } std::vector abundances; - std::vector entropies; abundances.reserve(stats.size()); - entropies.reserve(stats.size()); for (size_t i = 0; i < stats.size(); ++i) { abundances.push_back(stats[i].abundance); - entropies.push_back(stats[i].entropyValue); } - const bool hasAbundanceCutoff = largestJumpCutoff(abundances, abundanceCutoff); - const bool hasEntropyCutoff = largestJumpCutoff(entropies, entropyCutoff); - if (hasAbundanceCutoff == false || hasEntropyCutoff == false) { + const double maxTailFraction = clamp01(maxDropPercentage / 100.0); + size_t abundanceTailCount = 0; + bool hasAbundanceCutoff = largestJumpCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); + if (hasAbundanceCutoff == false) { + hasAbundanceCutoff = tailQuantileCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); + } + if (hasAbundanceCutoff == false) { abundanceCutoff = 0.0; - entropyCutoff = 0.0; return dropped; } - for (size_t i = 0; i < stats.size(); ++i) { - if (stats[i].abundance <= abundanceCutoff && stats[i].entropyValue >= entropyCutoff) { - dropped.insert(stats[i].key); - } + const std::unordered_set lowAbundanceTargets = selectTailTargets(stats, true, abundanceTailCount); + for (std::unordered_set::const_iterator it = lowAbundanceTargets.begin(); it != lowAbundanceTargets.end(); ++it) { + dropped.insert(*it); } if (dropped.size() == stats.size()) { dropped.clear(); @@ -684,11 +758,10 @@ static std::unordered_set selectDroppedTargets(const std::vector &dropped, size_t totalTargets, - double abundanceCutoff, - double entropyCutoff) { + double abundanceCutoff) { if (dropped.empty()) { Debug(Debug::INFO) << "Reclassify-taxonomy target filter kept all targets. abundance cutoff=" - << abundanceCutoff << " entropy cutoff=" << entropyCutoff << "\n"; + << abundanceCutoff << "\n"; return; } @@ -719,8 +792,7 @@ static void applyDroppedTargets(ReclassTaxContext &ctx, Debug(Debug::INFO) << "Reclassify-taxonomy dropped " << dropped.size() << " of " << totalTargets << " targets (" << removedPct << "%)" - << " using abundance <= " << abundanceCutoff - << " and entropy >= " << entropyCutoff << ".\n"; + << " using abundance <= " << abundanceCutoff << ".\n"; } static void markDroppedTargets(std::vector &stats, const std::unordered_set &dropped) { @@ -855,21 +927,20 @@ static void writeTaxonomyStats(const std::vector &stats, }); FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); - fputs("taxid\trank\ttaxname\ttaxlineage\tprotein_abundance_pct\tprotein_count\tmean_entropy\tmean_entropy_penalty\n", handle); + fputs("taxid\trank\ttaxname\ttaxlineage\tprotein_abundance_pct\tprotein_count\tmean_entropy\n", handle); for (size_t i = 0; i < rows.size(); ++i) { const TaxonNode *node = (rows[i].taxId != 0) ? taxonomy->taxonNode(rows[i].taxId, false) : NULL; const std::string lineage = (node != NULL) ? taxonomy->taxLineage(node, true) : "unclassified"; const double denom = (rows[i].proteinCount > 0) ? static_cast(rows[i].proteinCount) : 1.0; - fprintf(handle, "%u\t%s\t%s\t%s\t%.12g\t%zu\t%.12g\t%.12g\n", + fprintf(handle, "%u\t%s\t%s\t%s\t%.12g\t%zu\t%.12g\n", rows[i].taxId, (node != NULL) ? taxonomy->getString(node->rankIdx) : "unclassified", (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified", lineage.c_str(), rows[i].abundance, rows[i].proteinCount, - rows[i].entropySum / denom, - rows[i].entropyPenaltySum / denom); + rows[i].entropySum / denom); } fclose(handle); @@ -913,10 +984,9 @@ int reclassifytaxonomy(int argc, const char **argv, const Command &command) { std::vector allTargetStats = collectTargetStats(ctx); double abundanceCutoff = 0.0; - double entropyCutoff = 0.0; const std::unordered_set dropped = selectDroppedTargets(allTargetStats, - abundanceCutoff, - entropyCutoff); + par.reclassifyMaxDropPercentage, + abundanceCutoff); markDroppedTargets(allTargetStats, dropped); convertAbundanceToPercent(allTargetStats); @@ -924,7 +994,7 @@ int reclassifytaxonomy(int argc, const char **argv, const Command &command) { targetStats.erase(std::remove_if(targetStats.begin(), targetStats.end(), [](const TargetStats &entry) { return entry.dropped; }), targetStats.end()); - applyDroppedTargets(ctx, dropped, allTargetStats.size(), abundanceCutoff, entropyCutoff); + applyDroppedTargets(ctx, dropped, allTargetStats.size(), abundanceCutoff); convertAbundanceToPercent(targetStats); writeReclassifiedM8(ctx, queryHeaderReader, targetHeaderReader, m8Path); From 2ab1ef7c7b938c6d4a1e6915f519e89081831f51 Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Mon, 6 Apr 2026 16:43:39 +0900 Subject: [PATCH 07/12] 0406 updated EM algorithm --- src/MMseqsBase.cpp | 3 +- src/commons/Parameters.cpp | 10 +- src/util/reclassify_taxonomy.cpp | 253 ++++++++++++++++++------------- 3 files changed, 154 insertions(+), 112 deletions(-) diff --git a/src/MMseqsBase.cpp b/src/MMseqsBase.cpp index e8c04ae14..eab17f897 100644 --- a/src/MMseqsBase.cpp +++ b/src/MMseqsBase.cpp @@ -1075,11 +1075,12 @@ std::vector baseCommands = { {"reclassify", reclassifytaxonomy, &par.reclassify, COMMAND_RESULT | COMMAND_FORMAT_CONVERSION, "Reclassify alignments and export default flat-file summaries", "mmseqs reclassify queryDB targetDB alignmentDB outDir\n" + "# targetDB_mapping and targetDB_taxonomy are only required with --taxonomy 1\n" "mmseqs reclassify queryDB targetDB alignmentDB outDir --taxonomy 1\n", "Yaeji Kim", " ", CITATION_MMSEQS2 | CITATION_TAXONOMY, {{"queryDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::sequenceDb }, - {"targetDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER|DbType::NEED_TAXONOMY, &DbValidator::taxSequenceDb }, + {"targetDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::taxSequenceDb }, {"alignmentDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }, {"outDir", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::directory }}}, {"summarizealis", summarizealis, &par.threadsandcompression, COMMAND_RESULT, diff --git a/src/commons/Parameters.cpp b/src/commons/Parameters.cpp index 8f51ea5f6..acbccd66a 100644 --- a/src/commons/Parameters.cpp +++ b/src/commons/Parameters.cpp @@ -309,13 +309,13 @@ Parameters::Parameters(): PARAM_UNPACK_SUFFIX(PARAM_UNPACK_SUFFIX_ID, "--unpack-suffix", "Unpack suffix", "File suffix for unpacked files.\nAdd .gz suffix to write compressed files.", typeid(std::string), (void *) &unpackSuffix, "^.*$"), PARAM_UNPACK_NAME_MODE(PARAM_UNPACK_NAME_MODE_ID, "--unpack-name-mode", "Unpack name mode", "Name unpacked files by 0: DB key, 1: accession (through .lookup)", typeid(int), (void *) &unpackNameMode, "^[0-1]{1}$"), // reclassify - PARAM_RECLASSIFY_LAMBDA(PARAM_RECLASSIFY_LAMBDA_ID, "--lambda", "Reclassify lambda", "Lambda scaling factor for the reclassification score term", typeid(double), (void *) &reclassifyLambda, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_LAMBDA(PARAM_RECLASSIFY_LAMBDA_ID, "--lambda", "Reclassify lambda", "Lambda scaling factor for the reclassification bit score term", typeid(double), (void *) &reclassifyLambda, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), PARAM_RECLASSIFY_ALPHA(PARAM_RECLASSIFY_ALPHA_ID, "--alpha", "Reclassify alpha", "Exponent applied to abundance during reclassification", typeid(double), (void *) &reclassifyAlpha, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), - PARAM_RECLASSIFY_GAMMA(PARAM_RECLASSIFY_GAMMA_ID, "--gamma", "Reclassify gamma", "Exponent applied to entropy penalty during reclassification", typeid(double), (void *) &reclassifyGamma, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), + PARAM_RECLASSIFY_GAMMA(PARAM_RECLASSIFY_GAMMA_ID, "--gamma", "Reclassify gamma", "Exponent applied to coverage confidence during reclassification", typeid(double), (void *) &reclassifyGamma, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), PARAM_RECLASSIFY_MAX_ITER(PARAM_RECLASSIFY_MAX_ITER_ID, "--max-iter", "Reclassify max iterations", "Maximum number of SQUAREM iterations for reclassification", typeid(int), (void *) &reclassifyMaxIterations, "^[1-9]{1}[0-9]*$"), PARAM_RECLASSIFY_TOL(PARAM_RECLASSIFY_TOL_ID, "--tol", "Reclassify tolerance", "Convergence tolerance for reclassification", typeid(double), (void *) &reclassifyTolerance, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), - PARAM_RECLASSIFY_TAXONOMY(PARAM_RECLASSIFY_TAXONOMY_ID, "--taxonomy", "Reclassify taxonomy output", "0: write alignment and protein abundance only, 1: also write taxonomy_abundance.tsv", typeid(int), (void *) &reclassifyTaxonomy, "^[0-1]{1}$"), - PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE(PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE_ID, "--max-drop-percentage", "Max drop percentage", "Maximum percentage of targets that the automatic jump-based filter may classify as a tail for dropping (range 0.0-100.0)", typeid(double), (void *) &reclassifyMaxDropPercentage, "^100(\\.0+)?$|^([0-9]|[1-9][0-9])(\\.[0-9]+)?$"), + PARAM_RECLASSIFY_TAXONOMY(PARAM_RECLASSIFY_TAXONOMY_ID, "--taxonomy", "Reclassify taxonomy output", "0: write alignment and protein abundance only; taxonomy files are not required. 1: also write taxonomy_abundance.tsv and taxonomic columns in protein_abundance.tsv; requires targetDB_mapping and targetDB_taxonomy", typeid(int), (void *) &reclassifyTaxonomy, "^[0-1]{1}$"), + PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE(PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE_ID, "--max-drop-percentage", "Max drop percentage", "Maximum percentage of targets that the automatic jump-based filter may classify as a tail for dropping (range 0.0-100.0, default 30.0)", typeid(double), (void *) &reclassifyMaxDropPercentage, "^100(\\.0+)?$|^([0-9]|[1-9][0-9])(\\.[0-9]+)?$"), // for modules that should handle -h themselves PARAM_HELP(PARAM_HELP_ID, "-h", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN), PARAM_HELP_LONG(PARAM_HELP_LONG_ID, "--help", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN) @@ -2664,7 +2664,7 @@ void Parameters::setDefaults() { reclassifyMaxIterations = 100; reclassifyTolerance = 1e-5; reclassifyTaxonomy = 0; - reclassifyMaxDropPercentage = 20.0; + reclassifyMaxDropPercentage = 30.0; lcaRanks = ""; showTaxLineage = 0; diff --git a/src/util/reclassify_taxonomy.cpp b/src/util/reclassify_taxonomy.cpp index 6e30198c2..3e03fc710 100644 --- a/src/util/reclassify_taxonomy.cpp +++ b/src/util/reclassify_taxonomy.cpp @@ -17,14 +17,16 @@ #include #include #include +#ifdef OPENMP +#include +#endif namespace { struct ReclassTaxEntry { Matcher::result_t result; double abundance; double posterior; - double entropyValue; - double entropyPenalty; + double coverageConfidence; }; typedef std::unordered_map > MappingTable; @@ -38,8 +40,7 @@ struct TargetStats { unsigned int key; unsigned int targetLength; double abundance; - double entropyValue; - double entropyPenalty; + double coverageConfidence; bool dropped; std::vector intervals; }; @@ -47,11 +48,10 @@ struct TargetStats { struct TaxonomyStats { unsigned int taxId; double abundance; - double entropySum; - double entropyPenaltySum; + double coverageConfidenceSum; size_t proteinCount; - TaxonomyStats() : taxId(0), abundance(0.0), entropySum(0.0), entropyPenaltySum(0.0), proteinCount(0) {} + TaxonomyStats() : taxId(0), abundance(0.0), coverageConfidenceSum(0.0), proteinCount(0) {} }; struct ReclassTaxContext { @@ -71,6 +71,8 @@ static const double EPS = 1e-12; static const size_t MIN_FILTER_TARGETS = 20; static const size_t MIN_TAIL_TARGETS = 2; +static double clamp01(double value); + static std::vector targetListFromSet(const std::unordered_set &targets) { std::vector out(targets.begin(), targets.end()); SORT_SERIAL(out.begin(), out.end()); @@ -109,7 +111,7 @@ static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &c } Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); - records.push_back(ReclassTaxEntry{result, 0.0, 0.0, 0.0, 0.0}); + records.push_back(ReclassTaxEntry{result, 0.0, 0.0, 0.0}); ctx.targetSet.insert(result.dbKey); data = Util::skipLine(data); } @@ -152,16 +154,33 @@ static void initAbundance(MappingTable &mappingTable, const std::unordered_set &targetSet, double lambda) { +struct TargetHitRef { + const ReclassTaxEntry *entry; + double expScore; + double weight; +}; + +static void initCoverageConfidence(MappingTable &mappingTable, + const std::unordered_set &targetSet, + double lambda, + int threads) { std::unordered_map targetMin; std::unordered_map targetMax; + std::unordered_map targetLenMap; + std::unordered_map > hitsByTarget; for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { targetMin[*it] = std::numeric_limits::max(); targetMax[*it] = std::numeric_limits::min(); + targetLenMap[*it] = 0; + hitsByTarget.emplace(*it, std::vector()); } for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSumExp = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSumExp += std::exp(lambda * static_cast(it->second[j].result.score)); + } for (size_t j = 0; j < it->second.size(); ++j) { const unsigned int target = it->second[j].result.dbKey; if (it->second[j].result.dbStartPos < targetMin[target]) { @@ -170,67 +189,67 @@ static void initEntropy(MappingTable &mappingTable, const std::unordered_setsecond[j].result.dbEndPos > targetMax[target]) { targetMax[target] = it->second[j].result.dbEndPos; } + if (static_cast(it->second[j].result.dbLen) > targetLenMap[target]) { + targetLenMap[target] = static_cast(it->second[j].result.dbLen); + } + const double expScore = std::exp(lambda * static_cast(it->second[j].result.score)); + const double weight = (scoreSumExp > 0.0) ? (expScore / scoreSumExp) : 0.0; + hitsByTarget[target].push_back(TargetHitRef{&it->second[j], expScore, weight}); } } - std::unordered_map > coverage; - coverage.reserve(targetSet.size()); - for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { - const int start = targetMin[*it]; - const int end = targetMax[*it]; - const int len = (end >= start) ? (end - start + 1) : 1; - coverage[*it] = std::vector(len, 0.0); - } + std::unordered_map coverageFraction; + coverageFraction.reserve(targetSet.size()); + const std::vector targetList = targetListFromSet(targetSet); + std::vector coverageFractionByIndex(targetList.size(), 0.0); - for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - for (size_t j = 0; j < it->second.size(); ++j) { - const Matcher::result_t &result = it->second[j].result; - const int targetLen = result.dbEndPos - result.dbStartPos + 1; - if (targetLen <= 0) { - continue; - } +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) + for (size_t i = 0; i < targetList.size(); ++i) { + const unsigned int target = targetList[i]; + const int startPos = targetMin[target]; + const int endPos = targetMax[target]; + const int len = (endPos >= startPos) ? (endPos - startPos + 1) : 1; + std::vector cov(static_cast(len), 0.0); + std::vector covConf(static_cast(len), 0.0); + + std::unordered_map >::const_iterator hitIt = hitsByTarget.find(target); + if (hitIt != hitsByTarget.end()) { + const std::vector &hits = hitIt->second; + for (size_t h = 0; h < hits.size(); ++h) { + const Matcher::result_t &result = hits[h].entry->result; + const int targetLen = result.dbEndPos - result.dbStartPos + 1; + if (targetLen <= 0) { + continue; + } - const double mq = std::exp(lambda * static_cast(result.score)) / static_cast(targetLen); - std::vector &cov = coverage[result.dbKey]; - const int start = std::max(0, result.dbStartPos - targetMin[result.dbKey]); - const int end = std::min(static_cast(cov.size()) - 1, result.dbEndPos - targetMin[result.dbKey]); - for (int pos = start; pos <= end; ++pos) { - cov[pos] += mq; + const double mq = hits[h].expScore / static_cast(targetLen); + const int start = std::max(0, result.dbStartPos - startPos); + const int end = std::min(len - 1, result.dbEndPos - startPos); + for (int pos = start; pos <= end; ++pos) { + cov[static_cast(pos)] += mq; + covConf[static_cast(pos)] += hits[h].weight; + } } } - } - std::unordered_map entropy; - entropy.reserve(targetSet.size()); - for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { - const std::vector &cov = coverage[*it]; - const double covSum = std::accumulate(cov.begin(), cov.end(), 0.0); - if (covSum <= 0.0) { - entropy[*it] = 0.0; - continue; - } - - double ent = 0.0; - for (size_t pos = 0; pos < cov.size(); ++pos) { - if (cov[pos] <= 0.0) { - continue; - } - const double p = cov[pos] / covSum; - ent -= p * std::log2(p); + double covered = 0.0; + for (size_t pos = 0; pos < covConf.size(); ++pos) { + covered += std::min(1.0, covConf[pos]); } - entropy[*it] = ent; + const unsigned int targetLen = (targetLenMap[target] > 0) ? targetLenMap[target] : 1; + const double fraction = covered / static_cast(targetLen); + coverageFractionByIndex[i] = clamp01(fraction); } - double entropySum = 0.0; - for (std::unordered_map::const_iterator it = entropy.begin(); it != entropy.end(); ++it) { - entropySum += it->second; + for (size_t i = 0; i < targetList.size(); ++i) { + coverageFraction[targetList[i]] = coverageFractionByIndex[i]; } for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { for (size_t j = 0; j < it->second.size(); ++j) { const unsigned int target = it->second[j].result.dbKey; - it->second[j].entropyValue = entropy[target]; - it->second[j].entropyPenalty = (entropySum > 0.0) ? (1.0 - (entropy[target] / entropySum)) : 0.0; + std::unordered_map::const_iterator cf = coverageFraction.find(target); + it->second[j].coverageConfidence = (cf != coverageFraction.end()) ? cf->second : 0.0; } } } @@ -250,8 +269,8 @@ static double scoreTerm(const ReclassTaxEntry &entry, double queryMaxScore, doub const double normalizedScore = static_cast(entry.result.score) / queryMaxScore; const double abundance = std::max(entry.abundance, EPS); - const double entropyPenalty = std::max(entry.entropyPenalty, EPS); - return std::exp(lambda * normalizedScore) * std::pow(abundance, alpha) * std::pow(entropyPenalty, gamma); + const double coverageConfidence = std::max(entry.coverageConfidence, EPS); + return std::exp(lambda * normalizedScore) * std::pow(abundance, alpha) * std::pow(coverageConfidence, gamma); } static void computePosterior(MappingTable &mappingTable, double lambda, double alpha, double gamma) { @@ -326,7 +345,7 @@ static void setAbundance(MappingTable &mappingTable, const std::vector emUpdate(MappingTable &mappingTable, double lambda, - const std::unordered_map > &fixedEntropy, + const std::unordered_map &fixedCoverageConfidence, const std::vector &targetList, size_t queryCount, double alpha, @@ -356,13 +375,11 @@ static std::vector emUpdate(MappingTable &mappingTable, for (size_t j = 0; j < it->second.size(); ++j) { const unsigned int target = it->second[j].result.dbKey; it->second[j].abundance = nextAbundance[target]; - std::unordered_map >::const_iterator fixed = fixedEntropy.find(target); - if (fixed != fixedEntropy.end()) { - it->second[j].entropyValue = fixed->second.first; - it->second[j].entropyPenalty = fixed->second.second; + std::unordered_map::const_iterator fixed = fixedCoverageConfidence.find(target); + if (fixed != fixedCoverageConfidence.end()) { + it->second[j].coverageConfidence = fixed->second; } else { - it->second[j].entropyValue = 0.0; - it->second[j].entropyPenalty = 0.0; + it->second[j].coverageConfidence = 0.0; } } } @@ -392,19 +409,26 @@ static std::vector projectSimplex(const std::vector &x) { return projected; } -static void squarem(ReclassTaxContext &ctx, double lambda, int maxIter, double tol, double alpha, double gamma) { +static void squarem(ReclassTaxContext &ctx, + double lambda, + int maxIter, + double tol, + double alpha, + double gamma, + int threads) { if (ctx.queryCount == 0 || ctx.targetSet.empty()) { return; } initAbundance(ctx.mappingTable, ctx.targetSet, ctx.queryCount); - initEntropy(ctx.mappingTable, ctx.targetSet, lambda); + initCoverageConfidence(ctx.mappingTable, ctx.targetSet, lambda, threads); + Debug(Debug::INFO) << "Reclassify-taxonomy initialized coverage confidence." << "\n"; - std::unordered_map > fixedEntropy; - fixedEntropy.reserve(ctx.targetSet.size()); + std::unordered_map fixedCoverageConfidence; + fixedCoverageConfidence.reserve(ctx.targetSet.size()); for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { for (size_t j = 0; j < it->second.size(); ++j) { - fixedEntropy[it->second[j].result.dbKey] = std::make_pair(it->second[j].entropyValue, it->second[j].entropyPenalty); + fixedCoverageConfidence[it->second[j].result.dbKey] = it->second[j].coverageConfidence; } } @@ -413,8 +437,8 @@ static void squarem(ReclassTaxContext &ctx, double lambda, int maxIter, double t std::vector logLikelihoods; for (int iter = 0; iter < maxIter; ++iter) { - const std::vector x1 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, gamma); - const std::vector x2 = emUpdate(ctx.mappingTable, lambda, fixedEntropy, targetList, ctx.queryCount, alpha, gamma); + const std::vector x1 = emUpdate(ctx.mappingTable, lambda, fixedCoverageConfidence, targetList, ctx.queryCount, alpha, gamma); + const std::vector x2 = emUpdate(ctx.mappingTable, lambda, fixedCoverageConfidence, targetList, ctx.queryCount, alpha, gamma); std::vector r(x0.size(), 0.0); std::vector v(x0.size(), 0.0); @@ -592,8 +616,7 @@ static std::vector collectTargetStats(const ReclassTaxContext &ctx) stats.key = entry.result.dbKey; stats.targetLength = entry.result.dbLen; stats.abundance = entry.abundance; - stats.entropyValue = entry.entropyValue; - stats.entropyPenalty = entry.entropyPenalty; + stats.coverageConfidence = entry.coverageConfidence; stats.dropped = false; addInterval(stats.intervals, entry.result.dbStartPos, entry.result.dbEndPos); } @@ -698,8 +721,8 @@ static std::unordered_set selectTailTargets(const std::vectorabundance : lhs->entropyValue; - const double rhsValue = useLowTail ? rhs->abundance : rhs->entropyValue; + const double lhsValue = useLowTail ? lhs->abundance : lhs->coverageConfidence; + const double rhsValue = useLowTail ? rhs->abundance : rhs->coverageConfidence; if (lhsValue != rhsValue) { return useLowTail ? (lhsValue < rhsValue) : (lhsValue > rhsValue); } @@ -866,34 +889,48 @@ static void writeReclassifiedM8(const ReclassTaxContext &ctx, static void writeProteinStats(const std::vector &stats, DBReader &targetHeaderReader, - MappingReader &mapping, + MappingReader *mapping, NcbiTaxonomy *taxonomy, const std::string &path) { FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); - fputs("target_key\ttarget_id\tabundance_pct\tentropy\tDrop(y/n)\tmapping_parts\tmapped_length\ttarget_length\ttaxid\trank\ttaxname\ttaxlineage\n", handle); + const bool withTaxonomy = (mapping != NULL && taxonomy != NULL); + if (withTaxonomy) { + fputs("target_key\ttarget_id\tabundance_pct\tcoverage_confidence\tDrop(y/n)\tmapped_length\ttarget_length\ttaxid\trank\ttaxname\ttaxlineage\n", handle); + } else { + fputs("target_key\ttarget_id\tabundance_pct\tcoverage_confidence\tDrop(y/n)\tmapped_length\ttarget_length\n", handle); + } for (size_t i = 0; i < stats.size(); ++i) { const unsigned int key = stats[i].key; const std::string targetId = identifierForKey(targetHeaderReader, key, 0); - const unsigned int taxId = mapping.lookup(key); - const TaxonNode *node = (taxId != 0) ? taxonomy->taxonNode(taxId, false) : NULL; - const std::string lineage = (node != NULL) ? taxonomy->taxLineage(node, true) : "unclassified"; - const std::string parts = intervalsToString(stats[i].intervals); const unsigned int mappedLength = intervalCoverage(stats[i].intervals); - fprintf(handle, "%u\t%s\t%.12g\t%.12g\t%s\t%s\t%u\t%u\t%u\t%s\t%s\t%s\n", - key, - targetId.c_str(), - stats[i].abundance, - stats[i].entropyValue, - stats[i].dropped ? "y" : "n", - parts.c_str(), - mappedLength, - stats[i].targetLength, - taxId, - (node != NULL) ? taxonomy->getString(node->rankIdx) : "unclassified", - (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified", - lineage.c_str()); + if (withTaxonomy) { + const unsigned int taxId = mapping->lookup(key); + const TaxonNode *node = (taxId != 0) ? taxonomy->taxonNode(taxId, false) : NULL; + const std::string lineage = (node != NULL) ? taxonomy->taxLineage(node, true) : "unclassified"; + fprintf(handle, "%u\t%s\t%.12g\t%.12g\t%s\t%u\t%u\t%u\t%s\t%s\t%s\n", + key, + targetId.c_str(), + stats[i].abundance, + stats[i].coverageConfidence, + stats[i].dropped ? "y" : "n", + mappedLength, + stats[i].targetLength, + taxId, + (node != NULL) ? taxonomy->getString(node->rankIdx) : "unclassified", + (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified", + lineage.c_str()); + } else { + fprintf(handle, "%u\t%s\t%.12g\t%.12g\t%s\t%u\t%u\n", + key, + targetId.c_str(), + stats[i].abundance, + stats[i].coverageConfidence, + stats[i].dropped ? "y" : "n", + mappedLength, + stats[i].targetLength); + } } fclose(handle); @@ -909,8 +946,7 @@ static void writeTaxonomyStats(const std::vector &stats, TaxonomyStats &entry = aggregated[taxId]; entry.taxId = taxId; entry.abundance += stats[i].abundance; - entry.entropySum += stats[i].entropyValue; - entry.entropyPenaltySum += stats[i].entropyPenalty; + entry.coverageConfidenceSum += stats[i].coverageConfidence; entry.proteinCount += 1; } @@ -927,20 +963,18 @@ static void writeTaxonomyStats(const std::vector &stats, }); FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); - fputs("taxid\trank\ttaxname\ttaxlineage\tprotein_abundance_pct\tprotein_count\tmean_entropy\n", handle); + fputs("taxid\trank\ttaxname\ttaxlineage\tprotein_abundance_pct\tprotein_count\n", handle); for (size_t i = 0; i < rows.size(); ++i) { const TaxonNode *node = (rows[i].taxId != 0) ? taxonomy->taxonNode(rows[i].taxId, false) : NULL; const std::string lineage = (node != NULL) ? taxonomy->taxLineage(node, true) : "unclassified"; - const double denom = (rows[i].proteinCount > 0) ? static_cast(rows[i].proteinCount) : 1.0; - fprintf(handle, "%u\t%s\t%s\t%s\t%.12g\t%zu\t%.12g\n", + fprintf(handle, "%u\t%s\t%s\t%s\t%.12g\t%zu\n", rows[i].taxId, (node != NULL) ? taxonomy->getString(node->rankIdx) : "unclassified", (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified", lineage.c_str(), rows[i].abundance, - rows[i].proteinCount, - rows[i].entropySum / denom); + rows[i].proteinCount); } fclose(handle); @@ -963,19 +997,25 @@ int reclassifytaxonomy(int argc, const char **argv, const Command &command) { DBReader::USE_INDEX | DBReader::USE_DATA); targetHeaderReader.open(DBReader::NOSORT); - NcbiTaxonomy *taxonomy = NcbiTaxonomy::openTaxonomy(par.db2); - MappingReader mapping(par.db2); + const bool withTaxonomy = (par.reclassifyTaxonomy == 1); + NcbiTaxonomy *taxonomy = NULL; + MappingReader *mapping = NULL; + if (withTaxonomy) { + taxonomy = NcbiTaxonomy::openTaxonomy(par.db2); + mapping = new MappingReader(par.db2); + } ReclassTaxContext ctx; loadAlignmentDb(reader, ctx); Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; - squarem(ctx, + squarem(ctx, par.reclassifyLambda, par.reclassifyMaxIterations, par.reclassifyTolerance, par.reclassifyAlpha, - par.reclassifyGamma); + par.reclassifyGamma, + par.threads); const std::string outDir = par.db4; const std::string m8Path = outDir + "/new_alignment_result.m8"; @@ -999,10 +1039,11 @@ int reclassifytaxonomy(int argc, const char **argv, const Command &command) { writeReclassifiedM8(ctx, queryHeaderReader, targetHeaderReader, m8Path); writeProteinStats(allTargetStats, targetHeaderReader, mapping, taxonomy, proteinPath); - if (par.reclassifyTaxonomy == 1) { - writeTaxonomyStats(targetStats, mapping, taxonomy, taxonomyPath); + if (withTaxonomy) { + writeTaxonomyStats(targetStats, *mapping, taxonomy, taxonomyPath); } + delete mapping; delete taxonomy; targetHeaderReader.close(); queryHeaderReader.close(); From 9a201006d00b6c0a79669f1fc8057348617b5536 Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Wed, 29 Apr 2026 20:58:16 +0900 Subject: [PATCH 08/12] Refine EM reclassify drop cutoff and sync equations doc --- NEW0423_reclassify.md | 317 ++++++++++++++ src/util/EM_reclassify.cpp | 853 +++++++++++++++++++++++++++++++++++++ 2 files changed, 1170 insertions(+) create mode 100644 NEW0423_reclassify.md create mode 100644 src/util/EM_reclassify.cpp diff --git a/NEW0423_reclassify.md b/NEW0423_reclassify.md new file mode 100644 index 000000000..e3b9a47c3 --- /dev/null +++ b/NEW0423_reclassify.md @@ -0,0 +1,317 @@ +# EM_reclassify.cpp Equations (NEW0423, KaTeX) + +본 문서는 `src/util/EM_reclassify.cpp`의 현재 구현 수식을 KaTeX 친화 문법으로 정리한 것이다. + +## 0. Constants +- $\varepsilon = 10^{-12}$ (`EPS`) +- $r_{\min}=-60,\ r_{\max}=60$ (`LOG_COMPATIBILITY_MIN`, `LOG_COMPATIBILITY_MAX`) +- $\tau_{\alpha}=3$ (`ABUNDANCE_EXP_TAU`) +- $\epsilon_\pi=10^{-8}$ (`ABUNDANCE_SMOOTH_EPS`) +- SQUAREM step bound: $[-1,1]$ (`STEP_MIN`, `STEP_MAX`) +- Drop filter activation threshold: $N_{\min}=20$ (`MIN_FILTER_TARGETS`) +- Minimum tail size for cutoff: $k_{\min}=2$ (`MIN_TAIL_TARGETS`) + +--- + +## 1. Abundance Initialization (`initAbundance`) +쿼리 $q$의 hit 집합을 $H(q)$라 하자. + +쿼리 내 score 합: +$$ +S_q = \sum_{t\in H(q)} \max(\mathrm{score}_{q,t}, 0) +$$ + +초기 가중치: +$$ +w_{q,t}^{(0)} = +\begin{cases} +\dfrac{\max(\mathrm{score}_{q,t},0)}{S_q}, & S_q > 0 \\ +0, & S_q \le 0 +\end{cases} +$$ + +타깃별 누적: +$$ +\tilde{\pi}_t^{(0)} = \sum_q w_{q,t}^{(0)} +$$ + +쿼리 수 평균: +$$ +\hat{\pi}_t^{(0)} = +\begin{cases} +\dfrac{\tilde{\pi}_t^{(0)}}{|Q|}, & |Q| > 0 \\ +\tilde{\pi}_t^{(0)}, & |Q| = 0 +\end{cases} +$$ + +simplex 정규화: +$$ +\pi_t^{(0)} = \frac{\hat{\pi}_t^{(0)}}{\sum_{t'} \hat{\pi}_{t'}^{(0)}} +$$ +(분모가 0이 아니면 수행) + +--- + +## 2. Coverage Confidence (`initCoverageConfidence`) +타깃 $t$의 관측 span 길이: +$$ +L_t = \max\left(1,\ \mathrm{maxEnd}_t - \mathrm{minStart}_t + 1\right) +$$ + +쿼리 내 hit 정규화 weight: +$$ +\omega_{q,h} = +\begin{cases} +\dfrac{\mathrm{score}_{q,h}}{\sum_{h'\in H(q)} \mathrm{score}_{q,h'}}, & \sum \mathrm{score} > 0 \\ +0, & \text{otherwise} +\end{cases} +$$ + +위치별 누적: +$$ +\mathrm{covConf}_t(p) = \sum_{h:\ p\in[start_h,end_h]} \omega_{q(h),h} +$$ + +클리핑: +$$ +\tilde{c}_t(p) = \min\left(1,\mathrm{covConf}_t(p)\right) +$$ + +coverage fraction: +$$ +f_t = \frac{1}{L_t} \sum_{p=1}^{L_t} \tilde{c}_t(p) +$$ + +HHI: +$$ +\mathrm{HHI}_t = +\begin{cases} +\dfrac{\sum_p \tilde{c}_t(p)^2}{\left(\sum_p \tilde{c}_t(p)\right)^2}, & \sum_p \tilde{c}_t(p) > 0 \\ +1, & \sum_p \tilde{c}_t(p) = 0 +\end{cases} +$$ + +penalty: +$$ +\mathrm{penalty}_t = 1 - \mathrm{HHI}_t +$$ + +최종 confidence: +$$ +c_t = \operatorname{clamp}_{[0,1]}\left(f_t\cdot\mathrm{penalty}_t\right) +$$ + +여기서 +$$ +\operatorname{clamp}_{[0,1]}(x) = \max(0, \min(1, x)) +$$ + +--- + +## 3. Query-Target Compatibility (`compatibilityLogTerm`) +best bit score: +$$ +b_q^{\max} = \max_{t\in H(q)} \mathrm{score}_{q,t} +$$ + +coverage feature: +$$ +\mathrm{cov}_{q,t} = \operatorname{clamp}_{[0,1]}\left(\min(qcov_{q,t}, dbcov_{q,t})\right) +$$ + +bit score 차이: +$$ +\Delta b_{q,t} = \mathrm{score}_{q,t} - b_q^{\max} +$$ + +compatibility log-term: +$$ +r_{q,t} = \beta_{bit}\Delta b_{q,t} + \beta_{id}\,id_{q,t} + \beta_{cov}\,\mathrm{cov}_{q,t} +$$ + +clip: +$$ +r_{q,t} \leftarrow \min\left(r_{\max},\max\left(r_{\min}, r_{q,t}\right)\right) +$$ + +compatibility: +$$ +\phi_{q,t} = \exp(r_{q,t}) +$$ + +기본 호출값: +- $\beta_{id}=1.0$ +- $\beta_{cov}=0.25$ +- $\beta_{bit}=\texttt{par.reclassifyLambda}$ + +--- + +## 4. Annealed Exponent (`squarem`) +iteration $m$ (`iter+1`)에서: +$$ +\alpha^{(m)} = \alpha_{\max}\left(1 - e^{-m/\tau_\alpha}\right) +$$ +- $\alpha_{\max}=\texttt{par.reclassifyAlpha}$ +- $\tau_\alpha=3$ + +--- + +## 5. E-step Posterior (`computePosterior`) +수치 안정화 abundance: +$$ +\bar{\pi}_t = \max(\pi_t, \epsilon_\pi) +$$ + +가중치: +$$ +N_{q,t} = \phi_{q,t}\cdot\bar{\pi}_t^{\alpha^{(m)}} +$$ + +posterior: +$$ +p_{q,t} = \frac{N_{q,t}}{\sum_{t'\in H(q)} N_{q,t'}} +$$ + +--- + +## 6. M-step Abundance Update (`emUpdate`) +responsibility 합: +$$ +R_t = \sum_q p_{q,t} +$$ + +coverage prior smoothing: +$$ +\tilde{\pi}_t = R_t + \kappa c_t + \epsilon_\pi +$$ +- $\kappa=\texttt{coveragePriorWeight}$ (코드에서는 `par.reclassifyGamma` 전달) + +정규화: +$$ +\pi_t^{new} = \frac{\tilde{\pi}_t}{\sum_{t'} \tilde{\pi}_{t'}} +$$ + +--- + +## 7. Log-likelihood (`logLikelihood`) +쿼리별 mixture: +$$ +M_q = \sum_{t\in H(q)} \phi_{q,t}\,\bar{\pi}_t^{\alpha^{(m)}} +$$ + +평균 로그우도: +$$ +\mathcal{L} = \frac{1}{|Q|}\sum_q \log\left(\max(M_q, 10^{-300})\right) +$$ + +--- + +## 8. Simplex Projection (`projectSimplex`) +비음수화: +$$ +x_i^+ = \max(x_i,0) +$$ + +합 $S=\sum_i x_i^+$가 양수면: +$$ +\Pi_i = \frac{x_i^+}{S} +$$ + +--- + +## 9. SQUAREM Extrapolation (`squarem`) +두 번의 EM 결과로: +$$ +r = x_1 - x_0,\qquad v = x_2 - x_1 - r +$$ + +노름: +$$ +\|r\|_2 = \sqrt{\sum_i r_i^2},\qquad \|v\|_2 = \sqrt{\sum_i v_i^2} +$$ + +가속계수: +$$ +a = +\begin{cases} +-1, & \|v\|_2 = 0 \\ +-\dfrac{\|r\|_2}{\|v\|_2}, & \text{otherwise} +\end{cases} +$$ + +clip: +$$ +a \leftarrow \min(1,\max(-1,a)) +$$ + +외삽: +$$ +x_{new} = x_0 - 2ar + a^2 v +$$ + +LL 감소 fallback: +$$ +\text{if } \mathcal{L}(x_{new}) < \mathcal{L}_{prev} - 10^{-9},\ \ x_{new} \leftarrow x_2 +$$ + +수렴량: +$$ +\Delta = \max_i |x_{new,i} - x_{0,i}| +$$ + +조건: +$$ +\Delta < \mathrm{tol}\ \text{and}\ \mathrm{iter} > 5 +$$ + +--- + +## 10. Target Drop Cutoff (`selectDroppedTargets`) +$$ +\mathrm{maxTailFraction} = \operatorname{clamp}_{[0,1]}\left(\frac{\mathrm{maxDropPercentage}}{100}\right) +$$ + +활성 조건: +$$ +n_{\text{targets}} \ge 20,\quad T=\sum_i a_i > \varepsilon,\quad \mathrm{maxTailFraction}>0 +$$ + +전체 abundance 질량: +$$ +T = \sum_i a_i,\qquad T_{maxTail} = \mathrm{maxTailFraction}\cdot T +$$ + +tail-quantile cutoff: +- low-tail: $\mathrm{cutoff}=a_{(k)}$ +- high-tail: $\mathrm{cutoff}=a_{(n-k+1)}$ +- 여기서 $k$는 누적 질량이 $T_{maxTail}$를 넘지 않는 최대 tail 크기이며, $2 \le k < n$ 필요 + +(현재 reclassify 경로에서는 low-tail abundance drop 사용) + +최종 drop 집합은 abundance 오름차순으로 tail을 다시 선택하며, 선택 질량이 $T_{maxTail}$를 넘기기 직전까지 포함한다(최소 2개는 허용). 전부 drop되는 경우에는 drop을 취소한다. + +--- + +## 11. Aux +제거 비율 로그: +$$ +\mathrm{removedPct} = 100\cdot\frac{|\mathrm{dropped}|}{\mathrm{totalTargets}} +$$ +(분모 0이면 0 처리) + +출력 정렬 우선순위: +$$ +\mathrm{posterior\ desc} \rightarrow \mathrm{bitScore\ desc} +$$ + +출력 시 `seqId`에 posterior 저장: +$$ +\mathrm{seqId}_{out} = p_{q,t} +$$ + +--- + +## 12. Update Notes (Code-aligned) +- `squarem`/`emUpdate`/`logLikelihood`에서 compatibility 가중치 호출값은 $\beta_{cov}=0.25$로 고정되어 있다. +- 타깃 drop은 coverage가 아니라 abundance low-tail 기준으로만 수행된다. +- drop cutoff는 tail-quantile 방식만 사용한다. diff --git a/src/util/EM_reclassify.cpp b/src/util/EM_reclassify.cpp new file mode 100644 index 000000000..2d1ab23bb --- /dev/null +++ b/src/util/EM_reclassify.cpp @@ -0,0 +1,853 @@ +#include "Parameters.h" +#include "DBReader.h" +#include "DBWriter.h" +#include "Debug.h" +#include "Util.h" +#include "Matcher.h" +#include "FastSort.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef OPENMP +#include +#endif + +namespace { +struct ReclassTaxEntry { + Matcher::result_t result; + double abundance; + double posterior; + double coverageConfidence; +}; + +typedef std::unordered_map > MappingTable; + +struct Interval { + int start; + int end; +}; + +struct TargetStats { + unsigned int key; + unsigned int targetLength; + double abundance; + double coverageConfidence; + bool dropped; + std::vector intervals; +}; + +struct ReclassTaxContext { + MappingTable mappingTable; + std::vector queryOrder; + std::unordered_set targetSet; + size_t queryCount; + bool hasBacktrace; + bool hasOrfPosition; + + ReclassTaxContext() : queryCount(0), hasBacktrace(false), hasOrfPosition(false) {} +}; + +static const double STEP_MIN = -1.0; +static const double STEP_MAX = 1.0; +static const double EPS = 1e-12; +static const double LOG_COMPATIBILITY_MIN = -60.0; +static const double LOG_COMPATIBILITY_MAX = 60.0; +static const double ABUNDANCE_EXP_TAU = 3.0; +static const double ABUNDANCE_SMOOTH_EPS = 1e-8; +static const size_t MIN_FILTER_TARGETS = 20; +static const size_t MIN_TAIL_TARGETS = 2; + +static double clamp01(double value); + +static std::vector targetListFromSet(const std::unordered_set &targets) { + std::vector out(targets.begin(), targets.end()); + SORT_SERIAL(out.begin(), out.end()); + return out; +} + +static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &ctx) { + Debug::Progress progress(reader.getSize()); + const char *entry[255]; + + for (size_t i = 0; i < reader.getSize(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = reader.getDbKey(i); + char *data = reader.getData(i, 0); + + if (reader.getEntryLen(i) <= 1) { + continue; + } + + std::vector &records = ctx.mappingTable[queryKey]; + if (records.empty()) { + ctx.queryOrder.push_back(queryKey); + } + while (*data != '\0') { + const size_t columns = Util::getWordsOfLine(data, entry, 255); + if (columns < Matcher::ALN_RES_WITHOUT_BT_COL_CNT) { + Debug(Debug::ERROR) << "Invalid alignment result record in query " << queryKey << ".\n"; + EXIT(EXIT_FAILURE); + } + + if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { + ctx.hasBacktrace = true; + } + if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { + ctx.hasOrfPosition = true; + } + + Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); + records.push_back(ReclassTaxEntry{result, 0.0, 0.0, 0.0}); + ctx.targetSet.insert(result.dbKey); + data = Util::skipLine(data); + } + } + + ctx.queryCount = ctx.mappingTable.size(); +} + +static void initAbundance(MappingTable &mappingTable, const std::unordered_set &targetSet, size_t queryCount) { + std::unordered_map initAbundance; + initAbundance.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + initAbundance[*it] = 0.0; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSum = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSum += std::max(static_cast(it->second[j].result.score), 0.0); + } + if (scoreSum <= 0.0) { + continue; + } + for (size_t j = 0; j < it->second.size(); ++j) { + const double nonNegativeScore = std::max(static_cast(it->second[j].result.score), 0.0); + initAbundance[it->second[j].result.dbKey] += nonNegativeScore / scoreSum; + } + } + + if (queryCount > 0) { + const double denom = static_cast(queryCount); + for (std::unordered_map::iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { + it->second /= denom; + } + } + double totalAbundance = 0.0; + for (std::unordered_map::const_iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { + totalAbundance += it->second; + } + if (totalAbundance > 0.0) { + for (std::unordered_map::iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { + it->second /= totalAbundance; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = initAbundance[it->second[j].result.dbKey]; + } + } +} + +struct TargetHitRef { + const ReclassTaxEntry *entry; + double score; + double weight; +}; + +static void initCoverageConfidence(MappingTable &mappingTable, + const std::unordered_set &targetSet, + int threads) { + (void)threads; + std::unordered_map targetMin; + std::unordered_map targetMax; + std::unordered_map > hitsByTarget; + + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + targetMin[*it] = std::numeric_limits::max(); + targetMax[*it] = std::numeric_limits::min(); + hitsByTarget.emplace(*it, std::vector()); + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSum = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSum += static_cast(it->second[j].result.score); + } + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + if (it->second[j].result.dbStartPos < targetMin[target]) { + targetMin[target] = it->second[j].result.dbStartPos; + } + if (it->second[j].result.dbEndPos > targetMax[target]) { + targetMax[target] = it->second[j].result.dbEndPos; + } + const double score = static_cast(it->second[j].result.score); + const double weight = (scoreSum > 0.0) ? (score / scoreSum) : 0.0; + hitsByTarget[target].push_back(TargetHitRef{&it->second[j], score, weight}); + } + } + + std::unordered_map coverageFraction; + coverageFraction.reserve(targetSet.size()); + const std::vector targetList = targetListFromSet(targetSet); + std::vector coverageFractionByIndex(targetList.size(), 0.0); + +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) + for (size_t i = 0; i < targetList.size(); ++i) { + const unsigned int target = targetList[i]; + const int startPos = targetMin[target]; + const int endPos = targetMax[target]; + const int len = (endPos >= startPos) ? (endPos - startPos + 1) : 1; + std::vector cov(static_cast(len), 0.0); + std::vector covConf(static_cast(len), 0.0); + + std::unordered_map >::const_iterator hitIt = hitsByTarget.find(target); + if (hitIt != hitsByTarget.end()) { + const std::vector &hits = hitIt->second; + for (size_t h = 0; h < hits.size(); ++h) { + const Matcher::result_t &result = hits[h].entry->result; + const int targetLen = result.dbEndPos - result.dbStartPos + 1; + if (targetLen <= 0) { + continue; + } + + const double mq = hits[h].score / static_cast(targetLen); + const int start = std::max(0, result.dbStartPos - startPos); + const int end = std::min(len - 1, result.dbEndPos - startPos); + for (int pos = start; pos <= end; ++pos) { + cov[static_cast(pos)] += mq; + covConf[static_cast(pos)] += hits[h].weight; + } + } + } + + double covered = 0.0; + double squaredCovered = 0.0; + for (size_t pos = 0; pos < covConf.size(); ++pos) { + const double clipped = std::min(1.0, covConf[pos]); + covered += clipped; + squaredCovered += clipped * clipped; + } + const double fraction = covered / static_cast(len); + const double hhi = (covered > 0.0) ? (squaredCovered / (covered * covered)) : 1.0; + const double concentrationPenalty = 1.0 - hhi; + coverageFractionByIndex[i] = clamp01(fraction * concentrationPenalty); + } + + for (size_t i = 0; i < targetList.size(); ++i) { + coverageFraction[targetList[i]] = coverageFractionByIndex[i]; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + std::unordered_map::const_iterator cf = coverageFraction.find(target); + it->second[j].coverageConfidence = (cf != coverageFraction.end()) ? cf->second : 0.0; + } + } +} + +static double maxQueryBitScore(const std::vector &entries) { + double maxScore = std::numeric_limits::lowest(); + for (size_t i = 0; i < entries.size(); ++i) { + maxScore = std::max(maxScore, static_cast(entries[i].result.score)); + } + return maxScore; +} + +static double hitCoverage(const Matcher::result_t &result) { + return clamp01(std::min(static_cast(result.qcov), static_cast(result.dbcov))); +} + +static double compatibilityLogTerm(const ReclassTaxEntry &entry, + double queryMaxScore, + double betaBit, + double betaSeqId, + double betaCov) { + const double deltaBit = static_cast(entry.result.score) - queryMaxScore; + const double seqId = static_cast(entry.result.seqId); + const double cov = hitCoverage(entry.result); + const double r = (betaBit * deltaBit) + (betaSeqId * seqId) + (betaCov * cov); + return std::max(LOG_COMPATIBILITY_MIN, std::min(LOG_COMPATIBILITY_MAX, r)); +} + +static void computePosterior(MappingTable &mappingTable, + double betaBit, + double betaSeqId, + double betaCov, + double abundanceExponent) { + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + const double queryMaxScore = maxQueryBitScore(it->second); + std::vector numerators(it->second.size(), 0.0); + double denom = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + const double phi = std::exp(compatibilityLogTerm(it->second[j], queryMaxScore, betaBit, betaSeqId, betaCov)); + const double abundance = std::max(it->second[j].abundance, ABUNDANCE_SMOOTH_EPS); + const double weighted = phi * std::pow(abundance, abundanceExponent); + numerators[j] = weighted; + denom += weighted; + } + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].posterior = (denom > 0.0) ? (numerators[j] / denom) : 0.0; + } + } +} + +static double logLikelihood(const MappingTable &mappingTable, + double betaBit, + double betaSeqId, + double betaCov, + double abundanceExponent, + size_t queryCount) { + if (queryCount == 0) { + return 0.0; + } + + double ll = 0.0; + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + const double queryMaxScore = maxQueryBitScore(it->second); + double mixture = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + const double phi = std::exp(compatibilityLogTerm(it->second[j], queryMaxScore, betaBit, betaSeqId, betaCov)); + const double abundance = std::max(it->second[j].abundance, ABUNDANCE_SMOOTH_EPS); + mixture += phi * std::pow(abundance, abundanceExponent); + } + ll += std::log(mixture > 0.0 ? mixture : 1e-300); + } + return ll / static_cast(queryCount); +} + +static std::vector abundanceVectorFromTable(const MappingTable &mappingTable, const std::vector &targetList) { + std::unordered_map abundance; + abundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + abundance[targetList[i]] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + abundance[it->second[j].result.dbKey] = it->second[j].abundance; + } + } + + std::vector out; + out.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + out.push_back(abundance[targetList[i]]); + } + return out; +} + +static void setAbundance(MappingTable &mappingTable, const std::vector &targetList, const std::vector &abundanceVector) { + std::unordered_map abundance; + abundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + abundance[targetList[i]] = abundanceVector[i]; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = abundance[it->second[j].result.dbKey]; + } + } +} + +static std::vector emUpdate(MappingTable &mappingTable, + double betaBit, + double betaSeqId, + double betaCov, + const std::unordered_map &fixedCoverageConfidence, + const std::vector &targetList, + size_t queryCount, + double abundanceExponent, + double coveragePriorWeight) { + computePosterior(mappingTable, betaBit, betaSeqId, betaCov, abundanceExponent); + + std::unordered_map nextAbundance; + nextAbundance.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + nextAbundance[targetList[i]] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + nextAbundance[it->second[j].result.dbKey] += it->second[j].posterior; + } + } + + (void)queryCount; + double denom = 0.0; + for (std::unordered_map::iterator it = nextAbundance.begin(); it != nextAbundance.end(); ++it) { + const std::unordered_map::const_iterator fixed = fixedCoverageConfidence.find(it->first); + const double confidence = (fixed != fixedCoverageConfidence.end()) ? fixed->second : 0.0; + it->second = it->second + (coveragePriorWeight * confidence) + ABUNDANCE_SMOOTH_EPS; + denom += it->second; + } + if (denom > 0.0) { + for (std::unordered_map::iterator it = nextAbundance.begin(); it != nextAbundance.end(); ++it) { + it->second /= denom; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + it->second[j].abundance = nextAbundance[target]; + std::unordered_map::const_iterator fixed = fixedCoverageConfidence.find(target); + if (fixed != fixedCoverageConfidence.end()) { + it->second[j].coverageConfidence = fixed->second; + } else { + it->second[j].coverageConfidence = 0.0; + } + } + } + + std::vector out; + out.reserve(targetList.size()); + for (size_t i = 0; i < targetList.size(); ++i) { + out.push_back(nextAbundance[targetList[i]]); + } + return out; +} + +static std::vector projectSimplex(const std::vector &x) { + std::vector projected = x; + for (size_t i = 0; i < projected.size(); ++i) { + if (projected[i] < 0.0) { + projected[i] = 0.0; + } + } + + const double sum = std::accumulate(projected.begin(), projected.end(), 0.0); + if (sum > 0.0) { + for (size_t i = 0; i < projected.size(); ++i) { + projected[i] /= sum; + } + } + return projected; +} + +static void squarem(ReclassTaxContext &ctx, + double betaBit, + int maxIter, + double tol, + double alphaMax, + double coveragePriorWeight, + int threads) { + if (ctx.queryCount == 0 || ctx.targetSet.empty()) { + return; + } + + initAbundance(ctx.mappingTable, ctx.targetSet, ctx.queryCount); + initCoverageConfidence(ctx.mappingTable, ctx.targetSet, threads); + Debug(Debug::INFO) << "Reclassify initialized coverage confidence." << "\n"; + + std::unordered_map fixedCoverageConfidence; + fixedCoverageConfidence.reserve(ctx.targetSet.size()); + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + fixedCoverageConfidence[it->second[j].result.dbKey] = it->second[j].coverageConfidence; + } + } + + const std::vector targetList = targetListFromSet(ctx.targetSet); + std::vector x0 = abundanceVectorFromTable(ctx.mappingTable, targetList); + std::vector logLikelihoods; + + for (int iter = 0; iter < maxIter; ++iter) { + const double abundanceExponent = alphaMax * (1.0 - std::exp(-static_cast(iter + 1) / ABUNDANCE_EXP_TAU)); + const std::vector x1 = emUpdate(ctx.mappingTable, + betaBit, + 1.0, + 0.25, + fixedCoverageConfidence, + targetList, + ctx.queryCount, + abundanceExponent, + coveragePriorWeight); + const std::vector x2 = emUpdate(ctx.mappingTable, + betaBit, + 1.0, + 0.25, + fixedCoverageConfidence, + targetList, + ctx.queryCount, + abundanceExponent, + coveragePriorWeight); + + std::vector r(x0.size(), 0.0); + std::vector v(x0.size(), 0.0); + for (size_t i = 0; i < x0.size(); ++i) { + r[i] = x1[i] - x0[i]; + v[i] = x2[i] - x1[i] - r[i]; + } + + double normR = 0.0; + double normV = 0.0; + for (size_t i = 0; i < x0.size(); ++i) { + normR += r[i] * r[i]; + normV += v[i] * v[i]; + } + normR = std::sqrt(normR); + normV = std::sqrt(normV); + + double accel = (normV == 0.0) ? -1.0 : -(normR / normV); + accel = std::max(STEP_MIN, std::min(STEP_MAX, accel)); + + std::vector xNew(x0.size(), 0.0); + for (size_t i = 0; i < x0.size(); ++i) { + xNew[i] = x0[i] - 2.0 * accel * r[i] + accel * accel * v[i]; + } + xNew = projectSimplex(xNew); + + setAbundance(ctx.mappingTable, targetList, xNew); + computePosterior(ctx.mappingTable, betaBit, 1.0, 0.25, abundanceExponent); + double currentLl = logLikelihood(ctx.mappingTable, betaBit, 1.0, 0.25, abundanceExponent, ctx.queryCount); + + if (!logLikelihoods.empty() && currentLl < logLikelihoods.back() - 1e-9) { + setAbundance(ctx.mappingTable, targetList, x2); + computePosterior(ctx.mappingTable, betaBit, 1.0, 0.25, abundanceExponent); + currentLl = logLikelihood(ctx.mappingTable, betaBit, 1.0, 0.25, abundanceExponent, ctx.queryCount); + xNew = x2; + } + + logLikelihoods.push_back(currentLl); + + double parameterChange = 0.0; + for (size_t i = 0; i < x0.size(); ++i) { + parameterChange = std::max(parameterChange, std::fabs(xNew[i] - x0[i])); + } + + Debug(Debug::INFO) << "Reclassify iteration " << iter << ": LL=" << currentLl << " delta=" << parameterChange << "\n"; + x0 = xNew; + if (parameterChange < tol && iter > 5) { + Debug(Debug::INFO) << "Reclassify converged after " << (iter + 1) << " iterations." << "\n"; + break; + } + } +} + +static void addInterval(std::vector &intervals, int start, int end) { + Interval interval; + interval.start = std::min(start, end); + interval.end = std::max(start, end); + intervals.push_back(interval); +} + +static std::vector mergeIntervals(std::vector intervals) { + if (intervals.empty()) { + return intervals; + } + + std::sort(intervals.begin(), intervals.end(), [](const Interval &lhs, const Interval &rhs) { + if (lhs.start != rhs.start) { + return lhs.start < rhs.start; + } + return lhs.end < rhs.end; + }); + + std::vector merged; + merged.push_back(intervals[0]); + for (size_t i = 1; i < intervals.size(); ++i) { + if (intervals[i].start <= merged.back().end + 1) { + merged.back().end = std::max(merged.back().end, intervals[i].end); + } else { + merged.push_back(intervals[i]); + } + } + return merged; +} + +static std::vector collectTargetStats(const ReclassTaxContext &ctx) { + std::unordered_map statsByTarget; + + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const ReclassTaxEntry &entry = it->second[j]; + TargetStats &stats = statsByTarget[entry.result.dbKey]; + stats.key = entry.result.dbKey; + stats.targetLength = entry.result.dbLen; + stats.abundance = entry.abundance; + stats.coverageConfidence = entry.coverageConfidence; + stats.dropped = false; + addInterval(stats.intervals, entry.result.dbStartPos, entry.result.dbEndPos); + } + } + + std::vector out; + out.reserve(statsByTarget.size()); + for (std::unordered_map::iterator it = statsByTarget.begin(); it != statsByTarget.end(); ++it) { + it->second.intervals = mergeIntervals(it->second.intervals); + out.push_back(it->second); + } + + std::sort(out.begin(), out.end(), [](const TargetStats &lhs, const TargetStats &rhs) { + if (lhs.abundance != rhs.abundance) { + return lhs.abundance > rhs.abundance; + } + return lhs.key < rhs.key; + }); + return out; +} + +static double clamp01(double value) { + return std::max(0.0, std::min(1.0, value)); +} + +static bool tailQuantileCutoff(std::vector values, + bool useLowTail, + double maxTailFraction, + double &cutoff, + size_t &tailCount) { + cutoff = 0.0; + tailCount = 0; + if (values.size() < MIN_FILTER_TARGETS) { + return false; + } + + std::sort(values.begin(), values.end()); + maxTailFraction = clamp01(maxTailFraction); + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + if (totalMass <= EPS || maxTailFraction <= 0.0) { + return false; + } + const double maxTailMass = maxTailFraction * totalMass; + + double accumulatedMass = 0.0; + size_t maxTailCount = 0; + for (size_t i = 0; i < values.size(); ++i) { + const double candidate = accumulatedMass + values[i]; + if (candidate > (maxTailMass + EPS)) { + break; + } + accumulatedMass = candidate; + ++maxTailCount; + } + if (maxTailCount < MIN_TAIL_TARGETS || maxTailCount >= values.size()) { + return false; + } + + tailCount = maxTailCount; + if (useLowTail) { + cutoff = values[tailCount - 1]; + } else { + cutoff = values[values.size() - tailCount]; + } + return true; +} + +static std::unordered_set selectTailTargets(const std::vector &stats, + bool useLowTail, + size_t tailCount, + double maxTailFraction) { + std::vector ordered; + ordered.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + ordered.push_back(&stats[i]); + } + + std::sort(ordered.begin(), ordered.end(), [useLowTail](const TargetStats *lhs, const TargetStats *rhs) { + const double lhsValue = useLowTail ? lhs->abundance : lhs->coverageConfidence; + const double rhsValue = useLowTail ? rhs->abundance : rhs->coverageConfidence; + if (lhsValue != rhsValue) { + return useLowTail ? (lhsValue < rhsValue) : (lhsValue > rhsValue); + } + return lhs->key < rhs->key; + }); + + double totalMass = 0.0; + for (size_t i = 0; i < ordered.size(); ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + totalMass += value; + } + const double maxTailMass = clamp01(maxTailFraction) * totalMass; + + std::unordered_set selected; + const size_t limit = std::min(tailCount, ordered.size()); + double selectedMass = 0.0; + selected.reserve(limit); + for (size_t i = 0; i < limit; ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + if (selected.size() >= MIN_TAIL_TARGETS && (selectedMass + value) > (maxTailMass + EPS)) { + break; + } + selectedMass += value; + selected.insert(ordered[i]->key); + } + return selected; +} + +static std::unordered_set selectDroppedTargets(const std::vector &stats, + double maxDropPercentage, + double &abundanceCutoff) { + std::unordered_set dropped; + if (stats.empty()) { + abundanceCutoff = 0.0; + return dropped; + } + if (stats.size() < MIN_FILTER_TARGETS) { + abundanceCutoff = 0.0; + return dropped; + } + + std::vector abundances; + abundances.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + abundances.push_back(stats[i].abundance); + } + + const double maxTailFraction = clamp01(maxDropPercentage / 100.0); + size_t abundanceTailCount = 0; + const bool hasAbundanceCutoff = tailQuantileCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); + if (hasAbundanceCutoff == false) { + abundanceCutoff = 0.0; + return dropped; + } + + const std::unordered_set lowAbundanceTargets = selectTailTargets(stats, true, abundanceTailCount, maxTailFraction); + for (std::unordered_set::const_iterator it = lowAbundanceTargets.begin(); it != lowAbundanceTargets.end(); ++it) { + dropped.insert(*it); + } + if (dropped.size() == stats.size()) { + dropped.clear(); + } + return dropped; +} + +static void applyDroppedTargets(ReclassTaxContext &ctx, + const std::unordered_set &dropped, + size_t totalTargets, + double abundanceCutoff) { + if (dropped.empty()) { + Debug(Debug::INFO) << "Reclassify target filter kept all targets. abundance cutoff=" + << abundanceCutoff << "\n"; + return; + } + + for (MappingTable::iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end();) { + std::vector &records = it->second; + records.erase(std::remove_if(records.begin(), records.end(), [&dropped](const ReclassTaxEntry &entry) { + return dropped.find(entry.result.dbKey) != dropped.end(); + }), records.end()); + + if (records.empty()) { + it = ctx.mappingTable.erase(it); + } else { + ++it; + } + } + + ctx.queryOrder.erase(std::remove_if(ctx.queryOrder.begin(), ctx.queryOrder.end(), [&ctx](unsigned int queryKey) { + return ctx.mappingTable.find(queryKey) == ctx.mappingTable.end(); + }), ctx.queryOrder.end()); + + for (std::unordered_set::const_iterator it = dropped.begin(); it != dropped.end(); ++it) { + ctx.targetSet.erase(*it); + } + + const double removedPct = (totalTargets > 0) + ? (100.0 * static_cast(dropped.size()) / static_cast(totalTargets)) + : 0.0; + Debug(Debug::INFO) << "Reclassify dropped " << dropped.size() + << " of " << totalTargets + << " targets (" << removedPct << "%)" + << " using abundance <= " << abundanceCutoff << ".\n"; +} + +static bool compareByPosteriorThenBitScore(const ReclassTaxEntry &a, const ReclassTaxEntry &b) { + if (a.posterior != b.posterior) { + return a.posterior > b.posterior; + } + if (a.result.score != b.result.score) { + return a.result.score > b.result.score; + } + return Matcher::compareHits(a.result, b.result); +} + +static void writeReclassifiedDb(const ReclassTaxContext &ctx, + int dbType, + const std::string &outDb, + const std::string &outIndex, + int threads, + bool compress) { + DBWriter writer(outDb.c_str(), outIndex.c_str(), threads, compress, dbType); + writer.open(); + + Debug::Progress progress(ctx.queryOrder.size()); +#pragma omp parallel + { + unsigned int thread_idx = 0; +#ifdef OPENMP + thread_idx = static_cast(omp_get_thread_num()); +#endif + char buffer[1024 + 32768 * 4]; + +#pragma omp for schedule(dynamic, 5) + for (size_t i = 0; i < ctx.queryOrder.size(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = ctx.queryOrder[i]; + MappingTable::const_iterator recordsIt = ctx.mappingTable.find(queryKey); + if (recordsIt == ctx.mappingTable.end()) { + continue; + } + + std::vector records = recordsIt->second; + SORT_SERIAL(records.begin(), records.end(), compareByPosteriorThenBitScore); + + writer.writeStart(thread_idx); + for (size_t j = 0; j < records.size(); ++j) { + Matcher::result_t res = records[j].result; + res.seqId = static_cast(records[j].posterior); + size_t len = Matcher::resultToBuffer(buffer, res, ctx.hasBacktrace, ctx.hasOrfPosition); + writer.writeAdd(buffer, len, thread_idx); + } + writer.writeEnd(queryKey, thread_idx); + } + } + + writer.close(); +} +} + +int emreclassify(int argc, const char **argv, const Command &command) { + Parameters &par = Parameters::getInstance(); + par.parseParameters(argc, argv, command, true, 0, 0); + + DBReader reader(par.db3.c_str(), par.db3Index.c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + reader.open(DBReader::LINEAR_ACCCESS); + + ReclassTaxContext ctx; + loadAlignmentDb(reader, ctx); + Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; + + squarem(ctx, + par.reclassifyLambda, + par.reclassifyMaxIterations, + par.reclassifyTolerance, + par.reclassifyAlpha, + par.reclassifyGamma, + par.threads); + + std::vector allTargetStats = collectTargetStats(ctx); + double abundanceCutoff = 0.0; + const std::unordered_set dropped = selectDroppedTargets(allTargetStats, + par.reclassifyMaxDropPercentage, + abundanceCutoff); + applyDroppedTargets(ctx, dropped, allTargetStats.size(), abundanceCutoff); + + writeReclassifiedDb(ctx, reader.getDbtype(), par.db4, par.db4Index, par.threads, par.compressed); + + reader.close(); + return EXIT_SUCCESS; +} From 864c06bb6b3432c92f76e42a75ed49a7415605ed Mon Sep 17 00:00:00 2001 From: yeahedge01 Date: Wed, 29 Apr 2026 20:59:52 +0900 Subject: [PATCH 09/12] Delete NEW0423_reclassify.md --- NEW0423_reclassify.md | 317 ------------------------------------------ 1 file changed, 317 deletions(-) delete mode 100644 NEW0423_reclassify.md diff --git a/NEW0423_reclassify.md b/NEW0423_reclassify.md deleted file mode 100644 index e3b9a47c3..000000000 --- a/NEW0423_reclassify.md +++ /dev/null @@ -1,317 +0,0 @@ -# EM_reclassify.cpp Equations (NEW0423, KaTeX) - -본 문서는 `src/util/EM_reclassify.cpp`의 현재 구현 수식을 KaTeX 친화 문법으로 정리한 것이다. - -## 0. Constants -- $\varepsilon = 10^{-12}$ (`EPS`) -- $r_{\min}=-60,\ r_{\max}=60$ (`LOG_COMPATIBILITY_MIN`, `LOG_COMPATIBILITY_MAX`) -- $\tau_{\alpha}=3$ (`ABUNDANCE_EXP_TAU`) -- $\epsilon_\pi=10^{-8}$ (`ABUNDANCE_SMOOTH_EPS`) -- SQUAREM step bound: $[-1,1]$ (`STEP_MIN`, `STEP_MAX`) -- Drop filter activation threshold: $N_{\min}=20$ (`MIN_FILTER_TARGETS`) -- Minimum tail size for cutoff: $k_{\min}=2$ (`MIN_TAIL_TARGETS`) - ---- - -## 1. Abundance Initialization (`initAbundance`) -쿼리 $q$의 hit 집합을 $H(q)$라 하자. - -쿼리 내 score 합: -$$ -S_q = \sum_{t\in H(q)} \max(\mathrm{score}_{q,t}, 0) -$$ - -초기 가중치: -$$ -w_{q,t}^{(0)} = -\begin{cases} -\dfrac{\max(\mathrm{score}_{q,t},0)}{S_q}, & S_q > 0 \\ -0, & S_q \le 0 -\end{cases} -$$ - -타깃별 누적: -$$ -\tilde{\pi}_t^{(0)} = \sum_q w_{q,t}^{(0)} -$$ - -쿼리 수 평균: -$$ -\hat{\pi}_t^{(0)} = -\begin{cases} -\dfrac{\tilde{\pi}_t^{(0)}}{|Q|}, & |Q| > 0 \\ -\tilde{\pi}_t^{(0)}, & |Q| = 0 -\end{cases} -$$ - -simplex 정규화: -$$ -\pi_t^{(0)} = \frac{\hat{\pi}_t^{(0)}}{\sum_{t'} \hat{\pi}_{t'}^{(0)}} -$$ -(분모가 0이 아니면 수행) - ---- - -## 2. Coverage Confidence (`initCoverageConfidence`) -타깃 $t$의 관측 span 길이: -$$ -L_t = \max\left(1,\ \mathrm{maxEnd}_t - \mathrm{minStart}_t + 1\right) -$$ - -쿼리 내 hit 정규화 weight: -$$ -\omega_{q,h} = -\begin{cases} -\dfrac{\mathrm{score}_{q,h}}{\sum_{h'\in H(q)} \mathrm{score}_{q,h'}}, & \sum \mathrm{score} > 0 \\ -0, & \text{otherwise} -\end{cases} -$$ - -위치별 누적: -$$ -\mathrm{covConf}_t(p) = \sum_{h:\ p\in[start_h,end_h]} \omega_{q(h),h} -$$ - -클리핑: -$$ -\tilde{c}_t(p) = \min\left(1,\mathrm{covConf}_t(p)\right) -$$ - -coverage fraction: -$$ -f_t = \frac{1}{L_t} \sum_{p=1}^{L_t} \tilde{c}_t(p) -$$ - -HHI: -$$ -\mathrm{HHI}_t = -\begin{cases} -\dfrac{\sum_p \tilde{c}_t(p)^2}{\left(\sum_p \tilde{c}_t(p)\right)^2}, & \sum_p \tilde{c}_t(p) > 0 \\ -1, & \sum_p \tilde{c}_t(p) = 0 -\end{cases} -$$ - -penalty: -$$ -\mathrm{penalty}_t = 1 - \mathrm{HHI}_t -$$ - -최종 confidence: -$$ -c_t = \operatorname{clamp}_{[0,1]}\left(f_t\cdot\mathrm{penalty}_t\right) -$$ - -여기서 -$$ -\operatorname{clamp}_{[0,1]}(x) = \max(0, \min(1, x)) -$$ - ---- - -## 3. Query-Target Compatibility (`compatibilityLogTerm`) -best bit score: -$$ -b_q^{\max} = \max_{t\in H(q)} \mathrm{score}_{q,t} -$$ - -coverage feature: -$$ -\mathrm{cov}_{q,t} = \operatorname{clamp}_{[0,1]}\left(\min(qcov_{q,t}, dbcov_{q,t})\right) -$$ - -bit score 차이: -$$ -\Delta b_{q,t} = \mathrm{score}_{q,t} - b_q^{\max} -$$ - -compatibility log-term: -$$ -r_{q,t} = \beta_{bit}\Delta b_{q,t} + \beta_{id}\,id_{q,t} + \beta_{cov}\,\mathrm{cov}_{q,t} -$$ - -clip: -$$ -r_{q,t} \leftarrow \min\left(r_{\max},\max\left(r_{\min}, r_{q,t}\right)\right) -$$ - -compatibility: -$$ -\phi_{q,t} = \exp(r_{q,t}) -$$ - -기본 호출값: -- $\beta_{id}=1.0$ -- $\beta_{cov}=0.25$ -- $\beta_{bit}=\texttt{par.reclassifyLambda}$ - ---- - -## 4. Annealed Exponent (`squarem`) -iteration $m$ (`iter+1`)에서: -$$ -\alpha^{(m)} = \alpha_{\max}\left(1 - e^{-m/\tau_\alpha}\right) -$$ -- $\alpha_{\max}=\texttt{par.reclassifyAlpha}$ -- $\tau_\alpha=3$ - ---- - -## 5. E-step Posterior (`computePosterior`) -수치 안정화 abundance: -$$ -\bar{\pi}_t = \max(\pi_t, \epsilon_\pi) -$$ - -가중치: -$$ -N_{q,t} = \phi_{q,t}\cdot\bar{\pi}_t^{\alpha^{(m)}} -$$ - -posterior: -$$ -p_{q,t} = \frac{N_{q,t}}{\sum_{t'\in H(q)} N_{q,t'}} -$$ - ---- - -## 6. M-step Abundance Update (`emUpdate`) -responsibility 합: -$$ -R_t = \sum_q p_{q,t} -$$ - -coverage prior smoothing: -$$ -\tilde{\pi}_t = R_t + \kappa c_t + \epsilon_\pi -$$ -- $\kappa=\texttt{coveragePriorWeight}$ (코드에서는 `par.reclassifyGamma` 전달) - -정규화: -$$ -\pi_t^{new} = \frac{\tilde{\pi}_t}{\sum_{t'} \tilde{\pi}_{t'}} -$$ - ---- - -## 7. Log-likelihood (`logLikelihood`) -쿼리별 mixture: -$$ -M_q = \sum_{t\in H(q)} \phi_{q,t}\,\bar{\pi}_t^{\alpha^{(m)}} -$$ - -평균 로그우도: -$$ -\mathcal{L} = \frac{1}{|Q|}\sum_q \log\left(\max(M_q, 10^{-300})\right) -$$ - ---- - -## 8. Simplex Projection (`projectSimplex`) -비음수화: -$$ -x_i^+ = \max(x_i,0) -$$ - -합 $S=\sum_i x_i^+$가 양수면: -$$ -\Pi_i = \frac{x_i^+}{S} -$$ - ---- - -## 9. SQUAREM Extrapolation (`squarem`) -두 번의 EM 결과로: -$$ -r = x_1 - x_0,\qquad v = x_2 - x_1 - r -$$ - -노름: -$$ -\|r\|_2 = \sqrt{\sum_i r_i^2},\qquad \|v\|_2 = \sqrt{\sum_i v_i^2} -$$ - -가속계수: -$$ -a = -\begin{cases} --1, & \|v\|_2 = 0 \\ --\dfrac{\|r\|_2}{\|v\|_2}, & \text{otherwise} -\end{cases} -$$ - -clip: -$$ -a \leftarrow \min(1,\max(-1,a)) -$$ - -외삽: -$$ -x_{new} = x_0 - 2ar + a^2 v -$$ - -LL 감소 fallback: -$$ -\text{if } \mathcal{L}(x_{new}) < \mathcal{L}_{prev} - 10^{-9},\ \ x_{new} \leftarrow x_2 -$$ - -수렴량: -$$ -\Delta = \max_i |x_{new,i} - x_{0,i}| -$$ - -조건: -$$ -\Delta < \mathrm{tol}\ \text{and}\ \mathrm{iter} > 5 -$$ - ---- - -## 10. Target Drop Cutoff (`selectDroppedTargets`) -$$ -\mathrm{maxTailFraction} = \operatorname{clamp}_{[0,1]}\left(\frac{\mathrm{maxDropPercentage}}{100}\right) -$$ - -활성 조건: -$$ -n_{\text{targets}} \ge 20,\quad T=\sum_i a_i > \varepsilon,\quad \mathrm{maxTailFraction}>0 -$$ - -전체 abundance 질량: -$$ -T = \sum_i a_i,\qquad T_{maxTail} = \mathrm{maxTailFraction}\cdot T -$$ - -tail-quantile cutoff: -- low-tail: $\mathrm{cutoff}=a_{(k)}$ -- high-tail: $\mathrm{cutoff}=a_{(n-k+1)}$ -- 여기서 $k$는 누적 질량이 $T_{maxTail}$를 넘지 않는 최대 tail 크기이며, $2 \le k < n$ 필요 - -(현재 reclassify 경로에서는 low-tail abundance drop 사용) - -최종 drop 집합은 abundance 오름차순으로 tail을 다시 선택하며, 선택 질량이 $T_{maxTail}$를 넘기기 직전까지 포함한다(최소 2개는 허용). 전부 drop되는 경우에는 drop을 취소한다. - ---- - -## 11. Aux -제거 비율 로그: -$$ -\mathrm{removedPct} = 100\cdot\frac{|\mathrm{dropped}|}{\mathrm{totalTargets}} -$$ -(분모 0이면 0 처리) - -출력 정렬 우선순위: -$$ -\mathrm{posterior\ desc} \rightarrow \mathrm{bitScore\ desc} -$$ - -출력 시 `seqId`에 posterior 저장: -$$ -\mathrm{seqId}_{out} = p_{q,t} -$$ - ---- - -## 12. Update Notes (Code-aligned) -- `squarem`/`emUpdate`/`logLikelihood`에서 compatibility 가중치 호출값은 $\beta_{cov}=0.25$로 고정되어 있다. -- 타깃 drop은 coverage가 아니라 abundance low-tail 기준으로만 수행된다. -- drop cutoff는 tail-quantile 방식만 사용한다. From 75d64120484c19d7752db921f08d1e385cc52713 Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Thu, 14 May 2026 22:32:58 +0900 Subject: [PATCH 10/12] EM_0514 update --- EM.md | 329 ------- EM_abundance.md | 64 ++ EM_paper.md | 317 ------- EM_reclassify.md | 191 ++++ countNUM.py | 123 +++ old.md | 174 ++++ reclassify.md | 158 ++++ reclassify0403.md | 205 +++++ src/CommandDeclarations.h | 3 +- src/MMseqsBase.cpp | 21 +- src/commons/Parameters.cpp | 21 +- src/commons/Parameters.h | 1 + src/util/CMakeLists.txt | 3 +- src/util/EM_abundnace.cpp | 738 ++++++++++++++++ ...assify_taxonomy.cpp => EM_reclassify2.cpp} | 429 +++------ src/util/EM_reclassify3.cpp | 835 ++++++++++++++++++ 16 files changed, 2637 insertions(+), 975 deletions(-) delete mode 100644 EM.md create mode 100644 EM_abundance.md delete mode 100644 EM_paper.md create mode 100644 EM_reclassify.md create mode 100644 countNUM.py create mode 100644 old.md create mode 100644 reclassify.md create mode 100644 reclassify0403.md create mode 100644 src/util/EM_abundnace.cpp rename src/util/{reclassify_taxonomy.cpp => EM_reclassify2.cpp} (65%) create mode 100644 src/util/EM_reclassify3.cpp diff --git a/EM.md b/EM.md deleted file mode 100644 index 2648dbdaa..000000000 --- a/EM.md +++ /dev/null @@ -1,329 +0,0 @@ -# EM Formulation in `reclassify_taxonomy.cpp` - -## Overview - -`reclassify_taxonomy.cpp` reorders MMseqs2 alignment hits by estimating posterior support for each target using: - -- alignment score -- target abundance -- entropy-based coverage penalty - -The method is an EM-like iterative update accelerated with SQUAREM. - -## Notation - -Let: - -- \( q \): a query -- \( H_q \): the set of hits for query \(q\) -- \( h \in H_q \): one hit -- \( t(h) \): the target of hit \(h\) -- \( Q \): total number of queries -- \( s_h \): alignment score of hit \(h\) -- \( A_t \): abundance of target \(t\) -- \( E_t \): entropy penalty of target \(t\) -- \( \lambda, \alpha, \gamma \): model parameters - -## 1. Initial Abundance - -For each query \(q\), compute the sum of scores over all hits: - -\[ -S_q = \sum_{h \in H_q} s_h -\] - -Each hit contributes a normalized score fraction to its target: - -\[ -c_h = \frac{s_h}{S_q} -\] - -Initial abundance of target \(t\) is: - -\[ -A_t^{(0)} = \frac{1}{Q} \sum_q \sum_{h \in H_q,\ t(h)=t} \frac{s_h}{S_q} -\] - -## 2. Coverage Weight for Entropy - -For each hit \(h\), define target-covered length: - -\[ -\ell_h = dbEndPos_h - dbStartPos_h + 1 -\] - -The code assigns each hit a coverage weight: - -\[ -m_h = \frac{e^{\lambda s_h}}{\ell_h} -\] - -For each covered target position \(p\), accumulate: - -\[ -C_t(p) = \sum_{h \text{ covering } p} m_h -\] - -## 3. Position Probability on a Target - -Let total coverage mass on target \(t\) be: - -\[ -Z_t = \sum_p C_t(p) -\] - -Then the normalized positional probability is: - -\[ -p_t(p) = \frac{C_t(p)}{Z_t} -\] - -## 4. Target Entropy - -Coverage entropy for target \(t\) is: - -\[ -H_t = - \sum_p p_t(p)\log_2 p_t(p) -\] - -If \(Z_t = 0\), then: - -\[ -H_t = 0 -\] - -## 5. Entropy Penalty - -Let the total entropy over all targets be: - -\[ -H_{\mathrm{sum}} = \sum_t H_t -\] - -Then the penalty used in scoring is: - -\[ -E_t = -\begin{cases} -1 - \frac{H_t}{H_{\mathrm{sum}}}, & H_{\mathrm{sum}} > 0 \\ -0, & H_{\mathrm{sum}} = 0 -\end{cases} -\] - -## 6. Score Term for One Hit - -For each query \(q\), define: - -\[ -S_{\max,q} = \max_{h \in H_q} s_h -\] - -In the current implementation, this is the maximum observed bit score among the query's candidate hits. - -With - -\[ -\varepsilon = 10^{-12} -\] - -the abundance and entropy penalty are clipped from below: - -\[ -A_h = \max(A_{t(h)}, \varepsilon) -\] - -\[ -E_h = \max(E_{t(h)}, \varepsilon) -\] - -The score term is: - -\[ -f_h = -\exp\left(\lambda \frac{s_h}{S_{\max,q}}\right) -\cdot A_h^{\alpha} -\cdot E_h^{\gamma} -\] - -## 7. E-step: Posterior Probability - -For each query \(q\), normalize the score terms over all its hits: - -\[ -Z_q = \sum_{h \in H_q} f_h -\] - -Then the posterior probability of hit \(h\) is: - -\[ -P(h \mid q) = -\begin{cases} -\frac{f_h}{Z_q}, & Z_q > 0 \\ -0, & Z_q = 0 -\end{cases} -\] - -## 8. M-step: Abundance Update - -The updated abundance of target \(t\) is the average posterior mass assigned to it: - -\[ -A_t^{\mathrm{new}} = -\frac{1}{Q} -\sum_q -\sum_{h \in H_q,\ t(h)=t} -P(h \mid q) -\] - -This is the EM abundance update used in the code. - -## 9. Log-Likelihood - -For each query \(q\), the code evaluates: - -\[ -\ell_q = -\sum_{h \in H_q} P(h \mid q)\log f_h -- -\log\left(\sum_{h \in H_q} f_h\right) -\] - -The average log-likelihood is: - -\[ -LL = \frac{1}{Q} \sum_q \ell_q -\] - -This value is used to monitor whether the accelerated update is acceptable. - -## 10. Simplex Projection - -After acceleration, abundance values may become negative or fail to sum to 1. - -The code projects back to the simplex by: - -1. clipping negatives: - -\[ -x_i' = \max(x_i, 0) -\] - -2. renormalizing if the total is positive: - -\[ -x_i'' = \frac{x_i'}{\sum_j x_j'} -\] - -## 11. SQUAREM Acceleration - -Let: - -\[ -x_1 = EM(x_0), \qquad x_2 = EM(x_1) -\] - -Define: - -\[ -r = x_1 - x_0 -\] - -\[ -v = x_2 - x_1 - r -\] - -Their Euclidean norms are: - -\[ -\|r\| = \sqrt{\sum_i r_i^2}, \qquad -\|v\| = \sqrt{\sum_i v_i^2} -\] - -The acceleration factor is: - -\[ -a = -\begin{cases} --1, & \|v\| = 0 \\ --\frac{\|r\|}{\|v\|}, & \|v\| > 0 -\end{cases} -\] - -Then it is clipped to: - -\[ -a \in [-1, 1] -\] - -The accelerated update is: - -\[ -x_{\mathrm{new}} = x_0 - 2ar + a^2 v -\] - -Finally, `projectSimplex(x_new)` is applied. - -## 12. Likelihood Safeguard - -If the accelerated update lowers the log-likelihood, the code falls back to the second EM step: - -\[ -x_{\mathrm{new}} \leftarrow x_2 -\] - -This keeps the iteration stable. - -## 13. Convergence Criterion - -The stopping criterion is the maximum absolute parameter change: - -\[ -\Delta = \max_i |x_{\mathrm{new}, i} - x_{0,i}| -\] - -Convergence is declared when: - -\[ -\Delta < \mathrm{tol} -\] - -after at least 6 iterations. - -## 14. Final Ranking - -After convergence, hits for each query are sorted by posterior probability: - -\[ -h_i \prec h_j \quad \text{if} \quad P(h_i \mid q) > P(h_j \mid q) -\] - -If two hits have equal posterior, MMseqs2's default hit comparison is used as a tie-breaker. - -## Summary - -The model implemented in `reclassify_taxonomy.cpp` is: - -\[ -f_h = -\exp\left(\lambda \frac{s_h}{S_{\max,q}}\right) -\cdot -A_{t(h)}^{\alpha} -\cdot -E_{t(h)}^{\gamma} -\] - -with EM updates: - -\[ -P(h \mid q) = \frac{f_h}{\sum_{h' \in H_q} f_{h'}} -\] - -\[ -A_t^{\mathrm{new}} = -\frac{1}{Q} -\sum_q -\sum_{h \in H_q,\ t(h)=t} -P(h \mid q) -\] - -and SQUAREM acceleration applied to the abundance vector. diff --git a/EM_abundance.md b/EM_abundance.md new file mode 100644 index 000000000..359c1de2d --- /dev/null +++ b/EM_abundance.md @@ -0,0 +1,64 @@ +# EM Abundance Notes 04/13 + +This document summarizes the outputs implemented in `src/util/EM_abundnace.cpp`. + +## Overview + +`mmseqs abundance` reads the reclassified alignment DB produced by `mmseqs reclassify` and produces two possible outputs: + +- Per-target abundance table (default). +- Kraken-style report when `--taxonomy 1` is set. + +The input alignment DB must include the posterior probability for each hit. The current implementation reads the posterior from the `seqId` field, or from an extra trailing column if present. + +## Example usage + +- mmseqs abundance queryDB targetDB newDB abundance.tsv +- mmseqs abundance queryDB targetDB newDB abundance.report --taxonomy 1 + +Notes: +- `newDB` should be created by `mmseqs reclassify`. +- When `--taxonomy 1` is used, `targetDB_mapping` and `targetDB_taxonomy` are required. + +## Per-target abundance table + +For each target, the command reports: +- target key and identifier +- abundance percentage +- coverage confidence +- drop flag based on low-abundance filtering +- mapped length and target length + +This table is written to the output path (typically `.tsv`). + +Drop handling: +- The per-target table includes all targets and marks low-abundance filtered ones in `Drop(y/n)`. +- When `--taxonomy 1` is used, dropped targets are removed before building Kraken/Bracken reports. + +## Kraken-style report + +When `--taxonomy 1` is set, the output is a Kraken-style report with fields: + +- percent of reads in the clade +- clade read count +- direct (taxon) read count +- rank code +- taxid +- name + +The report is written to the output path (typically `.report`). + +Notes on Kraken compatibility: +- The report format mirrors Kraken's `--report` layout, but values are not guaranteed to be identical to Kraken outputs. +- Counts are derived from EM abundance by converting percent to expected reads (with rounding), then aggregating clade/direct counts via the taxonomy tree. +- Ordering and handling of missing or unclassified taxa can differ from Kraken's implementation. + +## Abundance from posterior + +For each target $t$, abundance is computed from posteriors as: + +$$ +\hat{A}_t = \frac{1}{|Q|} \sum_q R_{qt} +$$ + +`R_{qt}` is the posterior for hit $(q,t)$ in `newDB`. Abundances are then converted to percentages and filtered by a low-abundance tail cutoff. diff --git a/EM_paper.md b/EM_paper.md deleted file mode 100644 index f16156bed..000000000 --- a/EM_paper.md +++ /dev/null @@ -1,317 +0,0 @@ -# A Paper-Style Description of the EM Reclassification Model - -## Abstract - -The reclassification procedure implemented in `reclassify_taxonomy.cpp` refines target ranking for each query by combining local alignment evidence with global target-level support. The model assigns each hit a latent posterior probability and iteratively updates target abundances using an expectation-maximization framework. To improve convergence speed, the abundance update is accelerated using SQUAREM. - -## Model Setup - -Let \(q \in \{1, \dots, Q\}\) index queries, and let \(H_q\) denote the set of candidate hits for query \(q\). For each hit \(h \in H_q\), define: - -- \(t(h)\): target assigned to hit \(h\) -- \(s_h\): alignment score - -For each target \(t\), the model estimates: - -- abundance \(A_t\) -- entropy-derived penalty \(E_t\) - -The free model parameters are: - -\[ -\lambda, \alpha, \gamma -\] - -which control the relative effect of alignment score, abundance, and entropy penalty. - -## Hit Likelihood Term - -For each query \(q\), define the query-specific normalization term: - -\[ -S_{\max,q} = \max_{h \in H_q} s_h -\] - -In the current implementation, this is the maximum observed bit score among the candidate hits of query \(q\). - -For each hit \(h\), abundance and entropy penalty are clipped from below: - -\[ -\tilde{A}_h = \max(A_{t(h)}, \varepsilon) -\] - -\[ -\tilde{E}_h = \max(E_{t(h)}, \varepsilon) -\] - -with - -\[ -\varepsilon = 10^{-12} -\] - -The unnormalized compatibility score of hit \(h\) is: - -\[ -f_h = -\exp\left(\lambda \frac{s_h}{S_{\max,q}}\right) -\cdot \tilde{A}_h^{\alpha} -\cdot \tilde{E}_h^{\gamma} -\] - -This quantity defines the relative support for assigning query \(q\) to target \(t(h)\). - -## Entropy-Based Penalty - -For a hit \(h\), let its target-covered length be: - -\[ -\ell_h = dbEndPos_h - dbStartPos_h + 1 -\] - -The hit contributes position-wise target coverage mass: - -\[ -m_h = \frac{\exp(\lambda s_h)}{\ell_h} -\] - -For target \(t\), the accumulated coverage at target position \(p\) is: - -\[ -C_t(p) = \sum_{h \text{ covering } p,\ t(h)=t} m_h -\] - -The normalized positional distribution is: - -\[ -p_t(p) = \frac{C_t(p)}{\sum_{p'} C_t(p')} -\] - -The target entropy is then: - -\[ -H_t = - \sum_p p_t(p)\log_2 p_t(p) -\] - -Let - -\[ -H_{\mathrm{sum}} = \sum_t H_t -\] - -The penalty assigned to target \(t\) is: - -\[ -E_t = -\begin{cases} -1 - \frac{H_t}{H_{\mathrm{sum}}}, & H_{\mathrm{sum}} > 0 \\ -0, & H_{\mathrm{sum}} = 0 -\end{cases} -\] - -This term modulates target support according to the spatial distribution of its matched regions. - -## Initialization - -The abundance is initialized by normalized query-local alignment support. For each query \(q\), define: - -\[ -S_q = \sum_{h \in H_q} s_h -\] - -Then the initial abundance of target \(t\) is: - -\[ -A_t^{(0)} = -\frac{1}{Q} -\sum_q -\sum_{h \in H_q,\ t(h)=t} -\frac{s_h}{S_q} -\] - -This initialization gives each query unit mass, distributed proportionally to hit score. - -## E-step - -Given the current abundances \(A_t\), posterior probabilities are assigned to hits within each query: - -\[ -P(h \mid q) = -\frac{f_h}{\sum_{h' \in H_q} f_{h'}} -\] - -whenever the denominator is positive; otherwise the posterior is set to zero. - -Thus, the E-step computes a soft assignment of each query to its candidate targets. - -## M-step - -The target abundance is updated by averaging posterior mass across all queries: - -\[ -A_t^{\mathrm{new}} = -\frac{1}{Q} -\sum_q -\sum_{h \in H_q,\ t(h)=t} -P(h \mid q) -\] - -The entropy penalty is kept fixed during the EM iteration after its initial computation. - -## Objective Function - -The code evaluates an average query-normalized log-likelihood-like objective: - -\[ -\ell_q = -\sum_{h \in H_q} P(h \mid q)\log f_h -- -\log\left(\sum_{h \in H_q} f_h\right) -\] - -and - -\[ -LL = \frac{1}{Q}\sum_q \ell_q -\] - -This objective is used to reject unstable accelerated updates. - -## SQUAREM Acceleration - -Let \(EM(x)\) denote one EM abundance update applied to abundance vector \(x\). Starting from \(x_0\), compute: - -\[ -x_1 = EM(x_0), \qquad x_2 = EM(x_1) -\] - -Define: - -\[ -r = x_1 - x_0 -\] - -\[ -v = x_2 - x_1 - r -\] - -with Euclidean norms: - -\[ -\|r\| = \sqrt{\sum_i r_i^2}, \qquad -\|v\| = \sqrt{\sum_i v_i^2} -\] - -The acceleration parameter is: - -\[ -a = -\begin{cases} --1, & \|v\| = 0 \\ --\frac{\|r\|}{\|v\|}, & \|v\| > 0 -\end{cases} -\] - -and is clipped into the interval: - -\[ -a \in [-1,1] -\] - -The accelerated proposal is: - -\[ -x_{\mathrm{new}} = x_0 - 2ar + a^2 v -\] - -Since \(x_{\mathrm{new}}\) must remain a valid abundance vector, it is projected back onto the simplex. - -## Simplex Projection - -Given a proposed abundance vector \(x\), projection is implemented as: - -1. clip negative entries: - -\[ -x_i' = \max(x_i, 0) -\] - -2. normalize: - -\[ -x_i'' = \frac{x_i'}{\sum_j x_j'} -\] - -if the denominator is positive. - -This ensures abundances are nonnegative and sum to 1. - -## Safeguard Step - -If the accelerated proposal decreases the objective, the algorithm falls back to the second ordinary EM iterate: - -\[ -x_{\mathrm{new}} \leftarrow x_2 -\] - -This provides a monotonicity safeguard against unstable extrapolation. - -## Convergence Criterion - -The iteration stops when the maximum coordinate-wise parameter change becomes sufficiently small: - -\[ -\Delta = \max_i |x_{\mathrm{new},i} - x_{0,i}| -\] - -Convergence is declared when: - -\[ -\Delta < \mathrm{tol} -\] - -after an initial burn-in of at least six iterations. - -## Final Ranking - -After convergence, hits are sorted by posterior probability: - -\[ -h_i \prec h_j -\quad \Longleftrightarrow \quad -P(h_i \mid q) > P(h_j \mid q) -\] - -with MMseqs2's default hit comparison used to break ties. - -## Final Model Summary - -The reclassification model can be summarized as: - -\[ -f_h = -\exp\left(\lambda \frac{s_h}{S_{\max,q}}\right) -\cdot -A_{t(h)}^{\alpha} -\cdot -E_{t(h)}^{\gamma} -\] - -with posterior assignment: - -\[ -P(h \mid q) = -\frac{f_h}{\sum_{h' \in H_q} f_{h'}} -\] - -and abundance update: - -\[ -A_t^{\mathrm{new}} = -\frac{1}{Q} -\sum_q -\sum_{h \in H_q,\ t(h)=t} -P(h \mid q) -\] - -This framework combines local alignment quality with global target prevalence to produce a query-specific posterior ranking of targets. diff --git a/EM_reclassify.md b/EM_reclassify.md new file mode 100644 index 000000000..13957871a --- /dev/null +++ b/EM_reclassify.md @@ -0,0 +1,191 @@ +# EM_reclassify.cpp Summary (04/12) + +## Overview +- This module performs EM-based reclassification of alignment hits. +- It updates target abundances and posterior probabilities, optionally drops low-abundance targets, and writes a new alignment DB where hits are sorted by posterior. +- Primary input: alignment result DB (`par.db3`). Primary output: reclassified alignment DB (`par.db4`). + +## Example Usage Order +- `mmseqs createdb query.fasta queryDB` +- `mmseqs search queryDB targetDB alignDB tmp -a` +- `mmseqs reclassify queryDB targetDB alignDB newDB` +- `mmseqs convertalis queryDB targetDB newDB reclassify_result.m8` +- `mmseqs abundance queryDB targetDB newDB abundance.tsv --taxonomy 1` + +Notes: +- Plain custom FASTA target DB: use `createdb` + `createtaxdb` as needed. +- Prebuilt taxonomy-ready MMseqs DB: `createtaxdb` is usually not needed. + +## Key Data Structures +- `ReclassTaxEntry`: per-hit working record with `abundance`, `posterior`, `coverageConfidence`. +- `MappingTable`: `queryKey -> vector`. +- `ReclassTaxContext`: full working set (mapping table, query order, target set, query count, output-format flags). +- `TargetStats`: per-target aggregate stats used for filtering/reporting. + +## Mathematical Definitions +Let `q` be a query and `t` a target hit for `q`. + +- Normalized score: + +$$ +s_{q,t} = \frac{\text{score}_{q,t}}{\max_{t'} \text{score}_{q,t'}} +$$ + +- Score term used in posterior update: + +$$ +\text{scoreTerm}_{q,t} = \exp(\lambda \cdot s_{q,t}) \cdot a_t^{\alpha} \cdot c_t^{\gamma} +$$ + +where `a_t` is abundance, `c_t` is coverage confidence. + +- Posterior: + +$$ +p_{q,t} = \frac{\text{scoreTerm}_{q,t}}{\sum_{t'} \text{scoreTerm}_{q,t'}} +$$ + +- Abundance update: + +$$ +a_t^{\text{new}} = \frac{1}{|Q|}\sum_{q \in Q} p_{q,t} +$$ + +- Log-likelihood (average per query): + +$$ +\mathcal{L} = \frac{1}{|Q|}\sum_{q \in Q}\left(\sum_t p_{q,t}\log(\text{scoreTerm}_{q,t}) - \log\sum_t \text{scoreTerm}_{q,t}\right) +$$ + +## Coverage Confidence +Coverage confidence is computed once at initialization and then kept fixed during EM. + +### 1) Query-level hit weight (score normalization) +For each query `q`, define hit weight by plain bit-score normalization (no `exp`, no `lambda`): + +$$ +w_{q,h} = \frac{\text{score}_{q,h}}{\sum_{h'} \text{score}_{q,h'}} +$$ + +### 2) Position-wise accumulation on each target +For each target position `p`, accumulate weighted support: + +$$ +\text{covConf}_t(p) = \sum_{h: p \in [\text{start}_h, \text{end}_h]} w_{q,h} +$$ + +Clip each position to avoid over-crediting stacked/repeat mappings at one locus: + +$$ +\tilde{c}_t(p) = \min(1, \text{covConf}_t(p)) +$$ + +### 3) Base coverage over observed span +Observed span length: + +$$ +L_t = \max(1, \text{endPos}_t - \text{startPos}_t + 1) +$$ + +Base coverage fraction: + +$$ +f_t = \frac{1}{L_t}\sum_{p=1}^{L_t}\tilde{c}_t(p) +$$ + +### 4) Concentration penalty (HHI-based) +Define concentration (Herfindahl-Hirschman Index over clipped position mass): + +$$ +\text{HHI}_t = \frac{\sum_{p=1}^{L_t}\tilde{c}_t(p)^2}{\left(\sum_{p=1}^{L_t}\tilde{c}_t(p)\right)^2} +$$ + +with the edge case: + +$$ +\text{HHI}_t = 1 \quad \text{if } \sum_{p}\tilde{c}_t(p)=0 +$$ + +Convert to dispersion reward (low when concentrated, high when spread): + +$$ +\text{penalty}_t = 1 - \text{HHI}_t +$$ + +### 5) Final coverage confidence + +$$ +c_t = \operatorname{clamp}_{[0,1]}\left(f_t \cdot \text{penalty}_t\right) +$$ + +Important details: +- Position contribution is clipped by `min(1, ...)`. +- Normalization length is **observed span** (`maxEnd - minStart + 1`), not full `dbLen`. +- Final value includes HHI-based concentration penalty and is clamped to `[0, 1]`. +- `lambda` still affects posterior via `scoreTerm`, but no longer affects coverage-confidence initialization. +- Behavior intent: + - Many hits stacked in one local region: high concentration -> high HHI -> small `(1-HHI)` -> lower `c_t`. + - Hits distributed across many positions: lower concentration -> lower HHI -> larger `(1-HHI)` -> higher `c_t`. + +## Processing Pipeline +1. Load alignment DB +- `loadAlignmentDb()` parses alignment rows into `mappingTable`. +- Keeps backtrace/ORF layout flags for output compatibility. + +2. Initialize +- `initAbundance()` initializes abundance from query-normalized hit scores. + +### Abundance Initialization (초기값 계산) + +각 쿼리 $q$에 대해, 해당 쿼리의 모든 target hit $t$에 대해 bit score를 정규화합니다: + +$$ +w_{q,t} = \frac{\text{score}_{q,t}}{\sum_{t'} \text{score}_{q,t'}} +$$ + +각 target $t$의 초기 abundance는, 해당 target이 등장한 모든 쿼리에서의 $w_{q,t}$ 값을 평균낸 값입니다: + +$$ +a_t^{(0)} = \frac{1}{|Q|} \sum_{q \in Q} w_{q,t} +$$ + +즉, 각 target의 abundance는 쿼리별로 score-normalized weight를 합산한 뒤 전체 쿼리 수로 나누어 초기화합니다. +- `initCoverageConfidence()` computes per-target `coverageConfidence` using `score/sum(score)` weights. + +3. EM + SQUAREM +- `computePosterior()` computes per-query posterior. +- `emUpdate()` updates abundance; coverage confidence remains fixed. +- `squarem()` accelerates EM with extrapolation and simplex projection. +- If LL decreases, fallback to conservative step (`x2`). + +4. Optional target filtering +- `collectTargetStats()` builds target-level abundance/coverage/interval stats. +- `selectDroppedTargets()` chooses low-tail drops under mass cap. +- `applyDroppedTargets()` removes dropped targets from mapping/query/target sets. + +5. Write output DB +- `writeReclassifiedDb()` writes hits sorted by posterior. +- Posterior is written into output `seqId` field. + +## Parameters (Current Defaults) +- `--lambda` (`reclassifyLambda`): `2` +- `--alpha` (`reclassifyAlpha`): `1.0` +- `--gamma` (`reclassifyGamma`): `1.0` +- `--max-iter` (`reclassifyMaxIterations`): `100` +- `--tol` (`reclassifyTolerance`): `1e-5` +- `--drop-percentage` (`reclassifyMaxDropPercentage`): `10.0` + +## Output/Storage Clarification +- `reclassify` output DB stores posterior in `seqId`. +- `coverageConfidence` is not persisted in alignment DB columns. +- `abundance` recomputes `coverageConfidence` from input alignments and writes it to `abundance.tsv`. + +## Consistency Across Modules +The updated coverage-confidence logic (`score/sum(score)` + observed-span normalization + HHI penalty) is applied consistently in: +- `src/util/EM_reclassify.cpp` +- `src/util/EM_abundnace.cpp` + +## Logging +- Coverage-confidence initialization complete message. +- Iteration LL and parameter delta. +- Number of dropped targets and applied abundance cutoff. diff --git a/countNUM.py b/countNUM.py new file mode 100644 index 000000000..58f8477d4 --- /dev/null +++ b/countNUM.py @@ -0,0 +1,123 @@ +#Counting the unique target ID from result2.m8 & reclass_result.txt +from collections import defaultdict +import pdb + +PATH1 = '/home/yakim/benchmark/result_mmseq.m8' # mmseq +PATH2 = '/home/yakim/benchmark/reclass_bit.m8' # after reclassification - bit score +PATH3 = '/home/yakim/benchmark/reclass_seqid.m8' # after reclassification - seq id +PATH4 = '/home/yakim/benchmark/reclass0922.m8' # after reclassification - C++ (seqid) + + +# MMseq2 result count +qandt = [] +hit_num = 0 +redund1 = [] +with open(PATH1, 'r') as f: + for line in f: + a = line.split() + qandt.append([a[0],a[1]]) + +for q, t in qandt: + query = q.split('_')[-1] #query + qnum = q.split('_')[0] + if qnum not in redund1: + redund1.append(qnum) + if query == t: + hit_num += 1 +print('Hit # of mmseqs result:', hit_num, '&& Total #:', len(redund1)) + + +# Mmseq2 result count (count also for every top 10) +''' +hit_num1 = 0 +grouped = defaultdict(list) + +with open(PATH1, 'r') as f: + for line in f: + a = line.split() + query_full = a[0] # query705_from_UniRef90_A0AA46S9G1 + target = a[1] # A0AA46S9G1 + seq_identity = round(float(a[2]), 3) + query_num = query_full.split('_')[0] # query705 + query_last_id = query_full.split('_')[-1] # A0AA46S9G1 + + # query 번호별로 그룹핑 + grouped[query_num].append({ + 'query_last_id': query_last_id, + 'target': target, + 'seq_identity': seq_identity + }) # row[0] row[1] row[2] + #{query1: (A0AA46S9G1, A0AA46S9R1, 0.9), (A0AA46S9G1, A0AA46S9G1, 0.9), (A0AA46S9G1, A0AA34S9T1, 0.85)} + +for query_num, rows in grouped.items(): + if not rows: + continue + top_seq_identity = rows[0]['seq_identity'] # 0.9 + top_query_last_id = rows[0]['query_last_id'] # A0AA46S9G1 + + count = sum( + (row['query_last_id'] == row['target']) and + (row['seq_identity'] == top_seq_identity) + for row in rows + ) + #print(f'{query_num}: {count}') # 각 그룹별로 개수 출력 + hit_num1 += count + +print('Hit # of mmseqs result(top 10):', hit_num1, '&& Total #:', len(redund)) +''' + +# After reclassification result (bit score) +qandt2 = [] +hit_num2 = 0 +redund2 = [] +with open(PATH2, 'r') as f: + for line in f: + a = line.split() + qandt2.append([a[0],a[1]]) + +for q, t in qandt2: + query = q.split('_')[-1] #query + qnum = q.split('_')[0] + if qnum not in redund2: + redund2.append(qnum) + if query == t: + hit_num2 += 1 +print('Hit # of reclassification result(bit score):', hit_num2, '&& Total #:', len(redund2)) + + +# After reclassification result (seq id) +qandt3 = [] +hit_num3 = 0 +redund3 = [] +with open(PATH3, 'r') as f: + for line in f: + a = line.split() + qandt3.append([a[0],a[1]]) + +for q, t in qandt3: + query = q.split('_')[-1] #query + qnum = q.split('_')[0] + if qnum not in redund3: + redund3.append(qnum) + if query == t: + hit_num3 += 1 +print('Hit # of reclassification result(seqid):', hit_num3, '&& Total #:', len(redund3)) + +# after reclass result count (C++) +qandt4 = [] +hit_num4 = 0 +redund4 = [] +with open(PATH4, 'r') as f: + for line in f: + a = line.split() + qandt4.append([a[0],a[1]]) + +for q, t in qandt4: + query = q.split('_')[-1] #query + qnum = q.split('_')[0] + if qnum not in redund4: + redund4.append(qnum) + if query == t: + hit_num4 += 1 +print('Hit # of reclassification result(C++):', hit_num4, '&& Total #:', len(redund4)) + diff --git a/old.md b/old.md new file mode 100644 index 000000000..84be02a80 --- /dev/null +++ b/old.md @@ -0,0 +1,174 @@ +# Reclassify Taxonomy2 Equations + +This document summarizes the equations implemented in `EM/reclassify-taxonomy2.cpp`. + What reclassify adds: + + - It does not treat each query in isolation. + - It estimates target abundance over all matched queries. + - It uses that abundance to boost hits to globally supported targets and downweight isolated or ambiguous hits. + - It penalizes high-entropy or diffuse mappings. + - It iterates this with an EM-style procedure, accelerated with SQUAREM, until convergence. + +You can use the 'mmseqs reclassify' after running 'mmseqs search' +## Ex usage (search -> align -> (createtaxdb) -> reclassify) +- mmseqs createdb query.fasta queryDB +- mmseqs search queryDB targetDB alignDB tmp -a +- mmseqs reclassify queryDB targetDB alignmentDB newDB +- mmseqs convertalis queryDB targetDB newnDB(or alignDB) reclassify_result.m8 +- mmseqs abundance queryDB targetDB newDB abundance.report (--taxonomy 1) +* Plain custom FASTA target DB: createdb + createtaxdb +* Prebuilt taxonomy-ready MMseqs database: createtaxdb usually not needed + +## Notation + +- Query index: $q$ +- Target/protein index: $t$ +- Hit list for query $q$: $H(q)$ +- Bit score: $s_{qt}$ +- Query max score: $S_q = \max_{t \in H(q)} s_{qt}$ +- Target abundance: $A_t$ +- Entropy value: $E_t$ +- Entropy penalty: $P_t$ +- Posterior for hit $(q,t)$: $R_{qt}$ +- Parameters: $\lambda, \alpha, \gamma$ +- Small constant: $\varepsilon = 10^{-12}$ + +## Abundance initialization + +For each query $q$, compute the sum of scores over its hit list: + +$$ +S^{\text{sum}}_q = \sum_{t \in H(q)} s_{qt} +$$ + +If $S^{\text{sum}}_q > 0$, each target gets a fractional count: + +$$ +c_{qt} = \frac{s_{qt}}{S^{\text{sum}}_q} +$$ + +Aggregate counts across all queries and normalize by $|Q|$: + +$$ +A_t = \frac{1}{|Q|} \sum_q c_{qt} +$$ + +If $|Q| = 0$, all $A_t$ are treated as $0$. + +## Entropy value and penalty + +For each target $t$, find its global alignment bounds: + +$$ + {minPos}_t = \min_q \text{dbStart}_{qt}, \quad + {maxPos}_t = \max_q \text{dbEnd}_{qt} +$$ + +Coverage is accumulated over the interval $[\text{minPos}_t,\text{maxPos}_t]$. +For each hit $(q,t)$: + +$$ +\ell_{qt} = \text{dbEnd}_{qt} - \text{dbStart}_{qt} + 1 +$$ + +$$ +m_{qt} = \frac{\exp(\lambda s_{qt})}{\ell_{qt}} +$$ + +For each position $p$ covered by the hit: + +$$ + ext{cov}_t(p) \mathrel{+}= m_{qt} +$$ + +Normalize coverage to probabilities: + +$$ +P_t(p) = \frac{\text{cov}_t(p)}{\sum_{p'} \text{cov}_t(p')} +$$ + +Entropy for target $t$: + +$$ +E_t = -\sum_p P_t(p) \log_2 P_t(p) +$$ + +Entropy penalty (let $E^{\text{sum}} = \sum_{t'} E_{t'}$): + +$$ +P_t = 1 - \frac{E_t}{E^{\text{sum}}} +$$ + +If $E^{\text{sum}} \le 0$, then $P_t = 0$. + +## Score term + +Normalized score for a hit: + +$$ + {ns}_{qt} = \frac{s_{qt}}{S_q} +$$ + +Score term used for posteriors: + +$$ +T_{qt} = \exp(\lambda \text{ns}_{qt}) \cdot \max(A_t,\varepsilon)^{\alpha} \cdot \max(P_t,\varepsilon)^{\gamma} +$$ + +If $S_q \le 0$, then $T_{qt} = 0$ for all $t \in H(q)$. + +## Posterior for each query + +For each query $q$: + +$$ +Z_q = \sum_{t \in H(q)} T_{qt} +$$ + +$$ +R_{qt} = \frac{T_{qt}}{Z_q} +$$ + +If $Z_q = 0$, then $R_{qt} = 0$. + +## EM abundance update + +Given posteriors, the next abundance is: + +$$ +\hat{A}_t = \frac{1}{|Q|} \sum_q R_{qt} +$$ + +## Log-likelihood (per query) + +$$ +\mathcal{L} = \frac{1}{|Q|} \sum_q \left( \sum_t R_{qt} \log(T_{qt}) - \log(Z_q) \right) +$$ + +## SQUAREM step (acceleration) + +Two EM steps produce $x_1$ and $x_2$ from $x_0$: + +$$ +\mathbf{r} = x_1 - x_0, \quad \mathbf{v} = x_2 - x_1 - \mathbf{r} +$$ + +Acceleration scalar (clipped to $[-1, 1]$): + +$$ +a = -\frac{\|\mathbf{r}\|}{\|\mathbf{v}\|} +$$ + +Proposed update and simplex projection: + +$$ +\mathbf{x}_{\text{new}} = x_0 - 2a\mathbf{r} + a^2\mathbf{v}, \quad \mathbf{x}_{\text{new}} \in \Delta +$$ + +## Simplex projection + +Clamp negatives and renormalize: + +$$ +x_i \leftarrow \max(0, x_i), \quad x_i \leftarrow \frac{x_i}{\sum_j x_j} +$$ diff --git a/reclassify.md b/reclassify.md new file mode 100644 index 000000000..2fe9d6a3c --- /dev/null +++ b/reclassify.md @@ -0,0 +1,158 @@ +# Reclassify EM Notes + +This document summarizes the equations and outputs implemented in `src/util/EM_reclassify.cpp`. + +What reclassify adds: +- It does not treat each query in isolation. +- It estimates target abundance over all matched queries. +- It uses that abundance to boost hits to globally supported targets and downweight isolated or ambiguous hits. +- It penalizes high-entropy or diffuse mappings. +- It iterates this with an EM-style procedure, accelerated with SQUAREM, until convergence. + +You can use `mmseqs reclassify` after running `mmseqs search` and (optional) `mmseqs align -a`. + +## Example usage (search -> align -> reclassify) +- mmseqs createdb query.fasta queryDB +- mmseqs search queryDB targetDB alignDB tmp -a +- mmseqs reclassify queryDB targetDB alignDB newDB +- mmseqs convertalis queryDB targetDB newDB reclassify_result.m8 +- mmseqs abundance queryDB targetDB newDB abundance.report (--taxonomy 1) + +Notes: +- `newDB` is an alignment DB in the same format as `alignDB`, but ordered by posterior probability per query. +- The posterior probability replaces the `seqId` field in `newDB` to keep the column count unchanged. + +## Notation + +- Query index: $q$ +- Target/protein index: $t$ +- Hit list for query $q$: $H(q)$ +- Bit score: $s_{qt}$ +- Query max score: $S_q = \max_{t \in H(q)} s_{qt}$ +- Target abundance: $A_t$ +- Coverage confidence: $C_t$ +- Posterior for hit $(q,t)$: $R_{qt}$ +- Parameters: $\lambda, \alpha, \gamma$ +- Small constant: $\varepsilon = 10^{-12}$ + +## Abundance initialization + +For each query $q$, compute the sum of scores over its hit list: + +$$ +S^{\text{sum}}_q = \sum_{t \in H(q)} s_{qt} +$$ + +If $S^{\text{sum}}_q > 0$, each target gets a fractional count: + +$$ +c_{qt} = \frac{s_{qt}}{S^{\text{sum}}_q} +$$ + +Aggregate counts across all queries and normalize by $|Q|$: + +$$ +A_t = \frac{1}{|Q|} \sum_q c_{qt} +$$ + +If $|Q| = 0$, all $A_t$ are treated as $0$. + +## Coverage confidence + +For each target $t$, find its global alignment bounds: + +$$ + ext{minPos}_t = \min_q \text{dbStart}_{qt}, \quad + ext{maxPos}_t = \max_q \text{dbEnd}_{qt} +$$ + +Coverage is accumulated over the interval $[\text{minPos}_t,\text{maxPos}_t]$. +For each hit $(q,t)$: + +$$ +\ell_{qt} = \text{dbEnd}_{qt} - \text{dbStart}_{qt} + 1 +$$ + +$$ +m_{qt} = \frac{\exp(\lambda s_{qt})}{\ell_{qt}} +$$ + +For each position $p$ covered by the hit: + +$$ + ext{cov}_t(p) \mathrel{+}= m_{qt} +$$ + +The per-target coverage confidence $C_t$ is the fraction of target positions where the accumulated per-hit weights reach at least 1.0. Values are clipped to $[0,1]$. + +## Score term + +Normalized score for a hit: + +$$ +\text{ns}_{qt} = \frac{s_{qt}}{S_q} +$$ + +Score term used for posteriors: + +$$ +T_{qt} = \exp(\lambda \text{ns}_{qt}) \cdot \max(A_t,\varepsilon)^{\alpha} \cdot \max(C_t,\varepsilon)^{\gamma} +$$ + +If $S_q \le 0$, then $T_{qt} = 0$ for all $t \in H(q)$. + +## Posterior for each query + +For each query $q$: + +$$ +Z_q = \sum_{t \in H(q)} T_{qt} +$$ + +$$ +R_{qt} = \frac{T_{qt}}{Z_q} +$$ + +If $Z_q = 0$, then $R_{qt} = 0$. + +## EM abundance update + +Given posteriors, the next abundance is: + +$$ +\hat{A}_t = \frac{1}{|Q|} \sum_q R_{qt} +$$ + +## Log-likelihood (per query) + +$$ +\mathcal{L} = \frac{1}{|Q|} \sum_q \left( \sum_t R_{qt} \log(T_{qt}) - \log(Z_q) \right) +$$ + +## SQUAREM step (acceleration) + +Two EM steps produce $x_1$ and $x_2$ from $x_0$: + +$$ +\mathbf{r} = x_1 - x_0, \quad \mathbf{v} = x_2 - x_1 - \mathbf{r} +$$ + +Acceleration scalar (clipped to $[-1, 1]$): + +$$ +a = -\frac{\|\mathbf{r}\|}{\|\mathbf{v}\|} +$$ + +Proposed update and simplex projection: + +$$ +\mathbf{x}_{\text{new}} = x_0 - 2a\mathbf{r} + a^2\mathbf{v}, \quad \mathbf{x}_{\text{new}} \in \Delta +$$ + +## Simplex projection + +Clamp negatives and renormalize: + +$$ +x_i \leftarrow \max(0, x_i), \quad x_i \leftarrow \frac{x_i}{\sum_j x_j} +$$ diff --git a/reclassify0403.md b/reclassify0403.md new file mode 100644 index 000000000..3e9ab2e24 --- /dev/null +++ b/reclassify0403.md @@ -0,0 +1,205 @@ +# Reclassify 0406 Notes + +This document summarizes the logic and formulas in `src/util/reclassify_taxonomy.cpp`. + +## Overview + +- Input: an alignment-result DB (from `mmseqs search -a` or `mmseqs align`) with per-hit fields. +- Output: + - `new_alignment_result.m8` (re-ranked by posterior) + - `protein_abundance.tsv` (per-target stats) + - `taxonomy_abundance.tsv` (optional taxonomy aggregation) + +The pipeline is: +1) Read alignments into per-query hit lists. +2) Initialize abundance and coverage confidence. +3) Run EM with SQUAREM acceleration. +4) Drop low-abundance targets by tail cutoff. +5) Write outputs. + +## Ex usage order +- mmseqs createdb query.fasta queryDB +- mmseqs search queryDB targetDB alignDB tmp -a +- mmseqs reclassify queryDB targetDB alignmentDB newDB +- mmseqs convertalis queryDB targetDB newnDB(or alignDB) reclassify_result.m8 +- mmseqs abundance queryDB targetDB newDB abundance.report (--taxonomy 1) +* Plain custom FASTA target DB: createdb + createtaxdb +* Prebuilt taxonomy-ready MMseqs database: createtaxdb usually not needed + +## Data structures + +- Query key: the DB key for a query sequence. +- Target key: the DB key for a target sequence. +- Hit: a `Matcher::result_t` record (alignment fields, scores, positions). +- Per-hit state: + - abundance $A_t$ + - posterior $R_{qt}$ + - coverage confidence $C_t$ + +## Notation + +- Query index: $q$ +- Target/protein index: $t$ +- Hit list for query $q$: $H(q)$ +- Bit score: $s_{qt}$ +- Query max score: $S_q = \max_{t \in H(q)} s_{qt}$ +- Target abundance: $A_t$ +- Coverage confidence: $C_t$ +- Posterior for hit $(q,t)$: $R_{qt}$ +- Parameters: $\lambda, \alpha, \gamma$ +- Small constant: $\varepsilon = 10^{-12}$ + +## Step 1: Load alignment DB + +- Each DB entry is parsed into per-query hit lists. +- Targets are tracked in a global set. + +## Step 2: Abundance initialization + +For each query $q$, sum scores over hits: + +$$ +S^{\text{sum}}_q = \sum_{t \in H(q)} s_{qt} +$$ + +If $S^{\text{sum}}_q > 0$, assign fractional counts: + +$$ +c_{qt} = \frac{s_{qt}}{S^{\text{sum}}_q} +$$ + +Aggregate over all queries and normalize by $|Q|$: + +$$ +A_t = \frac{1}{|Q|} \sum_q c_{qt} +$$ + +## Step 3: Coverage confidence (per target) + +Threads usage: +- The coverage-confidence calculation is parallelized across targets with OpenMP. +- The loop over target IDs runs with `#pragma omp parallel for num_threads(threads) schedule(dynamic, 1)`. +- Other steps in this pipeline are single-threaded in this file. + +For each target $t$, compute global bounds: + +$$ +\text{minPos}_t = \min_q \text{dbStart}_{qt}, \quad +\text{maxPos}_t = \max_q \text{dbEnd}_{qt} +$$ + +For each hit $(q,t)$, define alignment length: + +$$ +\ell_{qt} = \text{dbEnd}_{qt} - \text{dbStart}_{qt} + 1 +$$ + +Define per-hit weight: + +$$ +\text{expScore}_{qt} = \exp(\lambda s_{qt}), \quad +w_{qt} = \frac{\text{expScore}_{qt}}{\sum_{t' \in H(q)} \text{expScore}_{qt'}} +$$ + +Coverage accumulation over the target interval: + +$$ +\text{cov}_t(p) \mathrel{+}= \frac{\text{expScore}_{qt}}{\ell_{qt}} +$$ + +Coverage confidence accumulation: + +$$ +\text{conf}_t(p) \mathrel{+}= w_{qt} +$$ + +Coverage confidence (per target): + +$$ +C_t = \mathrm{clamp}_{[0,1]}\left(\frac{\sum_p \min(1, \text{conf}_t(p))}{\text{targetLen}_t}\right) +$$ + +## Step 4: Score term and posterior + +Normalized score: + +$$ +\text{ns}_{qt} = \frac{s_{qt}}{S_q} +$$ + +Score term: + +$$ +T_{qt} = \exp(\lambda \text{ns}_{qt}) \cdot \max(A_t, \varepsilon)^{\alpha} \cdot \max(C_t, \varepsilon)^{\gamma} +$$ + +Posterior for query $q$: + +$$ +Z_q = \sum_{t \in H(q)} T_{qt}, \quad +R_{qt} = \frac{T_{qt}}{Z_q} +$$ + +## Step 5: EM update + +Abundance update: + +$$ +\hat{A}_t = \frac{1}{|Q|} \sum_q R_{qt} +$$ + +## Step 6: Log-likelihood (per query) + +$$ +\mathcal{L} = \frac{1}{|Q|} \sum_q \left( \sum_t R_{qt} \log(T_{qt}) - \log(Z_q) \right) +$$ + +## Step 7: SQUAREM acceleration + +Two EM steps produce $x_1$ and $x_2$ from $x_0$: + +$$ +\mathbf{r} = x_1 - x_0, \quad \mathbf{v} = x_2 - x_1 - \mathbf{r} +$$ + +Acceleration scalar (clipped to $[-1, 1]$): + +$$ +a = -\frac{\|\mathbf{r}\|}{\|\mathbf{v}\|} +$$ + +Proposed update and simplex projection: + +$$ +\mathbf{x}_{\text{new}} = x_0 - 2a\mathbf{r} + a^2\mathbf{v}, \quad \mathbf{x}_{\text{new}} \in \Delta +$$ + +Simplex projection: + +$$ +x_i \leftarrow \max(0, x_i), \quad x_i \leftarrow \frac{x_i}{\sum_j x_j} +$$ + +If the accelerated step decreases log-likelihood, the code falls back to $x_2$. + +## Step 8: Target filtering (low-abundance drop) + +- Compute a cutoff on target abundances using a largest-gap rule or tail quantile rule. +- Drop up to `maxDropPercentage` of targets (with a minimum tail size). +- Remove dropped targets from the mapping table and recompute output stats. + +## Outputs + +- `new_alignment_result.m8` + - Hits re-sorted by posterior. + - Alignment summary derived from backtrace if present. +- `protein_abundance.tsv` + - Per-target: abundance %, coverage confidence, drop flag, mapped interval(s). +- `taxonomy_abundance.tsv` (if taxonomy enabled) + - Aggregated by taxonomy ID. + +## Notes + +- `C_t` (coverage confidence) is used directly in the score term. +- Parallelism is used in coverage confidence initialization across targets. +- No other steps in this file use OpenMP; EM and output phases are serial here. diff --git a/src/CommandDeclarations.h b/src/CommandDeclarations.h index 6f4d75293..1edbfd451 100644 --- a/src/CommandDeclarations.h +++ b/src/CommandDeclarations.h @@ -101,7 +101,8 @@ extern int gappedprefilter(int argc, const char **argv, const Command& command); extern int unpackdb(int argc, const char **argv, const Command& command); extern int rbh(int argc, const char **argv, const Command& command); extern int recoverlongestorf(int argc, const char **argv, const Command& command); -extern int reclassifytaxonomy(int argc, const char **argv, const Command& command); +extern int emreclassify(int argc, const char **argv, const Command& command); +extern int emabundance(int argc, const char **argv, const Command& command); extern int result2flat(int argc, const char **argv, const Command& command); extern int result2msa(int argc, const char **argv, const Command& command); extern int result2dnamsa(int argc, const char **argv, const Command& command); diff --git a/src/MMseqsBase.cpp b/src/MMseqsBase.cpp index eab17f897..3a00fa00d 100644 --- a/src/MMseqsBase.cpp +++ b/src/MMseqsBase.cpp @@ -1072,17 +1072,26 @@ std::vector baseCommands = { " ", CITATION_MMSEQS2, {{"resultDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::resultDb }, {"resultDB", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::resultDb }}}, - {"reclassify", reclassifytaxonomy, &par.reclassify, COMMAND_RESULT | COMMAND_FORMAT_CONVERSION, - "Reclassify alignments and export default flat-file summaries", - "mmseqs reclassify queryDB targetDB alignmentDB outDir\n" + {"reclassify", emreclassify, &par.reclassify, COMMAND_RESULT | COMMAND_FORMAT_CONVERSION, + "Reclassify alignments and write a new alignment DB with posterior probabilities(convertalis will convert it into column 3(instead of seqId))", + "mmseqs reclassify queryDB targetDB alignmentDB newDB\n", + "Yaeji Kim", + " ", + CITATION_MMSEQS2, {{"queryDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::sequenceDb }, + {"targetDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::sequenceDb }, + {"alignmentDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }, + {"newDB", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }}}, + {"abundance", emabundance, &par.abundance, COMMAND_RESULT | COMMAND_FORMAT_CONVERSION, + "Summarize abundance from reclassified alignment DB", + "mmseqs abundance queryDB targetDB newDB abundance.tsv\n" "# targetDB_mapping and targetDB_taxonomy are only required with --taxonomy 1\n" - "mmseqs reclassify queryDB targetDB alignmentDB outDir --taxonomy 1\n", + "mmseqs abundance queryDB targetDB newDB abundance.report --taxonomy 1\n", "Yaeji Kim", - " ", + " ", CITATION_MMSEQS2 | CITATION_TAXONOMY, {{"queryDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::sequenceDb }, {"targetDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA|DbType::NEED_HEADER, &DbValidator::taxSequenceDb }, {"alignmentDB", DbType::ACCESS_MODE_INPUT, DbType::NEED_DATA, &DbValidator::alignmentDb }, - {"outDir", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::directory }}}, + {"abundanceFile", DbType::ACCESS_MODE_OUTPUT, DbType::NEED_DATA, &DbValidator::flatfile }}}, {"summarizealis", summarizealis, &par.threadsandcompression, COMMAND_RESULT, "Summarize alignment result to one row (uniq. cov., cov., avg. seq. id.)", NULL, diff --git a/src/commons/Parameters.cpp b/src/commons/Parameters.cpp index acbccd66a..49c4fa162 100644 --- a/src/commons/Parameters.cpp +++ b/src/commons/Parameters.cpp @@ -314,8 +314,8 @@ Parameters::Parameters(): PARAM_RECLASSIFY_GAMMA(PARAM_RECLASSIFY_GAMMA_ID, "--gamma", "Reclassify gamma", "Exponent applied to coverage confidence during reclassification", typeid(double), (void *) &reclassifyGamma, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), PARAM_RECLASSIFY_MAX_ITER(PARAM_RECLASSIFY_MAX_ITER_ID, "--max-iter", "Reclassify max iterations", "Maximum number of SQUAREM iterations for reclassification", typeid(int), (void *) &reclassifyMaxIterations, "^[1-9]{1}[0-9]*$"), PARAM_RECLASSIFY_TOL(PARAM_RECLASSIFY_TOL_ID, "--tol", "Reclassify tolerance", "Convergence tolerance for reclassification", typeid(double), (void *) &reclassifyTolerance, "^([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?)|([0-9]*(\\.[0-9]+)?)$"), - PARAM_RECLASSIFY_TAXONOMY(PARAM_RECLASSIFY_TAXONOMY_ID, "--taxonomy", "Reclassify taxonomy output", "0: write alignment and protein abundance only; taxonomy files are not required. 1: also write taxonomy_abundance.tsv and taxonomic columns in protein_abundance.tsv; requires targetDB_mapping and targetDB_taxonomy", typeid(int), (void *) &reclassifyTaxonomy, "^[0-1]{1}$"), - PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE(PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE_ID, "--max-drop-percentage", "Max drop percentage", "Maximum percentage of targets that the automatic jump-based filter may classify as a tail for dropping (range 0.0-100.0, default 30.0)", typeid(double), (void *) &reclassifyMaxDropPercentage, "^100(\\.0+)?$|^([0-9]|[1-9][0-9])(\\.[0-9]+)?$"), + PARAM_RECLASSIFY_TAXONOMY(PARAM_RECLASSIFY_TAXONOMY_ID, "--taxonomy", "Abundance taxonomy output", "0: write alignment and protein abundance only; taxonomy files are not required. 1: also write taxonomy_abundance.tsv and taxonomic columns in protein_abundance.tsv; requires targetDB_mapping and targetDB_taxonomy", typeid(int), (void *) &reclassifyTaxonomy, "^[0-1]{1}$"), + PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE(PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE_ID, "--drop-percentage", "Max drop percentage", "Maximum percentage of cumulative low-tail abundance mass that the automatic jump-based filter may drop (range 0.0-100.0, default 10.0)", typeid(double), (void *) &reclassifyMaxDropPercentage, "^100(\\.0+)?$|^([0-9]|[1-9][0-9])(\\.[0-9]+)?$"), // for modules that should handle -h themselves PARAM_HELP(PARAM_HELP_ID, "-h", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN), PARAM_HELP_LONG(PARAM_HELP_LONG_ID, "--help", "Help", "Help", typeid(bool), (void *) &help, "", MMseqsParameter::COMMAND_HIDDEN) @@ -349,12 +349,23 @@ Parameters::Parameters(): reclassify.push_back(&PARAM_RECLASSIFY_GAMMA); reclassify.push_back(&PARAM_RECLASSIFY_MAX_ITER); reclassify.push_back(&PARAM_RECLASSIFY_TOL); - reclassify.push_back(&PARAM_RECLASSIFY_TAXONOMY); reclassify.push_back(&PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE); reclassify.push_back(&PARAM_THREADS); reclassify.push_back(&PARAM_COMPRESSED); reclassify.push_back(&PARAM_V); + // abundance + abundance.push_back(&PARAM_RECLASSIFY_LAMBDA); + abundance.push_back(&PARAM_RECLASSIFY_ALPHA); + abundance.push_back(&PARAM_RECLASSIFY_GAMMA); + abundance.push_back(&PARAM_RECLASSIFY_MAX_ITER); + abundance.push_back(&PARAM_RECLASSIFY_TOL); + abundance.push_back(&PARAM_RECLASSIFY_TAXONOMY); + abundance.push_back(&PARAM_RECLASSIFY_MAX_DROP_PERCENTAGE); + abundance.push_back(&PARAM_THREADS); + abundance.push_back(&PARAM_COMPRESSED); + abundance.push_back(&PARAM_V); + // createclusearchdb createclusearchdb.push_back(&PARAM_THREADS); createclusearchdb.push_back(&PARAM_COMPRESSED); @@ -2658,13 +2669,13 @@ void Parameters::setDefaults() { unpackNameMode = Parameters::UNPACK_NAME_ACCESSION; // reclassify - reclassifyLambda = 0.02; + reclassifyLambda = 1.0; reclassifyAlpha = 1.0; reclassifyGamma = 1.0; reclassifyMaxIterations = 100; reclassifyTolerance = 1e-5; reclassifyTaxonomy = 0; - reclassifyMaxDropPercentage = 30.0; + reclassifyMaxDropPercentage = 10.0; lcaRanks = ""; showTaxLineage = 0; diff --git a/src/commons/Parameters.h b/src/commons/Parameters.h index 41c67ffd7..c882b8bac 100644 --- a/src/commons/Parameters.h +++ b/src/commons/Parameters.h @@ -1226,6 +1226,7 @@ class Parameters { std::vector gpuserver; std::vector tsv2exprofiledb; std::vector reclassify; + std::vector abundance; std::vector combineList(const std::vector &par1, const std::vector &par2); diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index 65562c7f4..656c04941 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -48,7 +48,8 @@ set(util_source_files util/profile2pssm.cpp util/profile2neff.cpp util/profile2seq.cpp - util/reclassify_taxonomy.cpp + util/EM_reclassify.cpp + util/EM_abundnace.cpp util/recoverlongestorf.cpp util/result2dnamsa.cpp util/result2flat.cpp diff --git a/src/util/EM_abundnace.cpp b/src/util/EM_abundnace.cpp new file mode 100644 index 000000000..1294743ee --- /dev/null +++ b/src/util/EM_abundnace.cpp @@ -0,0 +1,738 @@ +#include "Parameters.h" +#include "DBReader.h" +#include "Debug.h" +#include "Util.h" +#include "Matcher.h" +#include "FastSort.h" +#include "FileUtil.h" +#include "NcbiTaxonomy.h" +#include "MappingReader.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef OPENMP +#include +#endif + +namespace { +struct ReclassTaxEntry { + Matcher::result_t result; + double abundance; + double posterior; + double coverageConfidence; +}; + +typedef std::unordered_map > MappingTable; + +struct Interval { + int start; + int end; +}; + +struct TargetStats { + unsigned int key; + unsigned int targetLength; + double abundance; + double coverageConfidence; + bool dropped; + std::vector intervals; +}; + +struct ReclassTaxContext { + MappingTable mappingTable; + std::vector queryOrder; + std::unordered_set targetSet; + size_t queryCount; + bool hasBacktrace; + bool hasOrfPosition; + + ReclassTaxContext() : queryCount(0), hasBacktrace(false), hasOrfPosition(false) {} +}; + +static const double EPS = 1e-12; +static const size_t MIN_FILTER_TARGETS = 20; +static const size_t MIN_TAIL_TARGETS = 2; + +static double clamp01(double value); + +static std::vector targetListFromSet(const std::unordered_set &targets) { + std::vector out(targets.begin(), targets.end()); + SORT_SERIAL(out.begin(), out.end()); + return out; +} + +static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &ctx) { + Debug::Progress progress(reader.getSize()); + const char *entry[255]; + + for (size_t i = 0; i < reader.getSize(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = reader.getDbKey(i); + char *data = reader.getData(i, 0); + + if (reader.getEntryLen(i) <= 1) { + continue; + } + + std::vector &records = ctx.mappingTable[queryKey]; + if (records.empty()) { + ctx.queryOrder.push_back(queryKey); + } + while (*data != '\0') { + const size_t columns = Util::getWordsOfLine(data, entry, 255); + if (columns < Matcher::ALN_RES_WITHOUT_BT_COL_CNT) { + Debug(Debug::ERROR) << "Invalid alignment result record in query " << queryKey << ".\n"; + EXIT(EXIT_FAILURE); + } + + if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { + ctx.hasBacktrace = true; + } + if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { + ctx.hasOrfPosition = true; + } + + Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); + + // Always use the 3rd alignment column (seqId) as posterior. + // Do not consume optional trailing columns as posterior. + double posterior = static_cast(result.seqId); + + records.push_back(ReclassTaxEntry{result, 0.0, posterior, 0.0}); + ctx.targetSet.insert(result.dbKey); + data = Util::skipLine(data); + } + } + + ctx.queryCount = ctx.mappingTable.size(); +} + +struct TargetHitRef { + const ReclassTaxEntry *entry; + double score; + double weight; +}; + +static void initCoverageConfidence(MappingTable &mappingTable, + const std::unordered_set &targetSet, + int threads) { + (void)threads; + std::unordered_map targetMin; + std::unordered_map targetMax; + std::unordered_map > hitsByTarget; + + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + targetMin[*it] = std::numeric_limits::max(); + targetMax[*it] = std::numeric_limits::min(); + hitsByTarget.emplace(*it, std::vector()); + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + double scoreSum = 0.0; + for (size_t j = 0; j < it->second.size(); ++j) { + scoreSum += static_cast(it->second[j].result.score); + } + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + if (it->second[j].result.dbStartPos < targetMin[target]) { + targetMin[target] = it->second[j].result.dbStartPos; + } + if (it->second[j].result.dbEndPos > targetMax[target]) { + targetMax[target] = it->second[j].result.dbEndPos; + } + const double score = static_cast(it->second[j].result.score); + const double weight = (scoreSum > 0.0) ? (score / scoreSum) : 0.0; + hitsByTarget[target].push_back(TargetHitRef{&it->second[j], score, weight}); + } + } + + std::unordered_map coverageFraction; + coverageFraction.reserve(targetSet.size()); + const std::vector targetList = targetListFromSet(targetSet); + std::vector coverageFractionByIndex(targetList.size(), 0.0); + +#pragma omp parallel for num_threads(threads) schedule(dynamic, 1) + for (size_t i = 0; i < targetList.size(); ++i) { + const unsigned int target = targetList[i]; + const int startPos = targetMin[target]; + const int endPos = targetMax[target]; + const int len = (endPos >= startPos) ? (endPos - startPos + 1) : 1; + std::vector cov(static_cast(len), 0.0); + std::vector covConf(static_cast(len), 0.0); + + std::unordered_map >::const_iterator hitIt = hitsByTarget.find(target); + if (hitIt != hitsByTarget.end()) { + const std::vector &hits = hitIt->second; + for (size_t h = 0; h < hits.size(); ++h) { + const Matcher::result_t &result = hits[h].entry->result; + const int targetLen = result.dbEndPos - result.dbStartPos + 1; + if (targetLen <= 0) { + continue; + } + + const double mq = hits[h].score / static_cast(targetLen); + const int start = std::max(0, result.dbStartPos - startPos); + const int end = std::min(len - 1, result.dbEndPos - startPos); + for (int pos = start; pos <= end; ++pos) { + cov[static_cast(pos)] += mq; + covConf[static_cast(pos)] += hits[h].weight; + } + } + } + + double covered = 0.0; + double squaredCovered = 0.0; + for (size_t pos = 0; pos < covConf.size(); ++pos) { + const double clipped = std::min(1.0, covConf[pos]); + covered += clipped; + squaredCovered += clipped * clipped; + } + const double fraction = covered / static_cast(len); + const double hhi = (covered > 0.0) ? (squaredCovered / (covered * covered)) : 1.0; + const double concentrationPenalty = 1.0 - hhi; + coverageFractionByIndex[i] = clamp01(fraction * concentrationPenalty); + } + + for (size_t i = 0; i < targetList.size(); ++i) { + coverageFraction[targetList[i]] = coverageFractionByIndex[i]; + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const unsigned int target = it->second[j].result.dbKey; + std::unordered_map::const_iterator cf = coverageFraction.find(target); + it->second[j].coverageConfidence = (cf != coverageFraction.end()) ? cf->second : 0.0; + } + } +} + +static void computeAbundanceFromPosterior(MappingTable &mappingTable, + const std::unordered_set &targetSet, + size_t queryCount) { + std::unordered_map abundance; + abundance.reserve(targetSet.size()); + for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { + abundance[*it] = 0.0; + } + + for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + abundance[it->second[j].result.dbKey] += it->second[j].posterior; + } + } + + if (queryCount > 0) { + const double denom = static_cast(queryCount); + for (std::unordered_map::iterator it = abundance.begin(); it != abundance.end(); ++it) { + it->second /= denom; + } + } + + for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + it->second[j].abundance = abundance[it->second[j].result.dbKey]; + } + } +} + +static void addInterval(std::vector &intervals, int start, int end) { + Interval interval; + interval.start = std::min(start, end); + interval.end = std::max(start, end); + intervals.push_back(interval); +} + +static std::vector mergeIntervals(std::vector intervals) { + if (intervals.empty()) { + return intervals; + } + + std::sort(intervals.begin(), intervals.end(), [](const Interval &lhs, const Interval &rhs) { + if (lhs.start != rhs.start) { + return lhs.start < rhs.start; + } + return lhs.end < rhs.end; + }); + + std::vector merged; + merged.push_back(intervals[0]); + for (size_t i = 1; i < intervals.size(); ++i) { + if (intervals[i].start <= merged.back().end + 1) { + merged.back().end = std::max(merged.back().end, intervals[i].end); + } else { + merged.push_back(intervals[i]); + } + } + return merged; +} + +static unsigned int intervalCoverage(const std::vector &intervals) { + unsigned int covered = 0; + for (size_t i = 0; i < intervals.size(); ++i) { + covered += static_cast(intervals[i].end - intervals[i].start + 1); + } + return covered; +} + +static std::vector collectTargetStats(const ReclassTaxContext &ctx) { + std::unordered_map statsByTarget; + + for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { + for (size_t j = 0; j < it->second.size(); ++j) { + const ReclassTaxEntry &entry = it->second[j]; + TargetStats &stats = statsByTarget[entry.result.dbKey]; + stats.key = entry.result.dbKey; + stats.targetLength = entry.result.dbLen; + stats.abundance = entry.abundance; + stats.coverageConfidence = entry.coverageConfidence; + stats.dropped = false; + addInterval(stats.intervals, entry.result.dbStartPos, entry.result.dbEndPos); + } + } + + std::vector out; + out.reserve(statsByTarget.size()); + for (std::unordered_map::iterator it = statsByTarget.begin(); it != statsByTarget.end(); ++it) { + it->second.intervals = mergeIntervals(it->second.intervals); + out.push_back(it->second); + } + + std::sort(out.begin(), out.end(), [](const TargetStats &lhs, const TargetStats &rhs) { + if (lhs.abundance != rhs.abundance) { + return lhs.abundance > rhs.abundance; + } + return lhs.key < rhs.key; + }); + return out; +} + +static double clamp01(double value) { + return std::max(0.0, std::min(1.0, value)); +} + +static bool largestJumpCutoff(std::vector values, + bool useLowTail, + double maxTailFraction, + double &cutoff, + size_t &tailCount) { + cutoff = 0.0; + tailCount = 0; + if (values.size() < MIN_FILTER_TARGETS) { + return false; + } + + std::sort(values.begin(), values.end()); + maxTailFraction = clamp01(maxTailFraction); + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + if (totalMass <= EPS || maxTailFraction <= 0.0) { + return false; + } + const double maxTailMass = maxTailFraction * totalMass; + + double bestGap = 0.0; + size_t bestIdx = 0; + double lowTailMass = 0.0; + for (size_t i = 0; i + 1 < values.size(); ++i) { + const size_t lowTailCount = i + 1; + const size_t highTailCount = values.size() - lowTailCount; + const size_t candidateTailCount = useLowTail ? lowTailCount : highTailCount; + lowTailMass += values[i]; + const double highTailMass = totalMass - lowTailMass; + const double candidateTailMass = useLowTail ? lowTailMass : highTailMass; + if (candidateTailCount < MIN_TAIL_TARGETS || candidateTailMass > (maxTailMass + EPS)) { + continue; + } + + const double gap = values[i + 1] - values[i]; + if (gap > bestGap) { + bestGap = gap; + bestIdx = i; + tailCount = candidateTailCount; + } + } + + if (bestGap <= EPS) { + return false; + } + + cutoff = 0.5 * (values[bestIdx] + values[bestIdx + 1]); + return true; +} + +static bool tailQuantileCutoff(std::vector values, + bool useLowTail, + double maxTailFraction, + double &cutoff, + size_t &tailCount) { + cutoff = 0.0; + tailCount = 0; + if (values.size() < MIN_FILTER_TARGETS) { + return false; + } + + std::sort(values.begin(), values.end()); + maxTailFraction = clamp01(maxTailFraction); + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + if (totalMass <= EPS || maxTailFraction <= 0.0) { + return false; + } + const double maxTailMass = maxTailFraction * totalMass; + + double accumulatedMass = 0.0; + size_t maxTailCount = 0; + for (size_t i = 0; i < values.size(); ++i) { + const double candidate = accumulatedMass + values[i]; + if (candidate > (maxTailMass + EPS)) { + break; + } + accumulatedMass = candidate; + ++maxTailCount; + } + if (maxTailCount < MIN_TAIL_TARGETS || maxTailCount >= values.size()) { + return false; + } + + tailCount = maxTailCount; + if (useLowTail) { + cutoff = values[tailCount - 1]; + } else { + cutoff = values[values.size() - tailCount]; + } + return true; +} + +static std::unordered_set selectTailTargets(const std::vector &stats, + bool useLowTail, + size_t tailCount, + double maxTailFraction) { + std::vector ordered; + ordered.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + ordered.push_back(&stats[i]); + } + + std::sort(ordered.begin(), ordered.end(), [useLowTail](const TargetStats *lhs, const TargetStats *rhs) { + const double lhsValue = useLowTail ? lhs->abundance : lhs->coverageConfidence; + const double rhsValue = useLowTail ? rhs->abundance : rhs->coverageConfidence; + if (lhsValue != rhsValue) { + return useLowTail ? (lhsValue < rhsValue) : (lhsValue > rhsValue); + } + return lhs->key < rhs->key; + }); + + double totalMass = 0.0; + for (size_t i = 0; i < ordered.size(); ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + totalMass += value; + } + const double maxTailMass = clamp01(maxTailFraction) * totalMass; + + std::unordered_set selected; + const size_t limit = std::min(tailCount, ordered.size()); + double selectedMass = 0.0; + selected.reserve(limit); + for (size_t i = 0; i < limit; ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + if (selected.size() >= MIN_TAIL_TARGETS && (selectedMass + value) > (maxTailMass + EPS)) { + break; + } + selectedMass += value; + selected.insert(ordered[i]->key); + } + return selected; +} + +static std::unordered_set selectDroppedTargets(const std::vector &stats, + double maxDropPercentage, + double &abundanceCutoff) { + std::unordered_set dropped; + if (stats.empty()) { + abundanceCutoff = 0.0; + return dropped; + } + if (stats.size() < MIN_FILTER_TARGETS) { + abundanceCutoff = 0.0; + return dropped; + } + + std::vector abundances; + abundances.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + abundances.push_back(stats[i].abundance); + } + + const double maxTailFraction = clamp01(maxDropPercentage / 100.0); + size_t abundanceTailCount = 0; + bool hasAbundanceCutoff = largestJumpCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); + if (hasAbundanceCutoff == false) { + hasAbundanceCutoff = tailQuantileCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); + } + if (hasAbundanceCutoff == false) { + abundanceCutoff = 0.0; + return dropped; + } + + const std::unordered_set lowAbundanceTargets = selectTailTargets(stats, true, abundanceTailCount, maxTailFraction); + for (std::unordered_set::const_iterator it = lowAbundanceTargets.begin(); it != lowAbundanceTargets.end(); ++it) { + dropped.insert(*it); + } + if (dropped.size() == stats.size()) { + dropped.clear(); + } + return dropped; +} + +static void markDroppedTargets(std::vector &stats, const std::unordered_set &dropped) { + for (size_t i = 0; i < stats.size(); ++i) { + stats[i].dropped = (dropped.find(stats[i].key) != dropped.end()); + } +} + +static void convertAbundanceToPercent(std::vector &stats) { + double total = 0.0; + for (size_t i = 0; i < stats.size(); ++i) { + total += stats[i].abundance; + } + + if (total <= 0.0) { + for (size_t i = 0; i < stats.size(); ++i) { + stats[i].abundance = 0.0; + } + return; + } + + for (size_t i = 0; i < stats.size(); ++i) { + stats[i].abundance = 100.0 * (stats[i].abundance / total); + } +} + +static const char *headerForKey(DBReader &headerReader, unsigned int key, unsigned int threadIdx) { + size_t id = headerReader.getId(key); + if (id == UINT_MAX) { + return NULL; + } + return headerReader.getData(id, threadIdx); +} + +static std::string identifierForKey(DBReader &headerReader, unsigned int key, unsigned int threadIdx) { + const char *header = headerForKey(headerReader, key, threadIdx); + if (header == NULL) { + return SSTR(key); + } + std::string parsed = Util::parseFastaHeader(header); + return parsed.empty() ? SSTR(key) : parsed; +} + +static void writeProteinStats(const std::vector &stats, + DBReader &targetHeaderReader, + const std::string &path) { + FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); + fputs("target_key\ttarget_id\tabundance_pct\tcoverage_confidence\tDrop(y/n)\tmapped_length\ttarget_length\n", handle); + + for (size_t i = 0; i < stats.size(); ++i) { + const unsigned int key = stats[i].key; + const std::string targetId = identifierForKey(targetHeaderReader, key, 0); + const unsigned int mappedLength = intervalCoverage(stats[i].intervals); + + fprintf(handle, "%u\t%s\t%.12g\t%.12g\t%s\t%u\t%u\n", + key, + targetId.c_str(), + stats[i].abundance, + stats[i].coverageConfidence, + stats[i].dropped ? "y" : "n", + mappedLength, + stats[i].targetLength); + } + + fclose(handle); +} + +static void writeKrakenReport(const std::vector &stats, + MappingReader &mapping, + NcbiTaxonomy *taxonomy, + size_t queryCount, + const std::string &path) { + std::unordered_map directCounts; + directCounts.reserve(stats.size()); + + for (size_t i = 0; i < stats.size(); ++i) { + const TaxID taxId = mapping.lookup(stats[i].key); + if (taxId == 0) { + continue; + } + const double expectedReads = stats[i].abundance * static_cast(queryCount) / 100.0; + directCounts[taxId] += static_cast(std::floor(expectedReads + 0.5)); + } + + const std::unordered_map > parentToChildren = taxonomy->getParentToChildren(); + const std::unordered_map cladeCounts = taxonomy->getCladeCounts(directCounts, parentToChildren); + + FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); + const double totalReads = (queryCount > 0) ? static_cast(queryCount) : 1.0; + + std::vector stack; + std::vector depthStack; + stack.push_back(1); + depthStack.push_back(0); + + while (!stack.empty()) { + TaxID taxId = stack.back(); + stack.pop_back(); + int depth = depthStack.back(); + depthStack.pop_back(); + + unsigned int cladeCount = 0; + unsigned int directCount = 0; + std::unordered_map::const_iterator it = cladeCounts.find(taxId); + if (it != cladeCounts.end()) { + cladeCount = it->second.cladeCount; + directCount = it->second.taxCount; + } + + if (cladeCount > 0) { + const TaxonNode *node = taxonomy->taxonNode(taxId, false); + const char *rankStr = (node != NULL) ? taxonomy->getString(node->rankIdx) : NULL; + char rankCode = '-'; + if (rankStr != NULL) { + std::map::const_iterator rankIt = NcbiShortRanks.find(std::string(rankStr)); + if (rankIt != NcbiShortRanks.end()) { + rankCode = rankIt->second; + } + } + const char *name = (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified"; + const double pct = 100.0 * static_cast(cladeCount) / totalReads; + + for (int i = 0; i < depth; ++i) { + fputs(" ", handle); + } + fprintf(handle, "%.4f\t%u\t%u\t%c\t%u\t%s\n", + pct, cladeCount, directCount, rankCode, static_cast(taxId), name); + } + + std::unordered_map::const_iterator childIt = cladeCounts.find(taxId); + if (childIt != cladeCounts.end()) { + const std::vector &children = childIt->second.children; + for (size_t i = 0; i < children.size(); ++i) { + stack.push_back(children[i]); + depthStack.push_back(depth + 1); + } + } + } + + fclose(handle); +} + +static void writeBrackenReport(const std::vector &stats, + MappingReader &mapping, + NcbiTaxonomy *taxonomy, + size_t queryCount, + const std::string &path) { + std::unordered_map directCounts; + directCounts.reserve(stats.size()); + + for (size_t i = 0; i < stats.size(); ++i) { + const TaxID taxId = mapping.lookup(stats[i].key); + if (taxId == 0) { + continue; + } + const double expectedReads = stats[i].abundance * static_cast(queryCount) / 100.0; + directCounts[taxId] += static_cast(std::floor(expectedReads + 0.5)); + } + + const std::unordered_map > parentToChildren = taxonomy->getParentToChildren(); + const std::unordered_map cladeCounts = taxonomy->getCladeCounts(directCounts, parentToChildren); + + FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); + fputs("name\ttaxonomy_id\ttaxonomy_lvl\tkraken_assigned_reads\tadded_reads\tnew_est_reads\tfraction_total_reads\n", handle); + + const double totalReads = (queryCount > 0) ? static_cast(queryCount) : 1.0; + for (std::unordered_map::const_iterator it = cladeCounts.begin(); it != cladeCounts.end(); ++it) { + const TaxID taxId = it->first; + const TaxonCounts &counts = it->second; + if (counts.cladeCount == 0) { + continue; + } + + const TaxonNode *node = taxonomy->taxonNode(taxId, false); + const char *rankStr = (node != NULL) ? taxonomy->getString(node->rankIdx) : "-"; + const char *name = (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified"; + const unsigned int cladeCount = counts.cladeCount; + const unsigned int directCount = counts.taxCount; + const unsigned int addedReads = (cladeCount >= directCount) ? (cladeCount - directCount) : 0; + const double fraction = static_cast(cladeCount) / totalReads; + + fprintf(handle, "%s\t%u\t%s\t%u\t%u\t%u\t%.12g\n", + name, + static_cast(taxId), + rankStr, + directCount, + addedReads, + cladeCount, + fraction); + } + + fclose(handle); +} +} + +int emabundance(int argc, const char **argv, const Command &command) { + Parameters &par = Parameters::getInstance(); + par.parseParameters(argc, argv, command, true, 0, 0); + + DBReader reader(par.db3.c_str(), par.db3Index.c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + reader.open(DBReader::LINEAR_ACCCESS); + + DBReader targetHeaderReader((par.db2 + "_h").c_str(), (par.db2 + "_h.index").c_str(), par.threads, + DBReader::USE_INDEX | DBReader::USE_DATA); + targetHeaderReader.open(DBReader::NOSORT); + + const bool withTaxonomy = (par.reclassifyTaxonomy == 1); + NcbiTaxonomy *taxonomy = NULL; + MappingReader *mapping = NULL; + if (withTaxonomy) { + taxonomy = NcbiTaxonomy::openTaxonomy(par.db2); + mapping = new MappingReader(par.db2); + } + + ReclassTaxContext ctx; + loadAlignmentDb(reader, ctx); + Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; + + initCoverageConfidence(ctx.mappingTable, ctx.targetSet, par.threads); + computeAbundanceFromPosterior(ctx.mappingTable, ctx.targetSet, ctx.queryCount); + + std::vector allTargetStats = collectTargetStats(ctx); + double abundanceCutoff = 0.0; + const std::unordered_set dropped = selectDroppedTargets(allTargetStats, + par.reclassifyMaxDropPercentage, + abundanceCutoff); + markDroppedTargets(allTargetStats, dropped); + convertAbundanceToPercent(allTargetStats); + + if (withTaxonomy) { + std::vector targetStats = allTargetStats; + targetStats.erase(std::remove_if(targetStats.begin(), targetStats.end(), [](const TargetStats &entry) { + return entry.dropped; + }), targetStats.end()); + writeKrakenReport(targetStats, *mapping, taxonomy, ctx.queryCount, par.db4); + writeBrackenReport(targetStats, *mapping, taxonomy, ctx.queryCount, par.db4 + ".bracken"); + } else { + writeProteinStats(allTargetStats, targetHeaderReader, par.db4); + } + + delete mapping; + delete taxonomy; + targetHeaderReader.close(); + reader.close(); + return EXIT_SUCCESS; +} diff --git a/src/util/reclassify_taxonomy.cpp b/src/util/EM_reclassify2.cpp similarity index 65% rename from src/util/reclassify_taxonomy.cpp rename to src/util/EM_reclassify2.cpp index 3e03fc710..b3a4e1d56 100644 --- a/src/util/reclassify_taxonomy.cpp +++ b/src/util/EM_reclassify2.cpp @@ -1,16 +1,13 @@ #include "Parameters.h" #include "DBReader.h" +#include "DBWriter.h" #include "Debug.h" #include "Util.h" #include "Matcher.h" #include "FastSort.h" -#include "FileUtil.h" -#include "NcbiTaxonomy.h" -#include "MappingReader.h" #include #include -#include #include #include #include @@ -45,15 +42,6 @@ struct TargetStats { std::vector intervals; }; -struct TaxonomyStats { - unsigned int taxId; - double abundance; - double coverageConfidenceSum; - size_t proteinCount; - - TaxonomyStats() : taxId(0), abundance(0.0), coverageConfidenceSum(0.0), proteinCount(0) {} -}; - struct ReclassTaxContext { MappingTable mappingTable; std::vector queryOrder; @@ -103,10 +91,12 @@ static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &c EXIT(EXIT_FAILURE); } - if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT) { + if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { ctx.hasBacktrace = true; } - if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT) { + if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + || columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { ctx.hasOrfPosition = true; } @@ -156,30 +146,28 @@ static void initAbundance(MappingTable &mappingTable, const std::unordered_set &targetSet, - double lambda, int threads) { + (void)threads; std::unordered_map targetMin; std::unordered_map targetMax; - std::unordered_map targetLenMap; std::unordered_map > hitsByTarget; for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { targetMin[*it] = std::numeric_limits::max(); targetMax[*it] = std::numeric_limits::min(); - targetLenMap[*it] = 0; hitsByTarget.emplace(*it, std::vector()); } for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { - double scoreSumExp = 0.0; + double scoreSum = 0.0; for (size_t j = 0; j < it->second.size(); ++j) { - scoreSumExp += std::exp(lambda * static_cast(it->second[j].result.score)); + scoreSum += static_cast(it->second[j].result.score); } for (size_t j = 0; j < it->second.size(); ++j) { const unsigned int target = it->second[j].result.dbKey; @@ -189,12 +177,9 @@ static void initCoverageConfidence(MappingTable &mappingTable, if (it->second[j].result.dbEndPos > targetMax[target]) { targetMax[target] = it->second[j].result.dbEndPos; } - if (static_cast(it->second[j].result.dbLen) > targetLenMap[target]) { - targetLenMap[target] = static_cast(it->second[j].result.dbLen); - } - const double expScore = std::exp(lambda * static_cast(it->second[j].result.score)); - const double weight = (scoreSumExp > 0.0) ? (expScore / scoreSumExp) : 0.0; - hitsByTarget[target].push_back(TargetHitRef{&it->second[j], expScore, weight}); + const double score = static_cast(it->second[j].result.score); + const double weight = (scoreSum > 0.0) ? (score / scoreSum) : 0.0; + hitsByTarget[target].push_back(TargetHitRef{&it->second[j], score, weight}); } } @@ -222,7 +207,7 @@ static void initCoverageConfidence(MappingTable &mappingTable, continue; } - const double mq = hits[h].expScore / static_cast(targetLen); + const double mq = hits[h].score / static_cast(targetLen); const int start = std::max(0, result.dbStartPos - startPos); const int end = std::min(len - 1, result.dbEndPos - startPos); for (int pos = start; pos <= end; ++pos) { @@ -233,12 +218,16 @@ static void initCoverageConfidence(MappingTable &mappingTable, } double covered = 0.0; + double squaredCovered = 0.0; for (size_t pos = 0; pos < covConf.size(); ++pos) { - covered += std::min(1.0, covConf[pos]); + const double clipped = std::min(1.0, covConf[pos]); + covered += clipped; + squaredCovered += clipped * clipped; } - const unsigned int targetLen = (targetLenMap[target] > 0) ? targetLenMap[target] : 1; - const double fraction = covered / static_cast(targetLen); - coverageFractionByIndex[i] = clamp01(fraction); + const double fraction = covered / static_cast(len); + const double hhi = (covered > 0.0) ? (squaredCovered / (covered * covered)) : 1.0; + const double concentrationPenalty = 1.0 - hhi; + coverageFractionByIndex[i] = clamp01(fraction * concentrationPenalty); } for (size_t i = 0; i < targetList.size(); ++i) { @@ -421,8 +410,8 @@ static void squarem(ReclassTaxContext &ctx, } initAbundance(ctx.mappingTable, ctx.targetSet, ctx.queryCount); - initCoverageConfidence(ctx.mappingTable, ctx.targetSet, lambda, threads); - Debug(Debug::INFO) << "Reclassify-taxonomy initialized coverage confidence." << "\n"; + initCoverageConfidence(ctx.mappingTable, ctx.targetSet, threads); + Debug(Debug::INFO) << "Reclassify initialized coverage confidence." << "\n"; std::unordered_map fixedCoverageConfidence; fixedCoverageConfidence.reserve(ctx.targetSet.size()); @@ -483,77 +472,15 @@ static void squarem(ReclassTaxContext &ctx, parameterChange = std::max(parameterChange, std::fabs(xNew[i] - x0[i])); } - Debug(Debug::INFO) << "Reclassify-taxonomy iteration " << iter << ": LL=" << currentLl << " delta=" << parameterChange << "\n"; + Debug(Debug::INFO) << "Reclassify iteration " << iter << ": LL=" << currentLl << " delta=" << parameterChange << "\n"; x0 = xNew; if (parameterChange < tol && iter > 5) { - Debug(Debug::INFO) << "Reclassify-taxonomy converged after " << (iter + 1) << " iterations.\n"; + Debug(Debug::INFO) << "Reclassify converged after " << (iter + 1) << " iterations." << "\n"; break; } } } -static bool compareByPosterior(const ReclassTaxEntry &a, const ReclassTaxEntry &b) { - if (a.posterior != b.posterior) { - return a.posterior > b.posterior; - } - return Matcher::compareHits(a.result, b.result); -} - -static const char *headerForKey(DBReader &headerReader, unsigned int key, unsigned int threadIdx) { - size_t id = headerReader.getId(key); - if (id == UINT_MAX) { - return NULL; - } - return headerReader.getData(id, threadIdx); -} - -static std::string identifierForKey(DBReader &headerReader, unsigned int key, unsigned int threadIdx) { - const char *header = headerForKey(headerReader, key, threadIdx); - if (header == NULL) { - return SSTR(key); - } - std::string parsed = Util::parseFastaHeader(header); - return parsed.empty() ? SSTR(key) : parsed; -} - -static void computeAlignmentCounts(const Matcher::result_t &res, unsigned int &alnLen, unsigned int &mismatchCount, unsigned int &gapOpenCount) { - gapOpenCount = 0; - alnLen = res.alnLength; - mismatchCount = 0; - - if (!res.backtrace.empty()) { - size_t matchCount = 0; - alnLen = 0; - for (size_t pos = 0; pos < res.backtrace.size(); ++pos) { - int cnt = 0; - if (std::isdigit(static_cast(res.backtrace[pos]))) { - cnt += Util::fast_atoi(res.backtrace.c_str() + pos); - while (std::isdigit(static_cast(res.backtrace[pos]))) { - pos++; - } - } - alnLen += cnt; - - switch (res.backtrace[pos]) { - case 'M': - matchCount += cnt; - break; - case 'D': - case 'I': - gapOpenCount += 1; - break; - } - } - const unsigned int identical = static_cast(res.seqId * static_cast(alnLen) + 0.5f); - mismatchCount = static_cast(matchCount - identical); - } else { - const int adjustQstart = (res.qStartPos == -1) ? 0 : res.qStartPos; - const int adjustDBstart = (res.dbStartPos == -1) ? 0 : res.dbStartPos; - const float bestMatchEstimate = static_cast(std::min(abs(res.qEndPos - adjustQstart), abs(res.dbEndPos - adjustDBstart))); - mismatchCount = static_cast(bestMatchEstimate * (1.0f - res.seqId) + 0.5f); - } -} - static void addInterval(std::vector &intervals, int start, int end) { Interval interval; interval.start = std::min(start, end); @@ -585,27 +512,6 @@ static std::vector mergeIntervals(std::vector intervals) { return merged; } -static std::string intervalsToString(const std::vector &intervals) { - std::string out; - for (size_t i = 0; i < intervals.size(); ++i) { - if (i > 0) { - out.append(","); - } - out.append(SSTR(intervals[i].start + 1)); - out.append(":"); - out.append(SSTR(intervals[i].end + 1)); - } - return out; -} - -static unsigned int intervalCoverage(const std::vector &intervals) { - unsigned int covered = 0; - for (size_t i = 0; i < intervals.size(); ++i) { - covered += static_cast(intervals[i].end - intervals[i].start + 1); - } - return covered; -} - static std::vector collectTargetStats(const ReclassTaxContext &ctx) { std::unordered_map statsByTarget; @@ -655,15 +561,23 @@ static bool largestJumpCutoff(std::vector values, std::sort(values.begin(), values.end()); maxTailFraction = clamp01(maxTailFraction); - const size_t maxTailCount = std::max(MIN_TAIL_TARGETS, - static_cast(std::floor(maxTailFraction * static_cast(values.size())))); + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + if (totalMass <= EPS || maxTailFraction <= 0.0) { + return false; + } + const double maxTailMass = maxTailFraction * totalMass; + double bestGap = 0.0; size_t bestIdx = 0; + double lowTailMass = 0.0; for (size_t i = 0; i + 1 < values.size(); ++i) { const size_t lowTailCount = i + 1; const size_t highTailCount = values.size() - lowTailCount; const size_t candidateTailCount = useLowTail ? lowTailCount : highTailCount; - if (candidateTailCount < MIN_TAIL_TARGETS || candidateTailCount > maxTailCount) { + lowTailMass += values[i]; + const double highTailMass = totalMass - lowTailMass; + const double candidateTailMass = useLowTail ? lowTailMass : highTailMass; + if (candidateTailCount < MIN_TAIL_TARGETS || candidateTailMass > (maxTailMass + EPS)) { continue; } @@ -696,8 +610,22 @@ static bool tailQuantileCutoff(std::vector values, std::sort(values.begin(), values.end()); maxTailFraction = clamp01(maxTailFraction); - const size_t maxTailCount = std::max(MIN_TAIL_TARGETS, - static_cast(std::floor(maxTailFraction * static_cast(values.size())))); + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + if (totalMass <= EPS || maxTailFraction <= 0.0) { + return false; + } + const double maxTailMass = maxTailFraction * totalMass; + + double accumulatedMass = 0.0; + size_t maxTailCount = 0; + for (size_t i = 0; i < values.size(); ++i) { + const double candidate = accumulatedMass + values[i]; + if (candidate > (maxTailMass + EPS)) { + break; + } + accumulatedMass = candidate; + ++maxTailCount; + } if (maxTailCount < MIN_TAIL_TARGETS || maxTailCount >= values.size()) { return false; } @@ -713,7 +641,8 @@ static bool tailQuantileCutoff(std::vector values, static std::unordered_set selectTailTargets(const std::vector &stats, bool useLowTail, - size_t tailCount) { + size_t tailCount, + double maxTailFraction) { std::vector ordered; ordered.reserve(stats.size()); for (size_t i = 0; i < stats.size(); ++i) { @@ -729,10 +658,23 @@ static std::unordered_set selectTailTargets(const std::vectorkey < rhs->key; }); + double totalMass = 0.0; + for (size_t i = 0; i < ordered.size(); ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + totalMass += value; + } + const double maxTailMass = clamp01(maxTailFraction) * totalMass; + std::unordered_set selected; const size_t limit = std::min(tailCount, ordered.size()); + double selectedMass = 0.0; selected.reserve(limit); for (size_t i = 0; i < limit; ++i) { + const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; + if (selected.size() >= MIN_TAIL_TARGETS && (selectedMass + value) > (maxTailMass + EPS)) { + break; + } + selectedMass += value; selected.insert(ordered[i]->key); } return selected; @@ -768,7 +710,7 @@ static std::unordered_set selectDroppedTargets(const std::vector lowAbundanceTargets = selectTailTargets(stats, true, abundanceTailCount); + const std::unordered_set lowAbundanceTargets = selectTailTargets(stats, true, abundanceTailCount, maxTailFraction); for (std::unordered_set::const_iterator it = lowAbundanceTargets.begin(); it != lowAbundanceTargets.end(); ++it) { dropped.insert(*it); } @@ -783,7 +725,7 @@ static void applyDroppedTargets(ReclassTaxContext &ctx, size_t totalTargets, double abundanceCutoff) { if (dropped.empty()) { - Debug(Debug::INFO) << "Reclassify-taxonomy target filter kept all targets. abundance cutoff=" + Debug(Debug::INFO) << "Reclassify target filter kept all targets. abundance cutoff=" << abundanceCutoff << "\n"; return; } @@ -812,176 +754,68 @@ static void applyDroppedTargets(ReclassTaxContext &ctx, const double removedPct = (totalTargets > 0) ? (100.0 * static_cast(dropped.size()) / static_cast(totalTargets)) : 0.0; - Debug(Debug::INFO) << "Reclassify-taxonomy dropped " << dropped.size() + Debug(Debug::INFO) << "Reclassify dropped " << dropped.size() << " of " << totalTargets << " targets (" << removedPct << "%)" << " using abundance <= " << abundanceCutoff << ".\n"; } -static void markDroppedTargets(std::vector &stats, const std::unordered_set &dropped) { - for (size_t i = 0; i < stats.size(); ++i) { - stats[i].dropped = (dropped.find(stats[i].key) != dropped.end()); - } -} - -static void convertAbundanceToPercent(std::vector &stats) { - double total = 0.0; - for (size_t i = 0; i < stats.size(); ++i) { - total += stats[i].abundance; - } - - if (total <= 0.0) { - for (size_t i = 0; i < stats.size(); ++i) { - stats[i].abundance = 0.0; - } - return; +static bool compareByPosteriorThenBitScore(const ReclassTaxEntry &a, const ReclassTaxEntry &b) { + if (a.posterior != b.posterior) { + return a.posterior > b.posterior; } - - for (size_t i = 0; i < stats.size(); ++i) { - stats[i].abundance = 100.0 * (stats[i].abundance / total); + if (a.result.score != b.result.score) { + return a.result.score > b.result.score; } + return Matcher::compareHits(a.result, b.result); } -static void writeReclassifiedM8(const ReclassTaxContext &ctx, - DBReader &queryHeaderReader, - DBReader &targetHeaderReader, - const std::string &path) { - FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); - char line[4096]; - - for (size_t i = 0; i < ctx.queryOrder.size(); ++i) { - const unsigned int queryKey = ctx.queryOrder[i]; - MappingTable::const_iterator recordsIt = ctx.mappingTable.find(queryKey); - if (recordsIt == ctx.mappingTable.end()) { - continue; - } +static void writeReclassifiedDb(const ReclassTaxContext &ctx, + int dbType, + const std::string &outDb, + const std::string &outIndex, + int threads, + bool compress) { + DBWriter writer(outDb.c_str(), outIndex.c_str(), threads, compress, dbType); + writer.open(); - std::string queryId = identifierForKey(queryHeaderReader, queryKey, 0); - std::vector records = recordsIt->second; - SORT_SERIAL(records.begin(), records.end(), compareByPosterior); - - for (size_t j = 0; j < records.size(); ++j) { - const Matcher::result_t &res = records[j].result; - const std::string targetId = identifierForKey(targetHeaderReader, res.dbKey, 0); - - unsigned int alnLen = 0; - unsigned int mismatchCount = 0; - unsigned int gapOpenCount = 0; - computeAlignmentCounts(res, alnLen, mismatchCount, gapOpenCount); - - const int written = snprintf(line, sizeof(line), - "%s\t%s\t%1.3f\t%u\t%u\t%u\t%d\t%d\t%d\t%d\t%.2E\t%d\n", - queryId.c_str(), targetId.c_str(), res.seqId, alnLen, - mismatchCount, gapOpenCount, - res.qStartPos + 1, res.qEndPos + 1, - res.dbStartPos + 1, res.dbEndPos + 1, - res.eval, res.score); - if (written < 0 || static_cast(written) >= sizeof(line)) { - Debug(Debug::WARNING) << "Truncated M8 line for query " << queryKey << " and target " << res.dbKey << ".\n"; + Debug::Progress progress(ctx.queryOrder.size()); +#pragma omp parallel + { + unsigned int thread_idx = 0; +#ifdef OPENMP + thread_idx = static_cast(omp_get_thread_num()); +#endif + char buffer[1024 + 32768 * 4]; + +#pragma omp for schedule(dynamic, 5) + for (size_t i = 0; i < ctx.queryOrder.size(); ++i) { + progress.updateProgress(); + const unsigned int queryKey = ctx.queryOrder[i]; + MappingTable::const_iterator recordsIt = ctx.mappingTable.find(queryKey); + if (recordsIt == ctx.mappingTable.end()) { continue; } - fputs(line, handle); - } - } - - fclose(handle); -} -static void writeProteinStats(const std::vector &stats, - DBReader &targetHeaderReader, - MappingReader *mapping, - NcbiTaxonomy *taxonomy, - const std::string &path) { - FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); - const bool withTaxonomy = (mapping != NULL && taxonomy != NULL); - if (withTaxonomy) { - fputs("target_key\ttarget_id\tabundance_pct\tcoverage_confidence\tDrop(y/n)\tmapped_length\ttarget_length\ttaxid\trank\ttaxname\ttaxlineage\n", handle); - } else { - fputs("target_key\ttarget_id\tabundance_pct\tcoverage_confidence\tDrop(y/n)\tmapped_length\ttarget_length\n", handle); - } + std::vector records = recordsIt->second; + SORT_SERIAL(records.begin(), records.end(), compareByPosteriorThenBitScore); - for (size_t i = 0; i < stats.size(); ++i) { - const unsigned int key = stats[i].key; - const std::string targetId = identifierForKey(targetHeaderReader, key, 0); - const unsigned int mappedLength = intervalCoverage(stats[i].intervals); - - if (withTaxonomy) { - const unsigned int taxId = mapping->lookup(key); - const TaxonNode *node = (taxId != 0) ? taxonomy->taxonNode(taxId, false) : NULL; - const std::string lineage = (node != NULL) ? taxonomy->taxLineage(node, true) : "unclassified"; - fprintf(handle, "%u\t%s\t%.12g\t%.12g\t%s\t%u\t%u\t%u\t%s\t%s\t%s\n", - key, - targetId.c_str(), - stats[i].abundance, - stats[i].coverageConfidence, - stats[i].dropped ? "y" : "n", - mappedLength, - stats[i].targetLength, - taxId, - (node != NULL) ? taxonomy->getString(node->rankIdx) : "unclassified", - (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified", - lineage.c_str()); - } else { - fprintf(handle, "%u\t%s\t%.12g\t%.12g\t%s\t%u\t%u\n", - key, - targetId.c_str(), - stats[i].abundance, - stats[i].coverageConfidence, - stats[i].dropped ? "y" : "n", - mappedLength, - stats[i].targetLength); - } - } - - fclose(handle); -} - -static void writeTaxonomyStats(const std::vector &stats, - MappingReader &mapping, - NcbiTaxonomy *taxonomy, - const std::string &path) { - std::unordered_map aggregated; - for (size_t i = 0; i < stats.size(); ++i) { - const unsigned int taxId = mapping.lookup(stats[i].key); - TaxonomyStats &entry = aggregated[taxId]; - entry.taxId = taxId; - entry.abundance += stats[i].abundance; - entry.coverageConfidenceSum += stats[i].coverageConfidence; - entry.proteinCount += 1; - } - - std::vector rows; - rows.reserve(aggregated.size()); - for (std::unordered_map::const_iterator it = aggregated.begin(); it != aggregated.end(); ++it) { - rows.push_back(it->second); - } - std::sort(rows.begin(), rows.end(), [](const TaxonomyStats &lhs, const TaxonomyStats &rhs) { - if (lhs.abundance != rhs.abundance) { - return lhs.abundance > rhs.abundance; + writer.writeStart(thread_idx); + for (size_t j = 0; j < records.size(); ++j) { + Matcher::result_t res = records[j].result; + res.seqId = static_cast(records[j].posterior); + size_t len = Matcher::resultToBuffer(buffer, res, ctx.hasBacktrace, ctx.hasOrfPosition); + writer.writeAdd(buffer, len, thread_idx); + } + writer.writeEnd(queryKey, thread_idx); } - return lhs.taxId < rhs.taxId; - }); - - FILE *handle = FileUtil::openFileOrDie(path.c_str(), "w", false); - fputs("taxid\trank\ttaxname\ttaxlineage\tprotein_abundance_pct\tprotein_count\n", handle); - - for (size_t i = 0; i < rows.size(); ++i) { - const TaxonNode *node = (rows[i].taxId != 0) ? taxonomy->taxonNode(rows[i].taxId, false) : NULL; - const std::string lineage = (node != NULL) ? taxonomy->taxLineage(node, true) : "unclassified"; - fprintf(handle, "%u\t%s\t%s\t%s\t%.12g\t%zu\n", - rows[i].taxId, - (node != NULL) ? taxonomy->getString(node->rankIdx) : "unclassified", - (node != NULL) ? taxonomy->getString(node->nameIdx) : "unclassified", - lineage.c_str(), - rows[i].abundance, - rows[i].proteinCount); } - fclose(handle); + writer.close(); } } -int reclassifytaxonomy(int argc, const char **argv, const Command &command) { +int emreclassify(int argc, const char **argv, const Command &command) { Parameters &par = Parameters::getInstance(); par.parseParameters(argc, argv, command, true, 0, 0); @@ -989,27 +823,11 @@ int reclassifytaxonomy(int argc, const char **argv, const Command &command) { DBReader::USE_INDEX | DBReader::USE_DATA); reader.open(DBReader::LINEAR_ACCCESS); - DBReader queryHeaderReader((par.db1 + "_h").c_str(), (par.db1 + "_h.index").c_str(), par.threads, - DBReader::USE_INDEX | DBReader::USE_DATA); - queryHeaderReader.open(DBReader::NOSORT); - - DBReader targetHeaderReader((par.db2 + "_h").c_str(), (par.db2 + "_h.index").c_str(), par.threads, - DBReader::USE_INDEX | DBReader::USE_DATA); - targetHeaderReader.open(DBReader::NOSORT); - - const bool withTaxonomy = (par.reclassifyTaxonomy == 1); - NcbiTaxonomy *taxonomy = NULL; - MappingReader *mapping = NULL; - if (withTaxonomy) { - taxonomy = NcbiTaxonomy::openTaxonomy(par.db2); - mapping = new MappingReader(par.db2); - } - ReclassTaxContext ctx; loadAlignmentDb(reader, ctx); Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; - squarem(ctx, + squarem(ctx, par.reclassifyLambda, par.reclassifyMaxIterations, par.reclassifyTolerance, @@ -1017,36 +835,15 @@ int reclassifytaxonomy(int argc, const char **argv, const Command &command) { par.reclassifyGamma, par.threads); - const std::string outDir = par.db4; - const std::string m8Path = outDir + "/new_alignment_result.m8"; - const std::string proteinPath = outDir + "/protein_abundance.tsv"; - const std::string taxonomyPath = outDir + "/taxonomy_abundance.tsv"; - std::vector allTargetStats = collectTargetStats(ctx); double abundanceCutoff = 0.0; const std::unordered_set dropped = selectDroppedTargets(allTargetStats, par.reclassifyMaxDropPercentage, abundanceCutoff); - markDroppedTargets(allTargetStats, dropped); - convertAbundanceToPercent(allTargetStats); - - std::vector targetStats = allTargetStats; - targetStats.erase(std::remove_if(targetStats.begin(), targetStats.end(), [](const TargetStats &entry) { - return entry.dropped; - }), targetStats.end()); applyDroppedTargets(ctx, dropped, allTargetStats.size(), abundanceCutoff); - convertAbundanceToPercent(targetStats); - writeReclassifiedM8(ctx, queryHeaderReader, targetHeaderReader, m8Path); - writeProteinStats(allTargetStats, targetHeaderReader, mapping, taxonomy, proteinPath); - if (withTaxonomy) { - writeTaxonomyStats(targetStats, *mapping, taxonomy, taxonomyPath); - } + writeReclassifiedDb(ctx, reader.getDbtype(), par.db4, par.db4Index, par.threads, par.compressed); - delete mapping; - delete taxonomy; - targetHeaderReader.close(); - queryHeaderReader.close(); reader.close(); return EXIT_SUCCESS; } diff --git a/src/util/EM_reclassify3.cpp b/src/util/EM_reclassify3.cpp new file mode 100644 index 000000000..a6f47a3eb --- /dev/null +++ b/src/util/EM_reclassify3.cpp @@ -0,0 +1,835 @@ +// #include "Parameters.h" +// #include "DBReader.h" +// #include "DBWriter.h" +// #include "Debug.h" +// #include "Util.h" +// #include "Matcher.h" +// #include "FastSort.h" + +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #ifdef OPENMP +// #include +// #endif + +// namespace { +// struct ReclassTaxEntry { +// Matcher::result_t result; +// double abundance; +// double posterior; +// double coverageConfidence; +// }; + +// typedef std::unordered_map > MappingTable; + +// struct Interval { +// int start; +// int end; +// }; + +// struct TargetStats { +// unsigned int key; +// unsigned int targetLength; +// double abundance; +// double coverageConfidence; +// bool dropped; +// std::vector intervals; +// }; + +// struct ReclassTaxContext { +// MappingTable mappingTable; +// std::vector queryOrder; +// std::unordered_set targetSet; +// size_t queryCount; +// bool hasBacktrace; +// bool hasOrfPosition; + +// ReclassTaxContext() : queryCount(0), hasBacktrace(false), hasOrfPosition(false) {} +// }; + +// static const double STEP_MIN = -1.0; +// static const double STEP_MAX = 1.0; +// static const double EPS = 1e-12; +// static const size_t MIN_FILTER_TARGETS = 20; +// static const size_t MIN_TAIL_TARGETS = 2; + +// static double clamp01(double value); + +// static std::vector targetListFromSet(const std::unordered_set &targets) { +// std::vector out(targets.begin(), targets.end()); +// SORT_SERIAL(out.begin(), out.end()); +// return out; +// } + +// static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &ctx) { +// Debug::Progress progress(reader.getSize()); +// const char *entry[255]; + +// for (size_t i = 0; i < reader.getSize(); ++i) { +// progress.updateProgress(); +// const unsigned int queryKey = reader.getDbKey(i); +// char *data = reader.getData(i, 0); + +// if (reader.getEntryLen(i) <= 1) { +// continue; +// } + +// std::vector &records = ctx.mappingTable[queryKey]; +// if (records.empty()) { +// ctx.queryOrder.push_back(queryKey); +// } +// while (*data != '\0') { +// const size_t columns = Util::getWordsOfLine(data, entry, 255); +// if (columns < Matcher::ALN_RES_WITHOUT_BT_COL_CNT) { +// Debug(Debug::ERROR) << "Invalid alignment result record in query " << queryKey << ".\n"; +// EXIT(EXIT_FAILURE); +// } + +// if (columns == Matcher::ALN_RES_WITH_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT +// || columns == Matcher::ALN_RES_WITH_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { +// ctx.hasBacktrace = true; +// } +// if (columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT +// || columns == Matcher::ALN_RES_WITH_ORF_POS_WITHOUT_BT_COL_CNT + 1 || columns == Matcher::ALN_RES_WITH_ORF_AND_BT_COL_CNT + 1) { +// ctx.hasOrfPosition = true; +// } + +// Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); +// records.push_back(ReclassTaxEntry{result, 0.0, 0.0, 0.0}); +// ctx.targetSet.insert(result.dbKey); +// data = Util::skipLine(data); +// } +// } + +// ctx.queryCount = ctx.mappingTable.size(); +// } + +// static void initAbundance(MappingTable &mappingTable, const std::unordered_set &targetSet, size_t queryCount) { +// std::unordered_map initAbundance; +// initAbundance.reserve(targetSet.size()); +// for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { +// initAbundance[*it] = 0.0; +// } + +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// double scoreSum = 0.0; +// for (size_t j = 0; j < it->second.size(); ++j) { +// scoreSum += it->second[j].result.score; +// } +// if (scoreSum <= 0.0) { +// continue; +// } +// for (size_t j = 0; j < it->second.size(); ++j) { +// initAbundance[it->second[j].result.dbKey] += it->second[j].result.score / scoreSum; +// } +// } + +// if (queryCount > 0) { +// const double denom = static_cast(queryCount); +// for (std::unordered_map::iterator it = initAbundance.begin(); it != initAbundance.end(); ++it) { +// it->second /= denom; +// } +// } + +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// it->second[j].abundance = initAbundance[it->second[j].result.dbKey]; +// } +// } +// } + +// struct TargetHitRef { +// const ReclassTaxEntry *entry; +// double score; +// double weight; +// }; + +// static void initCoverageConfidence(MappingTable &mappingTable, +// const std::unordered_set &targetSet, +// int threads) { +// (void)threads; +// std::unordered_map targetMin; +// std::unordered_map targetMax; +// std::unordered_map > hitsByTarget; + +// for (std::unordered_set::const_iterator it = targetSet.begin(); it != targetSet.end(); ++it) { +// targetMin[*it] = std::numeric_limits::max(); +// targetMax[*it] = std::numeric_limits::min(); +// hitsByTarget.emplace(*it, std::vector()); +// } + +// for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// double scoreSum = 0.0; +// for (size_t j = 0; j < it->second.size(); ++j) { +// scoreSum += static_cast(it->second[j].result.score); +// } +// for (size_t j = 0; j < it->second.size(); ++j) { +// const unsigned int target = it->second[j].result.dbKey; +// if (it->second[j].result.dbStartPos < targetMin[target]) { +// targetMin[target] = it->second[j].result.dbStartPos; +// } +// if (it->second[j].result.dbEndPos > targetMax[target]) { +// targetMax[target] = it->second[j].result.dbEndPos; +// } +// const double score = static_cast(it->second[j].result.score); +// const double weight = (scoreSum > 0.0) ? (score / scoreSum) : 0.0; +// hitsByTarget[target].push_back(TargetHitRef{&it->second[j], score, weight}); +// } +// } + +// std::unordered_map coverageFraction; +// coverageFraction.reserve(targetSet.size()); +// const std::vector targetList = targetListFromSet(targetSet); +// std::vector coverageFractionByIndex(targetList.size(), 0.0); + +// #pragma omp parallel for num_threads(threads) schedule(dynamic, 1) +// for (size_t i = 0; i < targetList.size(); ++i) { +// const unsigned int target = targetList[i]; +// const int startPos = targetMin[target]; +// const int endPos = targetMax[target]; +// const int len = (endPos >= startPos) ? (endPos - startPos + 1) : 1; +// std::vector cov(static_cast(len), 0.0); +// std::vector covConf(static_cast(len), 0.0); + +// std::unordered_map >::const_iterator hitIt = hitsByTarget.find(target); +// if (hitIt != hitsByTarget.end()) { +// const std::vector &hits = hitIt->second; +// for (size_t h = 0; h < hits.size(); ++h) { +// const Matcher::result_t &result = hits[h].entry->result; +// const int targetLen = result.dbEndPos - result.dbStartPos + 1; +// if (targetLen <= 0) { +// continue; +// } + +// const double mq = hits[h].score / static_cast(targetLen); +// const int start = std::max(0, result.dbStartPos - startPos); +// const int end = std::min(len - 1, result.dbEndPos - startPos); +// for (int pos = start; pos <= end; ++pos) { +// cov[static_cast(pos)] += mq; +// covConf[static_cast(pos)] += hits[h].weight; +// } +// } +// } + +// double covered = 0.0; +// double squaredCovered = 0.0; +// for (size_t pos = 0; pos < covConf.size(); ++pos) { +// const double clipped = std::min(1.0, covConf[pos]); +// covered += clipped; +// squaredCovered += clipped * clipped; +// } +// const double fraction = covered / static_cast(len); +// const double hhi = (covered > 0.0) ? (squaredCovered / (covered * covered)) : 1.0; +// const double concentrationPenalty = 1.0 - hhi; +// coverageFractionByIndex[i] = clamp01(fraction * concentrationPenalty); +// } + +// for (size_t i = 0; i < targetList.size(); ++i) { +// coverageFraction[targetList[i]] = coverageFractionByIndex[i]; +// } + +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// const unsigned int target = it->second[j].result.dbKey; +// std::unordered_map::const_iterator cf = coverageFraction.find(target); +// it->second[j].coverageConfidence = (cf != coverageFraction.end()) ? cf->second : 0.0; +// } +// } +// } + +// static double scoreTerm(const ReclassTaxEntry &entry, double lambda, double alpha, double gamma) { +// const double bitScore = static_cast(entry.result.score); +// const double abundance = std::max(entry.abundance, EPS); +// const double coverageConfidence = std::max(entry.coverageConfidence, EPS); +// return std::exp(lambda * bitScore) * std::pow(abundance, alpha) * std::pow(coverageConfidence, gamma); +// } + +// static void computePosterior(MappingTable &mappingTable, double lambda, double alpha, double gamma) { +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// double denom = 0.0; +// for (size_t j = 0; j < it->second.size(); ++j) { +// denom += scoreTerm(it->second[j], lambda, alpha, gamma); +// } +// for (size_t j = 0; j < it->second.size(); ++j) { +// const double value = scoreTerm(it->second[j], lambda, alpha, gamma); +// it->second[j].posterior = (denom > 0.0) ? (value / denom) : 0.0; +// } +// } +// } + +// static double logLikelihood(const MappingTable &mappingTable, double lambda, double alpha, double gamma, size_t queryCount) { +// if (queryCount == 0) { +// return 0.0; +// } + +// double ll = 0.0; +// for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// double denom = 0.0; +// for (size_t j = 0; j < it->second.size(); ++j) { +// const double value = scoreTerm(it->second[j], lambda, alpha, gamma); +// denom += value; +// if (it->second[j].posterior > 0.0) { +// ll += it->second[j].posterior * std::log(value > 0.0 ? value : 1e-300); +// } +// } +// ll -= std::log(denom > 0.0 ? denom : 1e-300); +// } +// return ll / static_cast(queryCount); +// } + +// static std::vector abundanceVectorFromTable(const MappingTable &mappingTable, const std::vector &targetList) { +// std::unordered_map abundance; +// abundance.reserve(targetList.size()); +// for (size_t i = 0; i < targetList.size(); ++i) { +// abundance[targetList[i]] = 0.0; +// } + +// for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// abundance[it->second[j].result.dbKey] = it->second[j].abundance; +// } +// } + +// std::vector out; +// out.reserve(targetList.size()); +// for (size_t i = 0; i < targetList.size(); ++i) { +// out.push_back(abundance[targetList[i]]); +// } +// return out; +// } + +// static void setAbundance(MappingTable &mappingTable, const std::vector &targetList, const std::vector &abundanceVector) { +// std::unordered_map abundance; +// abundance.reserve(targetList.size()); +// for (size_t i = 0; i < targetList.size(); ++i) { +// abundance[targetList[i]] = abundanceVector[i]; +// } + +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// it->second[j].abundance = abundance[it->second[j].result.dbKey]; +// } +// } +// } + +// static std::vector emUpdate(MappingTable &mappingTable, +// double lambda, +// const std::unordered_map &fixedCoverageConfidence, +// const std::vector &targetList, +// size_t queryCount, +// double alpha, +// double gamma) { +// computePosterior(mappingTable, lambda, alpha, gamma); + +// std::unordered_map nextAbundance; +// nextAbundance.reserve(targetList.size()); +// for (size_t i = 0; i < targetList.size(); ++i) { +// nextAbundance[targetList[i]] = 0.0; +// } + +// for (MappingTable::const_iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// nextAbundance[it->second[j].result.dbKey] += it->second[j].posterior; +// } +// } + +// if (queryCount > 0) { +// const double denom = static_cast(queryCount); +// for (std::unordered_map::iterator it = nextAbundance.begin(); it != nextAbundance.end(); ++it) { +// it->second /= denom; +// } +// } + +// for (MappingTable::iterator it = mappingTable.begin(); it != mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// const unsigned int target = it->second[j].result.dbKey; +// it->second[j].abundance = nextAbundance[target]; +// std::unordered_map::const_iterator fixed = fixedCoverageConfidence.find(target); +// if (fixed != fixedCoverageConfidence.end()) { +// it->second[j].coverageConfidence = fixed->second; +// } else { +// it->second[j].coverageConfidence = 0.0; +// } +// } +// } + +// std::vector out; +// out.reserve(targetList.size()); +// for (size_t i = 0; i < targetList.size(); ++i) { +// out.push_back(nextAbundance[targetList[i]]); +// } +// return out; +// } + +// static std::vector projectSimplex(const std::vector &x) { +// std::vector projected = x; +// for (size_t i = 0; i < projected.size(); ++i) { +// if (projected[i] < 0.0) { +// projected[i] = 0.0; +// } +// } + +// const double sum = std::accumulate(projected.begin(), projected.end(), 0.0); +// if (sum > 0.0) { +// for (size_t i = 0; i < projected.size(); ++i) { +// projected[i] /= sum; +// } +// } +// return projected; +// } + +// static void squarem(ReclassTaxContext &ctx, +// double lambda, +// int maxIter, +// double tol, +// double alpha, +// double gamma, +// int threads) { +// if (ctx.queryCount == 0 || ctx.targetSet.empty()) { +// return; +// } + +// initAbundance(ctx.mappingTable, ctx.targetSet, ctx.queryCount); +// initCoverageConfidence(ctx.mappingTable, ctx.targetSet, threads); +// Debug(Debug::INFO) << "Reclassify initialized coverage confidence." << "\n"; + +// std::unordered_map fixedCoverageConfidence; +// fixedCoverageConfidence.reserve(ctx.targetSet.size()); +// for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// fixedCoverageConfidence[it->second[j].result.dbKey] = it->second[j].coverageConfidence; +// } +// } + +// const std::vector targetList = targetListFromSet(ctx.targetSet); +// std::vector x0 = abundanceVectorFromTable(ctx.mappingTable, targetList); +// std::vector logLikelihoods; + +// for (int iter = 0; iter < maxIter; ++iter) { +// const std::vector x1 = emUpdate(ctx.mappingTable, lambda, fixedCoverageConfidence, targetList, ctx.queryCount, alpha, gamma); +// const std::vector x2 = emUpdate(ctx.mappingTable, lambda, fixedCoverageConfidence, targetList, ctx.queryCount, alpha, gamma); + +// std::vector r(x0.size(), 0.0); +// std::vector v(x0.size(), 0.0); +// for (size_t i = 0; i < x0.size(); ++i) { +// r[i] = x1[i] - x0[i]; +// v[i] = x2[i] - x1[i] - r[i]; +// } + +// double normR = 0.0; +// double normV = 0.0; +// for (size_t i = 0; i < x0.size(); ++i) { +// normR += r[i] * r[i]; +// normV += v[i] * v[i]; +// } +// normR = std::sqrt(normR); +// normV = std::sqrt(normV); + +// double accel = (normV == 0.0) ? -1.0 : -(normR / normV); +// accel = std::max(STEP_MIN, std::min(STEP_MAX, accel)); + +// std::vector xNew(x0.size(), 0.0); +// for (size_t i = 0; i < x0.size(); ++i) { +// xNew[i] = x0[i] - 2.0 * accel * r[i] + accel * accel * v[i]; +// } +// xNew = projectSimplex(xNew); + +// setAbundance(ctx.mappingTable, targetList, xNew); +// computePosterior(ctx.mappingTable, lambda, alpha, gamma); +// double currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, gamma, ctx.queryCount); + +// if (!logLikelihoods.empty() && currentLl < logLikelihoods.back() - 1e-9) { +// setAbundance(ctx.mappingTable, targetList, x2); +// computePosterior(ctx.mappingTable, lambda, alpha, gamma); +// currentLl = logLikelihood(ctx.mappingTable, lambda, alpha, gamma, ctx.queryCount); +// xNew = x2; +// } + +// logLikelihoods.push_back(currentLl); + +// double parameterChange = 0.0; +// for (size_t i = 0; i < x0.size(); ++i) { +// parameterChange = std::max(parameterChange, std::fabs(xNew[i] - x0[i])); +// } + +// Debug(Debug::INFO) << "Reclassify iteration " << iter << ": LL=" << currentLl << " delta=" << parameterChange << "\n"; +// x0 = xNew; +// if (parameterChange < tol && iter > 5) { +// Debug(Debug::INFO) << "Reclassify converged after " << (iter + 1) << " iterations." << "\n"; +// break; +// } +// } +// } + +// static void addInterval(std::vector &intervals, int start, int end) { +// Interval interval; +// interval.start = std::min(start, end); +// interval.end = std::max(start, end); +// intervals.push_back(interval); +// } + +// static std::vector mergeIntervals(std::vector intervals) { +// if (intervals.empty()) { +// return intervals; +// } + +// std::sort(intervals.begin(), intervals.end(), [](const Interval &lhs, const Interval &rhs) { +// if (lhs.start != rhs.start) { +// return lhs.start < rhs.start; +// } +// return lhs.end < rhs.end; +// }); + +// std::vector merged; +// merged.push_back(intervals[0]); +// for (size_t i = 1; i < intervals.size(); ++i) { +// if (intervals[i].start <= merged.back().end + 1) { +// merged.back().end = std::max(merged.back().end, intervals[i].end); +// } else { +// merged.push_back(intervals[i]); +// } +// } +// return merged; +// } + +// static std::vector collectTargetStats(const ReclassTaxContext &ctx) { +// std::unordered_map statsByTarget; + +// for (MappingTable::const_iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end(); ++it) { +// for (size_t j = 0; j < it->second.size(); ++j) { +// const ReclassTaxEntry &entry = it->second[j]; +// TargetStats &stats = statsByTarget[entry.result.dbKey]; +// stats.key = entry.result.dbKey; +// stats.targetLength = entry.result.dbLen; +// stats.abundance = entry.abundance; +// stats.coverageConfidence = entry.coverageConfidence; +// stats.dropped = false; +// addInterval(stats.intervals, entry.result.dbStartPos, entry.result.dbEndPos); +// } +// } + +// std::vector out; +// out.reserve(statsByTarget.size()); +// for (std::unordered_map::iterator it = statsByTarget.begin(); it != statsByTarget.end(); ++it) { +// it->second.intervals = mergeIntervals(it->second.intervals); +// out.push_back(it->second); +// } + +// std::sort(out.begin(), out.end(), [](const TargetStats &lhs, const TargetStats &rhs) { +// if (lhs.abundance != rhs.abundance) { +// return lhs.abundance > rhs.abundance; +// } +// return lhs.key < rhs.key; +// }); +// return out; +// } + +// static double clamp01(double value) { +// return std::max(0.0, std::min(1.0, value)); +// } + +// static bool largestJumpCutoff(std::vector values, +// bool useLowTail, +// double maxTailFraction, +// double &cutoff, +// size_t &tailCount) { +// cutoff = 0.0; +// tailCount = 0; +// if (values.size() < MIN_FILTER_TARGETS) { +// return false; +// } + +// std::sort(values.begin(), values.end()); +// maxTailFraction = clamp01(maxTailFraction); +// const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); +// if (totalMass <= EPS || maxTailFraction <= 0.0) { +// return false; +// } +// const double maxTailMass = maxTailFraction * totalMass; + +// double bestGap = 0.0; +// size_t bestIdx = 0; +// double lowTailMass = 0.0; +// for (size_t i = 0; i + 1 < values.size(); ++i) { +// const size_t lowTailCount = i + 1; +// const size_t highTailCount = values.size() - lowTailCount; +// const size_t candidateTailCount = useLowTail ? lowTailCount : highTailCount; +// lowTailMass += values[i]; +// const double highTailMass = totalMass - lowTailMass; +// const double candidateTailMass = useLowTail ? lowTailMass : highTailMass; +// if (candidateTailCount < MIN_TAIL_TARGETS || candidateTailMass > (maxTailMass + EPS)) { +// continue; +// } + +// const double gap = values[i + 1] - values[i]; +// if (gap > bestGap) { +// bestGap = gap; +// bestIdx = i; +// tailCount = candidateTailCount; +// } +// } + +// if (bestGap <= EPS) { +// return false; +// } + +// cutoff = 0.5 * (values[bestIdx] + values[bestIdx + 1]); +// return true; +// } + +// static bool tailQuantileCutoff(std::vector values, +// bool useLowTail, +// double maxTailFraction, +// double &cutoff, +// size_t &tailCount) { +// cutoff = 0.0; +// tailCount = 0; +// if (values.size() < MIN_FILTER_TARGETS) { +// return false; +// } + +// std::sort(values.begin(), values.end()); +// maxTailFraction = clamp01(maxTailFraction); +// const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); +// if (totalMass <= EPS || maxTailFraction <= 0.0) { +// return false; +// } +// const double maxTailMass = maxTailFraction * totalMass; + +// double accumulatedMass = 0.0; +// size_t maxTailCount = 0; +// for (size_t i = 0; i < values.size(); ++i) { +// const double candidate = accumulatedMass + values[i]; +// if (candidate > (maxTailMass + EPS)) { +// break; +// } +// accumulatedMass = candidate; +// ++maxTailCount; +// } +// if (maxTailCount < MIN_TAIL_TARGETS || maxTailCount >= values.size()) { +// return false; +// } + +// tailCount = maxTailCount; +// if (useLowTail) { +// cutoff = values[tailCount - 1]; +// } else { +// cutoff = values[values.size() - tailCount]; +// } +// return true; +// } + +// static std::unordered_set selectTailTargets(const std::vector &stats, +// bool useLowTail, +// size_t tailCount, +// double maxTailFraction) { +// std::vector ordered; +// ordered.reserve(stats.size()); +// for (size_t i = 0; i < stats.size(); ++i) { +// ordered.push_back(&stats[i]); +// } + +// std::sort(ordered.begin(), ordered.end(), [useLowTail](const TargetStats *lhs, const TargetStats *rhs) { +// const double lhsValue = useLowTail ? lhs->abundance : lhs->coverageConfidence; +// const double rhsValue = useLowTail ? rhs->abundance : rhs->coverageConfidence; +// if (lhsValue != rhsValue) { +// return useLowTail ? (lhsValue < rhsValue) : (lhsValue > rhsValue); +// } +// return lhs->key < rhs->key; +// }); + +// double totalMass = 0.0; +// for (size_t i = 0; i < ordered.size(); ++i) { +// const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; +// totalMass += value; +// } +// const double maxTailMass = clamp01(maxTailFraction) * totalMass; + +// std::unordered_set selected; +// const size_t limit = std::min(tailCount, ordered.size()); +// double selectedMass = 0.0; +// selected.reserve(limit); +// for (size_t i = 0; i < limit; ++i) { +// const double value = useLowTail ? ordered[i]->abundance : ordered[i]->coverageConfidence; +// if (selected.size() >= MIN_TAIL_TARGETS && (selectedMass + value) > (maxTailMass + EPS)) { +// break; +// } +// selectedMass += value; +// selected.insert(ordered[i]->key); +// } +// return selected; +// } + +// static std::unordered_set selectDroppedTargets(const std::vector &stats, +// double maxDropPercentage, +// double &abundanceCutoff) { +// std::unordered_set dropped; +// if (stats.empty()) { +// abundanceCutoff = 0.0; +// return dropped; +// } +// if (stats.size() < MIN_FILTER_TARGETS) { +// abundanceCutoff = 0.0; +// return dropped; +// } + +// std::vector abundances; +// abundances.reserve(stats.size()); +// for (size_t i = 0; i < stats.size(); ++i) { +// abundances.push_back(stats[i].abundance); +// } + +// const double maxTailFraction = clamp01(maxDropPercentage / 100.0); +// size_t abundanceTailCount = 0; +// bool hasAbundanceCutoff = largestJumpCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); +// if (hasAbundanceCutoff == false) { +// hasAbundanceCutoff = tailQuantileCutoff(abundances, true, maxTailFraction, abundanceCutoff, abundanceTailCount); +// } +// if (hasAbundanceCutoff == false) { +// abundanceCutoff = 0.0; +// return dropped; +// } + +// const std::unordered_set lowAbundanceTargets = selectTailTargets(stats, true, abundanceTailCount, maxTailFraction); +// for (std::unordered_set::const_iterator it = lowAbundanceTargets.begin(); it != lowAbundanceTargets.end(); ++it) { +// dropped.insert(*it); +// } +// if (dropped.size() == stats.size()) { +// dropped.clear(); +// } +// return dropped; +// } + +// static void applyDroppedTargets(ReclassTaxContext &ctx, +// const std::unordered_set &dropped, +// size_t totalTargets, +// double abundanceCutoff) { +// if (dropped.empty()) { +// Debug(Debug::INFO) << "Reclassify target filter kept all targets. abundance cutoff=" +// << abundanceCutoff << "\n"; +// return; +// } + +// for (MappingTable::iterator it = ctx.mappingTable.begin(); it != ctx.mappingTable.end();) { +// std::vector &records = it->second; +// records.erase(std::remove_if(records.begin(), records.end(), [&dropped](const ReclassTaxEntry &entry) { +// return dropped.find(entry.result.dbKey) != dropped.end(); +// }), records.end()); + +// if (records.empty()) { +// it = ctx.mappingTable.erase(it); +// } else { +// ++it; +// } +// } + +// ctx.queryOrder.erase(std::remove_if(ctx.queryOrder.begin(), ctx.queryOrder.end(), [&ctx](unsigned int queryKey) { +// return ctx.mappingTable.find(queryKey) == ctx.mappingTable.end(); +// }), ctx.queryOrder.end()); + +// for (std::unordered_set::const_iterator it = dropped.begin(); it != dropped.end(); ++it) { +// ctx.targetSet.erase(*it); +// } + +// const double removedPct = (totalTargets > 0) +// ? (100.0 * static_cast(dropped.size()) / static_cast(totalTargets)) +// : 0.0; +// Debug(Debug::INFO) << "Reclassify dropped " << dropped.size() +// << " of " << totalTargets +// << " targets (" << removedPct << "%)" +// << " using abundance <= " << abundanceCutoff << ".\n"; +// } + +// static bool compareByPosteriorThenBitScore(const ReclassTaxEntry &a, const ReclassTaxEntry &b) { +// if (a.posterior != b.posterior) { +// return a.posterior > b.posterior; +// } +// if (a.result.score != b.result.score) { +// return a.result.score > b.result.score; +// } +// return Matcher::compareHits(a.result, b.result); +// } + +// static void writeReclassifiedDb(const ReclassTaxContext &ctx, +// int dbType, +// const std::string &outDb, +// const std::string &outIndex, +// int threads, +// bool compress) { +// DBWriter writer(outDb.c_str(), outIndex.c_str(), threads, compress, dbType); +// writer.open(); + +// Debug::Progress progress(ctx.queryOrder.size()); +// #pragma omp parallel +// { +// unsigned int thread_idx = 0; +// #ifdef OPENMP +// thread_idx = static_cast(omp_get_thread_num()); +// #endif +// char buffer[1024 + 32768 * 4]; + +// #pragma omp for schedule(dynamic, 5) +// for (size_t i = 0; i < ctx.queryOrder.size(); ++i) { +// progress.updateProgress(); +// const unsigned int queryKey = ctx.queryOrder[i]; +// MappingTable::const_iterator recordsIt = ctx.mappingTable.find(queryKey); +// if (recordsIt == ctx.mappingTable.end()) { +// continue; +// } + +// std::vector records = recordsIt->second; +// SORT_SERIAL(records.begin(), records.end(), compareByPosteriorThenBitScore); + +// writer.writeStart(thread_idx); +// for (size_t j = 0; j < records.size(); ++j) { +// Matcher::result_t res = records[j].result; +// res.seqId = static_cast(records[j].posterior); +// size_t len = Matcher::resultToBuffer(buffer, res, ctx.hasBacktrace, ctx.hasOrfPosition); +// writer.writeAdd(buffer, len, thread_idx); +// } +// writer.writeEnd(queryKey, thread_idx); +// } +// } + +// writer.close(); +// } +// } + +// int emreclassify(int argc, const char **argv, const Command &command) { +// Parameters &par = Parameters::getInstance(); +// par.parseParameters(argc, argv, command, true, 0, 0); + +// DBReader reader(par.db3.c_str(), par.db3Index.c_str(), par.threads, +// DBReader::USE_INDEX | DBReader::USE_DATA); +// reader.open(DBReader::LINEAR_ACCCESS); + +// ReclassTaxContext ctx; +// loadAlignmentDb(reader, ctx); +// Debug(Debug::INFO) << "Loaded " << ctx.queryCount << " queries with hits and " << ctx.targetSet.size() << " unique targets.\n"; + +// squarem(ctx, +// par.reclassifyLambda, +// par.reclassifyMaxIterations, +// par.reclassifyTolerance, +// par.reclassifyAlpha, +// par.reclassifyGamma, +// par.threads); + +// std::vector allTargetStats = collectTargetStats(ctx); +// double abundanceCutoff = 0.0; +// const std::unordered_set dropped = selectDroppedTargets(allTargetStats, +// par.reclassifyMaxDropPercentage, +// abundanceCutoff); +// applyDroppedTargets(ctx, dropped, allTargetStats.size(), abundanceCutoff); + +// writeReclassifiedDb(ctx, reader.getDbtype(), par.db4, par.db4Index, par.threads, par.compressed); + +// reader.close(); +// return EXIT_SUCCESS; +// } From 14a4caaa7b0c0b52c2f16464c736e017ad326cfd Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Thu, 14 May 2026 22:52:07 +0900 Subject: [PATCH 11/12] Remove markdown docs from EM_0514 update --- EM_abundance.md | 64 --------------- EM_reclassify.md | 191 ------------------------------------------ old.md | 174 --------------------------------------- reclassify.md | 158 ----------------------------------- reclassify0403.md | 205 ---------------------------------------------- 5 files changed, 792 deletions(-) delete mode 100644 EM_abundance.md delete mode 100644 EM_reclassify.md delete mode 100644 old.md delete mode 100644 reclassify.md delete mode 100644 reclassify0403.md diff --git a/EM_abundance.md b/EM_abundance.md deleted file mode 100644 index 359c1de2d..000000000 --- a/EM_abundance.md +++ /dev/null @@ -1,64 +0,0 @@ -# EM Abundance Notes 04/13 - -This document summarizes the outputs implemented in `src/util/EM_abundnace.cpp`. - -## Overview - -`mmseqs abundance` reads the reclassified alignment DB produced by `mmseqs reclassify` and produces two possible outputs: - -- Per-target abundance table (default). -- Kraken-style report when `--taxonomy 1` is set. - -The input alignment DB must include the posterior probability for each hit. The current implementation reads the posterior from the `seqId` field, or from an extra trailing column if present. - -## Example usage - -- mmseqs abundance queryDB targetDB newDB abundance.tsv -- mmseqs abundance queryDB targetDB newDB abundance.report --taxonomy 1 - -Notes: -- `newDB` should be created by `mmseqs reclassify`. -- When `--taxonomy 1` is used, `targetDB_mapping` and `targetDB_taxonomy` are required. - -## Per-target abundance table - -For each target, the command reports: -- target key and identifier -- abundance percentage -- coverage confidence -- drop flag based on low-abundance filtering -- mapped length and target length - -This table is written to the output path (typically `.tsv`). - -Drop handling: -- The per-target table includes all targets and marks low-abundance filtered ones in `Drop(y/n)`. -- When `--taxonomy 1` is used, dropped targets are removed before building Kraken/Bracken reports. - -## Kraken-style report - -When `--taxonomy 1` is set, the output is a Kraken-style report with fields: - -- percent of reads in the clade -- clade read count -- direct (taxon) read count -- rank code -- taxid -- name - -The report is written to the output path (typically `.report`). - -Notes on Kraken compatibility: -- The report format mirrors Kraken's `--report` layout, but values are not guaranteed to be identical to Kraken outputs. -- Counts are derived from EM abundance by converting percent to expected reads (with rounding), then aggregating clade/direct counts via the taxonomy tree. -- Ordering and handling of missing or unclassified taxa can differ from Kraken's implementation. - -## Abundance from posterior - -For each target $t$, abundance is computed from posteriors as: - -$$ -\hat{A}_t = \frac{1}{|Q|} \sum_q R_{qt} -$$ - -`R_{qt}` is the posterior for hit $(q,t)$ in `newDB`. Abundances are then converted to percentages and filtered by a low-abundance tail cutoff. diff --git a/EM_reclassify.md b/EM_reclassify.md deleted file mode 100644 index 13957871a..000000000 --- a/EM_reclassify.md +++ /dev/null @@ -1,191 +0,0 @@ -# EM_reclassify.cpp Summary (04/12) - -## Overview -- This module performs EM-based reclassification of alignment hits. -- It updates target abundances and posterior probabilities, optionally drops low-abundance targets, and writes a new alignment DB where hits are sorted by posterior. -- Primary input: alignment result DB (`par.db3`). Primary output: reclassified alignment DB (`par.db4`). - -## Example Usage Order -- `mmseqs createdb query.fasta queryDB` -- `mmseqs search queryDB targetDB alignDB tmp -a` -- `mmseqs reclassify queryDB targetDB alignDB newDB` -- `mmseqs convertalis queryDB targetDB newDB reclassify_result.m8` -- `mmseqs abundance queryDB targetDB newDB abundance.tsv --taxonomy 1` - -Notes: -- Plain custom FASTA target DB: use `createdb` + `createtaxdb` as needed. -- Prebuilt taxonomy-ready MMseqs DB: `createtaxdb` is usually not needed. - -## Key Data Structures -- `ReclassTaxEntry`: per-hit working record with `abundance`, `posterior`, `coverageConfidence`. -- `MappingTable`: `queryKey -> vector`. -- `ReclassTaxContext`: full working set (mapping table, query order, target set, query count, output-format flags). -- `TargetStats`: per-target aggregate stats used for filtering/reporting. - -## Mathematical Definitions -Let `q` be a query and `t` a target hit for `q`. - -- Normalized score: - -$$ -s_{q,t} = \frac{\text{score}_{q,t}}{\max_{t'} \text{score}_{q,t'}} -$$ - -- Score term used in posterior update: - -$$ -\text{scoreTerm}_{q,t} = \exp(\lambda \cdot s_{q,t}) \cdot a_t^{\alpha} \cdot c_t^{\gamma} -$$ - -where `a_t` is abundance, `c_t` is coverage confidence. - -- Posterior: - -$$ -p_{q,t} = \frac{\text{scoreTerm}_{q,t}}{\sum_{t'} \text{scoreTerm}_{q,t'}} -$$ - -- Abundance update: - -$$ -a_t^{\text{new}} = \frac{1}{|Q|}\sum_{q \in Q} p_{q,t} -$$ - -- Log-likelihood (average per query): - -$$ -\mathcal{L} = \frac{1}{|Q|}\sum_{q \in Q}\left(\sum_t p_{q,t}\log(\text{scoreTerm}_{q,t}) - \log\sum_t \text{scoreTerm}_{q,t}\right) -$$ - -## Coverage Confidence -Coverage confidence is computed once at initialization and then kept fixed during EM. - -### 1) Query-level hit weight (score normalization) -For each query `q`, define hit weight by plain bit-score normalization (no `exp`, no `lambda`): - -$$ -w_{q,h} = \frac{\text{score}_{q,h}}{\sum_{h'} \text{score}_{q,h'}} -$$ - -### 2) Position-wise accumulation on each target -For each target position `p`, accumulate weighted support: - -$$ -\text{covConf}_t(p) = \sum_{h: p \in [\text{start}_h, \text{end}_h]} w_{q,h} -$$ - -Clip each position to avoid over-crediting stacked/repeat mappings at one locus: - -$$ -\tilde{c}_t(p) = \min(1, \text{covConf}_t(p)) -$$ - -### 3) Base coverage over observed span -Observed span length: - -$$ -L_t = \max(1, \text{endPos}_t - \text{startPos}_t + 1) -$$ - -Base coverage fraction: - -$$ -f_t = \frac{1}{L_t}\sum_{p=1}^{L_t}\tilde{c}_t(p) -$$ - -### 4) Concentration penalty (HHI-based) -Define concentration (Herfindahl-Hirschman Index over clipped position mass): - -$$ -\text{HHI}_t = \frac{\sum_{p=1}^{L_t}\tilde{c}_t(p)^2}{\left(\sum_{p=1}^{L_t}\tilde{c}_t(p)\right)^2} -$$ - -with the edge case: - -$$ -\text{HHI}_t = 1 \quad \text{if } \sum_{p}\tilde{c}_t(p)=0 -$$ - -Convert to dispersion reward (low when concentrated, high when spread): - -$$ -\text{penalty}_t = 1 - \text{HHI}_t -$$ - -### 5) Final coverage confidence - -$$ -c_t = \operatorname{clamp}_{[0,1]}\left(f_t \cdot \text{penalty}_t\right) -$$ - -Important details: -- Position contribution is clipped by `min(1, ...)`. -- Normalization length is **observed span** (`maxEnd - minStart + 1`), not full `dbLen`. -- Final value includes HHI-based concentration penalty and is clamped to `[0, 1]`. -- `lambda` still affects posterior via `scoreTerm`, but no longer affects coverage-confidence initialization. -- Behavior intent: - - Many hits stacked in one local region: high concentration -> high HHI -> small `(1-HHI)` -> lower `c_t`. - - Hits distributed across many positions: lower concentration -> lower HHI -> larger `(1-HHI)` -> higher `c_t`. - -## Processing Pipeline -1. Load alignment DB -- `loadAlignmentDb()` parses alignment rows into `mappingTable`. -- Keeps backtrace/ORF layout flags for output compatibility. - -2. Initialize -- `initAbundance()` initializes abundance from query-normalized hit scores. - -### Abundance Initialization (초기값 계산) - -각 쿼리 $q$에 대해, 해당 쿼리의 모든 target hit $t$에 대해 bit score를 정규화합니다: - -$$ -w_{q,t} = \frac{\text{score}_{q,t}}{\sum_{t'} \text{score}_{q,t'}} -$$ - -각 target $t$의 초기 abundance는, 해당 target이 등장한 모든 쿼리에서의 $w_{q,t}$ 값을 평균낸 값입니다: - -$$ -a_t^{(0)} = \frac{1}{|Q|} \sum_{q \in Q} w_{q,t} -$$ - -즉, 각 target의 abundance는 쿼리별로 score-normalized weight를 합산한 뒤 전체 쿼리 수로 나누어 초기화합니다. -- `initCoverageConfidence()` computes per-target `coverageConfidence` using `score/sum(score)` weights. - -3. EM + SQUAREM -- `computePosterior()` computes per-query posterior. -- `emUpdate()` updates abundance; coverage confidence remains fixed. -- `squarem()` accelerates EM with extrapolation and simplex projection. -- If LL decreases, fallback to conservative step (`x2`). - -4. Optional target filtering -- `collectTargetStats()` builds target-level abundance/coverage/interval stats. -- `selectDroppedTargets()` chooses low-tail drops under mass cap. -- `applyDroppedTargets()` removes dropped targets from mapping/query/target sets. - -5. Write output DB -- `writeReclassifiedDb()` writes hits sorted by posterior. -- Posterior is written into output `seqId` field. - -## Parameters (Current Defaults) -- `--lambda` (`reclassifyLambda`): `2` -- `--alpha` (`reclassifyAlpha`): `1.0` -- `--gamma` (`reclassifyGamma`): `1.0` -- `--max-iter` (`reclassifyMaxIterations`): `100` -- `--tol` (`reclassifyTolerance`): `1e-5` -- `--drop-percentage` (`reclassifyMaxDropPercentage`): `10.0` - -## Output/Storage Clarification -- `reclassify` output DB stores posterior in `seqId`. -- `coverageConfidence` is not persisted in alignment DB columns. -- `abundance` recomputes `coverageConfidence` from input alignments and writes it to `abundance.tsv`. - -## Consistency Across Modules -The updated coverage-confidence logic (`score/sum(score)` + observed-span normalization + HHI penalty) is applied consistently in: -- `src/util/EM_reclassify.cpp` -- `src/util/EM_abundnace.cpp` - -## Logging -- Coverage-confidence initialization complete message. -- Iteration LL and parameter delta. -- Number of dropped targets and applied abundance cutoff. diff --git a/old.md b/old.md deleted file mode 100644 index 84be02a80..000000000 --- a/old.md +++ /dev/null @@ -1,174 +0,0 @@ -# Reclassify Taxonomy2 Equations - -This document summarizes the equations implemented in `EM/reclassify-taxonomy2.cpp`. - What reclassify adds: - - - It does not treat each query in isolation. - - It estimates target abundance over all matched queries. - - It uses that abundance to boost hits to globally supported targets and downweight isolated or ambiguous hits. - - It penalizes high-entropy or diffuse mappings. - - It iterates this with an EM-style procedure, accelerated with SQUAREM, until convergence. - -You can use the 'mmseqs reclassify' after running 'mmseqs search' -## Ex usage (search -> align -> (createtaxdb) -> reclassify) -- mmseqs createdb query.fasta queryDB -- mmseqs search queryDB targetDB alignDB tmp -a -- mmseqs reclassify queryDB targetDB alignmentDB newDB -- mmseqs convertalis queryDB targetDB newnDB(or alignDB) reclassify_result.m8 -- mmseqs abundance queryDB targetDB newDB abundance.report (--taxonomy 1) -* Plain custom FASTA target DB: createdb + createtaxdb -* Prebuilt taxonomy-ready MMseqs database: createtaxdb usually not needed - -## Notation - -- Query index: $q$ -- Target/protein index: $t$ -- Hit list for query $q$: $H(q)$ -- Bit score: $s_{qt}$ -- Query max score: $S_q = \max_{t \in H(q)} s_{qt}$ -- Target abundance: $A_t$ -- Entropy value: $E_t$ -- Entropy penalty: $P_t$ -- Posterior for hit $(q,t)$: $R_{qt}$ -- Parameters: $\lambda, \alpha, \gamma$ -- Small constant: $\varepsilon = 10^{-12}$ - -## Abundance initialization - -For each query $q$, compute the sum of scores over its hit list: - -$$ -S^{\text{sum}}_q = \sum_{t \in H(q)} s_{qt} -$$ - -If $S^{\text{sum}}_q > 0$, each target gets a fractional count: - -$$ -c_{qt} = \frac{s_{qt}}{S^{\text{sum}}_q} -$$ - -Aggregate counts across all queries and normalize by $|Q|$: - -$$ -A_t = \frac{1}{|Q|} \sum_q c_{qt} -$$ - -If $|Q| = 0$, all $A_t$ are treated as $0$. - -## Entropy value and penalty - -For each target $t$, find its global alignment bounds: - -$$ - {minPos}_t = \min_q \text{dbStart}_{qt}, \quad - {maxPos}_t = \max_q \text{dbEnd}_{qt} -$$ - -Coverage is accumulated over the interval $[\text{minPos}_t,\text{maxPos}_t]$. -For each hit $(q,t)$: - -$$ -\ell_{qt} = \text{dbEnd}_{qt} - \text{dbStart}_{qt} + 1 -$$ - -$$ -m_{qt} = \frac{\exp(\lambda s_{qt})}{\ell_{qt}} -$$ - -For each position $p$ covered by the hit: - -$$ - ext{cov}_t(p) \mathrel{+}= m_{qt} -$$ - -Normalize coverage to probabilities: - -$$ -P_t(p) = \frac{\text{cov}_t(p)}{\sum_{p'} \text{cov}_t(p')} -$$ - -Entropy for target $t$: - -$$ -E_t = -\sum_p P_t(p) \log_2 P_t(p) -$$ - -Entropy penalty (let $E^{\text{sum}} = \sum_{t'} E_{t'}$): - -$$ -P_t = 1 - \frac{E_t}{E^{\text{sum}}} -$$ - -If $E^{\text{sum}} \le 0$, then $P_t = 0$. - -## Score term - -Normalized score for a hit: - -$$ - {ns}_{qt} = \frac{s_{qt}}{S_q} -$$ - -Score term used for posteriors: - -$$ -T_{qt} = \exp(\lambda \text{ns}_{qt}) \cdot \max(A_t,\varepsilon)^{\alpha} \cdot \max(P_t,\varepsilon)^{\gamma} -$$ - -If $S_q \le 0$, then $T_{qt} = 0$ for all $t \in H(q)$. - -## Posterior for each query - -For each query $q$: - -$$ -Z_q = \sum_{t \in H(q)} T_{qt} -$$ - -$$ -R_{qt} = \frac{T_{qt}}{Z_q} -$$ - -If $Z_q = 0$, then $R_{qt} = 0$. - -## EM abundance update - -Given posteriors, the next abundance is: - -$$ -\hat{A}_t = \frac{1}{|Q|} \sum_q R_{qt} -$$ - -## Log-likelihood (per query) - -$$ -\mathcal{L} = \frac{1}{|Q|} \sum_q \left( \sum_t R_{qt} \log(T_{qt}) - \log(Z_q) \right) -$$ - -## SQUAREM step (acceleration) - -Two EM steps produce $x_1$ and $x_2$ from $x_0$: - -$$ -\mathbf{r} = x_1 - x_0, \quad \mathbf{v} = x_2 - x_1 - \mathbf{r} -$$ - -Acceleration scalar (clipped to $[-1, 1]$): - -$$ -a = -\frac{\|\mathbf{r}\|}{\|\mathbf{v}\|} -$$ - -Proposed update and simplex projection: - -$$ -\mathbf{x}_{\text{new}} = x_0 - 2a\mathbf{r} + a^2\mathbf{v}, \quad \mathbf{x}_{\text{new}} \in \Delta -$$ - -## Simplex projection - -Clamp negatives and renormalize: - -$$ -x_i \leftarrow \max(0, x_i), \quad x_i \leftarrow \frac{x_i}{\sum_j x_j} -$$ diff --git a/reclassify.md b/reclassify.md deleted file mode 100644 index 2fe9d6a3c..000000000 --- a/reclassify.md +++ /dev/null @@ -1,158 +0,0 @@ -# Reclassify EM Notes - -This document summarizes the equations and outputs implemented in `src/util/EM_reclassify.cpp`. - -What reclassify adds: -- It does not treat each query in isolation. -- It estimates target abundance over all matched queries. -- It uses that abundance to boost hits to globally supported targets and downweight isolated or ambiguous hits. -- It penalizes high-entropy or diffuse mappings. -- It iterates this with an EM-style procedure, accelerated with SQUAREM, until convergence. - -You can use `mmseqs reclassify` after running `mmseqs search` and (optional) `mmseqs align -a`. - -## Example usage (search -> align -> reclassify) -- mmseqs createdb query.fasta queryDB -- mmseqs search queryDB targetDB alignDB tmp -a -- mmseqs reclassify queryDB targetDB alignDB newDB -- mmseqs convertalis queryDB targetDB newDB reclassify_result.m8 -- mmseqs abundance queryDB targetDB newDB abundance.report (--taxonomy 1) - -Notes: -- `newDB` is an alignment DB in the same format as `alignDB`, but ordered by posterior probability per query. -- The posterior probability replaces the `seqId` field in `newDB` to keep the column count unchanged. - -## Notation - -- Query index: $q$ -- Target/protein index: $t$ -- Hit list for query $q$: $H(q)$ -- Bit score: $s_{qt}$ -- Query max score: $S_q = \max_{t \in H(q)} s_{qt}$ -- Target abundance: $A_t$ -- Coverage confidence: $C_t$ -- Posterior for hit $(q,t)$: $R_{qt}$ -- Parameters: $\lambda, \alpha, \gamma$ -- Small constant: $\varepsilon = 10^{-12}$ - -## Abundance initialization - -For each query $q$, compute the sum of scores over its hit list: - -$$ -S^{\text{sum}}_q = \sum_{t \in H(q)} s_{qt} -$$ - -If $S^{\text{sum}}_q > 0$, each target gets a fractional count: - -$$ -c_{qt} = \frac{s_{qt}}{S^{\text{sum}}_q} -$$ - -Aggregate counts across all queries and normalize by $|Q|$: - -$$ -A_t = \frac{1}{|Q|} \sum_q c_{qt} -$$ - -If $|Q| = 0$, all $A_t$ are treated as $0$. - -## Coverage confidence - -For each target $t$, find its global alignment bounds: - -$$ - ext{minPos}_t = \min_q \text{dbStart}_{qt}, \quad - ext{maxPos}_t = \max_q \text{dbEnd}_{qt} -$$ - -Coverage is accumulated over the interval $[\text{minPos}_t,\text{maxPos}_t]$. -For each hit $(q,t)$: - -$$ -\ell_{qt} = \text{dbEnd}_{qt} - \text{dbStart}_{qt} + 1 -$$ - -$$ -m_{qt} = \frac{\exp(\lambda s_{qt})}{\ell_{qt}} -$$ - -For each position $p$ covered by the hit: - -$$ - ext{cov}_t(p) \mathrel{+}= m_{qt} -$$ - -The per-target coverage confidence $C_t$ is the fraction of target positions where the accumulated per-hit weights reach at least 1.0. Values are clipped to $[0,1]$. - -## Score term - -Normalized score for a hit: - -$$ -\text{ns}_{qt} = \frac{s_{qt}}{S_q} -$$ - -Score term used for posteriors: - -$$ -T_{qt} = \exp(\lambda \text{ns}_{qt}) \cdot \max(A_t,\varepsilon)^{\alpha} \cdot \max(C_t,\varepsilon)^{\gamma} -$$ - -If $S_q \le 0$, then $T_{qt} = 0$ for all $t \in H(q)$. - -## Posterior for each query - -For each query $q$: - -$$ -Z_q = \sum_{t \in H(q)} T_{qt} -$$ - -$$ -R_{qt} = \frac{T_{qt}}{Z_q} -$$ - -If $Z_q = 0$, then $R_{qt} = 0$. - -## EM abundance update - -Given posteriors, the next abundance is: - -$$ -\hat{A}_t = \frac{1}{|Q|} \sum_q R_{qt} -$$ - -## Log-likelihood (per query) - -$$ -\mathcal{L} = \frac{1}{|Q|} \sum_q \left( \sum_t R_{qt} \log(T_{qt}) - \log(Z_q) \right) -$$ - -## SQUAREM step (acceleration) - -Two EM steps produce $x_1$ and $x_2$ from $x_0$: - -$$ -\mathbf{r} = x_1 - x_0, \quad \mathbf{v} = x_2 - x_1 - \mathbf{r} -$$ - -Acceleration scalar (clipped to $[-1, 1]$): - -$$ -a = -\frac{\|\mathbf{r}\|}{\|\mathbf{v}\|} -$$ - -Proposed update and simplex projection: - -$$ -\mathbf{x}_{\text{new}} = x_0 - 2a\mathbf{r} + a^2\mathbf{v}, \quad \mathbf{x}_{\text{new}} \in \Delta -$$ - -## Simplex projection - -Clamp negatives and renormalize: - -$$ -x_i \leftarrow \max(0, x_i), \quad x_i \leftarrow \frac{x_i}{\sum_j x_j} -$$ diff --git a/reclassify0403.md b/reclassify0403.md deleted file mode 100644 index 3e9ab2e24..000000000 --- a/reclassify0403.md +++ /dev/null @@ -1,205 +0,0 @@ -# Reclassify 0406 Notes - -This document summarizes the logic and formulas in `src/util/reclassify_taxonomy.cpp`. - -## Overview - -- Input: an alignment-result DB (from `mmseqs search -a` or `mmseqs align`) with per-hit fields. -- Output: - - `new_alignment_result.m8` (re-ranked by posterior) - - `protein_abundance.tsv` (per-target stats) - - `taxonomy_abundance.tsv` (optional taxonomy aggregation) - -The pipeline is: -1) Read alignments into per-query hit lists. -2) Initialize abundance and coverage confidence. -3) Run EM with SQUAREM acceleration. -4) Drop low-abundance targets by tail cutoff. -5) Write outputs. - -## Ex usage order -- mmseqs createdb query.fasta queryDB -- mmseqs search queryDB targetDB alignDB tmp -a -- mmseqs reclassify queryDB targetDB alignmentDB newDB -- mmseqs convertalis queryDB targetDB newnDB(or alignDB) reclassify_result.m8 -- mmseqs abundance queryDB targetDB newDB abundance.report (--taxonomy 1) -* Plain custom FASTA target DB: createdb + createtaxdb -* Prebuilt taxonomy-ready MMseqs database: createtaxdb usually not needed - -## Data structures - -- Query key: the DB key for a query sequence. -- Target key: the DB key for a target sequence. -- Hit: a `Matcher::result_t` record (alignment fields, scores, positions). -- Per-hit state: - - abundance $A_t$ - - posterior $R_{qt}$ - - coverage confidence $C_t$ - -## Notation - -- Query index: $q$ -- Target/protein index: $t$ -- Hit list for query $q$: $H(q)$ -- Bit score: $s_{qt}$ -- Query max score: $S_q = \max_{t \in H(q)} s_{qt}$ -- Target abundance: $A_t$ -- Coverage confidence: $C_t$ -- Posterior for hit $(q,t)$: $R_{qt}$ -- Parameters: $\lambda, \alpha, \gamma$ -- Small constant: $\varepsilon = 10^{-12}$ - -## Step 1: Load alignment DB - -- Each DB entry is parsed into per-query hit lists. -- Targets are tracked in a global set. - -## Step 2: Abundance initialization - -For each query $q$, sum scores over hits: - -$$ -S^{\text{sum}}_q = \sum_{t \in H(q)} s_{qt} -$$ - -If $S^{\text{sum}}_q > 0$, assign fractional counts: - -$$ -c_{qt} = \frac{s_{qt}}{S^{\text{sum}}_q} -$$ - -Aggregate over all queries and normalize by $|Q|$: - -$$ -A_t = \frac{1}{|Q|} \sum_q c_{qt} -$$ - -## Step 3: Coverage confidence (per target) - -Threads usage: -- The coverage-confidence calculation is parallelized across targets with OpenMP. -- The loop over target IDs runs with `#pragma omp parallel for num_threads(threads) schedule(dynamic, 1)`. -- Other steps in this pipeline are single-threaded in this file. - -For each target $t$, compute global bounds: - -$$ -\text{minPos}_t = \min_q \text{dbStart}_{qt}, \quad -\text{maxPos}_t = \max_q \text{dbEnd}_{qt} -$$ - -For each hit $(q,t)$, define alignment length: - -$$ -\ell_{qt} = \text{dbEnd}_{qt} - \text{dbStart}_{qt} + 1 -$$ - -Define per-hit weight: - -$$ -\text{expScore}_{qt} = \exp(\lambda s_{qt}), \quad -w_{qt} = \frac{\text{expScore}_{qt}}{\sum_{t' \in H(q)} \text{expScore}_{qt'}} -$$ - -Coverage accumulation over the target interval: - -$$ -\text{cov}_t(p) \mathrel{+}= \frac{\text{expScore}_{qt}}{\ell_{qt}} -$$ - -Coverage confidence accumulation: - -$$ -\text{conf}_t(p) \mathrel{+}= w_{qt} -$$ - -Coverage confidence (per target): - -$$ -C_t = \mathrm{clamp}_{[0,1]}\left(\frac{\sum_p \min(1, \text{conf}_t(p))}{\text{targetLen}_t}\right) -$$ - -## Step 4: Score term and posterior - -Normalized score: - -$$ -\text{ns}_{qt} = \frac{s_{qt}}{S_q} -$$ - -Score term: - -$$ -T_{qt} = \exp(\lambda \text{ns}_{qt}) \cdot \max(A_t, \varepsilon)^{\alpha} \cdot \max(C_t, \varepsilon)^{\gamma} -$$ - -Posterior for query $q$: - -$$ -Z_q = \sum_{t \in H(q)} T_{qt}, \quad -R_{qt} = \frac{T_{qt}}{Z_q} -$$ - -## Step 5: EM update - -Abundance update: - -$$ -\hat{A}_t = \frac{1}{|Q|} \sum_q R_{qt} -$$ - -## Step 6: Log-likelihood (per query) - -$$ -\mathcal{L} = \frac{1}{|Q|} \sum_q \left( \sum_t R_{qt} \log(T_{qt}) - \log(Z_q) \right) -$$ - -## Step 7: SQUAREM acceleration - -Two EM steps produce $x_1$ and $x_2$ from $x_0$: - -$$ -\mathbf{r} = x_1 - x_0, \quad \mathbf{v} = x_2 - x_1 - \mathbf{r} -$$ - -Acceleration scalar (clipped to $[-1, 1]$): - -$$ -a = -\frac{\|\mathbf{r}\|}{\|\mathbf{v}\|} -$$ - -Proposed update and simplex projection: - -$$ -\mathbf{x}_{\text{new}} = x_0 - 2a\mathbf{r} + a^2\mathbf{v}, \quad \mathbf{x}_{\text{new}} \in \Delta -$$ - -Simplex projection: - -$$ -x_i \leftarrow \max(0, x_i), \quad x_i \leftarrow \frac{x_i}{\sum_j x_j} -$$ - -If the accelerated step decreases log-likelihood, the code falls back to $x_2$. - -## Step 8: Target filtering (low-abundance drop) - -- Compute a cutoff on target abundances using a largest-gap rule or tail quantile rule. -- Drop up to `maxDropPercentage` of targets (with a minimum tail size). -- Remove dropped targets from the mapping table and recompute output stats. - -## Outputs - -- `new_alignment_result.m8` - - Hits re-sorted by posterior. - - Alignment summary derived from backtrace if present. -- `protein_abundance.tsv` - - Per-target: abundance %, coverage confidence, drop flag, mapped interval(s). -- `taxonomy_abundance.tsv` (if taxonomy enabled) - - Aggregated by taxonomy ID. - -## Notes - -- `C_t` (coverage confidence) is used directly in the score term. -- Parallelism is used in coverage confidence initialization across targets. -- No other steps in this file use OpenMP; EM and output phases are serial here. From 5f54095be41f49d80538b5a7947ba1dc83e4b028 Mon Sep 17 00:00:00 2001 From: Yaeji Kim Date: Tue, 26 May 2026 23:51:37 +0900 Subject: [PATCH 12/12] reclassify final output change- add target's abundance --- src/util/EM_reclassify.cpp | 80 +++++++++++++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/src/util/EM_reclassify.cpp b/src/util/EM_reclassify.cpp index 2d1ab23bb..49294c5f5 100644 --- a/src/util/EM_reclassify.cpp +++ b/src/util/EM_reclassify.cpp @@ -9,8 +9,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -105,7 +107,7 @@ static void loadAlignmentDb(DBReader &reader, ReclassTaxContext &c } Matcher::result_t result = Matcher::parseAlignmentRecord(data, true); - records.push_back(ReclassTaxEntry{result, 0.0, 0.0, 0.0}); + records.push_back(ReclassTaxEntry{result, 0.0, static_cast(result.seqId), 0.0}); ctx.targetSet.insert(result.dbKey); data = Util::skipLine(data); } @@ -600,6 +602,73 @@ static std::vector collectTargetStats(const ReclassTaxContext &ctx) return out; } +static void printAbundanceDistribution(const std::vector &stats) { + if (stats.empty()) { + Debug(Debug::INFO) << "Abundance distribution: no targets.\n"; + return; + } + + std::vector values; + values.reserve(stats.size()); + for (size_t i = 0; i < stats.size(); ++i) { + values.push_back(stats[i].abundance); + } + std::sort(values.begin(), values.end()); + + const auto quantile = [&values](double q) -> double { + const size_t idx = static_cast(q * static_cast(values.size() - 1)); + return values[idx]; + }; + + const double totalMass = std::accumulate(values.begin(), values.end(), 0.0); + const auto cumulativeCountAtMassFraction = [&values, totalMass](double massFraction, size_t &count, double &countFrac) { + count = 0; + countFrac = 0.0; + if (values.empty() || totalMass <= 0.0 || massFraction <= 0.0) { + return; + } + + const double targetMass = massFraction * totalMass; + double cumulativeMass = 0.0; + for (size_t i = 0; i < values.size(); ++i) { + if ((cumulativeMass + values[i]) <= targetMass) { + cumulativeMass += values[i]; + ++count; + } else { + break; + } + } + countFrac = static_cast(count) / static_cast(values.size()); + }; + + std::ostringstream oss; + oss << std::fixed << std::setprecision(8); + oss << "Abundance distribution (targets=" << values.size() << "):" + << " min=" << values.front() + << " p25=" << quantile(0.25) + << " p50=" << quantile(0.50) + << " p75=" << quantile(0.75) + << " p90=" << quantile(0.90) + << " p95=" << quantile(0.95) + << " p99=" << quantile(0.99) + << " max=" << values.back(); + Debug(Debug::INFO) << oss.str() << "\n"; + + const double cutoffs[] = {0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99}; + std::ostringstream cumulativeOss; + cumulativeOss << std::fixed << std::setprecision(8); + cumulativeOss << "Abundance cumulative:"; + for (size_t i = 0; i < sizeof(cutoffs) / sizeof(cutoffs[0]); ++i) { + size_t count = 0; + double countFrac = 0.0; + cumulativeCountAtMassFraction(cutoffs[i], count, countFrac); + cumulativeOss << " <=" << cutoffs[i] + << "[count=" << count + << ",countFrac=" << countFrac << "]"; + } + Debug(Debug::INFO) << cumulativeOss.str() << "\n"; +} + static double clamp01(double value) { return std::max(0.0, std::min(1.0, value)); } @@ -840,11 +909,20 @@ int emreclassify(int argc, const char **argv, const Command &command) { par.threads); std::vector allTargetStats = collectTargetStats(ctx); + printAbundanceDistribution(allTargetStats); double abundanceCutoff = 0.0; const std::unordered_set dropped = selectDroppedTargets(allTargetStats, par.reclassifyMaxDropPercentage, abundanceCutoff); applyDroppedTargets(ctx, dropped, allTargetStats.size(), abundanceCutoff); + const size_t totalTargets = allTargetStats.size(); + const size_t removedTargets = dropped.size(); + const double removedPct = (totalTargets > 0) + ? (100.0 * static_cast(removedTargets) / static_cast(totalTargets)) + : 0.0; + Debug(Debug::INFO) << "Reclassify-drop summary: removed " << removedTargets + << " / " << totalTargets + << " targets (" << removedPct << "%).\n"; writeReclassifiedDb(ctx, reader.getDbtype(), par.db4, par.db4Index, par.threads, par.compressed);