Skip to content

Commit

Permalink
Merge pull request #354 from steineggerlab/test
Browse files Browse the repository at this point in the history
add --single-chain-include-mode
  • Loading branch information
Woosub-Kim authored Sep 19, 2024
2 parents b7c58ac + d267b3d commit 8809363
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 14 deletions.
6 changes: 5 additions & 1 deletion src/commons/LocalParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ LocalParameters::LocalParameters() :
PARAM_EXACT_TMSCORE(PARAM_EXACT_TMSCORE_ID,"--exact-tmscore", "Exact TMscore","turn on fast exact TMscore (slow), default is approximate" ,typeid(int), (void *) &exactTMscore, "^[0-1]{1}$"),
PARAM_N_SAMPLE(PARAM_N_SAMPLE_ID, "--n-sample", "Sample size","pick N random sample" ,typeid(int), (void *) &nsample, "^[0-9]{1}[0-9]*$"),
PARAM_COORD_STORE_MODE(PARAM_COORD_STORE_MODE_ID, "--coord-store-mode", "Coord store mode", "Coordinate storage mode: \n1: C-alpha as float\n2: C-alpha as difference (uint16_t)", typeid(int), (void *) &coordStoreMode, "^[1-2]{1}$",MMseqsParameter::COMMAND_EXPERT),
PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD(PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD_ID, "--min-assigned-chains-ratio", "Minimum assigned chains percentage Threshold", "minimum percentage of assigned chains out of all query chains > thr [0,100] %", typeid(float), (void *) & minAssignedChainsThreshold, "^[0-9]*(\\.[0-9]+)?$"),
PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD(PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD_ID, "--min-assigned-chains-ratio", "Minimum assigned chains percentage Threshold", "Minimum ratio of assigned chains out of all query chains > thr [0.0,1.0]", typeid(float), (void *) & minAssignedChainsThreshold, "^[0-9]*(\\.[0-9]+)?$", MMseqsParameter::COMMAND_ALIGN),
PARAM_SINGLE_CHAIN_INCLUDE_MODE(PARAM_SINGLE_CHAIN_INCLUDE_MODE_ID, "--single-chain-include-mode", "Single Chained Assignments Inclusion Mode for Multimer", "Single Chained Assignments Inclusion 0: include single chained assignments, 1: NOT include single chained assignment", typeid(int), (void *) & singleChainIncludeMode, "^[0-1]{1}$", MMseqsParameter::COMMAND_ALIGN),
PARAM_CLUSTER_SEARCH(PARAM_CLUSTER_SEARCH_ID, "--cluster-search", "Cluster search", "first find representative then align all cluster members", typeid(int), (void *) &clusterSearch, "^[0-1]{1}$",MMseqsParameter::COMMAND_MISC),
PARAM_FILE_INCLUDE(PARAM_FILE_INCLUDE_ID, "--file-include", "File Inclusion Regex", "Include file names based on this regex", typeid(std::string), (void *) &fileInclude, "^.*$"),
PARAM_FILE_EXCLUDE(PARAM_FILE_EXCLUDE_ID, "--file-exclude", "File Exclusion Regex", "Exclude file names based on this regex", typeid(std::string), (void *) &fileExclude, "^.*$"),
Expand Down Expand Up @@ -186,6 +187,7 @@ LocalParameters::LocalParameters() :

//scorecmultimer
scoremultimer.push_back(&PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD);
scoremultimer.push_back(&PARAM_SINGLE_CHAIN_INCLUDE_MODE);
scoremultimer.push_back(&PARAM_THREADS);
scoremultimer.push_back(&PARAM_V);

