Skip to content

a-paulus/softjax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SoftJAX logo

Soft differentiable programming in JAX

PyPI version Python version License arXiv paper

Looking for PyTorch? See SoftTorch.

What is SoftJAX?

SoftJAX provides soft differentiable drop-in replacements for traditionally non-differentiable functions in JAX, including

  • elementwise operators: abs, relu, clip, sign, round and heaviside;
  • array-valued operators: (arg)max, (arg)min, (arg)quantile, (arg)median, (arg)sort, (arg)top_k and rank;
  • comparison operators such as: greater, equal or isclose;
  • logical operators such as: logical_and, all or any;
  • selection operators such as: where, take_along_axis, dynamic_index_in_dim or choose.

All operators offer multiple modes and adjustable strength of softening, allowing for e.g. smoothness of the soft function or boundedness of the softened region, depending on the user needs.

Moreover, we tightly integrate functionality for deploying functions using straight-through-estimation, where we use non-differentiable functions in the forward pass and their differentiable replacements in the backward pass.

Note, while SoftJAX is designed to provide direct drop-in replacements for JAX's operators, soft axis-wise operators return probability distributions over indices (instead of an index), effectively changing the shape of the function's output.

Installation

Requires Python 3.11+.

pip install softjax

Documentation

Available at https://a-paulus.github.io/softjax/.

Quick examples

Robust median regression: Minimize the median absolute residual to be robust to outliers.

import jax, jax.numpy as jnp, softjax as sj

key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (20, 3))
w_true = jnp.array([1.0, -2.0, 0.5])
y = X @ w_true
y = y.at[0].set(1e6)  # inject outlier

def median_regression_loss(w, X, y, mode="smooth"):
    residuals = y - X @ w
    return sj.median(sj.abs(residuals, mode=mode), mode=mode)

w = jnp.zeros(3)
print("Hard grad:", jax.grad(median_regression_loss)(w, X, y, mode="hard"))
print("Soft grad:", jax.grad(median_regression_loss)(w, X, y, mode="smooth"))

for _ in range(50):
    w = w - 0.1 * jax.grad(median_regression_loss)(w, X, y)
print("Learned w:", w, " (true:", w_true, ")")
Hard grad: [-0.5108  0.4321 -0.0122]
Soft grad: [-0.8061  0.5254  0.099 ]
Learned w: [ 1.  -2.   0.5]  (true: [ 1.  -2.   0.5] )

Top-k feature selection: Discover which features of a trained model are important.

n_features, k = 10, 3
k1, k2 = jax.random.split(jax.random.PRNGKey(42))
X = jax.random.normal(k1, (100, n_features))
w_model = jnp.array([0, 2.0, 0, -1.5, 0, 0, 0, 5.0, 0, 0])
y = X @ w_model + 0.1 * jax.random.normal(k2, (100,))

def feature_selection_loss(g, X, y, w_model, mode="smooth"):
    _, soft_idx = sj.top_k(g, k=k, mode=mode, gated_grad=False)
    mask = soft_idx.sum(axis=0)
    y_pred = (X * mask) @ w_model
    return jnp.mean(sj.abs(y_pred - y))

g = jnp.zeros(n_features)
print("Hard grad:", jax.grad(feature_selection_loss)(g, X, y, w_model, mode="hard"))
print("Soft grad:", jax.grad(feature_selection_loss)(g, X, y, w_model, mode="smooth"))

for _ in range(5):
    g = g - 0.001 * jax.grad(feature_selection_loss)(g, X, y, w_model)
print("Selected features:", jax.lax.top_k(g, k=k)[1])
Hard grad: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Soft grad: [  2268.416   -2371.1378   2268.416    1126.1998   2268.416    2268.416
   2268.416  -14633.9742   2268.416    2268.416 ]
Selected features: [7 1 3]

Differentiable threshold filtering: Learn a threshold that gates inputs.

x = jnp.array([0.2, 0.8, 0.5, 1.2, 0.1])
target_sum = 2.0  # sum of values above threshold = 2.0 (i.e. 0.8 + 1.2)

def filter_loss(t, x, target, mode="smooth"):
    mask = sj.greater(x, t, mode=mode)
    return (jnp.sum(mask * x) - target) ** 2

