From a3f2cfafbbffb15fbaca534f65ebac8816cad36c Mon Sep 17 00:00:00 2001 From: YilingQiao Date: Sat, 21 Dec 2024 17:55:09 -0500 Subject: [PATCH] add api for set global solver parameters --- .../solvers/rigid/constraint_solver_decomp.py | 12 ++------- .../rigid/constraint_solver_decomp_island.py | 11 ++------ .../solvers/rigid/rigid_solver_decomp.py | 25 ++++++++++++++++++- genesis/utils/geom.py | 4 +-- 4 files changed, 30 insertions(+), 22 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index 592200e..c680c31 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -168,10 +168,7 @@ def add_collision_constraints(self): if ti.static(self.sparse_solve): self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs - - imp, aref = gu.imp_aref( - self._solver._sol_contact_resolve_time, impact.sol_params, -impact.penetration, jac_qvel - ) + imp, aref = gu.imp_aref(impact.sol_params, -impact.penetration, jac_qvel) diag = t + impact.friction * impact.friction * t diag *= 2 * impact.friction * impact.friction * (1 - imp) / ti.max(imp, gs.EPS) @@ -199,12 +196,7 @@ def add_joint_limit_constraints(self): jac = side jac_qvel = jac * self._solver.dofs_state[i_d, i_b].vel - imp, aref = gu.imp_aref( - self._solver._sol_contact_resolve_time, - self._solver.dofs_info[i_d].sol_params, - pos, - jac_qvel, - ) + imp, aref = gu.imp_aref(self._solver.dofs_info[i_d].sol_params, pos, jac_qvel) diag = self._solver.dofs_info[i_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS) aref = aref * (pos < 0) if pos < 0: diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py b/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py index 1e67dc6..e33867a 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py @@ -195,9 +195,7 @@ def add_collision_constraints(self, island, i_b): if ti.static(self.sparse_solve): self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs - imp, aref = gu.imp_aref( - self._solver._sol_contact_resolve_time, impact.sol_params, -impact.penetration, jac_qvel - ) + imp, aref = gu.imp_aref(impact.sol_params, -impact.penetration, jac_qvel) diag = t + impact.friction * impact.friction * t diag *= 2 * impact.friction * impact.friction * (1 - imp) / ti.max(imp, gs.EPS) @@ -235,12 +233,7 @@ def add_joint_limit_constraints(self, island, i_b): jac = side jac_qvel = jac * self._solver.dofs_state[i_d, i_b].vel - imp, aref = gu.imp_aref( - self._solver._sol_contact_resolve_time, - self._solver.dofs_info[i_d].sol_params, - pos, - jac_qvel, - ) + imp, aref = gu.imp_aref(self._solver.dofs_info[i_d].sol_params, pos, jac_qvel) diag = self._solver.dofs_info[i_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS) aref = aref * (pos < 0) if pos < 0: diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 0e0786b..a6abeb5 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -351,7 +351,7 @@ def _kernel_init_dof_fields( for j in ti.static(range(7)): self.dofs_info[i].sol_params[j] = dofs_sol_params[i, j] - self.dofs_info[i].sol_params[0] = self._substep_dt * 2 + self.dofs_info[i].sol_params[0] = self._sol_contact_resolve_time self.dofs_info[i].armature = dofs_armature[i] self.dofs_info[i].invweight = dofs_invweight[i] @@ -776,6 +776,7 @@ def _kernel_init_geom_fields( for j in ti.static(range(7)): self.geoms_info[i].data[j] = geoms_data[i, j] self.geoms_info[i].sol_params[j] = geoms_sol_params[i, j] + self.geoms_info[i].sol_params[0] = self._sol_contact_resolve_time self.geoms_info[i].sol_params[0] = ti.max(self.geoms_info[i].sol_params[0], self._substep_dt * 2) @@ -3319,6 +3320,28 @@ def _kernel_set_qpos( for i_q_, i_b_ in ti.ndrange(qs_idx.shape[0], envs_idx.shape[0]): self.qpos[qs_idx[i_q_], envs_idx[i_b_]] = qpos[i_b_, i_q_] + def set_global_sol_params(self, sol_params): + """ + Solver parameters (timeconst, dampratio, dmin, dmax, width, mid, power). + Reference: https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters + """ + assert len(sol_params) == 7 + self._kernel_set_global_sol_params(sol_params) + + @ti.kernel + def _kernel_set_global_sol_params(self, sol_params: ti.types.ndarray()): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + for i in range(self.n_geoms): + for j in ti.static(range(7)): + self.geoms_info[i].sol_params[j] = sol_params[j] + + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL) + for i, b in ti.ndrange(self.n_dofs, self._B): + for j in ti.static(range(7)): + self.dofs_info[i].sol_params[j] = sol_params[j] + + self.dofs_info[i].sol_params[0] = self._substep_dt * 2 + def set_dofs_kp(self, kp, dofs_idx): kp, dofs_idx = self._validate_1D_io_variables(kp, dofs_idx, batched=False) self._kernel_set_dofs_kp(kp, dofs_idx) diff --git a/genesis/utils/geom.py b/genesis/utils/geom.py index c272b1f..acfd651 100644 --- a/genesis/utils/geom.py +++ b/genesis/utils/geom.py @@ -306,9 +306,9 @@ def orthogonals2(a): @ti.func -def imp_aref(timeconst, params, pos, vel): +def imp_aref(params, pos, vel): # The first term in parms is the timeconst parsed from mjcf. However, we don't use it here but use the one passed in, which is 2*substep_dt. - _, dampratio, dmin, dmax, width, mid, power = params + timeconst, dampratio, dmin, dmax, width, mid, power = params imp_x = ti.abs(pos) / width imp_a = (1.0 / mid ** (power - 1)) * imp_x**power imp_b = 1 - (1.0 / (1 - mid) ** (power - 1)) * (1 - imp_x) ** power