Coder Social home page Coder Social logo

lucidrains / glom-pytorch Goto Github PK

View Code? Open in Web Editor NEW
189.0 15.0 27.0 104 KB

An attempt at the implementation of Glom, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns), for emergent part-whole heirarchies from data

License: MIT License

Python 100.00%
artificial-intelligence deep-learning geoffrey-hinton

glom-pytorch's Introduction

GLOM - Pytorch

An implementation of Glom, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns) for learning emergent part-whole heirarchies from data.

Yannic Kilcher's video was instrumental in helping me to understand this paper

Install

$ pip install glom-pytorch

Usage

import torch
from glom_pytorch import Glom

model = Glom(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)

Pass the return_all = True keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.

It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.

import torch
from glom_pytorch import Glom

model = Glom(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
all_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)

# get the top level outputs after iteration 6
top_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension)

Denoising self-supervised learning for encouraging emergence, as described by Hinton

import torch
import torch.nn.functional as F
from torch import nn
from einops.layers.torch import Rearrange

from glom_pytorch import Glom

model = Glom(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
noised_img = img + torch.randn_like(img)

all_levels = model(noised_img, return_all = True)

patches_to_images = nn.Sequential(
    nn.Linear(512, 14 * 14 * 3),
    Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
)

top_level = all_levels[7, :, :, -1]  # get the top level embeddings after iteration 6
recon_img = patches_to_images(top_level)

# do self-supervised learning by denoising

loss = F.mse_loss(img, recon_img)
loss.backward()

You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)

import torch
from glom_pytorch import Glom

model = Glom(
    dim = 512,
    levels = 6,
    image_size = 224,
    patch_size = 14
)

img1 = torch.randn(1, 3, 224, 224)
img2 = torch.randn(1, 3, 224, 224)
img3 = torch.randn(1, 3, 224, 224)

levels1 = model(img1, iters = 12)                   # image 1 for 12 iterations
levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6)  # image 3 for 6 iterations

Appreciation

Thanks goes out to Cfoster0 for reviewing the code

Todo

  • contrastive / consistency regularization of top-ish levels

Citations

@misc{hinton2021represent,
    title   = {How to represent part-whole hierarchies in a neural network}, 
    author  = {Geoffrey Hinton},
    year    = {2021},
    eprint  = {2102.12627},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

glom-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

glom-pytorch's Issues

glom effectiveness

Hey -- I was wondering if you published any results about the effectivness of the model or if you are running any experiments?

Implementing geometric mean for consensus opinion/levels_mean

Hi, I'm trying to implement the consensus opinion (levels_mean) as a geometric mean of the top-down predictions, bottom-up predictions, attention-weighted average of same-level embeddings, and embeddings of the previous time step as described by the original paper. Any ideas on how the weights should be set?

At first I thought this could be a learnable parameter, but section 9.1 reads

For interpreting a static image with no temporal context, the weights used for this weighted geometric mean need to change during the iterations that occur after a new fixation.

which leads me to believe that these might need to be outputted on the fly a la vanilla attention as opposed to being learned. Maybe an MLP that takes in the four source embeddings and outputs four scalars as weights?

Bug in forward?

Hello, thank you for making this code available! I think there could be a potential bug in the first line of the forward function:

b, h, w, _, device = *img.shape, img.device

but the input image shape is of kind b c h w, so it could be fixed by replacing it with

b, _, h, w, device = *img.shape, img.device

Am I wrong?

Classification

Hi @lucidrains !
Do you have any idea/insight on how to supervise classification (let's say, for example, MNIST digits classification) after having trained GLOM in an unsupervised way as a denoising autoencoder?
In the paper that seems to be the final goal.
However, it's not clear to me which columns and/or levels should be used for the classification. Also, since GLOM it's dealing with patches, how can single black patches vote towards a certain digit?

In other words, after training GLOM as a denoising autoencoder on MNIST, what we have is:

  • p X p columns, where p is the number of patches per dimension (e.g. 7X7=49 patches)
  • 6 levels for each column, where the top-most levels should in theory represent higher-level entities, so it seems natural to search for the digit information in these layers
  • 6*2=12 iterations, to allow for information to be passed by both top-down and bottom-up networks

Just by applying dimensionality reduction on the top-most level at different iterations does not seem enough to make the digit clusters emerge. So I'm wondering if you (or anybody else) have some insights on this.
Cheers!

help

Hello, when I tried to reproduce your model, I got this error. I'm not sure how to correct it, can y help me?

Traceback (most recent call last):
File "main.py", line 172, in
outputs = custom_model(images,iters = 12)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/root/class/glom_pytorch/glom_pytorch.py", line 109, in forward
consensus = self.attention(levels)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 727, in call_impl
result = self.forward(*input, **kwargs)
File "/root/class/glom_pytorch/glom_pytorch.py", line 49, in forward
sim.masked_fill
(self_mask, TOKEN_ATTEND_SELF_VALUE)
RuntimeError: Expected object of scalar type Bool but got scalar type Float for argument #2 'mask' in call to th_masked_fill_bool

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.