Coder Social home page Coder Social logo

octo's People

Contributors

andrearosasco avatar andrehe02 avatar dibyaghosh avatar ericjang avatar homerw avatar kpertsch avatar kvablack avatar mees avatar moojink avatar nicklashansen avatar ojh6404 avatar sudeepdasari avatar youliangtan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

octo's Issues

Some questions wrt the paper's Appendix E

Hi, great work, and thanks for your sharing!

I have read your paper and got great inspiration from the papers. However, I still find something unclear:

  1. The history frames. You mentioned that a one-frame history is beneficial for pre-training. Then how do you manage the input data, would that be like: ----, e.g. the interleaved style? In other words, do we need to split all the original trajectories into 2-steps chunks? And how do you account for the first frame, and repeat it?

  2. Shuffle buffer size. The buffer means the "sampling frame from different trajectories across datasets"? Please point it out if I understand wrongly.

  3. Heads. It seems that the diffusion policy head is the most robust and efficient one.

  4. Step Modelling. I am curious about how to model the step information. Since one trajectory has one task instruction with many steps, do we need to differentiate the differences between steps? Also, how do you decide that the robot should stop? Using some heuristics?

Thanks again for your sharing!

Random behavior after fine tuning octo on new robot

Hi !

I managed to convert my (very small for now) training dataset to rlds format, then to run 02_finetune_new_observation_action.py after modifying mainly this to match the action space of my robot and the image key name :

   logging.info("Loading finetuning dataset...")
    dataset = make_single_dataset(
        dataset_kwargs=dict(
            name="ReachyDataset",
            data_dir=FLAGS.data_dir,
            image_obs_keys={"primary": "head"},
            state_obs_keys=["state"],
            language_key="language_instruction",
            action_proprio_normalization_type=NormalizationType.NORMAL,
            absolute_action_mask=[True] * 19,
        ),
[...]

The training curves look like this :
Capture d’écran du 2024-01-07 11-25-32

My finetuning data was in the real world (20 examples of 10 seconds each of the robot being teleoperated by me to grab a small wooden cube). Below is an example of one recorded episode
episode

Then I tried run inference, inspired by 01_inference_pretrained.ipynb. My code looks like this:

model = OctoModel.load_pretrained("/data1/apirrone/octo/trainings/")
task = model.create_tasks(texts=["Grab the wooden cube"])

while True:
    observation = {
        "image_primary": get_image(),
        "proprio": get_state(),
        "pad_mask": np.array([[True]]),
    }
    actions = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))[0]
    actions = (
        actions * model.dataset_statistics["action"]["std"]
        + model.dataset_statistics["action"]["mean"]
    )
    for step in actions:
        set_joints(np.array(step))
        time.sleep(0.02)

    time.sleep(0.5)

I am running it on our simulator for now as I don't have access to a robot right now.
This is what I get (I actually fixed the head in place in this example):

octo_2-2024-01-07_11.53.39.mp4

Is this expected with such a little finetuning dataset ? Would having a lot more data solve it on its own ?

I also have a few warnings that are a little concerning :)

INFO:root:No task inputs matching image_primary were found. Replacing with zero padding.
WARNING:root:'observations' is missing items compared to example_batch: {'pad_mask_dict/timestep', 'timestep', 'pad_mask_dict/image_primary', 'pad_mask_dict/proprio'}
WARNING:root:No pad_mask_dict found. Nothing will be masked.

Overall I don't think I understand properly what I am doing yet. For example :

  • I don't know what is the role of pad_mask in the observations dict in this context.
  • Also, when calling model.sample_actions([...]), does this return the full trajectory that is supposed to "solve" the task ? Or should I sample it multiple times with new observations ?
  • What is pad_mask_dict ?

Any help would be greatly appreciated !

Thanks,

Antoine

There appears to be a memory leak during the training phase?

I train model from scratch on the V100 , the cpu memory grew during the training process, and after about 10,000 more steps, the cpu memory of the machine was used up, 512GB of memory

How do I locate this problem, thank you

Using the following script:
python scripts/train.py --config scripts/configs/octo_pretrain_config.py:vit_s

Franka implementation

Is the Franka implementation, mentioned in the paper, available anywhere? I would like to run the code and finetune it for our Franka manipulator :)

Polymetis is already installed, what I really need is the integrations using a gym Env, and specifically the step function :)

Failed to eval finetuned model on `aloha-sim-cube` gym environment

Hi, thanks for your great work!

I have finetuned the model by using examples/02_finetune_new_observation_action.py. And I'm running examples/03_eval_finetuned.py to show the finetuned results.

I followed the instructions

Finally modify the sys.path.append statement below to add the ACT repo to your path and start a virtual display:
Xvfb :1 -screen 0 1024x768x16 &
export DISPLAY=:1

and add sys.path.append("/path/to/act"). But still cannot make gym.make("aloha-sim-cube-v0") successful.

Another problem is that I cannot successfully load the finetuned model. Here's the backtrace.

