Coder Social home page Coder Social logo

pytorch-containers's Introduction

pytorch-containers

This repository aims to help former Torchies more seamlessly transition to the "Containerless" world of PyTorch by providing a list of PyTorch implementations of Torch Table Layers.

Table of Contents

Note: As a result of full integration with autograd, PyTorch requires networks to be defined in the following manner:

  • Define all layers to be used in the __init__ method of your network
  • Combine them however you want in the forward method of your network (avoiding in place Tensor ops)

And that's all there is to it!

We will build upon a generic "TableModule" class that we initially define as:

class TableModule(nn.Module):
    def __init__(self):
        super(TableModule, self).__init__()
        self.layer1 = nn.Linear(5, 5).double()
        self.layer2 = nn.Linear(5, 10).double()
        
    def forward(self, x):
        ...
        ...
        ...
        return ...

ConcatTable

Torch

net = nn.ConcatTable()
net:add(nn.Linear(5, 5))
net:add(nn.Linear(5, 10))

input = torch.range(1, 5)
net:forward(input)

PyTorch

class TableModule(nn.Module):
    def __init__(self):
        super(TableModule, self).__init__()
        self.layer1 = nn.Linear(5, 5)
        self.layer2 = nn.Linear(5, 10)
        
    def forward(self,x):
        y = [self.layer1(x), self.layer2(x)]
        return y
        
input = Variable(torch.range(1, 5).unsqueeze(0))
net = TableModule()
net(input)

As you can see, PyTorch allows you to apply each member module that would have been part of your Torch ConcatTable, directly to the same input Variable. This offers much more flexibility as your architectures become more complex, and it's also a lot easier than remembering the exact functionality of ConcatTable, or any of the other tables for that matter.

Two other things to note:

  • To work with autograd, we must wrap our input in a Variable (we can also pass a python iterable of Variables)
  • PyTorch requires us to add a batch dimension which is why we call .unsqueeze(0) on the input

ParallelTable

Torch

net = nn.ParallelTable()
net:add(nn.Linear(10, 5))
net:add(nn.Linear(5, 10))

input1 = Torch.rand(1, 10)
input2 = Torch.rand(1, 5)
output = net:forward{input1, input2}

PyTorch

class TableModule(nn.Module):
    def __init__(self):
        super(TableModule,self).__init__()
        self.layer1 = nn.Linear(10, 5)
        self.layer2 = nn.Linear(5, 10)
        
    def forward(self,x):
        y = [self.layer1(x[0]), self.layer2(x[1])]
        return y
        
input1 = Variable(torch.rand(1, 10))
input2 = Variable(torch.rand(1, 5))
net = TableModule()
output = net([input1, input2])

MapTable

Torch

net = nn.MapTable()
net:add(nn.Linear(5, 10))

input1 = torch.rand(1, 5)
input2 = torch.rand(1, 5)
input3 = torch.rand(1, 5)
output = net:forward{input1, input2, input3}

PyTorch

class TableModule(nn.Module):
    def __init__(self):
        super(TableModule, self).__init__()
        self.layer = nn.Linear(5, 10)
        
    def forward(self, x):
        y = [self.layer(member) for member in x]
        return y
        
input1 = Variable(torch.rand(1, 5))
input2 = Variable(torch.rand(1, 5))
input3 = Variable(torch.rand(1, 5))
net = TableModule()
output = net([input1, input2, input3])

SplitTable

Torch

net = nn.SplitTable(2) # here we specify the dimension on which to split the input Tensor
input = torch.rand(2, 5)
output = net:forward(input)

PyTorch

class TableModule(nn.Module):
    def __init__(self):
        super(TableModule, self).__init__()
        
    def forward(self, x, dim):
        y = x.chunk(x.size(dim), dim)
        return y
        
input = Variable(torch.rand(2, 5))
net = TableModule()
output = net(input, 1)

Alternatively, we could have used torch.split() instead of torch.chunk(). See the docs.

JoinTable

Torch

net = nn.JoinTable(1)
input1 = torch.rand(1, 5)
input2 = torch.rand(2, 5)
input3 = torch.rand(3, 5)
output = net:forward{input1, input2, input3}

PyTorch

class TableModule(nn.Module):
    def __init__(self):
        super(TableModule, self).__init__()
        
    def forward(self, x, dim):
        y = torch.cat(x, dim)
        return y
        
