Bases: Module
A standard LeNet implementation in pytorch.
This class is intentionally not @traceable (models and layers are handled by a different process).
The LeNet model has 3 convolution layers and 2 dense layers.
Parameters:
Name |
Type |
Description |
Default |
input_shape |
Tuple[int, int, int]
|
The shape of the model input (channels, height, width).
|
(1, 28, 28)
|
classes |
int
|
The number of outputs the model should generate.
|
10
|
Raises:
Type |
Description |
ValueError
|
Length of input_shape is not 3.
|
ValueError
|
input_shape [1] or input_shape [2] is smaller than 18.
|
Source code in fastestimator/fastestimator/architecture/pytorch/lenet.py
| class LeNet(torch.nn.Module):
"""A standard LeNet implementation in pytorch.
This class is intentionally not @traceable (models and layers are handled by a different process).
The LeNet model has 3 convolution layers and 2 dense layers.
Args:
input_shape: The shape of the model input (channels, height, width).
classes: The number of outputs the model should generate.
Raises:
ValueError: Length of `input_shape` is not 3.
ValueError: `input_shape`[1] or `input_shape`[2] is smaller than 18.
"""
def __init__(self, input_shape: Tuple[int, int, int] = (1, 28, 28), classes: int = 10) -> None:
LeNet._check_input_shape(input_shape)
super().__init__()
conv_kernel = 3
self.pool_kernel = 2
self.conv1 = nn.Conv2d(input_shape[0], 32, conv_kernel)
self.conv2 = nn.Conv2d(32, 64, conv_kernel)
self.conv3 = nn.Conv2d(64, 64, conv_kernel)
flat_x = ((((input_shape[1] - (conv_kernel - 1)) // self.pool_kernel) -
(conv_kernel - 1)) // self.pool_kernel) - (conv_kernel - 1)
flat_y = ((((input_shape[2] - (conv_kernel - 1)) // self.pool_kernel) -
(conv_kernel - 1)) // self.pool_kernel) - (conv_kernel - 1)
self.fc1 = nn.Linear(flat_x * flat_y * 64, 64)
self.fc2 = nn.Linear(64, classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = fn.relu(self.conv1(x))
x = fn.max_pool2d(x, self.pool_kernel)
x = fn.relu(self.conv2(x))
x = fn.max_pool2d(x, self.pool_kernel)
x = fn.relu(self.conv3(x))
x = x.view(x.size(0), -1)
x = fn.relu(self.fc1(x))
x = fn.softmax(self.fc2(x), dim=-1)
return x
@staticmethod
def _check_input_shape(input_shape):
if len(input_shape) != 3:
raise ValueError("Length of `input_shape` is not 3 (channel, height, width)")
_, height, width = input_shape
if height < 18 or width < 18:
raise ValueError("Both height and width of input_shape need to not smaller than 18")
|