Skip to content

Commit

Permalink
Adapted RelationalDeepBlocker encoding saving (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
dobraczka authored Oct 26, 2023
1 parent 844551e commit 186840e
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 10 deletions.
35 changes: 28 additions & 7 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,31 @@ def _handle_artifacts(
upload(artifact_dir, artifact_dir)


def _handle_encodings_dir(blocker, artifact_name, experiment_artifact_dir):
if isinstance(blocker, (EmbeddingBlocker, RelationalDeepBlocker)):
if blocker.force:
encodings_dir = _create_artifact_path(
artifact_name, experiment_artifact_dir, suffix="_encoded"
)
else:
encodings_dir = _create_artifact_path(
"ignoring_params", experiment_artifact_dir, suffix="_encoded"
)
if not os.path.exists(encodings_dir):
os.makedirs(encodings_dir)
run_info_path = _create_artifact_path(
f"created_by_{artifact_name}",
encodings_dir,
suffix="_encoded",
)
Path(run_info_path).touch()
blocker.save = True
blocker.save_dir = encodings_dir
return encodings_dir
else:
return None


def prepare(
blocker: Blocker, dataset: EADataset, params: Dict, wandb: bool
) -> ExperimentInfo:
Expand Down Expand Up @@ -173,13 +198,9 @@ def prepare(
os.makedirs(experiment_artifact_dir)
tracker.start_run()
artifact_name = _create_artifact_name(tracker, params)
encodings_dir = None
if isinstance(blocker, EmbeddingBlocker):
encodings_dir = _create_artifact_path(
artifact_name, experiment_artifact_dir, suffix="_encoded"
)
blocker.save = True
blocker.save_dir = encodings_dir
encodings_dir = _handle_encodings_dir(
blocker, artifact_name, experiment_artifact_dir
)

params_artifact_path = (
_create_artifact_path(
Expand Down
50 changes: 47 additions & 3 deletions src/klinker/blockers/relation_aware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, List, Optional, Tuple, TypeVar
import pathlib
from typing import Callable, List, Optional, Tuple, TypeVar, Union

import dask.dataframe as dd
import pandas as pd
Expand Down Expand Up @@ -345,6 +346,9 @@ class RelationalDeepBlocker(RelationalBlocker):
>>> blocks = blocker.assign(left=ds.left, right=ds.right, left_rel=ds.left_rel, right_rel=ds.right_rel)
"""

_attribute_blocker: DeepBlocker
_relation_blocker: DeepBlocker

def __init__(
self,
attr_frame_encoder: HintOrType[DeepBlockerFrameEncoder] = None,
Expand All @@ -355,19 +359,59 @@ def __init__(
rel_frame_encoder_kwargs: OptionalKwargs = None,
rel_embedding_block_builder: HintOrType[EmbeddingBlockBuilder] = None,
rel_embedding_block_builder_kwargs: OptionalKwargs = None,
save: bool = True,
save_dir: Optional[Union[str, pathlib.Path]] = None,
force: bool = False,
):
self._attribute_blocker = DeepBlocker(
frame_encoder=attr_frame_encoder,
frame_encoder_kwargs=attr_frame_encoder_kwargs,
embedding_block_builder=attr_embedding_block_builder,
embedding_block_builder_kwargs=attr_embedding_block_builder_kwargs,
force=force,
)
self._relation_blocker = DeepBlocker(
frame_encoder=rel_frame_encoder,
frame_encoder_kwargs=rel_frame_encoder_kwargs,
embedding_block_builder=rel_embedding_block_builder,
embedding_block_builder_kwargs=rel_embedding_block_builder_kwargs,
force=force,
)
# set after instatiating seperate blocker to use setter
self.save = save
self.force = force
self.save_dir = save_dir

@property
def save(self) -> bool:
return self._save

@save.setter
def save(self, value: bool):
self._save = value
self._attribute_blocker.save = value
self._relation_blocker.save = value

@property
def force(self) -> bool:
return self._force

@force.setter
def force(self, value: bool):
self._force = value
self._attribute_blocker.force = value
self._relation_blocker.force = value

@property
def save_dir(self) -> Optional[Union[str, pathlib.Path]]:
return self._save_dir

@save_dir.setter
def save_dir(self, value: Optional[Union[str, pathlib.Path]]):
if value is None:
self._save_dir = None
self._attribute_blocker.save_dir = None
self._relation_blocker.save_dir = None
else:
sd = pathlib.Path(value)
self._save_dir = sd
self._attribute_blocker.save_dir = sd.joinpath("attributes")
self._relation_blocker.save_dir = sd.joinpath("relation")

0 comments on commit 186840e

Please sign in to comment.