octo-models / octo Goto Github PK
View Code? Open in Web Editor NEWOcto is a transformer-based robot policy trained on a diverse mix of 800k robot trajectories.
Home Page: https://octo-models.github.io/
License: MIT License
Octo is a transformer-based robot policy trained on a diverse mix of 800k robot trajectories.
Home Page: https://octo-models.github.io/
License: MIT License
Hi, great work, and thanks for your sharing!
I have read your paper and got great inspiration from the papers. However, I still find something unclear:
The history frames. You mentioned that a one-frame history is beneficial for pre-training. Then how do you manage the input data, would that be like: ----, e.g. the interleaved style? In other words, do we need to split all the original trajectories into 2-steps chunks? And how do you account for the first frame, and repeat it?
Shuffle buffer size. The buffer means the "sampling frame from different trajectories across datasets"? Please point it out if I understand wrongly.
Heads. It seems that the diffusion policy head is the most robust and efficient one.
Step Modelling. I am curious about how to model the step information. Since one trajectory has one task instruction with many steps, do we need to differentiate the differences between steps? Also, how do you decide that the robot should stop? Using some heuristics?
Thanks again for your sharing!
Hi !
I managed to convert my (very small for now) training dataset to rlds format, then to run 02_finetune_new_observation_action.py after modifying mainly this to match the action space of my robot and the image key name :
logging.info("Loading finetuning dataset...")
dataset = make_single_dataset(
dataset_kwargs=dict(
name="ReachyDataset",
data_dir=FLAGS.data_dir,
image_obs_keys={"primary": "head"},
state_obs_keys=["state"],
language_key="language_instruction",
action_proprio_normalization_type=NormalizationType.NORMAL,
absolute_action_mask=[True] * 19,
),
[...]
The training curves look like this :
My finetuning data was in the real world (20 examples of 10 seconds each of the robot being teleoperated by me to grab a small wooden cube). Below is an example of one recorded episode
Then I tried run inference, inspired by 01_inference_pretrained.ipynb. My code looks like this:
model = OctoModel.load_pretrained("/data1/apirrone/octo/trainings/")
task = model.create_tasks(texts=["Grab the wooden cube"])
while True:
observation = {
"image_primary": get_image(),
"proprio": get_state(),
"pad_mask": np.array([[True]]),
}
actions = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))[0]
actions = (
actions * model.dataset_statistics["action"]["std"]
+ model.dataset_statistics["action"]["mean"]
)
for step in actions:
set_joints(np.array(step))
time.sleep(0.02)
time.sleep(0.5)
I am running it on our simulator for now as I don't have access to a robot right now.
This is what I get (I actually fixed the head in place in this example):
Is this expected with such a little finetuning dataset ? Would having a lot more data solve it on its own ?
I also have a few warnings that are a little concerning :)
INFO:root:No task inputs matching image_primary were found. Replacing with zero padding.
WARNING:root:'observations' is missing items compared to example_batch: {'pad_mask_dict/timestep', 'timestep', 'pad_mask_dict/image_primary', 'pad_mask_dict/proprio'}
WARNING:root:No pad_mask_dict found. Nothing will be masked.
Overall I don't think I understand properly what I am doing yet. For example :
pad_mask
in the observations
dict in this context.model.sample_actions([...])
, does this return the full trajectory that is supposed to "solve" the task ? Or should I sample it multiple times with new observations ?pad_mask_dict
?Any help would be greatly appreciated !
Thanks,
Antoine
I train model from scratch on the V100 , the cpu memory grew during the training process, and after about 10,000 more steps, the cpu memory of the machine was used up, 512GB of memory
How do I locate this problem, thank you
Using the following script:
python scripts/train.py --config scripts/configs/octo_pretrain_config.py:vit_s
Is the Franka implementation, mentioned in the paper, available anywhere? I would like to run the code and finetune it for our Franka manipulator :)
Polymetis is already installed, what I really need is the integrations using a gym Env, and specifically the step function :)
Hi, thanks for your great work!
I have finetuned the model by using examples/02_finetune_new_observation_action.py
. And I'm running examples/03_eval_finetuned.py
to show the finetuned results.
I followed the instructions
octo/examples/03_eval_finetuned.py
Lines 9 to 11 in 8fe7497
and add sys.path.append("/path/to/act")
. But still cannot make gym.make("aloha-sim-cube-v0")
successful.
Another problem is that I cannot successfully load the finetuned model. Here's the backtrace.
Traceback (most recent call last):
File "/code/octo/examples/03_eval_finetuned.py", line 101, in <module>
app.run(main)
File "/opt/conda/envs/octo/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/opt/conda/envs/octo/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/code/octo/examples/03_eval_finetuned.py", line 35, in main
model = OctoModel.load_pretrained(FLAGS.finetuned_path)
File "/code/octo/octo/model/octo_model.py", line 274, in load_pretrained
params = checkpointer.restore(step, params_shape)
File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 550, in restore
restored_items = self._restore_impl(
File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 582, in _restore_impl
restored[item_name] = self._checkpointers[item_name].restore(
File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 165, in restore
restored = self._restore_with_args(directory, *args, **kwargs)
File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 103, in _restore_with_args
restored = self._handler.restore(directory, args=ckpt_args)
File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 1063, in restore
restored_item = _transform_checkpoint(
File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 601, in _transform_checkpoint
item = utils.deserialize_tree(restored, item)
File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/utils.py", line 281, in deserialize_tree
return jax.tree_util.tree_map_with_path(
File "/opt/conda/envs/octo/lib/python3.10/site-packages/jax/_src/tree_util.py", line 857, in tree_map_with_path
return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
File "/opt/conda/envs/octo/lib/python3.10/site-packages/jax/_src/tree_util.py", line 857, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
File "/opt/conda/envs/octo/lib/python3.10/site-packages/orbax/checkpoint/utils.py", line 278, in _reconstruct_from_keypath
result = result[key_name]
KeyError: 'diffusion_model'
It looks like I didn't save the diffusion model in the training process. Did I miss something in the configuration?
Thanks.
I assume the paper's evaluations are with the Base size, but I'm wondering how big the performance gap is between small and base.
Hi, when following 03_eval_finetuned.py
to create environment like:
env = gym.make("aloha-sim-cube-v0")
# add wrappers for history and "receding horizon control", i.e. action chunking
env = HistoryWrapper(env, horizon=1)
env = RHCWrapper(env, exec_horizon=50)
# wrap env to handle action/proprio normalization -- match normalization type to the one used during finetuning
env = UnnormalizeActionProprio(
env, model.dataset_statistics, normalization_type="normal"
)
The environment is finally wrapped by UnnormalizeActionProprio who is both ActionWrapper and ObservationWrapper.
I noticed that the function unnormalize
is called when env.step(actions)
- this is expected as we need to unnormalize the action.
However the normalize
, as well as observetion
function is not called in env.step
. They are called in env.reset
but not step. Would this leads to any critical outcome?
Hello,
Thank you for this amazing project. Great job!
I have a small question regarding the finetunning. In the code it is mentioned to use custom version of bridge dataset (https://rail.eecs.berkeley.edu/datasets/bridge_release/data/tfds/), there are some features just changed (e.g. names, action dimensions, etc.). I was wondering if there was any normalization applied to the action in your custom version of bridge dataset?
Also, one more question, in general in order to finetune your model should actions in new dataset be normalized at the stage of creation?
I was able to fine-tune with a modified version of example 2 with the following action head:
config["model"]["heads"]["action"] = ModuleSpec.create(
L1ActionHead,
pred_horizon=9,
action_dim=11,
readout_key="readout_action",
)
The policy works reasonably well on the robot.
After, I've been trying to fine-tune with a diffusion head, but the robot goes out of control with this.
config["model"]["heads"]["action"] = ModuleSpec.create(
# L1ActionHead,
DiffusionActionHead,
use_map=False,
pred_horizon=9,
action_dim=11,
readout_key="readout_action",
)
The rest of the script is unchanged.
What else could be the problem?
Update: The diffusion-based model seems to be outputting fairly extreme action values, like those that are less than the minimum value or more than the maximum value in the dataset.
These are the action statistics:
'max': array([6.92096353e-03, 7.15068638e-01, 2.65712190e+00, 1.22000003e+00, 4.09653854e+00, 9.43594933e-01, 1.17203128e+00, 2.65219069e+00, 1.00000000e+00, 1.50034428e-01, 4.94167267e-04]),
'mean': array([-1.45641267e+00, 2.27537051e-01, -2.96192672e-02, -6.16574585e-01, -7.61023015e-02, 2.06921268e-02, 4.98067914e-03, -1.26738450e-04, 5.67098975e-01, 7.91745202e-04, -7.36813426e-01]),
'min': array([-2.61999989, -0.05653883, -1.91999996, -1.75 , -2.45071816, -0.88507879, -1.3124876 , -1.5 , 0. , -0.09809657, -1.35774779]),
'std': array([0.50778913, 0.21539085, 0.41647774, 0.72304732, 0.8022927, 0.11436515, 0.08398118, 0.30631647, 0.49528143, 0.01135448, 0.26658419])}
and these are sample outputs after unnormalization:
act: [-3.9953585 1.3044913 2.0527694 -4.231811 3.9353614 0.59251785 -0.41492522 -1.5317091 3.0435061 -0.05598066 -2.0697343 ]
act: [ 1.082533 1.3044913 2.0527694 2.998662 -4.087566 0.59251785
-0.41492522 1.5314556 3.0435061 -0.05598066 -2.0697343 ]
which are clearly out of bounds.
Hello, thanks for the amazing work!
I want to execute the script examples/02_finetune_new_observation_action.py
obtain a fine tuned version of octo to test in the aloha gym environment. I started the training on an A100 but I'm noticing some differences with what it is written in the report:
Also, I'm doing the finetuning on octo-small
and the 24GB and 5h fine-tuning refers to octo-base
so I was expecting it to take even less.
I was wondering whether fine-tuning on the aloha benchmark requires different hyperparameters which results in the inconsistencies or there is an error in the finetuning script.
The only modifications I've done to the code are:
flags.DEFINE_string(
"pretrained_path", 'hf://rail-berkeley/octo-small', "Path to pre-trained Octo checkpoint directory."
)
flags.DEFINE_string("data_dir", '.', "Path to finetuning dataset, in RLDS format.")
flags.DEFINE_string("save_dir", './aloha_octo', "Directory for saving finetuning checkpoints.")
As all the values were previously set to None
.
Do you have any idea what the issue might be?
Thanks in advance 🐙
Hello,
I am attempting to perform fine-tuning on a model using multiple nodes, each equipped with 8 A100 GPUs, and I'm encountering some difficulties. The implementation of Octo is based on JAX, and I initially thought that jax.pmap
could be used to parallelize the work across multiple GPUs and nodes. However, the batch data is in dictionary form, which seems to be incompatible with jax.pmap
, causing the process to not work as expected.
Here are the details of my environment:
Given the situation, I am looking for advice or solutions to achieve multi-node fine-tuning with my setup. Is there a recommended approach to handle batches in dictionary form with jax.pmap
, or perhaps an alternative method to perform fine-tuning across multiple nodes in JAX?
Any insights or suggestions would be greatly appreciated.
Thank you.
Do you have any insights about the differences in configuration between the pre-trained model and the aloha fine-tuned one?
In particular I was wondering
I've fined tuned the Octo model on my custom dataset for our Spot robot and now I'm trying to adapt this example and have some questions! Mainly, it doesn't seem like this example could actually run as-is.
jax.jit
I can get it working but I'm not sure that's the best approach? what's the purpose of the jit?Traceback (most recent call last):
File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/peter/Documents/octo/scripts/eval_conq.py", line 234, in main
action = np.array(policy_fn(obs, task), dtype=np.float64)
File "/home/peter/Documents/octo/scripts/eval_conq.py", line 189, in sample_actions
actions = pretrained_model.sample_actions(
File "/home/peter/Documents/octo/octo/model/octo_model.py", line 185, in sample_actions
pad_mask = observations["pad_mask"]
File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 728, in op
return getattr(self.aval, f"_{name}")(self, *args)
File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 341, in _getitem
return lax_numpy._rewriting_take(self, item)
File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4362, in _rewriting_take
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
File "/home/peter/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4436, in _split_index_for_jit
raise TypeError(f"JAX does not support string indexing; got {idx=}")
TypeError: JAX does not support string indexing; got idx=('pad_mask',)
ResizeImageWrapper
, self.resize_size
should be set somehow, but it isn't. Is this just a typo?I'm happy to make a PR for the things listed here that seem small & easy.
Hi,
In the finetune_config.py I have selected only text_conditioning, but there is a bug issue that occurs in the finetuning code:
File "/home/users/m/octo/octo/data/dataset.py", line 98, in apply_trajectory_transforms
getattr(goal_relabeling, goal_relabeling_strategy),
AttributeError: module 'octo.data.utils.goal_relabeling' has no attribute 'no_image_conditioning'
In file goal_relabeling indeed no function "no_image_conditioning", so I was wondering if you would release it?
Thank in advance.
In the paper, in the last page (19)
it is mentioned that in the coffee demo a single third-person wrist
were used. What do you mean by this? Does this contradict with what you said in the paper regarding wrist images achieving worse performance due to lack of data.
Also the pod picking and placement is done in one motion/task or two?
Thank you!
The zero-shot capability of the model is mentioned several times in the paper and compared with the RT-1-X and RT-2-X respectively
In the evaluation experiment, how to define the zero-shot capability?
(The operation of objects is different , but robotic arm in the evaluation stage has appeared in the training data?)
In the comparison data of the experiment, is RT-1-X using the open source model, and RT-2-X is the data of the cited paper?
Thank you!
Hello,
I have installed octo as instructed. When I wanted to test the dataloading I got the error:
ModuleNotFoundError:` No module named 'dlimp'
Inside the requirements.txt there is this line:
dlimp @ git+https://github.com/kvablack/dlimp@d08da3852c149548aaa8551186d619d87375df08
When i followed that link it leads to a non existing repo.
Hi,
Using scripts/finetuning.py, I have been trying to train models with different manually collected dataset on pick/pick_and_place tasks, but the problem with rising validation MSE loss occurred. To simplify the experiments (mentioned in #29) I have collected a rather simple dataset of picking one object from the same location with no object rotation, so most of them look more or less the same. But the problem is that the validation MSE loss is still rising without any real reason. The train and validation are split like 50 and 7 trajectories accordingly.
So, I was wondering if you have any idea what could be the problem here?
Maybe, adding MSE of diffusion policy (used in train loss) could show a better picture on validation?
Thank you for your great work! I encountered some dependency conflicts when installing the project environment, where the version of scipy should be earlier than the latest 1.13.0 (2024.4.2). Otherwise, problems "AttributeError: module 'scipy.linalg' has no attribute 'tril'" will occur. This can be solved by returning the scipy version to 1.12.0. Hope you can fix the dependency version in requirement.txt, which is scipy<1.6.0,>=1.12.0. Thank you very much!
I want to try a new data augmentation strategy that is not covered by https://github.com/kvablack/dlimp/blob/main/dlimp/augmentations.py
I'm looking at this part of the example:
octo/examples/02_finetune_new_observation_action.py
Lines 68 to 86 in 8fe7497
Is there a suggested way of implementing this? For example, if I have a custom function fn(image) that returns the augmented image, is there a way to integrate this into training easily?
Edit:
Specifically, I want to implement something like mixup, so I'll have to sample from an image dataset every call. I'm not as familiar with TensorFlow so maybe I'll try switching to the PyTorch dataloader...
Hi,
In your paper, it was mentioned that you were doing finetuning with NVIDIA A500 with 24 Gb of VRAM. I was wondering what batch_size you used in that case. Thanks
I'm running the 04_eval_finetuned_on_robot.py script provided in examples, straight out of the box. I tried using the language conditioning option, and providing a new instruction of "move the green object to the bottom right burner" when prompted.
Running into this issue in the pass through the model - wondering if this is a common problem or if there's any hints on fixing this?
File "/home/arhan/projects/widowx_octo_inference/04_eval_finetuned_on_robot.py", line 238, in <module>
app.run(main)
File "/home/arhan/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/arhan/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/arhan/projects/widowx_octo_inference/04_eval_finetuned_on_robot.py", line 212, in main
action = np.array(policy_fn(obs, task), dtype=np.float64)
File "/home/arhan/projects/widowx_octo_inference/04_eval_finetuned_on_robot.py", line 127, in sample_actions
actions = pretrained_model.sample_actions(
File "/home/arhan/projects/octo/octo/model/octo_model.py", line 187, in sample_actions
transformer_outputs = self.run_transformer(
File "/home/arhan/projects/octo/octo/model/octo_model.py", line 152, in run_transformer
return self.module.apply(
File "/home/arhan/projects/octo/octo/model/octo_module.py", line 274, in __call__
outputs["obs"] = TokenGroup.concatenate(
File "/home/arhan/projects/octo/octo/model/components/base.py", line 31, in concatenate
data = jnp.concatenate([t.tokens for t in group_list], axis=axis)
File "/home/arhan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1839, in concatenate
raise ValueError("Need at least one array to concatenate.")
ValueError: Need at least one array to concatenate.```
When I ran the fine-tuning script, I noticed that there were warnings related to CUDA, cuDNN,TensorRT etc.
I follow the env setting in the readme. I suspected that these might be due to the incompatibility between JAX and the environment.
(octo) wenbo@wenbo-4090:~/Documents/data/octo/scripts$ python finetune.py
/media/wenbo/12T/octo/scripts/finetune.py:3: DeprecationWarning: the imp module is deprecated in favour of importlib and slated for removal in Python 3.12; see the module's documentation for alternative uses
import imp
2024-03-27 16:43:55.322631: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-27 16:43:55.322684: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-27 16:43:55.453763: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-27 16:43:57.157329: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
W0327 16:44:02.909117 128929550399296 compilation_cache.py:59] Initialized persistent compilation cache at /home/wenbo/.jax_compilation_cache
I0327 16:44:03.371840 128929550399296 xla_bridge.py:633] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0327 16:44:03.382258 128929550399296 xla_bridge.py:633] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
I0327 16:44:03.382936 128929550399296 finetune_rlbench.py:66]
I am trying to pretrain on a dataset and my intended use case is to have three images tokenized as inputs to the transformers and an action head with 2 outputs. When I run the script I do see that this is indeed the case -
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ t=0 obs_instruction (16 tokens) ┃ t=0 obs_primary (16 tokens) ┃ t=0 obs_secondary (16 tokens) ┃ t=0 readout_action (1 tokens) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ t=0 obs_instruction (16 tokens) │ x │ x │ x │ x │
├──────────────────────────────────┼──────────────────────────────────┼──────────────────────────────┼────────────────────────────────┼────────────────────────────────┤
│ t=0 obs_primary (16 tokens) │ x │ x │ x │ x │
├──────────────────────────────────┼──────────────────────────────────┼──────────────────────────────┼────────────────────────────────┼────────────────────────────────┤
│ t=0 obs_secondary (16 tokens) │ x │ x │ x │ x │
├──────────────────────────────────┼──────────────────────────────────┼──────────────────────────────┼────────────────────────────────┼────────────────────────────────┤
│ t=0 readout_action (1 tokens) │ │ │ │ x │
└──────────────────────────────────┴──────────────────────────────────┴──────────────────────────────┴────────────────────────────────┴────────────────────────────────┘
But I get these as INFO messages
I1229 11:29:38.141532 140098393864000 tokenizers.py:123] No task inputs matching image_instruction were found. Replacing with zero padding.
I1229 11:29:38.191684 140098393864000 tokenizers.py:123] No task inputs matching image_primary were found. Replacing with zero padding.
I1229 11:29:38.241426 140098393864000 tokenizers.py:123] No task inputs matching image_secondary were found. Replacing with zero padding.
My input config for observation is the following
config["model"]["observation_tokenizers"] = {
"primary": ModuleSpec.create(
ImageTokenizer,
obs_stack_keys=["image_primary"],
task_stack_keys=["image_primary"],
encoder=ModuleSpec.create(SmallStem16),
),
"secondary": ModuleSpec.create(
ImageTokenizer,
obs_stack_keys=["image_secondary"],
task_stack_keys=["image_secondary"],
encoder=ModuleSpec.create(SmallStem16),
),
"instruction": ModuleSpec.create(
ImageTokenizer,
obs_stack_keys=["image_instruction"],
task_stack_keys=["image_instruction"],
encoder=ModuleSpec.create(SmallStem16),
),
}
Going through the image tokenizer code seems like there is obs_stack_keys in the case I want to stack the input? And then there is task input which I am not sure what is it meant for? Am I doing this in the right way?
When finetuning the model I will have to create a dataset of some episodes.
Do you have any resources on how you recorded/created a RLDS dataset? It seems to be somewhat niche with little documentation.
When running the examples/03_eval_finetuned.py script on a GPU (A100 80G, cuda12.2), I encountered the following error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to allocate 98697216 bytes for new constant
Here are details:
Traceback (most recent call last):
File "/home/houxuan/code/robot/octo/examples/03_eval_finetuned1.py", line 217, in
app.run(main)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/houxuan/code/robot/octo/examples/03_eval_finetuned1.py", line 193, in main
actions = policy_fn(
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 177, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/pjit.py", line 256, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/pjit.py", line 167, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/core.py", line 2656, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/core.py", line 388, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/core.py", line 868, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/pjit.py", line 1212, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/pjit.py", line 1196, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/pjit.py", line 1152, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/profiler.py", line 340, in wrapper
return func(*args, **kwargs)
File "/home/houxuan/miniconda3/envs/octo/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1152, in call
results = self.xla_executable.execute_sharded(input_bufs)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to allocate 98697216 bytes for new constant
Making the following modification in the code resolves the problem:
# jit model action prediction function for faster inference
# policy_fn = jax.jit(model.sample_actions)
policy_fn = model.sample_actions
any suggestion for how to make jax.jit transformation work?
Are there learning curves available for the HF models and would it be possible to share these to help enable others to reproduce the results of the paper? Completely understand if these are not stored/available asking just in case they are.
Google Colab from link: https://colab.research.google.com/drive/1z0vELj_lX9OWeoMG_WvXnQs43aPOEAhz?usp=sharing
can not run. This line in step 1 action = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))
seems to be the issue. I get this error
AssertionError Traceback (most recent call last)
[<ipython-input-23-97a79f13d085>](https://localhost:8080/#) in <cell line: 8>()
6 task = model.create_tasks(texts=["pick up the fork"])
7
----> 8 action = model.sample_actions(observation, task, rng=3)#jax.random.PRNGKey(0))
9 print(action) # [batch, action_chunk, action_dim]
[... skipping hidden 12 frame]
2 frames
[... skipping hidden 12 frame]
[/content/octo/octo/model/octo_model.py](https://localhost:8080/#) in _verify_shapes(pytree, name, example_pytree, starting_dim, strict, raise_error, silent)
488
489 if raise_error and (fail or (weak_fail and strict)):
--> 490 raise AssertionError(f"{name} does not match example batch.")
491
492 return weak_fail or fail
AssertionError: observations does not match example batch.```
I am running on a T4 GPU, have not tried with A100 but I am not sure if that can be the root of the issues
I installed octo in my virtualenv using pip as instructed in Readme.
After that, when I run the test code, I can't proceed any further with the message "Very slow compile?"
I also tried reinstalling jax, but I don't know the cause.
How can I solve this?
octo/scripts/finetune.py:3: DeprecationWarning: the imp module is deprecated in favour of importlib and slated for removal in Python 3.12; see the module's documentation for alternative uses
import imp
2024-03-18 10:58:02.165197: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-18 10:58:02.165237: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-18 10:58:02.166325: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-18 10:58:02.959024: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-03-18 10:58:05.042673: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
W0318 10:58:05.043408 140289948624704 compilation_cache.py:59] Initialized persistent compilation cache at /home/haha/.jax_compilation_cache
I0318 10:58:05.063319 140289948624704 xla_bridge.py:633] Unable to initialize backend 'cuda': Unable to load CUDA.
I0318 10:58:05.063727 140289948624704 xla_bridge.py:633] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0318 10:58:05.065797 140289948624704 xla_bridge.py:633] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
W0318 10:58:05.065938 140289948624704 xla_bridge.py:697] CUDA backend failed to initialize: Unable to load CUDA. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0318 10:58:05.066064 140289948624704 finetune.py:63]
Octo Finetuning Script
======================
Pretrained model: hf://rail-berkeley/octo-small
Finetuning Dataset: bridge_dataset
Data dir: ./tests/debug_dataset
Task Modality: multimodal
Finetuning Mode: full
# Devices: 1
Batch size: 256 (256 per device)
# Steps: 50000
Fetching 8 files: 100%|█████████████████████████| 8/8 [00:00<00:00, 4329.05it/s]
I0318 10:58:13.746319 140289948624704 checkpointer.py:164] Restoring item from /home/haha/.cache/huggingface/hub/models--rail-berkeley--octo-small/snapshots/03d88976c54a58e10480d2043a8c762b35bc2611/270000/default.
I0318 10:58:14.635694 140289948624704 checkpointer.py:166] Finished restoring checkpoint from /home/haha/.cache/huggingface/hub/models--rail-berkeley--octo-small/snapshots/03d88976c54a58e10480d2043a8c762b35bc2611/270000/default.
I0318 10:58:16.732701 140289948624704 dataset_info.py:578] Load dataset info from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:58:16.845350 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split all, from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:58:17.743555 140289948624704 data_utils.py:110] Loading existing dataset statistics from tests/debug_dataset/bridge_dataset/1.0.0/dataset_statistics_38c366095a9a20c72a862f6a1a4a5ae4fe98ed4507d67cfa7314de76b872c6c4.json.
I0318 10:58:17.802858 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split train, from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:08.408026 140289948624704 train_utils.py:248] Freezing parameters that include the following keys: ['*hf_model*'].
I0318 10:59:08.411441 140289948624704 train_utils.py:284] Num trainable params: 27,042,060.
I0318 10:59:08.411522 140289948624704 train_utils.py:285] Num frozen params: 109,628,544.
I0318 10:59:08.411549 140289948624704 train_utils.py:286] To see a detailed list of frozen params, set logging level to DEBUG.
W0318 10:59:09.221674 140289948624704 finetune.py:261] save_dir not passed in, not saving checkpoints
I0318 10:59:09.225611 140289948624704 dataset_info.py:578] Load dataset info from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:09.294973 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split all, from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:09.553126 140289948624704 data_utils.py:110] Loading existing dataset statistics from tests/debug_dataset/bridge_dataset/1.0.0/dataset_statistics_38c366095a9a20c72a862f6a1a4a5ae4fe98ed4507d67cfa7314de76b872c6c4.json.
I0318 10:59:09.608441 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split val, from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:10.278485 140289948624704 dataset_info.py:578] Load dataset info from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:10.341788 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split all, from tests/debug_dataset/bridge_dataset/1.0.0
I0318 10:59:10.596746 140289948624704 data_utils.py:110] Loading existing dataset statistics from tests/debug_dataset/bridge_dataset/1.0.0/dataset_statistics_38c366095a9a20c72a862f6a1a4a5ae4fe98ed4507d67cfa7314de76b872c6c4.json.
I0318 10:59:10.650095 140289948624704 logging_logger.py:49] Constructing tf.data.Dataset bridge_dataset for split val, from tests/debug_dataset/bridge_dataset/1.0.0
0%| | 0/50000 [00:00<?, ?it/s]2024-03-18 11:01:32.364956: E external/xla/xla/service/slow_operation_alarm.cc:65]
********************************
[Compiling module jit_train_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2024-03-18 11:03:52.031034: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 4m19.665375883s
********************************
[Compiling module jit_train_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
0%| | 9/50000 [30:23<2447:58:38, 176.29s/it]
Hi, thanks for this great job! Currently the recommended configuration is NVIDIA driver>520 + Cuda11.8 + Cudnn8.6 and NVIDIA driver>525 + Cuda12.2 + Cudnn8.9, but unfortunately my PC doesn't support such high config. Has anyone else tried any other lower version config, such as NVIDIA driver 470?
Warmly for any related info. Jax is really too difficult for me.
From inspecting the model, it seems there actually is no "wrist" observation tokenizer for the small model, only the "base" model. Also, I'm only seeing the "language" task tokenizer for both the "base" and "small" models. I'm checking by looking in pretarined_model.config['model']
.
Perhaps the descriptions on HF should be updated?
Can you provide the stl to 3D print the wrist camera holder for the franka panda robot.
Hi, I'm trying to train all parameters by modifying the example, but it seems like it's not working for me. What might be the issue?
With the original script, I get the following log:
298 I1218 11:34:09.575279 140267533906560 train_utils.py:403] ########## Parameters skipped during model loading: ##########
299 I1218 11:34:09.575732 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.LayerNorm_0.bias
300 I1218 11:34:09.575860 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.LayerNorm_0.scale
301 I1218 11:34:09.575954 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_0.bias
302 I1218 11:34:09.576039 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_0.kernel
303 I1218 11:34:09.576118 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_1.bias
304 I1218 11:34:09.576195 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_1.kernel
305 I1218 11:34:09.576265 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.key.bias
306 I1218 11:34:09.576333 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.key.kernel
307 I1218 11:34:09.576400 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.out.bias
308 I1218 11:34:09.576465 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.out.kernel
309 I1218 11:34:09.576580 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.query.bias
310 I1218 11:34:09.576655 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.query.kernel
311 I1218 11:34:09.576721 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.value.bias
312 I1218 11:34:09.576784 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.value.kernel
313 I1218 11:34:09.576846 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.probe
314 I1218 11:34:09.576910 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.mean_proj.bias
315 I1218 11:34:09.576974 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.mean_proj.kernel
316 I1218 11:34:09.577038 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_pos_embedding
317 I1218 11:34:09.577101 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_projection.bias
318 I1218 11:34:09.577165 140267533906560 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_projection.kernel
319 I1218 11:34:09.578479 140267533906560 train_utils.py:248] Freezing parameters that include the following keys: ['*hf_model*'].
320 I1218 11:34:09.581623 140267533906560 train_utils.py:284] Num trainable params: 27,178,734.
321 I1218 11:34:09.581736 140267533906560 train_utils.py:285] Num frozen params: 109,628,544.
322 I1218 11:34:09.581822 140267533906560 train_utils.py:286] To see a detailed list of frozen params, set logging level to DEBUG.
323 I1218 11:34:10.776679 140267533906560 02_finetune_new_observation_action.py:186] Starting finetuning...
Then, I set frozen_keys to an empty list in a crude attempt to fine-tune fully. I got the following with this:
305 I1218 18:19:57.928706 139778262782592 train_utils.py:403] ########## Parameters skipped during model loading: ##########
306 I1218 18:19:57.929069 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.LayerNorm_0.bias
307 I1218 18:19:57.929162 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.LayerNorm_0.scale
308 I1218 18:19:57.929238 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_0.bias
309 I1218 18:19:57.929309 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_0.kernel
310 I1218 18:19:57.929384 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_1.bias
311 I1218 18:19:57.929462 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MlpBlock_0.Dense_1.kernel
312 I1218 18:19:57.929538 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.key.bias
313 I1218 18:19:57.929613 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.key.kernel
314 I1218 18:19:57.929690 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.out.bias
315 I1218 18:19:57.929764 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.out.kernel
316 I1218 18:19:57.929843 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.query.bias
317 I1218 18:19:57.929908 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.query.kernel
318 I1218 18:19:57.929970 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.value.bias
319 I1218 18:19:57.930034 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.MultiHeadDotProductAttention_0.value.kernel
320 I1218 18:19:57.930099 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.map_head.probe
321 I1218 18:19:57.930162 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.mean_proj.bias
322 I1218 18:19:57.930227 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: heads_action.mean_proj.kernel
323 I1218 18:19:57.930290 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_pos_embedding
324 I1218 18:19:57.930354 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_projection.bias
325 I1218 18:19:57.930418 139778262782592 train_utils.py:405] Param missing in pre-trained model, skipping: octo_transformer.obs_proprio_projection.kernel
326 I1218 18:19:57.931729 139778262782592 train_utils.py:248] Freezing parameters that include the following keys: [].
327 I1218 18:19:57.934117 139778262782592 train_utils.py:284] Num trainable params: 136,807,278.
328 I1218 18:19:57.934224 139778262782592 train_utils.py:285] Num frozen params: 0.
329 I1218 18:19:57.934307 139778262782592 train_utils.py:286] To see a detailed list of frozen params, set logging level to DEBUG.
330 I1218 18:19:59.401661 139778262782592 02_finetune_new_observation_action.py:186] Starting finetuning...
0 frozen params, so this seemed right. However, VRAM usage did not change (honestly I thought I would get GPU OOM with this), and the loss curve remained mostly unchanged from the default script:
(I'm using a custom dataset)
So I'm guessing I'm missing/misunderstanding something?
I think it would be cool to have a discord or meeting group to discuss developments here or somewhere.
Enabling python -m pip install git+https://github.com/octo-models/octo
is useful in colabs for example but creates an error:
https://colab.research.google.com/drive/1Arhc2gvfMmWUEDCYqeopSKr0zzzfx8Q-#scrollTo=2ZDF2sSQ8hmB
collecting git+https://github.com/octo-models/octo
Cloning https://github.com/octo-models/octo to /tmp/pip-req-build-6fhdexfl
Running command git clone --filter=blob:none --quiet https://github.com/octo-models/octo /tmp/pip-req-build-6fhdexfl
Resolved https://github.com/octo-models/octo to commit 653c54acde686fde619855f2eac0dd6edad7116b
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing metadata (pyproject.toml) ... done
Building wheels for collected packages: octo
Building wheel for octo (pyproject.toml) ... done
Created wheel for octo: filename=octo-0.0.0-py3-none-any.whl size=1898 sha256=3e00c63e651d2ed51d54253c47ca02ed3ce75281d633d63e6bedadd356ad5be8
Stored in directory: /tmp/pip-ephem-wheel-cache-fx6kt1v_/wheels/ca/2e/83/3e87bf03d7e424b0e89f061c2ed34eb2a6cee22ed0cfa40b52
Successfully built octo
Installing collected packages: octo
Successfully installed octo-0.0.0
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
[<ipython-input-1-230dc17a8c11>](https://localhost:8080/#) in <cell line: 9>()
7 """
8 import numpy as np
----> 9 from octo.data.dataset import make_interleaved_dataset
10 from octo.data.oxe import make_oxe_dataset_kwargs_and_weights
11 import tensorflow as tf
ModuleNotFoundError: No module named 'octo.data'
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the
"Open Examples" button below.
Whether I set this to False
or True
I am getting the same number of parameters:
Flags.DEFINE_bool(
"freeze_transformer",
True,
"Whether pre-trained transformer weights should be frozen.",
)
When it is set to False
, I get in the log:
Freezing parameters that include the following keys: ['*hf_model*'].
Num trainable params: 25,639,496.
Num frozen params: 109,628,544.
If I set it to True
:
Freezing parameters that include the following keys: ['*hf_model*', 'head_mlp_only'].
Num trainable params: 25,639,496.
Num frozen params: 109,628,544.
Both have same number of parameters!
Hi,
Great work ! And thank you very much for sharing the code and pretrained weights !
After browsing the repo and project page, I am a little bit confused on how I could apply octo to a new robot.
I would like to run it on a humanoid robot (Reachy https://www.pollen-robotics.com/), which has a geometry that is quite different from the robot arms you showcase.
I guess I would need to collect data for fine tuning and provide a description of the robot to the system somehow, as well as write some kind of interface ? But it is not clear to me what are the expected formats for the data, the robot description (URDF?) etc.
Could you give an overview of the required steps ?
Thank you vey much !
Excited by the work, great paper and open release.
I am interested in testing some ideas that will involve pretraining (e.g. architecture changes, etc.), likely without access to a real-world setup, at least at first. Just starting to look at the codebase.
Curious about recommendations for sim/offline evaluation. 1) For evaluation, any recs / best practices for separating datasets into train/test/validation, or holding out rt-x datasets. What seem to be the most useful proxies for real-world perf. 2) I saw there are provided examples for sim finetuning, are there any results that could be shared for simulated envs? Are there any sim envs that "work" for testing zero-shot eval in addition to finetuning?
Thanks!
I think it should be (x - mean) * std?
octo/octo/utils/gym_wrappers.py
Lines 283 to 285 in 37951e4
Hi! When I use your example in https://github.com/octo-models/octo/blob/main/examples/06_pytorch_oxe_dataloader.py, loading data with multiple GPUs in parallel is extremely slow, taking hundreds of times longer than with a single GPU. How can I solve this problem?
The following code reproduces the issue.
import os
import random
import tqdm
import torch
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
import torch.distributed as dist
import numpy as np
import time
from octo.data.dataset import make_interleaved_dataset
from octo.data.oxe import make_oxe_dataset_kwargs_and_weights
DATA_PATH = "your_path_to_oxe"
def setup(rank, world_size, port):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
class TorchRLDSDataset(torch.utils.data.IterableDataset):
"""Thin wrapper around RLDS dataset for use with PyTorch dataloaders."""
def __init__(
self,
rlds_dataset,
train=True,
):
self._rlds_dataset = rlds_dataset
self._is_train = train
def __iter__(self):
for sample in self._rlds_dataset.as_numpy_iterator():
yield sample
def __len__(self):
lengths = np.array(
[
stats["num_transitions"]
for stats in self._rlds_dataset.dataset_statistics
]
)
if hasattr(self._rlds_dataset, "sample_weights"):
lengths *= np.array(self._rlds_dataset.sample_weights)
total_len = lengths.sum()
if self._is_train:
return int(0.95 * total_len)
else:
return int(0.05 * total_len)
def experiment(rank, devices, port):
device = devices[rank]
device = f"cuda:{device}"
setup(rank, world_size=len(devices), port=port)
dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights(
"oxe_magic_soup",
DATA_PATH,
load_camera_views=("primary", "wrist"),
)
dataset = make_interleaved_dataset(
dataset_kwargs_list,
sample_weights,
train=True,
shuffle_buffer_size=1000, # change to 500k for training, large shuffle buffers are important, but adjust to your RAM
batch_size=None, # batching will be handles in PyTorch Dataloader object
balance_weights=True,
traj_transform_kwargs=dict(
goal_relabeling_strategy="uniform",
window_size=2,
future_action_window_size=3,
subsample_length=100,
),
frame_transform_kwargs=dict(
image_augment_kwargs={
"primary": dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.1],
random_contrast=[0.9, 1.1],
random_saturation=[0.9, 1.1],
random_hue=[0.05],
augment_order=[
"random_resized_crop",
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
),
"wrist": dict(
random_brightness=[0.1],
random_contrast=[0.9, 1.1],
random_saturation=[0.9, 1.1],
random_hue=[0.05],
augment_order=[
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
),
},
resize_size=dict(
primary=(256, 256),
wrist=(128, 128),
),
num_parallel_calls=200,
),
traj_transform_threads=48,
traj_read_threads=48,
)
pytorch_dataset = TorchRLDSDataset(dataset)
dataloader = DataLoader(
pytorch_dataset,
batch_size=16,
num_workers=0, # important to keep this to 0 so PyTorch does not mess with the parallelism
)
for i, sample in tqdm.tqdm(enumerate(dataloader)):
if i == 5000:
break
if __name__ == "__main__":
devices = [0, 1, 2, 3]
port = (random.randint(0, 3000) % 3000) + 27000
mp.spawn(experiment, args=(devices, port), nprocs=len(devices), join=True)
Hello,
Setup: new environment
Robot: Franka panda
Data de-normalization: I am using the Berkeley cable routing as it also uses Franka robot ( would you recommend different one?)
I am trying to run your base model on a Franka panda robot. I have not done fine tuning yet. I wanted to check and test the initial behavior first.
For now I am getting random actions from the robot and some times fixed noisy actions( always moving down, no matter what the instruction/goal image is or even the robot initial position). It is not succeeding for simple tasks like move to object or pick up object.
Is this the expected behavior? or it should at least move in the direction of the object?
Thank you .
I'm trying to finetune Octo on custom data by modifying a script in the examples:
dataset = make_single_dataset(
dataset_kwargs=dict(
name="custom",
data_dir=FLAGS.data_dir,
image_obs_keys={"primary": "head", "wrist": "hand"},
state_obs_keys=["state"],
language_key="language_instruction",
action_proprio_normalization_type=NormalizationType.NORMAL,
absolute_action_mask=[True] * 11,
),
traj_transform_kwargs=dict(
window_size=1,
future_action_window_size=49, # so we get 50 actions for our action chunk
),
frame_transform_kwargs=dict(
resize_size={"primary": (256, 256), "wrist": (128, 128)},
),
train=True,
)
The original data has head and hand images both at 224x224, 6-D proprioception, and an 11-D action space. But the output seems to indicate that the wrist image is being ignored. What might I be doing incorrectly?
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ task_language (16 tokens) ┃ t=0 obs_primary (256 tokens) ┃ t=0 obs_proprio (6 tokens) ┃ t=0 readout_action (1 tokens) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ task_language (16 tokens) │ x │ x │ x │ x │
├────────────────────────────────┼───────────────────────────┼───────────────────────────────┼─────────────────────────────┼────────────────────────────────┤
│ t=0 obs_primary (256 tokens) │ │ x │ x │ x │
├────────────────────────────────┼───────────────────────────┼───────────────────────────────┼─────────────────────────────┼────────────────────────────────┤
│ t=0 obs_proprio (6 tokens) │ │ x │ x │ x │
├────────────────────────────────┼───────────────────────────┼───────────────────────────────┼─────────────────────────────┼────────────────────────────────┤
│ t=0 readout_action (1 tokens) │ │ │ │ x │
└────────────────────────────────┴───────────────────────────┴───────────────────────────────┴─────────────────────────────┴────────────────────────────────┘
after installing using the info provided in the 03 script, still has error
File "/dataSSD/3zhou/miniconda/envs/octo/lib/python3.10/site-packages/gym/envs/registration.py", line 219, in _check_version_exists
_check_name_exists(ns, name)
File "/dataSSD/3zhou/miniconda/envs/octo/lib/python3.10/site-packages/gym/envs/registration.py", line 197, in _check_name_exists
raise error.NameNotFound(
gym.error.NameNotFound: Environment aloha-sim-cube doesn't exist.
and sys.path.append also pointed to that act local file
Hi thanks your great work. I notice that you provide the finetune code, do you consider to directly provide the ckpt of finetuning example?
Hi congradulation to this work!
I am trying to eval / train octo on my in-house data, which is bascially a gym environment that can render RGB images. I have no background on RLDS format and OXE datasets. Do you have any resource to share that I can quickly start working? Thanks!
Octo's action space comprises end-effector velocities, representing changes in ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']. I intend to assess the model's zero-shot capability in a simulator. Despite understanding the significant domain gap, my goal is to verify the pipeline's error-free operation. I'm utilizing RLBench, where I observed that the action space is defined by joint angles, differing from Octo's.
For questions:
Regarding zero-shot, should I compute the corresponding joint angles from Octo's outputs using inverse kinematics, ensuring model-environment action alignment? Is this correct?
Regarding few-shot, should the action space in demonstration be end-effector velocities rather than joint angles? I understand joint angles are directly observable, whereas end-effector velocities necessitate computational conversion.
Thanks for your great work!
great work ! But i still have several questions to understand the model
First question: Can a general vision model understand the depth information from a camera? Does the depth information require special processing?
Second question: Dual-arm robots and single-arm robots may have different behavioral patterns; does Octo distinguish between these two types of robots?
Third question: For the wrist camera of a dual-arm robot, is special processing required?
Thank you very much for your time and assistance.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.