classLeNet(torch.nn.Module):"""A standard LeNet implementation in pytorch. This class is intentionally not @traceable (models and layers are handled by a different process). The LeNet model has 3 convolution layers and 2 dense layers. Args: input_shape: The shape of the model input (channels, height, width). classes: The number of outputs the model should generate. Raises: ValueError: Length of `input_shape` is not 3. ValueError: `input_shape`[1] or `input_shape`[2] is smaller than 18. """def__init__(self,input_shape:Tuple[int,int,int]=(1,28,28),classes:int=10)->None:LeNet._check_input_shape(input_shape)super().__init__()conv_kernel=3self.pool_kernel=2self.conv1=nn.Conv2d(input_shape[0],32,conv_kernel)self.conv2=nn.Conv2d(32,64,conv_kernel)self.conv3=nn.Conv2d(64,64,conv_kernel)flat_x=((((input_shape[1]-(conv_kernel-1))//self.pool_kernel)-(conv_kernel-1))//self.pool_kernel)-(conv_kernel-1)flat_y=((((input_shape[2]-(conv_kernel-1))//self.pool_kernel)-(conv_kernel-1))//self.pool_kernel)-(conv_kernel-1)self.fc1=nn.Linear(flat_x*flat_y*64,64)self.fc2=nn.Linear(64,classes)defforward(self,x:torch.Tensor)->torch.Tensor:x=fn.relu(self.conv1(x))x=fn.max_pool2d(x,self.pool_kernel)x=fn.relu(self.conv2(x))x=fn.max_pool2d(x,self.pool_kernel)x=fn.relu(self.conv3(x))x=x.view(x.size(0),-1)x=fn.relu(self.fc1(x))x=fn.softmax(self.fc2(x),dim=-1)returnx@staticmethoddef_check_input_shape(input_shape):iflen(input_shape)!=3:raiseValueError("Length of `input_shape` is not 3 (channel, height, width)")_,height,width=input_shapeifheight<18orwidth<18:raiseValueError("Both height and width of input_shape need to not smaller than 18")