Coder Social home page Coder Social logo

lwm's People

Contributors

eltociear avatar lhao499 avatar wilson1yan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lwm's Issues

Mistral

Hi,
LWM is incredible! Any plans to release a Mistral version?
Thanks!

AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'

Hello, please tell me about the error when running vision_chat.py, where jax==0.4.23, tux==0.0.2
The specific errors are as follows:
File "/home/LWM-main/lwm/vision_chat.py", line 12, in
from tux import (
File "/root/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/init.py", line 17, in
from .optimizers import (AdamWOptimizerFactory, get_weight_decay_mask, optax_add_scheduled_weight_decay,
File "/root/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/optimizers.py", line 193, in
class OptaxScheduledWeightDecayState(NamedTuple):
File "/root/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/optimizers.py", line 194, in OptaxScheduledWeightDecayState
count: jnp.DeviceArray
File "/root/miniconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'

Memory requirements

It would be worth to provide the measured memory requirements for inference Text Models at 32K, 128K,256K,512K and 1M tokens context window in both PyTorch and JAX.

Running into issues for mac M1

I am trying to run the run_sample_video.sh file from the scripts folder.
I am running into a lot of dependency issues when running this on a mac M1.
Has anyone been successful in running it on M1 ?

vision chat error

Hi,

I'm trying to run run_vision_chat.sh but getting the following error:

(lwm) minyoung@claw2:~/Projects/LWM$ bash scripts/run_vision_chat.sh 
I0215 18:19:20.605390 140230836105600 xla_bridge.py:689] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0215 18:19:20.607900 140230836105600 xla_bridge.py:689] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-15 18:19:29.755994: W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Traceback (most recent call last):
  File "/home/minyoung/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/minyoung/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/minyoung/Projects/LWM/lwm/vision_chat.py", line 254, in <module>
    run(main)
  File "/home/minyoung/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/minyoung/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/minyoung/Projects/LWM/lwm/vision_chat.py", line 249, in main
    sampler = Sampler()
  File "/home/minyoung/Projects/LWM/lwm/vision_chat.py", line 42, in __init__
    self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
  File "/home/minyoung/Projects/LWM/lwm/llama.py", line 260, in get_jax_mesh
    return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'tp', 'sp'))
  File "/home/minyoung/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/distributed.py", line 140, in get_jax_mesh
    mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
ValueError: cannot reshape array of size 1 into shape (1,newaxis,32,1)

These are the model configs I used.

export llama_tokenizer_path="./LWM-Chat-1M-Jax/tokenizer.model"
export vqgan_checkpoint="./LWM-Chat-1M-Jax/vqgan"
export lwm_checkpoint="./LWM-Chat-1M-Jax/params"
export input_file="./traj0.mp4"

Chat format?

Hello, what is the chat format for the chat models?

Llama?

Getting memory error while running run_vision.sh

I0228 17:55:33.471474 139972342939648 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0228 17:55:33.473013 139972342939648 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory

0%| | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:05<00:00, 5.02s/it]
100%|██████████| 1/1 [00:05<00:00, 5.02s/it]
2024-02-28 17:56:12.936238: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 29.79GiB (rounded to 31985762560)requested by op
2024-02-28 17:56:12.936491: W external/tsl/tsl/framework/bfc_allocator.cc:497] ____************************************************************
Fatal Python error: Segmentation fault

Thread 0x00007f4b5e7fc640 (most recent call first):
File "/usr/lib/python3.10/threading.py", line 324 in wait
File "/usr/lib/python3.10/threading.py", line 607 in wait
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/tqdm/_monitor.py", line 60 in run
File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Current thread 0x00007f4dd9c78000 (most recent call first):
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/compiler.py", line 256 in backend_compile
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/profiler.py", line 336 in wrapper
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/compiler.py", line 333 in compile_or_get_cached
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2528 in _cached_compilation
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2659 in from_hlo
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2219 in compile
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/pjit.py", line 1165 in _pjit_call_impl_python
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/pjit.py", line 1229 in call_impl_cache_miss
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/pjit.py", line 1245 in _pjit_call_impl
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/core.py", line 935 in process_primitive
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/core.py", line 447 in bind_with_trace
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/core.py", line 2740 in bind
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/pjit.py", line 168 in _python_pjit_helper
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/pjit.py", line 257 in cache_miss
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179 in reraise_with_filtered_traceback
File "/home/azureuser/LWM1/LWM/lwm/vision_chat.py", line 230 in call
File "/home/azureuser/LWM1/LWM/lwm/vision_chat.py", line 250 in main
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/absl/app.py", line 254 in _run_main
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/absl/app.py", line 308 in run
File "/home/azureuser/LWM1/LWM/lwm/vision_chat.py", line 254 in
File "/usr/lib/python3.10/runpy.py", line 86 in _run_code
File "/usr/lib/python3.10/runpy.py", line 196 in _run_module_as_main

Extension modules: PIL._imaging, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, jaxlib.cpu_feature_guard, charset_normalizer.md, yaml._yaml, msgpack._cmsgpack, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._flinalg, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial.transform._rotation, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, scipy.stats.beta_ufunc, scipy.stats._boost.beta_ufunc, scipy.stats.binom_ufunc, scipy.stats._boost.binom_ufunc, scipy.stats.nbinom_ufunc, scipy.stats._boost.nbinom_ufunc, scipy.stats.hypergeom_ufunc, scipy.stats._boost.hypergeom_ufunc, scipy.stats.ncf_ufunc, scipy.stats._boost.ncf_ufunc, scipy.stats.ncx2_ufunc, scipy.stats._boost.ncx2_ufunc, scipy.stats.nct_ufunc, scipy.stats._boost.nct_ufunc, scipy.stats.skewnorm_ufunc, scipy.stats._boost.skewnorm_ufunc, scipy.stats.invgauss_ufunc, scipy.stats._boost.invgauss_ufunc, scipy.interpolate._fitpack, scipy.interpolate.dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._ansari_swilk_statistics, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._mvn, scipy.stats._rcont.rcont, scipy.stats._unuran.unuran_wrapper, multidict._multidict, yarl._quoting_c, aiohttp._helpers, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket, frozenlist._frozenlist, google._upb._message, psutil._psutil_linux, psutil._psutil_posix, sentencepiece._sentencepiece (total: 129)

