Hi, thank you for the great baseline repo.
I am trying to setup each dataset (with different number of classes) as a task and perform continual learning. However I am a little lost regarding how the models are handling the multiple heads with potentially different output features. Do you have any suggestions on how this might be addressed?
Currently I am using something like:
class MultiTaskModel(ResNet):
def __init__(self):
super(MultiTaskModel, self).__init__(BasicBlock, [2,2,2,2])
resnet = torchvision.models.resnet34(pretrained=True)
self.in_feature = resnet.fc.in_features
self.tasks = []
self.fc = None
# add all layers that are not fc or classifier to the model
self.shared = nn.Sequential()
for name, module in resnet.named_children():
if name != 'fc' and name != 'classifiers':
self.shared.add_module(name, module)
# self.classifiers.append(resnet.fc)
def set_task(self, task):
print("Setting task to", task)
self.tasks.append(task)
print(f"tasks are {self.tasks}")
print(f"task index is {task_list.index(task)}")
# add a new fc layer for the new task
self.fc = nn.Linear(self.in_feature, classes_per_task[task_list.index(task)])
self.fc.apply(kaiming_normal_init)
print(f"fc is {self.fc}")
def forward(self, x):
x = self.shared(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
But unable to replicate the results for some methods, such as the EWC
. If I use the default lambda = 400
as in the repo, the loss becomes Nan
.
Currently, I am maintaining the sequence of recogseq dataset but observing Catastrophic forgetting for EWC and LWF. To be precise, I calculated the performance of each task after training the last task.
This is my training loop:
for task in dataset_names:
train_loader, val_loader, _, _, _ = get_dataloaders(task, 0.8, batch_size)
all_train_loaders[task] = train_loader
all_val_loaders[task] = val_loader
for idx, task in tqdm.tqdm(enumerate(task_list)):
current_train_loader = all_train_loaders[task]
current_val_loader = all_val_loaders[task]
model = MultiTaskModel().to(device)
if idx > 0:
print(f"Previous task: {task_list[idx-1]}")
ckpt = torch.load(f"epoch_{task_list[idx-1]}.pth.tar")
model.load_state_dict(ckpt['state_dict'])
model = model.to(device)
start_time = time.time()
model, acc = fine_tune_EWC_acuumelation(current_train_loader, current_val_loader, model, reg_lambda=1, num_epochs=num_epochs, lr=0.008, batch_size=batch_size, weight_decay=0, current_task=task)
```
Here's a sample output :
```bash
Model loaded for task svhn
Performance of previous task: flowers
fc set to Linear(in_features=512, out_features=103, bias=True)
Accuracy of the network on the 1311 test images: 3.75
Performance of previous task: scenes
fc set to Linear(in_features=512, out_features=67, bias=True)
Accuracy of the network on the 3123 test images: 5.013020833333333
Performance of previous task: birds
fc set to Linear(in_features=512, out_features=201, bias=True)
Accuracy of the network on the 2358 test images: 0.9982638888888888
Performance of previous task: cars
fc set to Linear(in_features=512, out_features=196, bias=True)
Accuracy of the network on the 1621 test images: 0.6875
Performance of previous task: aircraft
fc set to Linear(in_features=512, out_features=56, bias=True)
Accuracy of the network on the 2000 test images: 15.574596774193548
Performance of previous task: chars
fc set to Linear(in_features=512, out_features=63, bias=True)
Accuracy of the network on the 12599 test images: 43.247767857142854
Performance of previous task: svhn
fc set to Linear(in_features=512, out_features=10, bias=True)
Accuracy of the network on the 26032 test images: 96.13223522167488
Running accuracy for task svhn is [3.75, 5.013020833333333, 0.9982638888888888, 0.6875, 15.574596774193548, 43.247767857142854, 96.13223522167488]
Mean accuracy for task svhn is 23.629054939319072