Skip to content

Commit

Permalink
add replay buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
KuoHaoZeng committed Sep 6, 2024
1 parent 497b4d1 commit ece8b41
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 14 deletions.
31 changes: 17 additions & 14 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,8 +1200,8 @@ def __init__(
save_ckpt_after_every_pipeline_stage: bool = True,
first_local_worker_id: int = 0,
save_ckpt_at_every_host: bool = False,
offpolicy_batch_size: int = 32,
replay_buffer_max_size: int = 640,
offpolicy_batch_size: int = 0,
offpolicy_max_batch_size: int = 640,
**kwargs,
):
kwargs["mode"] = TRAIN_MODE_STR
Expand Down Expand Up @@ -1231,14 +1231,15 @@ def __init__(
self.training_pipeline: TrainingPipeline = config.training_pipeline()

# [OFFP]
self.replay_buffer = ReplayBuffer(
storage=LazyMemmapStorage(
max_size=replay_buffer_max_size,
device=torch.device("cpu"),
scratch_dir="/tmp/replay_buffer/",
),
batch_size=offpolicy_batch_size,
)
if offpolicy_batch_size > 0:
self.replay_buffer = ReplayBuffer(
storage=LazyMemmapStorage(
max_size=offpolicy_max_batch_size,
device=torch.device("cpu"),
scratch_dir="/tmp/replay_buffer/",
),
batch_size=offpolicy_batch_size,
)

if self.num_workers != 1:
# Ensure that we're only using early stopping criterions in the non-distributed setting.
Expand Down Expand Up @@ -1851,17 +1852,19 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
for storage in self.training_pipeline.current_stage_storage.values():
storage.before_updates(**before_update_info)

adapted_storage = StorageAdapter(storage, torch.device("cpu"))
tensordict = adapted_storage.to_tensordict(batch_size=[storage.rewards.shape[1]])
self.replay_buffer.extend(tensordict)
if self.replay_buffer is not None:
adapted_storage = StorageAdapter(storage, torch.device("cpu"))
tensordict = adapted_storage.to_tensordict(batch_size=[storage.rewards.shape[1]])
self.replay_buffer.extend(tensordict)

for sc in self.training_pipeline.current_stage.stage_components:
component_storage = uuid_to_storage[sc.storage_uuid]

self.compute_losses_track_them_and_backprop(
stage=self.training_pipeline.current_stage,
stage_component=sc,
storage=component_storage,
storage=component_storage if self.replay_buffer is not None else None,
replay_buffer=self.replay_buffer,
)

for storage in self.training_pipeline.current_stage_storage.values():
Expand Down
4 changes: 4 additions & 0 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ def start_train(
valid_on_initial_weights: bool = False,
try_restart_after_task_error: bool = False,
save_ckpt_at_every_host: bool = False,
offpolicy_batch_size: Optional[int] = 0,
offpolicy_max_batch_size: Optional[int] = 640,
):
self._initialize_start_train_or_start_test()

Expand Down Expand Up @@ -574,6 +576,8 @@ def start_train(
valid_on_initial_weights=valid_on_initial_weights,
try_restart_after_task_error=try_restart_after_task_error,
save_ckpt_at_every_host=save_ckpt_at_every_host,
offpolicy_batch_size=offpolicy_batch_size,
offpolicy_max_batch_size=offpolicy_max_batch_size,
)
train: BaseProcess = self.mp_ctx.Process(
target=self.train_loop,
Expand Down
20 changes: 20 additions & 0 deletions allenact/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,24 @@ def get_argument_parser():
)
parser.set_defaults(save_ckpt_at_every_host=False)

parser.add_argument(
"--offpolicy_batch_size",
dest="offpolicy_batch_size",
required=False,
type=int,
default=0,
help="Batch size for off-policy training (default: 0, i.e. on-policy training).",
)

parser.add_argument(
"--offpolicy_max_batch_size",
dest="offpolicy_max_batch_size",
required=False,
type=int,
default=640,
help="Max batch size for replay buffer used for off-policy training.",
)

parser.add_argument(
"--callbacks",
dest="callbacks",
Expand Down Expand Up @@ -495,6 +513,8 @@ def main():
valid_on_initial_weights=args.valid_on_initial_weights,
try_restart_after_task_error=args.enable_crash_recovery,
save_ckpt_at_every_host=args.save_ckpt_at_every_host,
offpolicy_batch_size=args.offpolicy_batch_size,
offpolicy_max_batch_size=args.offpolicy_max_batch_size,
)
else:
OnPolicyRunner(
Expand Down

0 comments on commit ece8b41

Please sign in to comment.