Skip to content

lr_shedule

cosine_decay

Learning rate cosine decay function (using half of cosine curve).

This method is useful for scheduling learning rates which oscillate over time:

s = fe.schedule.LRScheduler(model=model, lr_fn=lambda step: cosine_decay(step, cycle_length=3750, init_lr=1e-3))
fe.Estimator(..., traces=[s])

For more information, check out SGDR: https://arxiv.org/pdf/1608.03983.pdf.

Parameters:

Name Type Description Default
time int

The current step or epoch during training starting from 1.

required
cycle_length int

The decay cycle length.

required
init_lr float

Initial learning rate to decay from.

required
min_lr float

Minimum learning rate.

1e-06
start int

The step or epoch to start the decay schedule.

1
cycle_multiplier int

The factor by which next cycle length will be multiplied.

1

Returns:

Name Type Description
lr

learning rate given current step or epoch.

Source code in fastestimator\fastestimator\schedule\lr_shedule.py
def cosine_decay(time: int,
                 cycle_length: int,
                 init_lr: float,
                 min_lr: float = 1e-6,
                 start: int = 1,
                 cycle_multiplier: int = 1):
    """Learning rate cosine decay function (using half of cosine curve).

    This method is useful for scheduling learning rates which oscillate over time:
    ```python
    s = fe.schedule.LRScheduler(model=model, lr_fn=lambda step: cosine_decay(step, cycle_length=3750, init_lr=1e-3))
    fe.Estimator(..., traces=[s])
    ```

    For more information, check out SGDR: https://arxiv.org/pdf/1608.03983.pdf.

    Args:
        time: The current step or epoch during training starting from 1.
        cycle_length: The decay cycle length.
        init_lr: Initial learning rate to decay from.
        min_lr: Minimum learning rate.
        start: The step or epoch to start the decay schedule.
        cycle_multiplier: The factor by which next cycle length will be multiplied.

    Returns:
        lr: learning rate given current step or epoch.
    """
    if time < start:
        lr = init_lr
    else:
        time = time - start + 1
        if cycle_multiplier > 1:
            current_cycle_idx = math.ceil(
                math.log(time * (cycle_multiplier - 1) / cycle_length + 1) / math.log(cycle_multiplier)) - 1
            cumulative = cycle_length * (cycle_multiplier**current_cycle_idx - 1) / (cycle_multiplier - 1)
        elif cycle_multiplier == 1:
            current_cycle_idx = math.ceil(time / cycle_length) - 1
            cumulative = current_cycle_idx * cycle_length
        else:
            raise ValueError("multiplier must be at least 1")
        current_cycle_length = cycle_length * cycle_multiplier**current_cycle_idx
        time_in_cycle = (time - cumulative) / current_cycle_length
        lr = (init_lr - min_lr) / 2 * math.cos(time_in_cycle * math.pi) + (init_lr + min_lr) / 2
    return lr