Coder Social home page Coder Social logo

tensorflow-cookbook's Introduction

Contributions

In now, this repo contains general architectures and functions that are useful for the GAN and classificstion.

I will continue to add useful things to other areas.

Also, your pull requests and issues are always welcome.

And write what you want to implement on the issue. I'll implement it.

How to use

Import

  • ops.py
    • operations
    • from ops import *
  • utils.py
    • image processing
    • from utils import *

Network template

def network(x, is_training=True, reuse=False, scope="network"):
    with tf.variable_scope(scope, reuse=reuse):
        x = conv(...)
        
        ...
        
        return logit

Insert data to network using DatasetAPI

Image_Data_Class = ImageData(img_size, img_ch, augment_flag)

trainA_dataset = ['./dataset/cat/trainA/a.jpg', 
                  './dataset/cat/trainA/b.png', 
                  './dataset/cat/trainA/c.jpeg', 
                  ...]
trainA = tf.data.Dataset.from_tensor_slices(trainA_dataset)
trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=16)
trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()

trainA_iterator = trainA.make_one_shot_iterator()
data_A = trainA_iterator.get_next()

logit = network(data_A)
  • See this for more information.

Option

  • padding='SAME'
    • pad = ceil[ (kernel - stride) / 2 ]
  • pad_type
    • 'zero' or 'reflect'
  • sn

Caution

  • If you don't want to share variable, set all scope names differently.

Weight

weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001)
weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001)

Initialization

  • Xavier : tf.contrib.layers.xavier_initializer()
  • He : tf.contrib.layers.variance_scaling_initializer()
  • Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
  • Truncated_normal : tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
  • Orthogonal : tf.orthogonal_initializer(1.0) / # if relu = sqrt(2), the others = 1.0

Regularization

  • l2_decay : tf.contrib.layers.l2_regularizer(0.0001)
  • orthogonal_regularizer : orthogonal_regularizer(0.0001) & orthogonal_regularizer_fully(0.0001)

Convolution

basic conv

x = conv(x, channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=True, scope='conv')

partial conv (NVIDIA Partial Convolution)

x = partial_conv(x, channels=64, kernel=3, stride=2, use_bias=True, padding='SAME', sn=True, scope='partial_conv')

p_conv p_result

dilated conv

x = dilate_conv(x, channels=64, kernel=3, rate=2, use_bias=True, padding='VALID', sn=True, scope='dilate_conv')

Deconvolution

basic deconv

x = deconv(x, channels=64, kernel=3, stride=1, padding='SAME', use_bias=True, sn=True, scope='deconv')

Fully-connected

x = fully_connected(x, units=64, use_bias=True, sn=True, scope='fully_connected')

Pixel shuffle

x = conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_down')
x = conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_up')
  • down ===> [height, width] -> [height // scale_factor, width // scale_factor]
  • up ===> [height, width] -> [height * scale_factor, width * scale_factor]

pixel_shuffle


Block

residual block

x = resblock(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block')
x = resblock_down(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_down')
x = resblock_up(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_up')
  • down ===> [height, width] -> [height // 2, width // 2]
  • up ===> [height, width] -> [height * 2, width * 2]

dense block

x = denseblock(x, channels=64, n_db=6, is_training=is_training, use_bias=True, sn=True, scope='denseblock')
  • n_db ===> The number of dense-block

residual-dense block

x = res_denseblock(x, channels=64, n_rdb=20, n_rdb_conv=6, is_training=is_training, use_bias=True, sn=True, scope='res_denseblock')
  • n_rdb ===> The number of RDB
  • n_rdb_conv ===> per RDB conv layer

attention block

x = self_attention(x, channels=64, use_bias=True, sn=True, scope='self_attention')
x = self_attention_with_pooling(x, channels=64, use_bias=True, sn=True, scope='self_attention_version_2')

x = squeeze_excitation(x, channels=64, ratio=16, use_bias=True, sn=True, scope='squeeze_excitation')

x = convolution_block_attention(x, channels=64, ratio=16, use_bias=True, sn=True, scope='convolution_block_attention')



Normalization

x = batch_norm(x, is_training=is_training, scope='batch_norm')
x = layer_norm(x, scope='layer_norm')
x = instance_norm(x, scope='instance_norm')
x = group_norm(x, groups=32, scope='group_norm')

x = pixel_norm(x)

x = batch_instance_norm(x, scope='batch_instance_norm')
x = switch_norm(x, scope='switch_norm')

x = condition_batch_norm(x, z, is_training=is_training, scope='condition_batch_norm'):

x = adaptive_instance_norm(x, gamma, beta)
  • See this for how to use condition_batch_norm
  • See this for how to use adaptive_instance_norm

Activation

x = relu(x)
x = lrelu(x, alpha=0.2)
x = tanh(x)
x = sigmoid(x)
x = swish(x)
x = elu(x)

Pooling & Resize

x = up_sample(x, scale_factor=2)

x = max_pooling(x, pool_size=2)
x = avg_pooling(x, pool_size=2)

x = global_max_pooling(x)
x = global_avg_pooling(x)

x = flatten(x)
x = hw_flatten(x)

Loss

classification loss

loss, accuracy = classification_loss(logit, label)

loss = dice_loss(n_classes=10, logit, label)

pixel loss

loss = L1_loss(x, y)
loss = L2_loss(x, y)
loss = huber_loss(x, y)
loss = histogram_loss(x, y)

loss = gram_style_loss(x, y)

loss = color_consistency_loss(x, y)
  • histogram_loss means the difference in the color distribution of the image pixel values.
  • gram_style_loss means the difference between the styles using gram matrix.
  • color_consistency_loss means the color difference between the generated image and the input image.

gan loss

d_loss = discriminator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit)
g_loss = generator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit)
  • Ra
  • loss_func
    • gan
    • lsgan
    • hinge
    • wgan-gp
    • dragan
  • See this for how to use gradient_penalty
d_bottleneck_loss = vdb_loss(real_mu, real_logvar, i_c) + vdb_loss(fake_mu, fake_logvar, i_c)

kl-divergence (z ~ N(0, 1))

loss = kl_loss(mean, logvar)

Author

Junho Kim

tensorflow-cookbook's People

Contributors

moono avatar taki0112 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.