Skip to content

_set_lr

set_lr

Set the learning rate of a given model generated by fe.build.

This method can be used with TensorFlow models:

m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")  # m.optimizer.lr == 0.001
fe.backend.set_lr(m, lr=0.8)  # m.optimizer.lr == 0.8

This method can be used with PyTorch models:

m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")  # m.optimizer.param_groups[-1]['lr'] == 0.001
fe.backend.set_lr(m, lr=0.8)  # m.optimizer.param_groups[-1]['lr'] == 0.8

Parameters:

Name Type Description Default
model Union[Model, Module]

A neural network instance to modify.

required
lr float

The learning rate to assign to the model.

required
weight_decay Optional[float]

The weight decay parameter.

None

Raises:

Type Description
ValueError

If model is an unacceptable data type.

Source code in fastestimator/fastestimator/backend/_set_lr.py
def set_lr(model: Union[tf.keras.Model, torch.nn.Module], lr: float, weight_decay: Optional[float] = None):
    """Set the learning rate of a given `model` generated by `fe.build`.

    This method can be used with TensorFlow models:
    ```python
    m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")  # m.optimizer.lr == 0.001
    fe.backend.set_lr(m, lr=0.8)  # m.optimizer.lr == 0.8
    ```

    This method can be used with PyTorch models:
    ```python
    m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")  # m.optimizer.param_groups[-1]['lr'] == 0.001
    fe.backend.set_lr(m, lr=0.8)  # m.optimizer.param_groups[-1]['lr'] == 0.8
    ```

    Args:
        model: A neural network instance to modify.
        lr: The learning rate to assign to the `model`.
        weight_decay: The weight decay parameter.

    Raises:
        ValueError: If `model` is an unacceptable data type.
    """
    assert hasattr(model, "fe_compiled") and model.fe_compiled, "set_lr only accept models from fe.build"
    if isinstance(model, tf.keras.Model):
        # when using decoupled weight decay like SGDW or AdamW, weight decay factor needs to change together with lr
        if hasattr(model.current_optimizer, "weight_decay") and tf.keras.backend.get_value(
                model.current_optimizer.weight_decay) is not None:
            if weight_decay is None:
                weight_decay = tf.keras.backend.get_value(model.current_optimizer.weight_decay) * lr / get_lr(model)
            model.current_optimizer.weight_decay = weight_decay
        model.current_optimizer.lr = lr
    elif isinstance(model, torch.nn.Module):
        for param_group in model.current_optimizer.param_groups:
            param_group['lr'] = lr
    else:
        raise ValueError("Unrecognized model instance {}".format(type(model)))