Skip to content

Commit

Permalink
Remove realloc in QueryMatch to safe memory
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-steinegger committed Feb 14, 2024
1 parent 78ae2c5 commit 950342d
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 47 deletions.
55 changes: 47 additions & 8 deletions src/prefiltering/CacheFriendlyOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ CacheFriendlyOperations<BINSIZE>::~CacheFriendlyOperations<BINSIZE>(){

template<unsigned int BINSIZE>
size_t CacheFriendlyOperations<BINSIZE>::findDuplicates(IndexEntryLocal **input, CounterResult *output,
size_t outputSize, unsigned short indexFrom, unsigned short indexTo, bool computeTotalScore) {
size_t outputSize, unsigned short indexFrom, unsigned short indexTo, bool computeTotalScore) {
do {
setupBinPointer();
CounterResult *lastPosition = (binDataFrame + BINCOUNT * binSize) - 1;
Expand All @@ -58,12 +58,16 @@ size_t CacheFriendlyOperations<BINSIZE>::mergeElementsByScore(CounterResult *inp
}

template<unsigned int BINSIZE>
size_t CacheFriendlyOperations<BINSIZE>::mergeElementsByDiagonal(CounterResult *inputOutputArray, const size_t N) {
size_t CacheFriendlyOperations<BINSIZE>::mergeElementsByDiagonal(CounterResult *inputOutputArray, const size_t N, const bool keepScoredHits) {
do {
setupBinPointer();
hashElements(inputOutputArray, N);
} while(checkForOverflowAndResizeArray(false) == true); // overflowed occurred
return mergeDiagonalDuplicates(inputOutputArray);
if(keepScoredHits){
return mergeDiagonalKeepScoredHitsDuplicates(inputOutputArray);
}else{
return mergeDiagonalDuplicates(inputOutputArray);
}
}

template<unsigned int BINSIZE>
Expand Down Expand Up @@ -93,6 +97,7 @@ size_t CacheFriendlyOperations<BINSIZE>::mergeDiagonalDuplicates(CounterResult *
--n;
}
// combine diagonals
// we keep only the last diagonal element
for (size_t n = 0; n < currBinSize; n++) {
const CounterResult &element = binStartPos[n];
const unsigned int hashBinElement = element.id >> (MASK_0_5_BIT);
Expand All @@ -109,6 +114,40 @@ size_t CacheFriendlyOperations<BINSIZE>::mergeDiagonalDuplicates(CounterResult *
return doubleElementCount;
}


template<unsigned int BINSIZE>
size_t CacheFriendlyOperations<BINSIZE>::mergeDiagonalKeepScoredHitsDuplicates(CounterResult *output) {
size_t doubleElementCount = 0;
const CounterResult *bin_ref_pointer = binDataFrame;
// duplicateBitArray is already zero'd from findDuplicates

for (size_t bin = 0; bin < BINCOUNT; bin++) {
const CounterResult *binStartPos = (bin_ref_pointer + bin * binSize);
const size_t currBinSize = (bins[bin] - binStartPos);
// write diagonals + 1 in reverse order in the byte array
for (size_t n = 0; n < currBinSize; n++) {
const unsigned int element = binStartPos[n].id >> (MASK_0_5_BIT);
duplicateBitArray[element] = static_cast<unsigned char>(binStartPos[n].diagonal) + 1;
}
// combine diagonals
// we keep only the last diagonal element
size_t n = currBinSize - 1;
while (n != static_cast<size_t>(-1)) {
const CounterResult &element = binStartPos[n];
const unsigned int hashBinElement = element.id >> (MASK_0_5_BIT);
output[doubleElementCount].id = element.id;
output[doubleElementCount].count = element.count;
output[doubleElementCount].diagonal = element.diagonal;
// std::cout << output[doubleElementCount].id << " " << (int)output[doubleElementCount].count << " " << (int)static_cast<unsigned char>(output[doubleElementCount].diagonal) << std::endl;
// memory overflow can not happen since input array = output array
doubleElementCount += (output[doubleElementCount].count != 0 || duplicateBitArray[hashBinElement] != static_cast<unsigned char>(binStartPos[n].diagonal)) ? 1 : 0;
duplicateBitArray[hashBinElement] = static_cast<unsigned char>(element.diagonal);
--n;
}
}
return doubleElementCount;
}

template<unsigned int BINSIZE>
size_t CacheFriendlyOperations<BINSIZE>::mergeScoreDuplicates(CounterResult *output) {
size_t doubleElementCount = 0;
Expand Down Expand Up @@ -211,12 +250,12 @@ size_t CacheFriendlyOperations<BINSIZE>::findDuplicates(CounterResult *output, s
output[doubleElementCount].id = element;
output[doubleElementCount].count = 0;
output[doubleElementCount].diagonal = tmpElementBuffer[n].diagonal;
// const unsigned char diagonal = static_cast<unsigned char>(tmpElementBuffer[n].diagonal);
// const unsigned char diagonal = static_cast<unsigned char>(tmpElementBuffer[n].diagonal);
// memory overflow can not happen since input array = output array
// if(duplicateBitArray[hashBinElement] != tmpElementBuffer[n].diagonal){
// std::cout << "seq="<< output[doubleElementCount].id << "\tDiag=" << (int) output[doubleElementCount].diagonal
// << " dup.Array=" << (int)duplicateBitArray[hashBinElement] << " tmp.Arr="<< (int)tmpElementBuffer[n].diagonal << std::endl;
// }
// if(duplicateBitArray[hashBinElement] != tmpElementBuffer[n].diagonal){
// std::cout << "seq="<< output[doubleElementCount].id << "\tDiag=" << (int) output[doubleElementCount].diagonal
// << " dup.Array=" << (int)duplicateBitArray[hashBinElement] << " tmp.Arr="<< (int)tmpElementBuffer[n].diagonal << std::endl;
// }
doubleElementCount += (duplicateBitArray[hashBinElement] != static_cast<unsigned char>(tmpElementBuffer[n].diagonal)) ? 1 : 0;
duplicateBitArray[hashBinElement] = static_cast<unsigned char>(tmpElementBuffer[n].diagonal);
}
Expand Down
4 changes: 3 additions & 1 deletion src/prefiltering/CacheFriendlyOperations.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class CacheFriendlyOperations {
size_t mergeElementsByScore(CounterResult *inputOutputArray, const size_t N);

// merge elements in CounterResult by diagonal, combines elements with same ids that occur after each other
size_t mergeElementsByDiagonal(CounterResult *inputOutputArray, const size_t N);
size_t mergeElementsByDiagonal(CounterResult *inputOutputArray, const size_t N, const bool keepScoredHits = false);

size_t keepMaxScoreElementOnly(CounterResult *inputOutputArray, const size_t N);

Expand Down Expand Up @@ -124,6 +124,8 @@ class CacheFriendlyOperations {

size_t mergeDiagonalDuplicates(CounterResult *output);

size_t mergeDiagonalKeepScoredHitsDuplicates(CounterResult *output);

size_t keepMaxElement(CounterResult *output);
};

Expand Down
44 changes: 24 additions & 20 deletions src/prefiltering/QueryMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,17 @@ std::pair<hit_t*, size_t> QueryMatcher::matchQuery(Sequence *querySeq, unsigned
} else {
memset(compositionBias, 0, sizeof(float) * querySeq->L);
}

if(diagonalScoring == true){
ungappedAlignment->createProfile(querySeq, compositionBias);
}
size_t resultSize = match(querySeq, compositionBias);
if (hook != NULL) {
resultSize = hook->afterDiagonalMatchingHook(*this, resultSize);
}
std::pair<hit_t *, size_t> queryResult;
if (diagonalScoring) {
// write diagonal scores in count value
ungappedAlignment->processQuery(querySeq, compositionBias, foundDiagonals, resultSize);
ungappedAlignment->align(foundDiagonals, resultSize);
memset(scoreSizes, 0, SCORE_RANGE * sizeof(unsigned int));
CounterResult * resultReadPos = foundDiagonals;
CounterResult * resultWritePos = foundDiagonals + resultSize;
Expand Down Expand Up @@ -267,35 +269,37 @@ size_t QueryMatcher::match(Sequence *seq, float *compositionBias) {
//std::cout << seq->getDbKey() << std::endl;
//idx.printKmer(index[kmerPos], kmerSize, kmerSubMat->num2aa);
//std::cout << "\t" << current_i << "\t"<< index[kmerPos] << std::endl;
//for (size_t i = 0; i < seqListSize; i++) {
// char diag = entries[i].position_j - current_i;
// std::cout << "(" << entries[i].seqId << " " << (int) diag << ")\t";
//}
// for (size_t i = 0; i < seqListSize; i++) {
// if(23865 == entries[i].seqId ){
// char diag = entries[i].position_j - current_i;
// std::cout << "(" << entries[i].seqId << " " << (int) diag << ")\t";
// }
// }
//std::cout << std::endl;

// detected overflow while matching
if ((sequenceHits + seqListSize) >= lastSequenceHit) {
stats->diagonalOverflow = true;
// realloc foundDiagonals if only 10% of memory left
if((foundDiagonalsSize - overflowHitCount) < 0.1 * foundDiagonalsSize){
foundDiagonalsSize *= 1.5;
foundDiagonals = (CounterResult*) realloc(foundDiagonals, foundDiagonalsSize * sizeof(CounterResult));
if(foundDiagonals == NULL){
Debug(Debug::ERROR) << "Out of memory in QueryMatcher::match\n";
EXIT(EXIT_FAILURE);
}
}
// last pointer
indexPointer[current_i + 1] = sequenceHits;
//std::cout << "Overflow in i=" << indexStart << std::endl;
const size_t hitCount = findDuplicates(indexPointer,
foundDiagonals + overflowHitCount,
foundDiagonalsSize - overflowHitCount,
indexStart, current_i, (diagonalScoring == false));

// this happens only if we have two overflows in a row
if (overflowHitCount != 0) {
// merge lists, hitCount is max. dbSize so there can be no overflow in mergeElements
overflowHitCount = mergeElements(foundDiagonals, hitCount + overflowHitCount);
if(diagonalScoring == true){
overflowHitCount = mergeElements(foundDiagonals, hitCount + overflowHitCount, true);
// align the new diaognals
ungappedAlignment->align(foundDiagonals, overflowHitCount);
// We keep only the maximal diagonal scoring hit, so the max number of hits is DBsize
overflowHitCount = keepMaxScoreElementOnly(foundDiagonals, overflowHitCount);
} else {
// in case of scoring we just sum up in mergeElements, so the max number of hits is DBsize
// merge lists, hitCount is max. dbSize so there can be no overflow in mergeElements
overflowHitCount = mergeElements(foundDiagonals, hitCount + overflowHitCount);
}
} else {
overflowHitCount = hitCount;
}
Expand Down Expand Up @@ -463,11 +467,11 @@ size_t QueryMatcher::findDuplicates(IndexEntryLocal **hitsByIndex,
return localResultSize;
}

size_t QueryMatcher::mergeElements(CounterResult *foundDiagonals, size_t hitCounter) {
size_t QueryMatcher::mergeElements(CounterResult *foundDiagonals, size_t hitCounter, bool keepScoredHits) {
size_t overflowHitCount = 0;
#define MERGE_CASE(x) \
case x: overflowHitCount = diagonalScoring ? \
cachedOperation##x->mergeElementsByDiagonal(foundDiagonals,hitCounter) : \
cachedOperation##x->mergeElementsByDiagonal(foundDiagonals,hitCounter, keepScoredHits) : \
cachedOperation##x->mergeElementsByScore(foundDiagonals,hitCounter); \
break;

Expand Down
3 changes: 1 addition & 2 deletions src/prefiltering/QueryMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ class QueryMatcher {
size_t findDuplicates(IndexEntryLocal **hitsByIndex, CounterResult *output,
size_t outputSize, unsigned short indexFrom, unsigned short indexTo, bool computeTotalScore);


size_t mergeElements(CounterResult *foundDiagonals, size_t hitCounter);
size_t mergeElements(CounterResult *foundDiagonals, size_t hitCounter, bool keepHitsWithCounts = false);

size_t keepMaxScoreElementOnly(CounterResult *foundDiagonals, size_t resultSize);

Expand Down
22 changes: 10 additions & 12 deletions src/prefiltering/UngappedAlignment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,8 @@ UngappedAlignment::~UngappedAlignment() {
delete [] score_arr;
}

void UngappedAlignment::processQuery(Sequence *seq,
float *biasCorrection,
CounterResult *results,
size_t resultSize) {
createProfile(seq, biasCorrection, subMatrix->subMatrix);
queryLen = seq->L;
computeScores(queryProfile, seq->L, results, resultSize);
void UngappedAlignment::align(CounterResult *results, size_t resultSize) {
computeScores(queryProfile, queryLen, results, resultSize);
}


Expand Down Expand Up @@ -290,7 +285,7 @@ void UngappedAlignment::scoreDiagonalAndUpdateHits(const char * queryProfile,
// update score
for(size_t hitIdx = 0; hitIdx < hitSize; hitIdx++){
hits[seqs[hitIdx].id]->count = static_cast<unsigned char>(std::min(static_cast<unsigned int>(255),
score_arr[hitIdx]));
score_arr[hitIdx]));
if(seqs[hitIdx].seqLen == 1){
std::pair<const unsigned char *, const unsigned int> dbSeq = sequenceLookup->getSequence(hits[hitIdx]->id);
if(dbSeq.second >= 32768){
Expand Down Expand Up @@ -344,6 +339,10 @@ void UngappedAlignment::computeScores(const char *queryProfile,
// continue;
// }
const unsigned short currDiag = results[i].diagonal;
// skip results that already have a diagonal score
if(results[i].count != 0){
continue;
}
diagonalMatches[currDiag * DIAGONALBINSIZE + diagonalCounter[currDiag]] = &results[i];
diagonalCounter[currDiag]++;
if(diagonalCounter[currDiag] == DIAGONALBINSIZE) {
Expand Down Expand Up @@ -384,9 +383,8 @@ void UngappedAlignment::extractScores(unsigned int *score_arr, simd_int score) {


void UngappedAlignment::createProfile(Sequence *seq,
float * biasCorrection,
short **subMat) {

float * biasCorrection) {
queryLen = seq->L;
if(Parameters::isEqualDbtype(seq->getSequenceType(), Parameters::DBTYPE_HMM_PROFILE)) {
memset(queryProfile, 0, (Sequence::PROFILE_AA_SIZE + 1) * seq->L);
}else{
Expand All @@ -409,7 +407,7 @@ void UngappedAlignment::createProfile(Sequence *seq,
for (int pos = 0; pos < seq->L; pos++) {
unsigned int aaIdx = seq->numSequence[pos];
for (int i = 0; i < subMatrix->alphabetSize; i++) {
queryProfile[pos * (Sequence::PROFILE_AA_SIZE + 1) + i] = (subMat[aaIdx][i] + aaCorrectionScore[pos]);
queryProfile[pos * (Sequence::PROFILE_AA_SIZE + 1) + i] = (subMatrix->subMatrix[aaIdx][i] + aaCorrectionScore[pos]);
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/prefiltering/UngappedAlignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ class UngappedAlignment {

~UngappedAlignment();

void createProfile(Sequence *seq, float *biasCorrection);

// This function computes the diagonal score for each CounterResult object
// it assigns the diagonal score to the CounterResult object
void processQuery(Sequence *seq, float *compositionBias, CounterResult *results,
size_t resultSize);
void align(CounterResult *results,
size_t resultSize);

int scoreSingelSequenceByCounterResult(CounterResult &result);

Expand Down Expand Up @@ -90,8 +92,6 @@ class UngappedAlignment {

void extractScores(unsigned int *score_arr, simd_int score);

void createProfile(Sequence *seq, float *biasCorrection, short **subMat);

int computeSingelSequenceScores(const char *queryProfile, const unsigned int queryLen,
std::pair<const unsigned char *, const unsigned int> &dbSeq,
int diagonal, unsigned int minDistToDiagonal);
Expand Down

0 comments on commit 950342d

Please sign in to comment.