Coder Social home page Coder Social logo

attn_mask about coca-pytorch HOT 3 OPEN

lucidrains avatar lucidrains commented on July 29, 2024 1
attn_mask

from coca-pytorch.

Comments (3)

gshaikov-paige avatar gshaikov-paige commented on July 29, 2024 1

@pldlgb we only mask the last row of sim because this row corresponds to the CLS token query. Without this mask it will attend to all the keys before it, incl. PAD keys.

We don't need to mask other queries because we don't care what PAD queries attend to - they will be masked out when we compute CE loss. We also don't need to mask text queries since they are already masked by the causal mask so they can only look backwards at other text queries.

from coca-pytorch.

skyerhxx avatar skyerhxx commented on July 29, 2024

I have the same question. It seems like the attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) is not right.
Based on the original paper, the attn_mask here should be in the form of an inverted triangle, to prevent the current timestep feature from seeing the future timestep feature.

Welcome to discuss.

from coca-pytorch.

gshaikov-paige avatar gshaikov-paige commented on July 29, 2024

@skyerhxx This is not the causal mask, this is a mask that prevents CLS tokens from attending to PAD tokens in the batch.

We add PAD tokens to the text batch since text examples have different length but the tensor has a fixed dimension, so to concat them into a batch tensor one must pad the end sequence with dummy token, i.e. a PAD token. However, since we append CLS token to the very end, it will attend to the entire sequence, including PAD tokens, which we don't want. So we mask them out.

from coca-pytorch.

Related Issues (17)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.