Source code for easy_robot_control.utils.trajectories

from dataclasses import dataclass
import traceback
from typing import Callable, Dict, Optional, TypeVar, Union
import numpy as np
from numpy.typing import NDArray
import quaternion as qt
from rclpy.time import Time
from scipy.spatial import geometric_slerp
from tf2_ros import Duration


[docs] @dataclass class Point: time: Time coord: Optional[NDArray] = None rot: Optional[qt.quaternion] = None
[docs] @dataclass class PointDefined: time: Time coord: NDArray rot: qt.quaternion
[docs] def assessPointDefine(point: Point) -> PointDefined: assert point.coord is not None assert point.rot is not None return PointDefined(time=point.time, coord=point.coord, rot=point.rot)
TypeP = TypeVar("TypeP", float, np.ndarray, qt.quaternion) # Am I too used to C++ ??? no that can't be IFunGlob = Callable[[TypeP, float], TypeP] IFunTime = Callable[[float, float], float] # Ok maybe yes
[docs] def smooth(relative_end: float, t: float) -> float: """smoothes the interval [0, 1] to have a soft start and end (derivative is zero) """ x = relative_end * (1 - np.cos(t * np.pi)) / 2 return x
[docs] def linear(relative_end: TypeP, t: float) -> TypeP: # assert isinstance(relative_end, TypeP) if isinstance(relative_end, float) or isinstance(relative_end, np.ndarray): return relative_end * t elif isinstance(relative_end, qt.quaternion): intPart: int = np.floor(t) floatPart: float = t - intPart float_slerp: NDArray = geometric_slerp( start=qt.one, end=qt.as_float_array(relative_end), t=floatPart, # only works with 0<=t<=1 ) int_slerp = relative_end**intPart combined = qt.as_quat_array(float_slerp) * int_slerp return combined else: return relative_end * t # type: ignore
[docs] def triangle(relative_end: float, t: float) -> float: return float((-abs(t - 1 / 2) * 2 + 1) * relative_end)
[docs] def boomrang(relative_end: float, t: float) -> float: # x = (1 - np.cos(t * np.pi)) / 2 z = np.sin(t * np.pi) return z * relative_end
[docs] def singo(relative_end: float, t: float) -> float: z = np.sin(1 / 2 * t * np.pi) return z * relative_end
[docs] def sincom(relative_end: float, t: float) -> float: return reverseFun(singo)(relative_end, t)
[docs] def prolongate(func: IFunTime) -> IFunTime: inv = reverseFun(func) def newfunc(r:float, t:float) -> float: intPart: int = np.floor(t) floatPart: float = t - intPart if intPart % 2 == 1: float_slerp = inv( r, floatPart, ) else: float_slerp = func( r, floatPart, ) int_slerp = intPart combined = (float_slerp) + int_slerp return combined * r return newfunc
[docs] def reverseFun(func: IFunTime) -> IFunTime: return lambda relative_end, t: -func(relative_end, 1 - t) + 1
temporal: Dict[str, IFunTime] = { "linear": linear, "triangle": triangle, "boomrang": lambda r,t: prolongate(singo)(r, triangle(r, t)), "cos": prolongate(sincom), "sin": prolongate(singo), "smooth": prolongate(smooth), "smooth2": prolongate(lambda x, y: smooth(x, smooth(x, y))), } spatial: Dict[str, IFunGlob] = { "linear": linear, }
[docs] def get_interp( interp_str: str, start: Point, end: Point ) -> Callable[[Time], PointDefined]: assert interp_str in list(spatial.keys()) assert interp_str in list(temporal.keys()) if start.coord is None: start.coord = np.zeros(3, dtype=float) if start.rot is None: start.rot = qt.one.copy() definedStart: PointDefined = assessPointDefine(start) func = lambda t: globalInterpolator( definedStart, end, spatial[interp_str], temporal[interp_str], t, ) return func
[docs] def globalInterpolator( start: PointDefined, end: Point, spatial_interp: IFunGlob, temporal_interp: IFunGlob, now: Time, ) -> PointDefined: if now < start.time: return start if now > end.time: saturated = end if saturated.coord is None: saturated.coord = start.coord if saturated.rot is None: saturated.rot = start.rot return assessPointDefine(saturated) elapsed: Duration = now - start.time total: Duration = end.time - start.time progress: float = elapsed.nanoseconds / total.nanoseconds assert 0 <= progress <= 1 t = temporal_interp(1, progress) coord = coord_interp(start.coord, end.coord, spatial_interp, t) rot = quat_interp(start.rot, end.rot, spatial_interp, t) interPoint = PointDefined( time=now, coord=coord, rot=rot, ) return interPoint
[docs] def time_interp(ifunc: IFunGlob, t: float): interp = ifunc(1, t) return interp
[docs] def coord_interp(start: NDArray, end: Optional[NDArray], ifunc: IFunGlob, t: float): if end is None: return start relative_end = end - start interp = ifunc(relative_end, t) return interp
[docs] def quat_interp(start: qt.quaternion, end: qt.quaternion, ifunc: IFunGlob, t: float): if end is None: return start relative_end = end / start interp = ifunc(relative_end, t) return interp
[docs] class Interpolator: def __init__( self, start: PointDefined, end: Point, spatial_interp: Union[None, IFunGlob, str] = None, temporal_interp: Union[None, IFunGlob, str] = None, ) -> None: self.start: PointDefined = start self.end: Point = end self.temporal_interp: IFunGlob self.spatial_interp: IFunGlob self.last: PointDefined if temporal_interp is None: self.temporal_interp = linear elif isinstance(temporal_interp, str): assert temporal_interp in spatial.keys() self.temporal_interp = temporal[temporal_interp] else: self.temporal_interp = temporal_interp if spatial_interp is None: self.spatial_interp = linear elif isinstance(spatial_interp, str): assert spatial_interp in spatial.keys() self.spatial_interp = spatial[spatial_interp] else: self.spatial_interp = spatial_interp self.last: PointDefined = self.compute(now=end.time)
[docs] def expired(self, now: Time) -> bool: return self.end.time < now
[docs] def early(self, now: Time) -> bool: return self.start.time > now
[docs] def compute(self, now: Time) -> PointDefined: return globalInterpolator( self.start, self.end, self.spatial_interp, self.temporal_interp, now, )
if __name__ == "__main__": import matplotlib.pyplot as plt col = 3 row = (len(temporal.keys()) + col - 1) // col print(len(temporal.keys()), row, col) fig, axes = plt.subplots( row, col, figsize=(6 * col, 3 * row), ) for idx, k in enumerate(temporal.keys()): f = lambda t: temporal[k](1, t) x = np.linspace(-1, 2, 1000, dtype=float) try: y = np.array([f(n) for n in x]) except Exception as exception: print(f"FAILED {k}") m = f"Exception intercepted {traceback.format_exc()}" print(m) continue ax = axes[idx // col, idx % col] ax.plot(x, y) ax.grid() ax.set_title(f"{k}") rectangle = plt.Rectangle((0, 0), 1, 1, fill=None, edgecolor='black', linewidth=2) ax.add_patch(rectangle) print(f"done {k}.png") plt.savefig(f"shapers.png") plt.cla plt.clf plt.close