Coder Social home page Coder Social logo

hsahuja111 / reconstructing-training-data-from-trained-neural-network Goto Github PK

View Code? Open in Web Editor NEW
0.0 1.0 1.0 29.5 MB

This repository has research paper implementation which reconstructs training data.

Jupyter Notebook 100.00%
reconstruction-algorithm training-data-generation binary-mlp mnist-dataset toy-dataset

reconstructing-training-data-from-trained-neural-network's Introduction

Reconstructing-Training-Data-From-Trained-Neural-Network

This repository contains implementation of a research paper - Reconstruction Training Data from Trained Neural Network. I have tried to explain this paper below but to go through each theoritical proof you can check the original paper which is present in this repository named Research-Paper.pdf.

Lets firstly understand the basic idea of this paper

Overview of the paper

Usually, what we do is train a Neural Network using some training data and then leverage it to predict the output for some unknown data. But, this paper has done some great work to extract the training data back from the trained model.

Lets understand this more concisely,

Input to us is a Trained MLP model. By trained MLP model we actually mean that we have the weights/parameters saved with us. Now ,just by using this we have to generate the training data used while training our model.

image

Maths Behind

In the actual training process of a Neural Network we apply gradient descent on the Loss function we got. This loss function is actually the difference between the predicted and original. We keep on iterating till the time we got minimum error.

But here, What will be our Loss functions??

The intuition of it lies in the fact that every MLP model although seems to be generic but often has some bias. In case of Binary MLP model trained with Gradient descent, it converges to the same solution as given by SVM margin maximising. Hence from the equation given by SVM Margin Maximing problem we have derived 2 loss functions for our problem. The 3rd Loss function is based upon generic thinking.

So this is the SVM equation:

image image

image

So here there are two

3 Loss functions used

1.Parameter Loss

image

2.Lambda Loss

image

3.Range Loss

Say, our training data were gray scale image so the pixel range will lies from -1 to 1. So we are penalising if we our resultant images goes outside of this range.

Suppose these Loss functions are L1,L2,L3 respectively.Total Loss will be weighted sum of L1,L2,L3.

Datasets

  1. TOY DATASET

We have initially tried on the toy dataset. The process involved selecting 10 fully red images and 10 fully blue images to train a Binary MLP model. Subsequently, 100 randomly initialized images were obtained, and the objective now is to regenerate the original images, half of which should be red and the other half blue.

image
  1. MNIST DATASET A pre-trained Binary MLP model was selected, which was trained using a dataset of 500 samples, comprising 250 samples of odd digits and the remaining 250 of even digits. The aim now is to generate these images back using a new set of 1000 samples.
image

Process Involved and Results

INPUT --> Parameters of Trained Model , Randomly initialised input image , parameters , lambda

Firstly, we have created a Binary MLP model and loaded the weights of the trained Binary MLP Model. You can find .pth file in this repository which has parameters(weights and biases).

CRUX

"Now, in the traditional process we apply gradient descent and tries to find the optimal value of weights and biases which would minimise the loss function. But here we have to find the optimal value of the pixels which would minimise the 3 loss functions explained above."

image

reconstructing-training-data-from-trained-neural-network's People

Contributors

franken14 avatar hsahuja111 avatar

Watchers

 avatar

Forkers

hkashyap0809

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.