Source code for motion_stack.core.utils.joint_mapper

import operator
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, overload

from .joint_state import JState

SubShaper = Optional[Callable[[float], float]]


[docs] def operate_sub_shapers( shaper1: SubShaper, shaper2: SubShaper, op: Callable[[float, float], float] ) -> SubShaper: if shaper1 and shaper2: return lambda x: op(shaper1(x), shaper2(x)) return shaper1 or shaper2
[docs] def eggify_shapers(inner: SubShaper, outer: SubShaper) -> SubShaper: if inner and outer: return lambda x: outer(inner(x)) return inner or outer
[docs] @dataclass class Shaper: """Holds and applies functions to position, velocity and effort fields. If None, the indentity is used. """ position: SubShaper = None velocity: SubShaper = None effort: SubShaper = None def _combine(self, other: "Shaper", op: Callable) -> "Shaper": return Shaper( position=operate_sub_shapers(self.position, other.position, op), velocity=operate_sub_shapers(self.velocity, other.velocity, op), effort=operate_sub_shapers(self.effort, other.effort, op), ) # Arithmetic operations def __add__(self, other: "Shaper") -> "Shaper": return self._combine(other, operator.add) def __sub__(self, other: "Shaper") -> "Shaper": return self._combine(other, operator.sub) def __mul__(self, other: "Shaper") -> "Shaper": return self._combine(other, operator.mul) def __truediv__(self, other: "Shaper") -> "Shaper": return NotImplemented # return self._combine(other, operator.truediv) @overload def __call__(self, other: "Shaper") -> "Shaper": ... @overload def __call__(self, other: JState) -> None: ... def __call__(self, other: Union["Shaper", JState]) -> Union["Shaper", None]: if isinstance(other, Shaper): return Shaper( position=eggify_shapers(other.position, self.position), velocity=eggify_shapers(other.velocity, self.velocity), effort=eggify_shapers(other.effort, self.effort), ) elif isinstance(other, JState): apply_shaper(other, self) return else: return NotImplemented
URDFJointName = str NameMap = Dict[URDFJointName, URDFJointName] StateMap = Dict[URDFJointName, Shaper]
[docs] def reverse_dict(d: Dict) -> Dict: return dict(zip(d.values(), d.keys()))
[docs] def remap_names(states: List[JState], mapping: NameMap): names_in: List[Optional[str]] = list(map(lambda s: s.name, states)) shared = set(names_in) & set(mapping.keys()) for name in shared: if name is None: continue ind = names_in.index(name) new_name = mapping.get(name) if new_name is not None: states[ind].name = new_name
[docs] def apply_shaper(state: JState, shaper: Shaper): for attr in shaper.__annotations__.keys(): sub_shaper: SubShaper = getattr(shaper, attr, None) if sub_shaper is None: continue sub_state: Optional[float] = getattr(state, attr, None) if sub_state is None: continue setattr(state, attr, sub_shaper(sub_state))
# print(f"{attr}: {sub_state} -> {sub_shaper(sub_state)}")
[docs] def shape_states(states: List[JState], mapping: StateMap): names_in: List[Optional[str]] = list(map(lambda s: s.name, states)) shared = set(names_in) & set(mapping.keys()) for name in shared: if name is None: continue ind = names_in.index(name) shaper = mapping.get(name) if shaper is not None: shaper(states[ind])