Skip to content

Commit

Permalink
changes for using prostt5
Browse files Browse the repository at this point in the history
  • Loading branch information
gamcil committed Aug 5, 2024
1 parent e739fef commit dc52098
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 52 deletions.
116 changes: 78 additions & 38 deletions src/strucclustutils/msa2lddt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,22 @@ int msa2lddt(int argc, const char **argv, const Command& command, int makeReport
seqDbrAA.open(DBReader<unsigned int>::NOSORT);
DBReader<unsigned int> seqDbr3Di((par.db1+"_ss").c_str(), (par.db1+"_ss.index").c_str(), par.threads, DBReader<unsigned int>::USE_INDEX|DBReader<unsigned int>::USE_DATA);
seqDbr3Di.open(DBReader<unsigned int>::NOSORT);
DBReader<unsigned int> seqDbrCA((par.db1+"_ca").c_str(), (par.db1+"_ca.index").c_str(), par.threads, DBReader<unsigned int>::USE_INDEX|DBReader<unsigned int>::USE_DATA);
seqDbrCA.open(DBReader<unsigned int>::NOSORT);

// Check for CA database
DBReader<unsigned int> *seqDbrCA = NULL;
bool caExist = FileUtil::fileExists((par.db1 + "_ca.dbtype").c_str());
if (caExist == false) {
Debug(Debug::INFO) << "Did not find " << FileUtil::baseName(par.db1) << " C-alpha database, not calculating LDDT\n";
} else {
seqDbrCA = new DBReader<unsigned int>(
(par.db1 + "_ca").c_str(),
(par.db1 + "_ca.index").c_str(),
par.threads,
DBReader<unsigned int>::USE_INDEX|DBReader<unsigned int>::USE_DATA
);
seqDbrCA->open(DBReader<unsigned int>::NOSORT);
}

IndexReader headerDB(par.db1, par.threads, IndexReader::HEADERS, touch ? IndexReader::PRELOAD_INDEX : 0);

// Read in MSA, mapping headers to database indices
Expand All @@ -394,20 +408,21 @@ int msa2lddt(int argc, const char **argv, const Command& command, int makeReport
for (size_t i = 0; i < subset.size(); i++)
subset[i] = i;

std::tie(perColumnScore, perColumnCount, lddtScore, numCols) = calculate_lddt(cigars_aa, subset, indices, lengths, &seqDbrCA, par.pairThreshold);


std::string scores;
for (float score : perColumnScore) {
if (scores.length() > 0) scores += ",";
scores += std::to_string(score);
if (caExist) {
std::tie(perColumnScore, perColumnCount, lddtScore, numCols) = calculate_lddt(cigars_aa, subset, indices, lengths, seqDbrCA, par.pairThreshold);
std::string scores;
for (float score : perColumnScore) {
if (scores.length() > 0) scores += ",";
scores += std::to_string(score);
}
Debug(Debug::INFO) << "Average MSA LDDT: " << lddtScore << '\n';
Debug(Debug::INFO) << "Columns considered: " << numCols << "/" << alnLength << '\n';
Debug(Debug::INFO) << "Column scores: " << scores << '\n';
}
std::cout << "Average MSA LDDT: " << lddtScore << '\n';
std::cout << "Columns considered: " << numCols << "/" << alnLength << '\n';
std::cout << "Column scores: " << scores << '\n';

// Write clustal format MSA HTML
if (makeReport) {
Debug(Debug::INFO) << "Generating report\n";
DBWriter resultWriter(par.db3.c_str(), (par.db3 + ".index").c_str(), static_cast<unsigned int>(par.threads), par.compressed, Parameters::DBTYPE_OMIT_FILE);
resultWriter.open();

Expand Down Expand Up @@ -470,34 +485,46 @@ R"html(<!DOCTYPE html>
for (size_t i = 0; i < cigars_aa.size(); i++) {
std::string seq_aa = expand(cigars_aa[i]);
std::string seq_ss = expand(cigars_ss[i]);
std::string seq_ca = getXYZstring(indices[i], lengths[i], &seqDbrCA);
std::string entry;
entry.append("{\"name\":\"");
entry.append(headers[i]);
entry.append("\",\"aa\": \"");
entry.append(seq_aa);
entry.append("\",\"ss\": \"");
entry.append(seq_ss);
entry.append("\",\"ca\": \"");
entry.append(seq_ca);
entry.append("\"}");
if (i != cigars_aa.size() - 1)
entry.append("\"");
if (caExist) {
std::string seq_ca = getXYZstring(indices[i], lengths[i], seqDbrCA);
entry.append(",\"ca\": \"");
entry.append(seq_ca);
entry.append("\"");
}
entry.append("}");
if (i != cigars_aa.size() - 1) {
entry.append(",");
} else {
entry.append("]");
}
resultWriter.writeData(entry.c_str(), entry.length(), 0, 0, false, false);
}

std::string middle = "],\"scores\": [";
resultWriter.writeData(middle.c_str(), middle.length(), 0, 0, false, false);
}

// Per-column scores, as [score, score, ...]
// TODO: optionally save this as .csv
for (int i = 0; i < alnLength; i++) {
std::string entry = (perColumnCount[i] == 0) ? "-1" : std::to_string(perColumnScore[i]);
if (i != alnLength - 1)
entry.append(",");
resultWriter.writeData(entry.c_str(), entry.length(), 0, 0, false, false);
std::string middle = "";

if (caExist) {
middle.append(",\"scores\": [");
for (int i = 0; i < alnLength; i++) {
std::string entry = (perColumnCount[i] == 0) ? "-1" : std::to_string(perColumnScore[i]);
if (i != alnLength - 1) {
entry.append(",");
}
middle.append(entry);
}
middle.append("]");
}
std::string end = "],";
resultWriter.writeData(middle.c_str(), middle.length(), 0, 0, false, false);

std::string end = "";

if (par.guideTree != "") {
std::string tree;
Expand All @@ -508,25 +535,37 @@ R"html(<!DOCTYPE html>
tree += line;
newick.close();
}
end.append("\"tree\": \"");
end.append(",\"tree\": \"");
end.append(tree);
end.append("\",");
end.append("\"");
}
end.append("\"statistics\": {");
end.append(",\"statistics\": {");

bool hasPrev = false;
if (par.reportPaths) {
end.append("\"db\":\"");
end.append(par.db1);
end.append("\",\"msaFile\":\"");
end.append(par.db2);
end.append("\",");
end.append("\"");
hasPrev = true;
}
if (caExist) {
if (hasPrev) {
end.append(",");
}
end.append("\"msaLDDT\":");
end.append(std::to_string(lddtScore));
hasPrev = true;
}
end.append("\"msaLDDT\":");
end.append(std::to_string(lddtScore));

if (par.reportCommand != "") {
end.append(",\"cmdString\":\"");
if (hasPrev) {
end.append(",");
}
end.append("\"cmdString\":\"");
end.append(par.reportCommand);
end.append("\"");
hasPrev = true;
}
end.append("}}");

Expand All @@ -544,9 +583,10 @@ R"html(<!DOCTYPE html>
}

seqDbrAA.close();
seqDbrCA.close();
seqDbr3Di.close();

if (caExist) {
seqDbrCA->close();
}
return EXIT_SUCCESS;
}

Expand Down
43 changes: 29 additions & 14 deletions src/strucclustutils/structuremsa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,26 +1054,26 @@ Matcher::result_t pairwiseTMAlign(
int mergedId,
int targetId,
DBReader<unsigned int> &seqDbrAA,
DBReader<unsigned int> &seqDbrCA
DBReader<unsigned int> *seqDbrCA
) {
int qLen = seqDbrAA.getSeqLen(mergedId);
int tLen = seqDbrAA.getSeqLen(targetId);

unsigned int qKey = seqDbrAA.getDbKey(mergedId);
size_t qCaId = seqDbrCA.getId(qKey);
size_t qCaId = seqDbrCA->getId(qKey);

unsigned int tKey = seqDbrAA.getDbKey(targetId);
size_t tCaId = seqDbrCA.getId(tKey);
size_t tCaId = seqDbrCA->getId(tKey);

Coordinate16 qcoords;
char *qcadata = seqDbrCA.getData(qCaId, 0);
size_t qCaLength = seqDbrCA.getEntryLen(qCaId);
char *qcadata = seqDbrCA->getData(qCaId, 0);
size_t qCaLength = seqDbrCA->getEntryLen(qCaId);
float *qCaData = qcoords.read(qcadata, qLen, qCaLength);
char *merged_aa_seq = seqDbrAA.getData(qCaId, 0);

Coordinate16 tcoords;
char *tcadata = seqDbrCA.getData(tCaId, 0);
size_t tCaLength = seqDbrCA.getEntryLen(tCaId);
char *tcadata = seqDbrCA->getData(tCaId, 0);
size_t tCaLength = seqDbrCA->getEntryLen(tCaId);
float *tCaData = tcoords.read(tcadata, tLen, tCaLength);
char *target_aa_seq = seqDbrAA.getData(tCaId, 0);

Expand Down Expand Up @@ -1117,8 +1117,21 @@ int structuremsa(int argc, const char **argv, const Command& command, bool preCl
seqDbrAA.open(DBReader<unsigned int>::NOSORT);
DBReader<unsigned int> seqDbr3Di((par.db1+"_ss").c_str(), (par.db1+"_ss.index").c_str(), par.threads, DBReader<unsigned int>::USE_INDEX|DBReader<unsigned int>::USE_DATA);
seqDbr3Di.open(DBReader<unsigned int>::NOSORT);
DBReader<unsigned int> seqDbrCA((par.db1+"_ca").c_str(), (par.db1+"_ca.index").c_str(), par.threads, DBReader<unsigned int>::USE_INDEX|DBReader<unsigned int>::USE_DATA);
seqDbrCA.open(DBReader<unsigned int>::NOSORT);

// Check for CA database
DBReader<unsigned int> *seqDbrCA = NULL;
bool caExist = FileUtil::fileExists((par.db1 + "_ca.dbtype").c_str());
if (caExist == false) {
Debug(Debug::INFO) << "Did not find " << FileUtil::baseName(par.db1) << " C-alpha database, not using\n";
} else {
seqDbrCA = new DBReader<unsigned int>(
(par.db1 + "_ca").c_str(),
(par.db1 + "_ca.index").c_str(),
par.threads,
DBReader<unsigned int>::USE_INDEX|DBReader<unsigned int>::USE_DATA
);
seqDbrCA->open(DBReader<unsigned int>::NOSORT);
}

IndexReader qdbrH(par.db1, par.threads, IndexReader::HEADERS, touch ? IndexReader::PRELOAD_INDEX : 0);

Expand Down Expand Up @@ -1512,7 +1525,7 @@ int structuremsa(int argc, const char **argv, const Command& command, bool preCl
// If neither are profiles, do TM-align as well and take the best alignment
bool tmaligned = false;
// if (false) {
if (!queryIsProfile && !targetIsProfile) {
if (caExist && !queryIsProfile && !targetIsProfile) {
Matcher::result_t tmRes = pairwiseTMAlign(mergedId, targetId, seqDbrAA, seqDbrCA);
std::vector<Instruction> qBtTM;
std::vector<Instruction> tBtTM;
Expand Down Expand Up @@ -1549,7 +1562,7 @@ int structuremsa(int argc, const char **argv, const Command& command, bool preCl
std::vector<size_t> indices_tm = { dbKeys[mergedId], dbKeys[targetId] };
std::vector<int> lengths_tm = { seqLens[mergedId], seqLens[targetId] };

float lddtTM = std::get<2>(calculate_lddt(cigars_tm, subset_tm, indices_tm, lengths_tm, &seqDbrCA, par.pairThreshold));
float lddtTM = std::get<2>(calculate_lddt(cigars_tm, subset_tm, indices_tm, lengths_tm, seqDbrCA, par.pairThreshold));
// std::cout << "got TM lddt: " << lddtTM << '\n';

// adjust cigars with 3Di alignment result
Expand Down Expand Up @@ -1578,7 +1591,7 @@ int structuremsa(int argc, const char **argv, const Command& command, bool preCl
// std::cout << expand(query_aa) << '\n';
// std::cout << expand(target_aa) << '\n';

float lddt3Di = std::get<2>(calculate_lddt(cigars_tm, subset_tm, indices_tm, lengths_tm, &seqDbrCA, par.pairThreshold));
float lddt3Di = std::get<2>(calculate_lddt(cigars_tm, subset_tm, indices_tm, lengths_tm, seqDbrCA, par.pairThreshold));
// std::cout << "got 3Di lddt: " << lddt3Di << '\n';

if (lddtTM > lddt3Di) {
Expand Down Expand Up @@ -1686,7 +1699,7 @@ int structuremsa(int argc, const char **argv, const Command& command, bool preCl
{
if (par.refineIters > 0) {
refineMany(
tinySubMatAA, tinySubMat3Di, &seqDbrCA, cigars_aa, cigars_ss, calculator_aa,
tinySubMatAA, tinySubMat3Di, seqDbrCA, cigars_aa, cigars_ss, calculator_aa,
filter_aa, subMat_aa, calculator_3di, filter_3di, subMat_3di, structureSmithWaterman,
par.refineIters, par.compBiasCorrection, par.wg, par.filterMaxSeqId, par.qsc,
par.Ndiff, par.covMSAThr, par.filterMinEnable, par.filterMsa, par.gapExtend.values.aminoacid(),
Expand Down Expand Up @@ -1749,7 +1762,9 @@ int structuremsa(int argc, const char **argv, const Command& command, bool preCl
free(tinySubMat3Di);
seqDbrAA.close();
seqDbr3Di.close();
seqDbrCA.close();
if (caExist) {
seqDbrCA->close();
}

return EXIT_SUCCESS;
}
Expand Down

0 comments on commit dc52098

Please sign in to comment.