Coder Social home page Coder Social logo

Comments (11)

KerfuffleV2 avatar KerfuffleV2 commented on July 20, 2024 2

Some of the changes discussed here are now in #47 - that doesn't include using mmap or safetensors (yet) but it could easily be added.

@BlinkDL

can we save a strategy string in pth? please test

Yes, this seems to work fine. Also safetensors has a facility for saving metadata so it shouldn't be a problem there either.

@Blealtan

The safetensors format has some lower level interface, which might be more efficient and can be load layer by layer directly to any device.

You're almost certainly right. I don't really know enough to really do that kind of fancy stuff, however I could add the simplest/naivest approach to loading and someone who has a better idea of the actual function of the code could modify it later to make it more efficient.

@alreadydone

Maybe ggerganov/llama.cpp#91 (comment) helps

Their motivation is basically the same as mine, but it's a bit of a different problem there since they do everything on CPU and never need to push the data to a GPU.

Basically, the only issue I saw with the mmap approach was when pinning memory for streaming to the GPU: that resulted in the memory being copied so the overall memory usage was higher. This was the same between my hacky mmap code and safetensors, so I suspect PyTorch's internal loading mechanism is doing something special to avoid that problem.

from chatrwkv.

KerfuffleV2 avatar KerfuffleV2 commented on July 20, 2024 1

Actually, there's a simpler way to massively reduce load times/memory usage. Basically just torch.save the model after the RWKV class gets done loading it. Then it can be loaded again, as long as the conversion stuff is disabled.

For example, using strategy cuda fp16i8 *15+ -> cuda fp16 *1 with RWKV-4-Pile-7B-20230109-ctx4096.pth and running v2/chat.py until it's ready for user input:

Normal loading

    User time (seconds): 109.77
    System time (seconds): 98.65
    Percent of CPU this job got: 349%
    Elapsed (wall clock) time (h:mm:ss or m:ss): 0:59.66
    Maximum resident set size (kbytes): 21080216

Loading pre-converted with mmap

    User time (seconds): 14.84
    System time (seconds): 8.09
    Percent of CPU this job got: 108%
    Elapsed (wall clock) time (h:mm:ss or m:ss): 0:21.07
    Maximum resident set size (kbytes): 12911364

Loading pre-converted with torch.load

    User time (seconds): 14.20
    System time (seconds): 9.84
    Percent of CPU this job got: 102%
    Elapsed (wall clock) time (h:mm:ss or m:ss): 0:23.55
    Maximum resident set size (kbytes): 9198336

I don't know why, but the mmap approach actually uses more memory. Probably something is happening to cause buffers to get copied instead of just used in place. It's still a little faster, but still not worth it.

The preconverted .pth file is about 7.3GB vs 14GB for the original. The mmap stuff seems like it may not be worth it but "precompiling" the model based on the strategy seems like it's a huge advantage in memory usage and load speed.

from chatrwkv.

BlinkDL avatar BlinkDL commented on July 20, 2024 1

Yes that is my plan too but I am working on other TODOs at this moment. Will be great if you can submit a pull request.

You can also use https://github.com/huggingface/safetensors :)

from chatrwkv.

KerfuffleV2 avatar KerfuffleV2 commented on July 20, 2024

Thanks for the reply. Right now it's a little awkward since I'm stuck on an older version because of #38 - also the current code is really, really nasty because it's basically just cut and pasting your code that I don't really understand.

I could create a draft pull request or something just as a proof of concept if you want, I don't know if there are any issues with what I'm doing (it appears to work, but...)

Also, with some trial and error I found out it's the pin_memory stuff that's causing the mmap approach to use more memory. I'm not really sure why it would use more overall, but if I comment that out it uses about the same amount as torch.load (I assume it will be slower streaming stuff to the GPU though).

from chatrwkv.

KerfuffleV2 avatar KerfuffleV2 commented on July 20, 2024

safetensors was really easy to drop in. Unfortunately, it seems like it has the same issue as my mmap approach where it uses much more memory than just loading via torch.load:

Via torch.load

Maximum resident set size (kbytes): 9088712

Via safetensors

Maximum resident set size (kbytes): 14935848

safetensors was a little faster (around 24sec total) while torch was around 26sec but it's a small difference.

PyTorch must be doing something magical to be able to pin those tensors without actually having to copy or allocate memory or something like that. I was looking into it but I wasn't able to find out what was going on.

Since memory seems like the big bottleneck here, unfortunately it seems like using torch.load is still best.

edit: These tests were done with the same data/strategy as above, so the 7B model mostly converted to fp16i8`.

from chatrwkv.

KerfuffleV2 avatar KerfuffleV2 commented on July 20, 2024

@BlinkDL

Now that the latest version works for me, I'm able to proceed with this. I'm not completely sure what approach you want to take though.

edit: Also, I think the memory increase for these two approaches only applies for .pin_memory() which is (I think) only used when streaming. So you can potentially get the pros for free except in the streaming situation. In that case, then it could possibly be a tradeoff between memory usage/streaming usage when running the model.

mmap

Pros: Compatible with existing models. Loads faster than torch.load.

Cons: Custom code based that will break if PyTorch changes their model format. Using mmap may be less portable. Modifying an mmaped file while it's in use will probably cause problems. Uses ~30% more memory than torch.load

safetensors

Pros: Loads faster than torch.load. Safer that the pickle format.

Cons: Requires converting models to a new format. Uses ~30% more memory than torch.load.

There's probably no reason not to at least add support for loading ST models though, the only disadvantage is adding a dependency to the safetensors package. Like I mentioned, I got it working trivially so it's a super simple change. The only question here really would be do you want to detect a certain file extension like .safetensors or whatever for the model file to enable using it, or pass something to the RWKV constructor?


The saving "precompiled" versions of models is a bit of a different issue. The big pro here is it allows loading models more much quickly and using a lot less memory doing the load process. Of course, it requires loading the full model first and then saving it. Additionally, unless other metadata is added, these models just won't work correctly if they're loaded again with a strategy that doesn't match the one they were saved with. One way would be to save the strategy in the file and basically just use that to override a strategy specified to the RWKV constructor or something like that.

from chatrwkv.

BlinkDL avatar BlinkDL commented on July 20, 2024

ok how abt we torch.save the converted weight, and torch.load :)

can we save a strategy string in pth? please test

from chatrwkv.

BlinkDL avatar BlinkDL commented on July 20, 2024

@KerfuffleV2 oh and we can do it in a separate convert_model.py
input: model, strategy
output: save converted model

from chatrwkv.

alreadydone avatar alreadydone commented on July 20, 2024

Maybe ggerganov/llama.cpp#91 (comment) helps

from chatrwkv.

Blealtan avatar Blealtan commented on July 20, 2024

The safetensors format has some lower level interface, which might be more efficient and can be load layer by layer directly to any device. Might work better in certain cases.

from chatrwkv.

BlinkDL avatar BlinkDL commented on July 20, 2024

Update ChatRWKV v2 & pip rwkv package (0.7.0):
Use v2/convert_model.py to convert a model for a strategy, for faster loading & saves CPU RAM.

from chatrwkv.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.