From 443c05956678c9a00277c3042ff9ba8de42b4c00 Mon Sep 17 00:00:00 2001 From: Harun Mustafa Date: Tue, 11 Jun 2024 16:36:46 +0200 Subject: [PATCH] simplification --- .../src/graph/annotated_graph_algorithm.cpp | 99 +++++-------------- 1 file changed, 25 insertions(+), 74 deletions(-) diff --git a/metagraph/src/graph/annotated_graph_algorithm.cpp b/metagraph/src/graph/annotated_graph_algorithm.cpp index 83678ea092..270c9bd1a5 100644 --- a/metagraph/src/graph/annotated_graph_algorithm.cpp +++ b/metagraph/src/graph/annotated_graph_algorithm.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include "common/logger.hpp" #include "common/vectors/bitmap.hpp" @@ -238,12 +237,8 @@ mask_nodes_by_label_dual(std::shared_ptr graph_ptr, for (size_t j = 0; j < groups.size(); ++j) { const auto &col = *columns_all[j]; const auto &col_vals = *column_values_all[j]; - if (uint64_t r = col.conditional_rank1(row_i)) { - uint64_t c = col_vals[r - 1]; - unitig_sums[j] += c; - std::lock_guard lock(agg_mu); - ++unitig_hists[j][c]; - } + if (uint64_t r = col.conditional_rank1(row_i)) + unitig_sums[j] += col_vals[r - 1]; } } @@ -253,6 +248,9 @@ mask_nodes_by_label_dual(std::shared_ptr graph_ptr, std::lock_guard lock(agg_mu); ++num_unitigs; + for (size_t j = 0; j < groups.size(); ++j) { + ++unitig_hists[j][unitig_sums[j]]; + } }, num_threads); } @@ -435,8 +433,7 @@ mask_nodes_by_label_dual(std::shared_ptr graph_ptr, double var_orig = var; var = std::max(var, mu + 0.1); double r_guess = mu * mu / (var - mu); - common::logger->trace("{}: initial guess:\tmu: {}\tvar: {}\tr: {}\tp: {}\t", - j, mu, var_orig, r_guess, mu / var); + double p_guess = mu / var; double r = r_guess; try { r = boost::math::tools::newton_raphson_iterate([&](double r) { @@ -454,17 +451,12 @@ mask_nodes_by_label_dual(std::shared_ptr graph_ptr, r = r_guess; } - return std::make_pair(r, r / (r + mu)); - }; - - // geometric mean - double target_sum = 0.0; - for (size_t j = 0; j < groups.size(); ++j) { - target_sum += log(static_cast(sums[j])); - } - target_sum = exp(target_sum / groups.size()); + double p = r / (r + mu); + common::logger->trace("{}: mu: {}\tvar: {}\tmoment est:\tr: {}\tp: {}\tmle: r: {}\tp: {}", + j, mu, var_orig, r_guess, p_guess, r, p); - double min_p = 1.0; + return std::make_pair(r, p); + }; std::vector> nb_params(groups.size()); @@ -476,65 +468,24 @@ mask_nodes_by_label_dual(std::shared_ptr graph_ptr, double mu2 = mu * mu; double var = static_cast(config.test_by_unitig ? sums_of_squares_unitigs[j] : sums_of_squares[j]) / nelem - mu2; nb_params[j] = get_rp(j, mu, var, hist); - min_p = std::min(min_p, nb_params[j].second); } + // geometric mean + double target_sum = 0.0; + double p_guess = 0.0; + for (size_t j = 0; j < groups.size(); ++j) { + target_sum += log(static_cast(sums[j])); + p_guess += log(nb_params[j].second); + } + target_sum = exp(target_sum / groups.size()); + p_guess = exp(p_guess / groups.size()); + common::logger->trace("Fitting common p"); std::vector r_maps(groups.size(), 1.0); - double target_p = min_p; - try { - target_p = boost::math::tools::newton_raphson_iterate([&](double p) { - double dl = 0; - double ddl = 0; - for (size_t j = 0; j < groups.size(); ++j) { - double f = target_sum / sums[j]; - double &r = r_maps[j]; - r = boost::math::tools::newton_raphson_iterate([&](double r) { - double dl = nelem * (log(p) - boost::math::digamma(r)); - double ddl = -nelem * boost::math::trigamma(r); - for (const auto &[k, c] : hists[j]) { - dl += boost::math::digamma(k * f + r) * c; - ddl += boost::math::trigamma(k * f + r) * c; - } - return std::make_pair(dl, ddl); - }, r, std::numeric_limits::min(), f * sums[j], 30); - double factor = nelem * r / p; - dl += factor; - ddl -= factor / p; - const auto &hist = config.test_by_unitig ? unitig_hists[j] : hists[j]; - for (const auto &[k, c] : hist) { - double factor = k * f * c / (1 - p); - dl -= factor; - ddl -= factor / (1 - p); - } - } - - return std::make_pair(dl, ddl); - }, min_p, std::numeric_limits::min(), 1.0, 30); - } catch (std::exception &e) { - common::logger->warn("Caught exception while fitting p: Falling back to initial guess"); - common::logger->warn("{}", e.what()); - target_p = min_p; - for (size_t j = 0; j < groups.size(); ++j) { - double f = target_sum / sums[j]; - double &r = r_maps[j]; - try { - r = boost::math::tools::newton_raphson_iterate([&](double r) { - double dl = nelem * (log(target_p) - boost::math::digamma(r)); - double ddl = -nelem * boost::math::trigamma(r); - for (const auto &[k, c] : hists[j]) { - dl += boost::math::digamma(k * f + r) * c; - ddl += boost::math::trigamma(k * f + r) * c; - } - return std::make_pair(dl, ddl); - }, r, std::numeric_limits::min(), f * sums[j], 30); - } catch (std::exception &e) { - common::logger->warn("Caught exception while fitting r for {}: Falling back to initial guess", j); - common::logger->warn("{}", e.what()); - double mu = static_cast(sums[j]) * f / nelem; - r = target_p * mu / (1 - target_p); - } - } + double target_p = p_guess; + double r_map_base = target_p / nelem / (1 - target_p); + for (size_t j = 0; j < groups.size(); ++j) { + r_maps[j] = r_map_base * sums[j]; } std::vector>> count_maps(groups.size());