Coder Social home page Coder Social logo

Comments (11)

wilson1yan avatar wilson1yan commented on September 22, 2024 3

The stalling seems to be due to a weird bug with importing torch after decord (link). I've updated the requirements.txt to remove the torch dependency and it seems to run fine on GPU now (tested on an A100, CUDA 12.3). You will need to delete / reinstall your environment, or uninstall torch / torchvision.

I also added more detailed installation instructions which worked for me to the README (also shown below):

conda create -n lwm python=3.10
pip install -U "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt

from lwm.

Alpslee avatar Alpslee commented on September 22, 2024 2

Using Python 3.10 instead of Python 3.11, I'm able to install the requirements as stated in the repo (i.e. tensorflow 2.11.0). However, I still run into the

ImportError: cannot import name 'linear_util' from 'jax'

error when trying to run run_vision_chat.sh

do you use pyenv or conda? I switched to python 3.10, still got error : Could not find a version that satisfies the requirement tensorflow==2.11.0 (from versions: 2.13.0rc0, 2.13.0rc1, 2.13.0rc2, 2.13.0, 2.13.1, 2.14.0rc0, 2.14.0rc1, 2.14.0, 2.14.1, 2.15.0rc0, 2.15.0rc1, 2.15.0)

from lwm.

anubhavashok avatar anubhavashok commented on September 22, 2024 2

Updating flax, jax, chex and tux to the latest versions worked for me.

pip install flax -U
pip install tux -U
pip install chex -U

When updating jax make sure to install the GPU compatible version if you're using GPU

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

from lwm.

heyitsguay avatar heyitsguay commented on September 22, 2024 1

Bumping tensorflow to 2.14.1 allows for successful environment setup. Downloading LWM-Chat-32K-Jax and attempting to run bash scripts/run_vision_chat.sh with a small test video raises the following error:

ImportError: cannot import name 'linear_util' from 'jax'

This is with flax==0.7.0, jax==0.4.24, and jaxlib==0.4.24.

Some fiddling around reveals that linear_util is a function in jax.extend but not jax. Not sure if this is related to using the different version of tensorflow, but would appreciate some advice on getting this up and running outside a TPU environment! Thank you.

from lwm.

heyitsguay avatar heyitsguay commented on September 22, 2024 1

Using Python 3.10 instead of Python 3.11, I'm able to install the requirements as stated in the repo (i.e. tensorflow 2.11.0). However, I still run into the

ImportError: cannot import name 'linear_util' from 'jax'

error when trying to run run_vision_chat.sh

from lwm.

wilson1yan avatar wilson1yan commented on September 22, 2024 1

I think I ran into a similar hanging issue before on GPU, due to something in the transformers package stalling due to some FlaxSampleOutput or something. I didn't get a chance to look that deeply into it before since we were using almost all TPUs anyways but I'll look into it now

from lwm.

heyitsguay avatar heyitsguay commented on September 22, 2024

Switching the tensorflow version to 2.12.1 (not that I know whether that still works yet) creates another version conflict:

    tensorflow 2.12.1 depends on numpy<=1.24.3 and >=1.22
    chex 0.1.82 depends on numpy>=1.25.0

from lwm.

heyitsguay avatar heyitsguay commented on September 22, 2024

Taking the naive approach of going into each flax source file that has a from jax import linear_util as lu line and replacing it with from jax.extend import linear_util as lu brings the script to a new error:

AttributeError: module 'jax.random' has no attribute 'KeyArray'

This one does have some hits on Google, such as this and this. Looks to be versioning issues, though there appears to be differences in suggested fixes.

Have you verified that LWM works outside a TPU environment? Can you share an example environment and params for run_vision_chat.sh that works on a Linux GPU or multi-GPU machine? Thank you!

from lwm.

jayavanth avatar jayavanth commented on September 22, 2024

do you use pyenv or conda? I switched to python 3.10, still got error : Could not find a version that satisfies the requirement tensorflow==2.11.0 (from versions: 2.13.0rc0, 2.13.0rc1, 2.13.0rc2, 2.13.0, 2.13.1, 2.14.0rc0, 2.14.0rc1, 2.14.0, 2.14.1, 2.15.0rc0, 2.15.0rc1, 2.15.0)

Same issue with Python 3.10.13. Seems to work with Python 3.10.12 on Colab

from lwm.

heyitsguay avatar heyitsguay commented on September 22, 2024

do you use pyenv or conda? I switched to python 3.10, still got error : Could not find a version that satisfies the requirement tensorflow==2.11.0 (from versions: 2.13.0rc0, 2.13.0rc1, 2.13.0rc2, 2.13.0, 2.13.1, 2.14.0rc0, 2.14.0rc1, 2.14.0, 2.14.1, 2.15.0rc0, 2.15.0rc1, 2.15.0)

Am using Python 3.10.6, neither pyenv nor conda, just installing everything in a venv. I presume tensorflow is just specifying a range of compatible python versions for each previous library version.

Using tensorflow 2.14.1 doesn't seem to have affected the flax/jax errors I encountered down the line, though who knows if it would cause other problems eventually.

from lwm.

heyitsguay avatar heyitsguay commented on September 22, 2024

Updating flax, jax, chex and tux to the latest versions worked for me.

pip install flax -U
pip install tux -U
pip install chex -U

When updating jax make sure to install the GPU compatible version if you're using GPU

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

This worked for me! Now run_vision_chat.sh runs, though it appears to be hanging after something completes? I get

I0214 22:05:40.405170 140030517383168 xla_bridge.py:689] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0214 22:05:40.408375 140030517383168 xla_bridge.py:689] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-14 22:05:42.728080: W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
100%|██████████| 1/1 [00:05<00:00,  5.89s/it]

With ~18846MB of VRAM allocated on each of my 4 old GPUs on the server (P40s) but no activity.

from lwm.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.