File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 66import jax .numpy as jnp
77import numpy as np
88import pytest
9+ from jax import Array
910
1011import crazyflow .sim .functional as F
1112from crazyflow .control import Control
@@ -89,7 +90,7 @@ def test_functional_attitude_control(attitude_freq: int):
8990 # Check if we can apply inside of jax jit which does not permit device tracking. See
9091 # https://github.com/jax-ml/jax/issues/26000 for more context.
9192 @jax .jit
92- def apply_control (data : SimData , cmd : jnp . ndarray ) -> SimData :
93+ def apply_control (data : SimData , cmd : Array ) -> SimData :
9394 return F .attitude_control (data , cmd )
9495
9596 jax .block_until_ready (apply_control (data , cmd ))
@@ -103,7 +104,7 @@ def test_functional_attitude_control_device(device: str):
103104 cmd = np .random .rand (sim .n_worlds , sim .n_drones , 4 )
104105 data = F .attitude_control (data , cmd )
105106 controls = data .controls .attitude
106- assert isinstance (controls .staged_cmd , jnp . ndarray ), "Buffers must remain JAX arrays"
107+ assert isinstance (controls .staged_cmd , Array ), "Buffers must remain JAX arrays"
107108 assert jnp .all (controls .staged_cmd == cmd ), "Buffers must match command"
108109
109110
You can’t perform that action at this time.
0 commit comments