diff --git a/.ci/ci.yml b/.ci/ci.yml index f8b4707ac31..17f5ab949af 100644 --- a/.ci/ci.yml +++ b/.ci/ci.yml @@ -62,6 +62,7 @@ jobs: - job: swig displayName: Linux SWIG Interface dependsOn: libshogun + timeoutInMinutes: 90 pool: vmImage: ubuntu-16.04 diff --git a/src/shogun/base/AnyParameter.h b/src/shogun/base/AnyParameter.h index 5ada21ce6b8..8fb7959b0d7 100644 --- a/src/shogun/base/AnyParameter.h +++ b/src/shogun/base/AnyParameter.h @@ -39,12 +39,17 @@ namespace shogun // has automatically computed value AUTO = 1u << 10, + // is const READONLY = 1u << 11, - // an executable function - RUNFUNCTION = 1u << 12, + // a class member function with side effects + FUNCTION = 1u << 12, - CONSTRAIN = 1u << 13, + // a const class member function + CONSTFUNCTION = 1u << 13, + + // parameter is constrained, e.g. has to be positive + CONSTRAIN = 1u << 14, ALL = std::numeric_limits::max(), }; @@ -58,7 +63,8 @@ namespace shogun {ParameterProperties::SETTING, "SETTING"}, {ParameterProperties::AUTO, "AUTO"}, {ParameterProperties::READONLY, "READONLY"}, - {ParameterProperties::RUNFUNCTION, "RUNFUNCTION"}, + {ParameterProperties::FUNCTION, "FUNCTION"}, + {ParameterProperties::CONSTFUNCTION, "CONSTFUNCTION"}, {ParameterProperties::CONSTRAIN, "CONSTRAIN"}}; enableEnumClassBitmask(ParameterProperties); @@ -182,7 +188,12 @@ namespace shogun { } - Any get_value() const + const Any& get_value() const + { + return m_value; + } + + Any& get_value() { return m_value; } diff --git a/src/shogun/base/SGObject.cpp b/src/shogun/base/SGObject.cpp index 2bcfa2f9c48..96accdf7a1c 100644 --- a/src/shogun/base/SGObject.cpp +++ b/src/shogun/base/SGObject.cpp @@ -359,69 +359,69 @@ void SGObject::create_parameter( self->create(std::forward(_tag), std::forward(parameter)); } -void SGObject::update_parameter(const BaseTag& _tag, const Any& value, bool do_checks) +const AnyParameter& SGObject::get_parameter(const BaseTag& _tag) const { - auto& param = self->at(_tag); - auto& pprop = param.get_properties(); - if (pprop.has_property(ParameterProperties::READONLY)) - require(!do_checks, - "{}::{} is marked as read-only and cannot be modified!", - get_name(), _tag.name().c_str()); - - if (pprop.has_property(ParameterProperties::CONSTRAIN)) + if (!has_parameter(_tag)) { - auto msg = self->find(_tag)->second.get_constrain_function()(value); - if (!msg.empty()) - { - require(!do_checks, - "{}::{} cannot be updated because it must be: {}!", - get_name(), _tag.name().c_str(), msg.c_str()); - } + error( + "Parameter {}::{} does not exist.", get_name(), + _tag.name().c_str()); } - param.set_value(value); - for (auto& method : param.get_callbacks()) - method(); - - pprop.remove_property(ParameterProperties::AUTO); -} - -AnyParameter SGObject::get_parameter(const BaseTag& _tag) const -{ - const auto& parameter = self->get(_tag); + + const auto& parameter = self->at(_tag); + if (parameter.get_properties().has_property( - ParameterProperties::RUNFUNCTION)) + ParameterProperties::FUNCTION)) { error( "The parameter {}::{} is registered as a function, " "use the .run() method instead!\n", get_name(), _tag.name().c_str()); } - if (parameter.get_value().empty()) + return parameter; +} + +AnyParameter& SGObject::get_parameter(const BaseTag& _tag) +{ + if (!has_parameter(_tag)) + { + error( + "Parameter {}::{} does not exist.", get_name(), + _tag.name().c_str()); + } + + auto& parameter = self->at(_tag); + + if (parameter.get_properties().has_property( + ParameterProperties::FUNCTION)) { error( - "There is no parameter called \"{}\" in {}", _tag.name().c_str(), - get_name()); + "The parameter {}::{} is registered as a function, " + "use the .run() method instead!\n", + get_name(), _tag.name().c_str()); } return parameter; } -AnyParameter SGObject::get_function(const BaseTag& _tag) const +const AnyParameter& SGObject::get_function(const BaseTag& _tag) const { - const auto& parameter = self->get(_tag); + if (!has_parameter(_tag)) + { + error( + "Parameter {}::{} does not exist.", get_name(), + _tag.name().c_str()); + } + + const auto& parameter = self->at(_tag); + if (!parameter.get_properties().has_property( - ParameterProperties::RUNFUNCTION)) + ParameterProperties::FUNCTION)) { error( "The parameter {}::{} is not registered as a function, " "use the .get() method instead", get_name(), _tag.name().c_str()); } - if (parameter.get_value().empty()) - { - error( - "There is no parameter called \"{}\" in {}", _tag.name().c_str(), - get_name()); - } return parameter; } diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index 32b40334838..6b9969561f5 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -130,6 +130,18 @@ SG_FORCED_INLINE const char* convert_string_to_char(const char* name) */ class SGObject: public std::enable_shared_from_this { + template + struct ParameterGetterInterface + { + ReturnType& m_value; + }; + + template + struct ParameterPutInterface + { + const ValueType& m_value; + }; + public: /** Definition of observed subject */ typedef rxcpp::subjects::subject> SGSubject; @@ -298,37 +310,18 @@ class SGObject: public std::enable_shared_from_this */ template ::value>* = nullptr> - void put(const Tag& _tag, const T& value) noexcept(false) + void put(const Tag& _tag, const T& value) { - if (has_parameter(_tag)) - { - auto parameter_value = get_parameter(_tag).get_value(); - if (!parameter_value.cloneable()) - { - error( - "Cannot put parameter {}::{}.", get_name(), - _tag.name().c_str()); - } - try - { - any_cast(parameter_value); - } - catch (const TypeMismatchException& exc) - { - error( - "Cannot put parameter {}::{} of type {}, incompatible " - "provided type {}.", - get_name(), _tag.name().c_str(), exc.actual().c_str(), - exc.expected().c_str()); - } - update_parameter(_tag, make_any(value)); - } - else + const auto& parameter_value = get_parameter(_tag).get_value(); + + if (!parameter_value.cloneable()) { error( - "Parameter {}::{} does not exist.", get_name(), + "Cannot put parameter {}::{}.", get_name(), _tag.name().c_str()); } + + update_parameter(_tag, value); } /** Setter for a class parameter that has values of type string, @@ -340,7 +333,7 @@ class SGObject: public std::enable_shared_from_this */ template ::value>* = nullptr> - void put(const Tag& _tag, const T& value) noexcept(false) + void put(const Tag& _tag, const T& value) { std::string val_string(value); @@ -351,7 +344,7 @@ class SGObject: public std::enable_shared_from_this _tag.name().c_str()); } - auto string_to_enum = m_string_to_enum_map[_tag.name()]; + auto string_to_enum = m_string_to_enum_map.at(_tag.name()); if (string_to_enum.find(val_string) == string_to_enum.end()) { @@ -404,7 +397,7 @@ class SGObject: public std::enable_shared_from_this Tag>> tag_vector(name); auto dispatched = get(tag_vector); dispatched.push_back(value); - update_parameter(BaseTag(name), make_any(dispatched), false); + update_parameter(BaseTag(name), dispatched, false); } #ifndef SWIG @@ -460,9 +453,9 @@ class SGObject: public std::enable_shared_from_this * @return object parameter */ #ifdef SWIG - std::shared_ptr get(const std::string& name) const noexcept(false); + std::shared_ptr get(const std::string& name) const; #else - std::shared_ptr get(std::string_view name) const noexcept(false); + std::shared_ptr get(std::string_view name) const; #endif #ifndef SWIG @@ -519,68 +512,43 @@ class SGObject: public std::enable_shared_from_this * @param _tag name and type information of parameter * @return value of the parameter identified by the input tag */ - template ::value && !is_sg_base::value>* = nullptr> - T get(const Tag& _tag) const noexcept(false) + template + auto get(const Tag& _tag) const { - const Any value = get_parameter(_tag).get_value(); - try - { - return any_cast(value); - } - catch (const TypeMismatchException& exc) - { - error( - "Cannot get parameter {}::{} of type {}, incompatible " - "requested type {}.", - get_name(), _tag.name().c_str(), exc.actual().c_str(), - exc.expected().c_str()); - } - // we won't be there - return any_cast(value); - } + using ReturnType = std::conditional_t::value, std::shared_ptr, T>; + ReturnType result; + + const auto& param = get_parameter(_tag); - template ::value>* = nullptr> - std::shared_ptr get(const Tag& _tag) const noexcept(false) - { - const Any value = get_parameter(_tag).get_value(); - try - { - return any_cast>(value); - } - catch (const TypeMismatchException& exc) + if constexpr (is_string::value) { - error( - "Cannot get parameter {}::{} of type {}, incompatible " - "requested type {} or there are no options for parameter " - "{}::{}.", - get_name(), _tag.name().c_str(), exc.actual().c_str(), - exc.expected().c_str(), get_name(), _tag.name().c_str()); + if (m_string_to_enum_map.count(_tag.name())) + return std::string(string_enum_reverse_lookup(_tag.name(), get(_tag.name()))); } - // we won't be there - return nullptr; - } - - template ::value>* = nullptr> - T get(const Tag& _tag) const noexcept(false) - { - if (m_string_to_enum_map.find(_tag.name()) == m_string_to_enum_map.end()) + + const auto& value = param.get_value(); + try { - const Any value = get_parameter(_tag).get_value(); - try + if (param.get_properties().has_property(ParameterProperties::CONSTFUNCTION)) { - return any_cast(value); + ParameterGetterInterface> visitor{result}; + value.visit_with(&visitor); } - catch (const TypeMismatchException& exc) + else { - error( - "Cannot get parameter {}::{} of type {}, incompatible " - "requested type {} or there are no options for parameter " - "{}::{}.", - get_name(), _tag.name().c_str(), exc.actual().c_str(), - exc.expected().c_str(), get_name(), _tag.name().c_str()); + ParameterGetterInterface visitor{result}; + value.visit_with(&visitor); } } - return std::string(string_enum_reverse_lookup(_tag.name(), get(_tag.name()))); + catch (...) + { + error( + "Cannot get parameter {}::{} of type {}, incompatible with " + "requested type {}.", + get_name(), _tag.name().c_str(), value.type().c_str(), + demangled_type().c_str()); + } + return result; } #endif @@ -592,9 +560,9 @@ class SGObject: public std::enable_shared_from_this */ template ::value>> #ifdef SWIG - T get(const std::string& name) const noexcept(false) + T get(const std::string& name) const #else - T get(std::string_view name) const noexcept(false) + T get(std::string_view name) const #endif { Tag tag(name); @@ -607,13 +575,13 @@ class SGObject: public std::enable_shared_from_this * @return value of the parameter corresponding to the input name and type */ #ifdef SWIG - void run(const std::string& name) const noexcept(false) + void run(const std::string& name) const #else - void run(std::string_view name) const noexcept(false) + void run(std::string_view name) const #endif { - Tag tag(name); - auto param = get_function(tag); + const auto tag = Tag(name); + const auto& param = get_function(tag); if (!any_cast(param.get_value())) { error("Failed to run function {}::{}", get_name(), name.data()); @@ -622,7 +590,7 @@ class SGObject: public std::enable_shared_from_this #ifndef SWIG template ::value>* = nullptr> - std::shared_ptr get(std::string_view name) const noexcept(false) + std::shared_ptr get(std::string_view name) const { Tag tag(name); return get(tag); @@ -734,7 +702,7 @@ class SGObject: public std::enable_shared_from_this * * @exception ShogunException will be thrown if an error occurs. */ - virtual void load_serializable_pre() noexcept(false); + virtual void load_serializable_pre(); /** Can (optionally) be overridden to post-initialize some member * variables which are not PARAMETER::ADD'ed. Make sure that at @@ -743,7 +711,7 @@ class SGObject: public std::enable_shared_from_this * * @exception ShogunException will be thrown if an error occurs. */ - virtual void load_serializable_post() noexcept(false); + virtual void load_serializable_post(); /** Can (optionally) be overridden to pre-initialize some member * variables which are not PARAMETER::ADD'ed. Make sure that at @@ -752,7 +720,7 @@ class SGObject: public std::enable_shared_from_this * * @exception ShogunException will be thrown if an error occurs. */ - virtual void save_serializable_pre() noexcept(false); + virtual void save_serializable_pre(); /** Can (optionally) be overridden to post-initialize some member * variables which are not PARAMETER::ADD'ed. Make sure that at @@ -761,7 +729,7 @@ class SGObject: public std::enable_shared_from_this * * @exception ShogunException will be thrown if an error occurs. */ - virtual void save_serializable_post() noexcept(false); + virtual void save_serializable_post(); inline bool get_load_serializable_pre() const { @@ -784,6 +752,38 @@ class SGObject: public std::enable_shared_from_this } protected: + template + void register_parameter_visitor() const + { + Any::register_visitor>( + [](T* value, auto* visitor) + { + *value = visitor->m_value; + } + ); + if constexpr (traits::is_functional::value) + { + if constexpr (!traits::returns_void::value) + { + using ReturnType = typename T::result_type; + Any::register_visitor>( + [](T* value, auto* visitor) + { + visitor->m_value = value->operator()(); + } + ); + } + } + else + { + Any::register_visitor>( + [](T* value, auto* visitor) + { + visitor->m_value = *value; + } + ); + } + } /** Registers a class parameter which is identified by a tag. * This enables the parameter to be modified by put() and retrieved by * get(). @@ -824,6 +824,7 @@ class SGObject: public std::enable_shared_from_this AnyParameterProperties properties = AnyParameterProperties()) { create_parameter(BaseTag(name), AnyParameter(make_any_ref(value), properties)); + register_parameter_visitor(); } /** Puts a pointer to some parameter into the parameter map. @@ -848,6 +849,7 @@ class SGObject: public std::enable_shared_from_this BaseTag(name), AnyParameter( make_any_ref(value), properties, std::move(auto_init))); + register_parameter_visitor(); } #ifndef SWIG @@ -880,6 +882,7 @@ class SGObject: public std::enable_shared_from_this constrain_function.run(casted_val, result); return result; })); + register_parameter_visitor(); } #endif @@ -898,6 +901,7 @@ class SGObject: public std::enable_shared_from_this { create_parameter( BaseTag(name), AnyParameter(make_any_ref(value, len), properties)); + register_parameter_visitor(); } /** Puts a pointer to some 2d parameter array (i.e. a matrix) into the @@ -917,6 +921,7 @@ class SGObject: public std::enable_shared_from_this { create_parameter( BaseTag(name), AnyParameter(make_any_ref(value, rows, cols), properties)); + register_parameter_visitor(); } #ifndef SWIG @@ -930,10 +935,12 @@ class SGObject: public std::enable_shared_from_this { AnyParameterProperties properties( "Dynamic parameter", - ParameterProperties::READONLY); - std::function bind_method = - std::bind(method, dynamic_cast(this)); + ParameterProperties::READONLY | ParameterProperties::CONSTFUNCTION); + std::function bind_method = [this, method](){ + return (static_cast(this) ->* method)(); + }; create_parameter(BaseTag(name), AnyParameter(make_any(bind_method), properties)); + register_parameter_visitor>(); } /** Puts a pointer to a (lazily evaluated) function into the parameter map. @@ -948,10 +955,12 @@ class SGObject: public std::enable_shared_from_this { AnyParameterProperties properties( "Non-const function", - ParameterProperties::RUNFUNCTION | ParameterProperties::READONLY); - std::function bind_method = - std::bind(method, dynamic_cast(this)); + ParameterProperties::READONLY | ParameterProperties::FUNCTION); + std::function bind_method = [this, method](){ + return (static_cast(this) ->* method)(); + }; create_parameter(BaseTag(name), AnyParameter(make_any(bind_method), properties)); + register_parameter_visitor>(); } /** Adds a callback function to a parameter identified by its name @@ -1065,7 +1074,62 @@ class SGObject: public std::enable_shared_from_this * @param _tag name information of parameter * @param value new value of parameter */ - void update_parameter(const BaseTag& _tag, const Any& value, bool do_checks = true); + template + void update_parameter(const BaseTag& _tag, const T& value, bool do_checks = true) + { + auto& param = get_parameter(_tag); + auto& pprop = param.get_properties(); + auto& parameter_value = param.get_value(); + if (pprop.has_property(ParameterProperties::READONLY)) + require(!do_checks, + "{}::{} is marked as read-only and cannot be modified!", + get_name(), _tag.name().c_str()); + + if (pprop.has_property(ParameterProperties::CONSTRAIN)) + { + auto msg = param.get_constrain_function()(make_any(value)); + if (!msg.empty()) + { + require(!do_checks, + "{}::{} cannot be updated because it must be: {}!", + get_name(), _tag.name().c_str(), msg.c_str()); + } + } + if constexpr (std::is_same_v) + { + param.set_value(value); + } + else + { + ParameterPutInterface visitor{value}; + + try + { + parameter_value.visit_with(&visitor); + } + catch (...) + { + error( + "Cannot put parameter {}::{} of type {}, incompatible with " + "provided type {}.", + get_name(), _tag.name().c_str(), parameter_value.type().c_str(), + demangled_type().c_str()); + } + } + + for (auto& method : param.get_callbacks()) + method(); + + pprop.remove_property(ParameterProperties::AUTO); + } + + /** Getter for a class parameter, identified by a BaseTag. + * Throws an exception if the class does not have such a parameter. + * + * @param _tag name information of parameter + * @return value of the parameter identified by the input tag + */ + const AnyParameter& get_parameter(const BaseTag& _tag) const; /** Getter for a class parameter, identified by a BaseTag. * Throws an exception if the class does not have such a parameter. @@ -1073,7 +1137,7 @@ class SGObject: public std::enable_shared_from_this * @param _tag name information of parameter * @return value of the parameter identified by the input tag */ - AnyParameter get_parameter(const BaseTag& _tag) const; + AnyParameter& get_parameter(const BaseTag& _tag); /** Getter for a class function, identified by a BaseTag. * Throws an exception if the class does not have such a parameter. @@ -1081,7 +1145,7 @@ class SGObject: public std::enable_shared_from_this * @param _tag name information of parameter * @return value of the parameter identified by the input tag */ - AnyParameter get_function(const BaseTag& _tag) const; + const AnyParameter& get_function(const BaseTag& _tag) const; class Self; std::unique_ptr self; @@ -1158,11 +1222,15 @@ class SGObject: public std::enable_shared_from_this template void observe(const int64_t step, std::string_view name) const { - auto param = this->get_parameter(BaseTag(name)); - auto cloned = any_cast(param.get_value()); + const auto tag = Tag(name); + auto& param = this->get_parameter(tag); + auto pprop = param.get_properties(); + auto cloned = get(tag); + pprop.remove_property(ParameterProperties::CONSTFUNCTION); + pprop.remove_property(ParameterProperties::FUNCTION); this->observe( step, name, static_cast(clone_utils::clone(cloned)), - param.get_properties()); + pprop); } /** diff --git a/src/shogun/base/class_list.cpp.py b/src/shogun/base/class_list.cpp.py index ebce1926584..290b88b0ad4 100644 --- a/src/shogun/base/class_list.cpp.py +++ b/src/shogun/base/class_list.cpp.py @@ -62,7 +62,8 @@ "NullFileSystem", "FilterVisitor", "RandomMixin", "MaxCrossValidation", "StreamingDataFetcher", "MaxMeasure", "MaxTestPower", "MedianHeuristic", "WeightedMaxMeasure", "WeightedMaxTestPower", - "Seedable", "ShogunEnv", "ShapeVisitor", "FeatureImportanceTree"] + "Seedable", "ShogunEnv", "ShapeVisitor", "FeatureImportanceTree", + "ParameterPutInterface", "ParameterGetterInterface"] SHOGUN_TEMPLATE_CLASS = "SHOGUN_TEMPLATE_CLASS" SHOGUN_BASIC_CLASS = "SHOGUN_BASIC_CLASS" diff --git a/src/shogun/lib/any.h b/src/shogun/lib/any.h index cc702e8d9c9..2060a4c958f 100644 --- a/src/shogun/lib/any.h +++ b/src/shogun/lib/any.h @@ -1205,19 +1205,15 @@ namespace shogun void visit(AnyVisitor* visitor) const; template - void visit_with(State* state = nullptr) + void visit_with(State* state = nullptr) const { const auto value_type = std::type_index{policy->type_info()}; const auto state_type = std::type_index{typeid(State)}; - if (policy->is_functional()) - { - throw std::logic_error{ - "Visit is not supported for functional Any"}; - } const auto key = std::make_pair(state_type, value_type); - if (Any::visitor_registry.count(key)) + auto visitor = Any::visitor_registry.find(key); + if (visitor != Any::visitor_registry.end()) { - visitor_registry[key](storage, state); + visitor->second(storage, state); } else { @@ -1229,7 +1225,7 @@ namespace shogun static void register_caster(std::function caster); template - static void register_visitor(std::function visitor); + static void register_visitor(std::function visitor); private: void set_or_inherit(const Any& other); @@ -1275,7 +1271,7 @@ namespace shogun } template - void Any::register_visitor(std::function visitor) + void Any::register_visitor(std::function visitor) { const auto key = std::make_pair( std::type_index{typeid(State)}, std::type_index{typeid(Type)}); @@ -1288,7 +1284,7 @@ namespace shogun visitor_registry[key] = [visitor](void* value, void* state) { auto typed_state = static_cast(state); auto typed_value = static_cast(value); - visitor(*typed_value, typed_state); + visitor(typed_value, typed_state); }; } diff --git a/tests/unit/lib/Any_unittest.cc b/tests/unit/lib/Any_unittest.cc index f5cb5e2300d..9c171d1f00e 100644 --- a/tests/unit/lib/Any_unittest.cc +++ b/tests/unit/lib/Any_unittest.cc @@ -850,7 +850,7 @@ TEST(Any, simple_visit) struct ExtractVisitor { }; - Any::register_visitor([&] (auto value, auto visitor) { extracted_value = value; }); + Any::register_visitor([&] (auto* value, auto* visitor) { extracted_value = *value; }); auto any = make_any(42); any.visit_with(); EXPECT_EQ(any.as(), extracted_value); @@ -863,8 +863,8 @@ TEST(Any, stateful_visit) { std::stringstream ss; }; - Any::register_visitor([&] (auto value, auto visitor) { - (visitor->ss) << value; + Any::register_visitor([&] (auto* value, auto* visitor) { + (visitor->ss) << *value; }); auto any = make_any(initial_value); StringStreamVisitor visitor;