Coder Social home page Coder Social logo

iesl / box-embeddings Goto Github PK

View Code? Open in Web Editor NEW
100.0 19.0 10.0 8.93 MB

Box Embeddings as Modules

Home Page: https://www.iesl.cs.umass.edu/box-embeddings

License: Apache License 2.0

Python 66.75% Makefile 2.02% Shell 0.04% CSS 6.92% HTML 0.49% Jupyter Notebook 23.79%
box-embedding deep-learning machine-learning

box-embeddings's People

Contributors

dhruvdcoder avatar mboratko avatar purujitgoyal avatar ssdasgupta avatar tejas4888 avatar trangtran72 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

box-embeddings's Issues

Demo models

  1. Embedding wordnet like in gumbel box paper.

  2. 2-class NLI using Multi-NLI.

Implement each one in Tensorflow and Pytorch in separate notebooks in an examples folder at the root level.

Add document for steps before release.

Add a RELEASE.md describing steps that need to be taken before making a release.

  1. Update the version in setup.py
  2. Check if the test coverage is above a threshold.
  3. Check if the CHANGELOG is up-to-date.
  4. Create a tag matching the version in the setup.py. Copy the "unreleased" section of the CHANGELOG into the description for the tag and release.

Usage Documentation

Examples of the following:

  1. Train shallow box representations (can / should be on a simple toy dataset, eg. mammal, birds, etc. maybe do both BCE and max-margin)
  2. How you use it on the output of BERT (maybe NLI?)
  3. (Maybe more in the future)
    Probably do this in notebooks: https://mybinder.org/

Change to "temperature" everywhere

Currently we have gumbel_beta and also beta in softplus, this is confusing.

We should change to using intersection_temperature (taking the place of current gumbel_beta) and volume_temperature (where 1/volume_temperature takes the place of current beta in softplus calculation).

Exact Bessel Volume

Currently, the volume function is an approximation of the Bessel volume in this repo. However, I have tried to implement an exact version of the Bessel volume in the past. It was numerically not stable. I would like to request you to have a look at the code snippet and see how this could be appended to this repo.

The bessel function wrapper

class Bessel(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        dev = input.device
        with torch.no_grad():
            x = special.k0(input.detach().cpu()).to(dev)
            input.to(dev)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        dev = grad_output.device
        with torch.no_grad():
            grad_input = grad_output*(-special.k1(input.detach().cpu())).to(dev)
            input.to(dev)

        return grad_input

The volume function

    def _log_bessel_volume(cls,
                           z: Tensor,
                           Z: Tensor,
                           gumbel_beta: float=1.,
                           scale: Union[float, Tensor] = 1.) -> Tensor:
        eps = torch.finfo(z.dtype).tiny
        if isinstance(scale, float):
            s = torch.tensor(scale)
        else:
            s = scale
        element = (2*torch.exp((z-Z)/(2*gumbel_beta))).clamp_max(100)
        return (torch.sum(
            torch.log(2*gumbel_beta*Bessel.apply(element).clamp_min(eps)),
            dim=-1) + torch.log(s)
        )

Subclass from `torch.Tensor`

One issue with subclassing will be the creation from zZ. We want to avoid the inverse-forward roundtrips whenever possible.

Tf.Tensor is not an instance of Tf.Variable

isinstance(tf.Variable([1]), tf.Tensor) returns False but returns True for tf.constant([1]). The variable implementation will be required for computing gradients as constants cannot be modified in tensorflow. Currently the code is functional for the tf.constant([1]) implementation

Return torch.sum instead of torch.mean for l2_side_regularizer

def l2_side_regularizer(
    box_tensor: BoxTensor, log_scale: bool = False
) -> Union[float, torch.Tensor]:
    z = box_tensor.z  # (..., box_dim)
    Z = box_tensor.Z  # (..., box_dim)
    if not log_scale:
        return torch.mean((Z - z) ** 2)
    else:
        return torch.mean(torch.log(torch.abs(Z - z) + eps))

Embedding module

Currently we need to have a separate embedding layer which we then pass to some box parameterization, it would be nice to wrap this so box embeddings can be created directly.

Update description

Current description mentions "Pytorch implementation for box..". Should be updated with Tensorflow too

Create a wrapper for Conditional probability

We need this because of the numerical issues with pytorch's softplus. While creating conditional probability we need to concatenate, take volume, and split again.

Optional step: Find the failure points: on what shapes and values does this happen?
First step: Profile the difference between applying softplus on two tensors or applying softplus on their concatenation and split it. Profile both runtime and memory usage.

BoxTensor Indexing

Hi,

I'm new to the box embeddings. I'm wondering if the current implementation has some sort of indexing method? E.g. if a box tensor contains 64 16-dimensional boxes, what's the best way to access specific boxes in this tensor? I think constructing a new box tensor with .from_zZ() could work but just wondering if we can do this more efficiently.

Thanks!

Make `gumbel_intersection` work with broadcasting semantic in box_shape.

Currently, we stack up the two BoxTensors to compute intersection using logsumexp.

torch.stack((t1.z / gumbel_beta, t2.z / gumbel_beta)), 0

This does not support broadcasting. For instance, if we have two BoxTensors of box_shape (batch1, 1, box_dim) and (1, batch2, box_dim), all the other intersection ops can produce intersection box with box_shape (batch1, batch2, box_dim) but gumbel_intersection cannot. This operation is required for multilabel classification using boxes.

Plan:

  1. Implement a logsumexp that works with two tensors and that is API consistent with torch.maximum(). Call this operator real_softmax because it is essentially a differentiable max operation.

  2. Update the current gumbel_intersection to use real_softmax instead of stack + logsumexp.

  3. Write boundary test-cases to check for numerical issues (underflow as well as overflow).

Visualization

General code where you visualize multiple boxes.
2 functions:

  1. One which can plot boxes in 2d
  2. One which plots boxes in n dimensions, using a horizontal / vertical stack of one dimensional boxes

For the API, we should look into "grammar of graphics".

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.