Skip to content

Commit fc6baff

Browse files
committed
Fix type hints
1 parent a32ad15 commit fc6baff

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

tests/unit/test_functional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import jax.numpy as jnp
77
import numpy as np
88
import pytest
9+
from jax import Array
910

1011
import crazyflow.sim.functional as F
1112
from crazyflow.control import Control
@@ -89,7 +90,7 @@ def test_functional_attitude_control(attitude_freq: int):
8990
# Check if we can apply inside of jax jit which does not permit device tracking. See
9091
# https://github.com/jax-ml/jax/issues/26000 for more context.
9192
@jax.jit
92-
def apply_control(data: SimData, cmd: jnp.ndarray) -> SimData:
93+
def apply_control(data: SimData, cmd: Array) -> SimData:
9394
return F.attitude_control(data, cmd)
9495

9596
jax.block_until_ready(apply_control(data, cmd))
@@ -103,7 +104,7 @@ def test_functional_attitude_control_device(device: str):
103104
cmd = np.random.rand(sim.n_worlds, sim.n_drones, 4)
104105
data = F.attitude_control(data, cmd)
105106
controls = data.controls.attitude
106-
assert isinstance(controls.staged_cmd, jnp.ndarray), "Buffers must remain JAX arrays"
107+
assert isinstance(controls.staged_cmd, Array), "Buffers must remain JAX arrays"
107108
assert jnp.all(controls.staged_cmd == cmd), "Buffers must match command"
108109

109110

0 commit comments

Comments
 (0)