Skip to content

set_lr

set_lr

Set the learning rate of a given model.

This method can be used with TensorFlow models:

m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")  # m.optimizer.lr == 0.001
fe.backend.set_lr(m, lr=0.8)  # m.optimizer.lr == 0.8

This method can be used with PyTorch models:

m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")  # m.optimizer.param_groups[-1]['lr'] == 0.001
fe.backend.set_lr(m, lr=0.8)  # m.optimizer.param_groups[-1]['lr'] == 0.8

Parameters:

Name Type Description Default
model Union[tf.keras.Model, torch.nn.Module]

A neural network instance to modify.

required
lr float

The learning rate to assign to the model.

required

Raises:

Type Description
ValueError

If model is an unacceptable data type.

Source code in fastestimator\fastestimator\backend\set_lr.py
def set_lr(model: Union[tf.keras.Model, torch.nn.Module], lr: float):
    """Set the learning rate of a given `model`.

    This method can be used with TensorFlow models:
    ```python
    m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")  # m.optimizer.lr == 0.001
    fe.backend.set_lr(m, lr=0.8)  # m.optimizer.lr == 0.8
    ```

    This method can be used with PyTorch models:
    ```python
    m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")  # m.optimizer.param_groups[-1]['lr'] == 0.001
    fe.backend.set_lr(m, lr=0.8)  # m.optimizer.param_groups[-1]['lr'] == 0.8
    ```

    Args:
        model: A neural network instance to modify.
        lr: The learning rate to assign to the `model`.

    Raises:
        ValueError: If `model` is an unacceptable data type.
    """
    if isinstance(model, tf.keras.Model):
        tf.keras.backend.set_value(model.current_optimizer.lr, lr)
    elif isinstance(model, torch.nn.Module):
        for param_group in model.current_optimizer.param_groups:
            param_group['lr'] = lr
    else:
        raise ValueError("Unrecognized model instance {}".format(type(model)))