Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
# pixi environments
.pixi
*.egg-info
site
41 changes: 16 additions & 25 deletions docs/api/core.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

::: drone_controllers.core

The core module provides the foundational functionality for controller parametrization and registration.
The core module provides the foundational functionality for controller parametrization.

## Key Concepts

### Controller Parametrization

The `parametrize` function allows you to automatically configure controllers with parameters for specific drone models:
The `parametrize` function automatically configures a controller with parameters for a specific drone model by inspecting the function's keyword-only arguments and filling them from the corresponding TOML file:

```python
from drone_controllers import parametrize
Expand All @@ -18,38 +18,29 @@ from drone_controllers.mellinger import state2attitude
controller = parametrize(state2attitude, "cf2x_L250")

# Use the controller (all parameters are automatically filled in)
rpyt, pos_err = controller(pos, quat, vel, ang_vel, cmd)
rpyt, pos_err = controller(pos, quat, vel, cmd)
```

### Parameter Registry
### Manual Parameter Loading

Controllers register their parameter types using the `@register_controller_parameters` decorator:
Use `load_params` to inspect or override parameters directly:

```python
@register_controller_parameters(MyControllerParams)
def my_controller(pos, vel, *, param1, param2, param3):
# Controller implementation
pass
```

### ControllerParams Protocol
from drone_controllers.core import load_params

All controller parameter classes must implement the `ControllerParams` protocol:
params = load_params("mellinger", "state2attitude", "cf2x_L250")
print(params["mass"]) # 0.029
print(params["kp"]) # position gain array
```

- `load(drone_model: str)` - Load parameters for a specific drone model
- `_asdict()` - Convert parameters to a dictionary
### Array Namespace Support

## Example Usage
Both `parametrize` and `load_params` accept an `xp` argument so that static parameters are placed in the correct array namespace before being bound to the function:

```python
from functools import partial
from drone_controllers.mellinger.params import StateParams

# Manual parameter loading
params = StateParams.load("cf2x_L250")
controller = partial(state2attitude, **params._asdict())

# Equivalent to using parametrize
import jax.numpy as jnp
from drone_controllers import parametrize
controller = parametrize(state2attitude, "cf2x_L250")
from drone_controllers.mellinger import state2attitude

controller = parametrize(state2attitude, "cf2x_L250", xp=jnp)
```
3 changes: 3 additions & 0 deletions docs/api/drones.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Drones

::: drone_controllers.drones
34 changes: 7 additions & 27 deletions docs/api/mellinger.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from drone_controllers.mellinger import state2attitude

controller = parametrize(state2attitude, "cf2x_L250")

rpyt, pos_err_i = controller(pos, quat, vel, ang_vel, cmd)
rpyt, pos_err_i = controller(pos, quat, vel, cmd)
```

### attitude2force_torque
Expand All @@ -33,7 +33,7 @@ from drone_controllers.mellinger import attitude2force_torque

controller = parametrize(attitude2force_torque, "cf2x_L250")

force, torque, att_err_i = controller(pos, quat, vel, ang_vel, rpyt_cmd)
force, torque, att_err_i = controller(quat, ang_vel, rpyt_cmd)
```

### force_torque2rotor_vel
Expand All @@ -52,26 +52,6 @@ controller = parametrize(force_torque2rotor_vel, "cf2x_L250")
rotor_speeds = controller(force, torque)
```

## Parameter Classes

### StateParams

::: drone_controllers.mellinger.params.StateParams

Parameters for the position control loop.

### AttitudeParams

::: drone_controllers.mellinger.params.AttitudeParams

Parameters for the attitude control loop.

### ForceTorqueParams

::: drone_controllers.mellinger.params.ForceTorqueParams

Parameters for the force/torque to rotor speed conversion.

## Complete Controller Pipeline