Expand Down Expand Up @@ -228,6 +230,8 @@ LocalParameters::LocalParameters() :
minDiagScoreThr = 30;
maskBfactorThreshold = 0;
chainNameMode = 0;
minAssignedChainsThreshold = 0.0;
singleChainIncludeMode = 0;
writeMapping = 0;
tmAlignFast = 1;
exactTMscore = 0;
Expand Down
2 changes: 2 additions & 0 deletions src/commons/LocalParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class LocalParameters : public Parameters {
PARAMETER(PARAM_N_SAMPLE)
PARAMETER(PARAM_COORD_STORE_MODE)
PARAMETER(PARAM_MIN_ASSIGNED_CHAINS_THRESHOLD)
PARAMETER(PARAM_SINGLE_CHAIN_INCLUDE_MODE)
PARAMETER(PARAM_CLUSTER_SEARCH)
PARAMETER(PARAM_FILE_INCLUDE)
PARAMETER(PARAM_FILE_EXCLUDE)
Expand Down Expand Up @@ -150,6 +151,7 @@ class LocalParameters : public Parameters {
int nsample;
int coordStoreMode;
float minAssignedChainsThreshold;
int singleChainIncludeMode;
int clusterSearch;
std::string fileInclude;
std::string fileExclude;
Expand Down
3 changes: 2 additions & 1 deletion src/strucclustutils/MultimerUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "TMaligner.h"

const unsigned int NOT_AVAILABLE_CHAIN_KEY = 4294967295;
const double MAX_ASSIGNED_CHAIN_RATIO = 1.0;
const float MAX_ASSIGNED_CHAIN_RATIO = 1.0;
const double TOO_SMALL_MEAN = 1.0;
const double TOO_SMALL_CV = 0.1;
const double FILTERED_OUT = 0.0;
Expand All @@ -15,6 +15,7 @@ const float LEARNING_RATE = 0.1;
const float TM_SCORE_MARGIN = 0.7;
const unsigned int MULTIPLE_CHAINED_COMPLEX = 2;
const unsigned int SIZE_OF_SUPERPOSITION_VECTOR = 12;
const int SKIP_SINGLE_CHAIN_ASSIGNMENTS = 1;
typedef std::pair<std::string, std::string> compNameChainName_t;
typedef std::map<unsigned int, unsigned int> chainKeyToComplexId_t;
typedef std::map<unsigned int, std::vector<unsigned int>> complexIdToChainKeys_t;
Expand Down
7 changes: 3 additions & 4 deletions src/strucclustutils/createmultimerreport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ int createmultimerreport(int argc, const char **argv, const Command &command) {
std::string qLookupFile = par.db1 + ".lookup";
TranslateNucl translateNucl(static_cast<TranslateNucl::GenCode>(par.translationTable));

Matcher::result_t res;
std::map<unsigned int, unsigned int> qChainKeyToComplexIdMap;
std::map<unsigned int, std::vector<unsigned int>> qComplexIdToChainKeyMap;
std::vector<unsigned int> qComplexIdVec;
Expand Down Expand Up @@ -178,9 +177,9 @@ int createmultimerreport(int argc, const char **argv, const Command &command) {
} // MP end
SORT_PARALLEL(complexResults.begin(), complexResults.end(), compareComplexResultByQuery);
for (size_t complexResIdx = 0; complexResIdx < complexResults.size(); complexResIdx++) {
const ScoreComplexResult& res = complexResults[complexResIdx];
const resultToWrite_t& data = res.resultToWrite;
resultWriter.writeData(data.c_str(), data.length(), res.assId, 0, isDb, isDb);
const ScoreComplexResult& cRes = complexResults[complexResIdx];
const resultToWrite_t& data = cRes.resultToWrite;
resultWriter.writeData(data.c_str(), data.length(), cRes.assId, 0, isDb, isDb);
}
resultWriter.close(true);
if (isDb == false) {
Expand Down
25 changes: 17 additions & 8 deletions src/strucclustutils/scoremultimer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ struct SearchResult {
dbResidueLen = residueLen;
}

void standardize() {
void standardize(int singleChainedAssignmentIncludeMode) {
if (dbResidueLen == 0)
alnVec.clear();

if (singleChainedAssignmentIncludeMode==SKIP_SINGLE_CHAIN_ASSIGNMENTS && dbChainKeys.size() < MULTIPLE_CHAINED_COMPLEX)
alnVec.clear();

if (alnVec.empty())
return;

Expand Down Expand Up @@ -179,9 +182,11 @@ bool compareNeighborWithDist(const NeighborsWithDist &first, const NeighborsWith

class DBSCANCluster {
public:
DBSCANCluster(SearchResult &searchResult, std::set<cluster_t> &finalClusters, double minCov) : searchResult(searchResult), finalClusters(finalClusters) {
DBSCANCluster(SearchResult &searchResult, std::set<cluster_t> &finalClusters, double minCov, int singleChainMode) : searchResult(searchResult), finalClusters(finalClusters) {
cLabel = 0;
minimumClusterSize = (unsigned int) ((double) searchResult.qChainKeys.size() * minCov);
if (singleChainMode == SKIP_SINGLE_CHAIN_ASSIGNMENTS)
minimumClusterSize = std::max(MULTIPLE_CHAINED_COMPLEX, minimumClusterSize);
maximumClusterSize = std::min(searchResult.qChainKeys.size(), searchResult.dbChainKeys.size());
maximumClusterNum = searchResult.alnVec.size() / maximumClusterSize;
prevMaxClusterSize = 0;
Expand Down Expand Up @@ -467,7 +472,7 @@ class DBSCANCluster {

class ComplexScorer {
public:
ComplexScorer(IndexReader *qDbr3Di, IndexReader *tDbr3Di, DBReader<unsigned int> &alnDbr, IndexReader *qCaDbr, IndexReader *tCaDbr, unsigned int thread_idx, double minAssignedChainsRatio) : alnDbr(alnDbr), qCaDbr(qCaDbr), tCaDbr(tCaDbr), thread_idx(thread_idx), minAssignedChainsRatio(minAssignedChainsRatio) {
ComplexScorer(IndexReader *qDbr3Di, IndexReader *tDbr3Di, DBReader<unsigned int> &alnDbr, IndexReader *qCaDbr, IndexReader *tCaDbr, unsigned int thread_idx, double minAssignedChainsRatio, int singleChainedAssignmentIncludeMode) : alnDbr(alnDbr), qCaDbr(qCaDbr), tCaDbr(tCaDbr), thread_idx(thread_idx), minAssignedChainsRatio(minAssignedChainsRatio), singleChainedAssignmentIncludeMode(singleChainedAssignmentIncludeMode) {
maxChainLen = std::max(qDbr3Di->sequenceReader->getMaxSeqLen()+1, tDbr3Di->sequenceReader->getMaxSeqLen()+1);
q3diDbr = qDbr3Di;
t3diDbr = tDbr3Di;
Expand Down Expand Up @@ -533,7 +538,7 @@ class ComplexScorer {
paredSearchResult.alnVec.emplace_back(aln);
continue;
}
paredSearchResult.standardize();
paredSearchResult.standardize(singleChainedAssignmentIncludeMode);
if (!paredSearchResult.alnVec.empty())
searchResults.emplace_back(paredSearchResult);

Expand All @@ -545,7 +550,7 @@ class ComplexScorer {
paredSearchResult.alnVec.emplace_back(aln);
}
currAlns.clear();
paredSearchResult.standardize();
paredSearchResult.standardize(singleChainedAssignmentIncludeMode);
if (!paredSearchResult.alnVec.empty())
searchResults.emplace_back(paredSearchResult);

Expand All @@ -559,7 +564,7 @@ class ComplexScorer {
tmAligner = new TMaligner(maxResLen, false, true, false);
}
finalClusters.clear();
DBSCANCluster dbscanCluster(searchResult, finalClusters, minAssignedChainsRatio);
DBSCANCluster dbscanCluster(searchResult, finalClusters, minAssignedChainsRatio, singleChainedAssignmentIncludeMode);
if (!dbscanCluster.getAlnClusters()) {
finalClusters.clear();
return;
Expand Down Expand Up @@ -605,6 +610,7 @@ class ComplexScorer {
SearchResult paredSearchResult;
std::set<cluster_t> finalClusters;
bool hasBacktrace;
int singleChainedAssignmentIncludeMode;

unsigned int getQueryResidueLength(std::vector<unsigned int> &qChainKeys) {
unsigned int qResidueLen = 0;
Expand Down Expand Up @@ -697,7 +703,8 @@ int scoremultimer(int argc, const char **argv, const Command &command) {
);
}

double minAssignedChainsRatio = par.minAssignedChainsThreshold > MAX_ASSIGNED_CHAIN_RATIO ? MAX_ASSIGNED_CHAIN_RATIO: par.minAssignedChainsThreshold;
float minAssignedChainsRatio = par.minAssignedChainsThreshold > MAX_ASSIGNED_CHAIN_RATIO ? MAX_ASSIGNED_CHAIN_RATIO: par.minAssignedChainsThreshold;
int singleChainIncludeMode = par.singleChainIncludeMode;

std::vector<unsigned int> qComplexIndices;
std::vector<unsigned int> dbComplexIndices;
Expand All @@ -723,12 +730,14 @@ int scoremultimer(int argc, const char **argv, const Command &command) {
std::vector<SearchResult> searchResults;
std::vector<Assignment> assignments;
std::vector<resultToWrite_t> resultToWriteLines;
ComplexScorer complexScorer(q3DiDbr, &t3DiDbr, alnDbr, qCaDbr, &tCaDbr, thread_idx, minAssignedChainsRatio);
ComplexScorer complexScorer(q3DiDbr, &t3DiDbr, alnDbr, qCaDbr, &tCaDbr, thread_idx, minAssignedChainsRatio, singleChainIncludeMode);
#pragma omp for schedule(dynamic, 1)
// for each q complex
for (size_t qCompIdx = 0; qCompIdx < qComplexIndices.size(); qCompIdx++) {
unsigned int qComplexId = qComplexIndices[qCompIdx];
std::vector<unsigned int> &qChainKeys = qComplexIdToChainKeysMap.at(qComplexId);
if (par.singleChainIncludeMode == SKIP_SINGLE_CHAIN_ASSIGNMENTS && qChainKeys.size() < MULTIPLE_CHAINED_COMPLEX)
continue;
complexScorer.getSearchResults(qComplexId, qChainKeys, dbChainKeyToComplexIdMap, dbComplexIdToChainKeysMap, searchResults);
// for each db complex
for (size_t dbId = 0; dbId < searchResults.size(); dbId++) {
Expand Down

0 comments on commit 8809363

Please sign in to comment.