Coder Social home page Coder Social logo

lumenpallidium / energy_transformer Goto Github PK

View Code? Open in Web Editor NEW
10.0 2.0 2.0 72 KB

Pytorch implementation of an energy transformer - an energy-based reccurrent variant of the transformer.

License: MIT License

Python 100.00%
energy-transformer hopfield-network masked-image-modeling pytorch transformer

energy_transformer's Introduction

Description

This is a tiny repository, but I have been reusing the energy transformer across multiple projects so I wanted it to live a pip-installable existence (for myself and others if they are interested).

This repository contains an implementation of energy transformers, which may be the only Pytorch implementation at the moment. The jax implementation can be found here, this repository is a straightforward port of it, with some consolidation and adaption for PyTorch. The main file includes an example with the full self-supervised masked image reconstruction training used in the paper (except on CIFAR instead of Imagenet for speed). This example is optional and requires some extra (common) packages not installed during setup.

Briefly, an energy transformer is a variant of the transformer which runs a variant of attention in parallel with a Hopfield network. It is effectively recurrent, iteratively acting on its input as it descends the gradient of its energy function. The paper above contains the full mathematical details of the energy function. Note that, unlike a conventional transformer, this model has no feedforward layer: inputs have postional embedding added, then they are normalized and passed through the network; the input is iteratively modified by subtracting the network output then running the residual through the network (including normalization) again.

The Modern Hopfield variants (SoftmaxModernHopfield and BaseModernHopfield) that are used in the energy transformer are also available for import.

Installation

To install this package, run:

pip install git+https://github.com/LumenPallidium/energy_transformer.git

The only requirement is Pytorch (>=2.0). If you run the optional masked image reconstruction pipeline example, you will also need torchvision, einops, matplotlib, and tqdm.The above PIP install command will install Pytorch, but I would reccomend installing on your own independently, so you can configure any neccesary environments, CUDA tools, etc.

energy_transformer's People

Contributors

lumenpallidium avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

sailfish009

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.