-
Notifications
You must be signed in to change notification settings - Fork 20
/
openvla_agent.py
60 lines (49 loc) · 2.11 KB
/
openvla_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from mbodied.agents.motion.motor_agent import MotorAgent
from mbodied.types.motion.control import HandControl, Motion
from mbodied.types.sense.vision import Image
class OpenVlaAgent(MotorAgent):
"""OpenVLA agent to generate robot actions.
Specify gradio server endpoint in model_src to make inference via API.
See openvla_example_server.py for the an exmaple of the gradio server code.
`actor` is a gradio server taking: image, instruction, and unnorm_key as input.
Examples:
>>> openvla_agent = OpenVlaAgent(model_src="https://api.mbodi.ai/community-models/")
>>> openvla.act("move hand forward", Image(size=(224, 224)))
HandControl(pose=Pose6D(x=1,y=2,z=3,roll=0,pitch=0,yaw=0), grasp=JointControl(value=0))
"""
def __init__(
self,
recorder="omit",
recorder_kwargs=None,
model_src=None,
model_kwargs=None,
**kwargs,
):
super().__init__(
recorder=recorder,
recorder_kwargs=recorder_kwargs,
model_src=model_src,
model_kwargs=model_kwargs,
**kwargs,
)
def act(self, instruction: str, image: Image, unnorm_key: str = "bridge_orig") -> Motion:
"""Act based on the instruction and image using the remote server.
Args:
instruction (str): The instruction to act on.
image (Image): The image to act on.
unnorm_key (str): The key for the unnormalized image.
Returns:
Motion: The HandControl generated by the agent.
"""
if self.actor is None:
raise ValueError("Remote actor for OpenVLA not initialized.")
response = self.actor.predict(image.base64, instruction, unnorm_key)
items = response.strip("[]").split()
action = [float(item) for item in items]
return HandControl.unflatten(action)
# # Example usage:
# if __name__ == "__main__":
# openvla_agent = OpenVlaAgent(model_src="https://api.mbodi.ai/community-models/")
# image = Image("resources/xarm.jpeg")
# response = openvla_agent.act("move forward", image)
# print(response)