Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gwlearn/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def fit(
self.oob_y_pooled_ = np.concatenate(non_empty_y)
else:
# Set to empty array with same dtype as y
self.oob_y_pooled_ = np.array([], dtype=y.dtype) # ty:ignore[no-matching-overload]
self.oob_y_pooled_ = np.array([], dtype=y.dtype)
if non_empty_pred:
self.oob_pred_pooled_ = np.concatenate(non_empty_pred)
else:
Expand Down Expand Up @@ -772,7 +772,7 @@ def fit(
self.oob_y_pooled_ = np.concatenate(non_empty_y)
else:
# Set to empty array with same dtype as y
self.oob_y_pooled_ = np.array([], dtype=y.dtype) # type:ignore[no-matching-overload]
self.oob_y_pooled_ = np.array([], dtype=y.dtype)
if non_empty_pred:
self.oob_pred_pooled_ = np.concatenate(non_empty_pred)
else:
Expand Down
2 changes: 1 addition & 1 deletion gwlearn/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class BandwidthSearch:
... max_iter=200,
... ).fit(X, y, geometry=gdf.representative_point())
>>> search.optimal_bandwidth_
np.int64(40)
40
"""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions gwlearn/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_interval_search_basic(sample_data): # noqa: F811
assert (
search.min_bandwidth
<= search.optimal_bandwidth_ # ty:ignore[unsupported-operator]
<= search.max_bandwidth # ty:ignore[unsupported-operator]
<= search.max_bandwidth
)

# Check the number of bandwidths tested
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_golden_section_search_basic(sample_data): # noqa: F811
assert (
search.min_bandwidth
<= search.optimal_bandwidth_ # ty:ignore[unsupported-operator]
<= search.max_bandwidth # ty:ignore[unsupported-operator]
<= search.max_bandwidth
)
# Golden section should evaluate fewer points than interval search
assert len(search.scores_) <= search.max_iterations * 2
Expand Down