input1 = Variable(torch.rand(1, 5))
input2 = Variable(torch.rand(2, 5))
input3 = Variable(torch.rand(3, 5))
net = TableModule()
output = net([input1, input2, input3], 0)

Note: We could have used torch.stack() instead of torch.cat(). See the docs.

Math Tables

The math table implementations are pretty intuitive, so the Torch implementations are omitted in this repo, but just like the others, their well-written descriptions and examples can be found by visiting their official docs.

PyTorch Math

Here we define one class that executes all of the math operations.

class TableModule(nn.Module):
    def __init__(self):
        super(TableModule, self).__init__()
        
    def forward(self, x1, x2):
        x_sum = x1+x2  # could use .sum() if input given as python iterable
        x_sub = x1-x2
        x_div = x1/x2
        x_mul = x1*x2
        x_min = torch.min(x1, x2)
        x_max = torch.max(x1, x2)
        return x_sum, x_sub, x_div, x_mul, x_min, x_max

input1 = Variable(torch.range(1, 5).view(1, 5))
input2 = Variable(torch.range(6, 10).view(1, 5))
net = TableModule()
output = net(input1, input2)
print(output)

And we get:

(Variable containing:
  7   9  11  13  15
[torch.FloatTensor of size 1x5]
, Variable containing:
-5 -5 -5 -5 -5
[torch.FloatTensor of size 1x5]
, Variable containing:
 0.1667  0.2857  0.3750  0.4444  0.5000
[torch.FloatTensor of size 1x5]
, Variable containing:
  6  14  24  36  50
[torch.FloatTensor of size 1x5]
, Variable containing:
 1  2  3  4  5
[torch.FloatTensor of size 1x5]
, Variable containing:
  6   7   8   9  10
[torch.FloatTensor of size 1x5]
)

The advantages that come with autograd when manipulating networks in these ways become much more apparent with more complex architectures, so let's combine some of the operations we defined above.

Intuitively Build Complex Architectures

Now we will visit a more complex example that combines several of the above operations. The graph below is a random network that I created using the Torch nngraph package. The Torch model definition using nngraph can be found here and a raw Torch implementation can be found here for comparison to the PyTorch code that follows.

class Branch(nn.Module):
    def __init__(self, b2):
        super(Branch, self).__init__()
        """
        Upon closer examination of the structure, note a
        MaxPool2d with the same params is used in each branch, 
        so we can just reuse this and pass in the 
        conv layer that is repeated in parallel right after 
        it (reusing it as well).
        """
        self.b = nn.MaxPool2d(kernel_size=2, stride=2)
        self.b2 = b2
        
    def forward(self,x):
        x = self.b(x) 
        y = [self.b2(x).view(-1), self.b2(x).view(-1)] # pytorch 'ParallelTable'
        z = torch.cat((y[0], y[1])) # pytorch 'JoinTable'
        return z

Now that we have a branch class general enough to handle both branches, we can define the base segments and piece it all together in a very natural way.

class ComplexNet(nn.Module):
    def __init__(self, m1, m2):
        super(ComplexNet, self).__init__()
        # define each piece of our network shown above
        self.net1 = m1 # segment 1 from VGG
        self.net2 = m2 #segment 2 from VGG
        self.net3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # last layer 
        self.branch1 = Branch(nn.Conv2d(64, 64, kernel_size=3, padding=1)) 
        self.branch2 = Branch(nn.Conv2d(128, 256, kernel_size=3, padding=1))
         
    def forward(self, x):
        """
        Here we see that autograd allows us to safely reuse Variables in 
        defining the computational graph.  We could also reuse Modules or even 
        use loops or conditional statements.
        Note: Some of this could be condensed, but it is laid out the way it 
        is for clarity.
        """
        x = self.net1(x)
        x1 = self.branch1(x) # SplitTable (implicitly)
        y = self.net2(x) 
        x2 = self.branch2(y) # SplitTable (implicitly)
        x3 = self.net3(y).view(-1)
        output = torch.cat((x1, x2, x3), 0) # JoinTable
        return output

This is a loop to define our VGG conv layers derived from pytorch/vision. (maybe a little overkill for our small case)

def make_layers(params, ch): 
    layers = []
    channels = ch
    for p in params:
            conv2d = nn.Conv2d(channels, p, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(inplace=True)]
            channels = p
    return nn.Sequential(*layers) 
   
net = ComplexNet(make_layers([64, 64], 3), make_layers([128, 128], 64))

This documented python code can be found here.

pytorch-containers's People

Contributors

amdegroot avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.