Traceback (most recent call last):
  File "/code/octo/examples/03_eval_finetuned.py", line 101, in <module>
    app.run(main)
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/code/octo/examples/03_eval_finetuned.py", line 35, in main
    model = OctoModel.load_pretrained(FLAGS.finetuned_path)
  File "/code/octo/octo/model/octo_model.py", line 274, in load_pretrained
    params = checkpointer.restore(step, params_shape)
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 550, in restore
    restored_items = self._restore_impl(
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 582, in _restore_impl
    restored[item_name] = self._checkpointers[item_name].restore(
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 165, in restore
    restored = self._restore_with_args(directory, *args, **kwargs)
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 103, in _restore_with_args
    restored = self._handler.restore(directory, args=ckpt_args)
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 1063, in restore
    restored_item = _transform_checkpoint(
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 601, in _transform_checkpoint
    item = utils.deserialize_tree(restored, item)
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/utils.py", line 281, in deserialize_tree
    return jax.tree_util.tree_map_with_path(
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/jax/_src/tree_util.py", line 857, in tree_map_with_path
    return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/jax/_src/tree_util.py", line 857, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
  File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/utils.py", line 278, in _reconstruct_from_keypath
    result = result[key_name]
KeyError: 'diffusion_model'

It looks like I didn't save the diffusion model in the training process. Did I miss something in the configuration?

Thanks.

`UnnormalizeActionProprio`'s observation function is not called in env.step

Hi, when following 03_eval_finetuned.py to create environment like:

    env = gym.make("aloha-sim-cube-v0")

    # add wrappers for history and "receding horizon control", i.e. action chunking
    env = HistoryWrapper(env, horizon=1)
    env = RHCWrapper(env, exec_horizon=50)

    # wrap env to handle action/proprio normalization -- match normalization type to the one used during finetuning
    env = UnnormalizeActionProprio(
        env, model.dataset_statistics, normalization_type="normal"
    )

The environment is finally wrapped by UnnormalizeActionProprio who is both ActionWrapper and ObservationWrapper.

I noticed that the function unnormalize is called when env.step(actions) - this is expected as we need to unnormalize the action.

However the normalize, as well as observetion function is not called in env.step. They are called in env.reset but not step. Would this leads to any critical outcome?

Bridge dataset action normalization

Hello,

Thank you for this amazing project. Great job!

I have a small question regarding the finetunning. In the code it is mentioned to use custom version of bridge dataset (https://rail.eecs.berkeley.edu/datasets/bridge_release/data/tfds/), there are some features just changed (e.g. names, action dimensions, etc.). I was wondering if there was any normalization applied to the action in your custom version of bridge dataset?

Also, one more question, in general in order to finetune your model should actions in new dataset be normalized at the stage of creation?

Issue with diffusion head

I was able to fine-tune with a modified version of example 2 with the following action head:

config["model"]["heads"]["action"] = ModuleSpec.create(
    L1ActionHead,
    pred_horizon=9,
    action_dim=11,
    readout_key="readout_action",
)

The policy works reasonably well on the robot.
After, I've been trying to fine-tune with a diffusion head, but the robot goes out of control with this.

config["model"]["heads"]["action"] = ModuleSpec.create(
    # L1ActionHead,
    DiffusionActionHead,
    use_map=False,

    pred_horizon=9,
    action_dim=11,
    readout_key="readout_action",
)

The rest of the script is unchanged.
What else could be the problem?

Update: The diffusion-based model seems to be outputting fairly extreme action values, like those that are less than the minimum value or more than the maximum value in the dataset.

These are the action statistics:

'max': array([6.92096353e-03, 7.15068638e-01, 2.65712190e+00, 1.22000003e+00, 4.09653854e+00, 9.43594933e-01, 1.17203128e+00, 2.65219069e+00, 1.00000000e+00, 1.50034428e-01, 4.94167267e-04]),
'mean': array([-1.45641267e+00,  2.27537051e-01, -2.96192672e-02, -6.16574585e-01, -7.61023015e-02,  2.06921268e-02,  4.98067914e-03, -1.26738450e-04, 5.67098975e-01,  7.91745202e-04, -7.36813426e-01]),
'min': array([-2.61999989, -0.05653883, -1.91999996, -1.75      , -2.45071816, -0.88507879, -1.3124876 , -1.5       ,  0.        , -0.09809657, -1.35774779]),
'std': array([0.50778913, 0.21539085, 0.41647774, 0.72304732, 0.8022927, 0.11436515, 0.08398118, 0.30631647, 0.49528143, 0.01135448, 0.26658419])}

and these are sample outputs after unnormalization:

act: [-3.9953585   1.3044913   2.0527694  -4.231811    3.9353614   0.59251785   -0.41492522 -1.5317091   3.0435061  -0.05598066 -2.0697343 ]
act: [ 1.082533    1.3044913   2.0527694   2.998662   -4.087566    0.59251785
 -0.41492522  1.5314556   3.0435061  -0.05598066 -2.0697343 ]

which are clearly out of bounds.

Inconsistencies in model fine-tuning

Hello, thanks for the amazing work!
I want to execute the script examples/02_finetune_new_observation_action.py obtain a fine tuned version of octo to test in the aloha gym environment. I started the training on an A100 but I'm noticing some differences with what it is written in the report:

  • The script trains the model for 5k steps instead of 50k steps
  • The fine-tuning process is currently taking 77GB of VRAM while it should be able to run on 24GB
  • The current estimate to get to 5k steps is 80h instead of 5h

Also, I'm doing the finetuning on octo-small and the 24GB and 5h fine-tuning refers to octo-base so I was expecting it to take even less.

I was wondering whether fine-tuning on the aloha benchmark requires different hyperparameters which results in the inconsistencies or there is an error in the finetuning script.

The only modifications I've done to the code are:

flags.DEFINE_string(
    "pretrained_path", 'hf://rail-berkeley/octo-small', "Path to pre-trained Octo checkpoint directory."
)
flags.DEFINE_string("data_dir", '.', "Path to finetuning dataset, in RLDS format.")
flags.DEFINE_string("save_dir", './aloha_octo', "Directory for saving finetuning checkpoints.")

As all the values were previously set to None.

Do you have any idea what the issue might be?

Thanks in advance 🐙

Multi-node Fine-tuning with JAX on A100 Clusters

Hello,

I am attempting to perform fine-tuning on a model using multiple nodes, each equipped with 8 A100 GPUs, and I'm encountering some difficulties. The implementation of Octo is based on JAX, and I initially thought that jax.pmap could be used to parallelize the work across multiple GPUs and nodes. However, the batch data is in dictionary form, which seems to be incompatible with jax.pmap, causing the process to not work as expected.

Here are the details of my environment:

  • OS: Ubuntu 20.04
  • CUDA Version: 11.3
  • Python Version: 3.10

Given the situation, I am looking for advice or solutions to achieve multi-node fine-tuning with my setup. Is there a recommended approach to handle batches in dictionary form with jax.pmap, or perhaps an alternative method to perform fine-tuning across multiple nodes in JAX?

Any insights or suggestions would be greatly appreciated.

Thank you.

Aloha Finetuning Configuration

Do you have any insights about the differences in configuration between the pre-trained model and the aloha fine-tuned one?

Screenshot 2024-01-22 143318

In particular I was wondering

  • why the aloha fine-tuning does not use a diffusion head
  • why the pretrained model only has action chunks of dimension 4
  • why the pretrained model misses a tokenizer for the proprioception

Issues running eval on new robot platform

I've fined tuned the Octo model on my custom dataset for our Spot robot and now I'm trying to adapt this example and have some questions! Mainly, it doesn't seem like this example could actually run as-is.

  1. It seems the args "argmax" and "temperature" are not actually used and I think need to be removed?
  2. I get this error which I haven't yet debugged much. If i remove the jax.jit I can get it working but I'm not sure that's the best approach? what's the purpose of the jit?
Traceback (most recent call last):
  File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/peter/Documents/octo/scripts/eval_conq.py", line 234, in main
    action = np.array(policy_fn(obs, task), dtype=np.float64)
  File "/home/peter/Documents/octo/scripts/eval_conq.py", line 189, in sample_actions
    actions = pretrained_model.sample_actions(
  File "/home/peter/Documents/octo/octo/model/octo_model.py", line 185, in sample_actions
    pad_mask = observations["pad_mask"]
  File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 728, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 341, in _getitem
    return lax_numpy._rewriting_take(self, item)
  File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4362, in _rewriting_take
    treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
  File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4436, in _split_index_for_jit
    raise TypeError(f"JAX does not support string indexing; got {idx=}")
TypeError: JAX does not support string indexing; got idx=('pad_mask',)
  1. In ResizeImageWrapper, self.resize_size should be set somehow, but it isn't. Is this just a typo?

I'm happy to make a PR for the things listed here that seem small & easy.

Running finetune script with text conditioning only

Hi,

In the finetune_config.py I have selected only text_conditioning, but there is a bug issue that occurs in the finetuning code:

File "/home/users/m/octo/octo/data/dataset.py", line 98, in apply_trajectory_transforms
getattr(goal_relabeling, goal_relabeling_strategy),
AttributeError: module 'octo.data.utils.goal_relabeling' has no attribute 'no_image_conditioning'

In file goal_relabeling indeed no function "no_image_conditioning", so I was wondering if you would release it?

Thank in advance.

Third-person wrist image in the paper

In the paper, in the last page (19) it is mentioned that in the coffee demo a single third-person wrist were used. What do you mean by this? Does this contradict with what you said in the paper regarding wrist images achieving worse performance due to lack of data.

Also the pod picking and placement is done in one motion/task or two?

Thank you!

Discuss the zero-shot capability of the model in the paper

The zero-shot capability of the model is mentioned several times in the paper and compared with the RT-1-X and RT-2-X respectively

  1. In the evaluation experiment, how to define the zero-shot capability?
    (The operation of objects is different , but robotic arm in the evaluation stage has appeared in the training data?)

  2. In the comparison data of the experiment, is RT-1-X using the open source model, and RT-2-X is the data of the cited paper?

Thank you!

Rising validation loss

Hi,

Using scripts/finetuning.py, I have been trying to train models with different manually collected dataset on pick/pick_and_place tasks, but the problem with rising validation MSE loss occurred. To simplify the experiments (mentioned in #29) I have collected a rather simple dataset of picking one object from the same location with no object rotation, so most of them look more or less the same. But the problem is that the validation MSE loss is still rising without any real reason. The train and validation are split like 50 and 7 trajectories accordingly.

image

So, I was wondering if you have any idea what could be the problem here?

Maybe, adding MSE of diffusion policy (used in train loss) could show a better picture on validation?

AttributeError: module 'scipy.linalg' has no attribute 'tril'

Thank you for your great work! I encountered some dependency conflicts when installing the project environment, where the version of scipy should be earlier than the latest 1.13.0 (2024.4.2). Otherwise, problems "AttributeError: module 'scipy.linalg' has no attribute 'tril'" will occur. This can be solved by returning the scipy version to 1.12.0. Hope you can fix the dependency version in requirement.txt, which is scipy<1.6.0,>=1.12.0. Thank you very much!

Custom data augmentation

I want to try a new data augmentation strategy that is not covered by https://github.com/kvablack/dlimp/blob/main/dlimp/augmentations.py

I'm looking at this part of the example:

dataset = make_single_dataset(
dataset_kwargs=dict(
name="aloha_sim_cube_scripted_dataset",
data_dir=FLAGS.data_dir,
image_obs_keys={"primary": "top"},
state_obs_keys=["state"],
language_key="language_instruction",
action_proprio_normalization_type=NormalizationType.NORMAL,
absolute_action_mask=[True] * 14,
),
traj_transform_kwargs=dict(
window_size=1,
future_action_window_size=49, # so we get 50 actions for our action chunk
),
frame_transform_kwargs=dict(
resize_size={"primary": (256, 256)},
),
train=True,
)

Is there a suggested way of implementing this? For example, if I have a custom function fn(image) that returns the augmented image, is there a way to integrate this into training easily?

Edit:
Specifically, I want to implement something like mixup, so I'll have to sample from an image dataset every call. I'm not as familiar with TensorFlow so maybe I'll try switching to the PyTorch dataloader...

Finetuning on A500

Hi,

In your paper, it was mentioned that you were doing finetuning with NVIDIA A500 with 24 Gb of VRAM. I was wondering what batch_size you used in that case. Thanks

Evaluating Octo on WidowX - Example Script Error

I'm running the 04_eval_finetuned_on_robot.py script provided in examples, straight out of the box. I tried using the language conditioning option, and providing a new instruction of "move the green object to the bottom right burner" when prompted.

Running into this issue in the pass through the model - wondering if this is a common problem or if there's any hints on fixing this?

  File "/home/arhan/projects/widowx_octo_inference/04_eval_finetuned_on_robot.py", line 238, in <module>
    app.run(main)
  File "/home/arhan/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/arhan/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/arhan/projects/widowx_octo_inference/04_eval_finetuned_on_robot.py", line 212, in main
    action = np.array(policy_fn(obs, task), dtype=np.float64)
  File "/home/arhan/projects/widowx_octo_inference/04_eval_finetuned_on_robot.py", line 127, in sample_actions
    actions = pretrained_model.sample_actions(
  File "/home/arhan/projects/octo/octo/model/octo_model.py", line 187, in sample_actions
    transformer_outputs = self.run_transformer(
  File "/home/arhan/projects/octo/octo/model/octo_model.py", line 152, in run_transformer
    return self.module.apply(
  File "/home/arhan/projects/octo/octo/model/octo_module.py", line 274, in __call__
    outputs["obs"] = TokenGroup.concatenate(
  File "/home/arhan/projects/octo/octo/model/components/base.py", line 31, in concatenate
    data = jnp.concatenate([t.tokens for t in group_list], axis=axis)
  File "/home/arhan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1839, in concatenate
    raise ValueError("Need at least one array to concatenate.")
ValueError: Need at least one array to concatenate.```

warnings related to CUDA, cuDNN,TensorRT etc.

When I ran the fine-tuning script, I noticed that there were warnings related to CUDA, cuDNN,TensorRT etc.
I follow the env setting in the readme. I suspected that these might be due to the incompatibility between JAX and the environment.

(octo) wenbo@wenbo-4090:~/Documents/data/octo/scripts$ python finetune.py
/media/wenbo/12T/octo/scripts/finetune.py:3: DeprecationWarning: the imp module is deprecated in favour of importlib and slated for removal in Python 3.12; see the module's documentation for alternative uses
import imp
2024-03-27 16:43:55.322631: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-27 16:43:55.322684: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-27 16:43:55.453763: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-27 16:43:57.157329: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
W0327 16:44:02.909117 128929550399296 compilation_cache.py:59] Initialized persistent compilation cache at /home/wenbo/.jax_compilation_cache
I0327 16:44:03.371840 128929550399296 xla_bridge.py:633] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0327 16:44:03.382258 128929550399296 xla_bridge.py:633] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
I0327 16:44:03.382936 128929550399296 finetune_rlbench.py:66]

Question: What is the intended use case for task stack keys?

I am trying to pretrain on a dataset and my intended use case is to have three images tokenized as inputs to the transformers and an action head with 2 outputs. When I run the script I do see that this is indeed the case -

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                                  ┃ t=0 obs_instruction (16 tokens)  ┃ t=0 obs_primary (16 tokens)  ┃ t=0 obs_secondary (16 tokens)  ┃ t=0 readout_action (1 tokens)  ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ t=0 obs_instruction (16 tokens)  │ x                                │ x                            │ x                              │ x                              │
├──────────────────────────────────┼──────────────────────────────────┼──────────────────────────────┼────────────────────────────────┼────────────────────────────────┤
│ t=0 obs_primary (16 tokens)      │ x                                │ x                            │ x                              │ x                              │
├──────────────────────────────────┼──────────────────────────────────┼──────────────────────────────┼────────────────────────────────┼────────────────────────────────┤
│ t=0 obs_secondary (16 tokens)    │ x                                │ x                            │ x                              │ x                              │
├──────────────────────────────────┼──────────────────────────────────┼──────────────────────────────┼────────────────────────────────┼────────────────────────────────┤
│ t=0 readout_action (1 tokens)    │                                  │                              │                                │ x                              │
└──────────────────────────────────┴──────────────────────────────────┴──────────────────────────────┴────────────────────────────────┴────────────────────────────────┘

But I get these as INFO messages

I1229 11:29:38.141532 140098393864000 tokenizers.py:123] No task inputs matching image_instruction were found. Replacing with zero padding.
I1229 11:29:38.191684 140098393864000 tokenizers.py:123] No task inputs matching image_primary were found. Replacing with zero padding.
I1229 11:29:38.241426 140098393864000 tokenizers.py:123] No task inputs matching image_secondary were found. Replacing with zero padding.

My input config for observation is the following

    config["model"]["observation_tokenizers"] = {
        "primary": ModuleSpec.create(
            ImageTokenizer,
            obs_stack_keys=["image_primary"],
            task_stack_keys=["image_primary"],
            encoder=ModuleSpec.create(SmallStem16),
        ),
        "secondary": ModuleSpec.create(
            ImageTokenizer,
            obs_stack_keys=["image_secondary"],
            task_stack_keys=["image_secondary"],
            encoder=ModuleSpec.create(SmallStem16),
        ),
        "instruction": ModuleSpec.create(
            ImageTokenizer,
            obs_stack_keys=["image_instruction"],
            task_stack_keys=["image_instruction"],
            encoder=ModuleSpec.create(SmallStem16),
        ),

    }

Going through the image tokenizer code seems like there is obs_stack_keys in the case I want to stack the input? And then there is task input which I am not sure what is it meant for? Am I doing this in the right way?

Finetuning the model

When finetuning the model I will have to create a dataset of some episodes.
Do you have any resources on how you recorded/created a RLDS dataset? It seems to be somewhat niche with little documentation.

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to allocate 98697216 bytes for new constant

When running the examples/03_eval_finetuned.py script on a GPU (A100 80G, cuda12.2), I encountered the following error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to allocate 98697216 bytes for new constant

Here are details:

Traceback (most recent call last):
File "/home/houxuan/code/robot/octo/examples/03_eval_finetuned1.py", line 217, in
app.run(main)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/houxuan/code/robot/octo/examples/03_eval_finetuned1.py", line 193, in main
actions = policy_fn(
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 177, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/pjit.py", line 256, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/pjit.py", line 167, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/core.py", line 2656, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/core.py", line 388, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/core.py", line 868, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/pjit.py", line 1212, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/pjit.py", line 1196, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/pjit.py", line 1152, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/profiler.py", line 340, in wrapper
return func(*args, **kwargs)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1152, in call
results = self.xla_executable.execute_sharded(input_bufs)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to allocate 98697216 bytes for new constant

Making the following modification in the code resolves the problem:

    # jit model action prediction function for faster inference
    # policy_fn = jax.jit(model.sample_actions)
    policy_fn = model.sample_actions

any suggestion for how to make jax.jit transformation work?

Official Learning Curves for Published Models

Are there learning curves available for the HF models and would it be possible to share these to help enable others to reproduce the results of the paper? Completely understand if these are not stored/available asking just in case they are.

Google Colab not working

Google Colab from link: https://colab.research.google.com/drive/1z0vELj_lX9OWeoMG_WvXnQs43aPOEAhz?usp=sharing
can not run. This line in step 1 action = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0)) seems to be the issue. I get this error


AssertionError                            Traceback (most recent call last)

[<ipython-input-23-97a79f13d085>](https://localhost:8080/#) in <cell line: 8>()
      6 task = model.create_tasks(texts=["pick up the fork"])
      7 
----> 8 action = model.sample_actions(observation, task, rng=3)#jax.random.PRNGKey(0))
      9 print(action)   # [batch, action_chunk, action_dim]

    [... skipping hidden 12 frame]

2 frames

    [... skipping hidden 12 frame]

[/content/octo/octo/model/octo_model.py](https://localhost:8080/#) in _verify_shapes(pytree, name, example_pytree, starting_dim, strict, raise_error, silent)
    488 
    489     if raise_error and (fail or (weak_fail and strict)):
--> 490         raise AssertionError(f"{name} does not match example batch.")
    491 
    492     return weak_fail or fail

AssertionError: observations does not match example batch.```

I am running on a T4 GPU, have not tried with A100 but I am not sure if that can be the root of the issues

Installation and compilation issues

  • Env
    • RTX 4090
    • Ubuntu20.04
    • Python -> Used virtual env (python3.10)
    • CUDA & cudnn in local: cuda-11.2, cudnn 8.6.0

I installed octo in my virtualenv using pip as instructed in Readme.
After that, when I run the test code, I can't proceed any further with the message "Very slow compile?"
I also tried reinstalling jax, but I don't know the cause.
How can I solve this?

octo/scripts/finetune.py:3: DeprecationWarning: the imp module is deprecated in favour of importlib and slated for removal in Python 3.12; see the module's documentation for alternative uses
  import imp
2024-03-18 10:58:02.165197: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-18 10:58:02.165237: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-18 10:58:02.166325: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-18 10:58:02.959024: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-03-18 10:58:05.042673: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
W0318 10:58:05.043408 140289948624704 compilation_cache.py:59] Initialized persistent compilation cache at /home/haha/.jax_compilation_cache
I0318 10:58:05.063319 140289948624704 xla_bridge.py:633] Unable to initialize backend 'cuda': Unable to load CUDA.
I0318 10:58:05.063727 140289948624704 xla_bridge.py:633] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0318 10:58:05.065797 140289948624704 xla_bridge.py:633] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
W0318 10:58:05.065938 140289948624704 xla_bridge.py:697] CUDA backend failed to initialize: Unable to load CUDA. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0318 10:58:05.066064 140289948624704 finetune.py:63] 
        Octo Finetuning Script
        ======================
        Pretrained model: hf://rail-berkeley/octo-small
        Finetuning Dataset: bridge_dataset
        Data dir: ./tests/debug_dataset
        Task Modality: multimodal
        Finetuning Mode: full

        # Devices: 1
        Batch size: 256 (256 per device)
        # Steps: 50000
    
Fetching 8 files: 100%|█████████████████████████| 8/8 [00:00<00:00, 4329.05it/s]
I0318 10:58:13.746319 140289948624704 checkpointer.py:164] Restoring item from /home/haha/.cache/huggingface/hub/models--rail-berkeley--octo-small/snapshots/03d88976c54a58e10480d2043a8c762b35bc2611/270000/default.
I0318 10:58:14.635694 140289948624704 checkpointer.py:166] Finished restoring checkpoint from /home/haha/.cache/huggingface/hub/models--rail-berkeley--octo-small/snapshots/03d88976c54a58e10480d2043a8c762b35bc2611/270000/default.
I0318 10:58:16.732701 140289948624704 dataset_info.py:578] Load dataset info from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:58:16.845350 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split all, from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:58:17.743555 140289948624704 data_utils.py:110] Loading existing dataset statistics from tests/debug_dataset/bridge_dataset/1.0.0/dataset_statistics_38c366095a9a20c72a862f6a1a4a5ae4fe98ed4507d67cfa7314de76b872c6c4.json.
I0318 10:58:17.802858 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split train, from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:08.408026 140289948624704 train_utils.py:248] Freezing parameters that include the following keys: ['*hf_model*'].
I0318 10:59:08.411441 140289948624704 train_utils.py:284] Num trainable params: 27,042,060.
I0318 10:59:08.411522 140289948624704 train_utils.py:285] Num frozen params: 109,628,544.
I0318 10:59:08.411549 140289948624704 train_utils.py:286] To see a detailed list of frozen params, set logging level to DEBUG.
W0318 10:59:09.221674 140289948624704 finetune.py:261] save_dir not passed in, not saving checkpoints
I0318 10:59:09.225611 140289948624704 dataset_info.py:578] Load dataset info from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:09.294973 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split all, from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:09.553126 140289948624704 data_utils.py:110] Loading existing dataset statistics from tests/debug_dataset/bridge_dataset/1.0.0/dataset_statistics_38c366095a9a20c72a862f6a1a4a5ae4fe98ed4507d67cfa7314de76b872c6c4.json.
I0318 10:59:09.608441 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split val, from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:10.278485 140289948624704 dataset_info.py:578] Load dataset info from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:10.341788 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split all, from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:10.596746 140289948624704 data_utils.py:110] Loading existing dataset statistics from tests/debug_dataset/bridge_dataset/1.0.0/dataset_statistics_38c366095a9a20c72a862f6a1a4a5ae4fe98ed4507d67cfa7314de76b872c6c4.json.
I0318 10:59:10.650095 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split val, from tests/debug_dataset/bridge_dataset/1.0.0
  0%|                                                 | 0/50000 [00:00<?, ?it/s]2024-03-18 11:01:32.364956: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_train_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2024-03-18 11:03:52.031034: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 4m19.665375883s