out of memory error

bash scripts/run_vision_chat.sh
removed --mesh_dim param
model is LWM-Chat-32K-Jax
out of memory error, how to solve it

my card is nvidia 2080 super 8G

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1708500656.672727   10871 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
I0221 15:30:57.202437 140383335174272 xla_bridge.py:513] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0221 15:30:57.202921 140383335174272 xla_bridge.py:513] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-21 15:36:18.340692: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.00GiB (rounded to 2147483648)requested by op 
2024-02-21 15:36:18.340908: W external/tsl/tsl/framework/bfc_allocator.cc:497] *________**********************************************************************_____________________
2024-02-21 15:36:18.340944: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2644] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2147483648 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.00GiB
              constant allocation:         0B
        maybe_live_out allocation:    2.00GiB
     preallocated temp allocation:         0B
                 total allocation:    3.00GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
        Buffer 1:
                Size: 2.00GiB
                Operator: op_name="pjit(to_dtype)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/mnt/data/test/LWM/lwm/vision_chat.py" source_line=199
                XLA Label: fusion
                Shape: f32[32,4096,4096]
                ==========================

        Buffer 2:
                Size: 1.00GiB
                Entry Parameter Subshape: bf16[32,4096,4096]
                ==========================


jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 254, in <module>
    run(main)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 249, in main
    sampler = Sampler()
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 51, in __init__
    self._load_model()
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 199, in _load_model
    self.params = tree_apply(shard_fns, self.params)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/jax_utils.py", line 148, in tree_apply
    return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/jax_utils.py", line 148, in <lambda>
    return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/distributed.py", line 95, in shard_fn
    return jax_shard_function(tensor).block_until_ready()
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2147483648 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.00GiB
              constant allocation:         0B
        maybe_live_out allocation:    2.00GiB
     preallocated temp allocation:         0B
                 total allocation:    3.00GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
        Buffer 1:
                Size: 2.00GiB
                Operator: op_name="pjit(to_dtype)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/mnt/data/test/LWM/lwm/vision_chat.py" source_line=199
                XLA Label: fusion
                Shape: f32[32,4096,4096]
                ==========================

        Buffer 2:
                Size: 1.00GiB
                Entry Parameter Subshape: bf16[32,4096,4096]
                ==========================


I0000 00:00:1708500978.900009   10871 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.
(lwm) test@test-3:/mnt/data/test/LWM$ nvidia-smi
Wed Feb 21 15:47:00 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2080 S...    Off| 00000000:01:00.0 Off |                  N/A |
|  0%   40C    P0               23W / 250W|      0MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

LlamaForCausalLM requires the PyTorch library but it was not found in your environment

Traceback (most recent call last):
  File "/output/LWM/scripts/sample_pyt.py", line 8, in <module>
    model = LlamaForCausalLM.from_pretrained(args.model)
  File "/usr/local/envs/lwm/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1112, in __getattribute__
    requires_backends(cls, cls._backends)
  File "/usr/local/envs/lwm/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1100, in requires_backends
    raise ImportError("".join(failed))
ImportError: 
LlamaForCausalLM requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.

Generate video Only First frame has img, other frames are random pixel

I use bash scripts/run_sample_video.sh, the sh file is:
using LWM-Chat-1M-JAX model.

...

python3 -u -m lwm.vision_generation \
    --prompt='A long big pig is walking across the street' \
    --output_file='fireworks.mp4' \
    --temperature_image=1.0 \
    --temperature_video=1.0 \
    --top_k_image=8192 \
    --top_k_video=1000 \
    --cfg_scale_image=5.0 \
    --cfg_scale_video=1.0 \
    --vqgan_checkpoint="$vqgan_checkpoint" \
    --n_frames=8 \
    --mesh_dim='!1,1,2,1' \
    --dtype='bf16' \
    --load_llama_config='7b' \
    --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,use_flash_attention=True,scan_attention=False,scan_query_chunk_size=256,scan_key_chunk_size=256,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
    --load_checkpoint="params::$lwm_checkpoint" \
    --tokenizer.vocab_file="$llama_tokenizer_path"
read

after generation, the output video only first frame has meaningful frame, other frame are all random pixel.

Incorrect input shape when reading .png files

