Comments (11)
The stalling seems to be due to a weird bug with importing torch
after decord
(link). I've updated the requirements.txt
to remove the torch
dependency and it seems to run fine on GPU now (tested on an A100, CUDA 12.3). You will need to delete / reinstall your environment, or uninstall torch / torchvision.
I also added more detailed installation instructions which worked for me to the README (also shown below):
conda create -n lwm python=3.10
pip install -U "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt
from lwm.
Using Python 3.10 instead of Python 3.11, I'm able to install the requirements as stated in the repo (i.e. tensorflow 2.11.0). However, I still run into the
ImportError: cannot import name 'linear_util' from 'jax'
error when trying to run run_vision_chat.sh
do you use pyenv or conda? I switched to python 3.10, still got error : Could not find a version that satisfies the requirement tensorflow==2.11.0 (from versions: 2.13.0rc0, 2.13.0rc1, 2.13.0rc2, 2.13.0, 2.13.1, 2.14.0rc0, 2.14.0rc1, 2.14.0, 2.14.1, 2.15.0rc0, 2.15.0rc1, 2.15.0)
from lwm.
Updating flax, jax, chex and tux to the latest versions worked for me.
pip install flax -U
pip install tux -U
pip install chex -U
When updating jax make sure to install the GPU compatible version if you're using GPU
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
from lwm.
Bumping tensorflow to 2.14.1 allows for successful environment setup. Downloading LWM-Chat-32K-Jax and attempting to run bash scripts/run_vision_chat.sh
with a small test video raises the following error:
ImportError: cannot import name 'linear_util' from 'jax'
This is with flax==0.7.0
, jax==0.4.24
, and jaxlib==0.4.24
.
Some fiddling around reveals that linear_util
is a function in jax.extend
but not jax
. Not sure if this is related to using the different version of tensorflow, but would appreciate some advice on getting this up and running outside a TPU environment! Thank you.
from lwm.
Using Python 3.10 instead of Python 3.11, I'm able to install the requirements as stated in the repo (i.e. tensorflow 2.11.0). However, I still run into the
ImportError: cannot import name 'linear_util' from 'jax'
error when trying to run run_vision_chat.sh
from lwm.
I think I ran into a similar hanging issue before on GPU, due to something in the transformers
package stalling due to some FlaxSampleOutput or something. I didn't get a chance to look that deeply into it before since we were using almost all TPUs anyways but I'll look into it now
from lwm.
Switching the tensorflow version to 2.12.1 (not that I know whether that still works yet) creates another version conflict:
tensorflow 2.12.1 depends on numpy<=1.24.3 and >=1.22
chex 0.1.82 depends on numpy>=1.25.0
from lwm.
Taking the naive approach of going into each flax
source file that has a from jax import linear_util as lu
line and replacing it with from jax.extend import linear_util as lu
brings the script to a new error:
AttributeError: module 'jax.random' has no attribute 'KeyArray'
This one does have some hits on Google, such as this and this. Looks to be versioning issues, though there appears to be differences in suggested fixes.
Have you verified that LWM works outside a TPU environment? Can you share an example environment and params for run_vision_chat.sh
that works on a Linux GPU or multi-GPU machine? Thank you!
from lwm.
do you use pyenv or conda? I switched to python 3.10, still got error : Could not find a version that satisfies the requirement tensorflow==2.11.0 (from versions: 2.13.0rc0, 2.13.0rc1, 2.13.0rc2, 2.13.0, 2.13.1, 2.14.0rc0, 2.14.0rc1, 2.14.0, 2.14.1, 2.15.0rc0, 2.15.0rc1, 2.15.0)
Same issue with Python 3.10.13. Seems to work with Python 3.10.12 on Colab
from lwm.
do you use pyenv or conda? I switched to python 3.10, still got error : Could not find a version that satisfies the requirement tensorflow==2.11.0 (from versions: 2.13.0rc0, 2.13.0rc1, 2.13.0rc2, 2.13.0, 2.13.1, 2.14.0rc0, 2.14.0rc1, 2.14.0, 2.14.1, 2.15.0rc0, 2.15.0rc1, 2.15.0)
Am using Python 3.10.6, neither pyenv nor conda, just installing everything in a venv. I presume tensorflow is just specifying a range of compatible python versions for each previous library version.
Using tensorflow 2.14.1 doesn't seem to have affected the flax/jax errors I encountered down the line, though who knows if it would cause other problems eventually.
from lwm.
Updating flax, jax, chex and tux to the latest versions worked for me.
pip install flax -U pip install tux -U pip install chex -U
When updating jax make sure to install the GPU compatible version if you're using GPU
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
This worked for me! Now run_vision_chat.sh runs, though it appears to be hanging after something completes? I get
I0214 22:05:40.405170 140030517383168 xla_bridge.py:689] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0214 22:05:40.408375 140030517383168 xla_bridge.py:689] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-14 22:05:42.728080: W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
100%|██████████| 1/1 [00:05<00:00, 5.89s/it]
With ~18846MB of VRAM allocated on each of my 4 old GPUs on the server (P40s) but no activity.
from lwm.
Related Issues (20)
- RESOURCE_EXHAUSTED: XLA:TPU compile permanent
- DP FSDP & SP
- ValueError: Incompatible shapes for broadcasting: (2, 1, 1, 526464) and requested shape (2, 1, 32768, 32768) HOT 2
- Can it be used in the environment H100 ?
- Great work! Any plan for the vision-language models in Pytorch?
- Weight conversion scripts HOT 1
- Minimum GPU memory capacity required to run HOT 1
- vision model initialization
- what is the "_missing_keys"?
- Interesting Problems of Accuracy & Inference Speed with run_eval_needle.sh
- Question about loading LLaMA-2 7B on the LLM context extension stage
- vison-language model training data example for videos
- Any consideration on why use 4 sp & 32 tp?
- Quantize model weights
- Error while running bash command: run_sample_video.sh | Error: "TypeError: missing a required argument: 'segment_ids'" HOT 6
- Hang in vision_generation.py with newer versions of Jax HOT 1
- A question on your implementation of decoder phase of llama
- I wonder if you will release the tokenized dataset.
- checkpoints for run_sample_video.sh
- Asking to pad but the tokenizer does not have a padding token
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lwm.