Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add api for set global solver parameters #223

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions genesis/engine/solvers/rigid/constraint_solver_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,7 @@ def add_collision_constraints(self):

if ti.static(self.sparse_solve):
self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs

imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time, impact.sol_params, -impact.penetration, jac_qvel
)
imp, aref = gu.imp_aref(impact.sol_params, -impact.penetration, jac_qvel)

diag = t + impact.friction * impact.friction * t
diag *= 2 * impact.friction * impact.friction * (1 - imp) / ti.max(imp, gs.EPS)
Expand Down Expand Up @@ -199,12 +196,7 @@ def add_joint_limit_constraints(self):

jac = side
jac_qvel = jac * self._solver.dofs_state[i_d, i_b].vel
imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time,
self._solver.dofs_info[i_d].sol_params,
pos,
jac_qvel,
)
imp, aref = gu.imp_aref(self._solver.dofs_info[i_d].sol_params, pos, jac_qvel)
diag = self._solver.dofs_info[i_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS)
aref = aref * (pos < 0)
if pos < 0:
Expand Down
11 changes: 2 additions & 9 deletions genesis/engine/solvers/rigid/constraint_solver_decomp_island.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def add_collision_constraints(self, island, i_b):
if ti.static(self.sparse_solve):
self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs

imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time, impact.sol_params, -impact.penetration, jac_qvel
)
imp, aref = gu.imp_aref(impact.sol_params, -impact.penetration, jac_qvel)

diag = t + impact.friction * impact.friction * t
diag *= 2 * impact.friction * impact.friction * (1 - imp) / ti.max(imp, gs.EPS)
Expand Down Expand Up @@ -235,12 +233,7 @@ def add_joint_limit_constraints(self, island, i_b):

jac = side
jac_qvel = jac * self._solver.dofs_state[i_d, i_b].vel
imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time,
self._solver.dofs_info[i_d].sol_params,
pos,
jac_qvel,
)
imp, aref = gu.imp_aref(self._solver.dofs_info[i_d].sol_params, pos, jac_qvel)
diag = self._solver.dofs_info[i_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS)
aref = aref * (pos < 0)
if pos < 0:
Expand Down
25 changes: 24 additions & 1 deletion genesis/engine/solvers/rigid/rigid_solver_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _kernel_init_dof_fields(
for j in ti.static(range(7)):
self.dofs_info[i].sol_params[j] = dofs_sol_params[i, j]

self.dofs_info[i].sol_params[0] = self._substep_dt * 2
self.dofs_info[i].sol_params[0] = self._sol_contact_resolve_time

self.dofs_info[i].armature = dofs_armature[i]
self.dofs_info[i].invweight = dofs_invweight[i]
Expand Down Expand Up @@ -776,6 +776,7 @@ def _kernel_init_geom_fields(
for j in ti.static(range(7)):
self.geoms_info[i].data[j] = geoms_data[i, j]
self.geoms_info[i].sol_params[j] = geoms_sol_params[i, j]
self.geoms_info[i].sol_params[0] = self._sol_contact_resolve_time

self.geoms_info[i].sol_params[0] = ti.max(self.geoms_info[i].sol_params[0], self._substep_dt * 2)

Expand Down Expand Up @@ -3319,6 +3320,28 @@ def _kernel_set_qpos(
for i_q_, i_b_ in ti.ndrange(qs_idx.shape[0], envs_idx.shape[0]):
self.qpos[qs_idx[i_q_], envs_idx[i_b_]] = qpos[i_b_, i_q_]

def set_global_sol_params(self, sol_params):
"""
Solver parameters (timeconst, dampratio, dmin, dmax, width, mid, power).
Reference: https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters
"""
assert len(sol_params) == 7
self._kernel_set_global_sol_params(sol_params)

@ti.kernel
def _kernel_set_global_sol_params(self, sol_params: ti.types.ndarray()):
ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL)
for i in range(self.n_geoms):
for j in ti.static(range(7)):
self.geoms_info[i].sol_params[j] = sol_params[j]

ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL)
for i, b in ti.ndrange(self.n_dofs, self._B):
for j in ti.static(range(7)):
self.dofs_info[i].sol_params[j] = sol_params[j]

self.dofs_info[i].sol_params[0] = self._substep_dt * 2

def set_dofs_kp(self, kp, dofs_idx):
kp, dofs_idx = self._validate_1D_io_variables(kp, dofs_idx, batched=False)
self._kernel_set_dofs_kp(kp, dofs_idx)
Expand Down
4 changes: 2 additions & 2 deletions genesis/utils/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,9 @@ def orthogonals2(a):


@ti.func
def imp_aref(timeconst, params, pos, vel):
def imp_aref(params, pos, vel):
# The first term in parms is the timeconst parsed from mjcf. However, we don't use it here but use the one passed in, which is 2*substep_dt.
_, dampratio, dmin, dmax, width, mid, power = params
timeconst, dampratio, dmin, dmax, width, mid, power = params
imp_x = ti.abs(pos) / width
imp_a = (1.0 / mid ** (power - 1)) * imp_x**power
imp_b = 1 - (1.0 / (1 - mid) ** (power - 1)) * (1 - imp_x) ** power
Expand Down