Compute the k=0 branch of the Lambert W function.
See https://en.wikipedia.org/wiki/Lambert_W_function for details. Only valid for inputs >= -1/e (approx -0.368). We
do not check this for the sake of speed, but if an input is out of domain the return value may be random /
inconsistent or even NaN.
This method can be used with Numpy data:
n = np.array([-1.0/math.e, -0.34, -0.32, -0.2, 0, 0.12, 0.15, math.e, 5, math.exp(1 + math.e), 100])
b = fe.backend.lambertw(n) # [-1, -0.654, -0.560, -0.259, 0, 0.108, 0.132, 1, 1.327, 2.718, 3.386]
This method can be used with TensorFlow tensors:
t = tf.constant([-1.0/math.e, -0.34, -0.32, -0.2, 0, 0.12, 0.15, math.e, 5, math.exp(1 + math.e), 100])
b = fe.backend.lambertw(t) # [-1, -0.654, -0.560, -0.259, 0, 0.108, 0.132, 1, 1.327, 2.718, 3.386]
This method can be used with PyTorch tensors:
p = torch.tensor([-1.0/math.e, -0.34, -0.32, -0.2, 0, 0.12, 0.15, math.e, 5, math.exp(1 + math.e), 100])
b = fe.backend.lambertw(p) # [-1, -0.654, -0.560, -0.259, 0, 0.108, 0.132, 1, 1.327, 2.718, 3.386]
Parameters:
Name |
Type |
Description |
Default |
tensor |
Tensor
|
|
required
|
Returns:
Type |
Description |
Tensor
|
The lambertw function evaluated at tensor .
|
Raises:
Type |
Description |
ValueError
|
If tensor is an unacceptable data type.
|
Source code in fastestimator/fastestimator/backend/_lambertw.py
| def lambertw(tensor: Tensor) -> Tensor:
"""Compute the k=0 branch of the Lambert W function.
See https://en.wikipedia.org/wiki/Lambert_W_function for details. Only valid for inputs >= -1/e (approx -0.368). We
do not check this for the sake of speed, but if an input is out of domain the return value may be random /
inconsistent or even NaN.
This method can be used with Numpy data:
```python
n = np.array([-1.0/math.e, -0.34, -0.32, -0.2, 0, 0.12, 0.15, math.e, 5, math.exp(1 + math.e), 100])
b = fe.backend.lambertw(n) # [-1, -0.654, -0.560, -0.259, 0, 0.108, 0.132, 1, 1.327, 2.718, 3.386]
```
This method can be used with TensorFlow tensors:
```python
t = tf.constant([-1.0/math.e, -0.34, -0.32, -0.2, 0, 0.12, 0.15, math.e, 5, math.exp(1 + math.e), 100])
b = fe.backend.lambertw(t) # [-1, -0.654, -0.560, -0.259, 0, 0.108, 0.132, 1, 1.327, 2.718, 3.386]
```
This method can be used with PyTorch tensors:
```python
p = torch.tensor([-1.0/math.e, -0.34, -0.32, -0.2, 0, 0.12, 0.15, math.e, 5, math.exp(1 + math.e), 100])
b = fe.backend.lambertw(p) # [-1, -0.654, -0.560, -0.259, 0, 0.108, 0.132, 1, 1.327, 2.718, 3.386]
```
Args:
tensor: The input value.
Returns:
The lambertw function evaluated at `tensor`.
Raises:
ValueError: If `tensor` is an unacceptable data type.
"""
if tf.is_tensor(tensor):
return tfp.math.lambertw(tensor)
if isinstance(tensor, torch.Tensor):
return _torch_lambertw(tensor)
elif isinstance(tensor, np.ndarray):
# scipy implementation is numerically unstable at exactly -1/e, but the result should be -1.0
return np.nan_to_num(lamw(tensor, k=0, tol=1e-6).real.astype(tensor.dtype), nan=-1.0)
else:
raise ValueError("Unrecognized tensor type {}".format(type(tensor)))
|