Skip to content

Commit

Permalink
Feature/categorical molecular2 (#260)
Browse files Browse the repository at this point in the history
* add basic categoricalmolecular one

* add spec

* implement new transform validation

* update tests

* add from descriptor encoding

* update inverse transform

* fix pyright

* add test
  • Loading branch information
jduerholt authored Aug 15, 2023
1 parent 8eb3fe4 commit 810a385
Show file tree
Hide file tree
Showing 11 changed files with 323 additions and 54 deletions.
78 changes: 38 additions & 40 deletions bofire/data_models/domain/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
import warnings
from enum import Enum
from typing import Dict, List, Literal, Optional, Sequence, Tuple, Type, Union, cast

import numpy as np
Expand All @@ -18,6 +19,7 @@
AnyOutput,
CategoricalDescriptorInput,
CategoricalInput,
CategoricalMolecularInput,
ContinuousInput,
ContinuousOutput,
DiscreteInput,
Expand Down Expand Up @@ -379,6 +381,7 @@ def transform(
Returns:
pd.DataFrame: Transformed dataframe. Only input features are included.
"""
# TODO: clean this up and move it into the individual classes
specs = self._validate_transform_specs(specs)
transformed = []
for feat in self.get():
Expand Down Expand Up @@ -418,6 +421,7 @@ def inverse_transform(
Returns:
pd.DataFrame: Back transformed dataframe. Only input features are included.
"""
# TODO: clean this up and move it into the individual classes
self._validate_transform_specs(specs=specs)
transformed = []
for feat in self.get():
Expand All @@ -437,57 +441,51 @@ def inverse_transform(
elif specs[feat.key] == CategoricalEncodingEnum.DESCRIPTOR:
assert isinstance(feat, CategoricalDescriptorInput)
transformed.append(feat.from_descriptor_encoding(experiments))
elif isinstance(specs[feat.key], MolFeatures):
assert isinstance(feat, CategoricalMolecularInput)
transformed.append(feat.from_descriptor_encoding(specs[feat.key], experiments)) # type: ignore

return pd.concat(transformed, axis=1)

def _validate_transform_specs(self, specs: TInputTransformSpecs):
def _validate_transform_specs(
self, specs: TInputTransformSpecs
) -> TInputTransformSpecs:
"""Checks the validity of the transform specs .
Args:
specs (TInputTransformSpecs): Transform specs to be validated.
"""
# first check that the keys in the specs dict are correct also correct feature keys
if (
len(
set(specs.keys())
- set(self.get_keys(CategoricalInput))
- set(self.get_keys(MolecularInput))
)
> 0
):
raise ValueError("Unknown features specified in transform specs.")
# next check that all values are of type CategoricalEncodingEnum or MolFeatures
if not (
all(
isinstance(enc, (CategoricalEncodingEnum, MolFeatures))
for enc in specs.values()
)
):
raise ValueError("Unknown transform specified.")
# next check that only CategoricalDescriptorInput can have the value DESCRIPTOR
descriptor_keys = []
for key, value in specs.items():
if value == CategoricalEncodingEnum.DESCRIPTOR:
descriptor_keys.append(key)
if (
len(set(descriptor_keys) - set(self.get_keys(CategoricalDescriptorInput)))
> 0
):
raise ValueError("Wrong features types assigned to DESCRIPTOR transform.")
# next check if MolFeatures have been assigned to feature types other than MolecularInput
molfeature_keys = []
for key, value in specs.items():
if isinstance(value, MolFeatures):
molfeature_keys.append(key)
if len(set(molfeature_keys) - set(self.get_keys(MolecularInput))) > 0:
raise ValueError("Wrong features types assigned to MolFeatures transforms.")
# next check that all MolecularInput have MolFeatures transforms
for feat in self.get(includes=[MolecularInput]):
mol_encoding = specs.get(feat.key)
if mol_encoding is None:
raise ValueError("No transform assigned to MolecularInput.")
elif not isinstance(mol_encoding, MolFeatures):
raise ValueError("Incorrect transform assigned to MolecularInput.")
try:
feat = self.get_by_key(key)
except KeyError:
raise ValueError(
f"Unknown feature with key {key} specified in transform specs."
)
# TODO
# this is ugly, on the long run we have to get rid of the transform enums
# and replace them with classes, then the following lines collapse into just two
assert isinstance(feat, Input)
enums = [t for t in feat.valid_transform_types() if isinstance(t, Enum)]
no_enums = [
t for t in feat.valid_transform_types() if not isinstance(t, Enum)
]
if isinstance(value, Enum):
if value not in enums:
raise ValueError(
f"Forbidden transform type for feature with key {key}"
)
else:
if len(no_enums) == 0:
raise ValueError(
f"Forbidden transform type for feature with key {key}"
)
if not isinstance(value, tuple(no_enums)): # type: ignore
raise ValueError(
f"Forbidden transform type for feature with key {key}"
)
return specs

def get_bounds(
Expand Down
7 changes: 6 additions & 1 deletion bofire/data_models/features/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
Output,
TInputTransformSpecs,
)
from bofire.data_models.features.molecular import MolecularInput
from bofire.data_models.features.molecular import (
CategoricalMolecularInput,
MolecularInput,
)
from bofire.data_models.features.numerical import NumericalInput

AbstractFeature = Union[
Expand All @@ -33,6 +36,7 @@
CategoricalDescriptorInput,
MolecularInput,
CategoricalOutput,
CategoricalMolecularInput,
]

AnyInput = Union[
Expand All @@ -42,6 +46,7 @@
ContinuousDescriptorInput,
CategoricalDescriptorInput,
MolecularInput,
CategoricalMolecularInput,
]

AnyOutput = Union[ContinuousOutput, CategoricalOutput]
8 changes: 8 additions & 0 deletions bofire/data_models/features/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def init_allowed(cls, values):
raise ValueError("no category is allowed")
return values

@staticmethod
def valid_transform_types() -> List[CategoricalEncodingEnum]:
return [
CategoricalEncodingEnum.ONE_HOT,
CategoricalEncodingEnum.DUMMY,
CategoricalEncodingEnum.ORDINAL,
]

def is_fixed(self) -> bool:
"""Returns True if there is only one allowed category.
Expand Down
9 changes: 9 additions & 0 deletions bofire/data_models/features/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def validate_values(cls, v, values):
raise ValueError(f"No variation for descriptor {d}.")
return v

@staticmethod
def valid_transform_types() -> List[CategoricalEncodingEnum]:
return [
CategoricalEncodingEnum.ONE_HOT,
CategoricalEncodingEnum.DUMMY,
CategoricalEncodingEnum.ORDINAL,
CategoricalEncodingEnum.DESCRIPTOR,
]

def to_df(self):
"""tabular overview of the feature as DataFrame
Expand Down
5 changes: 5 additions & 0 deletions bofire/data_models/features/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def __lt__(self, other) -> bool:
class Input(Feature):
"""Base class for all input features."""

@staticmethod
@abstractmethod
def valid_transform_types() -> List[Union[CategoricalEncodingEnum, AnyMolFeatures]]:
pass

@abstractmethod
def is_fixed(self) -> bool:
"""Indicates if a variable is set to a fixed value.
Expand Down
126 changes: 123 additions & 3 deletions bofire/data_models/features/molecular.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
from typing import ClassVar, List, Literal, Optional, Tuple
import warnings
from typing import ClassVar, List, Literal, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
from pydantic import validator

from bofire.data_models.features.categorical import _CAT_SEP
from bofire.data_models.enum import CategoricalEncodingEnum
from bofire.data_models.features.categorical import _CAT_SEP, CategoricalInput
from bofire.data_models.features.feature import Input
from bofire.data_models.molfeatures.api import AnyMolFeatures
from bofire.data_models.molfeatures.api import (
AnyMolFeatures,
Fingerprints,
FingerprintsFragments,
Fragments,
MordredDescriptors,
)
from bofire.utils.cheminformatics import smiles2mol


class MolecularInput(Input):
type: Literal["MolecularInput"] = "MolecularInput"
order: ClassVar[int] = 6

@staticmethod
def valid_transform_types() -> List[AnyMolFeatures]:
return [Fingerprints, FingerprintsFragments, Fragments, MordredDescriptors] # type: ignore

def validate_experimental(
self, values: pd.Series, strict: bool = False
) -> pd.Series:
Expand Down Expand Up @@ -68,3 +82,109 @@ def to_descriptor_encoding(
descriptor_values.index = values.index

return descriptor_values


class CategoricalMolecularInput(CategoricalInput, MolecularInput):
type: Literal["CategoricalMolecularInput"] = "CategoricalMolecularInput"
order: ClassVar[int] = 7

@validator("categories")
def validate_smiles(cls, categories: Sequence[str]):
"""validates that categories are valid smiles. Note that this check can only
be executed when rdkit is available.
Args:
categories (List[str]): List of smiles
Raises:
ValueError: when string is not a smiles
Returns:
List[str]: List of the smiles
"""
# check on rdkit availability:
try:
smiles2mol(categories[0])
except NameError:
warnings.warn("rdkit not installed, categories cannot be validated.")
return categories

for cat in categories:
smiles2mol(cat)
return categories

@staticmethod
def valid_transform_types() -> List[Union[AnyMolFeatures, CategoricalEncodingEnum]]:
return CategoricalInput.valid_transform_types() + [
Fingerprints,
FingerprintsFragments,
Fragments,
MordredDescriptors, # type: ignore
]

def get_bounds(
self,
transform_type: Union[CategoricalEncodingEnum, AnyMolFeatures],
values: Optional[pd.Series] = None,
) -> Tuple[List[float], List[float]]:
if isinstance(transform_type, CategoricalEncodingEnum):
# we are just using the standard categorical transformations
return super().get_bounds(transform_type=transform_type, values=values)
else:
# in case that values is None, we return the optimization bounds
# else we return the complete bounds
data = self.to_descriptor_encoding(
transform_type=transform_type,
values=pd.Series(self.get_allowed_categories())
if values is None
else pd.Series(self.categories),
)
lower = data.min(axis=0).values.tolist()
upper = data.max(axis=0).values.tolist()
return lower, upper

def from_descriptor_encoding(
self, transform_type: AnyMolFeatures, values: pd.DataFrame
) -> pd.Series:
"""Converts values back from descriptor encoding.
Args:
values (pd.DataFrame): Descriptor encoded dataframe.
Raises:
ValueError: If descriptor columns not found in the dataframe.
Returns:
pd.Series: Series with categorical values.
"""

# This method is modified based on the categorical descriptor feature
# TODO: move it to more central place
cat_cols = [
f"{self.key}{_CAT_SEP}{d}" for d in transform_type.get_descriptor_names()
]
# we allow here explicitly that the dataframe can have more columns than needed to have it
# easier in the backtransform.
if np.any([c not in values.columns for c in cat_cols]):
raise ValueError(
f"{self.key}: Column names don't match categorical levels: {values.columns}, {cat_cols}."
)
s = pd.DataFrame(
data=np.sqrt(
np.sum(
(
values[cat_cols].to_numpy()[:, np.newaxis, :]
- self.to_descriptor_encoding(
transform_type=transform_type,
values=pd.Series(self.get_allowed_categories()),
).to_numpy()
)
** 2,
axis=2,
)
),
columns=self.get_allowed_categories(),
index=values.index,
).idxmin(1)
s.name = self.key
return s
4 changes: 4 additions & 0 deletions bofire/data_models/features/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ class NumericalInput(Input):

unit: Optional[str] = None

@staticmethod
def valid_transform_types() -> List:
return []

def to_unit_range(
self, values: Union[pd.Series, np.ndarray], use_real_bounds: bool = False
) -> Union[pd.Series, np.ndarray]:
Expand Down
9 changes: 9 additions & 0 deletions bofire/data_models/molfeatures/molfeatures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import abstractmethod
from typing import List, Literal, Optional

import pandas as pd
Expand All @@ -19,6 +20,14 @@ class MolFeatures(BaseModel):

type: str

@abstractmethod
def get_descriptor_names(self) -> List[str]:
pass

@abstractmethod
def get_descriptor_values(self, values: pd.Series) -> pd.DataFrame:
pass


class Fingerprints(MolFeatures):
type: Literal["Fingerprints"] = "Fingerprints"
Expand Down
14 changes: 14 additions & 0 deletions tests/bofire/data_models/specs/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,17 @@
"key": str(uuid.uuid4()),
},
)

specs.add_valid(
features.CategoricalMolecularInput,
lambda: {
"key": str(uuid.uuid4()),
"categories": [
"CC(=O)Oc1ccccc1C(=O)O",
"c1ccccc1",
"[CH3][CH2][OH]",
"N[C@](C)(F)C(=O)O",
],
"allowed": [True, True, True, True],
},
)
Loading

0 comments on commit 810a385

Please sign in to comment.