Skip to content

Commit

Permalink
Improve documentation and device config of labellers
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Oct 9, 2023
1 parent 10d1d6e commit 93f863f
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 51 deletions.
4 changes: 2 additions & 2 deletions src/algs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class Algorithm(DcModule):
max_grad_norm: Optional[float] = None

def __post_init__(self) -> None:
self.use_gpu: bool = torch.cuda.is_available() and self.gpu >= 0
self.device: torch.device = resolve_device(self.gpu)
self.use_amp = self.use_amp and self.use_gpu
use_gpu = torch.cuda.is_available() and self.gpu >= 0
self.use_amp = self.use_amp and use_gpu
self.grad_scaler: Optional[GradScaler] = GradScaler() if self.use_amp else None
logger.info(f"{torch.cuda.device_count()} GPU(s) available - using device '{self.device}'")

Expand Down
10 changes: 8 additions & 2 deletions src/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,13 @@ def set_transforms_all(self, value: Optional[ImageTform]) -> None:

@classmethod
def from_ds(
cls, *, config: DataModuleConf, ds: Dataset, splitter: DataSplitter, labeller: "Labeller"
cls,
*,
config: DataModuleConf,
ds: Dataset,
splitter: DataSplitter,
labeller: "Labeller",
device: torch.device,
) -> Self:
splits = splitter(ds)
dm = cls(
Expand All @@ -432,7 +438,7 @@ def from_ds(
test=splits.test,
split_seed=getattr(splitter, "seed", None),
)
deployment_ids = labeller.run(dm=dm)
deployment_ids = labeller.run(dm=dm, device=device)
dm.deployment_ids = deployment_ids
return dm

Expand Down
4 changes: 4 additions & 0 deletions src/labelling/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def centroidal_label_noise(
num[row_inds, inv[indices]] = 0.0
denom = num.sum(dim=1, keepdim=True)
probs = num / denom
# random sampling is better done on the CPU
probs = probs.to(torch.device("cpu"))
new_labels = torch.multinomial(probs, num_samples=1, replacement=False, generator=generator)
del probs
labels = labels.to(torch.device("cpu"))
indices = indices.to(torch.device("cpu"))
if not inplace:
labels = labels.clone()
labels[indices] = new_labels.squeeze(1)
Expand Down
92 changes: 51 additions & 41 deletions src/labelling/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import wandb
from wandb.wandb_run import Run

from src.data import DataModule, resolve_device
from src.data import DataModule
from src.evaluation.metrics import print_metrics
from src.models import Classifier, OptimizerCfg
from src.utils import to_item
Expand Down Expand Up @@ -41,60 +41,54 @@

class Labeller(ABC):
@abstractmethod
def run(self, dm: DataModule) -> Optional[Tensor]:
def run(self, dm: DataModule, device: torch.device) -> Optional[Tensor]:
raise NotImplementedError()

def __call__(self, dm: DataModule) -> Optional[Tensor]:
return self.run(dm=dm)
def __call__(self, dm: DataModule, device: torch.device) -> Optional[Tensor]:
return self.run(dm=dm, device=device)


@dataclass(repr=False, eq=False)
class KmeansOnClipEncodings(DcModule, Labeller):
"""Generate embeddings with CLIP (optionally finetuned) and then do k-means."""

clip_version: ClipVersion = ClipVersion.RN50
download_root: Optional[str] = None
ft: FineTuneParams = field(default_factory=lambda: FineTuneParams(steps=1000))
enc_batch_size: int = 64

gpu: int = 0
spherical: bool = True
fft_cluster_init: bool = False
supervised_cluster_init: bool = False
n_init: int = 10
save_as_artifact: bool = True
artifact_name: Optional[str] = None

cache_encoder: bool = False
# cache_encoder: bool = False
encodings_path: Optional[Path] = None

encoder: Optional[ClipVisualEncoder] = field(
init=False, default=None, metadata={"omegaconf_ignore": True}
)
# encoder: Optional[ClipVisualEncoder] = field(
# init=False, default=None, metadata={"omegaconf_ignore": True}
# )
_fitted_kmeans: Optional[KMeans] = field(
init=False, default=None, metadata={"omegaconf_ignore": True}
)

@override
def run(self, dm: DataModule, *, use_cached_encoder: bool = False) -> Tensor:
device = resolve_device(self.gpu)
def run(self, dm: DataModule, device: torch.device) -> Tensor:
if self.encodings_path is not None and self.encodings_path.exists():
encodings = Encodings.from_npz(self.encodings_path)
else:
if self.encoder is None or not use_cached_encoder:
encoder = ClipVisualEncoder(
version=self.clip_version, download_root=self.download_root
)
if self.ft.steps > 0:
encoder.finetune(dm=dm, params=self.ft, device=device)
else:
encoder = self.encoder
encoder = ClipVisualEncoder(version=self.clip_version, download_root=self.download_root)
if self.ft.steps > 0:
encoder.finetune(dm=dm, params=self.ft, device=device)
encodings = encoder.encode(dm=dm, batch_size_tr=self.enc_batch_size, device=device)
if self.encodings_path is not None:
encodings.save(self.encodings_path)
if self.cache_encoder:
self.encoder = encoder
else:
del encoder
torch.cuda.empty_cache()
# if self.cache_encoder:
# self.encoder = encoder
del encoder
torch.cuda.empty_cache()

kmeans = KMeans(
spherical=self.spherical,
Expand All @@ -115,16 +109,17 @@ def run(self, dm: DataModule, *, use_cached_encoder: bool = False) -> Tensor:

@dataclass(eq=False)
class ClipClassifier(Labeller):
"""Predict s and y with a fine-tuned CLIP classifier."""

clip_version: ClipVersion = ClipVersion.RN50
download_root: Optional[str] = None
ft: FineTuneParams = field(default_factory=lambda: FineTuneParams(steps=1000))
batch_size_te: int = 64

gpu: int = 0
save_as_artifact: bool = True
artifact_name: Optional[str] = None

cache_encoder: bool = False
# cache_encoder: bool = False

@torch.no_grad()
def evaluate(
Expand All @@ -147,8 +142,7 @@ def evaluate(
return metrics

@override
def run(self, dm: DataModule, *, use_cached_encoder: bool = False) -> Tensor:
device = resolve_device(self.gpu)
def run(self, dm: DataModule, device: torch.device) -> Tensor:
encoder = ClipVisualEncoder(version=self.clip_version, download_root=self.download_root)
ft_model = encoder.finetune(dm=dm, params=self.ft, device=device)
classifier = Classifier(model=ft_model, opt=OptimizerCfg())
Expand All @@ -172,12 +166,14 @@ def run(self, dm: DataModule, *, use_cached_encoder: bool = False) -> Tensor:

@dataclass(eq=False)
class LabelFromArtifact(Labeller):
"""Load labels from W&B."""

version: Optional[int] = None # latest by default
artifact_name: Optional[str] = None
root: Optional[Path] = None # artifacts/clustering by default

@override
def run(self, dm: DataModule) -> Tensor:
def run(self, dm: DataModule, device: torch.device) -> Tensor:
return load_labels_from_artifact(
run=wandb.run,
datamodule=dm,
Expand All @@ -189,26 +185,32 @@ def run(self, dm: DataModule) -> Tensor:

@dataclass(eq=False)
class NullLabeller(Labeller):
"""Don't do any bag balancing."""

@override
def run(self, dm: DataModule) -> None:
def run(self, dm: DataModule, device: torch.device) -> None:
return None


@dataclass(eq=False)
class GroundTruthLabeller(Labeller):
"""Use ground truth for bag balancing."""

seed: int = 47

@property
def generator(self) -> torch.Generator:
return torch.Generator().manual_seed(self.seed)

@override
def run(self, dm: DataModule) -> Tensor:
def run(self, dm: DataModule, device: torch.device) -> Tensor:
return dm.group_ids_dep


@dataclass(eq=False)
class LabelNoiser(Labeller):
"""Base class for methods which take the ground truth labels to add noise to them."""

level: float = 0.10
seed: int = 47
weighted_index_sampling: bool = True
Expand All @@ -222,11 +224,13 @@ def generator(self) -> torch.Generator:
return torch.Generator().manual_seed(self.seed)

@abstractmethod
def _noise(self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule) -> Tensor:
def _noise(
self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule, device: torch.device
) -> Tensor:
raise NotImplementedError()

@override
def run(self, dm: DataModule) -> Tensor:
def run(self, dm: DataModule, device: torch.device) -> Tensor:
group_ids = dm.group_ids_dep
logger.info(
f"Injecting noise into ground-truth labels with noise level '{self.level}'"
Expand All @@ -240,40 +244,46 @@ def run(self, dm: DataModule) -> Tensor:
weighted=self.weighted_index_sampling,
)
# Inject label-noise into the group identifiers.
group_ids = self._noise(dep_ids=group_ids, flip_inds=flip_inds, dm=dm)
group_ids = self._noise(dep_ids=group_ids, flip_inds=flip_inds, dm=dm, device=device)
return group_ids


@dataclass(eq=False)
class UniformLabelNoiser(LabelNoiser):
"""Take the ground truth labels and flip them uniformly randomly."""

@override
def _noise(self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule) -> Tensor:
def _noise(
self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule, device: torch.device
) -> Tensor:
return uniform_label_noise(
labels=dep_ids, indices=flip_inds, generator=self.generator, inplace=True
)


@dataclass(eq=False)
class CentroidalLabelNoiser(LabelNoiser):
"""Get embeddings from (non-fine-tuned) CLIP and the flip to nearest centroid."""

metric: ClnMetric = ClnMetric.COSINE
clip_version: ClipVersion = ClipVersion.RN50
download_root: Optional[str] = None
enc_batch_size: int = 64
gpu: int = 0

@override
def _noise(self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule) -> Tensor:
device = resolve_device(self.gpu)
def _noise(
self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule, device: torch.device
) -> Tensor:
encoder = ClipVisualEncoder(version=self.clip_version, download_root=self.download_root)
encodings, _ = encode_with_group_ids(
model=encoder,
dl=dm.deployment_dataloader(eval=True, batch_size=self.enc_batch_size),
device=device,
)
return centroidal_label_noise(
labels=dep_ids,
indices=flip_inds,
encodings=encodings,
labels=dep_ids.to(device),
indices=flip_inds.to(device),
encodings=encodings.to(device),
generator=self.generator,
inplace=True,
)
4 changes: 1 addition & 3 deletions src/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def init(
) -> Run:
if self.tags is None:
self.tags = []
self.tags.extend(
cfg_obj.__class__.__name__ for cfg_obj in cfgs_for_group
)
self.tags.extend(cfg_obj.__class__.__name__ for cfg_obj in cfgs_for_group)
if with_tag is not None:
self.tags.append(with_tag)
# TODO: not sure whether `reinit` really should be hardcoded
Expand Down
9 changes: 7 additions & 2 deletions src/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ class BaseRelay:
}

def init_dm(
self, ds: CdtVisionDataset[TernarySample, Tensor, Tensor], labeller: Labeller
self,
ds: CdtVisionDataset[TernarySample, Tensor, Tensor],
labeller: Labeller,
device: torch.device,
) -> DataModule:
assert isinstance(self.split, DataSplitter)

logger.info(f"Current working directory: '{os.getcwd()}'")
random_seed(self.seed, use_cuda=True)
torch.multiprocessing.set_sharing_strategy("file_system")
dm = DataModule.from_ds(config=self.dm, ds=ds, splitter=self.split, labeller=labeller)
dm = DataModule.from_ds(
config=self.dm, ds=ds, splitter=self.split, labeller=labeller, device=device
)
logger.info(str(dm))
return dm
2 changes: 1 addition & 1 deletion src/relay/supmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def run(self, raw_config: Optional[dict[str, Any]] = None) -> Optional[float]:

ds = self.ds()
run = self.wandb.init(raw_config, (ds, self.labeller, self.ae_arch, self.disc_arch))
dm = self.init_dm(ds, self.labeller)
dm = self.init_dm(ds, self.labeller, self.alg.device)
ae_pair = self.ae_arch(input_shape=dm.dim_x)
ae = SplitLatentAe(opt=self.ae, model=ae_pair, feature_group_slices=dm.feature_group_slices)
logger.info(f"Encoding dim: {ae.latent_dim}, {ae.encoding_size}")
Expand Down

0 comments on commit 93f863f

Please sign in to comment.