Skip to content

Commit

Permalink
Merge pull request #65 from christsa/main
Browse files Browse the repository at this point in the history
MANO support and z-up feature for Meshes
  • Loading branch information
kaufManu authored Sep 23, 2024
2 parents 380a2c2 + 9113538 commit 8fb6d46
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 17 deletions.
59 changes: 58 additions & 1 deletion aitviewer/models/smpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,65 @@ def fk(

return output.vertices, output.joints

def fk_mano(
self,
hand_pose,
betas,
global_orient=None,
trans=None,
normalize_root=False,
mano=True,
):
"""
Convert mano pose data (joint angles and shape parameters) to positional data (joint and mesh vertex positions).
:param hand_pose: A tensor of shape (N, N_JOINTS*3), i.e. joint angles in angle-axis format or PCA format (N, N_PCA_COMPONENTS). This contains all
body joints which are not the root.
:param betas: A tensor of shape (N, N_BETAS) containing the betas/shape parameters.
:param global_orient: Orientation of the root or None. If specified expected shape is (N, 3).
:param trans: translation that is applied to vertices and joints or None, this is the 'transl' parameter
of the MANO Model. If specified expected shape is (N, 3).
:param normalize_root: If set, it will normalize the root such that its orientation is the identity in the
first frame and its position starts at the origin.
:return: The resulting vertices and joints.
"""

batch_size = hand_pose.shape[0]
device = hand_pose.device

if global_orient is None:
global_orient = torch.zeros([batch_size, 3]).to(dtype=hand_pose.dtype, device=device)
if trans is None:
trans = torch.zeros([batch_size, 3]).to(dtype=hand_pose.dtype, device=device)

# Batch shapes if they don't match batch dimension.
if len(betas.shape) == 1 or betas.shape[0] == 1:
betas = betas.repeat(hand_pose.shape[0], 1)
betas = betas[:, : self.num_betas]

if normalize_root:
# Make everything relative to the first root orientation.
root_ori = aa2rot(global_orient)
first_root_ori = torch.inverse(root_ori[0:1])
root_ori = torch.matmul(first_root_ori, root_ori)
global_orient = rot2aa(root_ori)
trans = torch.matmul(first_root_ori.unsqueeze(0), trans.unsqueeze(-1)).squeeze()
trans = trans - trans[0:1]

output = self.bm(
hand_pose=hand_pose,
betas=betas,
global_orient=global_orient,
transl=trans,
)

return output.vertices, output.joints

def forward(self, *args, **kwargs):
"""
Forward pass using forward kinematics
"""
return self.fk(*args, **kwargs)

if "mano" in kwargs.keys():
return self.fk_mano(*args, **kwargs)
else:
return self.fk(*args, **kwargs)
4 changes: 4 additions & 0 deletions aitviewer/renderables/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
draw_edges=False,
draw_outline=False,
instance_transforms=None,
z_up=False,
icon="\u008d",
**kwargs,
):
Expand Down Expand Up @@ -147,6 +148,9 @@ def _maybe_unsqueeze(x):
self.clip_control = np.array((0, 0, 0), np.int32)
self.clip_value = np.array((0, 0, 0), np.float32)

if z_up:
self.rotation = np.matmul(np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]]), self.rotation)

@classmethod
def instanced(cls, *args, positions=None, rotations=None, scales=None, **kwargs):
"""
Expand Down
53 changes: 37 additions & 16 deletions aitviewer/renderables/smpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,28 @@ def __init__(

# First convert the relative joint angles to global joint angles in rotation matrix form.
if self.smpl_layer.model_type != "flame":
global_oris = local_to_global(
torch.cat([self.poses_root, self.poses_body], dim=-1),
self.skeleton[:, 0],
output_format="rotmat",
)
if self.smpl_layer.model_type != "mano":
global_oris = local_to_global(
torch.cat([self.poses_root, self.poses_body, self.poses_left_hand, self.poses_right_hand], dim=-1),
self.skeleton[:, 0],
output_format="rotmat",
)
else:
global_oris = local_to_global(
torch.cat([self.poses_root, self.poses_body], dim=-1),
self.skeleton[:, 0],
output_format="rotmat",
)
global_oris = c2c(global_oris.reshape((self.n_frames, -1, 3, 3)))
else:
global_oris = np.tile(np.eye(3), self.joints.shape[:-1])[np.newaxis]

if self._z_up and not C.z_up:
self.rotation = np.matmul(np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]]), self.rotation)

self.rbs = RigidBodies(self.joints, global_oris, length=0.1, gui_affine=False, name="Joint Angles")
self._add_node(self.rbs, enabled=self._show_joint_angles)
if self.smpl_layer.model_type != "mano":
self.rbs = RigidBodies(self.joints, global_oris, length=0.1, gui_affine=False, name="Joint Angles")
self._add_node(self.rbs, enabled=self._show_joint_angles)

self.mesh_seq = Meshes(
self.vertices,
Expand Down Expand Up @@ -397,20 +405,33 @@ def fk(self, current_frame_only=False):
trans = self.trans
betas = self.betas

verts, joints = self.smpl_layer(
poses_root=poses_root,
poses_body=poses_body,
poses_left_hand=poses_left_hand,
poses_right_hand=poses_right_hand,
betas=betas,
trans=trans,
)
if self.smpl_layer.model_type == "mano":
verts, joints = self.smpl_layer(
hand_pose=poses_body,
betas=betas,
global_orient=poses_root,
trans=trans,
mano=True,
)
else:
verts, joints = self.smpl_layer(
poses_root=poses_root,
poses_body=poses_body,
poses_left_hand=poses_left_hand,
poses_right_hand=poses_right_hand,
betas=betas,
trans=trans,
)

# Apply post_fk_func if specified.
if self.post_fk_func:
verts, joints = self.post_fk_func(self, verts, joints, current_frame_only)

skeleton = self.smpl_layer.skeletons()["body"].T
skeleton = (
self.smpl_layer.skeletons()["body"].T
if not self.smpl_layer.model_type == "mano"
else self.smpl_layer.skeletons()["all"].T
)
faces = self.smpl_layer.bm.faces.astype(np.int64)
joints = joints[:, : skeleton.shape[0]]

Expand Down

0 comments on commit 8fb6d46

Please sign in to comment.