Skip to content

Commit

Permalink
add simplecil and zsclip
Browse files Browse the repository at this point in the history
  • Loading branch information
12132321313 committed Oct 25, 2024
1 parent 8637cc5 commit 6c0bdb9
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 1 deletion.
15 changes: 15 additions & 0 deletions exps/simplecil.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"prefix": "reproduce",
"dataset": "cifar224",
"memory_size": 2000,
"memory_per_class": 20,
"fixed_memory": false,
"shuffle": true,
"init_cls": 10,
"increment": 10,
"model_name": "simplecil",
"convnet_type": "clip",
"device": ["0"],
"seed": [1993]
}

15 changes: 15 additions & 0 deletions exps/zs_clip.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"prefix": "reproduce",
"dataset": "cifar224",
"memory_size": 2000,
"memory_per_class": 20,
"fixed_memory": false,
"shuffle": true,
"init_cls": 10,
"increment": 10,
"model_name": "zs_clip",
"convnet_type": "clip",
"device": ["0"],
"seed": [1993]
}

88 changes: 88 additions & 0 deletions models/simplecil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging
import numpy as np
import torch
from torch import nn
from torch.serialization import load
from tqdm import tqdm
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils.inc_net import SimpleVitNet
from models.base import BaseLearner
from utils.toolkit import target2onehot, tensor2numpy,get_attribute


num_workers = 8

class Learner(BaseLearner):
def __init__(self, args):
super().__init__(args)
self._network = SimpleVitNet(args, True)
self.args=args

self.batch_size= get_attribute(args,"batch_size", 48)
self.init_lr= get_attribute(args,"init_lr", 0.01)
self.weight_decay= get_attribute(args,"weight_decay", 0.0005)
self.min_lr= get_attribute(args,"min_lr", 1e-8)

def after_task(self):
self._known_classes = self._total_classes

def replace_fc(self,trainloader, model, args):
model = model.eval()
embedding_list = []
label_list = []
# data_list=[]
with torch.no_grad():
for i, batch in enumerate(trainloader):
(_,data,label)=batch
data=data.to(self._device)
label=label.to(self._device)
embedding=model.convnet.encode_image(data)
# embedding = embedding / embedding.norm(dim=-1, keepdim=True)
embedding_list.append(embedding.cpu())
label_list.append(label.cpu())
embedding_list = torch.cat(embedding_list, dim=0)
label_list = torch.cat(label_list, dim=0)

class_list=np.unique(self.train_dataset.labels)
for class_index in class_list:
print('Replacing...',class_index)
#print(class_index)
data_index=(label_list==class_index).nonzero().squeeze(-1)
embedding=embedding_list[data_index]
proto=embedding.mean(0)
self._network.fc.weight.data[class_index]=proto

return model

def incremental_train(self, data_manager):
self._cur_task += 1
self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
self._network.update_fc(self._total_classes)
logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source="train",
mode="train", )
self.train_dataset = train_dataset
self.data_manager = data_manager
self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=num_workers)
test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test")
self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

train_dataset_for_protonet = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),
source="train", mode="test", )
self.train_loader_for_protonet = DataLoader(train_dataset_for_protonet, batch_size=self.batch_size,
shuffle=True, num_workers=num_workers)

if len(self._multiple_gpus) > 1:
print('Multiple GPUs')
self._network = nn.DataParallel(self._network, self._multiple_gpus)
self._train(self.train_loader, self.test_loader, self.train_loader_for_protonet)
if len(self._multiple_gpus) > 1:
self._network = self._network.module

def _train(self, train_loader, test_loader, train_loader_for_protonet):

self._network.to(self._device)
self.replace_fc(train_loader_for_protonet, self._network, None)
165 changes: 165 additions & 0 deletions models/zs_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import logging
import numpy as np
import torch
from torch import nn
from torch.serialization import load
from tqdm import tqdm
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils.inc_net import SimpleClipNet
from models.base import BaseLearner
from utils.toolkit import target2onehot, tensor2numpy, get_attribute, ClipLoss
from utils.data_manager import LaionData

# zero shot clip

num_workers = 8


class Learner(BaseLearner):
def __init__(self, args):
super().__init__(args)

self._network = SimpleClipNet(args, True)
self.batch_size = get_attribute(args, "batch_size", 48)
self.init_lr = get_attribute(args, "init_lr", 0.01)
self.weight_decay = get_attribute(args, "weight_decay", 0.0005)
self.min_lr = get_attribute(args, "min_lr", 1e-8)
self.args = args


def after_task(self):
self._known_classes = self._total_classes

def incremental_train(self, data_manager):
self._cur_task += 1
self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
self._network.update_fc(self._total_classes)
logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source="train",
mode="train", )
self.train_dataset = train_dataset
self.data_manager = data_manager
self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=num_workers)
test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test")
self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