********************************
[Compiling module jit_train_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
  0%|                                   | 9/50000 [30:23<2447:58:38, 176.29s/it]

Trail with any other version (jax, nvidia driver, cuda and cudnn)?

Hi, thanks for this great job! Currently the recommended configuration is NVIDIA driver>520 + Cuda11.8 + Cudnn8.6 and NVIDIA driver>525 + Cuda12.2 + Cudnn8.9, but unfortunately my PC doesn't support such high config. Has anyone else tried any other lower version config, such as NVIDIA driver 470?

Warmly for any related info. Jax is really too difficult for me.

MSE loss in logs

Hi,

After doing fine-tuning, I have checked MSE plots in two tabs: "
offline_metrics_red_cube/text_conditioned" and "training", and they were quite different.

image

I was wondering if it is the same function used in both of evaluations?

Hugging face description for octo-small may be incorrect

From inspecting the model, it seems there actually is no "wrist" observation tokenizer for the small model, only the "base" model. Also, I'm only seeing the "language" task tokenizer for both the "base" and "small" models. I'm checking by looking in pretarined_model.config['model'].

Perhaps the descriptions on HF should be updated?

Full fine-tuning in 02_finetune example

Hi, I'm trying to train all parameters by modifying the example, but it seems like it's not working for me. What might be the issue?

With the original script, I get the following log:

 298 I1218 11:34:09.575279 140267533906560 train_utils.py:403] ########## Parameters skipped during model loading: ##########
 299 I1218 11:34:09.575732 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.LayerNorm_0.bias
 300 I1218 11:34:09.575860 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.LayerNorm_0.scale
 301 I1218 11:34:09.575954 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_0.bias
 302 I1218 11:34:09.576039 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_0.kernel
 303 I1218 11:34:09.576118 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_1.bias
 304 I1218 11:34:09.576195 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_1.kernel
 305 I1218 11:34:09.576265 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.key.bias
 306 I1218 11:34:09.576333 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.key.kernel
 307 I1218 11:34:09.576400 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.out.bias
 308 I1218 11:34:09.576465 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.out.kernel
 309 I1218 11:34:09.576580 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.query.bias
 310 I1218 11:34:09.576655 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.query.kernel
 311 I1218 11:34:09.576721 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.value.bias
 312 I1218 11:34:09.576784 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.value.kernel
 313 I1218 11:34:09.576846 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.probe
 314 I1218 11:34:09.576910 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.mean_proj.bias
 315 I1218 11:34:09.576974 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.mean_proj.kernel
 316 I1218 11:34:09.577038 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_pos_embedding
 317 I1218 11:34:09.577101 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_projection.bias
 318 I1218 11:34:09.577165 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_projection.kernel
 319 I1218 11:34:09.578479 140267533906560 train_utils.py:248] Freezing parameters that include the following keys: ['*hf_model*'].
 320 I1218 11:34:09.581623 140267533906560 train_utils.py:284] Num trainable params: 27,178,734.
 321 I1218 11:34:09.581736 140267533906560 train_utils.py:285] Num frozen params: 109,628,544.
 322 I1218 11:34:09.581822 140267533906560 train_utils.py:286] To see a detailed list of frozen params, set logging level to DEBUG.
 323 I1218 11:34:10.776679 140267533906560 02_finetune_new_observation_action.py:186] Starting finetuning...

