Skip to content

Commit

Permalink
Add training code
Browse files Browse the repository at this point in the history
  • Loading branch information
naitian committed Nov 16, 2023
1 parent 2d0d9c8 commit c82442c
Show file tree
Hide file tree
Showing 4 changed files with 536 additions and 0 deletions.
89 changes: 89 additions & 0 deletions memes/representations/prepare_clip_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Prepares huggingface dataset."""

from collections import Counter, defaultdict
from pathlib import Path
import random
from datasets import load_dataset, DatasetDict
from sklearn.model_selection import train_test_split
from PIL import Image
from tqdm.auto import tqdm

from memes.utils import DATA_DIR


dataset = load_dataset(
"csv", data_files=str(DATA_DIR / "representations/all-8-processed-filtered.csv")
)
dataset = dataset["train"]


def get_template_group(dataset):
templates = dataset["meme_template"]
img_paths = dataset["img_path"]
template_group = defaultdict(set)
for i, tpl in enumerate(tqdm(templates)):
if not Path(img_paths[i]).exists():
continue
template_group[tpl].add(img_paths[i])
return template_group

template_group = get_template_group(dataset)
keeplist = {k for k, v in template_group.items() if len(v) >= 100}

def split_ocr(row):
ocr = eval(row["ocr_output"])
# I use <|endoftext|> here because it looks like that token is used for all
# sorts of purposes in CLIP including unk and pad (see
# https://huggingface.co/transformers/model_doc/clip.html#transformers.CLIPTokenizer)
row["text"] = "<|endoftext|>".join([b[1] for b in ocr])
return row




def filter_corrupt_images(examples):
"""remove problematic images"""
valid_images = []
for image_file in examples["resampled_img_path"]:
try:
# assert Path(image_file).exists()
Image.open(image_file)
valid_images.append(True)
except Exception:
valid_images.append(False)
return valid_images

def resample_image_path(examples):
output = []
for tpl in examples["meme_template"]:
output.append(random.choice(tuple(template_group[tpl])))
return {"resampled_img_path": output}

print("Filtering low count templates")
dataset = dataset.filter(lambda x: x["meme_template"] in keeplist)
print("Resampling image paths")
dataset = dataset.map(resample_image_path, batched=True, num_proc=8)
# Skip filtering bc I'm pretty sure we would've caught any corrupt images much
# earlier in the pipeline (e.g. during phashing)
# print("Filtering corrupt images")
# dataset = dataset.filter(
# filter_corrupt_images, batched=True, num_proc=16
# )
print("Splitting OCR")
dataset = dataset.map(split_ocr, num_proc=8)

print("Creating train val splits")
unique_text = dataset.unique("text")
train, val = train_test_split(unique_text, test_size=0.1, random_state=0xB1AB)
train = set(train)
val = set(val)

train_subset = dataset.filter(lambda x: x["text"] in train)
val_subset = dataset.filter(lambda x: x["text"] in val)

splits = DatasetDict({"train": train_subset, "test": val_subset})

dataset = dataset.class_encode_column("meme_template")

# splits = dataset.train_test_split(test_size=0.05, stratify_by_column="meme_template")
splits.save_to_disk(DATA_DIR / "representations/all-8-processed-clip-final")
39 changes: 39 additions & 0 deletions memes/representations/prepare_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Prepares huggingface dataset."""

from collections import Counter
from datasets import load_dataset, DatasetDict
from sklearn.model_selection import train_test_split

from memes.utils import DATA_DIR



dataset = load_dataset("csv", data_files=str(DATA_DIR / "representations/all-8-processed-filtered.csv"))
dataset = dataset["train"]

counts = Counter(dataset["meme_template"])
keeplist = {k for k, v in counts.items() if v >= 100}

def split_ocr(row):
ocr = eval(row["ocr_output"])
row["text"] = "</s></s>".join([b[1] for b in ocr])
return row

dataset = dataset.filter(lambda x: x["meme_template"] in keeplist)
dataset = dataset.map(split_ocr, num_proc=8)

dataset = dataset.class_encode_column("meme_template")

print("Creating train val splits")
unique_text = dataset.unique("text")
train, val = train_test_split(unique_text, test_size=0.1, random_state=0xB1AB)
train = set(train)
val = set(val)

train_subset = dataset.filter(lambda x: x["text"] in train)
val_subset = dataset.filter(lambda x: x["text"] in val)

splits = DatasetDict({"train": train_subset, "test": val_subset})

# splits = dataset.train_test_split(test_size=0.05, stratify_by_column="meme_template")
splits.save_to_disk(DATA_DIR / "representations/all-8-processed-dataset-final")
209 changes: 209 additions & 0 deletions memes/representations/train_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from accelerate import Accelerator
import argparse
from datetime import datetime
import multiprocessing
import os

import evaluate
from datasets import load_from_disk, Image
from transformers import (CLIPImageProcessor, CLIPTokenizer, CLIPModel, DataCollatorWithPadding, AutoConfig, get_scheduler)
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode

from tqdm.auto import tqdm
import wandb


from memes.utils import DATA_DIR, assert_dir, construct_output_filename



def main(args):
TRAIN_BATCH_SIZE = 512
EVAL_BATCH_SIZE = 512
EVAL_STEPS = 5 # 1000 if not args.debug else 10
LOG_STEPS = 1 # 500 if not args.debug else 10
BASE_LR = 4e-6
CLASSIFIER_LR = 4e-6 # following the fine-tuning blog post
NUM_EPOCHS = 3
FROM_SCRATCH = False
MIXED_PRECISION = True