# train_dataset_for_protonet=data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="test", )
# self.train_loader_for_protonet = DataLoader(train_dataset_for_protonet, batch_size=self.batch_size, shuffle=True, num_workers=num_workers)

if len(self._multiple_gpus) > 1:
print('Multiple GPUs')
self._network = nn.DataParallel(self._network, self._multiple_gpus)

if len(self._multiple_gpus) > 1:
self._network = self._network.module

self._network.to(self._device)

def _compute_accuracy(self, model, loader):
self._network.eval()
class_to_label = self.data_manager._class_to_label
templates = self.data_manager._data_to_prompt
total_labels = class_to_label[:self._total_classes] # mask all known classes
text_features = []
with torch.no_grad():
for l in total_labels:
texts = [t.format(l) for t in templates]
texts = self._network.tokenizer(texts).to(self._device)
class_embeddings = self._network.convnet.encode_text(texts) # num_str, dim
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
class_embeddings = class_embeddings.mean(dim=0)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
text_features.append(class_embeddings)
text_features = torch.stack(text_features, dim=0) # num_classes, dim

correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(loader):
inputs = inputs.to(self._device)
with torch.no_grad():
# outputs = model(inputs)["logits"]
with torch.no_grad():
# outputs = self._network(inputs)["logits"]
image_features = self._network.convnet.encode_image(inputs)
image_features /= image_features.norm(dim=-1, keepdim=True) # bs, dim
outputs = image_features @ text_features.T # bs, num_classes
predicts = torch.max(outputs, dim=1)[1]
correct += (predicts.cpu() == targets).sum()
total += len(targets)
print('Accuracy: {:.2f}%'.format(correct * 100 / total))
return np.around(tensor2numpy(correct) * 100 / total, decimals=2)

def _eval_cnn(self, loader):
self._network.eval()
class_to_label = self.data_manager._class_to_label
templates = self.data_manager._data_to_prompt
total_labels = class_to_label[:self._total_classes] # mask all known classes
text_features = []
with torch.no_grad():
for l in total_labels:
texts = [t.format(l) for t in templates]
texts = self._network.tokenizer(texts).cuda()
class_embeddings = self._network.convnet.encode_text(texts)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
class_embeddings = class_embeddings.mean(dim=0)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
text_features.append(class_embeddings)
text_features = torch.stack(text_features, dim=0)

y_pred, y_true = [], []
for _, (_, inputs, targets) in enumerate(loader):
inputs = inputs.to(self._device)
with torch.no_grad():
image_features = self._network.convnet.encode_image(inputs)
image_features /= image_features.norm(dim=-1, keepdim=True)
outputs = image_features @ text_features.T
predicts = torch.topk(
outputs, k=self.topk, dim=1, largest=True, sorted=True
)[
1
] # [bs, topk]
y_pred.append(predicts.cpu().numpy())
y_true.append(targets.cpu().numpy())

return np.concatenate(y_pred), np.concatenate(y_true) # [N, topk]

def _eval_zero_shot(self):
self._network.eval()
class_to_label = self.data_manager._class_to_label
templates = self.data_manager._data_to_prompt
total_labels = class_to_label # [:self._total_classes] # mask all known classes
text_features = []
with torch.no_grad():
for l in total_labels:
texts = [t.format(l) for t in templates]
texts = self._network.tokenizer(texts).cuda()
# class_embeddings = self._network.encode_text(texts)
class_embeddings = self._network.convnet.encode_text(texts)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
class_embeddings = class_embeddings.mean(dim=0)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
text_features.append(class_embeddings)
text_features = torch.stack(text_features, dim=0)

test_dataset = self.data_manager.get_dataset(np.arange(0, len(total_labels)), source="test", mode="test")
loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8)

y_pred, y_true = [], []
logits = []
for _, (_, inputs, targets) in enumerate(loader):
inputs = inputs.to(self._device)
with torch.no_grad():
# image_features=self._network.encode_image(inputs)
image_features = self._network.convnet.encode_image(inputs)
image_features /= image_features.norm(dim=-1, keepdim=True)
outputs = image_features @ text_features.T
predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1]
y_pred.append(predicts.cpu().numpy())
y_true.append(targets.cpu().numpy())
logits.append(outputs.cpu().numpy())

return np.concatenate(y_pred), np.concatenate(y_true) # [N, topk]


7 changes: 6 additions & 1 deletion utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ def get_model(model_name, args):
if name=="proof":
from models.proof import Learner
return Learner(args)

elif name == "simplecil":
from models.simplecil import Learner
return Learner(args)
elif name =="zs_clip":
from models.zs_clip import Learner
return Learner(args)
else:
assert 0

0 comments on commit 6c0bdb9

Please sign in to comment.