Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unsafely broad except clauses #173

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modAL/expected_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str = 'binary',
p_subsample: np.float = 1.0, n_instances: int = 1,
p_subsample: float = 1.0, n_instances: int = 1,
random_tie_break: bool = False) -> np.ndarray:
"""
Expected error reduction query strategy.
Expand Down
26 changes: 12 additions & 14 deletions modAL/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,13 @@ def query(self, X_pool, *query_args, return_metrics: bool = False, **query_kwarg
query_metrics: returns also the corresponding metrics, if return_metrics == True
"""

try:
query_result, query_metrics = self.query_strategy(
self, X_pool, *query_args, **query_kwargs)

except:
_query_strategy_result = self.query_strategy(
self, X_pool, *query_args, **query_kwargs)
if isinstance(_query_strategy_result, tuple) and len(_query_strategy_result) == 2:
query_result, query_metrics = _query_strategy_result
else:
query_result = _query_strategy_result
query_metrics = None
query_result = self.query_strategy(
self, X_pool, *query_args, **query_kwargs)

if return_metrics:
if query_metrics is None:
Expand Down Expand Up @@ -313,14 +312,13 @@ def query(self, X_pool, return_metrics: bool = False, *query_args, **query_kwarg
query_metrics: returns also the corresponding metrics, if return_metrics == True
"""

try:
query_result, query_metrics = self.query_strategy(
self, X_pool, *query_args, **query_kwargs)

except:
_query_strategy_result = self.query_strategy(
self, X_pool, *query_args, **query_kwargs)
if isinstance(_query_strategy_result, tuple) and len(_query_strategy_result) == 2:
query_result, query_metrics = _query_strategy_result
else:
query_result = _query_strategy_result
query_metrics = None
query_result = self.query_strategy(
self, X_pool, *query_args, **query_kwargs)

if return_metrics:
if query_metrics is None:
Expand Down
108 changes: 59 additions & 49 deletions modAL/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import List, Sequence, Union

import numpy as np
Expand All @@ -6,7 +7,7 @@

try:
import torch
except:
except ImportError:
pass


Expand All @@ -23,22 +24,34 @@ def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
Returns:
New sequence of vertically stacked elements.
"""

if not blocks:
return blocks

types = {type(block) for block in blocks}

if any([sp.issparse(b) for b in blocks]):
return sp.vstack(blocks)
elif isinstance(blocks[0], pd.DataFrame):
return blocks[0].append(blocks[1:])
elif isinstance(blocks[0], np.ndarray):
elif types - {pd.DataFrame, pd.Series} == set():
def _block_to_df(block):
if isinstance(block, pd.DataFrame):
return block
elif isinstance(block, pd.Series):
# interpret series as a row
return block.to_frame().T
else:
raise TypeError(f"Expected DataFrame or Series but encountered {type(block)}")

return pd.concat([_block_to_df(block) for block in blocks])
elif types == {np.ndarray}:
return np.concatenate(blocks)
elif isinstance(blocks[0], list):
elif types == {list}:
return np.concatenate(blocks).tolist()

try:
if torch.is_tensor(blocks[0]):
return torch.cat(blocks)
except:
pass
if 'torch' in sys.modules and all(torch.is_tensor(block) for block in blocks):
return torch.cat(blocks)

raise TypeError("%s datatype is not supported" % type(blocks[0]))
raise TypeError("%s datatype(s) not supported" % types)


def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
Expand All @@ -51,22 +64,25 @@ def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
Returns:
New sequence of horizontally stacked elements.
"""

if not blocks:
return blocks

types = {type(block) for block in blocks}

if any([sp.issparse(b) for b in blocks]):
return sp.hstack(blocks)
elif isinstance(blocks[0], pd.DataFrame):
elif types == {pd.DataFrame}:
pd.concat(blocks, axis=1)
elif isinstance(blocks[0], np.ndarray):
elif types == {np.ndarray}:
return np.hstack(blocks)
elif isinstance(blocks[0], list):
elif types == {list}:
return np.hstack(blocks).tolist()

try:
if torch.is_tensor(blocks[0]):
return torch.cat(blocks, dim=1)
except:
pass
if 'torch' in sys.modules and torch.is_tensor(blocks[0]):
return torch.cat(blocks, dim=1)

TypeError("%s datatype is not supported" % type(blocks[0]))
raise TypeError("%s datatype(s) not supported" % types)


def add_row(X: modALinput, row: modALinput):
Expand Down Expand Up @@ -100,24 +116,26 @@ def retrieve_rows(

try:
return X[I]
except:
if sp.issparse(X):
# Out of the sparse matrix formats (sp.csc_matrix, sp.csr_matrix, sp.bsr_matrix,
# sp.lil_matrix, sp.dok_matrix, sp.coo_matrix, sp.dia_matrix), only sp.bsr_matrix, sp.coo_matrix
# and sp.dia_matrix don't support indexing and need to be converted to a sparse format
# that does support indexing. It seems conversion to CSR is currently most efficient.

sp_format = X.getformat()
return X.tocsr()[I].asformat(sp_format)
elif isinstance(X, pd.DataFrame):
return X.iloc[I]
elif isinstance(X, list):
return np.array(X)[I].tolist()
elif isinstance(X, dict):
X_return = {}
for key, value in X.items():
X_return[key] = retrieve_rows(value, I)
return X_return
except (KeyError, IndexError, TypeError):
pass

if sp.issparse(X):
# Out of the sparse matrix formats (sp.csc_matrix, sp.csr_matrix, sp.bsr_matrix,
# sp.lil_matrix, sp.dok_matrix, sp.coo_matrix, sp.dia_matrix), only sp.bsr_matrix, sp.coo_matrix
# and sp.dia_matrix don't support indexing and need to be converted to a sparse format
# that does support indexing. It seems conversion to CSR is currently most efficient.

sp_format = X.getformat()
return X.tocsr()[I].asformat(sp_format)
elif isinstance(X, pd.DataFrame):
return X.iloc[I]
elif isinstance(X, list):
return np.array(X)[I].tolist()
elif isinstance(X, dict):
X_return = {}
for key, value in X.items():
X_return[key] = retrieve_rows(value, I)
return X_return

raise TypeError("%s datatype is not supported" % type(X))

Expand All @@ -139,12 +157,6 @@ def drop_rows(
elif isinstance(X, list):
return np.delete(X, I, axis=0).tolist()

try:
if torch.is_tensor(blocks[0]):
return torch.cat(blocks)
except:
X[[True if row not in I else False for row in range(X.size(0))]]

raise TypeError("%s datatype is not supported" % type(X))


Expand Down Expand Up @@ -173,11 +185,9 @@ def data_shape(X: modALinput):
"""
Returns the shape of the data set X
"""
try:
# scipy.sparse, torch, pandas and numpy all support .shape
if isinstance(X, list):
return np.array(X).shape
elif hasattr(X, "shape"): # scipy.sparse, torch, pandas and numpy all support .shape
return X.shape
except:
if isinstance(X, list):
return np.array(X).shape

raise TypeError("%s datatype is not supported" % type(X))
3 changes: 2 additions & 1 deletion rtd_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
numpy==1.20.0
numpy
scipy
scikit-learn
ipykernel
nbsphinx
pandas
skorch
torch
17 changes: 5 additions & 12 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,8 @@ def dummy_function(X_in):
else:
true_result = n_functions*np.ones(shape=(n_samples, 1))

try:
np.testing.assert_almost_equal(
linear_combination(X_in), true_result)
except:
linear_combination(X_in)
np.testing.assert_almost_equal(
linear_combination(X_in), true_result)

def test_product(self):
for n_dim in range(1, 5):
Expand Down Expand Up @@ -476,15 +473,11 @@ def test_KL_max_disagreement(self):

true_KL_disagreement = np.zeros(shape=(n_samples, ))

try:
np.testing.assert_array_almost_equal(
true_KL_disagreement,
modAL.disagreement.KL_max_disagreement(
committee, np.random.rand(n_samples, 1))
)
except:
np.testing.assert_array_almost_equal(
true_KL_disagreement,
modAL.disagreement.KL_max_disagreement(
committee, np.random.rand(n_samples, 1))
)

# 2. unfitted committee
committee = mock.MockCommittee(fitted=False)
Expand Down