Image Classification Using LeViT¶
[Paper] [Notebook] [Torch Implementation][Tensorflow Implementation]
LeViT is a hybrid neural network for fast inference for image classification. LeViT significantly outperforms existing convnets and vision transformers with respect to the speed/accuracy tradeoff. In this example, we will implement a LeViT in PyTorch and showcase how to use imagenet pre-trained a LeViT to fine-tune it on a downstream task for good results with minimal downstream training time.
Let's first import some necessary packages.
import numpy as np
import tempfile
import fastestimator as fe
from fastestimator.trace.io import BestModelSaver
from fastestimator.dataset.data import cifair10
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, Resize
from fastestimator.op.numpyop.univariate import ChannelTranspose, CoarseDropout, Normalize, Onehot, Normalize
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.trace.adapt import LRScheduler
from fastestimator.trace.metric import Accuracy
from fastestimator.schedule import EpochScheduler, cosine_decay
Now let's define some parameters that will be used later:
batch_size=32
epochs=20
data_dir=None
train_steps_per_epoch=None
eval_steps_per_epoch=None
save_dir = tempfile.mkdtemp()
Lets build the basic building blocks for the LeVIT model. We will be defining 3 levit models (LeViT_128S, LeViT_256, LeViT_384), where LeViT_128S is the smallest and LeVIT_384 is the largest.
# Modified from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
import itertools
import torch
from torch.nn.init import trunc_normal_
specification = {
'LeViT_128S': {
'embed_dim': (128, 256, 384),
'key_dim': (16, 16, 16),
'num_heads': (4, 6, 8),
'depth': (2, 3, 4),
'drop_path': 0,
'weights': 'https://huggingface.co/facebook/levit-128S/resolve/main/pytorch_model.bin'
},
'LeViT_256': {
'embed_dim': (256, 384, 512),
'key_dim': (32, 32, 32),
'num_heads': (4, 6, 8),
'depth': (4, 4, 4),
'drop_path': 0,
'weights': 'https://huggingface.co/facebook/levit-256/resolve/main/pytorch_model.bin'
},
'LeViT_384': {
'embed_dim': (384, 512, 768),
'key_dim': (32, 32, 32),
'num_heads': (6, 9, 12),
'depth': (4, 4, 4),
'drop_path': 0.1,
'weights': 'https://huggingface.co/facebook/levit-384/resolve/main/pytorch_model.bin'
},
}
class ConvNorm(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
groups=1,
bn_weight_init=1):
super().__init__()
self.convolution = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias=False)
self.batch_norm = torch.nn.BatchNorm2d(out_channels)
torch.nn.init.constant_(self.batch_norm.weight, bn_weight_init)
def forward(self, x):
return self.batch_norm(self.convolution(x))
class Backbone(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
super().__init__()
self.convolution_layer1 = ConvNorm(in_channels,
out_channels // 8,
kernel_size=kernel_size,
stride=stride,
padding=padding)
self.activation_layer1 = torch.nn.Hardswish()
self.convolution_layer2 = ConvNorm(out_channels // 8,
out_channels // 4,
kernel_size=kernel_size,
stride=stride,
padding=padding)
self.activation_layer2 = torch.nn.Hardswish()
self.convolution_layer3 = ConvNorm(out_channels // 4,
out_channels // 2,
kernel_size=kernel_size,
stride=stride,
padding=padding)
self.activation_layer3 = torch.nn.Hardswish()
self.convolution_layer4 = ConvNorm(out_channels // 2,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
def forward(self, x):
x = self.activation_layer1(self.convolution_layer1(x))
x = self.activation_layer2(self.convolution_layer2(x))
x = self.activation_layer3(self.convolution_layer3(x))
return self.convolution_layer4(x)
class LinearNorm(torch.nn.Module):
def __init__(self, in_features, out_features, bn_weight_init=1):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
self.batch_norm = torch.nn.BatchNorm1d(out_features)
torch.nn.init.constant_(self.batch_norm.weight, bn_weight_init)
def forward(self, x):
x = self.linear(x)
return self.batch_norm(x.flatten(0, 1)).reshape_as(x)
class Downsample(torch.nn.Module):
def __init__(self, stride, resolution, use_pool=False):
super().__init__()
self.stride = stride
self.resolution = resolution
self.pool = torch.nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if use_pool else None
def forward(self, x):
batch_size, _, channels = x.shape
x = x.view(batch_size, self.resolution, self.resolution, channels)
if self.pool is not None:
x = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
else:
x = x[:, ::self.stride, ::self.stride]
return x.reshape(batch_size, -1, channels)
class Residual(torch.nn.Module):
def __init__(self, module, drop_rate):
super().__init__()
self.module = module
self.drop_out = torch.nn.Dropout(p=drop_rate)
def forward(self, x):
if self.training:
return x + self.drop_out(self.module(x))
else:
return x + self.module(x)
class Attention(torch.nn.Module):
def __init__(self, dim, key_dim, num_attention_heads=8, attention_ratio=4, resolution=14):
super().__init__()
self.num_attention_heads = num_attention_heads
self.scale = key_dim**-0.5
self.key_dim = key_dim
self.attention_ratio = attention_ratio
self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2
self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
self.queries_keys_values = LinearNorm(dim, self.out_dim_keys_values)
self.activation = torch.nn.Hardswish()
self.projection = LinearNorm(self.out_dim_projection, dim, bn_weight_init=0)
points = list(itertools.product(range(resolution), range(resolution)))
len_points = len(points)
attention_offsets, indices = {}, []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
indices.append(attention_offsets[offset])
self.attention_bias_cache = {}
self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
self.register_buffer("attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points))
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.attention_bias_cache:
self.attention_bias_cache = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.attention_bias_cache:
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.attention_bias_cache[device_key]
def forward(self, hidden_state):
batch_size, seq_length, _ = hidden_state.shape
queries_keys_values = self.queries_keys_values(hidden_state)
query, key, value = queries_keys_values.view(
batch_size, seq_length, self.num_attention_heads, -1).split([
self.key_dim, self.key_dim, self.attention_ratio * self.key_dim
],
dim=3)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
attention = attention.softmax(dim=-1)
hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, seq_length, self.out_dim_projection)
hidden_state = self.projection(self.activation(hidden_state))
return hidden_state
class AttentionDownsample(torch.nn.Module):
def __init__(
self,
input_dim,
output_dim,
key_dim,
num_attention_heads,
attention_ratio,
stride,
resolution_in,
resolution_out,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.scale = key_dim**-0.5
self.key_dim = key_dim
self.attention_ratio = attention_ratio
self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads
self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
self.resolution_out = resolution_out
# resolution_in is the intial resolution, resoloution_out is final resolution after downsampling
self.keys_values = LinearNorm(input_dim, self.out_dim_keys_values)
self.queries_subsample = Downsample(stride, resolution_in)
self.queries = LinearNorm(input_dim, key_dim * num_attention_heads)
self.activation = torch.nn.Hardswish()
self.projection = LinearNorm(self.out_dim_projection, output_dim)
self.attention_bias_cache = {}
points = list(itertools.product(range(resolution_in), range(resolution_in)))
points_ = list(itertools.product(range(resolution_out), range(resolution_out)))
len_points, len_points_ = len(points), len(points_)
attention_offsets, indices = {}, []
for p1 in points_:
for p2 in points:
size = 1
offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
indices.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
self.register_buffer("attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points))
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.attention_bias_cache:
self.attention_bias_cache = {} # clear ab cache
def get_attention_biases(self, device):
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.attention_bias_cache:
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.attention_bias_cache[device_key]
def forward(self, hidden_state):
batch_size, seq_length, _ = hidden_state.shape
key, value = (self.keys_values(hidden_state).view(
batch_size, seq_length, self.num_attention_heads,
-1).split([self.key_dim, self.attention_ratio * self.key_dim],
dim=3))
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
query = self.queries(self.queries_subsample(hidden_state))
query = query.view(batch_size, self.resolution_out**2, self.num_attention_heads,
self.key_dim).permute(0, 2, 1, 3)
attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
attention = attention.softmax(dim=-1)
hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, -1, self.out_dim_projection)
hidden_state = self.projection(self.activation(hidden_state))
return hidden_state
class MLP(torch.nn.Module):
"""
MLP Layer with `2X` expansion in contrast to ViT with `4X`.
"""
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.linear_up = LinearNorm(input_dim, hidden_dim)
self.activation = torch.nn.Hardswish()
self.linear_down = LinearNorm(hidden_dim, input_dim)
def forward(self, hidden_state):
hidden_state = self.linear_up(hidden_state)
hidden_state = self.activation(hidden_state)
hidden_state = self.linear_down(hidden_state)
return hidden_state
class NormLinear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.):
super().__init__()
self.batch_norm = torch.nn.BatchNorm1d(in_features)
self.drop = torch.nn.Dropout(drop)
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
trunc_normal_(self.linear.weight, std=std)
if self.linear.bias is not None:
torch.nn.init.constant_(self.linear.bias, 0)
def forward(self, x):
return self.linear(self.drop(self.batch_norm(x)))
class LeViT(torch.nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
embed_dim,
key_dim,
depth,
num_heads,
attention_ratio,
mlp_ratio,
down_ops,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
distillation=True,
drop_path=0):
super().__init__()
resolution = img_size // patch_size
self.stages = []
self.num_classes = num_classes
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
self.distillation = distillation
self.patch_embed = Backbone(in_chans, embed_dim[0])
down_ops.append([''])
for i, (ed, kd, dpth, nh, ar, mr,
do) in enumerate(zip(embed_dim, key_dim, depth, num_heads, attention_ratio, mlp_ratio, down_ops)):
for _ in range(dpth):
self.stages.append(
Residual(
Attention(
dim=ed,
key_dim=kd,
num_attention_heads=nh,
attention_ratio=ar,
resolution=resolution,
),
drop_path))
if mr > 0:
h = int(ed * mr)
self.stages.append(Residual(MLP(input_dim=ed, hidden_dim=h), drop_path))
if do[0] == 'Subsample':
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
self.stages.append(
AttentionDownsample(input_dim=embed_dim[i],
output_dim=embed_dim[i + 1],
key_dim=do[1],
num_attention_heads=do[2],
attention_ratio=do[3],
stride=do[5],
resolution_in=resolution,
resolution_out=resolution_))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.stages.append(Residual(MLP(input_dim=embed_dim[i + 1], hidden_dim=h), drop_path))
self.stages = torch.nn.Sequential(*self.stages)
# Classifier head
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
if self.distillation:
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
x = self.stages(x)
x = x.mean(1)
if self.distillation:
x = self.head(x), self.head_dist(x)
if not self.training:
x = (x[0] + x[1]) / 2
else:
x = self.head(x)
return x
def model_factory(embed_dim, key_dim, depth, num_heads, drop_path, weights, num_classes, distillation, pretrained):
model = LeViT(
patch_size=16,
embed_dim=embed_dim,
num_heads=num_heads,
key_dim=key_dim,
depth=depth,
attention_ratio=(2, 2, 2),
mlp_ratio=(2, 2, 2),
down_ops=[
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
['Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2],
['Subsample', key_dim[0], embed_dim[1] // key_dim[0], 4, 2, 2],
],
num_classes=num_classes,
drop_path=drop_path,
distillation=distillation)
if pretrained:
# since the file names are the same, running the training script with different model(e.g. LeViT_256, LeViT_384) would throw error
# Either clear the cache or provide "model_dir" to load_state_dict_from_url.
checkpoint_dict = torch.hub.load_state_dict_from_url(weights, map_location='cpu')
model_dict = model.state_dict()
model_keys = list(model_dict.keys())
checkpoint_keys = list(checkpoint_dict.keys())
for i, _ in enumerate(model_keys):
if not (model_keys[i].startswith('head.linear') or model_keys[i].startswith('head_dist.linear')):
model_dict[model_keys[i]] = checkpoint_dict[checkpoint_keys[i]]
model.load_state_dict(model_dict)
return model
def LeViT_128S(num_classes=1000, distillation=False, pretrained=False):
return model_factory(**specification['LeViT_128S'],
num_classes=num_classes,
distillation=distillation,
pretrained=pretrained)
def LeViT_256(num_classes=1000, distillation=False, pretrained=False):
return model_factory(**specification['LeViT_256'],
num_classes=num_classes,
distillation=distillation,
pretrained=pretrained)
def LeViT_384(num_classes=1000, distillation=False, pretrained=False):
return model_factory(**specification['LeViT_384'],
num_classes=num_classes,
distillation=distillation,
pretrained=pretrained)
Lets load cifair10 dataset and define the data argumentation pipeline.
train_data, eval_data = cifair10.load_data(data_dir)
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
batch_size=batch_size,
ops=[
Resize(image_in="x", image_out="x", height=224, width=224),
Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
ChannelTranspose(inputs="x", outputs="x"),
Onehot(inputs="y", outputs="y", mode="train", num_classes=10, label_smoothing=0.05)
])
Lets define fastestimator model using the LeVIT-384. The downstream LeVIT is re-using the LeVIT encoder pre-trained on the imagenet dataset.
model = fe.build(model_fn=lambda: LeViT_384(num_classes=10, pretrained=True), optimizer_fn="adam")
Now that the model is defined, lets create some network ops for optimizing the mode. We will be using the CrossEntropy as our loss function.
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce", from_logits=True),
UpdateOp(model=model, loss_name="ce", mode="train")
])
Lets define some traces to save the model based on max accuracy and schedule the learning rate. Since learning rate warmup can help the transformer models to optimize faster lets define a learning rate scheduler, which will be slowly increase the learning rate for first 3 epochs, followed by using cosine_decay learning for every epoch after that.
init_lr = 1e-2 / 64 * batch_size
def lr_schedule_warmup(step, train_steps_epoch, init_lr):
warmup_steps = train_steps_epoch * 3
if step < warmup_steps:
lr = init_lr / warmup_steps * step
else:
lr = init_lr
return lr
lr_schedule = {
1:
LRScheduler(
model=model,
lr_fn=lambda step: lr_schedule_warmup(
step, train_steps_epoch=np.ceil(len(train_data) / batch_size), init_lr=init_lr)),
4:
LRScheduler(
model=model,
lr_fn=lambda epoch: cosine_decay(
epoch, cycle_length=epochs - 3, init_lr=init_lr, min_lr=init_lr / 100, start=4))
}
traces = [
Accuracy(true_key="y", pred_key="y_pred"),
BestModelSaver(model=model, save_dir=save_dir, metric="accuracy", save_best_mode="max"),
EpochScheduler(lr_schedule)
]
Finally its time for some actual training. To illustrate the effect of using the pre-trained encoder, we will only train the downstream task for a few epochs. So, lets define our estimator function and train for 10 epochs.
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch,
log_steps=600)
estimator.fit(warmup=False)
FastEstimator-Warn: Expected PyTorch version 2.0.1 but found 2.0.0+cu118. The framework may not work as expected.
______ __ ______ __ _ __
/ ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____
/ /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
/ __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / /
/_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/
FastEstimator-Start: step: 1; logging_interval: 600; num_device: 1;
FastEstimator-Train: step: 1; ce: 2.2783732; model_lr: 1.0663255e-06;
FastEstimator-Train: step: 600; ce: 0.69004774; model_lr: 0.00063979527; steps/sec: 15.51;
FastEstimator-Train: step: 1200; ce: 0.58633184; model_lr: 0.0012795905; steps/sec: 15.45;
FastEstimator-Train: step: 1563; epoch: 1; epoch_time(sec): 107.91;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 41.17;
Eval Progress: 208/313; steps/sec: 40.77;
Eval Progress: 313/313; steps/sec: 37.59;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 1563; epoch: 1; accuracy: 0.8829; ce: 0.3878323; max_accuracy: 0.8829; since_best_accuracy: 0;
FastEstimator-Train: step: 1800; ce: 0.77091324; model_lr: 0.0019193857; steps/sec: 13.14;
FastEstimator-Train: step: 2400; ce: 0.724447; model_lr: 0.002559181; steps/sec: 13.9;
FastEstimator-Train: step: 3000; ce: 0.587723; model_lr: 0.0031989764; steps/sec: 13.81;
FastEstimator-Train: step: 3126; epoch: 2; epoch_time(sec): 117.85;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 38.08;
Eval Progress: 208/313; steps/sec: 40.52;
Eval Progress: 313/313; steps/sec: 36.96;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 3126; epoch: 2; accuracy: 0.9156; ce: 0.28355816; max_accuracy: 0.9156; since_best_accuracy: 0;
FastEstimator-Train: step: 3600; ce: 0.46482837; model_lr: 0.0038387715; steps/sec: 12.14;
FastEstimator-Train: step: 4200; ce: 0.5003947; model_lr: 0.004478567; steps/sec: 13.75;
FastEstimator-Train: step: 4689; epoch: 3; epoch_time(sec): 117.97;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 40.45;
Eval Progress: 208/313; steps/sec: 42.25;
Eval Progress: 313/313; steps/sec: 37.54;
FastEstimator-Eval: step: 4689; epoch: 3; accuracy: 0.8961; ce: 0.35633534; max_accuracy: 0.9156; since_best_accuracy: 1;
FastEstimator-Train: step: 4800; ce: 0.54344803; model_lr: 0.0049578585; steps/sec: 12.73;
FastEstimator-Train: step: 5400; ce: 0.39572072; model_lr: 0.0049578585; steps/sec: 14.56;
FastEstimator-Train: step: 6000; ce: 0.39163283; model_lr: 0.0049578585; steps/sec: 14.61;
FastEstimator-Train: step: 6252; epoch: 4; epoch_time(sec): 112.53;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 40.48;
Eval Progress: 208/313; steps/sec: 42.51;
Eval Progress: 313/313; steps/sec: 39.6;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 6252; epoch: 4; accuracy: 0.9164; ce: 0.28092158; max_accuracy: 0.9164; since_best_accuracy: 0;
FastEstimator-Train: step: 6600; ce: 0.32644862; model_lr: 0.004832869; steps/sec: 12.85;
FastEstimator-Train: step: 7200; ce: 0.6165408; model_lr: 0.004832869; steps/sec: 14.61;
FastEstimator-Train: step: 7800; ce: 0.47856155; model_lr: 0.004832869; steps/sec: 14.52;
FastEstimator-Train: step: 7815; epoch: 5; epoch_time(sec): 112.57;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 43.0;
Eval Progress: 208/313; steps/sec: 42.07;
Eval Progress: 313/313; steps/sec: 39.3;
FastEstimator-Eval: step: 7815; epoch: 5; accuracy: 0.9079; ce: 0.30562693; max_accuracy: 0.9164; since_best_accuracy: 1;
FastEstimator-Train: step: 8400; ce: 0.66817826; model_lr: 0.0046292874; steps/sec: 12.82;
FastEstimator-Train: step: 9000; ce: 0.42860758; model_lr: 0.0046292874; steps/sec: 14.55;
FastEstimator-Train: step: 9378; epoch: 6; epoch_time(sec): 112.94;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 41.97;
Eval Progress: 208/313; steps/sec: 40.24;
Eval Progress: 313/313; steps/sec: 39.43;
FastEstimator-Eval: step: 9378; epoch: 6; accuracy: 0.9077; ce: 0.31278017; max_accuracy: 0.9164; since_best_accuracy: 2;
FastEstimator-Train: step: 9600; ce: 0.42492312; model_lr: 0.004354047; steps/sec: 12.91;
FastEstimator-Train: step: 10200; ce: 0.64375705; model_lr: 0.004354047; steps/sec: 14.68;
FastEstimator-Train: step: 10800; ce: 0.44151187; model_lr: 0.004354047; steps/sec: 14.7;
FastEstimator-Train: step: 10941; epoch: 7; epoch_time(sec): 112.01;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 43.45;
Eval Progress: 208/313; steps/sec: 43.06;
Eval Progress: 313/313; steps/sec: 39.2;
FastEstimator-Eval: step: 10941; epoch: 7; accuracy: 0.9008; ce: 0.32655126; max_accuracy: 0.9164; since_best_accuracy: 3;
FastEstimator-Train: step: 11400; ce: 0.42734885; model_lr: 0.004016521; steps/sec: 12.92;
FastEstimator-Train: step: 12000; ce: 0.36721689; model_lr: 0.004016521; steps/sec: 14.67;
FastEstimator-Train: step: 12504; epoch: 8; epoch_time(sec): 111.9;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 41.57;
Eval Progress: 208/313; steps/sec: 43.07;
Eval Progress: 313/313; steps/sec: 39.89;
FastEstimator-Eval: step: 12504; epoch: 8; accuracy: 0.888; ce: 0.3563145; max_accuracy: 0.9164; since_best_accuracy: 4;
FastEstimator-Train: step: 12600; ce: 0.5211024; model_lr: 0.0036282025; steps/sec: 12.96;
FastEstimator-Train: step: 13200; ce: 0.55599684; model_lr: 0.0036282025; steps/sec: 14.76;
FastEstimator-Train: step: 13800; ce: 0.38998234; model_lr: 0.0036282025; steps/sec: 14.72;
FastEstimator-Train: step: 14067; epoch: 9; epoch_time(sec): 111.9;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 39.02;
Eval Progress: 208/313; steps/sec: 41.62;
Eval Progress: 313/313; steps/sec: 37.75;
FastEstimator-Eval: step: 14067; epoch: 9; accuracy: 0.8861; ce: 0.36740774; max_accuracy: 0.9164; since_best_accuracy: 5;
FastEstimator-Train: step: 14400; ce: 0.59265655; model_lr: 0.003202316; steps/sec: 12.81;
FastEstimator-Train: step: 15000; ce: 0.6133712; model_lr: 0.003202316; steps/sec: 14.72;
FastEstimator-Train: step: 15600; ce: 0.5198127; model_lr: 0.003202316; steps/sec: 14.74;
FastEstimator-Train: step: 15630; epoch: 10; epoch_time(sec): 111.89;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 41.43;
Eval Progress: 208/313; steps/sec: 43.47;
Eval Progress: 313/313; steps/sec: 40.02;
FastEstimator-Eval: step: 15630; epoch: 10; accuracy: 0.9015; ce: 0.3149603; max_accuracy: 0.9164; since_best_accuracy: 6;
FastEstimator-Train: step: 16200; ce: 0.34475362; model_lr: 0.0027533642; steps/sec: 12.95;
FastEstimator-Train: step: 16800; ce: 0.37621388; model_lr: 0.0027533642; steps/sec: 14.68;
FastEstimator-Train: step: 17193; epoch: 11; epoch_time(sec): 112.13;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 40.54;
Eval Progress: 208/313; steps/sec: 42.77;
Eval Progress: 313/313; steps/sec: 39.55;
FastEstimator-Eval: step: 17193; epoch: 11; accuracy: 0.8998; ce: 0.3317904; max_accuracy: 0.9164; since_best_accuracy: 7;
FastEstimator-Train: step: 17400; ce: 0.32059616; model_lr: 0.0022966359; steps/sec: 12.95;
FastEstimator-Train: step: 18000; ce: 0.5899464; model_lr: 0.0022966359; steps/sec: 14.76;
FastEstimator-Train: step: 18600; ce: 0.34279728; model_lr: 0.0022966359; steps/sec: 14.78;
FastEstimator-Train: step: 18756; epoch: 12; epoch_time(sec): 111.15;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 39.45;
Eval Progress: 208/313; steps/sec: 42.5;
Eval Progress: 313/313; steps/sec: 39.38;
FastEstimator-Eval: step: 18756; epoch: 12; accuracy: 0.909; ce: 0.30617863; max_accuracy: 0.9164; since_best_accuracy: 8;
FastEstimator-Train: step: 19200; ce: 0.36847696; model_lr: 0.0018476842; steps/sec: 12.27;
FastEstimator-Train: step: 19800; ce: 0.33842516; model_lr: 0.0018476842; steps/sec: 13.7;
FastEstimator-Train: step: 20319; epoch: 13; epoch_time(sec): 119.7;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 41.31;
Eval Progress: 208/313; steps/sec: 41.69;
Eval Progress: 313/313; steps/sec: 39.67;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 20319; epoch: 13; accuracy: 0.92; ce: 0.27661368; max_accuracy: 0.92; since_best_accuracy: 0;
FastEstimator-Train: step: 20400; ce: 0.31592426; model_lr: 0.0014217976; steps/sec: 12.28;
FastEstimator-Train: step: 21000; ce: 0.33163732; model_lr: 0.0014217976; steps/sec: 13.82;
FastEstimator-Train: step: 21600; ce: 0.29629666; model_lr: 0.0014217976; steps/sec: 13.73;
FastEstimator-Train: step: 21882; epoch: 14; epoch_time(sec): 118.78;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 38.3;
Eval Progress: 208/313; steps/sec: 40.73;
Eval Progress: 313/313; steps/sec: 37.96;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 21882; epoch: 14; accuracy: 0.9266; ce: 0.26339; max_accuracy: 0.9266; since_best_accuracy: 0;
FastEstimator-Train: step: 22200; ce: 0.29879218; model_lr: 0.0010334792; steps/sec: 12.22;
FastEstimator-Train: step: 22800; ce: 0.2988974; model_lr: 0.0010334792; steps/sec: 13.71;
FastEstimator-Train: step: 23400; ce: 0.3045774; model_lr: 0.0010334792; steps/sec: 13.83;
FastEstimator-Train: step: 23445; epoch: 15; epoch_time(sec): 119.22;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 37.0;
Eval Progress: 208/313; steps/sec: 40.07;
Eval Progress: 313/313; steps/sec: 38.25;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 23445; epoch: 15; accuracy: 0.9331; ce: 0.2525196; max_accuracy: 0.9331; since_best_accuracy: 0;
FastEstimator-Train: step: 24000; ce: 0.2868572; model_lr: 0.0006959529; steps/sec: 12.21;
FastEstimator-Train: step: 24600; ce: 0.28900862; model_lr: 0.0006959529; steps/sec: 13.91;
FastEstimator-Train: step: 25008; epoch: 16; epoch_time(sec): 118.28;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 39.6;
Eval Progress: 208/313; steps/sec: 42.27;
Eval Progress: 313/313; steps/sec: 39.35;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 25008; epoch: 16; accuracy: 0.9404; ce: 0.23263757; max_accuracy: 0.9404; since_best_accuracy: 0;
FastEstimator-Train: step: 25200; ce: 0.28556466; model_lr: 0.00042071257; steps/sec: 12.59;
FastEstimator-Train: step: 25800; ce: 0.29081023; model_lr: 0.00042071257; steps/sec: 14.68;
FastEstimator-Train: step: 26400; ce: 0.39484438; model_lr: 0.00042071257; steps/sec: 14.6;
FastEstimator-Train: step: 26571; epoch: 17; epoch_time(sec): 112.03;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 42.82;
Eval Progress: 208/313; steps/sec: 40.72;
Eval Progress: 313/313; steps/sec: 39.3;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 26571; epoch: 17; accuracy: 0.9449; ce: 0.2264233; max_accuracy: 0.9449; since_best_accuracy: 0;
FastEstimator-Train: step: 27000; ce: 0.2842825; model_lr: 0.00021713124; steps/sec: 13.0;
FastEstimator-Train: step: 27600; ce: 0.28541481; model_lr: 0.00021713124; steps/sec: 14.68;
FastEstimator-Train: step: 28134; epoch: 18; epoch_time(sec): 111.57;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 42.22;
Eval Progress: 208/313; steps/sec: 42.56;
Eval Progress: 313/313; steps/sec: 38.89;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 28134; epoch: 18; accuracy: 0.9484; ce: 0.21373338; max_accuracy: 0.9484; since_best_accuracy: 0;
FastEstimator-Train: step: 28200; ce: 0.28641877; model_lr: 9.214158e-05; steps/sec: 13.02;
FastEstimator-Train: step: 28800; ce: 0.2853601; model_lr: 9.214158e-05; steps/sec: 14.66;
FastEstimator-Train: step: 29400; ce: 0.29596728; model_lr: 9.214158e-05; steps/sec: 14.66;
FastEstimator-Train: step: 29697; epoch: 19; epoch_time(sec): 111.86;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 42.4;
Eval Progress: 208/313; steps/sec: 42.63;
Eval Progress: 313/313; steps/sec: 39.49;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 29697; epoch: 19; accuracy: 0.9506; ce: 0.2108158; max_accuracy: 0.9506; since_best_accuracy: 0;
FastEstimator-Train: step: 30000; ce: 0.2867692; model_lr: 5e-05; steps/sec: 13.01;
FastEstimator-Train: step: 30600; ce: 0.28615165; model_lr: 5e-05; steps/sec: 14.63;
FastEstimator-Train: step: 31200; ce: 0.28824845; model_lr: 5e-05; steps/sec: 14.71;
FastEstimator-Train: step: 31260; epoch: 20; epoch_time(sec): 111.83;
Eval Progress: 1/313;
Eval Progress: 104/313; steps/sec: 39.09;
Eval Progress: 208/313; steps/sec: 38.57;
Eval Progress: 313/313; steps/sec: 39.78;
FastEstimator-BestModelSaver: Saved model to /tmp/tmp9ozjqvv_/model_best_accuracy.pt
FastEstimator-Eval: step: 31260; epoch: 20; accuracy: 0.9524; ce: 0.20477045; max_accuracy: 0.9524; since_best_accuracy: 0;
FastEstimator-Finish: step: 31260; model_lr: 5e-05; total_time(sec): 2536.45;
We are getting 95.24% accuracy which is not a bad results for 20 epochs.