Coder Social home page Coder Social logo

Comments (1)

LarsHill avatar LarsHill commented on May 9, 2024

I actually encountered a similar scenario.
The standard Huggingface bert-base-cased model trained with 16 bit mixed precision (using pytorch-lightning), a vocab size of 100K and a seq len of 1024 uses around 34GB of Memory with a batch size=8 (on Nvidia A100). If I switch to full 32 bit precision the RAM usage almost doubles to 67GB of Memory which is expected.

However, if I use the x-transformers "Bert-like" implementation (mimicking the Huggingface config):

self.bert = TransformerWrapper(
            num_tokens=100_000,
            max_seq_len=1024,
            emb_dropout=0.1,
            tie_embedding=True,
            attn_layers=Encoder(
                dim=768,
                depth=12,
                heads=12,
                attn_flash=False,
                layer_dropout=0.1,  # stochastic depth - dropout entire layer
                attn_dropout=0.1,  # dropout post-attention
                ff_dropout=0.1,  # feedforward dropout
                use_abs_pos_emb=True,
            ),
        )

the memory usage does not change if I switch between 16-mixed and 32 precision. The overall usage (same batch size and hardware) remains at a constant 52GB, which is substantially higher than the HF model with 34GB.
Why does the precision setting of the lightning trainer not affect the x-transformers implementation?

I would love to use the x-transformer implementation due to the large amount of new features. However, I am wondering where these significant GPU RAM differences come from? And why does torch.autocast, which I think is used by lightning under the hood show no effect?

from x-transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.