Skip to content

lr_scheduler

LRScheduler

Bases: Trace

Learning rate scheduler trace that changes the learning rate while training.

This class requires an input function which takes either 'epoch' or 'step' as input:

s = LRScheduler(model=model, lr_fn=lambda step: fe.schedule.cosine_decay(step, cycle_length=3750, init_lr=1e-3))
fe.Estimator(..., traces=[s])  # Learning rate will change based on step
s = LRScheduler(model=model, lr_fn=lambda epoch: fe.schedule.cosine_decay(epoch, cycle_length=3750, init_lr=1e-3))
fe.Estimator(..., traces=[s])  # Learning rate will change based on epoch

Parameters:

Name Type Description Default
model Union[tf.keras.Model, torch.nn.Module]

A model instance compiled with fe.build.

required
lr_fn Callable[[int], float]

A lr scheduling function that takes either 'epoch' or 'step' as input.

required

Raises:

Type Description
AssertionError

If the lr_fn is not configured properly.

Source code in fastestimator\fastestimator\trace\adapt\lr_scheduler.py
@traceable()
class LRScheduler(Trace):
    """Learning rate scheduler trace that changes the learning rate while training.

    This class requires an input function which takes either 'epoch' or 'step' as input:
    ```python
    s = LRScheduler(model=model, lr_fn=lambda step: fe.schedule.cosine_decay(step, cycle_length=3750, init_lr=1e-3))
    fe.Estimator(..., traces=[s])  # Learning rate will change based on step
    s = LRScheduler(model=model, lr_fn=lambda epoch: fe.schedule.cosine_decay(epoch, cycle_length=3750, init_lr=1e-3))
    fe.Estimator(..., traces=[s])  # Learning rate will change based on epoch
    ```

    Args:
        model: A model instance compiled with fe.build.
        lr_fn: A lr scheduling function that takes either 'epoch' or 'step' as input.

    Raises:
        AssertionError: If the `lr_fn` is not configured properly.
    """
    system: System

    def __init__(self, model: Union[tf.keras.Model, torch.nn.Module], lr_fn: Callable[[int], float]) -> None:
        self.model = model
        self.lr_fn = lr_fn
        assert hasattr(lr_fn, "__call__"), "lr_fn must be a function"
        arg = list(inspect.signature(lr_fn).parameters.keys())
        assert len(arg) == 1 and arg[0] in {"step", "epoch"}, "the lr_fn input arg must be either 'step' or 'epoch'"
        self.schedule_mode = arg[0]
        super().__init__(mode="train", outputs=self.model.model_name + "_lr")

    def on_epoch_begin(self, data: Data) -> None:
        if self.schedule_mode == "epoch":
            new_lr = np.float32(self.lr_fn(self.system.epoch_idx))
            set_lr(self.model, new_lr)

    def on_batch_begin(self, data: Data) -> None:
        if self.schedule_mode == "step":
            new_lr = np.float32(self.lr_fn(self.system.global_step))
            set_lr(self.model, new_lr)

    def on_batch_end(self, data: Data) -> None:
        if self.system.log_steps and (self.system.global_step % self.system.log_steps == 0
                                      or self.system.global_step == 1):
            current_lr = np.float32(get_lr(self.model))
            data.write_with_log(self.outputs[0], current_lr)