Source code for easy_robot_control.utils.hyper_sphere_clamp

"""
Vectorized functions to clamp onto an hypershpere

Author: Elian NEPPEL
Lab: SRL, Moonshot team
"""

from typing import Any, Tuple

import nptyping as nt
import numpy as np
import quaternion as qt
from nptyping import NDArray, Shape

from easy_robot_control.utils.math import Quaternion, qt_normalize

Farr = NDArray[Any, nt.Float]
Barr = NDArray[Any, nt.Bool]
Flo3 = NDArray[Shape["3"], nt.Floating]

SAMPLING_STEP = 0.01 # will sample every 0.01 for a unit hypersphere
# if you use the radii, it is equivalent sampling every 0.01 * radii


[docs] def clamp_to_unit_hs( start: NDArray[Shape["N"], nt.Float], end: NDArray[Shape["N"], nt.Float], sampling_step: float = SAMPLING_STEP, ) -> NDArray[Shape["N"], nt.Float]: """Finds the farthest point on the segment that is inside the unit hypersphere.""" assert start.shape == end.shape assert len(start.shape) == 1 assert start.shape[0] > 0 dimensionality: int = start.shape[0] sample_count: int = int(np.linalg.norm(end - start) / sampling_step) + 1 t = np.linspace(0, 1, sample_count, endpoint=True).reshape(-1, 1) interp = end * t + start * (1 - t) inside_hyper = np.linalg.norm(interp, axis=1) < 1 assert inside_hyper.shape[0] == sample_count selection = interp[inside_hyper] if selection.shape[0] == 0: return start elif selection.shape[0] == sample_count: return end furthest = selection[-1, :] assert furthest.shape[0] == dimensionality return furthest
[docs] def clamp_to_sqewed_hs( center: NDArray[Shape["N"], nt.Floating], start: NDArray[Shape["N"], nt.Floating], end: NDArray[Shape["N"], nt.Floating], radii: NDArray[Shape["N"], nt.Floating], ) -> NDArray[Shape["N"], nt.Floating]: """Finds the farthest point on the segment that is inside the sqewed hypersphere. radii of the hypersphere in each dimensions is computed by streching the space in each dimension, then computing relative to the unit hypersphere, then unstreching. """ assert len(center.shape) == 1 assert center.shape == start.shape == end.shape == radii.shape start_n = (start - center) / radii end_n = (end - center) / radii clamped_n = clamp_to_unit_hs(start=start_n, end=end_n) clamped = clamped_n * radii + center assert end.shape == clamped.shape return clamped
xyz_slots = [0, 1, 2] quat_slots = [3, 4, 5, 6] dims = len(xyz_slots) + len(quat_slots)
[docs] def fuse_xyz_quat(xyz: Flo3, quat: Quaternion) -> NDArray[Shape["7"], nt.Floating]: quat = qt_normalize(quat) fused = np.empty(dims, dtype=float) fused[xyz_slots] = xyz fused[quat_slots] = qt.as_float_array(quat) assert fused.shape == (dims,) return fused
[docs] def unfuse_xyz_quat(arr: NDArray[Shape["7"], nt.Floating]) -> Tuple[Flo3, Quaternion]: q = qt.from_float_array(arr[quat_slots]) q = qt_normalize(q) return arr[xyz_slots], q
[docs] def clamp_xyz_quat( center: Tuple[Flo3, Quaternion], start: Tuple[Flo3, Quaternion], end: Tuple[Flo3, Quaternion], radii: Tuple[float, float], ) -> Tuple[Flo3, Quaternion]: """wrapper for clamp_to_sqewed_hs specialized in one 3D coordinate + one quaternion. The math for the quaternion is wrong (lerp instead of slerp). So: Center and start quat should not be opposite from each-other. Precision goes down if they are far appart. Args: center: center from which not to diverge start: start point of the interpolation end: end point of the interpolation radii: allowed divergence for coord and quat Returns: """ xyz_radius, quat_radius = radii center_xyz, center_quat = center start_xyz, start_quat = start end_xyz, end_quat = end # quat_oneify = 1 / start_quat # quat_oneify = 1 # for q in [center_quat, start_quat, end_quat]: # q = q * quat_oneify radii_arr = np.array( [xyz_radius] * len(xyz_slots) + [quat_radius] * len(quat_slots), dtype=float ) fused_center = fuse_xyz_quat(center_xyz, center_quat) fused_end = fuse_xyz_quat(end_xyz, end_quat) fused_start = fuse_xyz_quat(start_xyz, start_quat) fused_clamp = clamp_to_sqewed_hs(fused_center, fused_start, fused_end, radii_arr) fuse_xyz, fuse_quat = unfuse_xyz_quat(arr=fused_clamp) # fuse_quat = fuse_quat / quat_oneify assert fuse_xyz.shape == (3,) assert qt.as_float_array(fuse_quat).shape == (4,) return fuse_xyz, fuse_quat