Learning rate schedules can be implemented using the LRScheduler
Trace
. LRScheduler
takes the model and learning schedule through the lr_fn parameter. lr_fn should be a function/lambda function with 'step' or 'epoch' as its input parameter. This determines whether the learning schedule will be applied at a step or epoch level.
For more details on traces, you can visit Beginner Tutorial 7 and Advanced Tutorial 4.
Let's create a function to generate the pipeline, model, and network to be used for this tutorial:
import fastestimator as fe
from fastestimator.architecture.tensorflow import LeNet
from fastestimator.dataset.data import mnist
from fastestimator.op.numpyop.univariate import ExpandDims, Minmax
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
def get_pipeline_model_network(model_name="LeNet"):
train_data, _ = mnist.load_data()
pipeline = fe.Pipeline(train_data=train_data,
batch_size=32,
ops=[ExpandDims(inputs="x", outputs="x"),
Minmax(inputs="x", outputs="x")])
model = fe.build(model_fn=LeNet, optimizer_fn="adam", model_name=model_name)
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
UpdateOp(model=model, loss_name="ce")
])
return pipeline, model, network
Customizing a Learning Rate Schedule Function¶
We can specify a custom learning schedule by passing a custom function to the lr_fn parameter of LRScheduler
. We can have this learning rate schedule applied at either the epoch or step level. Epoch and step both start from 1.
Epoch-wise¶
To apply learning rate scheduling at an epoch level, the custom function should have 'epoch' as its parameter. Let's look at the example below which demonstrates this. We will be using the summary parameter in the fit method to be able to visualize the learning rate later. You can go through Advanced Tutorial 6 for more details on accessing training history.
from fastestimator.summary.logs import visualize_logs
from fastestimator.trace.adapt import LRScheduler
def lr_schedule(epoch):
lr = 0.001*(20-epoch+1)/20
return lr
pipeline, model, network = get_pipeline_model_network()
traces = LRScheduler(model=model, lr_fn=lr_schedule)
estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=3, traces=traces)
history = estimator.fit(summary="Experiment_1")
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved. FastEstimator-Start: step: 1; logging_interval: 100; num_device: 0; FastEstimator-Train: step: 1; ce: 2.3051612; LeNet_lr: 0.001; FastEstimator-Train: step: 100; ce: 0.0635467; LeNet_lr: 0.001; steps/sec: 68.63; FastEstimator-Train: step: 200; ce: 0.19281179; LeNet_lr: 0.001; steps/sec: 69.69; FastEstimator-Train: step: 300; ce: 0.08815661; LeNet_lr: 0.001; steps/sec: 69.59; FastEstimator-Train: step: 400; ce: 0.10132183; LeNet_lr: 0.001; steps/sec: 73.27; FastEstimator-Train: step: 500; ce: 0.07849489; LeNet_lr: 0.001; steps/sec: 72.99; FastEstimator-Train: step: 600; ce: 0.05470478; LeNet_lr: 0.001; steps/sec: 70.92; FastEstimator-Train: step: 700; ce: 0.07486844; LeNet_lr: 0.001; steps/sec: 71.82; FastEstimator-Train: step: 800; ce: 0.13934246; LeNet_lr: 0.001; steps/sec: 69.99; FastEstimator-Train: step: 900; ce: 0.08544485; LeNet_lr: 0.001; steps/sec: 63.82; FastEstimator-Train: step: 1000; ce: 0.053602487; LeNet_lr: 0.001; steps/sec: 69.94; FastEstimator-Train: step: 1100; ce: 0.09014913; LeNet_lr: 0.001; steps/sec: 70.25; FastEstimator-Train: step: 1200; ce: 0.07134402; LeNet_lr: 0.001; steps/sec: 69.68; FastEstimator-Train: step: 1300; ce: 0.012463177; LeNet_lr: 0.001; steps/sec: 71.6; FastEstimator-Train: step: 1400; ce: 0.041746277; LeNet_lr: 0.001; steps/sec: 71.23; FastEstimator-Train: step: 1500; ce: 0.026363641; LeNet_lr: 0.001; steps/sec: 71.07; FastEstimator-Train: step: 1600; ce: 0.00073903985; LeNet_lr: 0.001; steps/sec: 68.05; FastEstimator-Train: step: 1700; ce: 0.001534229; LeNet_lr: 0.001; steps/sec: 70.16; FastEstimator-Train: step: 1800; ce: 0.05694783; LeNet_lr: 0.001; steps/sec: 72.08; FastEstimator-Train: step: 1875; epoch: 1; epoch_time: 28.25 sec; FastEstimator-Train: step: 1900; ce: 0.015518977; LeNet_lr: 0.00095; steps/sec: 53.08; FastEstimator-Train: step: 2000; ce: 0.08641896; LeNet_lr: 0.00095; steps/sec: 67.71; FastEstimator-Train: step: 2100; ce: 0.024979822; LeNet_lr: 0.00095; steps/sec: 68.86; FastEstimator-Train: step: 2200; ce: 0.028013688; LeNet_lr: 0.00095; steps/sec: 67.55; FastEstimator-Train: step: 2300; ce: 0.15737121; LeNet_lr: 0.00095; steps/sec: 67.49; FastEstimator-Train: step: 2400; ce: 0.05255642; LeNet_lr: 0.00095; steps/sec: 68.04; FastEstimator-Train: step: 2500; ce: 0.03364688; LeNet_lr: 0.00095; steps/sec: 67.09; FastEstimator-Train: step: 2600; ce: 0.06446718; LeNet_lr: 0.00095; steps/sec: 65.47; FastEstimator-Train: step: 2700; ce: 0.003595281; LeNet_lr: 0.00095; steps/sec: 68.34; FastEstimator-Train: step: 2800; ce: 0.047859844; LeNet_lr: 0.00095; steps/sec: 68.43; FastEstimator-Train: step: 2900; ce: 0.0088707255; LeNet_lr: 0.00095; steps/sec: 67.91; FastEstimator-Train: step: 3000; ce: 0.09855172; LeNet_lr: 0.00095; steps/sec: 67.98; FastEstimator-Train: step: 3100; ce: 0.012705317; LeNet_lr: 0.00095; steps/sec: 66.99; FastEstimator-Train: step: 3200; ce: 0.012674243; LeNet_lr: 0.00095; steps/sec: 66.29; FastEstimator-Train: step: 3300; ce: 0.0048475517; LeNet_lr: 0.00095; steps/sec: 67.02; FastEstimator-Train: step: 3400; ce: 0.077167764; LeNet_lr: 0.00095; steps/sec: 67.14; FastEstimator-Train: step: 3500; ce: 0.030703163; LeNet_lr: 0.00095; steps/sec: 65.63; FastEstimator-Train: step: 3600; ce: 0.015752632; LeNet_lr: 0.00095; steps/sec: 67.31; FastEstimator-Train: step: 3700; ce: 0.10233892; LeNet_lr: 0.00095; steps/sec: 67.73; FastEstimator-Train: step: 3750; epoch: 2; epoch_time: 28.34 sec; FastEstimator-Train: step: 3800; ce: 0.043337934; LeNet_lr: 0.0009; steps/sec: 50.85; FastEstimator-Train: step: 3900; ce: 0.002429512; LeNet_lr: 0.0009; steps/sec: 67.46; FastEstimator-Train: step: 4000; ce: 0.009731653; LeNet_lr: 0.0009; steps/sec: 69.08; FastEstimator-Train: step: 4100; ce: 0.06496129; LeNet_lr: 0.0009; steps/sec: 70.65; FastEstimator-Train: step: 4200; ce: 0.0048102853; LeNet_lr: 0.0009; steps/sec: 64.19; FastEstimator-Train: step: 4300; ce: 0.02214607; LeNet_lr: 0.0009; steps/sec: 64.36; FastEstimator-Train: step: 4400; ce: 0.0017155888; LeNet_lr: 0.0009; steps/sec: 70.27; FastEstimator-Train: step: 4500; ce: 0.0050323857; LeNet_lr: 0.0009; steps/sec: 65.19; FastEstimator-Train: step: 4600; ce: 0.0005132223; LeNet_lr: 0.0009; steps/sec: 66.96; FastEstimator-Train: step: 4700; ce: 0.2890709; LeNet_lr: 0.0009; steps/sec: 68.09; FastEstimator-Train: step: 4800; ce: 0.0011604289; LeNet_lr: 0.0009; steps/sec: 69.16; FastEstimator-Train: step: 4900; ce: 0.007527765; LeNet_lr: 0.0009; steps/sec: 69.86; FastEstimator-Train: step: 5000; ce: 0.06803949; LeNet_lr: 0.0009; steps/sec: 67.5; FastEstimator-Train: step: 5100; ce: 0.01932398; LeNet_lr: 0.0009; steps/sec: 63.75; FastEstimator-Train: step: 5200; ce: 0.0017156545; LeNet_lr: 0.0009; steps/sec: 68.0; FastEstimator-Train: step: 5300; ce: 0.00885325; LeNet_lr: 0.0009; steps/sec: 60.41; FastEstimator-Train: step: 5400; ce: 0.020371906; LeNet_lr: 0.0009; steps/sec: 68.71; FastEstimator-Train: step: 5500; ce: 0.021073567; LeNet_lr: 0.0009; steps/sec: 66.44; FastEstimator-Train: step: 5600; ce: 0.0105727; LeNet_lr: 0.0009; steps/sec: 66.34; FastEstimator-Train: step: 5625; epoch: 3; epoch_time: 28.49 sec; FastEstimator-Finish: step: 5625; LeNet_lr: 0.0009; total_time: 85.12 sec;
The learning rate is available in the training log at steps specified using the log_steps parameter in the Estimator
. By default, training is logged every 100 steps.
visualize_logs(history, include_metrics="LeNet_lr")
As you can see, the learning rate changes only after every epoch.
Step-wise¶
The custom function should have 'step' as its parameter for step-based learning rate schedules.
def lr_schedule(step):
lr = 0.001*(7500-step+1)/7500
return lr
pipeline, model, network = get_pipeline_model_network()
traces = LRScheduler(model=model, lr_fn=lr_schedule)
estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=2, traces=traces)
history2 = estimator.fit(summary="Experiment_2")
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved. FastEstimator-Start: step: 1; logging_interval: 100; num_device: 0; FastEstimator-Train: step: 1; ce: 2.3127494; LeNet_lr: 0.001; FastEstimator-Train: step: 100; ce: 0.48936948; LeNet_lr: 0.0009868; steps/sec: 69.11; FastEstimator-Train: step: 200; ce: 0.20337304; LeNet_lr: 0.00097346667; steps/sec: 68.69; FastEstimator-Train: step: 300; ce: 0.0806367; LeNet_lr: 0.00096013333; steps/sec: 67.94; FastEstimator-Train: step: 400; ce: 0.2164749; LeNet_lr: 0.0009468; steps/sec: 67.96; FastEstimator-Train: step: 500; ce: 0.035739463; LeNet_lr: 0.00093346665; steps/sec: 68.37; FastEstimator-Train: step: 600; ce: 0.13662587; LeNet_lr: 0.0009201333; steps/sec: 68.3; FastEstimator-Train: step: 700; ce: 0.03717831; LeNet_lr: 0.0009068; steps/sec: 67.53; FastEstimator-Train: step: 800; ce: 0.051307462; LeNet_lr: 0.00089346664; steps/sec: 65.37; FastEstimator-Train: step: 900; ce: 0.26505423; LeNet_lr: 0.00088013336; steps/sec: 65.54; FastEstimator-Train: step: 1000; ce: 0.020806756; LeNet_lr: 0.0008668; steps/sec: 65.7; FastEstimator-Train: step: 1100; ce: 0.33482748; LeNet_lr: 0.0008534667; steps/sec: 64.84; FastEstimator-Train: step: 1200; ce: 0.06784122; LeNet_lr: 0.00084013335; steps/sec: 64.64; FastEstimator-Train: step: 1300; ce: 0.008510821; LeNet_lr: 0.0008268; steps/sec: 62.37; FastEstimator-Train: step: 1400; ce: 0.0026529469; LeNet_lr: 0.0008134667; steps/sec: 63.92; FastEstimator-Train: step: 1500; ce: 0.10912442; LeNet_lr: 0.00080013333; steps/sec: 63.4; FastEstimator-Train: step: 1600; ce: 0.017952079; LeNet_lr: 0.0007868; steps/sec: 65.31; FastEstimator-Train: step: 1700; ce: 0.0029014135; LeNet_lr: 0.00077346666; steps/sec: 65.12; FastEstimator-Train: step: 1800; ce: 0.014999421; LeNet_lr: 0.0007601333; steps/sec: 64.47; FastEstimator-Train: step: 1875; epoch: 1; epoch_time: 29.07 sec; FastEstimator-Train: step: 1900; ce: 0.025787469; LeNet_lr: 0.0007468; steps/sec: 49.44; FastEstimator-Train: step: 2000; ce: 0.12261312; LeNet_lr: 0.00073346664; steps/sec: 66.56; FastEstimator-Train: step: 2100; ce: 0.113973364; LeNet_lr: 0.0007201333; steps/sec: 66.32; FastEstimator-Train: step: 2200; ce: 0.013463736; LeNet_lr: 0.0007068; steps/sec: 68.12; FastEstimator-Train: step: 2300; ce: 0.017395472; LeNet_lr: 0.0006934667; steps/sec: 67.15; FastEstimator-Train: step: 2400; ce: 0.0009720749; LeNet_lr: 0.00068013335; steps/sec: 66.65; FastEstimator-Train: step: 2500; ce: 0.04773684; LeNet_lr: 0.0006668; steps/sec: 65.45; FastEstimator-Train: step: 2600; ce: 0.0063042343; LeNet_lr: 0.0006534667; steps/sec: 64.76; FastEstimator-Train: step: 2700; ce: 0.0079148635; LeNet_lr: 0.00064013334; steps/sec: 66.14; FastEstimator-Train: step: 2800; ce: 0.046359185; LeNet_lr: 0.0006268; steps/sec: 64.41; FastEstimator-Train: step: 2900; ce: 0.019742897; LeNet_lr: 0.00061346666; steps/sec: 66.24; FastEstimator-Train: step: 3000; ce: 0.011395022; LeNet_lr: 0.0006001333; steps/sec: 67.24; FastEstimator-Train: step: 3100; ce: 0.12286997; LeNet_lr: 0.0005868; steps/sec: 66.07; FastEstimator-Train: step: 3200; ce: 0.0017318464; LeNet_lr: 0.00057346665; steps/sec: 65.83; FastEstimator-Train: step: 3300; ce: 0.10378831; LeNet_lr: 0.0005601333; steps/sec: 63.85; FastEstimator-Train: step: 3400; ce: 0.0010544249; LeNet_lr: 0.0005468; steps/sec: 63.47; FastEstimator-Train: step: 3500; ce: 0.013209468; LeNet_lr: 0.0005334667; steps/sec: 64.08; FastEstimator-Train: step: 3600; ce: 0.0063856095; LeNet_lr: 0.00052013336; steps/sec: 64.74; FastEstimator-Train: step: 3700; ce: 0.029427588; LeNet_lr: 0.0005068; steps/sec: 65.24; FastEstimator-Train: step: 3750; epoch: 2; epoch_time: 29.08 sec; FastEstimator-Finish: step: 3750; LeNet_lr: 0.0005001333; total_time: 58.17 sec;
visualize_logs(history2, include_metrics="LeNet_lr")
Using Built-In lr_schedule Function¶
Some learning rates schedules are widely popular in the deep learning community. We have implemented some of them in FastEstimator so that you don't need to write a custom schedule for them. We will be showcasing the cosine decay
schedule below.
cosine_decay¶
We can specify the length of the decay cycle and initial learning rate using cycle_length and init_lr respectively. Similar to custom learning schedule, lr_fn should have step or epoch as a parameter. The FastEstimator cosine decay can be used as follows:
from fastestimator.schedule import cosine_decay
pipeline, model, network = get_pipeline_model_network()
traces = LRScheduler(model=model, lr_fn=lambda step: cosine_decay(step, cycle_length=1875, init_lr=1e-3))
estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=2, traces=traces)
history3 = estimator.fit(summary="Experiment_3")
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved. FastEstimator-Start: step: 1; logging_interval: 100; num_device: 0; FastEstimator-Train: step: 1; ce: 2.297008; LeNet_lr: 0.0009999993; FastEstimator-Train: step: 100; ce: 0.33457872; LeNet_lr: 0.000993005; steps/sec: 65.37; FastEstimator-Train: step: 200; ce: 0.17324731; LeNet_lr: 0.000972216; steps/sec: 67.94; FastEstimator-Train: step: 300; ce: 0.068810925; LeNet_lr: 0.0009382152; steps/sec: 69.28; FastEstimator-Train: step: 400; ce: 0.1458275; LeNet_lr: 0.00089195487; steps/sec: 68.89; FastEstimator-Train: step: 500; ce: 0.025297957; LeNet_lr: 0.00083473074; steps/sec: 67.89; FastEstimator-Train: step: 600; ce: 0.093824804; LeNet_lr: 0.0007681455; steps/sec: 66.97; FastEstimator-Train: step: 700; ce: 0.043740712; LeNet_lr: 0.000694064; steps/sec: 65.74; FastEstimator-Train: step: 800; ce: 0.27958778; LeNet_lr: 0.00061456126; steps/sec: 66.3; FastEstimator-Train: step: 900; ce: 0.061571077; LeNet_lr: 0.0005318639; steps/sec: 64.48; FastEstimator-Train: step: 1000; ce: 0.033488706; LeNet_lr: 0.00044828805; steps/sec: 65.23; FastEstimator-Train: step: 1100; ce: 0.14927314; LeNet_lr: 0.00036617456; steps/sec: 64.95; FastEstimator-Train: step: 1200; ce: 0.029459432; LeNet_lr: 0.00028782323; steps/sec: 64.45; FastEstimator-Train: step: 1300; ce: 0.2428919; LeNet_lr: 0.00021542858; steps/sec: 65.25; FastEstimator-Train: step: 1400; ce: 0.05267974; LeNet_lr: 0.00015101816; steps/sec: 64.99; FastEstimator-Train: step: 1500; ce: 0.026721286; LeNet_lr: 9.639601e-05; steps/sec: 64.68; FastEstimator-Train: step: 1600; ce: 0.16975759; LeNet_lr: 5.3091975e-05; steps/sec: 65.42; FastEstimator-Train: step: 1700; ce: 0.0128316255; LeNet_lr: 2.231891e-05; steps/sec: 66.23; FastEstimator-Train: step: 1800; ce: 0.03704284; LeNet_lr: 4.9387068e-06; steps/sec: 67.37; FastEstimator-Train: step: 1875; epoch: 1; epoch_time: 29.06 sec; FastEstimator-Train: step: 1900; ce: 0.017644638; LeNet_lr: 0.0009995619; steps/sec: 47.01; FastEstimator-Train: step: 2000; ce: 0.062515415; LeNet_lr: 0.0009890847; steps/sec: 67.26; FastEstimator-Train: step: 2100; ce: 0.036463827; LeNet_lr: 0.00096492335; steps/sec: 67.81; FastEstimator-Train: step: 2200; ce: 0.03142055; LeNet_lr: 0.00092775445; steps/sec: 65.61; FastEstimator-Train: step: 2300; ce: 0.07728143; LeNet_lr: 0.00087861903; steps/sec: 64.32; FastEstimator-Train: step: 2400; ce: 0.0017153291; LeNet_lr: 0.00081889326; steps/sec: 64.85; FastEstimator-Train: step: 2500; ce: 0.16212359; LeNet_lr: 0.00075025; steps/sec: 64.87; FastEstimator-Train: step: 2600; ce: 0.02398505; LeNet_lr: 0.0006746117; steps/sec: 65.58; FastEstimator-Train: step: 2700; ce: 0.08713155; LeNet_lr: 0.00059409696; steps/sec: 65.2; FastEstimator-Train: step: 2800; ce: 0.043791365; LeNet_lr: 0.00051096076; steps/sec: 67.37; FastEstimator-Train: step: 2900; ce: 0.0008412299; LeNet_lr: 0.00042753152; steps/sec: 67.45; FastEstimator-Train: step: 3000; ce: 0.00043671875; LeNet_lr: 0.000346146; steps/sec: 67.15; FastEstimator-Train: step: 3100; ce: 0.004842498; LeNet_lr: 0.00026908363; steps/sec: 65.77; FastEstimator-Train: step: 3200; ce: 0.0051019713; LeNet_lr: 0.00019850275; steps/sec: 64.36; FastEstimator-Train: step: 3300; ce: 0.00468382; LeNet_lr: 0.00013638017; steps/sec: 63.97; FastEstimator-Train: step: 3400; ce: 0.14304858; LeNet_lr: 8.445584e-05; steps/sec: 64.01; FastEstimator-Train: step: 3500; ce: 0.006549092; LeNet_lr: 4.4184046e-05; steps/sec: 55.52; FastEstimator-Train: step: 3600; ce: 0.074703395; LeNet_lr: 1.6692711e-05; steps/sec: 55.0; FastEstimator-Train: step: 3700; ce: 0.0007971178; LeNet_lr: 2.7518167e-06; steps/sec: 52.77; FastEstimator-Train: step: 3750; epoch: 2; epoch_time: 30.31 sec; FastEstimator-Finish: step: 3750; LeNet_lr: 1e-06; total_time: 59.39 sec;
visualize_logs(history3, include_metrics="LeNet_lr")