def UNet(input_size: Tuple[int, int, int] = (128, 128, 1), output_channel: int = 1) -> tf.keras.Model:
"""A standard UNet implementation in TensorFlow
Args:
input_size: The size of the input tensor (height, width, channels).
output_channel: The number of output channels.
Raises:
ValueError: Length of `input_size` is not 3.
ValueError: `input_size`[0] or `input_size`[1] is not a multiple of 16.
Returns:
A TensorFlow UNet model.
"""
_check_input_size(input_size)
conv_config = {'activation': 'relu', 'padding': 'same', 'kernel_initializer': 'he_normal'}
up_config = {'size': (2, 2), 'interpolation': 'bilinear'}
inputs = Input(input_size)
conv1 = Conv2D(64, 3, **conv_config)(inputs)
conv1 = Conv2D(64, 3, **conv_config)(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, **conv_config)(pool1)
conv2 = Conv2D(128, 3, **conv_config)(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, 3, **conv_config)(pool2)
conv3 = Conv2D(256, 3, **conv_config)(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, 3, **conv_config)(pool3)
conv4 = Conv2D(512, 3, **conv_config)(conv4)
drop4 = Dropout(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, 3, **conv_config)(pool4)
conv5 = Conv2D(1024, 3, **conv_config)(conv5)
drop5 = Dropout(0.5)(conv5)
up6 = Conv2D(512, 3, **conv_config)(UpSampling2D(**up_config)(drop5))
merge6 = concatenate([drop4, up6], axis=-1)
conv6 = Conv2D(512, 3, **conv_config)(merge6)
conv6 = Conv2D(512, 3, **conv_config)(conv6)
up7 = Conv2D(256, 3, **conv_config)(UpSampling2D(**up_config)(conv6))
merge7 = concatenate([conv3, up7], axis=-1)
conv7 = Conv2D(256, 3, **conv_config)(merge7)
conv7 = Conv2D(256, 3, **conv_config)(conv7)
up8 = Conv2D(128, 3, **conv_config)(UpSampling2D(**up_config)(conv7))
merge8 = concatenate([conv2, up8], axis=-1)
conv8 = Conv2D(128, 3, **conv_config)(merge8)
conv8 = Conv2D(128, 3, **conv_config)(conv8)
up9 = Conv2D(64, 3, **conv_config)(UpSampling2D(**up_config)(conv8))
merge9 = concatenate([conv1, up9], axis=-1)
conv9 = Conv2D(64, 3, **conv_config)(merge9)
conv9 = Conv2D(64, 3, **conv_config)(conv9)
conv10 = Conv2D(output_channel, 1, activation='sigmoid')(conv9)
model = Model(inputs=inputs, outputs=conv10)
return model