Coder Social home page Coder Social logo

centroidsmatching's Introduction

Centroids Matching: an efficient Continual Learning approach operating in the embedding space

ArXiv preprint

Abstract ๐Ÿ“‘

Catastrophic forgetting (CF) occurs when a neural network loses the information previously learned while training on a set of samples from a different distribution, i.e., a new task. Existing approaches have achieved remarkable results in mitigating CF, especially in a scenario called task incremental learning. However, this scenario is not realistic, and limited work has been done to achieve good results on more realistic scenarios. In this paper, we propose a novel regularization method called Centroids Matching, that, inspired by meta-learning approaches, fights CF by operating in the feature space produced by the neural network, achieving good results while requiring a small memory footprint. Specifically, the approach classifies the samples directly using the feature vectors produced by the neural network, by matching those vectors with the centroids representing the classes from the current task, or all the tasks up to that point. Centroids Matching is faster than competing baselines, and it can be exploited to efficiently mitigate CF, by preserving the distances between the embedding space produced by the model when past tasks were over, and the one currently produced, leading to a method that achieves high accuracy on all the tasks, without using an external memory when operating on easy scenarios, or using a small one for more realistic ones. % The novelty of our proposal is that it works in the embedding space of the tasks, both during the training and during the mitigation of the CF. Extensive experiments demonstrate that CM achieves accuracy gains on multiple datasets and scenarios.

Centroids Matching (CM) in short ๐ŸŽฏ

CM is an approach to alleviate the catastrophic forgetting while training a neural network on a new set of samples, which is based on a neural networks that operates directly in the embedding space, by creating a centroid for each class, and forcing the samples to be as close as possible to the correct centroid. It is capable of alleviating the CF in two possible scenarios: Task and Class Incremental Learning. In the first one, the model is simply regularize by reducing the distance between the output obtained while training the current task and the ones obtained from the past model.
In the latter scenario, we also use an external memory containing samples from past tasks, and the embeddings spaces of past tasks are merged while training on the new one.

A visual overview of how the approach works is the following:

How to use the code and replicate the results

The entry point used to run the code is the script main.py, which CMD input is a hydra configuration; please check the hydra library to understand how it works.

For reference, folder configs contains all config files used, and the folder bash contains all the scripts used to run the experiments.

Requirements

The required libraries are:

Cite

Please cite our work if you found it useful:

@article{pomponi2022centroids,
  title={Centroids Matching: an efficient Continual Learning approach operating in the embedding space},
  author={Pomponi, Jary and Scardapane, Simone and Uncini, Aurelio},
  journal={arXiv preprint arXiv:2208.02048},
  year={2022}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.