t = jnp.array(0.0)
print("Hard grad:", jax.grad(filter_loss)(t, x, target_sum, mode="hard"))
print("Soft grad:", jax.grad(filter_loss)(t, x, target_sum, mode="smooth"))

for _ in range(20):
    t = t - 0.1 * jax.grad(filter_loss)(t, x, target_sum)
print("Learned threshold:", t)
Hard grad: 0.0
Soft grad: -0.6600359275215457
Learned threshold: 0.6211048323197621

Rule-based classifier: Learn decision boundaries [lo, hi] for a rule using soft logic and straight-through estimation. The rule is true if any element of a feature is inside [lo, hi].

x = jnp.array([[0.2, 0.8], [0.5, 0.3], [0.9, 0.1], [0.4, 0.7], [0.1, 0.4], [0.2, 0.7], [0.4, 0.1], [0.4, 0.7],
               [0.7, 0.29], [0.3, 0.3], [0.61, 0.25], [0.4, 0.6], [0.0, 0.1], [0.5, 0.3], [0.4, 0.9], [0.1, 0.57]])
labels = jnp.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0,
                    0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0])

@sj.st
def rule_loss(params, x, labels, mode="smooth"):
    lo, hi = params[0], params[1]
    above = sj.greater(x, lo, mode=mode)
    below = sj.less(x, hi, mode=mode)
    in_range = sj.logical_and(above, below)
    preds = sj.any(in_range, axis=-1)
    return ((preds - labels) ** 2).sum()

params = jnp.array([0.0, 1.0])  # start with wide range [0, 1]
print("Hard grad:", jax.grad(rule_loss)(params, x, labels, mode="hard"))
print("Soft grad:", jax.grad(rule_loss)(params, x, labels, mode="smooth"))

for _ in range(20):
    params = params - 0.01 * jax.grad(rule_loss)(params, x, labels)
print("Learned [lo, hi]:", params)
Hard grad: [0. 0.]
Soft grad: [-4.2777  1.4152]
Learned [lo, hi]: [0.2925 0.5999]

Optimization trajectories

Citation

If this library helped your academic work, please consider citing: (arXiv link)

@article{paulus2026softjax,
  title={{SoftJAX} \& {SoftTorch}: Empowering Automatic Differentiation Libraries with Informative Gradients},
  author={Paulus, Anselm and Geist, A.\ Ren\'e and Musil, V\'it and Hoffmann, Sebastian and Beker, Onur and Martius, Georg},
  journal={arXiv preprint},
  year={2026},
  eprint={2603.08824}
}

(Also consider starring the project on GitHub)

Special thanks and credit go to Patrick Kidger for the awesome JAX repositories that served as the basis for the documentation of this project.

Feedback

If you have any suggestions for improvement or other feedback, please reach out or raise a GitHub issue!

See also

Other libraries in the JAX ecosystem

Always useful
Equinox: neural networks and everything not already in core JAX!
jaxtyping: type annotations for shape/dtype of arrays.

Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
paramax: parameterizations and constraints for PyTrees.

Scientific computing
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)

Awesome JAX
Awesome JAX: a longer list of other JAX projects.

Other libraries on differentiable programming

Differentiable sorting, top-k and rank
DiffSort: Differentiable sorting networks in PyTorch.
DiffTopK: Differentiable top-k in PyTorch.
FastSoftSort: Fast differentiable sorting and rank in JAX.
Differentiable Top-k with Optimal Transport in JAX.
SoftSort: Differentiable argsort in PyTorch and TensorFlow.

Other
DiffLogic: Differentiable logic gate networks in PyTorch.
SmoothOT: Smooth and Sparse Optimal Transport.
JaxOpt: Differentiable optimization in JAX.

Papers on differentiable algorithms

SoftJAX builds on / implements various different algorithms for e.g. differentiable argtop_k, sorting and rank, including:

Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application
Differentiable Ranks and Sorting using Optimal Transport
Differentiable Top-k with Optimal Transport
SoftSort: A Continuous Relaxation for the argsort Operator
Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances
Smooth and Sparse Optimal Transport
Smooth Approximations of the Rounding Function
Fast Differentiable Sorting and Ranking
Differentiable Sorting Networks for Scalable Sorting and Ranking Supervision

Please check the API Documentation for implementation details.