diff --git a/modAL/expected_error.py b/modAL/expected_error.py index d7b3611..848ae7b 100644 --- a/modAL/expected_error.py +++ b/modAL/expected_error.py @@ -16,7 +16,7 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str = 'binary', - p_subsample: np.float = 1.0, n_instances: int = 1, + p_subsample: float = 1.0, n_instances: int = 1, random_tie_break: bool = False) -> np.ndarray: """ Expected error reduction query strategy. diff --git a/modAL/utils/data.py b/modAL/utils/data.py index 3e707ff..c75350e 100644 --- a/modAL/utils/data.py +++ b/modAL/utils/data.py @@ -6,7 +6,7 @@ try: import torch -except: +except ImportError: pass @@ -23,22 +23,37 @@ def data_vstack(blocks: Sequence[modALinput]) -> modALinput: Returns: New sequence of vertically stacked elements. """ + + if not blocks: + return blocks + + types = {type(block) for block in blocks} + if any([sp.issparse(b) for b in blocks]): return sp.vstack(blocks) - elif isinstance(blocks[0], pd.DataFrame): - return blocks[0].append(blocks[1:]) - elif isinstance(blocks[0], np.ndarray): + elif types - {pd.DataFrame, pd.Series} == set(): + def _block_to_df(block): + if isinstance(block, pd.DataFrame): + return block + elif isinstance(block, pd.Series): + # interpret series as a row + return block.to_frame().T + else: + raise TypeError(f"Expected DataFrame or Series but encountered {type(block)}") + + return pd.concat([_block_to_df(block) for block in blocks]) + elif types == {np.ndarray}: return np.concatenate(blocks) - elif isinstance(blocks[0], list): + elif types == {list}: return np.concatenate(blocks).tolist() try: - if torch.is_tensor(blocks[0]): + if all(torch.is_tensor(block) for block in blocks): return torch.cat(blocks) except: pass - raise TypeError("%s datatype is not supported" % type(blocks[0])) + raise TypeError("%s datatype(s) not supported" % types) def data_hstack(blocks: Sequence[modALinput]) -> modALinput: @@ -51,13 +66,19 @@ def data_hstack(blocks: Sequence[modALinput]) -> modALinput: Returns: New sequence of horizontally stacked elements. """ + + if not blocks: + return blocks + + types = {type(block) for block in blocks} + if any([sp.issparse(b) for b in blocks]): return sp.hstack(blocks) - elif isinstance(blocks[0], pd.DataFrame): + elif types == {pd.DataFrame}: pd.concat(blocks, axis=1) - elif isinstance(blocks[0], np.ndarray): + elif types == {np.ndarray}: return np.hstack(blocks) - elif isinstance(blocks[0], list): + elif types == {list}: return np.hstack(blocks).tolist() try: @@ -66,7 +87,7 @@ def data_hstack(blocks: Sequence[modALinput]) -> modALinput: except: pass - TypeError("%s datatype is not supported" % type(blocks[0])) + raise TypeError("%s datatype(s) not supported" % types) def add_row(X: modALinput, row: modALinput): diff --git a/rtd_requirements.txt b/rtd_requirements.txt index db0bd81..8108132 100644 --- a/rtd_requirements.txt +++ b/rtd_requirements.txt @@ -1,7 +1,8 @@ -numpy==1.20.0 +numpy scipy scikit-learn ipykernel nbsphinx pandas skorch +torch \ No newline at end of file