Skip to content

Commit

Permalink
simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
hmusta committed Jun 11, 2024
1 parent 1296ada commit 443c059
Showing 1 changed file with 25 additions and 74 deletions.
99 changes: 25 additions & 74 deletions metagraph/src/graph/annotated_graph_algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <boost/math/tools/roots.hpp>
#include <boost/math/special_functions/digamma.hpp>
#include <boost/math/special_functions/trigamma.hpp>
#include <boost/math/special_functions/polygamma.hpp>

#include "common/logger.hpp"
#include "common/vectors/bitmap.hpp"
Expand Down Expand Up @@ -238,12 +237,8 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,
for (size_t j = 0; j < groups.size(); ++j) {
const auto &col = *columns_all[j];
const auto &col_vals = *column_values_all[j];
if (uint64_t r = col.conditional_rank1(row_i)) {
uint64_t c = col_vals[r - 1];
unitig_sums[j] += c;
std::lock_guard<std::mutex> lock(agg_mu);
++unitig_hists[j][c];
}
if (uint64_t r = col.conditional_rank1(row_i))
unitig_sums[j] += col_vals[r - 1];
}
}

Expand All @@ -253,6 +248,9 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,

std::lock_guard<std::mutex> lock(agg_mu);
++num_unitigs;
for (size_t j = 0; j < groups.size(); ++j) {
++unitig_hists[j][unitig_sums[j]];
}
}, num_threads);
}

Expand Down Expand Up @@ -435,8 +433,7 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,
double var_orig = var;
var = std::max(var, mu + 0.1);
double r_guess = mu * mu / (var - mu);
common::logger->trace("{}: initial guess:\tmu: {}\tvar: {}\tr: {}\tp: {}\t",
j, mu, var_orig, r_guess, mu / var);
double p_guess = mu / var;
double r = r_guess;
try {
r = boost::math::tools::newton_raphson_iterate([&](double r) {
Expand All @@ -454,17 +451,12 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,
r = r_guess;
}

return std::make_pair(r, r / (r + mu));
};

// geometric mean
double target_sum = 0.0;
for (size_t j = 0; j < groups.size(); ++j) {
target_sum += log(static_cast<double>(sums[j]));
}
target_sum = exp(target_sum / groups.size());
double p = r / (r + mu);
common::logger->trace("{}: mu: {}\tvar: {}\tmoment est:\tr: {}\tp: {}\tmle: r: {}\tp: {}",
j, mu, var_orig, r_guess, p_guess, r, p);

double min_p = 1.0;
return std::make_pair(r, p);
};

std::vector<std::pair<double, double>> nb_params(groups.size());

Expand All @@ -476,65 +468,24 @@ mask_nodes_by_label_dual(std::shared_ptr<const DeBruijnGraph> graph_ptr,
double mu2 = mu * mu;
double var = static_cast<double>(config.test_by_unitig ? sums_of_squares_unitigs[j] : sums_of_squares[j]) / nelem - mu2;
nb_params[j] = get_rp(j, mu, var, hist);
min_p = std::min(min_p, nb_params[j].second);
}

// geometric mean
double target_sum = 0.0;
double p_guess = 0.0;
for (size_t j = 0; j < groups.size(); ++j) {
target_sum += log(static_cast<double>(sums[j]));
p_guess += log(nb_params[j].second);
}
target_sum = exp(target_sum / groups.size());
p_guess = exp(p_guess / groups.size());

common::logger->trace("Fitting common p");
std::vector<double> r_maps(groups.size(), 1.0);
double target_p = min_p;
try {
target_p = boost::math::tools::newton_raphson_iterate([&](double p) {
double dl = 0;
double ddl = 0;
for (size_t j = 0; j < groups.size(); ++j) {
double f = target_sum / sums[j];
double &r = r_maps[j];
r = boost::math::tools::newton_raphson_iterate([&](double r) {
double dl = nelem * (log(p) - boost::math::digamma(r));
double ddl = -nelem * boost::math::trigamma(r);
for (const auto &[k, c] : hists[j]) {
dl += boost::math::digamma(k * f + r) * c;
ddl += boost::math::trigamma(k * f + r) * c;
}
return std::make_pair(dl, ddl);
}, r, std::numeric_limits<double>::min(), f * sums[j], 30);
double factor = nelem * r / p;
dl += factor;
ddl -= factor / p;
const auto &hist = config.test_by_unitig ? unitig_hists[j] : hists[j];
for (const auto &[k, c] : hist) {
double factor = k * f * c / (1 - p);
dl -= factor;
ddl -= factor / (1 - p);
}
}

return std::make_pair(dl, ddl);
}, min_p, std::numeric_limits<double>::min(), 1.0, 30);
} catch (std::exception &e) {
common::logger->warn("Caught exception while fitting p: Falling back to initial guess");
common::logger->warn("{}", e.what());
target_p = min_p;
for (size_t j = 0; j < groups.size(); ++j) {
double f = target_sum / sums[j];
double &r = r_maps[j];
try {
r = boost::math::tools::newton_raphson_iterate([&](double r) {
double dl = nelem * (log(target_p) - boost::math::digamma(r));
double ddl = -nelem * boost::math::trigamma(r);
for (const auto &[k, c] : hists[j]) {
dl += boost::math::digamma(k * f + r) * c;
ddl += boost::math::trigamma(k * f + r) * c;
}
return std::make_pair(dl, ddl);
}, r, std::numeric_limits<double>::min(), f * sums[j], 30);
} catch (std::exception &e) {
common::logger->warn("Caught exception while fitting r for {}: Falling back to initial guess", j);
common::logger->warn("{}", e.what());
double mu = static_cast<double>(sums[j]) * f / nelem;
r = target_p * mu / (1 - target_p);
}
}
double target_p = p_guess;
double r_map_base = target_p / nelem / (1 - target_p);
for (size_t j = 0; j < groups.size(); ++j) {
r_maps[j] = r_map_base * sums[j];
}

std::vector<VectorMap<uint64_t, std::pair<size_t, uint64_t>>> count_maps(groups.size());
Expand Down

0 comments on commit 443c059

Please sign in to comment.