Skip to content

Commit de2db8f

Browse files
committed
Tool for running trained agent.
1 parent ce05d3e commit de2db8f

File tree

2 files changed

+163
-0
lines changed

2 files changed

+163
-0
lines changed

tools/run_agent.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import numpy as np
2+
import torch
3+
from absl import app, flags
4+
from hydra.experimental import initialize, compose
5+
from moviepy.editor import *
6+
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
7+
from omegaconf import OmegaConf, ListConfig
8+
from rlbench.action_modes.action_mode import MoveArmThenGripper
9+
from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaPlanning
10+
from rlbench.action_modes.gripper_action_modes import Discrete
11+
from rlbench.backend.utils import task_file_to_task_class
12+
13+
from arm import c2farm, qte, lpr
14+
from arm.custom_rlbench_env import CustomRLBenchEnv
15+
from arm.lpr.trajectory_action_mode import TrajectoryActionMode
16+
from launch import _create_obs_config
17+
from tools.utils import RLBenchCinematic
18+
19+
FREEZE_DURATION = 2
20+
FPS = 20
21+
22+
flags.DEFINE_string('logdir', '/path/to/log/dir', 'weight dir.')
23+
flags.DEFINE_string('method', 'C2FARM', 'The method to run.')
24+
flags.DEFINE_string('task', 'take_lid_off_saucepan', 'The task to run.')
25+
flags.DEFINE_integer('episodes', 1, 'The number of episodes to run.')
26+
27+
FLAGS = flags.FLAGS
28+
29+
30+
def _save_clips(clips, name):
31+
final_clip = concatenate_videoclips(clips)
32+
final_clip.write_videofile('%s.mp4' % name)
33+
34+
35+
def visualise(logdir, task, method):
36+
config_path = os.path.join(logdir, task, method, '.hydra')
37+
weights_path = os.path.join(logdir, task, method, 'seed0', 'weights')
38+
39+
if not os.path.exists(config_path):
40+
raise ValueError('No cofig in: ' + config_path)
41+
if not os.path.exists(weights_path):
42+
raise ValueError('No weights in: ' + weights_path)
43+
44+
with initialize(config_path=os.path.relpath(config_path)):
45+
cfg = compose(config_name="config")
46+
print(OmegaConf.to_yaml(cfg))
47+
48+
cfg.rlbench.cameras = cfg.rlbench.cameras if isinstance(
49+
cfg.rlbench.cameras, ListConfig) else [cfg.rlbench.cameras]
50+
51+
obs_config = _create_obs_config(
52+
cfg.rlbench.cameras, cfg.rlbench.camera_resolution)
53+
task_class = task_file_to_task_class(task)
54+
55+
gripper_mode = Discrete()
56+
if cfg.method.name == 'PathARM':
57+
arm_action_mode = TrajectoryActionMode(cfg.method.trajectory_points)
58+
else:
59+
arm_action_mode = EndEffectorPoseViaPlanning()
60+
action_mode = MoveArmThenGripper(arm_action_mode, gripper_mode)
61+
62+
env = CustomRLBenchEnv(
63+
task_class=task_class, observation_config=obs_config,
64+
action_mode=action_mode, dataset_root=cfg.rlbench.demo_path,
65+
episode_length=cfg.rlbench.episode_length, headless=True,
66+
time_in_state=True)
67+
_ = env.observation_elements
68+
69+
if cfg.method.name == 'C2FARM':
70+
agent = c2farm.launch_utils.create_agent(
71+
cfg, env, cfg.rlbench.scene_bounds,
72+
cfg.rlbench.camera_resolution)
73+
elif cfg.method.name == 'C2FARM+QTE':
74+
agent = qte.launch_utils.create_agent(
75+
cfg, env, cfg.rlbench.scene_bounds,
76+
cfg.rlbench.camera_resolution)
77+
elif cfg.method.name == 'LPR':
78+
agent = lpr.launch_utils.create_agent(
79+
cfg, env, cfg.rlbench.scene_bounds, cfg.rlbench.camera_resolution,
80+
cfg.method.trajectory_point_noise, cfg.method.trajectory_points,
81+
cfg.method.trajectory_mode, cfg.method.trajectory_samples)
82+
else:
83+
raise ValueError('Invalid method name.')
84+
85+
agent.build(training=False, device=torch.device("cpu"))
86+
weight_folders = sorted(map(int, os.listdir(weights_path)))
87+
agent.load_weights(os.path.join(weights_path, str(weight_folders[-1])))
88+
89+
env.launch()
90+
cinemtaic_cam = RLBenchCinematic()
91+
env.register_callback(cinemtaic_cam.callback)
92+
for ep in range(FLAGS.episodes):
93+
obs = env.reset()
94+
agent.reset()
95+
obs_history = {
96+
k: [np.array(v, dtype=_get_type(v))] * cfg.replay.timesteps for
97+
k, v in obs.items()}
98+
clips = []
99+
last = False
100+
for step in range(cfg.rlbench.episode_length):
101+
prepped_data = {k: torch.FloatTensor([v]) for k, v in obs_history.items()}
102+
act_result = agent.act(step, prepped_data, deterministic=True)
103+
transition = env.step(act_result)
104+
105+
trajectory_frames = cinemtaic_cam.frames
106+
if len(trajectory_frames) > 0:
107+
cinemtaic_cam.empty()
108+
clips.append(ImageSequenceClip(trajectory_frames, fps=FPS))
109+
110+
if last:
111+
break
112+
if transition.terminal:
113+
last = True
114+
for k in obs_history.keys():
115+
obs_history[k].append(transition.observation[k])
116+
obs_history[k].pop(0)
117+
_save_clips(clips, '%s_%s.mp4' % (method, task))
118+
119+
print('Shutting down env...')
120+
env.shutdown()
121+
122+
123+
def _get_type(x):
124+
if x.dtype == np.float64:
125+
return np.float32
126+
return x.dtype
127+
128+
129+
def main(argv):
130+
del argv
131+
visualise(FLAGS.logdir, FLAGS.task, FLAGS.method)
132+
133+
134+
if __name__ == '__main__':
135+
app.run(main)

tools/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import numpy as np
2+
from pyrep.const import RenderMode
3+
from pyrep.objects import Dummy, VisionSensor
4+
5+
6+
class RLBenchCinematic(object):
7+
8+
def __init__(self):
9+
cam_placeholder = Dummy('cam_cinematic_placeholder')
10+
self._cam_base = Dummy('cam_cinematic_base')
11+
self._cam = VisionSensor.create([640, 480])
12+
self._cam.set_explicit_handling(True)
13+
self._cam.set_pose(cam_placeholder.get_pose())
14+
self._cam.set_parent(cam_placeholder)
15+
self._cam.set_render_mode(RenderMode.OPENGL3)
16+
self._frames = []
17+
18+
def callback(self):
19+
self._cam.handle_explicitly()
20+
cap = (self._cam.capture_rgb() * 255).astype(np.uint8)
21+
self._frames.append(cap)
22+
23+
def empty(self):
24+
self._frames.clear()
25+
26+
@property
27+
def frames(self):
28+
return list(self._frames)

0 commit comments

Comments
 (0)