diff --git a/src/algs/base.py b/src/algs/base.py index abf7b619..1184cd33 100644 --- a/src/algs/base.py +++ b/src/algs/base.py @@ -23,9 +23,9 @@ class Algorithm(DcModule): max_grad_norm: Optional[float] = None def __post_init__(self) -> None: - self.use_gpu: bool = torch.cuda.is_available() and self.gpu >= 0 self.device: torch.device = resolve_device(self.gpu) - self.use_amp = self.use_amp and self.use_gpu + use_gpu = torch.cuda.is_available() and self.gpu >= 0 + self.use_amp = self.use_amp and use_gpu self.grad_scaler: Optional[GradScaler] = GradScaler() if self.use_amp else None logger.info(f"{torch.cuda.device_count()} GPU(s) available - using device '{self.device}'") diff --git a/src/data/data_module.py b/src/data/data_module.py index afe359fe..08abda55 100644 --- a/src/data/data_module.py +++ b/src/data/data_module.py @@ -422,7 +422,13 @@ def set_transforms_all(self, value: Optional[ImageTform]) -> None: @classmethod def from_ds( - cls, *, config: DataModuleConf, ds: Dataset, splitter: DataSplitter, labeller: "Labeller" + cls, + *, + config: DataModuleConf, + ds: Dataset, + splitter: DataSplitter, + labeller: "Labeller", + device: torch.device, ) -> Self: splits = splitter(ds) dm = cls( @@ -432,7 +438,7 @@ def from_ds( test=splits.test, split_seed=getattr(splitter, "seed", None), ) - deployment_ids = labeller.run(dm=dm) + deployment_ids = labeller.run(dm=dm, device=device) dm.deployment_ids = deployment_ids return dm diff --git a/src/labelling/noise.py b/src/labelling/noise.py index 4660b23c..5d8ce187 100644 --- a/src/labelling/noise.py +++ b/src/labelling/noise.py @@ -92,8 +92,12 @@ def centroidal_label_noise( num[row_inds, inv[indices]] = 0.0 denom = num.sum(dim=1, keepdim=True) probs = num / denom + # random sampling is better done on the CPU + probs = probs.to(torch.device("cpu")) new_labels = torch.multinomial(probs, num_samples=1, replacement=False, generator=generator) del probs + labels = labels.to(torch.device("cpu")) + indices = indices.to(torch.device("cpu")) if not inplace: labels = labels.clone() labels[indices] = new_labels.squeeze(1) diff --git a/src/labelling/pipeline.py b/src/labelling/pipeline.py index b449ac41..bd87d0b0 100644 --- a/src/labelling/pipeline.py +++ b/src/labelling/pipeline.py @@ -13,7 +13,7 @@ import wandb from wandb.wandb_run import Run -from src.data import DataModule, resolve_device +from src.data import DataModule from src.evaluation.metrics import print_metrics from src.models import Classifier, OptimizerCfg from src.utils import to_item @@ -41,21 +41,22 @@ class Labeller(ABC): @abstractmethod - def run(self, dm: DataModule) -> Optional[Tensor]: + def run(self, dm: DataModule, device: torch.device) -> Optional[Tensor]: raise NotImplementedError() - def __call__(self, dm: DataModule) -> Optional[Tensor]: - return self.run(dm=dm) + def __call__(self, dm: DataModule, device: torch.device) -> Optional[Tensor]: + return self.run(dm=dm, device=device) @dataclass(repr=False, eq=False) class KmeansOnClipEncodings(DcModule, Labeller): + """Generate embeddings with CLIP (optionally finetuned) and then do k-means.""" + clip_version: ClipVersion = ClipVersion.RN50 download_root: Optional[str] = None ft: FineTuneParams = field(default_factory=lambda: FineTuneParams(steps=1000)) enc_batch_size: int = 64 - gpu: int = 0 spherical: bool = True fft_cluster_init: bool = False supervised_cluster_init: bool = False @@ -63,38 +64,31 @@ class KmeansOnClipEncodings(DcModule, Labeller): save_as_artifact: bool = True artifact_name: Optional[str] = None - cache_encoder: bool = False + # cache_encoder: bool = False encodings_path: Optional[Path] = None - encoder: Optional[ClipVisualEncoder] = field( - init=False, default=None, metadata={"omegaconf_ignore": True} - ) + # encoder: Optional[ClipVisualEncoder] = field( + # init=False, default=None, metadata={"omegaconf_ignore": True} + # ) _fitted_kmeans: Optional[KMeans] = field( init=False, default=None, metadata={"omegaconf_ignore": True} ) @override - def run(self, dm: DataModule, *, use_cached_encoder: bool = False) -> Tensor: - device = resolve_device(self.gpu) + def run(self, dm: DataModule, device: torch.device) -> Tensor: if self.encodings_path is not None and self.encodings_path.exists(): encodings = Encodings.from_npz(self.encodings_path) else: - if self.encoder is None or not use_cached_encoder: - encoder = ClipVisualEncoder( - version=self.clip_version, download_root=self.download_root - ) - if self.ft.steps > 0: - encoder.finetune(dm=dm, params=self.ft, device=device) - else: - encoder = self.encoder + encoder = ClipVisualEncoder(version=self.clip_version, download_root=self.download_root) + if self.ft.steps > 0: + encoder.finetune(dm=dm, params=self.ft, device=device) encodings = encoder.encode(dm=dm, batch_size_tr=self.enc_batch_size, device=device) if self.encodings_path is not None: encodings.save(self.encodings_path) - if self.cache_encoder: - self.encoder = encoder - else: - del encoder - torch.cuda.empty_cache() + # if self.cache_encoder: + # self.encoder = encoder + del encoder + torch.cuda.empty_cache() kmeans = KMeans( spherical=self.spherical, @@ -115,16 +109,17 @@ def run(self, dm: DataModule, *, use_cached_encoder: bool = False) -> Tensor: @dataclass(eq=False) class ClipClassifier(Labeller): + """Predict s and y with a fine-tuned CLIP classifier.""" + clip_version: ClipVersion = ClipVersion.RN50 download_root: Optional[str] = None ft: FineTuneParams = field(default_factory=lambda: FineTuneParams(steps=1000)) batch_size_te: int = 64 - gpu: int = 0 save_as_artifact: bool = True artifact_name: Optional[str] = None - cache_encoder: bool = False + # cache_encoder: bool = False @torch.no_grad() def evaluate( @@ -147,8 +142,7 @@ def evaluate( return metrics @override - def run(self, dm: DataModule, *, use_cached_encoder: bool = False) -> Tensor: - device = resolve_device(self.gpu) + def run(self, dm: DataModule, device: torch.device) -> Tensor: encoder = ClipVisualEncoder(version=self.clip_version, download_root=self.download_root) ft_model = encoder.finetune(dm=dm, params=self.ft, device=device) classifier = Classifier(model=ft_model, opt=OptimizerCfg()) @@ -172,12 +166,14 @@ def run(self, dm: DataModule, *, use_cached_encoder: bool = False) -> Tensor: @dataclass(eq=False) class LabelFromArtifact(Labeller): + """Load labels from W&B.""" + version: Optional[int] = None # latest by default artifact_name: Optional[str] = None root: Optional[Path] = None # artifacts/clustering by default @override - def run(self, dm: DataModule) -> Tensor: + def run(self, dm: DataModule, device: torch.device) -> Tensor: return load_labels_from_artifact( run=wandb.run, datamodule=dm, @@ -189,13 +185,17 @@ def run(self, dm: DataModule) -> Tensor: @dataclass(eq=False) class NullLabeller(Labeller): + """Don't do any bag balancing.""" + @override - def run(self, dm: DataModule) -> None: + def run(self, dm: DataModule, device: torch.device) -> None: return None @dataclass(eq=False) class GroundTruthLabeller(Labeller): + """Use ground truth for bag balancing.""" + seed: int = 47 @property @@ -203,12 +203,14 @@ def generator(self) -> torch.Generator: return torch.Generator().manual_seed(self.seed) @override - def run(self, dm: DataModule) -> Tensor: + def run(self, dm: DataModule, device: torch.device) -> Tensor: return dm.group_ids_dep @dataclass(eq=False) class LabelNoiser(Labeller): + """Base class for methods which take the ground truth labels to add noise to them.""" + level: float = 0.10 seed: int = 47 weighted_index_sampling: bool = True @@ -222,11 +224,13 @@ def generator(self) -> torch.Generator: return torch.Generator().manual_seed(self.seed) @abstractmethod - def _noise(self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule) -> Tensor: + def _noise( + self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule, device: torch.device + ) -> Tensor: raise NotImplementedError() @override - def run(self, dm: DataModule) -> Tensor: + def run(self, dm: DataModule, device: torch.device) -> Tensor: group_ids = dm.group_ids_dep logger.info( f"Injecting noise into ground-truth labels with noise level '{self.level}'" @@ -240,14 +244,18 @@ def run(self, dm: DataModule) -> Tensor: weighted=self.weighted_index_sampling, ) # Inject label-noise into the group identifiers. - group_ids = self._noise(dep_ids=group_ids, flip_inds=flip_inds, dm=dm) + group_ids = self._noise(dep_ids=group_ids, flip_inds=flip_inds, dm=dm, device=device) return group_ids @dataclass(eq=False) class UniformLabelNoiser(LabelNoiser): + """Take the ground truth labels and flip them uniformly randomly.""" + @override - def _noise(self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule) -> Tensor: + def _noise( + self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule, device: torch.device + ) -> Tensor: return uniform_label_noise( labels=dep_ids, indices=flip_inds, generator=self.generator, inplace=True ) @@ -255,15 +263,17 @@ def _noise(self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule) -> Tenso @dataclass(eq=False) class CentroidalLabelNoiser(LabelNoiser): + """Get embeddings from (non-fine-tuned) CLIP and the flip to nearest centroid.""" + metric: ClnMetric = ClnMetric.COSINE clip_version: ClipVersion = ClipVersion.RN50 download_root: Optional[str] = None enc_batch_size: int = 64 - gpu: int = 0 @override - def _noise(self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule) -> Tensor: - device = resolve_device(self.gpu) + def _noise( + self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule, device: torch.device + ) -> Tensor: encoder = ClipVisualEncoder(version=self.clip_version, download_root=self.download_root) encodings, _ = encode_with_group_ids( model=encoder, @@ -271,9 +281,9 @@ def _noise(self, dep_ids: Tensor, *, flip_inds: Tensor, dm: DataModule) -> Tenso device=device, ) return centroidal_label_noise( - labels=dep_ids, - indices=flip_inds, - encodings=encodings, + labels=dep_ids.to(device), + indices=flip_inds.to(device), + encodings=encodings.to(device), generator=self.generator, inplace=True, ) diff --git a/src/logging.py b/src/logging.py index 82df40fa..c7a86a2f 100644 --- a/src/logging.py +++ b/src/logging.py @@ -50,9 +50,7 @@ def init( ) -> Run: if self.tags is None: self.tags = [] - self.tags.extend( - cfg_obj.__class__.__name__ for cfg_obj in cfgs_for_group - ) + self.tags.extend(cfg_obj.__class__.__name__ for cfg_obj in cfgs_for_group) if with_tag is not None: self.tags.append(with_tag) # TODO: not sure whether `reinit` really should be hardcoded diff --git a/src/relay/base.py b/src/relay/base.py index 2e65402f..de2f2ec5 100644 --- a/src/relay/base.py +++ b/src/relay/base.py @@ -29,13 +29,18 @@ class BaseRelay: } def init_dm( - self, ds: CdtVisionDataset[TernarySample, Tensor, Tensor], labeller: Labeller + self, + ds: CdtVisionDataset[TernarySample, Tensor, Tensor], + labeller: Labeller, + device: torch.device, ) -> DataModule: assert isinstance(self.split, DataSplitter) logger.info(f"Current working directory: '{os.getcwd()}'") random_seed(self.seed, use_cuda=True) torch.multiprocessing.set_sharing_strategy("file_system") - dm = DataModule.from_ds(config=self.dm, ds=ds, splitter=self.split, labeller=labeller) + dm = DataModule.from_ds( + config=self.dm, ds=ds, splitter=self.split, labeller=labeller, device=device + ) logger.info(str(dm)) return dm diff --git a/src/relay/supmatch.py b/src/relay/supmatch.py index 9ff7c555..2e15d592 100644 --- a/src/relay/supmatch.py +++ b/src/relay/supmatch.py @@ -101,7 +101,7 @@ def run(self, raw_config: Optional[dict[str, Any]] = None) -> Optional[float]: ds = self.ds() run = self.wandb.init(raw_config, (ds, self.labeller, self.ae_arch, self.disc_arch)) - dm = self.init_dm(ds, self.labeller) + dm = self.init_dm(ds, self.labeller, self.alg.device) ae_pair = self.ae_arch(input_shape=dm.dim_x) ae = SplitLatentAe(opt=self.ae, model=ae_pair, feature_group_slices=dm.feature_group_slices) logger.info(f"Encoding dim: {ae.latent_dim}, {ae.encoding_size}")