defset_lr(model:Union[tf.keras.Model,torch.nn.Module],lr:float,weight_decay:Optional[float]=None):"""Set the learning rate of a given `model` generated by `fe.build`. This method can be used with TensorFlow models: ```python m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam") # m.optimizer.lr == 0.001 fe.backend.set_lr(m, lr=0.8) # m.optimizer.lr == 0.8 ``` This method can be used with PyTorch models: ```python m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam") # m.optimizer.param_groups[-1]['lr'] == 0.001 fe.backend.set_lr(m, lr=0.8) # m.optimizer.param_groups[-1]['lr'] == 0.8 ``` Args: model: A neural network instance to modify. lr: The learning rate to assign to the `model`. weight_decay: The weight decay parameter, this is only relevant when using `tfa.DecoupledWeightDecayExtension`. Raises: ValueError: If `model` is an unacceptable data type. """asserthasattr(model,"fe_compiled")andmodel.fe_compiled,"set_lr only accept models from fe.build"ifisinstance(model,tf.keras.Model):# when using decoupled weight decay like SGDW or AdamW, weight decay factor needs to change together with lr# see https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/DecoupledWeightDecayExtension for detailifisinstance(model.current_optimizer,tfa.optimizers.DecoupledWeightDecayExtension)orhasattr(model.current_optimizer,"inner_optimizer")andisinstance(model.current_optimizer.inner_optimizer,tfa.optimizers.DecoupledWeightDecayExtension):ifweight_decayisNone:weight_decay=tf.keras.backend.get_value(model.current_optimizer.weight_decay)*lr/get_lr(model)tf.keras.backend.set_value(model.current_optimizer.weight_decay,weight_decay)tf.keras.backend.set_value(model.current_optimizer.lr,lr)elifisinstance(model,torch.nn.Module):forparam_groupinmodel.current_optimizer.param_groups:param_group['lr']=lrelse:raiseValueError("Unrecognized model instance {}".format(type(model)))