Set the learning rate of a given model
generated by fe.build
.
This method can be used with TensorFlow models:
m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam") # m.optimizer.lr == 0.001
fe.backend.set_lr(m, lr=0.8) # m.optimizer.lr == 0.8
This method can be used with PyTorch models:
m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam") # m.optimizer.param_groups[-1]['lr'] == 0.001
fe.backend.set_lr(m, lr=0.8) # m.optimizer.param_groups[-1]['lr'] == 0.8
Parameters:
Name |
Type |
Description |
Default |
model |
Union[Model, Module]
|
A neural network instance to modify.
|
required
|
lr |
float
|
The learning rate to assign to the model .
|
required
|
weight_decay |
Optional[float]
|
The weight decay parameter, this is only relevant when using tfa.DecoupledWeightDecayExtension .
|
None
|
Raises:
Type |
Description |
ValueError
|
If model is an unacceptable data type.
|
Source code in fastestimator/fastestimator/backend/_set_lr.py
| def set_lr(model: Union[tf.keras.Model, torch.nn.Module], lr: float, weight_decay: Optional[float] = None):
"""Set the learning rate of a given `model` generated by `fe.build`.
This method can be used with TensorFlow models:
```python
m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam") # m.optimizer.lr == 0.001
fe.backend.set_lr(m, lr=0.8) # m.optimizer.lr == 0.8
```
This method can be used with PyTorch models:
```python
m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam") # m.optimizer.param_groups[-1]['lr'] == 0.001
fe.backend.set_lr(m, lr=0.8) # m.optimizer.param_groups[-1]['lr'] == 0.8
```
Args:
model: A neural network instance to modify.
lr: The learning rate to assign to the `model`.
weight_decay: The weight decay parameter, this is only relevant when using `tfa.DecoupledWeightDecayExtension`.
Raises:
ValueError: If `model` is an unacceptable data type.
"""
assert hasattr(model, "fe_compiled") and model.fe_compiled, "set_lr only accept models from fe.build"
if isinstance(model, tf.keras.Model):
# when using decoupled weight decay like SGDW or AdamW, weight decay factor needs to change together with lr
# see https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/DecoupledWeightDecayExtension for detail
if isinstance(model.current_optimizer, tfa.optimizers.DecoupledWeightDecayExtension) or hasattr(
model.current_optimizer, "inner_optimizer") and isinstance(
model.current_optimizer.inner_optimizer, tfa.optimizers.DecoupledWeightDecayExtension):
if weight_decay is None:
weight_decay = tf.keras.backend.get_value(model.current_optimizer.weight_decay) * lr / get_lr(model)
tf.keras.backend.set_value(model.current_optimizer.weight_decay, weight_decay)
tf.keras.backend.set_value(model.current_optimizer.lr, lr)
elif isinstance(model, torch.nn.Module):
for param_group in model.current_optimizer.param_groups:
param_group['lr'] = lr
else:
raise ValueError("Unrecognized model instance {}".format(type(model)))
|