Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DROID Policy Learning/Evaluation #144

Draft
wants to merge 44 commits into
base: r2d2
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c28c407
first commit, lots of WIP needed
ashwin-balakrishna96 Feb 28, 2024
851ca6e
remove act stuff
ashwin-balakrishna96 Feb 28, 2024
94fe84b
start clean
ashwin-balakrishna96 Feb 28, 2024
c0ddf13
release code
ashwin-balakrishna96 Mar 7, 2024
721ab48
fix args
ashwin-balakrishna96 Mar 7, 2024
1943abf
fix, using this as default from now on
ashwin-balakrishna96 Mar 8, 2024
c49ec03
cleanup, r2d2 --> droid
ashwin-balakrishna96 Mar 8, 2024
fe2e60c
update README, install deps, tested fresh install
ashwin-balakrishna96 Mar 8, 2024
21cab26
clean up requirements
ashwin-balakrishna96 Mar 8, 2024
03b19a8
update todos
ashwin-balakrishna96 Mar 8, 2024
fa1215f
clean
ashwin-balakrishna96 Mar 8, 2024
dfe3ed2
update README
ashwin-balakrishna96 Mar 8, 2024
bbe0c5b
clean
ashwin-balakrishna96 Mar 9, 2024
0b022fe
clean
ashwin-balakrishna96 Mar 9, 2024
56808a5
more cleanup
ashwin-balakrishna96 Mar 9, 2024
03a731d
more clean
ashwin-balakrishna96 Mar 9, 2024
e1d2e41
clean visual core
ashwin-balakrishna96 Mar 9, 2024
f814f8f
some hdf5 cleanup
suraj-nair-tri Mar 9, 2024
b21c4d7
Cleanup and removing unused files
suraj-nair-tri Mar 9, 2024
a804476
Update README.md
kpertsch Mar 9, 2024
d551ced
fix readmes
ashwin-balakrishna96 Mar 10, 2024
858f27f
fix merge
ashwin-balakrishna96 Mar 10, 2024
3b6f36c
clean more, gotta do bc transformer stuff
ashwin-balakrishna96 Mar 11, 2024
0d8c41b
remove unnecessary files
ashwin-balakrishna96 Mar 11, 2024
c96fa71
clean config
ashwin-balakrishna96 Mar 11, 2024
77239a8
clean configs
ashwin-balakrishna96 Mar 11, 2024
fe27799
clean
ashwin-balakrishna96 Mar 11, 2024
1269b0f
minor fix
ashwin-balakrishna96 Mar 11, 2024
c38876a
small fix
ashwin-balakrishna96 Mar 11, 2024
8f2a0f9
clean up install instructions
ashwin-balakrishna96 Mar 11, 2024
d55376c
octo dataloader fixes
ashwin-balakrishna96 Mar 11, 2024
51e1221
clean readme
ashwin-balakrishna96 Mar 12, 2024
f31a742
fix tiny bug
ashwin-balakrishna96 Mar 12, 2024
50d17eb
add dataloader example
ashwin-balakrishna96 Mar 12, 2024
e2df06b
fix
ashwin-balakrishna96 Mar 12, 2024
14c1c42
add in filter support
ashwin-balakrishna96 Mar 12, 2024
713de70
fix
ashwin-balakrishna96 Mar 12, 2024
1e06296
fix shuffle buffer size
ashwin-balakrishna96 Mar 12, 2024
2232921
full clean
ashwin-balakrishna96 Mar 13, 2024
cb5a935
clean out norm stuff
ashwin-balakrishna96 Mar 13, 2024
0725feb
fix comment
ashwin-balakrishna96 Mar 13, 2024
4ffb383
fix typo
ashwin-balakrishna96 Mar 13, 2024
eae359b
fix abs action
ashwin-balakrishna96 Mar 13, 2024
58e48fd
Update README.md
kpertsch Mar 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
octo dataloader fixes
  • Loading branch information
ashwin-balakrishna96 committed Mar 11, 2024
commit d55376cbb147ac24044c603bf9b1f7a721fee849
11 changes: 8 additions & 3 deletions robomimic/scripts/train.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,8 @@
from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings
from robomimic.utils.rlds_utils import droid_dataset_transform, robomimic_transform, DROID_TO_RLDS_OBS_KEY_MAP, DROID_TO_RLDS_LOW_DIM_OBS_KEY_MAP, TorchRLDSDataset

from octo.data.dataset import make_interleaved_dataset
from octo.data.dataset import make_dataset_from_rlds, make_interleaved_dataset
from octo.data.utils.data_utils import combine_dataset_statistics


def train(config, device):
@@ -96,7 +97,7 @@ def train(config, device):
"image_obs_keys": {"primary": DROID_TO_RLDS_OBS_KEY_MAP[obs_modalities[0]], "secondary": DROID_TO_RLDS_OBS_KEY_MAP[obs_modalities[1]]},
"state_obs_keys": [DROID_TO_RLDS_LOW_DIM_OBS_KEY_MAP[a] for a in config.observation.modalities.obs.low_dim],
"language_key": "language_instruction",
"keys_to_normalize": {"action": "action"},
"norm_skip_keys": ["proprio"],
"action_proprio_normalization_type": "bounds",
"absolute_action_mask": [True] * ac_dim,
"action_normalization_mask": [True] * ac_dim,
@@ -108,6 +109,10 @@ def train(config, device):
dataset_kwargs_list = [
{"name": d_name, **BASE_DATASET_KWARGS} for d_name in dataset_names
]
# Compute combined normalization stats
combined_dataset_statistics = combine_dataset_statistics(
[make_dataset_from_rlds(**dataset_kwargs, train=train)[1] for dataset_kwargs in dataset_kwargs_list]
)

dataset = make_interleaved_dataset(
dataset_kwargs_list,
@@ -116,7 +121,7 @@ def train(config, device):
shuffle_buffer_size=config.train.shuffle_buffer_size,
batch_size=None, # batching will be handles in PyTorch Dataloader object
balance_weights=False,
do_combined_normalization=True,
dataset_statistics=combined_dataset_statistics,
traj_transform_kwargs=dict(
# NOTE(Ashwin): window_size and future_action_window_size may break if
# not using diffusion policy
2 changes: 1 addition & 1 deletion robomimic/utils/action_utils.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ def get_action_stats_dict(rlds_dataset_stats, action_keys, action_shapes):
end_idx = start_idx + this_act_dim
action_stats[key] = dict()
for sub_key in rlds_dataset_stats.keys():
action = rlds_dataset_stats[sub_key]
action = np.array(rlds_dataset_stats[sub_key])
action_stats[key][sub_key] = action[...,start_idx:end_idx].reshape(
action.shape[:-1]+tuple(this_act_shape))
start_idx = end_idx