From f143eaf9fdd95f138335e49b3a26693847af5dac Mon Sep 17 00:00:00 2001 From: Gil Date: Mon, 29 Jun 2020 08:53:56 +0100 Subject: [PATCH] Dynamic cast constraint (#5082) * dynamic cast constraint * cleanup HistogramWordStringKernel --- src/shogun/base/AnyParameter.h | 7 +- src/shogun/base/SGObject.h | 10 +- src/shogun/base/constraint.h | 110 ++++++++++++++---- .../string/HistogramWordStringKernel.cpp | 102 ++++------------ .../kernel/string/HistogramWordStringKernel.h | 6 +- 5 files changed, 118 insertions(+), 117 deletions(-) diff --git a/src/shogun/base/AnyParameter.h b/src/shogun/base/AnyParameter.h index 8fb7959b0d7..82f6677d83a 100644 --- a/src/shogun/base/AnyParameter.h +++ b/src/shogun/base/AnyParameter.h @@ -14,6 +14,7 @@ #include #include #include +#include namespace shogun { @@ -168,7 +169,7 @@ namespace shogun } AnyParameter( Any&& value, const AnyParameterProperties& properties, - std::function constrain_function) + std::function(Any)> constrain_function) : m_value(std::move(value)), m_properties(properties), m_constrain_function(std::move(constrain_function)) { @@ -218,7 +219,7 @@ namespace shogun return m_init_function; } - const std::function& get_constrain_function() const + const std::function(Any)>& get_constrain_function() const noexcept { return m_constrain_function; @@ -272,7 +273,7 @@ namespace shogun Any m_value; AnyParameterProperties m_properties; std::shared_ptr m_init_function; - std::function m_constrain_function; + std::function(Any)> m_constrain_function; std::vector> m_callback_functions; }; } // namespace shogun diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index 74ad866d2f9..c5fb0b14e90 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -921,10 +921,8 @@ class SGObject: public std::enable_shared_from_this BaseTag(name), AnyParameter( make_any_ref(value), properties, [constrain_function](const auto& val) { - std::string result; auto casted_val = any_cast(val); - constrain_function.run(casted_val, result); - return result; + return constrain_function.check(casted_val); })); register_parameter_visitor(); } @@ -1131,12 +1129,12 @@ class SGObject: public std::enable_shared_from_this if (pprop.has_property(ParameterProperties::CONSTRAIN)) { - auto msg = param.get_constrain_function()(make_any(value)); - if (!msg.empty()) + const auto& val = param.get_constrain_function()(make_any(value)); + if (val) { require(!do_checks, "{}::{} cannot be updated because it must be: {}!", - get_name(), _tag.name().c_str(), msg.c_str()); + get_name(), _tag.name().c_str(), *val); } } if constexpr (std::is_same_v) diff --git a/src/shogun/base/constraint.h b/src/shogun/base/constraint.h index 5bb2ccd4435..839e5a43849 100644 --- a/src/shogun/base/constraint.h +++ b/src/shogun/base/constraint.h @@ -7,11 +7,16 @@ #ifndef __CONSTRAINT_H__ #define __CONSTRAINT_H__ +#include +#include + #include #include namespace shogun { + class SGObject; + namespace constraint_detail { template @@ -86,18 +91,80 @@ namespace shogun template struct generic_checker { - public: - generic_checker(T val) : m_val(val){}; - bool operator()(T val) const + generic_checker() = default; + bool operator()(const T& val) const { return check(val); }; virtual std::string error_msg() const = 0; + protected: + virtual bool check(const T& val) const = 0; + }; + + template + struct custom_constraint: generic_checker + { + template + custom_constraint(Functor&& func): m_func(func) + { + } + + std::string error_msg() const override + { + return msg; + } + + protected: + bool check(const T& val) const override + { + try + { + m_func(val); + } + catch (const std::exception& e) + { + msg = std::string(e.what()); + return false; + } + + return true; + } + + private: + std::string msg; + std::function m_func; + }; + + + template + struct castable: generic_checker> + { + castable(): generic_checker>() + { + } + + std::string error_msg() const override + { + return "of type " + demangled_type(); + } + + protected: + bool check(const std::shared_ptr& ptr) const override + { + return static_cast(std::dynamic_pointer_cast(ptr)); + } + }; + + template + struct comparisson_checker: generic_checker + { + comparisson_checker(T val): generic_checker() { + m_val = val; + } protected: T m_val; - virtual bool check(T val) const = 0; }; /** @@ -106,10 +173,9 @@ namespace shogun * @tparam T the type of val */ template - struct less_than : generic_checker + struct less_than : comparisson_checker { - public: - less_than(T val) : generic_checker(val){}; + less_than(T val) : comparisson_checker(val){}; std::string error_msg() const override { @@ -117,7 +183,7 @@ namespace shogun } protected: - bool check(T val) const override + bool check(const T& val) const override { return val < this->m_val; } @@ -129,10 +195,9 @@ namespace shogun * @tparam T the type of val */ template - struct less_than_or_equal : generic_checker + struct less_than_or_equal : comparisson_checker { - public: - less_than_or_equal(T val) : generic_checker(val){}; + less_than_or_equal(T val) : comparisson_checker(val){}; std::string error_msg() const override { @@ -140,7 +205,7 @@ namespace shogun } protected: - bool check(T val) const override + bool check(const T& val) const override { return val <= this->m_val; } @@ -152,17 +217,16 @@ namespace shogun * @tparam T the type of val */ template - struct greater_than : generic_checker + struct greater_than : comparisson_checker { - public: - greater_than(T val) : generic_checker(val){}; + greater_than(T val) : comparisson_checker(val){}; std::string error_msg() const override { return "greater than " + std::to_string(this->m_val); } protected: - bool check(T val) const override + bool check(const T& val) const override { return val > this->m_val; } @@ -174,10 +238,9 @@ namespace shogun * @tparam T the type of val */ template - struct greater_than_or_equal : generic_checker + struct greater_than_or_equal : comparisson_checker { - public: - greater_than_or_equal(T val) : generic_checker(val){}; + greater_than_or_equal(T val) : comparisson_checker(val){}; std::string error_msg() const override { @@ -185,7 +248,7 @@ namespace shogun } protected: - bool check(T val) const override + bool check(const T& val) const override { return val >= this->m_val; } @@ -237,14 +300,13 @@ namespace shogun } template - bool run(T val, std::string& buffer) const + std::optional check(const T& val) const { if (!constraint_detail::apply(val, m_funcs)) { - buffer = constraint_detail::get_error(m_funcs); - return false; + return constraint_detail::get_error(m_funcs); } - return true; + return std::nullopt; } private: diff --git a/src/shogun/kernel/string/HistogramWordStringKernel.cpp b/src/shogun/kernel/string/HistogramWordStringKernel.cpp index b14ac30a42c..ac010334c7c 100644 --- a/src/shogun/kernel/string/HistogramWordStringKernel.cpp +++ b/src/shogun/kernel/string/HistogramWordStringKernel.cpp @@ -21,18 +21,16 @@ HistogramWordStringKernel::HistogramWordStringKernel() init(); } -HistogramWordStringKernel::HistogramWordStringKernel(int32_t size, const std::shared_ptr& pie) +HistogramWordStringKernel::HistogramWordStringKernel(int32_t size, const std::shared_ptr& pie) : HistogramWordStringKernel() { - auto casted_pie = std::dynamic_pointer_cast(pie); - require(casted_pie, "Expected Machine to be PluginEstimate"); - estimate=std::move(casted_pie); + estimate=pie; set_cache_size(size); } HistogramWordStringKernel::HistogramWordStringKernel( const std::shared_ptr>& l, const std::shared_ptr>& r, - const std::shared_ptr& pie) + const std::shared_ptr& pie) : HistogramWordStringKernel(10, pie) { init(l, r); @@ -58,7 +56,6 @@ bool HistogramWordStringKernel::init(std::shared_ptr p_l, std::shared_ StringKernel::init(p_l,p_r); auto l=std::static_pointer_cast>(p_l); auto r=std::static_pointer_cast>(p_r); - auto plugin_estimate = std::static_pointer_cast(estimate); ASSERT(l) ASSERT(r) @@ -120,12 +117,12 @@ bool HistogramWordStringKernel::init(std::shared_ptr p_l, std::shared_ num_params=llen*((int32_t) l->get_num_symbols()); num_params2=llen*((int32_t) l->get_num_symbols())+rlen*((int32_t) r->get_num_symbols()); - if ((!plugin_estimate) || (!plugin_estimate->check_models())) + if ((!estimate) || (!estimate->check_models())) { error("no estimate available"); return false ; } ; - if (num_params2!=plugin_estimate->get_num_params()) + if (num_params2!=estimate->get_num_params()) { error("number of parameters of estimate and feature representation do not match"); return false ; @@ -152,13 +149,13 @@ bool HistogramWordStringKernel::init(std::shared_ptr p_l, std::shared_ bool free_vec; uint16_t* vec=l->get_feature_vector(i, len, free_vec); - mean[0]+=plugin_estimate->posterior_log_odds_obsolete(vec, len)/num_vectors; + mean[0]+=estimate->posterior_log_odds_obsolete(vec, len)/num_vectors; for (int32_t j=0; jlog_derivative_pos_obsolete(vec[j], j)/num_vectors; - mean[idx+num_params] += plugin_estimate->log_derivative_neg_obsolete(vec[j], j)/num_vectors; + mean[idx] += estimate->log_derivative_pos_obsolete(vec[j], j)/num_vectors; + mean[idx+num_params] += estimate->log_derivative_neg_obsolete(vec[j], j)/num_vectors; } l->free_feature_vector(vec, i, free_vec); @@ -171,7 +168,7 @@ bool HistogramWordStringKernel::init(std::shared_ptr p_l, std::shared_ bool free_vec; uint16_t* vec=l->get_feature_vector(i, len, free_vec); - variance[0] += Math::sq(plugin_estimate->posterior_log_odds_obsolete(vec, len)-mean[0])/num_vectors; + variance[0] += Math::sq(estimate->posterior_log_odds_obsolete(vec, len)-mean[0])/num_vectors; for (int32_t j=0; j p_l, std::shared_ } else { - variance[idx] += Math::sq(plugin_estimate->log_derivative_pos_obsolete(vec[j], j) + variance[idx] += Math::sq(estimate->log_derivative_pos_obsolete(vec[j], j) -mean[idx])/num_vectors; - variance[idx+num_params] += Math::sq(plugin_estimate->log_derivative_neg_obsolete(vec[j], j) + variance[idx+num_params] += Math::sq(estimate->log_derivative_neg_obsolete(vec[j], j) -mean[idx+num_params])/num_vectors; } } @@ -222,13 +219,13 @@ bool HistogramWordStringKernel::init(std::shared_ptr p_l, std::shared_ for (int32_t j=0; jlog_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ; - result -= plugin_estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ; + result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ; + result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ; } ld_mean_lhs[i]=result ; // precompute posterior-log-odds - plo_lhs[i] = plugin_estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ; + plo_lhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ; l->free_feature_vector(avec, i, free_avec); } ; @@ -247,13 +244,13 @@ bool HistogramWordStringKernel::init(std::shared_ptr p_l, std::shared_ for (int32_t j=0; jlog_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ; - result -= plugin_estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ; + result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ; + result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ; } ld_mean_rhs[i]=result ; // precompute posterior-log-odds - plo_rhs[i] = plugin_estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ; + plo_rhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ; r->free_feature_vector(avec, i, free_avec); } ; } ; @@ -354,7 +351,6 @@ float64_t HistogramWordStringKernel::compute(int32_t idx_a, int32_t idx_b) bool free_avec, free_bvec; uint16_t* avec=std::static_pointer_cast>(lhs)->get_feature_vector(idx_a, alen, free_avec); uint16_t* bvec=std::static_pointer_cast>(rhs)->get_feature_vector(idx_b, blen, free_bvec); - auto plugin_estimate = std::static_pointer_cast(estimate); // can only deal with strings of same length ASSERT(alen==blen) @@ -366,9 +362,9 @@ float64_t HistogramWordStringKernel::compute(int32_t idx_a, int32_t idx_b) if (avec[i]==bvec[i]) { int32_t a_idx = compute_index(i, avec[i]) ; - float64_t dd = plugin_estimate->log_derivative_pos_obsolete(avec[i], i) ; + float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ; result += dd*dd/variance[a_idx] ; - dd = plugin_estimate->log_derivative_neg_obsolete(avec[i], i) ; + dd = estimate->log_derivative_neg_obsolete(avec[i], i) ; result += dd*dd/variance[a_idx+num_params] ; } ; } @@ -377,11 +373,6 @@ float64_t HistogramWordStringKernel::compute(int32_t idx_a, int32_t idx_b) if (initialized) result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ; -#ifdef DEBUG_HWSK_COMPUTATION - float64_t result2 = compute_slow(idx_a, idx_b) ; - if (fabs(result - result2)>1e-10) - error("new={:e} old = {:e} diff = {:e}", result, result2, result - result2); -#endif std::static_pointer_cast>(lhs)->free_feature_vector(avec, idx_a, free_avec); std::static_pointer_cast>(rhs)->free_feature_vector(bvec, idx_b, free_bvec); return result; @@ -434,57 +425,6 @@ void HistogramWordStringKernel::init() /*m_parameters->add_vector(&variance, &num_params2, "variance");*/ watch_param("variance", &variance, &num_params2); - SG_ADD(&estimate, "estimate", "Plugin Estimate."); - - add_callback_function("estimate", [this](){ - if (!std::dynamic_pointer_cast(estimate)) { - error("Expected Machine to be PluginEstimate"); - estimate.reset(); - } - }); + SG_ADD((std::shared_ptr*)&estimate, "estimate", "Plugin Estimate.", ParameterProperties::CONSTRAIN, + SG_CONSTRAINT(castable())); } - -#ifdef DEBUG_HWSK_COMPUTATION -float64_t CHistogramWordStringKernel::compute_slow(int32_t idx_a, int32_t idx_b) -{ - int32_t alen, blen; - bool free_avec, free_bvec; - uint16_t* avec=std::static_pointer_cast>(lhs)->get_feature_vector(idx_a, alen, free_avec); - uint16_t* bvec=std::static_pointer_cast>(rhs)->get_feature_vector(idx_b, blen, free_bvec); - auto plugin_estimate = std::static_pointer_cast(estimate); - - // can only deal with strings of same length - ASSERT(alen==blen) - - float64_t result=(plugin_estimate->posterior_log_odds_obsolete(avec, alen)-mean[0])* - (plugin_estimate->posterior_log_odds_obsolete(bvec, blen)-mean[0])/(variance[0]); - result+= sum_m2_s2 ; // does not contain 0-th element - - for (int32_t i=0; ilog_derivative_pos_obsolete(avec[i], i) ; - result += dd*dd/variance[a_idx] ; - dd = plugin_estimate->log_derivative_neg_obsolete(avec[i], i) ; - result += dd*dd/variance[a_idx+num_params] ; - } ; - - result -= plugin_estimate->log_derivative_pos_obsolete(avec[i], i)*mean[a_idx]/variance[a_idx] ; - result -= plugin_estimate->log_derivative_pos_obsolete(bvec[i], i)*mean[b_idx]/variance[b_idx] ; - result -= plugin_estimate->log_derivative_neg_obsolete(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ; - result -= plugin_estimate->log_derivative_neg_obsolete(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ; - } - - if (initialized) - result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ; - - std::static_pointer_cast>(lhs)->free_feature_vector(avec, idx_a, free_avec); - std::static_pointer_cast>(rhs)->free_feature_vector(bvec, idx_b, free_bvec); - return result; -} - -#endif diff --git a/src/shogun/kernel/string/HistogramWordStringKernel.h b/src/shogun/kernel/string/HistogramWordStringKernel.h index 404dc44086c..faf49502adf 100644 --- a/src/shogun/kernel/string/HistogramWordStringKernel.h +++ b/src/shogun/kernel/string/HistogramWordStringKernel.h @@ -31,7 +31,7 @@ class HistogramWordStringKernel: public StringKernel * @param size cache size * @param pie plugin estimate */ - HistogramWordStringKernel(int32_t size, const std::shared_ptr& pie); + HistogramWordStringKernel(int32_t size, const std::shared_ptr& pie); /** constructor * @@ -41,7 +41,7 @@ class HistogramWordStringKernel: public StringKernel */ HistogramWordStringKernel( const std::shared_ptr>& l, const std::shared_ptr>& r, - const std::shared_ptr& pie); + const std::shared_ptr& pie); virtual ~HistogramWordStringKernel(); @@ -95,7 +95,7 @@ class HistogramWordStringKernel: public StringKernel protected: /** plugin estimate */ - std::shared_ptr estimate; + std::shared_ptr estimate; /** mean */ float64_t* mean;