Then, I set frozen_keys to an empty list in a crude attempt to fine-tune fully. I got the following with this:

 305 I1218 18:19:57.928706 139778262782592 train_utils.py:403] ########## Parameters skipped during model loading: ##########
 306 I1218 18:19:57.929069 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.LayerNorm_0.bias
 307 I1218 18:19:57.929162 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.LayerNorm_0.scale
 308 I1218 18:19:57.929238 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_0.bias
 309 I1218 18:19:57.929309 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_0.kernel
 310 I1218 18:19:57.929384 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_1.bias
 311 I1218 18:19:57.929462 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_1.kernel
 312 I1218 18:19:57.929538 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.key.bias
 313 I1218 18:19:57.929613 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.key.kernel
 314 I1218 18:19:57.929690 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.out.bias
 315 I1218 18:19:57.929764 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.out.kernel
 316 I1218 18:19:57.929843 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.query.bias
 317 I1218 18:19:57.929908 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.query.kernel
 318 I1218 18:19:57.929970 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.value.bias
 319 I1218 18:19:57.930034 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.value.kernel
 320 I1218 18:19:57.930099 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.probe
 321 I1218 18:19:57.930162 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.mean_proj.bias
 322 I1218 18:19:57.930227 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.mean_proj.kernel
 323 I1218 18:19:57.930290 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_pos_embedding
 324 I1218 18:19:57.930354 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_projection.bias
 325 I1218 18:19:57.930418 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_projection.kernel
 326 I1218 18:19:57.931729 139778262782592 train_utils.py:248] Freezing parameters that include the following keys: [].
 327 I1218 18:19:57.934117 139778262782592 train_utils.py:284] Num trainable params: 136,807,278.
 328 I1218 18:19:57.934224 139778262782592 train_utils.py:285] Num frozen params: 0.
 329 I1218 18:19:57.934307 139778262782592 train_utils.py:286] To see a detailed list of frozen params, set logging level to DEBUG.
 330 I1218 18:19:59.401661 139778262782592 02_finetune_new_observation_action.py:186] Starting finetuning...

