Coder Social home page Coder Social logo

nphdang / fs-bbt Goto Github PK

View Code? Open in Web Editor NEW
10.0 1.0 1.0 41.26 MB

Black-box Few-shot Knowledge Distillation

Batchfile 1.41% Python 98.59%
black-box black-box-model computer-vision conditional-variational-autoencoder cvae data-augmentation data-generation deep-learning few-shot few-shot-classification

fs-bbt's Introduction

FS-BBT: Knowledge Distillation with Few Samples and Black-box Teacher

This is the implementation of the FS-BBT method in the paper "Black-box Few-shot Knowledge Distillation", ECCV 2022: https://eccv2022.ecva.net/

Introduction

Knowledge distillation (KD) is an efficient approach to transfer the knowledge from a large “teacher” network to a smaller “student” network. Traditional KD methods require lots of labeled training samples and a white-box teacher (parameters are accessible) to train a good student. However, these resources are not always available in real-world applications. The distillation process often happens at an external party side where we do not have access to much data, and the teacher does not disclose its parameters due to security and privacy concerns. To overcome these challenges, we propose a black-box few-shot KD method to train the student with few unlabeled training samples and a black-box teacher. Our main idea is to expand the training set by generating a diverse set of out-of-distribution synthetic images using MixUp and a conditional variational auto-encoder. These synthetic images along with their labels obtained from the teacher are used to train the student. We conduct extensive experiments to show that our method significantly outperforms recent SOTA few/zero-shot KD methods on image classification tasks.

FS-BBT framework

framework

Results on MNIST and Fashion

results_mnist

Results on CIFAR-10 and CIFAR-100

results_cifar

Installation

  1. Python 3.6.7
  2. numpy 1.19.5
  3. scikit-learn 0.23
  4. scipy 1.3.1
  5. TensorFlow 1.15
  6. Keras 2.2.5

How to run

  • Each folder corresponds to a dataset
  • Run the ".bat" files to train the Teacher network, CVAE model, and the Student network
  • The pre-trained models of Teacher, CVAE, and Student are stored in the corresponding folders, and can be used directly to save time

Reference

Dang Nguyen, Sunil Gupta, Kien Do, Svetha Venkatesh (2022). Black-box Few-shot Knowledge Distillation. ECCV 2022, Tel Aviv, Israel

fs-bbt's People

Contributors

nphdang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

mldl

fs-bbt's Issues

What variable should I put in the third parameter of input layer of cvae model?

Epoch 1/100
Traceback (most recent call last):
File "cvae.py", line 137, in
validation_split=0.05, verbose=1)
File "/home/109511244/.local/lib/python3.6/site-packages/keras/engine/training.py", line 1184, in fit
tmp_logs = self.train_function(iterator)
File "/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 885, in call
result = self._call(*args, **kwds)
File "/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 933, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 760, in _initialize
*args, **kwds))
File "/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3066, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3463, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3308, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 1007, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 668, in wrapped_fn
out = weak_wrapped_fn().wrapped(*args, **kwds)
File "/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 994, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

/home/109511244/.local/lib/python3.6/site-packages/keras/engine/training.py:853 train_function  *
    return step_function(self, iterator)
/home/109511244/.local/lib/python3.6/site-packages/keras/engine/training.py:842 step_function  **
    outputs = model.distribute_strategy.run(run_step, args=(data,))
/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:1286 run
    return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2849 call_for_each_replica
    return self._call_for_each_replica(fn, args, kwargs)
/home/109511244/.local/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:3632 _call_for_each_replica
    return fn(*args, **kwargs)
/home/109511244/.local/lib/python3.6/site-packages/keras/engine/training.py:835 run_step  **
    outputs = model.train_step(data)
/home/109511244/.local/lib/python3.6/site-packages/keras/engine/training.py:787 train_step
    y_pred = self(x, training=True)
/home/109511244/.local/lib/python3.6/site-packages/keras/engine/base_layer.py:1020 __call__
    input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
/home/109511244/.local/lib/python3.6/site-packages/keras/engine/input_spec.py:202 assert_input_compatibility
    ' input tensors. Inputs received: ' + str(inputs))

ValueError: Layer cvae expects 3 input(s), but it received 2 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, 784) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None, 10) dtype=float32>]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.