Skip to content

Commit

Permalink
add api for set global solver parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
YilingQiao committed Dec 21, 2024
1 parent 6c3c938 commit 249402b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 20 deletions.
12 changes: 2 additions & 10 deletions genesis/engine/solvers/rigid/constraint_solver_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,7 @@ def add_collision_constraints(self):

if ti.static(self.sparse_solve):
self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs

imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time, impact.sol_params, -impact.penetration, jac_qvel
)
imp, aref = gu.imp_aref(impact.sol_params, -impact.penetration, jac_qvel)

diag = t + impact.friction * impact.friction * t
diag *= 2 * impact.friction * impact.friction * (1 - imp) / ti.max(imp, gs.EPS)
Expand Down Expand Up @@ -199,12 +196,7 @@ def add_joint_limit_constraints(self):

jac = side
jac_qvel = jac * self._solver.dofs_state[i_d, i_b].vel
imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time,
self._solver.dofs_info[i_d].sol_params,
pos,
jac_qvel,
)
imp, aref = gu.imp_aref(self._solver.dofs_info[i_d].sol_params, pos, jac_qvel)
diag = self._solver.dofs_info[i_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS)
aref = aref * (pos < 0)
if pos < 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def add_collision_constraints(self, island, i_b):
self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs

imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time, impact.sol_params, -impact.penetration, jac_qvel
impact.sol_params, -impact.penetration, jac_qvel
)

diag = t + impact.friction * impact.friction * t
Expand Down Expand Up @@ -235,12 +235,7 @@ def add_joint_limit_constraints(self, island, i_b):

jac = side
jac_qvel = jac * self._solver.dofs_state[i_d, i_b].vel
imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time,
self._solver.dofs_info[i_d].sol_params,
pos,
jac_qvel,
)
imp, aref = gu.imp_aref(self._solver.dofs_info[i_d].sol_params, pos, jac_qvel)
diag = self._solver.dofs_info[i_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS)
aref = aref * (pos < 0)
if pos < 0:
Expand Down
24 changes: 23 additions & 1 deletion genesis/engine/solvers/rigid/rigid_solver_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _kernel_init_dof_fields(
for j in ti.static(range(7)):
self.dofs_info[i].sol_params[j] = dofs_sol_params[i, j]

self.dofs_info[i].sol_params[0] = self._substep_dt * 2
self.dofs_info[i].sol_params[0] = self._sol_contact_resolve_time

self.dofs_info[i].armature = dofs_armature[i]
self.dofs_info[i].invweight = dofs_invweight[i]
Expand Down Expand Up @@ -776,6 +776,7 @@ def _kernel_init_geom_fields(
for j in ti.static(range(7)):
self.geoms_info[i].data[j] = geoms_data[i, j]
self.geoms_info[i].sol_params[j] = geoms_sol_params[i, j]
self.geoms_info[i].sol_params[0] = self._sol_contact_resolve_time

self.geoms_info[i].sol_params[0] = ti.max(self.geoms_info[i].sol_params[0], self._substep_dt * 2)

Expand Down Expand Up @@ -3319,6 +3320,27 @@ def _kernel_set_qpos(
for i_q_, i_b_ in ti.ndrange(qs_idx.shape[0], envs_idx.shape[0]):
self.qpos[qs_idx[i_q_], envs_idx[i_b_]] = qpos[i_b_, i_q_]

def set_global_sol_params(self, sol_params):
"""
Solver parameters (timeconst, dampratio, dmin, dmax, width, mid, power).
Reference: https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters
"""
assert len(sol_params) == 7
self._kernel_set_global_sol_params(sol_params)

@ti.kernel
def _kernel_set_global_sol_params(self, sol_params: ti.types.ndarray()):
ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL)
for i in range(self.n_geoms):
for j in ti.static(range(7)):
self.geoms_info[i].sol_params[j] = sol_params[j]

ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL)
for i, b in ti.ndrange(self.n_dofs, self._B):
for j in ti.static(range(7)):
self.dofs_info[i].sol_params[j] = sol_params[j]

self.dofs_info[i].sol_params[0] = self._substep_dt * 2
def set_dofs_kp(self, kp, dofs_idx):
kp, dofs_idx = self._validate_1D_io_variables(kp, dofs_idx, batched=False)
self._kernel_set_dofs_kp(kp, dofs_idx)
Expand Down
4 changes: 2 additions & 2 deletions genesis/utils/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,9 @@ def orthogonals2(a):


@ti.func
def imp_aref(timeconst, params, pos, vel):
def imp_aref(params, pos, vel):
# The first term in parms is the timeconst parsed from mjcf. However, we don't use it here but use the one passed in, which is 2*substep_dt.
_, dampratio, dmin, dmax, width, mid, power = params
timeconst, dampratio, dmin, dmax, width, mid, power = params
imp_x = ti.abs(pos) / width
imp_a = (1.0 / mid ** (power - 1)) * imp_x**power
imp_b = 1 - (1.0 / (1 - mid) ** (power - 1)) * (1 - imp_x) ** power
Expand Down

0 comments on commit 249402b

Please sign in to comment.