Coder Social home page Coder Social logo

trellixvulnteam / mmd-gan_y00k Goto Github PK

View Code? Open in Web Editor NEW

This project forked from richardwth/mmd-gan

0.0 0.0 0.0 27.57 MB

Improving MMD-GAN training with repulsive loss function

License: Apache License 2.0

Python 97.21% HTML 0.04% Jupyter Notebook 2.76%

mmd-gan_y00k's Introduction

MMD-GAN with Repulsive Loss Function

GAN: generative adversarial nets; MMD: maximum mean discrepancy; TF: TensorFlow

This repository contains codes for MMD-GAN and the repulsive loss proposed in ICLR paper [1]:
Wei Wang, Yuan Sun, Saman Halgamuge. Improving MMD-GAN Training with Repulsive Loss Function. ICLR 2019. URL: https://openreview.net/forum?id=HygjqjR9Km.

About the code

The code defines the neural network architecture as dictionaries and strings to ease test of different models. It also contains many other models I have tried, so sorry if you find it a little bit confusing.

The structure of code:

  1. DeepLearning/my_sngan/SNGan defines how a general GAN model is trained and evaluated.
  2. GeneralTools contains various tools:
    1. graph_func contains functions to run a model graph and metrics for evaluating generative models (Line 1595).
    2. input_func contains functions to handle datasets and input pipeline.
    3. layer_func contains functions to convert network architecture dictionary to operations
    4. math_func defines various mathematical operations. You may find spectral normalization at Line 397, loss functions for GAN at Line 2088, repulsive loss at Line 2505, repulsive with bounded kernel (referred to as rmb) at Line 2530.
    5. misc_fun contains FLAGs for the code.
  3. my_test_ contain the specific model architectures and hyperparameters.

Running the tests

  1. Modify GeneralTools/misc_func accordingly;
  2. Read Data/ReadMe.md; download and prepare the datasets;
  3. Run my_test_ with proper hyperparameters.

About the algorithms

Here we introduce the algorithms and tricks.

Proposed Methods

The paper [1] proposed three methods:

  1. Repulsive loss

equation

equation

where equation - real samples, equation - generated samples, equation - kernel formed by the discriminator equation and kernel equation. The discriminator loss of previous MMD-GAN [2], or what we called attractive loss, is equation.

Below is an illustration of the effects of MMD losses on free R(eal) and G(enerated) particles (code in Figures folder). The particles stand for discriminator outputs of samples, but, for illustration purpose, we allow them to move freely. These GIFs extend the Figure 1 of paper [1].

mmd_d_att mmd_d_rep
paired with paired with
mmd_g_att mmd_g_rep

In the first row, we randomly initialized the particles, and applied or for 600 steps. The velocity of each particle is . In the second row, we obtained the particle positions at the 450th step of the first row and applied for another 600 steps with velocity . The blue and orange arrows stand for the gradients of attractive and repulsive components of MMD losses respectively. In summary, these GIFs indicate how MMD losses may move the free particles. Of course, the actual case of MMD-GAN is much more complex as we update the model parameters instead of output scores directly and both networks are updated at each step.

We argue that may cause opposite gradients from attractive and repulsive components of both and during training, and thus slow down the training process. Note this is different from the end-stage training when the gradients should be opposite and cancelled out to reach 0. Another way of interpretation is that, by minimizing , the discriminator maximizes the similarity between the outputs of real samples, which results in D focusing on the similarities among real images and possibly ignoring the fine details that separate them. The repulsive loss actively learns such fine details to make real sample outputs repel each other.

  1. Bounded kernel (used only in equation)

equation

equation

The gradient of Gaussian kernel is near 0 when the input distance is too small or large. The bounded kernel avoids kernel saturation by truncating the two tails of distance distribution, an idea inspired by the hinge loss. This prevents the discriminator from becoming too confident.

  1. Power iteration for convolution (used in spectral normalization)

At last, we proposed a method to calculate the spectral norm of convolution kernel. At iteration t, for convolution kernel equation, do equation, equation, and equation. The spectral norm is estimated as equation.

Practical Tricks and Issues

We recommend using the following tricks.

  1. Spectral normalization, initially proposed in [3]. The idea is, at each layer, to use equation for convolution/dense multiplication. Here we multiply the signal with a constant after each spectral normalization to compensate for the decrease of signal norm at each layer. In the main text of paper [1], we used empirically. In Appendix C.3 of paper [1], we tested a variety of values.
  2. Two time-scale update rule (TTUR) [4]. The idea is to use different learning rates for the generator and discriminator.

Unlike the case of Wasserstein GAN, we do not encourage using the repulsive loss for discriminator or the MMD loss for generator to indicate the progress of training. You may find that, during the training process,

  • both and may be close to 0 initially; this is because both G and D are weak.
  • may gradually increase during training; this is because it becomes harder for G to generate high quality samples and fool D (and G may not have the capacity to do so).

For balanced and capable G and D, we would expect both and to stay close to 0 during the whole training process and any kernel (i.e., , and ) to be away from 0 or 1 and stay in the middle (e.g., 0.6).

In some cases, you may find training using the repulsive loss diverges. Do not panic. It may be that the learning rate is not suitable. Please try other learning rate or the bounded kernel.

Final Comments

Thank you for reading!

Please feel free to leave comments if things do not work or suddenly work, or if exploring my code ruins your day. :)

Reference

[1] Wei Wang, Yuan Sun, Saman Halgamuge. Improving MMD-GAN Training with Repulsive Loss Function. ICLR 2019. URL: https://openreview.net/forum?id=HygjqjR9Km.
[2] Chun-Liang Li, Wei-Cheng Chang, Yu Cheng, Yiming Yang, and Barnabas Poczos. MMD GAN: Towards deeper understanding of moment matching network. In NeurIPS, 2017. [3] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, and Yuichi Yoshida. Spectral normalization for generative adversarial networks. In ICLR, 2018.
[4] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. GANs Trained by a Two Time-Scale Update Rule Converge to a Nash Equilibrium. In NeurIPS, 2017.

mmd-gan_y00k's People

Contributors

richardwth avatar trellixvulnteam avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.