diff --git a/embodichain/lab/sim/solvers/base_solver.py b/embodichain/lab/sim/solvers/base_solver.py index c7fc70f2..c5a26e5d 100644 --- a/embodichain/lab/sim/solvers/base_solver.py +++ b/embodichain/lab/sim/solvers/base_solver.py @@ -172,9 +172,13 @@ def __init__(self, cfg: SolverCfg = None, device: str = None, **kwargs): device=self.device, ) + def _fk_end_matrix(th): + return self.pk_serial_chain.forward_kinematics( + th, end_only=True + ).get_matrix() + self.compiled_fk = torch.compile( - self.pk_serial_chain.forward_kinematics_tensor, - fullgraph=True, + _fk_end_matrix, dynamic=True, ) @@ -433,7 +437,7 @@ def get_fk(self, qpos: torch.tensor, **kwargs) -> torch.Tensor: logger.log_error("Kinematic chain is not initialized.") return torch.eye(4, device=self.device) # Compute forward kinematics - ee_link_xpos = self.compiled_fk(qpos)[-1, :, :, :] + ee_link_xpos = self.compiled_fk(qpos) # Ensure batch format for TCP batch_size = qpos.shape[0]