Skip to content

Commit 145f093

Browse files
committed
add covert warning for multiclasslabel encoder
1 parent aea04e7 commit 145f093

File tree

3 files changed

+29
-14
lines changed

3 files changed

+29
-14
lines changed

src/shogun/labels/BinaryLabelEncoder.h

+2-14
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace shogun
3333
{
3434
const auto result_vector = labs->as<DenseLabels>()->get_labels();
3535
check_is_valid(result_vector);
36-
if (!can_convert_float_to_int(result_vector))
36+
if (print_warning && !can_convert_float_to_int(result_vector))
3737
{
3838
io::warn(
3939
"({}, {}) have been converted to (-1, 1).",
@@ -122,19 +122,7 @@ namespace shogun
122122
fmt::join(unique_set, ", "));
123123
}
124124

125-
bool can_convert_float_to_int(const SGVector<float64_t>& vec) const
126-
{
127-
SGVector<int32_t> converted(vec.vlen);
128-
std::transform(
129-
vec.begin(), vec.end(), converted.begin(),
130-
[](auto&& e) { return static_cast<int32_t>(e); });
131-
return std::equal(
132-
vec.begin(), vec.end(), converted.begin(),
133-
[](auto&& e1, auto&& e2) {
134-
return std::abs(e1 - e2) <
135-
std::numeric_limits<float64_t>::epsilon();
136-
});
137-
}
125+
138126
};
139127
} // namespace shogun
140128

src/shogun/labels/LabelEncoder.h

+19
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ namespace shogun
5959
return "LabelEncoder";
6060
}
6161

62+
void set_print_warning(bool print_warning){
63+
print_warning = print_warning;
64+
}
6265
protected:
6366
SGVector<float64_t> fit_impl(const SGVector<float64_t>& origin_vector)
6467
{
@@ -98,6 +101,22 @@ namespace shogun
98101
});
99102
return original_vector;
100103
}
104+
105+
bool can_convert_float_to_int(const SGVector<float64_t>& vec) const
106+
{
107+
SGVector<int32_t> converted(vec.vlen);
108+
std::transform(
109+
vec.begin(), vec.end(), converted.begin(),
110+
[](auto&& e) { return static_cast<int32_t>(e); });
111+
return std::equal(
112+
vec.begin(), vec.end(), converted.begin(),
113+
[](auto&& e1, auto&& e2) {
114+
return std::abs(e1 - e2) <
115+
std::numeric_limits<float64_t>::epsilon();
116+
});
117+
}
118+
119+
bool print_warning = true;
101120
std::set<float64_t> unique_labels;
102121
std::unordered_map<float64_t, float64_t> normalized_to_origin;
103122
};

src/shogun/labels/MulticlassLabelsEncoder.h

+8
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ namespace shogun
3232
SGVector<float64_t> fit(const std::shared_ptr<Labels>& labs) override
3333
{
3434
const auto result_vector = labs->as<DenseLabels>()->get_labels();
35+
if (print_warning && !can_convert_float_to_int(result_vector))
36+
{
37+
std::set<float64_t> s(result_vector.begin(), result_vector.end());
38+
io::warn(
39+
"{} have been converted to 0...{}",
40+
fmt::join(s, ", "),
41+
result_vector.vlen - 1);
42+
}
3543
return fit_impl(result_vector);
3644
}
3745
/** Transform labels to normalized encoding.

0 commit comments

Comments
 (0)