Bases: Trace
Learning rate scheduler trace that changes the learning rate while training.
This class requires an input function which takes either 'epoch' or 'step' as input:
s = LRScheduler(model=model, lr_fn=lambda step: fe.schedule.cosine_decay(step, cycle_length=3750, init_lr=1e-3))
fe.Estimator(..., traces=[s]) # Learning rate will change based on step
s = LRScheduler(model=model, lr_fn=lambda epoch: fe.schedule.cosine_decay(epoch, cycle_length=3750, init_lr=1e-3))
fe.Estimator(..., traces=[s]) # Learning rate will change based on epoch
Parameters:
Name |
Type |
Description |
Default |
model |
Union[tf.keras.Model, torch.nn.Module]
|
A model instance compiled with fe.build. |
required
|
lr_fn |
Callable[[int], float]
|
A lr scheduling function that takes either 'epoch' or 'step' as input. |
required
|
Raises:
Type |
Description |
AssertionError
|
If the lr_fn is not configured properly. |
Source code in fastestimator\fastestimator\trace\adapt\lr_scheduler.py
| @traceable()
class LRScheduler(Trace):
"""Learning rate scheduler trace that changes the learning rate while training.
This class requires an input function which takes either 'epoch' or 'step' as input:
```python
s = LRScheduler(model=model, lr_fn=lambda step: fe.schedule.cosine_decay(step, cycle_length=3750, init_lr=1e-3))
fe.Estimator(..., traces=[s]) # Learning rate will change based on step
s = LRScheduler(model=model, lr_fn=lambda epoch: fe.schedule.cosine_decay(epoch, cycle_length=3750, init_lr=1e-3))
fe.Estimator(..., traces=[s]) # Learning rate will change based on epoch
```
Args:
model: A model instance compiled with fe.build.
lr_fn: A lr scheduling function that takes either 'epoch' or 'step' as input.
Raises:
AssertionError: If the `lr_fn` is not configured properly.
"""
system: System
def __init__(self, model: Union[tf.keras.Model, torch.nn.Module], lr_fn: Callable[[int], float]) -> None:
self.model = model
self.lr_fn = lr_fn
assert hasattr(lr_fn, "__call__"), "lr_fn must be a function"
arg = list(inspect.signature(lr_fn).parameters.keys())
assert len(arg) == 1 and arg[0] in {"step", "epoch"}, "the lr_fn input arg must be either 'step' or 'epoch'"
self.schedule_mode = arg[0]
super().__init__(mode="train", outputs=self.model.model_name + "_lr")
def on_epoch_begin(self, data: Data) -> None:
if self.schedule_mode == "epoch":
new_lr = np.float32(self.lr_fn(self.system.epoch_idx))
set_lr(self.model, new_lr)
def on_batch_begin(self, data: Data) -> None:
if self.schedule_mode == "step":
new_lr = np.float32(self.lr_fn(self.system.global_step))
set_lr(self.model, new_lr)
def on_batch_end(self, data: Data) -> None:
if self.system.log_steps and (self.system.global_step % self.system.log_steps == 0
or self.system.global_step == 1):
current_lr = np.float32(get_lr(self.model))
data.write_with_log(self.outputs[0], current_lr)
|