From c82442c4417813ab0a527a99065a8f90360823ec Mon Sep 17 00:00:00 2001 From: Naitian Zhou Date: Thu, 16 Nov 2023 02:18:46 -0500 Subject: [PATCH] Add training code --- memes/representations/prepare_clip_dataset.py | 89 ++++++++ memes/representations/prepare_dataset.py | 39 ++++ memes/representations/train_clip.py | 209 ++++++++++++++++++ memes/representations/train_roberta.py | 199 +++++++++++++++++ 4 files changed, 536 insertions(+) create mode 100644 memes/representations/prepare_clip_dataset.py create mode 100644 memes/representations/prepare_dataset.py create mode 100644 memes/representations/train_clip.py create mode 100644 memes/representations/train_roberta.py diff --git a/memes/representations/prepare_clip_dataset.py b/memes/representations/prepare_clip_dataset.py new file mode 100644 index 0000000..1559a9d --- /dev/null +++ b/memes/representations/prepare_clip_dataset.py @@ -0,0 +1,89 @@ +"""Prepares huggingface dataset.""" + +from collections import Counter, defaultdict +from pathlib import Path +import random +from datasets import load_dataset, DatasetDict +from sklearn.model_selection import train_test_split +from PIL import Image +from tqdm.auto import tqdm + +from memes.utils import DATA_DIR + + +dataset = load_dataset( + "csv", data_files=str(DATA_DIR / "representations/all-8-processed-filtered.csv") +) +dataset = dataset["train"] + + +def get_template_group(dataset): + templates = dataset["meme_template"] + img_paths = dataset["img_path"] + template_group = defaultdict(set) + for i, tpl in enumerate(tqdm(templates)): + if not Path(img_paths[i]).exists(): + continue + template_group[tpl].add(img_paths[i]) + return template_group + +template_group = get_template_group(dataset) +keeplist = {k for k, v in template_group.items() if len(v) >= 100} + +def split_ocr(row): + ocr = eval(row["ocr_output"]) + # I use <|endoftext|> here because it looks like that token is used for all + # sorts of purposes in CLIP including unk and pad (see + # https://huggingface.co/transformers/model_doc/clip.html#transformers.CLIPTokenizer) + row["text"] = "<|endoftext|>".join([b[1] for b in ocr]) + return row + + + + +def filter_corrupt_images(examples): + """remove problematic images""" + valid_images = [] + for image_file in examples["resampled_img_path"]: + try: + # assert Path(image_file).exists() + Image.open(image_file) + valid_images.append(True) + except Exception: + valid_images.append(False) + return valid_images + +def resample_image_path(examples): + output = [] + for tpl in examples["meme_template"]: + output.append(random.choice(tuple(template_group[tpl]))) + return {"resampled_img_path": output} + +print("Filtering low count templates") +dataset = dataset.filter(lambda x: x["meme_template"] in keeplist) +print("Resampling image paths") +dataset = dataset.map(resample_image_path, batched=True, num_proc=8) +# Skip filtering bc I'm pretty sure we would've caught any corrupt images much +# earlier in the pipeline (e.g. during phashing) +# print("Filtering corrupt images") +# dataset = dataset.filter( +# filter_corrupt_images, batched=True, num_proc=16 +# ) +print("Splitting OCR") +dataset = dataset.map(split_ocr, num_proc=8) + +print("Creating train val splits") +unique_text = dataset.unique("text") +train, val = train_test_split(unique_text, test_size=0.1, random_state=0xB1AB) +train = set(train) +val = set(val) + +train_subset = dataset.filter(lambda x: x["text"] in train) +val_subset = dataset.filter(lambda x: x["text"] in val) + +splits = DatasetDict({"train": train_subset, "test": val_subset}) + +dataset = dataset.class_encode_column("meme_template") + +# splits = dataset.train_test_split(test_size=0.05, stratify_by_column="meme_template") +splits.save_to_disk(DATA_DIR / "representations/all-8-processed-clip-final") diff --git a/memes/representations/prepare_dataset.py b/memes/representations/prepare_dataset.py new file mode 100644 index 0000000..d828841 --- /dev/null +++ b/memes/representations/prepare_dataset.py @@ -0,0 +1,39 @@ +"""Prepares huggingface dataset.""" + +from collections import Counter +from datasets import load_dataset, DatasetDict +from sklearn.model_selection import train_test_split + +from memes.utils import DATA_DIR + + + +dataset = load_dataset("csv", data_files=str(DATA_DIR / "representations/all-8-processed-filtered.csv")) +dataset = dataset["train"] + +counts = Counter(dataset["meme_template"]) +keeplist = {k for k, v in counts.items() if v >= 100} + +def split_ocr(row): + ocr = eval(row["ocr_output"]) + row["text"] = "".join([b[1] for b in ocr]) + return row + +dataset = dataset.filter(lambda x: x["meme_template"] in keeplist) +dataset = dataset.map(split_ocr, num_proc=8) + +dataset = dataset.class_encode_column("meme_template") + +print("Creating train val splits") +unique_text = dataset.unique("text") +train, val = train_test_split(unique_text, test_size=0.1, random_state=0xB1AB) +train = set(train) +val = set(val) + +train_subset = dataset.filter(lambda x: x["text"] in train) +val_subset = dataset.filter(lambda x: x["text"] in val) + +splits = DatasetDict({"train": train_subset, "test": val_subset}) + +# splits = dataset.train_test_split(test_size=0.05, stratify_by_column="meme_template") +splits.save_to_disk(DATA_DIR / "representations/all-8-processed-dataset-final") diff --git a/memes/representations/train_clip.py b/memes/representations/train_clip.py new file mode 100644 index 0000000..ea95656 --- /dev/null +++ b/memes/representations/train_clip.py @@ -0,0 +1,209 @@ +from accelerate import Accelerator +import argparse +from datetime import datetime +import multiprocessing +import os + +import evaluate +from datasets import load_from_disk, Image +from transformers import (CLIPImageProcessor, CLIPTokenizer, CLIPModel, DataCollatorWithPadding, AutoConfig, get_scheduler) +import torch +from torch.utils.data import DataLoader +from torch.optim import AdamW + +from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize +from torchvision.transforms.functional import InterpolationMode + +from tqdm.auto import tqdm +import wandb + + +from memes.utils import DATA_DIR, assert_dir, construct_output_filename + + + +def main(args): + TRAIN_BATCH_SIZE = 512 + EVAL_BATCH_SIZE = 512 + EVAL_STEPS = 5 # 1000 if not args.debug else 10 + LOG_STEPS = 1 # 500 if not args.debug else 10 + BASE_LR = 4e-6 + CLASSIFIER_LR = 4e-6 # following the fine-tuning blog post + NUM_EPOCHS = 3 + FROM_SCRATCH = False + MIXED_PRECISION = True + + precision = "bf16" if MIXED_PRECISION else "no" + accelerator = Accelerator(log_with="wandb", mixed_precision=precision) + + time = f"-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + debug = "-debug" if args.debug else "" + scratch = "-scratch" if FROM_SCRATCH else "-pretrain" + mixed = "-bf16" if MIXED_PRECISION else "" + run_name = f"clip{scratch}{mixed}{time}{debug}" + accelerator.init_trackers("latent-memes", init_kwargs={"wandb": {"name":run_name}}) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") + image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") + + def tokenize(batch): + # set max length to 77 to match clip-vit-base-patch32 defaults + # should not be a big deal bc we have pretty short strings + return tokenizer(batch["text"], truncation=True, max_length=77, padding="max_length") + + + def process_image(batch): + image_paths = batch["resampled_img_path"] + processed = image_processor(images=image_paths, return_tensors="pt", padding=True) + batch["pixel_values"] = processed.pixel_values + return batch + + dataset = load_from_disk(DATA_DIR / "representations/all-8-processed-clip").cast_column("resampled_img_path", Image()) + dataset = dataset.remove_columns(['post_id', 'subreddit', 'meme_hash', 'ocr_output', 'img_path', 'meme_template']) + # num_labels = dataset["train"].features["labels"].num_classes + + if args.debug: + train_dataset = dataset["train"].select(range(1000)).map(tokenize, batched=True, num_proc=multiprocessing.cpu_count()) + eval_dataset = dataset["test"].select(range(100)).map(tokenize, batched=True, num_proc=multiprocessing.cpu_count()) + else: + train_dataset = dataset["train"].map( + tokenize, + batched=True, + num_proc=multiprocessing.cpu_count() + ) + eval_dataset = dataset["test"].map( + tokenize, + batched=True, + num_proc=multiprocessing.cpu_count() + ).select(range(1000)) # we just take 1000 because we're running into OOM issues + + def collator(batch): + pixel_values = torch.stack([d["pixel_values"] for d in batch]) + input_ids = torch.tensor([d["input_ids"] for d in batch], dtype=torch.long) + attention_mask = torch.tensor([d["attention_mask"] for d in batch], dtype=torch.long) + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + "return_loss": True, + } + + train_dataset.set_format("torch") + train_dataset = train_dataset.remove_columns("text") + eval_dataset.set_format("torch") + eval_dataset = eval_dataset.remove_columns("text") + + train_dataset.set_transform(process_image, output_all_columns=True) + eval_dataset.set_transform(process_image, output_all_columns=True) + + train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=TRAIN_BATCH_SIZE, collate_fn=collator) + eval_dataloader = DataLoader(eval_dataset, shuffle=True, batch_size=EVAL_BATCH_SIZE, collate_fn=collator) + + if not FROM_SCRATCH: + model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + else: + raise NotImplementedError() + # config = AutoConfig.from_pretrained("openai/clip-vit-base-patch32") + # model = CLIPModel.from_config(config) + # model.to(device) + + optimizer = AdamW(model.parameters(), lr=BASE_LR) + + num_epochs = NUM_EPOCHS + # num_training_steps = num_epochs * len(train_dataloader) + # lr_scheduler = get_scheduler( + # name="constant", + # optimizer=optimizer, + # num_warmup_steps=0, + # num_training_steps=num_training_steps + # ) + + metric = evaluate.load("accuracy") + + + output_dir = DATA_DIR / "representations/clip/checkpoints" / run_name + assert_dir(output_dir) + + accelerator.print(output_dir) + + model, optimizer, train_dataloader = accelerator.prepare( + model, optimizer, train_dataloader + ) + # accelerator.register_for_checkpointing(lr_scheduler) + + @accelerator.on_main_process + def _save_model(model, epoch, step, best=False): + suffix = "best" if best else f"checkpoint-{epoch}_{step}" + model.module.save_pretrained( + output_dir / suffix, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + ) + + num_training_steps = num_epochs * len(train_dataloader) + progress_bar = tqdm(range(num_training_steps)) + model.train() + step = 0 + best_val_loss = float("inf") + for epoch in range(num_epochs): + for batch in train_dataloader: + # batch = {k: (v.to(device) if k != "return_loss" else v) for k, v in batch.items()} + # import pdb; pdb.set_trace() + outputs = model(**batch) + loss = outputs.loss + # print(loss) + # loss.backward() + accelerator.backward(loss) + optimizer.step() + # lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + + # TODO: modify the accuracy metric for contrastive context! + + if step % EVAL_STEPS == 0 and step > 0: + model.eval() + for batch in eval_dataloader: + # batch = {k: (v.to(device) if k != "return_loss" else v) for k, v in batch.items()} + with torch.no_grad(): + outputs = model(**batch) + logits = outputs.logits_per_image + predictions = torch.argmax(logits, dim=-1) + metric.add_batch(predictions=predictions, references=torch.tensor(range(predictions.shape[0]))) + accuracy = metric.compute() + if outputs.loss < best_val_loss: + _save_model(model, epoch, step, best=True) + accelerator.log({ + "eval/epoch": epoch, + "eval/loss": outputs.loss, + "eval/accuracy": accuracy + }) + + model.train() + + if step % LOG_STEPS == 0: + logits = outputs.logits_per_image + predictions = torch.argmax(logits, dim=-1).detach() + accuracy = metric.compute(predictions=predictions, references=torch.tensor(range(predictions.shape[0]))) + # base_lr = lr_scheduler.get_last_lr() + base_lr = BASE_LR + accelerator.log({ + "train/epoch": epoch, + "train/loss": loss, + "train/accuracy": accuracy, + "train/lr": base_lr, + }) + + step += 1 + accelerator.print("Saving checkpoint") + _save_model(model, epoch, step) + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--debug", action="store_true", default=False) + main(parser.parse_args()) diff --git a/memes/representations/train_roberta.py b/memes/representations/train_roberta.py new file mode 100644 index 0000000..1e8506f --- /dev/null +++ b/memes/representations/train_roberta.py @@ -0,0 +1,199 @@ +import argparse +from datetime import datetime +import multiprocessing +import os +from accelerate import Accelerator + +import evaluate +from datasets import load_from_disk +from transformers import (AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, + Trainer, TrainingArguments, get_scheduler, DataCollatorWithPadding, create_optimizer) +import torch +from torch.utils.data import DataLoader +from torch.optim import AdamW +from tqdm.auto import tqdm +import wandb + + +from memes.utils import DATA_DIR, assert_dir, construct_output_filename + +def main(args): + TRAIN_BATCH_SIZE = 32 + EVAL_BATCH_SIZE = 32 + EVAL_STEPS = 1000 if not args.debug else 10 + LOG_STEPS = 500 if not args.debug else 10 + BASE_LR = 1e-6 + CLASSIFIER_LR = 1e-5 + NUM_EPOCHS = 3 + FROM_SCRATCH = False + MIXED_PRECISION = False + + time = f"-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + debug = "-debug" if args.debug else "" + scratch = "-scratch" if FROM_SCRATCH else "-pretrain" + run_name = f"c-roberta_pt_lowlr{scratch}{time}{debug}" + + precision = "bf16" if MIXED_PRECISION else "no" + accelerator = Accelerator(log_with="wandb", mixed_precision=precision) + + accelerator.init_trackers("latent-memes", init_kwargs={"wandb": {"name":run_name}}) + # wandb.init( + # project="latent-memes", + # name=run_name, + # ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + + # training_args = TrainingArguments( + # output_dir=DATA_DIR / f"representations/classifier/checkpoints/roberta-{time}", + # evaluation_strategy="steps", + # eval_steps=500 if not args.debug else 100, + # logging_steps=100 if not args.debug else 5, + # report_to="wandb", + # run_name=f"c-roberta-{time}{debug}", + # auto_find_batch_size=True, + # save_steps=0.1, + # learning_rate=1e-3, + # lr_scheduler_type="linear", + # warmup_steps=2000, + # ) + + tokenizer = AutoTokenizer.from_pretrained("roberta-base") + + def tokenize(batch): + return tokenizer(batch["text"], truncation=True, max_length=128) + + dataset = load_from_disk(DATA_DIR / "representations/all-8-processed-dataset") + dataset = dataset.remove_columns(['post_id', 'subreddit', 'meme_hash', 'ocr_output', 'img_path']) + dataset = dataset.rename_column("meme_template", "labels") + num_labels = dataset["train"].features["labels"].num_classes + + if args.debug: + train_dataset = dataset["train"].select(range(1000)).map(tokenize, batched=True, num_proc=multiprocessing.cpu_count()) + eval_dataset = dataset["test"].select(range(100)).map(tokenize, batched=True, num_proc=multiprocessing.cpu_count()) + else: + train_dataset = dataset["train"].map( + tokenize, + batched=True, + num_proc=multiprocessing.cpu_count() + ) + eval_dataset = dataset["test"].map( + tokenize, + batched=True, + num_proc=multiprocessing.cpu_count() + ).select(range(1000)) # we just take 1000 because we're running into OOM issues + + + collator = DataCollatorWithPadding(tokenizer=tokenizer) + + train_dataset.set_format("torch") + train_dataset = train_dataset.remove_columns("text") + eval_dataset.set_format("torch") + eval_dataset = eval_dataset.remove_columns("text") + + train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=TRAIN_BATCH_SIZE, collate_fn=collator) + eval_dataloader = DataLoader(eval_dataset, shuffle=True, batch_size=EVAL_BATCH_SIZE, collate_fn=collator) + + if not FROM_SCRATCH: + model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=num_labels) + else: + config = AutoConfig.from_pretrained("roberta-base", num_labels=num_labels) + model = AutoModelForSequenceClassification.from_config(config) + model.to(device) + + optimizer = AdamW([ + {"params": model.roberta.parameters(), "lr": BASE_LR}, + {"params": model.classifier.parameters(), "lr": CLASSIFIER_LR} + ]) + + num_epochs = NUM_EPOCHS + num_training_steps = num_epochs * len(train_dataloader) + lr_scheduler = get_scheduler( + name="constant", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=num_training_steps + ) + + metric = evaluate.load("accuracy") + + + output_dir = DATA_DIR / "representations/classifier/checkpoints" / run_name + assert_dir(output_dir) + + @accelerator.on_main_process + def _save_model(model, epoch, step, best=False): + suffix = "best" if best else f"checkpoint-{epoch}_{step}" + model_path = construct_output_filename( + subdir=output_dir, + prefix=None, + suffix=suffix, + ext="pt", + ) + torch.save(model.state_dict(), model_path) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + progress_bar = tqdm(range(num_training_steps)) + model.train() + step = 0 + best_val_loss = float("inf") + for epoch in range(num_epochs): + for batch in train_dataloader: + # batch = {k: v.to(device) for k, v in batch.items()} + outputs = model(**batch) + # import pdb; pdb.set_trace() + loss = outputs.loss + # print(loss) + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + + if step % EVAL_STEPS == 0 and step > 0: + model.eval() + for batch in eval_dataloader: + batch = {k: v.to(device) for k, v in batch.items()} + with torch.no_grad(): + outputs = model(**batch) + logits = outputs.logits + predictions = torch.argmax(logits, dim=-1) + metric.add_batch(predictions=predictions, references=batch["labels"]) + accuracy = metric.compute() + if outputs.loss < best_val_loss: + _save_model(model, epoch, step, best=True) + accelerator.log({ + "eval/epoch": epoch, + "eval/loss": outputs.loss, + "eval/accuracy": accuracy + }) + + model.train() + + if step % LOG_STEPS == 0: + logits = outputs.logits + predictions = torch.argmax(logits, dim=-1).detach() + accuracy = metric.compute(references=batch["labels"], predictions=predictions) + base_lr, classifier_lr = lr_scheduler.get_last_lr() + accelerator.log({ + "train/epoch": epoch, + "train/loss": loss, + "train/accuracy": accuracy, + "train/base_lr": base_lr, + "train/classifier_lr": classifier_lr, + }) + + step += 1 + accelerator.print("Saving checkpoint") + _save_model(model, epoch, step) + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--debug", action="store_true", default=False) + main(parser.parse_args()) \ No newline at end of file