Skip to content

_set_lr

set_lr

Set the learning rate of a given model generated by fe.build.

This method can be used with TensorFlow models:

m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")  # m.optimizer.lr == 0.001
fe.backend.set_lr(m, lr=0.8)  # m.optimizer.lr == 0.8

This method can be used with PyTorch models:

m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")  # m.optimizer.param_groups[-1]['lr'] == 0.001
fe.backend.set_lr(m, lr=0.8)  # m.optimizer.param_groups[-1]['lr'] == 0.8

Parameters:

Name Type Description Default
model Union[Model, Module]

A neural network instance to modify.

required
lr float

The learning rate to assign to the model.

required
weight_decay Optional[float]

The weight decay parameter, this is only relevant when using tfa.DecoupledWeightDecayExtension.

None

Raises:

Type Description
ValueError

If model is an unacceptable data type.

Source code in fastestimator/fastestimator/backend/_set_lr.py
def set_lr(model: Union[tf.keras.Model, torch.nn.Module], lr: float, weight_decay: Optional[float] = None):
    """Set the learning rate of a given `model` generated by `fe.build`.

    This method can be used with TensorFlow models:
    ```python
    m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")  # m.optimizer.lr == 0.001
    fe.backend.set_lr(m, lr=0.8)  # m.optimizer.lr == 0.8
    ```

    This method can be used with PyTorch models:
    ```python
    m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")  # m.optimizer.param_groups[-1]['lr'] == 0.001
    fe.backend.set_lr(m, lr=0.8)  # m.optimizer.param_groups[-1]['lr'] == 0.8
    ```

    Args:
        model: A neural network instance to modify.
        lr: The learning rate to assign to the `model`.
        weight_decay: The weight decay parameter, this is only relevant when using `tfa.DecoupledWeightDecayExtension`.

    Raises:
        ValueError: If `model` is an unacceptable data type.
    """
    assert hasattr(model, "fe_compiled") and model.fe_compiled, "set_lr only accept models from fe.build"
    if isinstance(model, tf.keras.Model):
        # when using decoupled weight decay like SGDW or AdamW, weight decay factor needs to change together with lr
        # see https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/DecoupledWeightDecayExtension for detail
        if isinstance(model.current_optimizer, tfa.optimizers.DecoupledWeightDecayExtension) or hasattr(
                model.current_optimizer, "inner_optimizer") and isinstance(
                    model.current_optimizer.inner_optimizer, tfa.optimizers.DecoupledWeightDecayExtension):
            if weight_decay is None:
                weight_decay = tf.keras.backend.get_value(model.current_optimizer.weight_decay) * lr / get_lr(model)
            tf.keras.backend.set_value(model.current_optimizer.weight_decay, weight_decay)
        tf.keras.backend.set_value(model.current_optimizer.lr, lr)
    elif isinstance(model, torch.nn.Module):
        for param_group in model.current_optimizer.param_groups:
            param_group['lr'] = lr
    else:
        raise ValueError("Unrecognized model instance {}".format(type(model)))