Coder Social home page Coder Social logo

ponderalbert's Introduction

PonderALBERT

Hi! This is an experimental project that tries to combine variable-depth Pretrained Transformers with a halt mechanism. For this project, I choose to use ALBERT (Lan et. al), which has variable depth (due to weight-sharing), and the Halting mechanism proposed in the recent "PonderNet: Learning to Ponder" (Banino et. al).

For a detailed description of the halting mechanism, I suggest reading the PonderNet paper or watching the amazing video explanation made by Yannic Kilcher.

Usage

Model loading

from transformers import AlbertConfig, AlbertTokenizer
from ponder_albert.models import PonderAlbertClassifier

# A blank classifier with an Albert encoder can be initialized directly using an AlbertConfig object
# and a trained tokenizer
config = AlbertConfig(num_hidden_layers=12)
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = PonderAlbertClassifier(config, tokenizer, target_halt_probability=0.2)

# Alternatively, you can initialize the classifier using pretrained weights from the HF model database.
# Since ALBERT has a variable-depth encoder, you can still set the number of layers used as you want
model = PonderAlbertClassifier.from_pretrained('albert-base-v2', num_hidden_layers=43,
                                               target_halt_probability=0.2)

Model training

from ponder_albert.losses import PonderClassificationLoss

# Sample dataset
texts = ['The cat sat on the mat', 'The mat was sat on by the cat']
labels = [0, 1]

# Initializes the PonderNet criterion for text-classification
optimizer = torch.optim.Adam(model.parameters())
criterion = PonderClassificationLoss(kl_penalty_factor=1e-2)
model.train()

# Single parameter update
prediction = model(texts)
loss = criterion(prediction, labels)['total_loss']
loss.backward()
optimizer.step()

Inference with halting mechanism

During inference, the halting mechanism can stop the model halfway. But note that since the halting mechanism is stochastic, the results can still vary.

model.eval()

# Let's try it once
model(['My cool new text'])

# {'logits': ...,
#  'halt_probabilities': ...,
#  'model_halt_dist': GeneralizedGeometricDist(),
#  'target_halt_dist': GeneralizedGeometricDist(),
#  'passes': 5
#  }

# ...now twice
model(['My cool new text'])

# {'logits': ...,
#  'halt_probabilities': ...,
#  'model_halt_dist': GeneralizedGeometricDist(),
#  'target_halt_dist': GeneralizedGeometricDist(),
#  'passes': 8
#  }

ponderalbert's People

Contributors

piero2c avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.