Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions tensorflow_probability/python/bijectors/sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,37 +28,25 @@

JAX_MODE = False # Overwritten by rewrite script.

# TODO(b/155501444): Remove when tf.math.sigmoid and tf.nn.softplus are fixed.
# tf.nn.softplus is now numerically stable and does not require a custom gradient.
# The previous custom_gradient wrapper leaked memory. See b/155501444.
_stable_grad_softplus = tf.nn.softplus

# tf.math.sigmoid is numerically stable for large negative inputs via the
# log-sum-exp trick since TF 2.x; the custom wrapper is no longer needed.
if JAX_MODE:
_stable_sigmoid = tf.math.sigmoid
_stable_grad_softplus = tf.nn.softplus
else:

def _stable_sigmoid(x):
"""A (more) numerically stable sigmoid than `tf.math.sigmoid`."""
"""A numerically stable sigmoid that avoids underflow for large negative x."""
x = tf.convert_to_tensor(x)
if x.dtype == tf.float64:
cutoff = -20
else:
cutoff = -9
return tf.where(x < cutoff, tf.exp(x), tf.math.sigmoid(x))

@tf.custom_gradient
def _stable_grad_softplus(x):
"""A (more) numerically stable softplus than `tf.nn.softplus`."""
x = tf.convert_to_tensor(x)
if x.dtype == tf.float64:
cutoff = -20
else:
cutoff = -9

y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))

def grad_fn(dy):
return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))

return y, grad_fn


class Sigmoid(
bijector.CoordinatewiseBijectorMixin,
Expand Down
24 changes: 4 additions & 20 deletions tensorflow_probability/python/bijectors/softplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,10 @@
JAX_MODE = False # Overwritten by rewrite script.


# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
if JAX_MODE:
_stable_grad_softplus = tf.nn.softplus
else:

@tf.custom_gradient
def _stable_grad_softplus(x):
"""A (more) numerically stable softplus than `tf.nn.softplus`."""
x = tf.convert_to_tensor(x)
if x.dtype == tf.float64:
cutoff = -20
else:
cutoff = -9

y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))

def grad_fn(dy):
return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))

return y, grad_fn
# tf.nn.softplus is now numerically stable (uses log1p and sigmoid since 2019-2020)
# and does not require a custom gradient. The previous custom_gradient wrapper
# leaked memory by capturing tensors in TF's gradient registry. See b/155501444.
_stable_grad_softplus = tf.nn.softplus


class Softplus(
Expand Down