Here's how to use all three components together:
Expand Down Expand Up @@ -100,8 +80,8 @@ ang_vel = np.array([0.0, 0.0, 0.0])
cmd = np.array([1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

# Run the complete pipeline
rpyt, pos_err_i = state_ctrl(pos, quat, vel, ang_vel, cmd)
force, torque, att_err_i = attitude_ctrl(pos, quat, vel, ang_vel, rpyt)
rpyt, pos_err_i = state_ctrl(pos, quat, vel, cmd)
force, torque, att_err_i = attitude_ctrl(quat, ang_vel, rpyt)
rotor_speeds = rotor_ctrl(force, torque)

print(f"Final rotor speeds: {rotor_speeds} rad/s")
Expand All @@ -121,10 +101,10 @@ for step in range(100):

# Pass previous integral errors
ctrl_errors = (pos_err_i,) if pos_err_i is not None else None
rpyt, pos_err_i = state_ctrl(pos, quat, vel, ang_vel, cmd, ctrl_errors=ctrl_errors)
rpyt, pos_err_i = state_ctrl(pos, quat, vel, cmd, ctrl_errors=ctrl_errors)

ctrl_errors = (att_err_i,) if att_err_i is not None else None
force, torque, att_err_i = attitude_ctrl(pos, quat, vel, ang_vel, rpyt, ctrl_errors=ctrl_errors)
force, torque, att_err_i = attitude_ctrl(quat, ang_vel, rpyt, ctrl_errors=ctrl_errors)

rotor_speeds = rotor_ctrl(force, torque)
```
Expand All @@ -145,7 +125,7 @@ quat_jax = jnp.array([0.0, 0.0, 0.0, 1.0])
# JIT compile the controller
jit_controller = jit(parametrize(state2attitude, "cf2x_L250"))

rpyt, pos_err_i = jit_controller(pos_jax, quat_jax, vel_jax, ang_vel_jax, cmd_jax)
rpyt, pos_err_i = jit_controller(pos_jax, quat_jax, vel_jax, cmd_jax)
```

# References
Expand Down
42 changes: 12 additions & 30 deletions docs/getting-started/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ ang_vel = np.array([0.0, 0.0, 0.0]) # Current angular velocity [wx, wy, wz]
cmd = np.array([1.0, 0.0, 1.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

# Step 1: State to attitude control
rpyt_cmd, pos_error_integral = state_ctrl(pos, quat, vel, ang_vel, cmd)
rpyt_cmd, pos_error_integral = state_ctrl(pos, quat, vel, cmd)
print(f"Attitude command (R,P,Y,T): {rpyt_cmd}")

# Step 2: Attitude to force/torque
force, torque, att_error_integral = attitude_ctrl(pos, quat, vel, ang_vel, rpyt_cmd)
force, torque, att_error_integral = attitude_ctrl(quat, ang_vel, rpyt_cmd)
print(f"Desired force: {force[0]:.3f} N")
print(f"Desired torque: {torque}")

Expand Down Expand Up @@ -66,39 +66,23 @@ cmd_batch = np.zeros((*batch_shape, 13))
cmd_batch[..., :3] = pos_batch + np.random.randn(*batch_shape, 3) * 0.5 # Target positions

# Process entire batch at once
rpyt_batch, pos_err_batch = controller(pos_batch, quat_batch, vel_batch, ang_vel_batch, cmd_batch)
rpyt_batch, pos_err_batch = controller(pos_batch, quat_batch, vel_batch, cmd_batch)

print(f"Batch output shape: {rpyt_batch.shape}") # Should be (3, 5, 4)
print(f"Per-drone commands: {rpyt_batch[0, 0, :]}") # First drone, first timestep
```

## Manual Parameter Loading

You can also load parameters manually without using the `parametrize` decorator:
You can inspect or override parameters using `load_params`:

```python
import numpy as np
from functools import partial
from drone_controllers.mellinger import state2attitude
from drone_controllers.mellinger.params import StateParams

# Load parameters manually
params = StateParams.load("cf2x_L250")
print(f"Position gains: {params.kp}")
print(f"Velocity gains: {params.kd}")
print(f"Drone mass: {params.mass} kg")

# Create controller with custom parameters
controller = partial(state2attitude, **params._asdict())

# Use as before
pos = np.array([0.0, 0.0, 1.0])
quat = np.array([0.0, 0.0, 0.0, 1.0])
vel = np.array([0.0, 0.0, 0.0])
ang_vel = np.array([0.0, 0.0, 0.0])
cmd = np.ones(13)
from drone_controllers.core import load_params

rpyt, pos_err = controller(pos, quat, vel, ang_vel, cmd, ctrl_freq=100)
params = load_params("mellinger", "state2attitude", "cf2x_L250")
print(f"Position gains: {params['kp']}")
print(f"Velocity gains: {params['kd']}")
print(f"Drone mass: {params['mass']} kg")
```

## Array API Compatibility
Expand All @@ -124,7 +108,7 @@ controller = parametrize(state2attitude, "cf2x_L250")
from jax import jit
jit_controller = jit(controller)

rpyt, pos_err = jit_controller(pos, quat, vel, ang_vel, cmd)
rpyt, pos_err = jit_controller(pos, quat, vel, cmd)
print(f"Output type: {type(rpyt)}") # JAX array
```

Expand All @@ -143,16 +127,15 @@ controller = parametrize(state2attitude, "cf2x_L250")
pos = np.array([0.0, 0.0, 0.5])
quat = np.array([0.0, 0.0, 0.0, 1.0])
vel = np.array([0.0, 0.0, 0.0])
ang_vel = np.array([0.0, 0.0, 0.0])

# Target hover at 1m altitude
cmd = np.array([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

# First call - no integral error history
rpyt1, pos_err_i1 = controller(pos, quat, vel, ang_vel, cmd, ctrl_errors=None)
rpyt1, pos_err_i1 = controller(pos, quat, vel, cmd, ctrl_errors=None)

# Subsequent calls - pass integral error from previous step
rpyt2, pos_err_i2 = controller(pos, quat, vel, ang_vel, cmd, ctrl_errors=(pos_err_i1,))
rpyt2, pos_err_i2 = controller(pos, quat, vel, cmd, ctrl_errors=(pos_err_i1,))

print(f"Integral error evolution: {np.linalg.norm(pos_err_i1)} -> {np.linalg.norm(pos_err_i2)}")
```
Expand All @@ -176,7 +159,6 @@ for drone in Drones:

Now that you've seen the basics, explore:

- **[Concepts](../concepts/overview.md)** - Understand the theory behind the controllers
- **[API Reference](../api/core.md)** - Complete API documentation

## Common Issues
Expand Down
8 changes: 3 additions & 5 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,20 @@ controller = parametrize(state2attitude, "cf2x_L250")
pos = np.array([0.0, 0.0, 1.0]) # position [x, y, z]
quat = np.array([0.0, 0.0, 0.0, 1.0]) # quaternion [x, y, z, w]
vel = np.array([0.0, 0.0, 0.0]) # velocity [vx, vy, vz]
ang_vel = np.array([0.0, 0.0, 0.0]) # angular velocity [wx, wy, wz]

# Command: [x, y, z, vx, vy, vz, ax, ay, az, yaw, r_rate, p_rate, y_rate]
cmd = np.array([1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

# Compute control output
rpyt, pos_err_i = controller(pos, quat, vel, ang_vel, cmd)
rpyt, pos_err_i = controller(pos, quat, vel, cmd)
print(f"Roll-Pitch-Yaw-Thrust command: {rpyt}")
```

## Key Features

### Implemented Controllers

- **[Mellinger Controller](api/drone_controllers/mellinger/control.md)** — Geometric tracking controller based on the original Crazyflie implementation
- **[Mellinger Controller](api/mellinger.md)** — Geometric tracking controller based on the original Crazyflie implementation

### Supported Drone Models

Expand Down Expand Up @@ -89,6 +88,5 @@ All controllers support the Python Array API standard, meaning you can use them
## Getting Help

- Read the [Getting Started](getting-started/installation.md) guide
- Browse the [API Reference](api/core.md)
- Check out [Concepts](concepts/overview.md) for theory
- Browse the [API Reference](api/core.md)
- Report issues on [GitHub](https://github.com/learnsyslab/drone-controllers/issues)
Loading
Loading