Using run_vision_chat.sh with a .PNG image results in

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 254, in <module>
    run(main)
  File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 250, in main
    output = sampler(prompts, FLAGS.max_n_frames)[0]
  File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 228, in __call__
    batch = self.construct_input(prompts, max_n_frames)
  File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 123, in construct_input
    vision = self._read_process_vision(prompt['input_path'], max_n_frames)
  File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 102, in _read_process_vision
    enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)
  File "/mnt/vol_f/LWM/lwm/vqgan.py", line 53, in encode
    return self._encode(pixel_values)
  File "/mnt/vol_f/LWM/lwm/vqgan.py", line 35, in fn
    return self.model.apply(
  File "/mnt/vol_f/LWM/lwm/vqgan.py", line 122, in encode
    hidden_states = self.encoder(pixel_values)
  File "/mnt/vol_f/LWM/lwm/vqgan.py", line 155, in __call__
    hidden_states = nn.Conv(self.config.hidden_channels, [3, 3])(pixel_values)
  File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/linear.py", line 429, in __call__
    kernel = self.param('kernel', self.kernel_init, kernel_shape,
flax.errors.ScopeParamShapeError: Initializer expected to generate shape (3, 3, 3, 128) but got shape (3, 3, 4, 128) instead for parameter "kernel" in "/encoder/Conv_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

when number of channels in the input is > 3 (if transparency is present).

Safetensors

Do you have any plans on creating safetensors for the models?

Mesh dim setting error

Hi, when I follow the script run_vision_chat.sh from #13 and deprecate --mesh_dim, I still face such error:
attn_output = ring_attention_sharded(
ValueError: shard_map applied to the function 'functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 8, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

  • args[0] of shape float32[1,2560,32,128], where args[0] is bound to functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')'s parameter 'q', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

  • args[1] of shape float32[1,2560,32,128], where args[1] is bound to functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')'s parameter 'k', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

  • args[2] of shape float32[1,2560,32,128], where args[2] is bound to functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')'s parameter 'v', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

  • args[3] of shape float32[1,1,2560,2560], where args[3] is bound to functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')'s parameter 'attn_mask', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, 'sp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')' appropriately.

Any suggestion for the solution?

ValueError: bytes is too large when running scripts/run_train_text.sh

Detail message:
Traceback (most recent call last):
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/user_work_path/train/codes/LWM/lwm/train.py", line 396, in
run(main)
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/user_work_path/train/codes/LWM/lwm/train.py", line 387, in main
save_checkpoint(train_state, milestone=True)
File "/user_work_path/train/codes/LWM/lwm/train.py", line 325, in save_checkpoint
checkpointer.save_all(
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 102, in save_all
self.save_checkpoint(
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 46, in save_checkpoint
self.save_train_state_to_file(
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 78, in save_train_state_to_file
fout.write(packer.pack((key, to_bytes(value))))
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/msgpack/fallback.py", line 826, in pack
self._pack(obj)
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/msgpack/fallback.py", line 803, in _pack
self._pack(obj[i], nest_limit - 1)
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/msgpack/fallback.py", line 750, in _pack
raise ValueError("%s is too large" % type(obj).name)
ValueError: bytes is too large

train script
#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

export LLAMA_TOKENIZER_PATH=/user_work_path/tokenizer.model
export DATASET_PATH=/user_work_path/sample.jsonl
export SEED=1025

export PROJECT_ID='lwm'
export EXPERIMENT_NOTE=''
export EXPERIMENT_ID='example-text-train'
export OUTPUT_DIR=${PROJECT_DIR}/output

export COORDINATOR_ADDRESS=localhost:12345
export NUM_PROCESSES=1
export PROCESS_ID=0
export INITIALIZE_JAX_DISTRIBUTED=true

python3 -u -m lwm.train
--jax_distributed.coordinator_address ${COORDINATOR_ADDRESS}
--jax_distributed.initialize_jax_distributed ${INITIALIZE_JAX_DISTRIBUTED}
--jax_distributed.num_processes ${NUM_PROCESSES}
--jax_distributed.process_id ${PROCESS_ID}
--modality='text'
--mesh_dim='1,1,1,8'
--dtype='bf16'
--seed=${SEED}
--total_steps=10
--log_freq=1
--save_model_freq=0
--save_milestone_freq=5
--load_llama_config='13b'
--update_llama_config="dict(theta=10000,max_sequence_length=4096,use_flash_attention=False,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)"
--tokenizer.vocab_file="$LLAMA_TOKENIZER_PATH"
--optimizer.type='adamw'
--optimizer.accumulate_gradient_steps=1
--optimizer.adamw_optimizer.weight_decay=0.1
--optimizer.adamw_optimizer.lr=8e-5
--optimizer.adamw_optimizer.end_lr=8e-5
--optimizer.adamw_optimizer.lr_warmup_steps=5
--optimizer.adamw_optimizer.lr_decay_steps=200
--use_data_sharded_loader=True
--train_dataset.type='json'
--train_dataset.text_processor.fields='text'
--train_dataset.json_dataset.path="$DATASET_PATH"
--train_dataset.json_dataset.seq_length=1024
--train_dataset.json_dataset.batch_size=8
--train_dataset.json_dataset.tokenizer_processes=4
--train_dataset.json_dataset.tokenizer_parallel_chunk_size=2
--train_dataset.json_dataset.tokenizer_parallel_batch_size=8
--train_dataset.json_dataset.use_data_sharded_loader=True
--checkpointer.save_optimizer_state=True
--autoresume=False
--logger.append_uuid=False
--logger.online=False
--logger.project_id="$PROJECT_ID"
--logger.experiment_id="$EXPERIMENT_ID"
--logger.experiment_note="$EXPERIMENT_NOTE"
--logger.output_dir="$OUTPUT_DIR"
--logger.wandb_dir="$HOME/experiment_output/$PROJECT_ID"
read

environment
Package Version


absl-py 2.1.0
aiohttp 3.9.3
aiosignal 1.3.1
appdirs 1.4.4
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
cachetools 5.3.3
certifi 2024.2.2
charset-normalizer 3.3.2
chex 0.1.82
click 8.1.7
cloudpickle 3.0.0
contextlib2 21.6.0
datasets 2.17.1
decorator 5.1.1
decord 0.6.0
dill 0.3.8
docker-pycreds 0.4.0
einops 0.7.0
etils 1.7.0
exceptiongroup 1.2.0
executing 2.0.1
filelock 3.13.1
flax 0.7.0
frozenlist 1.4.1
fsspec 2023.10.0
gcsfs 2023.10.0
gitdb 4.0.11
GitPython 3.1.42
google-api-core 2.17.1
google-auth 2.28.1
google-auth-oauthlib 1.2.0
google-cloud-core 2.4.1
google-cloud-storage 2.14.0
google-crc32c 1.5.0
google-resumable-media 2.7.0
googleapis-common-protos 1.62.0
huggingface-hub 0.20.3
idna 3.6
imageio 2.34.0
imageio-ffmpeg 0.4.9
importlib_resources 6.1.2
ipdb 0.13.13
ipython 8.22.1
jax 0.4.23
jaxlib 0.4.23+cuda12.cudnn89
jedi 0.19.1
markdown-it-py 3.0.0
matplotlib-inline 0.1.6
mdurl 0.1.2
ml-collections 0.1.1
ml-dtypes 0.3.2
msgpack 1.0.7
multidict 6.0.5
multiprocess 0.70.16
nest-asyncio 1.6.0
numpy 1.26.4
nvidia-cublas-cu12 12.3.4.1
nvidia-cuda-cupti-cu12 12.3.101
nvidia-cuda-nvcc-cu12 12.3.107
nvidia-cuda-nvrtc-cu12 12.3.107
nvidia-cuda-runtime-cu12 12.3.101
nvidia-cudnn-cu12 8.9.7.29
nvidia-cufft-cu12 11.0.12.1
nvidia-cusolver-cu12 11.5.4.101
nvidia-cusparse-cu12 12.2.0.103
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.3.101
oauthlib 3.2.2
opt-einsum 3.3.0
optax 0.1.7
orbax-checkpoint 0.5.3
packaging 23.2
pandas 2.2.1
parso 0.8.3
pexpect 4.9.0
pillow 10.2.0
pip 23.3.1
prompt-toolkit 3.0.43
protobuf 4.25.3
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 15.0.0
pyarrow-hotfix 0.6
pyasn1 0.5.1
pyasn1-modules 0.3.0
Pygments 2.17.2
python-dateutil 2.8.2
pytz 2024.1
PyYAML 6.0.1
regex 2023.12.25
requests 2.31.0
requests-oauthlib 1.3.1
rich 13.7.0
rsa 4.9
scipy 1.12.0
sentencepiece 0.2.0
sentry-sdk 1.40.5
setproctitle 1.3.3
setuptools 68.2.2
six 1.16.0
smmap 5.0.1
stack-data 0.6.3
tensorstore 0.1.54
tiktoken 0.6.0
tokenizers 0.13.3
tomli 2.0.1
toolz 0.12.1
tqdm 4.66.2
traitlets 5.14.1
transformers 4.29.2
tux 0.0.2
typing_extensions 4.10.0
tzdata 2024.1
urllib3 2.2.1
wandb 0.16.3
wcwidth 0.2.13
wheel 0.41.2
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0

Could some one help me?

Requirements version error

Trying to install on an Ubuntu 22.04 system with pip 24.0 and python 3.11.

pip install -r requirements.txt yields the error: Could not find a version that satisfies the requirement tensorflow==2.11.0. Min version number that shows up for me is 2.12.0rc0.

run run_sample_video.sh config

请问这个文件中哪些参数都是什么意思呢?能不能写些注释呢?加入我想生成分辨率更高的视频,或者要生成时长更长的视频,我需要修改哪些参数呢?另外视频大小或者是分辨率有限制么?

NCCL Error when running the Jax LWM-Chat-1M-Jax

Environment

GPUs: 4x80G
Package Version


absl-py 2.1.0
aiohttp 3.9.3
aiosignal 1.3.1
appdirs 1.4.4
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
cachetools 5.3.2
certifi 2024.2.2
charset-normalizer 3.3.2
chex 0.1.82
click 8.1.7
cloudpickle 3.0.0
contextlib2 21.6.0
datasets 2.13.0
decorator 5.1.1
decord 0.6.0
dill 0.3.6
docker-pycreds 0.4.0
einops 0.7.0
etils 1.7.0
exceptiongroup 1.2.0
executing 2.0.1
filelock 3.13.1
flax 0.7.0
frozenlist 1.4.1
fsspec 2024.2.0
gcsfs 2024.2.0
gitdb 4.0.11
GitPython 3.1.42
google-api-core 2.17.1
google-auth 2.28.1
google-auth-oauthlib 1.2.0
google-cloud-core 2.4.1
google-cloud-storage 2.14.0
google-crc32c 1.5.0
google-resumable-media 2.7.0
googleapis-common-protos 1.62.0
huggingface-hub 0.20.3
idna 3.6
imageio 2.34.0
imageio-ffmpeg 0.4.9
importlib-resources 6.1.1
ipdb 0.13.13
ipython 8.21.0
jax 0.4.23
jaxlib 0.4.23+cuda12.cudnn89
jedi 0.19.1
markdown-it-py 3.0.0
matplotlib-inline 0.1.6
mdurl 0.1.2
ml-collections 0.1.1
ml-dtypes 0.3.2
msgpack 1.0.7
multidict 6.0.5
multiprocess 0.70.14
nest-asyncio 1.6.0
numpy 1.26.4
nvidia-cublas-cu12 12.3.4.1
nvidia-cuda-cupti-cu12 12.3.101
nvidia-cuda-nvcc-cu12 12.3.107
nvidia-cuda-nvrtc-cu12 12.3.107
nvidia-cuda-runtime-cu12 12.3.101
nvidia-cudnn-cu12 8.9.7.29
nvidia-cufft-cu12 11.0.12.1
nvidia-cusolver-cu12 11.5.4.101
nvidia-cusparse-cu12 12.2.0.103
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.3.101
oauthlib 3.2.2
opt-einsum 3.3.0
optax 0.1.7
orbax-checkpoint 0.5.3
packaging 23.2
pandas 2.2.0
parso 0.8.3
pexpect 4.9.0
pillow 10.2.0
pip 23.3.1
prompt-toolkit 3.0.43
protobuf 4.25.3
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 15.0.0
pyasn1 0.5.1
pyasn1-modules 0.3.0
Pygments 2.17.2
python-dateutil 2.8.2
pytz 2024.1
PyYAML 6.0.1
regex 2023.12.25
requests 2.31.0
requests-oauthlib 1.3.1
rich 13.7.0
rsa 4.9
scipy 1.12.0
sentencepiece 0.2.0
sentry-sdk 1.40.5
setproctitle 1.3.3
setuptools 68.2.2
six 1.16.0
smmap 5.0.1
stack-data 0.6.3
tensorstore 0.1.53
tiktoken 0.6.0
tokenizers 0.13.3
tomli 2.0.1
toolz 0.12.1
tqdm 4.66.2
traitlets 5.14.1
transformers 4.29.2
tux 0.0.2
typing_extensions 4.9.0
tzdata 2024.1
urllib3 2.2.1
wandb 0.16.3
wcwidth 0.2.13
wheel 0.41.2
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0

Error Messasge

`I0222 09:24:21.054814 140683333334848 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0222 09:24:21.056322 140683333334848 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-22 09:24:21.097023: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
0%| | 0/1 [00:00<?, ?it/s]2024-02-22 09:25:32.707642: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.
2024-02-22 09:25:32.707708: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.; current tracing scope: reduce-scatter-start.5; current profiling annotation: XlaModule:#hlo_module=pjit__forward_generate,program_id=21#.
2024-02-22 09:25:32.807973: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.
2024-02-22 09:25:32.808024: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.; current tracing scope: reduce-scatter-start.5; current profiling annotation: XlaModule:#hlo_module=pjit__forward_generate,program_id=21#.
2024-02-22 09:25:32.821063: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.
2024-02-22 09:25:32.821116: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.; current tracing scope: reduce-scatter-start.5; current profiling annotation: XlaModule:#hlo_module=pjit__forward_generate,program_id=21#.
2024-02-22 09:25:32.825532: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.
2024-02-22 09:25:32.825585: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.; current tracing scope: reduce-scatter-start.5; current profiling annotation: XlaModule:#hlo_module=pjit__forward_generate,program_id=21#.
0%| | 0/1 [00:07<?, ?it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/root/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/root/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/data/jeff/Git/LLM/src/T2V/LWM/lwm/vision_generation.py", line 258, in
run(main)
File "/root/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/root/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/data/jeff/Git/LLM/src/T2V/LWM/lwm/vision_generation.py", line 184, in main
img_enc, img = generate_first_frame(prompts, max_input_length=128)
File "/data/jeff/Git/LLM/src/T2V/LWM/lwm/vision_generation.py", line 158, in generate_first_frame
output, sharded_rng = _sharded_forward_generate(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.; current tracing scope: reduce-scatter-start.5; current profiling annotation: XlaModule:#hlo_module=pjit__forward_generate,program_id=21#.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed`

error when run run_sample_img.sh

Hi, I met this error while run the script run_sample_img.sh with dimention 1,1,3,1 as I am using 3 A100 GPUs.
Here is the script:

export llama_tokenizer_path="LWM-Chat-1M-Jax/tokenizer.model"
export vqgan_checkpoint="LWM-Chat-1M-Jax/vqgan"
export lwm_checkpoint="LWM-Chat-1M-Jax/params"

python3 -u -m lwm.vision_generation \
    --prompt='Fireworks over the city' \
    --output_file='fireworks.png' \
    --temperature_image=1.0 \
    --top_k_image=8192 \
    --cfg_scale_image=5.0 \
    --vqgan_checkpoint="$vqgan_checkpoint" \
    --n_frames=1 \
    --mesh_dim='1,1,3,1' \
    --dtype='fp32' \
    --load_llama_config='7b' \
    --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,use_flash_attention=True,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
    --load_checkpoint="params::$lwm_checkpoint" \
    --tokenizer.vocab_file="$llama_tokenizer_path"

Here is the error info:

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workspace/lwm/vision_generation.py", line 258, in <module>
    run(main)
  File "/opt/conda/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/opt/conda/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/workspace/lwm/vision_generation.py", line 94, in main
    model = FlaxVideoLLaMAForCausalLM(
  File "/workspace/lwm/vision_llama.py", line 145, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 213, in __init__
    params_shape_tree = jax.eval_shape(init_fn, self.key)
  File "/workspace/lwm/vision_llama.py", line 170, in init_weights
    random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]
  File "/workspace/lwm/vision_llama.py", line 401, in __call__
    outputs = self.transformer(
  File "/workspace/lwm/vision_llama.py", line 320, in __call__
    outputs = self.h(
  File "/workspace/lwm/llama.py", line 991, in __call__
    hidden_states, _ = nn.scan(
  File "/opt/conda/lib/python3.10/site-packages/flax/core/axes_scan.py", line 139, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/opt/conda/lib/python3.10/site-packages/flax/core/axes_scan.py", line 115, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/workspace/lwm/llama.py", line 766, in __call__
    attn_outputs = self.attention(
  File "/workspace/lwm/llama.py", line 657, in __call__
    attn_output = ring_attention_sharded(
ValueError: shard_map applied to the function 'functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f98223cb6d0>, axis_name='sp')' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 1, 3, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

How long to train the model?

Thank you for your contribution! Amazing performance! I just wonder the computational requirements for training such world models, e.g., how many GPUs and how long you need to train it?

How to setup conversation with vision chat?

I'm currently able to use run_vision_chat.sh with a limited number of video frames being passed in for a single text query. The text result is output from the model and then the process ends. However, the paper shows examples of a continuous dialogue about a video and I was wondering if it's possible to set this up.

Please provide example uses of the scripts

Only got image with vision jax model to work, and even then had to remove the mesh_grid arg.

Everything else has failed.

E.g. needle fails like:

#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

export llama_tokenizer_path="LWM-Chat-1M-Jax/tokenizer.model"
export lwm_text_checkpoint="LWM-Chat-1M-Jax/params"
# jsonl file containing text for haystack. Each line should be a json
# with a single key "text" containing the text.
export haystack_file="../ultrachat_qa_mix_128K/data.jsonl"
export output_file="output"

python3 -u scripts/eval_needle.py \
    --mesh_dim='!1,-1,4,1' \
    --dtype='fp32' \
    --load_llama_config='7b' \
    --update_llama_config="dict(theta=10000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=True,scan_query_chunk_size=1024,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \
    --load_checkpoint="params::$lwm_text_checkpoint" \
    --tokenizer.vocab_file="$llama_tokenizer_path" \
    --max_tokens_per_batch=5000 \
    --output_file="$output_file" \
    --haystack_file="$haystack_file" \
    --context_lengths_min=1000 \
    --context_lengths_max=10000 \
    --n_context_length_intervals=20 \
    --n_document_depth_intervals=20 \
    --n_rounds=3
read
(lwm) jon@gpu:~/LWM$ bash scripts/run_eval_needle.sh
I0216 10:25:24.068257 139879088207680 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0216 10:25:24.070914 139879088207680 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Starting Needle In A Haystack Testing...
- Context Lengths: 20, Min: 1000, Max: 10000
- Document Depths: 20, Min: 0%, Max: 100%
- Needle: The special magic {city} number is: {rnd_number}



W0216 10:26:39.398258 139879088207680 _metadata.py:139] Compute Engine Metadata server unavailable on attempt 1 of 3. Reason: timed out
W0216 10:26:39.447406 139879088207680 _metadata.py:139] Compute Engine Metadata server unavailable on attempt 2 of 3. Reason: [Errno 113] No route to host
W0216 10:26:42.451228 139879088207680 _metadata.py:139] Compute Engine Metadata server unavailable on attempt 3 of 3. Reason: timed out
W0216 10:26:42.451697 139879088207680 _default.py:338] Authentication failed using Compute Engine authentication due to unavailable metadata server.
W0216 10:26:42.530295 139879088207680 _metadata.py:208] Compute Engine Metadata server unavailable on attempt 1 of 5. Reason: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/service-accounts/default/?recursive=true (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7f372430f4c0>: Failed to resolve 'metadata.google.internal' ([Errno -2] Name or service not known)"))
W0216 10:26:42.607035 139879088207680 _metadata.py:208] Compute Engine Metadata server unavailable on attempt 2 of 5. Reason: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/service-accounts/default/?recursive=true (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7f372430efb0>: Failed to resolve 'metadata.google.internal' ([Errno -2] Name or service not known)"))
W0216 10:26:42.686556 139879088207680 _metadata.py:208] Compute Engine Metadata server unavailable on attempt 3 of 5. Reason: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/service-accounts/default/?recursive=true (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7f372430f130>: Failed to resolve 'metadata.google.internal' ([Errno -2] Name or service not known)"))
W0216 10:26:42.767113 139879088207680 _metadata.py:208] Compute Engine Metadata server unavailable on attempt 4 of 5. Reason: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/service-accounts/default/?recursive=true (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7f372430f160>: Failed to resolve 'metadata.google.internal' ([Errno -2] Name or service not known)"))
W0216 10:26:42.851304 139879088207680 _metadata.py:208] Compute Engine Metadata server unavailable on attempt 5 of 5. Reason: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/service-accounts/default/?recursive=true (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7f372430f7f0>: Failed to resolve 'metadata.google.internal' ([Errno -2] Name or service not known)"))
completed 0
Traceback (most recent call last):
  File "/home/jon/LWM/scripts/eval_needle.py", line 447, in <module>
    run(main)
  File "/home/jon/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/jon/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/jon/LWM/scripts/eval_needle.py", line 444, in main
    ht.start_test()
  File "/home/jon/LWM/scripts/eval_needle.py", line 306, in start_test
    self.run_test()
  File "/home/jon/LWM/scripts/eval_needle.py", line 230, in run_test
    full_contexts = self.read_context_files(FLAGS.n_rounds)
  File "/home/jon/LWM/scripts/eval_needle.py", line 129, in read_context_files
    text = json.loads(f.readline())['text']
KeyError: 'text'

i.e. some specific files are required that aren't shared, and some access to google is used, which isn't explained.

Great work! Any plan to train a smaller version, e.g. around 3B?

Hello,

It's a really great work which contributes a lot to the community!

Do you have any plan to train a smaller version of large world model (e.g., 1~3B), which may be based on smaller models like Phi-2? It should be much easier and use less computing resources.

bash run_vision_chat.sh -- cause flax.errors.ScopeParamNotFoundError: Could not find parameter named "embedding" in scope "/transformer/wte"

While run the command of "bash scripts/run_vision_chat.sh". Error happended .How to fix it.

(lwm) llm@llm-PowerEdge-R730xd:~/projects/LWM-main$ bash scripts/run_vision_chat.sh
I0221 14:02:43.257625 139932541391232 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0221 14:02:43.260045 139932541391232 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
100%|██████████| 1/1 [00:05<00:00, 5.59s/it]
Traceback (most recent call last):
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 254, in
run(main)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 250, in main
output = sampler(prompts, FLAGS.max_n_frames)[0]
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 230, in call
output, self.sharded_rng = self._forward_generate(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 781, in infer_params
return common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, *args)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers
)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 206, in fn
output = self.model.generate(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 429, in generate
return self._sample(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 733, in _sample
state = sample_search_body_fn(state)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 704, in sample_search_body_fn
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 232, in call
outputs = self.module.apply(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1511, in apply
return apply(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/core/scope.py", line 934, in wrapper
y = fn(root, *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 2082, in scope_fn
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 401, in call
outputs = self.transformer(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 313, in call
input_text_embeds = self.wte(jnp.where(vision_masks, 0, input_ids))
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 836, in _call_wrapped_method
self._try_setup()
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1094, in _try_setup
self.setup()
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/linear.py", line 771, in setup
self.embedding = self.param('embedding',
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1263, in param
v = self.scope.param(name, init_fn, *init_args, unbox=unbox)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/core/scope.py", line 842, in param
raise errors.ScopeParamNotFoundError(name, self.path_text)
flax.errors.ScopeParamNotFoundError: Could not find parameter named "embedding" in scope "/transformer/wte". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamNotFoundError)

pytorch model & ring attention

Thanks for sharing this excellent great work. We want to use pytorch models to try the effect of ring attention. Are there any plans to develop ring attention implementation under pytorch?

NCCL Error when running the Jax LWM-Chat-32K-Jax

Environment
GPUs: 8x4090

Package Version

Package Version


absl-py 2.1.0
aiohttp 3.9.3
aiosignal 1.3.1
appdirs 1.4.4
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
cachetools 5.3.2
certifi 2024.2.2
charset-normalizer 3.3.2
chex 0.1.82
click 8.1.7
cloudpickle 3.0.0
contextlib2 21.6.0
datasets 2.13.0
decorator 5.1.1
decord 0.6.0
dill 0.3.6
docker-pycreds 0.4.0
einops 0.7.0
etils 1.7.0
exceptiongroup 1.2.0
executing 2.0.1
filelock 3.13.1
flax 0.7.0
frozenlist 1.4.1
fsspec 2024.2.0
gcsfs 2024.2.0
gitdb 4.0.11
GitPython 3.1.42
google-api-core 2.17.1
google-auth 2.28.0
google-auth-oauthlib 1.2.0
google-cloud-core 2.4.1
google-cloud-storage 2.14.0
google-crc32c 1.5.0
google-resumable-media 2.7.0
googleapis-common-protos 1.62.0
huggingface-hub 0.20.3
idna 3.6
imageio 2.34.0
imageio-ffmpeg 0.4.9
importlib-resources 6.1.1
ipdb 0.13.13
ipython 8.21.0
jax 0.4.23
jaxlib 0.4.23+cuda11.cudnn86
jedi 0.19.1
markdown-it-py 3.0.0
matplotlib-inline 0.1.6
mdurl 0.1.2
ml-collections 0.1.1
ml-dtypes 0.3.2
msgpack 1.0.7
multidict 6.0.5
multiprocess 0.70.14
nest-asyncio 1.6.0
numpy 1.26.4
nvidia-cublas-cu12 12.3.4.1
nvidia-cuda-cupti-cu12 12.3.101
nvidia-cuda-nvcc-cu12 12.3.107
nvidia-cuda-nvrtc-cu12 12.3.107
nvidia-cuda-runtime-cu12 12.3.101
nvidia-cudnn-cu12 8.9.7.29
nvidia-cufft-cu12 11.0.12.1
nvidia-cusolver-cu12 11.5.4.101
nvidia-cusparse-cu12 12.2.0.103
nvidia-nvjitlink-cu12 12.3.101
oauthlib 3.2.2
opt-einsum 3.3.0
optax 0.1.7
orbax-checkpoint 0.5.3
packaging 23.2
pandas 2.2.0
parso 0.8.3
pexpect 4.9.0
pillow 10.2.0
pip 23.3.1
prompt-toolkit 3.0.43
protobuf 4.25.3
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 15.0.0
pyasn1 0.5.1
pyasn1-modules 0.3.0
Pygments 2.17.2
python-dateutil 2.8.2
pytz 2024.1
PyYAML 6.0.1
regex 2023.12.25
requests 2.31.0
requests-oauthlib 1.3.1
rich 13.7.0
rsa 4.9
scipy 1.12.0
sentencepiece 0.2.0
sentry-sdk 1.40.5
setproctitle 1.3.3
setuptools 68.2.2
six 1.16.0
smmap 5.0.1
stack-data 0.6.3
tensorstore 0.1.53
tiktoken 0.6.0
tokenizers 0.13.3
tomli 2.0.1
toolz 0.12.1
tqdm 4.66.2
traitlets 5.14.1
transformers 4.29.2
tux 0.0.2
typing_extensions 4.9.0
tzdata 2024.1
urllib3 2.2.1
wandb 0.16.3
wcwidth 0.2.13
wheel 0.41.2
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0

Error Messasge

I0223 22:58:05.579038 140312230876992 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0223 22:58:05.579842 140312230876992 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
100%|██████████| 1/1 [00:09<00:00, 9.21s/it]
2024-02-23 23:00:08.992159: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992208: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992237: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992261: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992281: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992348: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992392: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992430: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992459: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992472: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992483: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992499: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992510: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992522: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992536: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992551: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/users/yu01.xie/software/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/users/yu01.xie/software/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/users/yu01.xie/project/mllm/LWM-main/lwm/vision_chat.py", line 254, in
run(main)
File "/home/users/yu01.xie/software/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/users/yu01.xie/software/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/users/yu01.xie/project/mllm/LWM-main/lwm/vision_chat.py", line 250, in main
output = sampler(prompts, FLAGS.max_n_frames)[0]
File "/home/users/yu01.xie/project/mllm/LWM-main/lwm/vision_chat.py", line 230, in call
output, self.sharded_rng = self._forward_generate(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

LWM-Chat in PyTorch

Hi, thanks for releasing the code. Looks pretty interesting! I noticed that the LWM-Chat (multimodal) model checkpoint is only released in Jax. It would be great if you could release the model in PyTorch as well as you did for other text-only models!

License

Hi,
Great work on LWM! I noticed the weights are licensed under the Apache license but derived from Llama 2, do both the Llama 2 license and the Apache license apply to the weights?
Thanks!

(Related to #10)

Request for guidance on fine-tuning a model with custom data

Hello, I am a big fan of your project and I am interested in using your model for my own data. However, I am new to fine-tuning models and I am not sure how to proceed. Could you please provide some guidance on the steps I need to take to fine-tune your model with my own data?

Specifically, I would like to know:

1.What format my data needs to be in
2.How to preprocess my data
3.How to configure the model for fine-tuning
4.How to train the model on my data
5.How to evaluate the performance of the fine-tuned model

I would greatly appreciate any help you can provide. Thank you!

I hope this helps! Let me know if you have any other questions.

Also, it's worth noting that it's always a good idea to include as much information as possible about your specific use case and any error messages or unexpected behavior you are encountering when you are creating an issue on GitHub. This will help the maintainers of the project to better understand your problem and provide a more accurate solution.

TypeError: unsupported operand type(s) for |: 'type' and 'NoneType'

@wilson1yan It didn't work. More samples are needed including language and vision version.

./scripts/run_vision_chat.sh
Traceback (most recent call last):
File "/home/jiapeiyang/anaconda3/envs/nlp/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/jiapeiyang/anaconda3/envs/nlp/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/jiapeiyang/workspace/LWM/lwm/vision_chat.py", line 18, in
from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM
File "/home/jiapeiyang/workspace/LWM/lwm/vision_llama.py", line 21, in
from lwm.llama import LLaMAConfig, LLAMA_STANDARD_CONFIGS, FlaxLLaMABlockCollection, RMSNorm
File "/home/jiapeiyang/workspace/LWM/lwm/llama.py", line 31, in
from lwm.ring_attention import blockwise_ffn, ring_flash_attention_tpu,
File "/home/jiapeiyang/workspace/LWM/lwm/ring_attention.py", line 557, in
class BlockSizes:
File "/home/jiapeiyang/workspace/LWM/lwm/ring_attention.py", line 563, in BlockSizes
block_q_major_dkv: int | None = None
TypeError: unsupported operand type(s) for |: 'type' and 'NoneType'

`
export llama_tokenizer_path="/home/jiapeiyang/workspace/LWM/models/LWM-Chat-32K-Jax/tokenizer.model"
export vqgan_checkpoint="/home/jiapeiyang/workspace/LWM/models/LWM-Chat-32K-Jax/vqgan"
export lwm_checkpoint="/home/jiapeiyang/workspace/LWM/models/LWM-Chat-32K-Jax/params"
export input_file="/home/jiapeiyang/workspace/LWM/models/LWM-Chat-32K-Jax/test_a.jpg"

python3 -u -m lwm.vision_chat
--prompt="What is the video about?"
--input_file="$input_file"
--vqgan_checkpoint="$vqgan_checkpoint"
--mesh_dim='!1,1,8,1'
--dtype='fp32'
--load_llama_config='7b'
--max_n_frames=8
--update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)"
--load_checkpoint="params::$lwm_checkpoint"
--tokenizer.vocab_file="$llama_tokenizer_path"
2>&1 | tee ~/output.log
read

`

run run_samle_image.sh error

当prompt的长度超过8000个token的时候,就会报错 Argument list too long
image
请问我该如何解决呢?

non tpu inference

we need some samples that can run actually inference on vision / image samples on gpu

(lwm) ➜ LWM git:(main) ./scripts/run_sample_image.sh
WARNING: Logging before InitGoogle() is written to STDERR
I0000 00:00:1707909968.893833 11548 common_lib.cc:148] Failed to fetch URL on try 1 out of 6: Couldn't connect to server
I0000 00:00:1707909972.473827 11548 common_lib.cc:148] Failed to fetch URL on try 2 out of 6: Couldn't connect to server
I0000 00:00:1707909976.025912 11548 common_lib.cc:148] Failed to fetch URL on try 3 out of 6: Couldn't connect to server
^C^CI0000 00:00:1707909979.577878 11548 common_lib.cc:148] Failed to fetch URL on try 4 out of 6: Couldn't connect to server
^CI0000 00:00:1707909983.129666 11548 common_lib.cc:148] Failed to fetch URL on try 5 out of 6: Couldn't connect to server
I0000 00:00:1707909986.681892 11548 common_lib.cc:148] Failed to fetch URL on try 6 out of 6: Couldn't connect to server
Failed to get 'tpu-env' from instance metadata: INTERNAL: Couldn't connect to server
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:145
learning/45eac/tfrc/runtime/common_lib.cc:162
learning/45eac/tfrc/runtime/common_lib.cc:188

I0000 00:00:1707909990.233946 11548 common_lib.cc:148] Failed to fetch URL on try 1 out of 6: Couldn't connect to server
I0000 00:00:1707909993.785913 11548 common_lib.cc:148] Failed to fetch URL on try 2 out of 6: Couldn't connect to server
I0000 00:00:1707909997.337871 11548 common_lib.cc:148] Failed to fetch URL on try 3 out of 6: Couldn't connect to server
I0000 00:00:1707910000.890123 11548 common_lib.cc:148] Failed to fetch URL on try 4 out of 6: Couldn't connect to server
I0000 00:00:1707910004.442720 11548 common_lib.cc:148] Failed to fetch URL on try 5 out of 6: Couldn't connect to server
I0000 00:00:1707910007.994498 11548 common_lib.cc:148] Failed to fetch URL on try 6 out of 6: Couldn't connect to server
Failed to get 'tpu-env' from instance metadata: INTERNAL: Couldn't connect to server
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:145
learning/45eac/tfrc/runtime/common_lib.cc:162
learning/45eac/tfrc/runtime/common_lib.cc:188

Any experiments on reasoning or world modeling ?

This paper talks about long context length for a language model which is extended to be a vision-language model. I wonder why is it called World Model. It is not obvious in the paper. This paper seems focus more on the long context and evaluation of related retrieval ability with little discussion on the world modelling.

I wonder is there any specific discovery on the model ability that improves along with long context training. Does it make it more robust against prompt variations? More robust on reasoning ? More semantically riched in concept representations ? Better ontological/hierarchical learning towards the meaning ?

Will be curious on hearing more about the findings from the authors.

Thanks a lot for any insights in advance : )

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.