precision = "bf16" if MIXED_PRECISION else "no"
accelerator = Accelerator(log_with="wandb", mixed_precision=precision)

time = f"-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
debug = "-debug" if args.debug else ""
scratch = "-scratch" if FROM_SCRATCH else "-pretrain"
mixed = "-bf16" if MIXED_PRECISION else ""
run_name = f"clip{scratch}{mixed}{time}{debug}"
accelerator.init_trackers("latent-memes", init_kwargs={"wandb": {"name":run_name}})

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

def tokenize(batch):
# set max length to 77 to match clip-vit-base-patch32 defaults
# should not be a big deal bc we have pretty short strings
return tokenizer(batch["text"], truncation=True, max_length=77, padding="max_length")


def process_image(batch):
image_paths = batch["resampled_img_path"]
processed = image_processor(images=image_paths, return_tensors="pt", padding=True)
batch["pixel_values"] = processed.pixel_values
return batch

dataset = load_from_disk(DATA_DIR / "representations/all-8-processed-clip").cast_column("resampled_img_path", Image())
dataset = dataset.remove_columns(['post_id', 'subreddit', 'meme_hash', 'ocr_output', 'img_path', 'meme_template'])
# num_labels = dataset["train"].features["labels"].num_classes

if args.debug:
train_dataset = dataset["train"].select(range(1000)).map(tokenize, batched=True, num_proc=multiprocessing.cpu_count())
eval_dataset = dataset["test"].select(range(100)).map(tokenize, batched=True, num_proc=multiprocessing.cpu_count())
else:
train_dataset = dataset["train"].map(
tokenize,
batched=True,
num_proc=multiprocessing.cpu_count()
)
eval_dataset = dataset["test"].map(
tokenize,
batched=True,
num_proc=multiprocessing.cpu_count()
).select(range(1000)) # we just take 1000 because we're running into OOM issues

def collator(batch):
pixel_values = torch.stack([d["pixel_values"] for d in batch])
input_ids = torch.tensor([d["input_ids"] for d in batch], dtype=torch.long)
attention_mask = torch.tensor([d["attention_mask"] for d in batch], dtype=torch.long)
return {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"return_loss": True,
}

train_dataset.set_format("torch")
train_dataset = train_dataset.remove_columns("text")
eval_dataset.set_format("torch")
eval_dataset = eval_dataset.remove_columns("text")

train_dataset.set_transform(process_image, output_all_columns=True)
eval_dataset.set_transform(process_image, output_all_columns=True)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=TRAIN_BATCH_SIZE, collate_fn=collator)
eval_dataloader = DataLoader(eval_dataset, shuffle=True, batch_size=EVAL_BATCH_SIZE, collate_fn=collator)

if not FROM_SCRATCH:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
else:
raise NotImplementedError()
# config = AutoConfig.from_pretrained("openai/clip-vit-base-patch32")
# model = CLIPModel.from_config(config)
# model.to(device)

optimizer = AdamW(model.parameters(), lr=BASE_LR)

num_epochs = NUM_EPOCHS
# num_training_steps = num_epochs * len(train_dataloader)
# lr_scheduler = get_scheduler(
# name="constant",
# optimizer=optimizer,
# num_warmup_steps=0,
# num_training_steps=num_training_steps
# )

metric = evaluate.load("accuracy")


output_dir = DATA_DIR / "representations/clip/checkpoints" / run_name
assert_dir(output_dir)

accelerator.print(output_dir)

model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
# accelerator.register_for_checkpointing(lr_scheduler)

@accelerator.on_main_process
def _save_model(model, epoch, step, best=False):
suffix = "best" if best else f"checkpoint-{epoch}_{step}"
model.module.save_pretrained(
output_dir / suffix,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
)

num_training_steps = num_epochs * len(train_dataloader)
progress_bar = tqdm(range(num_training_steps))
model.train()
step = 0
best_val_loss = float("inf")
for epoch in range(num_epochs):
for batch in train_dataloader:
# batch = {k: (v.to(device) if k != "return_loss" else v) for k, v in batch.items()}
# import pdb; pdb.set_trace()
outputs = model(**batch)
loss = outputs.loss
# print(loss)
# loss.backward()
accelerator.backward(loss)
optimizer.step()
# lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)

# TODO: modify the accuracy metric for contrastive context!

if step % EVAL_STEPS == 0 and step > 0:
model.eval()
for batch in eval_dataloader:
# batch = {k: (v.to(device) if k != "return_loss" else v) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
logits = outputs.logits_per_image
predictions = torch.argmax(logits, dim=-1)
metric.add_batch(predictions=predictions, references=torch.tensor(range(predictions.shape[0])))
accuracy = metric.compute()
if outputs.loss < best_val_loss:
_save_model(model, epoch, step, best=True)
accelerator.log({
"eval/epoch": epoch,
"eval/loss": outputs.loss,
"eval/accuracy": accuracy
})

model.train()

if step % LOG_STEPS == 0:
logits = outputs.logits_per_image
predictions = torch.argmax(logits, dim=-1).detach()
accuracy = metric.compute(predictions=predictions, references=torch.tensor(range(predictions.shape[0])))
# base_lr = lr_scheduler.get_last_lr()
base_lr = BASE_LR
accelerator.log({
"train/epoch": epoch,
"train/loss": loss,
"train/accuracy": accuracy,
"train/lr": base_lr,
})

step += 1
accelerator.print("Saving checkpoint")
_save_model(model, epoch, step)
accelerator.end_training()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--debug", action="store_true", default=False)
main(parser.parse_args())
Loading

0 comments on commit c82442c

Please sign in to comment.