0 frozen params, so this seemed right. However, VRAM usage did not change (honestly I thought I would get GPU OOM with this), and the loss curve remained mostly unchanged from the default script:

Screen Shot 2023-12-19 at 9 00 49

(I'm using a custom dataset)

So I'm guessing I'm missing/misunderstanding something?

Discord or meeting group?

I think it would be cool to have a discord or meeting group to discuss developments here or somewhere.

Unable to install octo with pip

Enabling python -m pip install git+https://github.com/octo-models/octo is useful in colabs for example but creates an error:
https://colab.research.google.com/drive/1Arhc2gvfMmWUEDCYqeopSKr0zzzfx8Q-#scrollTo=2ZDF2sSQ8hmB

collecting git+https://github.com/octo-models/octo
  Cloning https://github.com/octo-models/octo to /tmp/pip-req-build-6fhdexfl
  Running command git clone --filter=blob:none --quiet https://github.com/octo-models/octo /tmp/pip-req-build-6fhdexfl
  Resolved https://github.com/octo-models/octo to commit 653c54acde686fde619855f2eac0dd6edad7116b
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
Building wheels for collected packages: octo
  Building wheel for octo (pyproject.toml) ... done
  Created wheel for octo: filename=octo-0.0.0-py3-none-any.whl size=1898 sha256=3e00c63e651d2ed51d54253c47ca02ed3ce75281d633d63e6bedadd356ad5be8
  Stored in directory: /tmp/pip-ephem-wheel-cache-fx6kt1v_/wheels/ca/2e/83/3e87bf03d7e424b0e89f061c2ed34eb2a6cee22ed0cfa40b52
Successfully built octo
Installing collected packages: octo
Successfully installed octo-0.0.0
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-1-230dc17a8c11>](https://localhost:8080/#) in <cell line: 9>()
      7 """
      8 import numpy as np
----> 9 from octo.data.dataset import make_interleaved_dataset
     10 from octo.data.oxe import make_oxe_dataset_kwargs_and_weights
     11 import tensorflow as tf

ModuleNotFoundError: No module named 'octo.data'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.

Bug in the finetune_new_observation_action.py

Whether I set this to False or True I am getting the same number of parameters:

Flags.DEFINE_bool(
    "freeze_transformer",
    True,
    "Whether pre-trained transformer weights should be frozen.",
)

When it is set to False, I get in the log:

Freezing parameters that include the following keys: ['*hf_model*'].
Num trainable params: 25,639,496.
Num frozen params: 109,628,544.

If I set it to True :

Freezing parameters that include the following keys: ['*hf_model*', 'head_mlp_only'].
Num trainable params: 25,639,496.
Num frozen params: 109,628,544.

Both have same number of parameters!

Run on new robot ?

Hi,

Great work ! And thank you very much for sharing the code and pretrained weights !

After browsing the repo and project page, I am a little bit confused on how I could apply octo to a new robot.

I would like to run it on a humanoid robot (Reachy https://www.pollen-robotics.com/), which has a geometry that is quite different from the robot arms you showcase.

I guess I would need to collect data for fine tuning and provide a description of the robot to the system somehow, as well as write some kind of interface ? But it is not clear to me what are the expected formats for the data, the robot description (URDF?) etc.

Could you give an overview of the required steps ?

Thank you vey much !

offline/sim evaluation recommendations

Excited by the work, great paper and open release.

I am interested in testing some ideas that will involve pretraining (e.g. architecture changes, etc.), likely without access to a real-world setup, at least at first. Just starting to look at the codebase.

Curious about recommendations for sim/offline evaluation. 1) For evaluation, any recs / best practices for separating datasets into train/test/validation, or holding out rt-x datasets. What seem to be the most useful proxies for real-world perf. 2) I saw there are provided examples for sim finetuning, are there any results that could be shared for simulated envs? Are there any sim envs that "work" for testing zero-shot eval in addition to finetuning?

Thanks!

Data loading is very slow when using 06_pytorch_oxe_dataloader.py with multiple GPUs

Hi! When I use your example in https://github.com/octo-models/octo/blob/main/examples/06_pytorch_oxe_dataloader.py, loading data with multiple GPUs in parallel is extremely slow, taking hundreds of times longer than with a single GPU. How can I solve this problem?

The following code reproduces the issue.

import os
import random
import tqdm
import torch
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
import torch.distributed as dist
import numpy as np
import time
from octo.data.dataset import make_interleaved_dataset
from octo.data.oxe import make_oxe_dataset_kwargs_and_weights

DATA_PATH = "your_path_to_oxe"

def setup(rank, world_size, port):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)
    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

class TorchRLDSDataset(torch.utils.data.IterableDataset):
    """Thin wrapper around RLDS dataset for use with PyTorch dataloaders."""

    def __init__(
        self,
        rlds_dataset,
        train=True,
    ):
        self._rlds_dataset = rlds_dataset
        self._is_train = train

    def __iter__(self):
        for sample in self._rlds_dataset.as_numpy_iterator():
            yield sample

    def __len__(self):
        lengths = np.array(
            [
                stats["num_transitions"]
                for stats in self._rlds_dataset.dataset_statistics
            ]
        )
        if hasattr(self._rlds_dataset, "sample_weights"):
            lengths *= np.array(self._rlds_dataset.sample_weights)
        total_len = lengths.sum()
        if self._is_train:
            return int(0.95 * total_len)
        else:
            return int(0.05 * total_len)

def experiment(rank, devices, port):
    device = devices[rank]
    device = f"cuda:{device}"
    setup(rank, world_size=len(devices), port=port)

    dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights(
        "oxe_magic_soup",
        DATA_PATH,
        load_camera_views=("primary", "wrist"),
    )

    dataset = make_interleaved_dataset(
        dataset_kwargs_list,
        sample_weights,
        train=True,
        shuffle_buffer_size=1000,  # change to 500k for training, large shuffle buffers are important, but adjust to your RAM
        batch_size=None,  # batching will be handles in PyTorch Dataloader object
        balance_weights=True,
        traj_transform_kwargs=dict(
            goal_relabeling_strategy="uniform",
            window_size=2,
            future_action_window_size=3,
            subsample_length=100,
        ),
        frame_transform_kwargs=dict(
            image_augment_kwargs={
                "primary": dict(
                    random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
                    random_brightness=[0.1],
                    random_contrast=[0.9, 1.1],
                    random_saturation=[0.9, 1.1],
                    random_hue=[0.05],
                    augment_order=[
                        "random_resized_crop",
                        "random_brightness",
                        "random_contrast",
                        "random_saturation",
                        "random_hue",
                    ],
                ),
                "wrist": dict(
                    random_brightness=[0.1],
                    random_contrast=[0.9, 1.1],
                    random_saturation=[0.9, 1.1],
                    random_hue=[0.05],
                    augment_order=[
                        "random_brightness",
                        "random_contrast",
                        "random_saturation",
                        "random_hue",
                    ],
                ),
            },
            resize_size=dict(
                primary=(256, 256),
                wrist=(128, 128),
            ),
            num_parallel_calls=200,
        ),
        traj_transform_threads=48,
        traj_read_threads=48,
    )


    pytorch_dataset = TorchRLDSDataset(dataset)
    dataloader = DataLoader(
        pytorch_dataset,
        batch_size=16,
        num_workers=0,  # important to keep this to 0 so PyTorch does not mess with the parallelism
    )

    for i, sample in tqdm.tqdm(enumerate(dataloader)):
        if i == 5000:
            break


if __name__ == "__main__":
    devices = [0, 1, 2, 3]
    port = (random.randint(0, 3000) % 3000) + 27000
    mp.spawn(experiment, args=(devices, port), nprocs=len(devices), join=True)

Random / Noisy behavior of a new robot

Hello,

Setup: new environment
Robot: Franka panda
Data de-normalization:  I am using the Berkeley cable routing as it also uses Franka robot ( would you recommend different one?)

I am trying to run your base model on a Franka panda robot. I have not done fine tuning yet. I wanted to check and test the initial behavior first.

For now I am getting random actions from the robot and some times fixed noisy actions( always moving down, no matter what the instruction/goal image is or even the robot initial position). It is not succeeding for simple tasks like move to object or pick up object.

Is this the expected behavior? or it should at least move in the direction of the object?

Thank you .

Can't get primary + wrist camera

I'm trying to finetune Octo on custom data by modifying a script in the examples:

dataset = make_single_dataset(
    dataset_kwargs=dict(
        name="custom",
        data_dir=FLAGS.data_dir,
        image_obs_keys={"primary": "head", "wrist": "hand"},
        state_obs_keys=["state"],
        language_key="language_instruction",
        action_proprio_normalization_type=NormalizationType.NORMAL,
        absolute_action_mask=[True] * 11,
    ),
    traj_transform_kwargs=dict(
        window_size=1,
        future_action_window_size=49,  # so we get 50 actions for our action chunk
    ),
    frame_transform_kwargs=dict(
        resize_size={"primary": (256, 256), "wrist": (128, 128)},
    ),
    train=True,
)

The original data has head and hand images both at 224x224, 6-D proprioception, and an 11-D action space. But the output seems to indicate that the wrist image is being ignored. What might I be doing incorrectly?

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                                ┃ task_language (16 tokens) ┃ t=0 obs_primary (256 tokens)  ┃ t=0 obs_proprio (6 tokens)  ┃ t=0 readout_action (1 tokens)  ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ task_language (16 tokens)      │ x                         │ x                             │ x                           │ x                              │
├────────────────────────────────┼───────────────────────────┼───────────────────────────────┼─────────────────────────────┼────────────────────────────────┤
│ t=0 obs_primary (256 tokens)   │                           │ x                             │ x                           │ x                              │
├────────────────────────────────┼───────────────────────────┼───────────────────────────────┼─────────────────────────────┼────────────────────────────────┤
│ t=0 obs_proprio (6 tokens)     │                           │ x                             │ x                           │ x                              │
├────────────────────────────────┼───────────────────────────┼───────────────────────────────┼─────────────────────────────┼────────────────────────────────┤
│ t=0 readout_action (1 tokens)  │                           │                               │                             │ x                              │
└────────────────────────────────┴───────────────────────────┴───────────────────────────────┴─────────────────────────────┴────────────────────────────────┘

Still can't find the aloha-sim-cube after installation

after installing using the info provided in the 03 script, still has error

 File "/dataSSD/3zhou/miniconda/envs/octo/lib/python3.10/site-packages/gym/envs/registration.py", line 219, in _check_version_exists
    _check_name_exists(ns, name)
  File "/dataSSD/3zhou/miniconda/envs/octo/lib/python3.10/site-packages/gym/envs/registration.py", line 197, in _check_name_exists
    raise error.NameNotFound(
gym.error.NameNotFound: Environment aloha-sim-cube doesn't exist. 

and sys.path.append also pointed to that act local file

Can you provide finetuning ckpt

Hi thanks your great work. I notice that you provide the finetune code, do you consider to directly provide the ckpt of finetuning example?

Guideline for training on in-house data?

Hi congradulation to this work!

I am trying to eval / train octo on my in-house data, which is bascially a gym environment that can render RGB images. I have no background on RLDS format and OXE datasets. Do you have any resource to share that I can quickly start working? Thanks!

Question about action space

Octo's action space comprises end-effector velocities, representing changes in ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']. I intend to assess the model's zero-shot capability in a simulator. Despite understanding the significant domain gap, my goal is to verify the pipeline's error-free operation. I'm utilizing RLBench, where I observed that the action space is defined by joint angles, differing from Octo's.

For questions:

  1. Regarding zero-shot, should I compute the corresponding joint angles from Octo's outputs using inverse kinematics, ensuring model-environment action alignment? Is this correct?

  2. Regarding few-shot, should the action space in demonstration be end-effector velocities rather than joint angles? I understand joint angles are directly observable, whereas end-effector velocities necessitate computational conversion.

Thanks for your great work!

dose octo support two wrist cameras

great work ! But i still have several questions to understand the model
First question: Can a general vision model understand the depth information from a camera? Does the depth information require special processing?

Second question: Dual-arm robots and single-arm robots may have different behavioral patterns; does Octo distinguish between these two types of robots?

Third question: For the wrist camera of a dual-arm robot, is special processing required?

Thank you very much for your time and assistance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.