Coder Social home page Coder Social logo

hmcalister / tensorflow-intepretability-project Goto Github PK

View Code? Open in Web Editor NEW
1.0 2.0 0.0 8.46 MB

An investigation into sequential learning of tasks using feed-forward networks built with Tensorflow

Python 11.59% Jupyter Notebook 88.41%
catastrophic-forgetting deep-learning elastic-weight-consolidation interpretability interpretable-machine-learning sequential-learning tensorflow visualization

tensorflow-intepretability-project's Introduction

Tensorflow Sequential Learning

An implementation of sequential learning in Tensorflow with Keras Sequential models, including infrastructure for elastic weight consolidation techniques.

Sequential learning involves having a single model (or multiple models with some shared sections) learn a series of tasks in sequence i.e. learn task A without knowledge of task B, then learn task B without access to the training data of task A. The goal is to prevent models "forgetting" (perhaps, catastrophically) previous tasks when new tasks are introduced

Project Architecture

Models are created and structured to have a base model (shared weights) with additional task-specific layers placed on top. Models are first created (with the tensorflow functional API to share layers) then passed to SequentialTasks (see SequentialLearning/SequentialTasks, and below) which manage an individual task, abstracting away the sequential learning aspect and focus on a single model on a single dataset. The SequentialLearningManager handles sequencing tasks one after the other as well as orchestrating testing of tasks and any sequential learning techniques such as EWC.

Sequential Task

Conceptually holds all information for a single task in a learning environment. Holds a reference to the task model (the base model plus this tasks extra layers), the training and validation datasets for this task, and the loss function. Holds responsibility for compiling the model (see EWC section below) as well as explicitly calling the fit method to train the model. Has methods for evaluating the model on validation data which is set up to be called at the end of each epoch.

Also has several seemingly superfluous attributes (x_lim, y_lim, data_fn...). These attributes are present for the possibility of adding functionality for graphing model input/outputs for interpretability. Currently these are unused and may be removed in the future.

SequentialLearningManager

Manages learning tasks sequentially. Holds reference to all tasks, as well as being the main interface for the main script. Instantiated with a list of SequentialTasks (with models already present) and parameters for behavior of elastic weight consolidation.

Includes a callback for managing task validation each epoch, so task performance can be graphed over all epochs for each task. On a technical note, this callback data is taken directly from model.evaluate(..., return_dict=True), so the stored data is a dictionary with keys being the loss+metric names and values being tuples of task performances each epoch.

EWC_Methods

Holds all implementations of elastic weight consolidation in this project. Callbacks for data collection (relating to EWC), methods for calculating the weight importance matrix Omega, and anything else is present here. Each elastic weight consolidation method has a unique way of calculating (or approximating, or guessing) weight importance. The EWC_Methods.py file acts as a simple interface to these methods, allowing different methods to be used in a plug-and-play manner by selecting a method from the EWC_Method Enum within.

Interpretability

A second focus of this project is on interpretability of neural networks (particularly convolutional networks, but other network architectures may be investigated in future). This module provides some methods to probe a network for interpretable features, including kernel activation, occlusion sensitivity, and GRADCAM.

This module also includes methods in the ModelAugmentation.py file that allows for altering tensorflow models based on the weight importance calculated during elastic weight consolidation. The goal here is to see if weight importance can provide a useful insight into model interpretability, or boost current interpretability measures.

tensorflow-intepretability-project's People

Contributors

hmcalister avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.