This is the final project of CSE204 Machine Learning course in Ecole Polytechnique. We use CNN and EfficientNet model to do OCR for singular handwriting Chinese Characters. The project is mainly based on Tensorflow 2.0 API.
Making sure to install all the packages in the requirements.txt
file, we can now start the process of OCRing the Chinese text.
# import necessary packages
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os
from alfred.utils.log import logger as logging
from tensorflow.keras.applications import EfficientNetB0
import pickle
import random
# load character mapping
this_dir = os.path.dirname(os.path.abspath("__file__"))
def load_characters():
# a = open(os.path.join(this_dir, 'dataset\\characters.txt'), 'r').readlines() # If you are using Windows
a = open(os.path.join(this_dir, 'dataset/characters.txt'), 'r').readlines() # If you are using Mac OS
return [i.strip() for i in a]
all_characters = load_characters()
num_classes = len(all_characters)
print(f'There are {num_classes} classes of chinese characters in the dataset')
logging.info('all characters: {}'.format(num_classes))
18:10:37 05.30 INFO 1114677896.py:12]: all characters: 3755
There are 3755 classes of chinese characters in the dataset
We want to load the dataset from the local file. The dataset is some tfrecord
files which conotains images, labels, and resolution information for the images. The reason for saving as tfrecord
format is to improve the training performance.
For more information about the dataset, please read the file README.md
Part 3. Dataset Preparation. We also provide the converted dataset in the following link. You can directly use it for training.
# Necessary parameters
IMG_SIZE = 80 # This size is fixed for EfficientNetB0
# decode tfrecord to images and labels
def parse_image(record):
"""
:param record: tfrecord file
:return: image and label
"""
features = tf.io.parse_single_example(record,
features={
'width':
tf.io.FixedLenFeature([], tf.int64),
'height':
tf.io.FixedLenFeature([], tf.int64),
'label':
tf.io.FixedLenFeature([], tf.int64),
'image':
tf.io.FixedLenFeature([], tf.string),
})
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
w = features['width']
h = features['height']
img = tf.cast(tf.reshape(img, (w, h)), dtype=tf.float32)
label = tf.cast(features['label'], tf.int64)
return {'image': img, 'label': label}
def load_datasets(filename):
"""
:param filename: tfrecord file
:return: dataset
"""
dataset = tf.data.TFRecordDataset([filename])
dataset = dataset.map(parse_image)
return dataset
In this part, we can see what are the datasets look like.
train_ds = load_datasets('dataset/train.tfrecord') # read train.tfrecord
test_ds = load_datasets('dataset/test.tfrecord')
train_mapped = train_ds.shuffle(100).batch(32).repeat()
test_mapped = test_ds.batch(32).repeat()
train_mapped
<RepeatDataset element_spec={'image': TensorSpec(shape=(None, None, None), dtype=tf.float32, name=None), 'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}>
# summary inforamtion of the dataset, plot the width and height of the images
shapes_plot_train = [[], []]
shapes_plot_test = [[], []]
for i in train_ds:
shapes_plot_train[0].append(i['image'].shape[0])
shapes_plot_train[1].append(i['image'].shape[1])
for i in test_ds:
shapes_plot_test[0].append(i['image'].shape[0])
shapes_plot_test[1].append(i['image'].shape[1])
# show the number of pictures in the dataset
print(f'There are {len(shapes_plot_train[0])} pictures in the training dataset')
print(f'There are {len(shapes_plot_test[0])} pictures in the testing dataset')
There are 897758 pictures in the training dataset There are 223991 pictures in the testing dataset
def plot_info_dataset(shapes_plot):
# plot the histogram of the width and height of the images, and the third plot is the scatter plot
fig, ax = plt.subplots(1, 3, figsize=(20, 5))
ax[0].hist(shapes_plot[0], bins=100)
ax[0].set_title('width')
ax[0].set_xlabel('pixels')
ax[0].set_ylabel('frequency')
ax[1].hist(shapes_plot[1], bins=100)
ax[1].set_title('height')
ax[0].set_xlabel('pixels')
ax[0].set_ylabel('frequency')
ax[2].scatter(shapes_plot[0], shapes_plot[1])
ax[2].set_title('width vs height')
ax[2].set_xlabel('width (pixels)')
ax[2].set_ylabel('height (pixels)')
plt.show()
print('------------------- Information of training dataset -------------------')
plot_info_dataset(shapes_plot_train)
print('------------------- Information of testing dataset -------------------')
plot_info_dataset(shapes_plot_test)
------------------- Information of training dataset -------------------
------------------- Information of testing dataset -------------------
We can also plot some images and labels from the datasets.
# default_font = FontProperties()
def plot_img(ds, characters, num =9):
plt.rcParams["font.sans-serif"]=["SimHei"] #Set font to avoid garbled characters
plt.rcParams["axes.unicode_minus"]=False # The same as above
ax, fig = plt.subplots(3, 3, figsize=(12, 12))
for i, data in enumerate(ds.take(num)):
img = data['image'].numpy()
label = data['label'].numpy()
fig[i//3][i%3].imshow(img, cmap='gray')
fig[i//3][i%3].set_title(f'label: {characters[label]}, size: {img.shape}', fontsize=15)
plt.show()
plot_img(train_ds, all_characters, num=9)
Important, Please Read!!! If you are using Mac OS, you may find that the labels are not displayed correctly. This is because the package Matplotlib
is not compatible with SimHei
font. To solve this problem, you can read this MacOS Chinese Characters Label Problem.
In this part, we will define the models and training functions.
In the first
part, We will first start from a simple CNN model.
# some basic parameters
TARGET_SIZE = 64
IMG_SIZE = 224 # This size is fixed for EfficientNetB0
ckpt_path = './checkpoints/simple_net/cn_ocr-{epoch}.ckpt'
train_path = 'dataset/train.tfrecord'
test_path = 'dataset/test.tfrecord'
# image preprocessing
def preprocess(ds):
"""
:param ds: dataset
:return: image and label
"""
ds['image'] = tf.expand_dims(ds['image'], axis=-1)
ds['image'] = tf.image.resize(ds['image'], (TARGET_SIZE, TARGET_SIZE))
ds['image'] = (ds['image'] - 128.) / 128.
return ds['image'], ds['label']
# history log callback function
#
# This part is used to collecting the training statics and save them into a file.
#
class SaveHistoryCallback(tf.keras.callbacks.Callback):
def __init__(self, history_path):
super().__init__()
self.history_path = history_path
if os.path.exists(self.history_path):
with open(self.history_path, 'rb') as f:
self.history = pickle.load(f)
else:
self.history = {}
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
for key in logs:
if key in self.history:
self.history[key].append(logs[key])
else:
self.history[key] = [logs[key]]
def on_train_end(self, logs=None):
with open(self.history_path, 'wb') as f:
pickle.dump(self.history, f)
# Define a simple CNN model
def simple_net(input_shape, n_classes):
model = tf.keras.Sequential([
layers.Conv2D(input_shape=input_shape, filters=32, kernel_size=(3, 3), strides=(1, 1),
padding='same', activation='relu'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Flatten(),
# layers.Dense(1024, activation='relu'),
layers.Dense(n_classes, activation='softmax')
])
return model
# Define with a more complex CNN model - However, this model is not even converging
def CNN_ComplexModel_1(input_shape, n_classes):
model = tf.keras.Sequential([
layers.Conv2D(input_shape=input_shape, filters=64, kernel_size=(3, 3), strides=(1, 1),
padding='same', activation='relu'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=256, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=512, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Flatten(),
layers.Dense(1024, activation='relu'),
layers.Dense(n_classes, activation='softmax')
])
return model
# show the summary of the model
model = simple_net((TARGET_SIZE, TARGET_SIZE, 1), len(all_characters))
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 64, 64, 32) 320 max_pooling2d (MaxPooling2D (None, 32, 32, 32) 0 ) conv2d_1 (Conv2D) (None, 32, 32, 64) 18496 max_pooling2d_1 (MaxPooling (None, 16, 16, 64) 0 2D) flatten (Flatten) (None, 16384) 0 dense (Dense) (None, 3755) 61525675 ================================================================= Total params: 61,544,491 Trainable params: 61,544,491 Non-trainable params: 0 _________________________________________________________________
model_cnn_complex = CNN_ComplexModel_1((TARGET_SIZE, TARGET_SIZE, 1), len(all_characters))
model_cnn_complex.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_2 (Conv2D) (None, 64, 64, 64) 640 max_pooling2d_2 (MaxPooling (None, 32, 32, 64) 0 2D) conv2d_3 (Conv2D) (None, 32, 32, 128) 73856 max_pooling2d_3 (MaxPooling (None, 16, 16, 128) 0 2D) conv2d_4 (Conv2D) (None, 16, 16, 256) 295168 max_pooling2d_4 (MaxPooling (None, 8, 8, 256) 0 2D) conv2d_5 (Conv2D) (None, 8, 8, 512) 1180160 max_pooling2d_5 (MaxPooling (None, 4, 4, 512) 0 2D) flatten_1 (Flatten) (None, 8192) 0 dense_1 (Dense) (None, 1024) 8389632 dense_2 (Dense) (None, 3755) 3848875 ================================================================= Total params: 13,788,331 Trainable params: 13,788,331 Non-trainable params: 0 _________________________________________________________________
In this part. We will try to train the model.
def train_simple():
print(f'number of classes: {num_classes}')
history_path = 'history_simple.pickle'
save_history_callback = SaveHistoryCallback(history_path)
train_dataset = load_datasets(train_path)
test_dataset = load_datasets(test_path)
train_dataset = train_dataset.map(preprocess).shuffle(100).batch(32).repeat()
test_dataset = test_dataset.shuffle(100).map(preprocess).batch(32)
print(f'train dataset: {train_dataset}')
# build model
model = simple_net((TARGET_SIZE, TARGET_SIZE, 1), num_classes)
model.summary()
# latest checkpoints
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
if latest_ckpt:
print(f'model resumed from: {latest_ckpt}')
model.load_weights(latest_ckpt)
else:
print('training from scratch')
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(ckpt_path,
save_weights_only=True,
verbose=1,
save_freq='epoch',
save_best_only=True),
save_history_callback
]
try:
model.fit(
train_dataset,
validation_data=test_dataset,
validation_steps=1000,
epochs=100,
steps_per_epoch=1024,
callbacks=callbacks)
except KeyboardInterrupt:
logging.info('keras model saved. KeyboardInterrupt')
save_history_callback.on_train_end()
return model
model.save_weights(ckpt_path.format(epoch=0))
model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr_simple.h5'))
logging.info('All epoch finished. keras model saved.')
return model
From our trainning result, In the two CNN models we provide above, the simpler one has better performance. The more complex CNN model not even converge which is a quite interesting result.
# train the model (computation heavy)
model = train_simple()
number of classes: 3755 train dataset: <RepeatDataset element_spec=(TensorSpec(shape=(None, 64, 64, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))> Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_6 (Conv2D) (None, 64, 64, 32) 320 max_pooling2d_6 (MaxPooling (None, 32, 32, 32) 0 2D) conv2d_7 (Conv2D) (None, 32, 32, 64) 18496 max_pooling2d_7 (MaxPooling (None, 16, 16, 64) 0 2D) flatten_2 (Flatten) (None, 16384) 0 dense_3 (Dense) (None, 3755) 61525675 ================================================================= Total params: 61,544,491 Trainable params: 61,544,491 Non-trainable params: 0 _________________________________________________________________ model resumed from: ./checkpoints/simple_net\cn_ocr-1.ckpt Epoch 1/100 4/1024 [..............................] - ETA: 20s - loss: 0.0467 - accuracy: 0.9766 WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0068s vs `on_train_batch_end` time: 0.0108s). Check your callbacks. 1022/1024 [============================>.] - ETA: 0s - loss: 0.0242 - accuracy: 0.9927 Epoch 1: val_loss improved from inf to 4.37403, saving model to ./checkpoints/simple_net\cn_ocr-1.ckpt 1024/1024 [==============================] - 20s 16ms/step - loss: 0.0242 - accuracy: 0.9927 - val_loss: 4.3740 - val_accuracy: 0.4745 Epoch 2/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.1772 - accuracy: 0.9523 Epoch 2: val_loss did not improve from 4.37403 1024/1024 [==============================] - 15s 15ms/step - loss: 0.1772 - accuracy: 0.9523 - val_loss: 5.0706 - val_accuracy: 0.3959 Epoch 3/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.1651 - accuracy: 0.9527 Epoch 3: val_loss did not improve from 4.37403 1024/1024 [==============================] - 14s 14ms/step - loss: 0.1652 - accuracy: 0.9528 - val_loss: 5.4529 - val_accuracy: 0.4083 Epoch 4/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.1733 - accuracy: 0.9509 Epoch 4: val_loss did not improve from 4.37403 1024/1024 [==============================] - 14s 14ms/step - loss: 0.1732 - accuracy: 0.9509 - val_loss: 4.7948 - val_accuracy: 0.4309 Epoch 5/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.1912 - accuracy: 0.9467 Epoch 5: val_loss did not improve from 4.37403 1024/1024 [==============================] - 15s 14ms/step - loss: 0.1910 - accuracy: 0.9467 - val_loss: 4.8460 - val_accuracy: 0.4083 Epoch 6/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.2448 - accuracy: 0.9330 Epoch 6: val_loss did not improve from 4.37403 1024/1024 [==============================] - 15s 14ms/step - loss: 0.2444 - accuracy: 0.9331 - val_loss: 5.6845 - val_accuracy: 0.3722 Epoch 7/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.2070 - accuracy: 0.9426 Epoch 7: val_loss did not improve from 4.37403 1024/1024 [==============================] - 15s 14ms/step - loss: 0.2068 - accuracy: 0.9426 - val_loss: 5.0307 - val_accuracy: 0.3918 Epoch 8/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.2148 - accuracy: 0.9402 Epoch 8: val_loss did not improve from 4.37403 1024/1024 [==============================] - 15s 14ms/step - loss: 0.2151 - accuracy: 0.9401 - val_loss: 5.2210 - val_accuracy: 0.4089 Epoch 9/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.2385 - accuracy: 0.9332 Epoch 9: val_loss did not improve from 4.37403 1024/1024 [==============================] - 15s 14ms/step - loss: 0.2385 - accuracy: 0.9332 - val_loss: 4.9118 - val_accuracy: 0.4042 Epoch 10/100 1020/1024 [============================>.] - ETA: 0s - loss: 0.2330 - accuracy: 0.9342 Epoch 10: val_loss did not improve from 4.37403 1024/1024 [==============================] - 15s 14ms/step - loss: 0.2326 - accuracy: 0.9343 - val_loss: 6.6977 - val_accuracy: 0.3193 Epoch 11/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.1729 - accuracy: 0.9498 Epoch 11: val_loss did not improve from 4.37403 1024/1024 [==============================] - 14s 14ms/step - loss: 0.1730 - accuracy: 0.9497 - val_loss: 4.9370 - val_accuracy: 0.4512 Epoch 12/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.2010 - accuracy: 0.9420 Epoch 12: val_loss did not improve from 4.37403 1024/1024 [==============================] - 14s 14ms/step - loss: 0.2009 - accuracy: 0.9420 - val_loss: 5.2270 - val_accuracy: 0.4319 Epoch 13/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.1908 - accuracy: 0.9446 Epoch 13: val_loss did not improve from 4.37403 1024/1024 [==============================] - 14s 14ms/step - loss: 0.1910 - accuracy: 0.9445 - val_loss: 4.6441 - val_accuracy: 0.4633 Epoch 14/100 1021/1024 [============================>.] - ETA: 0s - loss: 0.9264 - accuracy: 0.7941
18:20:18 05.30 INFO 3396354300.py:49]: keras model saved. KeyboardInterrupt
with open('history_simple.pickle', 'rb') as f:
history = pickle.load(f)
# Plot accuracy
plt.figure(figsize=(10, 7))
plt.plot(history['accuracy'])
plt.plot(history['val_accuracy'])
plt.title('Model accuracy - Simple CNN', fontsize=15)
plt.ylabel('Accuracy', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')
# show the highest accuracy in the graph
max_acc = max(history['val_accuracy'])
plt.annotate(f'max accuracy: {max_acc:.4f}', xy=(np.argmax(history['val_accuracy']), max_acc),
xytext=(np.argmax(history['val_accuracy']) - 10, max_acc + 0.1),
arrowprops=dict(facecolor='black', shrink=0.03), fontsize=15)
max_acc_train = max(history['accuracy'])
plt.annotate(f'max accuracy: {max_acc_train:.4f}', xy=(np.argmax(history['accuracy']), max_acc_train),
xytext=(np.argmax(history['accuracy']) - 30, max_acc_train - 0.1),
arrowprops=dict(facecolor='black', shrink=0.03), fontsize=15)
plt.show()
# Plot loss
plt.figure(figsize=(10, 7))
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.title('Model loss - Simple CNN', fontsize=15)
plt.ylabel('Loss', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
# Load the best model
model = simple_net((TARGET_SIZE, TARGET_SIZE, 1), num_classes)
model.load_weights(tf.train.latest_checkpoint(os.path.dirname(ckpt_path)))
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x25053771510>
# display the result
import random
def display_result(model, ds, characters):
ds = ds.shuffle(100).map(preprocess).batch(32).repeat()
f, axe = plt.subplots(3, 3, figsize=(15, 15))
plt.rcParams["font.sans-serif"]=["SimHei"] #Set font to avoid garbled characters
plt.rcParams["axes.unicode_minus"]=False # The same as above
# take 9 samples
for data in ds.take(1):
images, labels = data
predictions = model.predict(images)
for i in range(9):
axe[i//3, i%3].imshow(images[i].numpy().reshape(64, 64), cmap='gray')
axe[i//3, i%3].axis('on')
axe[i//3, i%3].set_title(f'pred: {characters[np.argmax(predictions[i])]}, label: {characters[labels[i]]}, Is Correct: {np.argmax(predictions[i]) == labels[i]}')
plt.show()
test_ds = load_datasets(test_path)
display_result(model, test_ds, all_characters)
1/1 [==============================] - 0s 65ms/step
This model can only reach the validation accuracy
to around 40%
, which is not satisfying for us. We decided to use a far more complicated model called effcientnet
to train the model. This model is a state-of-the-art model for image classification. We will use the B0
version of the model. For more information about the model, please read the paper and this webpage
# some basic parameters
IMG_SIZE = 224 # This size is fixed for EfficientNetB0
ckpt_path = './checkpoints/efficient_net/cn_ocr-{epoch}.ckpt'
train_path = 'dataset/train.tfrecord'
test_path = 'dataset/test.tfrecord'
def preprocess_efficientnet(ds):
"""
:param ds: dataset
:return: image and label
No normalization is needed here for EfficientNetB0
"""
ds['image'] = tf.expand_dims(ds['image'], axis=-1)
ds['image'] = tf.image.grayscale_to_rgb(ds['image'])
ds['image'] = tf.image.resize(ds['image'], (IMG_SIZE, IMG_SIZE))
return ds['image'], ds['label']
# model
def effcientnetB0_model():
inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3)) # EfficientNetB0 expects 3 channels
base_model = EfficientNetB0(include_top=False, input_tensor=inputs, weights='imagenet')
x = base_model.output
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(num_classes, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=x)
return model
# model summary of EfficientNetB0
model = effcientnetB0_model()
model.summary()
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 224, 224, 3 0 [] )] rescaling (Rescaling) (None, 224, 224, 3) 0 ['input_1[0][0]'] normalization (Normalization) (None, 224, 224, 3) 7 ['rescaling[0][0]'] rescaling_1 (Rescaling) (None, 224, 224, 3) 0 ['normalization[0][0]'] stem_conv_pad (ZeroPadding2D) (None, 225, 225, 3) 0 ['rescaling_1[0][0]'] stem_conv (Conv2D) (None, 112, 112, 32 864 ['stem_conv_pad[0][0]'] ) stem_bn (BatchNormalization) (None, 112, 112, 32 128 ['stem_conv[0][0]'] ) stem_activation (Activation) (None, 112, 112, 32 0 ['stem_bn[0][0]'] ) block1a_dwconv (DepthwiseConv2 (None, 112, 112, 32 288 ['stem_activation[0][0]'] D) ) block1a_bn (BatchNormalization (None, 112, 112, 32 128 ['block1a_dwconv[0][0]'] ) ) block1a_activation (Activation (None, 112, 112, 32 0 ['block1a_bn[0][0]'] ) ) block1a_se_squeeze (GlobalAver (None, 32) 0 ['block1a_activation[0][0]'] agePooling2D) block1a_se_reshape (Reshape) (None, 1, 1, 32) 0 ['block1a_se_squeeze[0][0]'] block1a_se_reduce (Conv2D) (None, 1, 1, 8) 264 ['block1a_se_reshape[0][0]'] block1a_se_expand (Conv2D) (None, 1, 1, 32) 288 ['block1a_se_reduce[0][0]'] block1a_se_excite (Multiply) (None, 112, 112, 32 0 ['block1a_activation[0][0]', ) 'block1a_se_expand[0][0]'] block1a_project_conv (Conv2D) (None, 112, 112, 16 512 ['block1a_se_excite[0][0]'] ) block1a_project_bn (BatchNorma (None, 112, 112, 16 64 ['block1a_project_conv[0][0]'] lization) ) block2a_expand_conv (Conv2D) (None, 112, 112, 96 1536 ['block1a_project_bn[0][0]'] ) block2a_expand_bn (BatchNormal (None, 112, 112, 96 384 ['block2a_expand_conv[0][0]'] ization) ) block2a_expand_activation (Act (None, 112, 112, 96 0 ['block2a_expand_bn[0][0]'] ivation) ) block2a_dwconv_pad (ZeroPaddin (None, 113, 113, 96 0 ['block2a_expand_activation[0][0] g2D) ) '] block2a_dwconv (DepthwiseConv2 (None, 56, 56, 96) 864 ['block2a_dwconv_pad[0][0]'] D) block2a_bn (BatchNormalization (None, 56, 56, 96) 384 ['block2a_dwconv[0][0]'] ) block2a_activation (Activation (None, 56, 56, 96) 0 ['block2a_bn[0][0]'] ) block2a_se_squeeze (GlobalAver (None, 96) 0 ['block2a_activation[0][0]'] agePooling2D) block2a_se_reshape (Reshape) (None, 1, 1, 96) 0 ['block2a_se_squeeze[0][0]'] block2a_se_reduce (Conv2D) (None, 1, 1, 4) 388 ['block2a_se_reshape[0][0]'] block2a_se_expand (Conv2D) (None, 1, 1, 96) 480 ['block2a_se_reduce[0][0]'] block2a_se_excite (Multiply) (None, 56, 56, 96) 0 ['block2a_activation[0][0]', 'block2a_se_expand[0][0]'] block2a_project_conv (Conv2D) (None, 56, 56, 24) 2304 ['block2a_se_excite[0][0]'] block2a_project_bn (BatchNorma (None, 56, 56, 24) 96 ['block2a_project_conv[0][0]'] lization) block2b_expand_conv (Conv2D) (None, 56, 56, 144) 3456 ['block2a_project_bn[0][0]'] block2b_expand_bn (BatchNormal (None, 56, 56, 144) 576 ['block2b_expand_conv[0][0]'] ization) block2b_expand_activation (Act (None, 56, 56, 144) 0 ['block2b_expand_bn[0][0]'] ivation) block2b_dwconv (DepthwiseConv2 (None, 56, 56, 144) 1296 ['block2b_expand_activation[0][0] D) '] block2b_bn (BatchNormalization (None, 56, 56, 144) 576 ['block2b_dwconv[0][0]'] ) block2b_activation (Activation (None, 56, 56, 144) 0 ['block2b_bn[0][0]'] ) block2b_se_squeeze (GlobalAver (None, 144) 0 ['block2b_activation[0][0]'] agePooling2D) block2b_se_reshape (Reshape) (None, 1, 1, 144) 0 ['block2b_se_squeeze[0][0]'] block2b_se_reduce (Conv2D) (None, 1, 1, 6) 870 ['block2b_se_reshape[0][0]'] block2b_se_expand (Conv2D) (None, 1, 1, 144) 1008 ['block2b_se_reduce[0][0]'] block2b_se_excite (Multiply) (None, 56, 56, 144) 0 ['block2b_activation[0][0]', 'block2b_se_expand[0][0]'] block2b_project_conv (Conv2D) (None, 56, 56, 24) 3456 ['block2b_se_excite[0][0]'] block2b_project_bn (BatchNorma (None, 56, 56, 24) 96 ['block2b_project_conv[0][0]'] lization) block2b_drop (Dropout) (None, 56, 56, 24) 0 ['block2b_project_bn[0][0]'] block2b_add (Add) (None, 56, 56, 24) 0 ['block2b_drop[0][0]', 'block2a_project_bn[0][0]'] block3a_expand_conv (Conv2D) (None, 56, 56, 144) 3456 ['block2b_add[0][0]'] block3a_expand_bn (BatchNormal (None, 56, 56, 144) 576 ['block3a_expand_conv[0][0]'] ization) block3a_expand_activation (Act (None, 56, 56, 144) 0 ['block3a_expand_bn[0][0]'] ivation) block3a_dwconv_pad (ZeroPaddin (None, 59, 59, 144) 0 ['block3a_expand_activation[0][0] g2D) '] block3a_dwconv (DepthwiseConv2 (None, 28, 28, 144) 3600 ['block3a_dwconv_pad[0][0]'] D) block3a_bn (BatchNormalization (None, 28, 28, 144) 576 ['block3a_dwconv[0][0]'] ) block3a_activation (Activation (None, 28, 28, 144) 0 ['block3a_bn[0][0]'] ) block3a_se_squeeze (GlobalAver (None, 144) 0 ['block3a_activation[0][0]'] agePooling2D) block3a_se_reshape (Reshape) (None, 1, 1, 144) 0 ['block3a_se_squeeze[0][0]'] block3a_se_reduce (Conv2D) (None, 1, 1, 6) 870 ['block3a_se_reshape[0][0]'] block3a_se_expand (Conv2D) (None, 1, 1, 144) 1008 ['block3a_se_reduce[0][0]'] block3a_se_excite (Multiply) (None, 28, 28, 144) 0 ['block3a_activation[0][0]', 'block3a_se_expand[0][0]'] block3a_project_conv (Conv2D) (None, 28, 28, 40) 5760 ['block3a_se_excite[0][0]'] block3a_project_bn (BatchNorma (None, 28, 28, 40) 160 ['block3a_project_conv[0][0]'] lization) block3b_expand_conv (Conv2D) (None, 28, 28, 240) 9600 ['block3a_project_bn[0][0]'] block3b_expand_bn (BatchNormal (None, 28, 28, 240) 960 ['block3b_expand_conv[0][0]'] ization) block3b_expand_activation (Act (None, 28, 28, 240) 0 ['block3b_expand_bn[0][0]'] ivation) block3b_dwconv (DepthwiseConv2 (None, 28, 28, 240) 6000 ['block3b_expand_activation[0][0] D) '] block3b_bn (BatchNormalization (None, 28, 28, 240) 960 ['block3b_dwconv[0][0]'] ) block3b_activation (Activation (None, 28, 28, 240) 0 ['block3b_bn[0][0]'] ) block3b_se_squeeze (GlobalAver (None, 240) 0 ['block3b_activation[0][0]'] agePooling2D) block3b_se_reshape (Reshape) (None, 1, 1, 240) 0 ['block3b_se_squeeze[0][0]'] block3b_se_reduce (Conv2D) (None, 1, 1, 10) 2410 ['block3b_se_reshape[0][0]'] block3b_se_expand (Conv2D) (None, 1, 1, 240) 2640 ['block3b_se_reduce[0][0]'] block3b_se_excite (Multiply) (None, 28, 28, 240) 0 ['block3b_activation[0][0]', 'block3b_se_expand[0][0]'] block3b_project_conv (Conv2D) (None, 28, 28, 40) 9600 ['block3b_se_excite[0][0]'] block3b_project_bn (BatchNorma (None, 28, 28, 40) 160 ['block3b_project_conv[0][0]'] lization) block3b_drop (Dropout) (None, 28, 28, 40) 0 ['block3b_project_bn[0][0]'] block3b_add (Add) (None, 28, 28, 40) 0 ['block3b_drop[0][0]', 'block3a_project_bn[0][0]'] block4a_expand_conv (Conv2D) (None, 28, 28, 240) 9600 ['block3b_add[0][0]'] block4a_expand_bn (BatchNormal (None, 28, 28, 240) 960 ['block4a_expand_conv[0][0]'] ization) block4a_expand_activation (Act (None, 28, 28, 240) 0 ['block4a_expand_bn[0][0]'] ivation) block4a_dwconv_pad (ZeroPaddin (None, 29, 29, 240) 0 ['block4a_expand_activation[0][0] g2D) '] block4a_dwconv (DepthwiseConv2 (None, 14, 14, 240) 2160 ['block4a_dwconv_pad[0][0]'] D) block4a_bn (BatchNormalization (None, 14, 14, 240) 960 ['block4a_dwconv[0][0]'] ) block4a_activation (Activation (None, 14, 14, 240) 0 ['block4a_bn[0][0]'] ) block4a_se_squeeze (GlobalAver (None, 240) 0 ['block4a_activation[0][0]'] agePooling2D) block4a_se_reshape (Reshape) (None, 1, 1, 240) 0 ['block4a_se_squeeze[0][0]'] block4a_se_reduce (Conv2D) (None, 1, 1, 10) 2410 ['block4a_se_reshape[0][0]'] block4a_se_expand (Conv2D) (None, 1, 1, 240) 2640 ['block4a_se_reduce[0][0]'] block4a_se_excite (Multiply) (None, 14, 14, 240) 0 ['block4a_activation[0][0]', 'block4a_se_expand[0][0]'] block4a_project_conv (Conv2D) (None, 14, 14, 80) 19200 ['block4a_se_excite[0][0]'] block4a_project_bn (BatchNorma (None, 14, 14, 80) 320 ['block4a_project_conv[0][0]'] lization) block4b_expand_conv (Conv2D) (None, 14, 14, 480) 38400 ['block4a_project_bn[0][0]'] block4b_expand_bn (BatchNormal (None, 14, 14, 480) 1920 ['block4b_expand_conv[0][0]'] ization) block4b_expand_activation (Act (None, 14, 14, 480) 0 ['block4b_expand_bn[0][0]'] ivation) block4b_dwconv (DepthwiseConv2 (None, 14, 14, 480) 4320 ['block4b_expand_activation[0][0] D) '] block4b_bn (BatchNormalization (None, 14, 14, 480) 1920 ['block4b_dwconv[0][0]'] ) block4b_activation (Activation (None, 14, 14, 480) 0 ['block4b_bn[0][0]'] ) block4b_se_squeeze (GlobalAver (None, 480) 0 ['block4b_activation[0][0]'] agePooling2D) block4b_se_reshape (Reshape) (None, 1, 1, 480) 0 ['block4b_se_squeeze[0][0]'] block4b_se_reduce (Conv2D) (None, 1, 1, 20) 9620 ['block4b_se_reshape[0][0]'] block4b_se_expand (Conv2D) (None, 1, 1, 480) 10080 ['block4b_se_reduce[0][0]'] block4b_se_excite (Multiply) (None, 14, 14, 480) 0 ['block4b_activation[0][0]', 'block4b_se_expand[0][0]'] block4b_project_conv (Conv2D) (None, 14, 14, 80) 38400 ['block4b_se_excite[0][0]'] block4b_project_bn (BatchNorma (None, 14, 14, 80) 320 ['block4b_project_conv[0][0]'] lization) block4b_drop (Dropout) (None, 14, 14, 80) 0 ['block4b_project_bn[0][0]'] block4b_add (Add) (None, 14, 14, 80) 0 ['block4b_drop[0][0]', 'block4a_project_bn[0][0]'] block4c_expand_conv (Conv2D) (None, 14, 14, 480) 38400 ['block4b_add[0][0]'] block4c_expand_bn (BatchNormal (None, 14, 14, 480) 1920 ['block4c_expand_conv[0][0]'] ization) block4c_expand_activation (Act (None, 14, 14, 480) 0 ['block4c_expand_bn[0][0]'] ivation) block4c_dwconv (DepthwiseConv2 (None, 14, 14, 480) 4320 ['block4c_expand_activation[0][0] D) '] block4c_bn (BatchNormalization (None, 14, 14, 480) 1920 ['block4c_dwconv[0][0]'] ) block4c_activation (Activation (None, 14, 14, 480) 0 ['block4c_bn[0][0]'] ) block4c_se_squeeze (GlobalAver (None, 480) 0 ['block4c_activation[0][0]'] agePooling2D) block4c_se_reshape (Reshape) (None, 1, 1, 480) 0 ['block4c_se_squeeze[0][0]'] block4c_se_reduce (Conv2D) (None, 1, 1, 20) 9620 ['block4c_se_reshape[0][0]'] block4c_se_expand (Conv2D) (None, 1, 1, 480) 10080 ['block4c_se_reduce[0][0]'] block4c_se_excite (Multiply) (None, 14, 14, 480) 0 ['block4c_activation[0][0]', 'block4c_se_expand[0][0]'] block4c_project_conv (Conv2D) (None, 14, 14, 80) 38400 ['block4c_se_excite[0][0]'] block4c_project_bn (BatchNorma (None, 14, 14, 80) 320 ['block4c_project_conv[0][0]'] lization) block4c_drop (Dropout) (None, 14, 14, 80) 0 ['block4c_project_bn[0][0]'] block4c_add (Add) (None, 14, 14, 80) 0 ['block4c_drop[0][0]', 'block4b_add[0][0]'] block5a_expand_conv (Conv2D) (None, 14, 14, 480) 38400 ['block4c_add[0][0]'] block5a_expand_bn (BatchNormal (None, 14, 14, 480) 1920 ['block5a_expand_conv[0][0]'] ization) block5a_expand_activation (Act (None, 14, 14, 480) 0 ['block5a_expand_bn[0][0]'] ivation) block5a_dwconv (DepthwiseConv2 (None, 14, 14, 480) 12000 ['block5a_expand_activation[0][0] D) '] block5a_bn (BatchNormalization (None, 14, 14, 480) 1920 ['block5a_dwconv[0][0]'] ) block5a_activation (Activation (None, 14, 14, 480) 0 ['block5a_bn[0][0]'] ) block5a_se_squeeze (GlobalAver (None, 480) 0 ['block5a_activation[0][0]'] agePooling2D) block5a_se_reshape (Reshape) (None, 1, 1, 480) 0 ['block5a_se_squeeze[0][0]'] block5a_se_reduce (Conv2D) (None, 1, 1, 20) 9620 ['block5a_se_reshape[0][0]'] block5a_se_expand (Conv2D) (None, 1, 1, 480) 10080 ['block5a_se_reduce[0][0]'] block5a_se_excite (Multiply) (None, 14, 14, 480) 0 ['block5a_activation[0][0]', 'block5a_se_expand[0][0]'] block5a_project_conv (Conv2D) (None, 14, 14, 112) 53760 ['block5a_se_excite[0][0]'] block5a_project_bn (BatchNorma (None, 14, 14, 112) 448 ['block5a_project_conv[0][0]'] lization) block5b_expand_conv (Conv2D) (None, 14, 14, 672) 75264 ['block5a_project_bn[0][0]'] block5b_expand_bn (BatchNormal (None, 14, 14, 672) 2688 ['block5b_expand_conv[0][0]'] ization) block5b_expand_activation (Act (None, 14, 14, 672) 0 ['block5b_expand_bn[0][0]'] ivation) block5b_dwconv (DepthwiseConv2 (None, 14, 14, 672) 16800 ['block5b_expand_activation[0][0] D) '] block5b_bn (BatchNormalization (None, 14, 14, 672) 2688 ['block5b_dwconv[0][0]'] ) block5b_activation (Activation (None, 14, 14, 672) 0 ['block5b_bn[0][0]'] ) block5b_se_squeeze (GlobalAver (None, 672) 0 ['block5b_activation[0][0]'] agePooling2D) block5b_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block5b_se_squeeze[0][0]'] block5b_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block5b_se_reshape[0][0]'] block5b_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block5b_se_reduce[0][0]'] block5b_se_excite (Multiply) (None, 14, 14, 672) 0 ['block5b_activation[0][0]', 'block5b_se_expand[0][0]'] block5b_project_conv (Conv2D) (None, 14, 14, 112) 75264 ['block5b_se_excite[0][0]'] block5b_project_bn (BatchNorma (None, 14, 14, 112) 448 ['block5b_project_conv[0][0]'] lization) block5b_drop (Dropout) (None, 14, 14, 112) 0 ['block5b_project_bn[0][0]'] block5b_add (Add) (None, 14, 14, 112) 0 ['block5b_drop[0][0]', 'block5a_project_bn[0][0]'] block5c_expand_conv (Conv2D) (None, 14, 14, 672) 75264 ['block5b_add[0][0]'] block5c_expand_bn (BatchNormal (None, 14, 14, 672) 2688 ['block5c_expand_conv[0][0]'] ization) block5c_expand_activation (Act (None, 14, 14, 672) 0 ['block5c_expand_bn[0][0]'] ivation) block5c_dwconv (DepthwiseConv2 (None, 14, 14, 672) 16800 ['block5c_expand_activation[0][0] D) '] block5c_bn (BatchNormalization (None, 14, 14, 672) 2688 ['block5c_dwconv[0][0]'] ) block5c_activation (Activation (None, 14, 14, 672) 0 ['block5c_bn[0][0]'] ) block5c_se_squeeze (GlobalAver (None, 672) 0 ['block5c_activation[0][0]'] agePooling2D) block5c_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block5c_se_squeeze[0][0]'] block5c_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block5c_se_reshape[0][0]'] block5c_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block5c_se_reduce[0][0]'] block5c_se_excite (Multiply) (None, 14, 14, 672) 0 ['block5c_activation[0][0]', 'block5c_se_expand[0][0]'] block5c_project_conv (Conv2D) (None, 14, 14, 112) 75264 ['block5c_se_excite[0][0]'] block5c_project_bn (BatchNorma (None, 14, 14, 112) 448 ['block5c_project_conv[0][0]'] lization) block5c_drop (Dropout) (None, 14, 14, 112) 0 ['block5c_project_bn[0][0]'] block5c_add (Add) (None, 14, 14, 112) 0 ['block5c_drop[0][0]', 'block5b_add[0][0]'] block6a_expand_conv (Conv2D) (None, 14, 14, 672) 75264 ['block5c_add[0][0]'] block6a_expand_bn (BatchNormal (None, 14, 14, 672) 2688 ['block6a_expand_conv[0][0]'] ization) block6a_expand_activation (Act (None, 14, 14, 672) 0 ['block6a_expand_bn[0][0]'] ivation) block6a_dwconv_pad (ZeroPaddin (None, 17, 17, 672) 0 ['block6a_expand_activation[0][0] g2D) '] block6a_dwconv (DepthwiseConv2 (None, 7, 7, 672) 16800 ['block6a_dwconv_pad[0][0]'] D) block6a_bn (BatchNormalization (None, 7, 7, 672) 2688 ['block6a_dwconv[0][0]'] ) block6a_activation (Activation (None, 7, 7, 672) 0 ['block6a_bn[0][0]'] ) block6a_se_squeeze (GlobalAver (None, 672) 0 ['block6a_activation[0][0]'] agePooling2D) block6a_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block6a_se_squeeze[0][0]'] block6a_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block6a_se_reshape[0][0]'] block6a_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block6a_se_reduce[0][0]'] block6a_se_excite (Multiply) (None, 7, 7, 672) 0 ['block6a_activation[0][0]', 'block6a_se_expand[0][0]'] block6a_project_conv (Conv2D) (None, 7, 7, 192) 129024 ['block6a_se_excite[0][0]'] block6a_project_bn (BatchNorma (None, 7, 7, 192) 768 ['block6a_project_conv[0][0]'] lization) block6b_expand_conv (Conv2D) (None, 7, 7, 1152) 221184 ['block6a_project_bn[0][0]'] block6b_expand_bn (BatchNormal (None, 7, 7, 1152) 4608 ['block6b_expand_conv[0][0]'] ization) block6b_expand_activation (Act (None, 7, 7, 1152) 0 ['block6b_expand_bn[0][0]'] ivation) block6b_dwconv (DepthwiseConv2 (None, 7, 7, 1152) 28800 ['block6b_expand_activation[0][0] D) '] block6b_bn (BatchNormalization (None, 7, 7, 1152) 4608 ['block6b_dwconv[0][0]'] ) block6b_activation (Activation (None, 7, 7, 1152) 0 ['block6b_bn[0][0]'] ) block6b_se_squeeze (GlobalAver (None, 1152) 0 ['block6b_activation[0][0]'] agePooling2D) block6b_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6b_se_squeeze[0][0]'] block6b_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6b_se_reshape[0][0]'] block6b_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6b_se_reduce[0][0]'] block6b_se_excite (Multiply) (None, 7, 7, 1152) 0 ['block6b_activation[0][0]', 'block6b_se_expand[0][0]'] block6b_project_conv (Conv2D) (None, 7, 7, 192) 221184 ['block6b_se_excite[0][0]'] block6b_project_bn (BatchNorma (None, 7, 7, 192) 768 ['block6b_project_conv[0][0]'] lization) block6b_drop (Dropout) (None, 7, 7, 192) 0 ['block6b_project_bn[0][0]'] block6b_add (Add) (None, 7, 7, 192) 0 ['block6b_drop[0][0]', 'block6a_project_bn[0][0]'] block6c_expand_conv (Conv2D) (None, 7, 7, 1152) 221184 ['block6b_add[0][0]'] block6c_expand_bn (BatchNormal (None, 7, 7, 1152) 4608 ['block6c_expand_conv[0][0]'] ization) block6c_expand_activation (Act (None, 7, 7, 1152) 0 ['block6c_expand_bn[0][0]'] ivation) block6c_dwconv (DepthwiseConv2 (None, 7, 7, 1152) 28800 ['block6c_expand_activation[0][0] D) '] block6c_bn (BatchNormalization (None, 7, 7, 1152) 4608 ['block6c_dwconv[0][0]'] ) block6c_activation (Activation (None, 7, 7, 1152) 0 ['block6c_bn[0][0]'] ) block6c_se_squeeze (GlobalAver (None, 1152) 0 ['block6c_activation[0][0]'] agePooling2D) block6c_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6c_se_squeeze[0][0]'] block6c_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6c_se_reshape[0][0]'] block6c_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6c_se_reduce[0][0]'] block6c_se_excite (Multiply) (None, 7, 7, 1152) 0 ['block6c_activation[0][0]', 'block6c_se_expand[0][0]'] block6c_project_conv (Conv2D) (None, 7, 7, 192) 221184 ['block6c_se_excite[0][0]'] block6c_project_bn (BatchNorma (None, 7, 7, 192) 768 ['block6c_project_conv[0][0]'] lization) block6c_drop (Dropout) (None, 7, 7, 192) 0 ['block6c_project_bn[0][0]'] block6c_add (Add) (None, 7, 7, 192) 0 ['block6c_drop[0][0]', 'block6b_add[0][0]'] block6d_expand_conv (Conv2D) (None, 7, 7, 1152) 221184 ['block6c_add[0][0]'] block6d_expand_bn (BatchNormal (None, 7, 7, 1152) 4608 ['block6d_expand_conv[0][0]'] ization) block6d_expand_activation (Act (None, 7, 7, 1152) 0 ['block6d_expand_bn[0][0]'] ivation) block6d_dwconv (DepthwiseConv2 (None, 7, 7, 1152) 28800 ['block6d_expand_activation[0][0] D) '] block6d_bn (BatchNormalization (None, 7, 7, 1152) 4608 ['block6d_dwconv[0][0]'] ) block6d_activation (Activation (None, 7, 7, 1152) 0 ['block6d_bn[0][0]'] ) block6d_se_squeeze (GlobalAver (None, 1152) 0 ['block6d_activation[0][0]'] agePooling2D) block6d_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6d_se_squeeze[0][0]'] block6d_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6d_se_reshape[0][0]'] block6d_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6d_se_reduce[0][0]'] block6d_se_excite (Multiply) (None, 7, 7, 1152) 0 ['block6d_activation[0][0]', 'block6d_se_expand[0][0]'] block6d_project_conv (Conv2D) (None, 7, 7, 192) 221184 ['block6d_se_excite[0][0]'] block6d_project_bn (BatchNorma (None, 7, 7, 192) 768 ['block6d_project_conv[0][0]'] lization) block6d_drop (Dropout) (None, 7, 7, 192) 0 ['block6d_project_bn[0][0]'] block6d_add (Add) (None, 7, 7, 192) 0 ['block6d_drop[0][0]', 'block6c_add[0][0]'] block7a_expand_conv (Conv2D) (None, 7, 7, 1152) 221184 ['block6d_add[0][0]'] block7a_expand_bn (BatchNormal (None, 7, 7, 1152) 4608 ['block7a_expand_conv[0][0]'] ization) block7a_expand_activation (Act (None, 7, 7, 1152) 0 ['block7a_expand_bn[0][0]'] ivation) block7a_dwconv (DepthwiseConv2 (None, 7, 7, 1152) 10368 ['block7a_expand_activation[0][0] D) '] block7a_bn (BatchNormalization (None, 7, 7, 1152) 4608 ['block7a_dwconv[0][0]'] ) block7a_activation (Activation (None, 7, 7, 1152) 0 ['block7a_bn[0][0]'] ) block7a_se_squeeze (GlobalAver (None, 1152) 0 ['block7a_activation[0][0]'] agePooling2D) block7a_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block7a_se_squeeze[0][0]'] block7a_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block7a_se_reshape[0][0]'] block7a_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block7a_se_reduce[0][0]'] block7a_se_excite (Multiply) (None, 7, 7, 1152) 0 ['block7a_activation[0][0]', 'block7a_se_expand[0][0]'] block7a_project_conv (Conv2D) (None, 7, 7, 320) 368640 ['block7a_se_excite[0][0]'] block7a_project_bn (BatchNorma (None, 7, 7, 320) 1280 ['block7a_project_conv[0][0]'] lization) top_conv (Conv2D) (None, 7, 7, 1280) 409600 ['block7a_project_bn[0][0]'] top_bn (BatchNormalization) (None, 7, 7, 1280) 5120 ['top_conv[0][0]'] top_activation (Activation) (None, 7, 7, 1280) 0 ['top_bn[0][0]'] global_average_pooling2d (Glob (None, 1280) 0 ['top_activation[0][0]'] alAveragePooling2D) dense_5 (Dense) (None, 3755) 4810155 ['global_average_pooling2d[0][0]' ] ================================================================================================== Total params: 8,859,726 Trainable params: 8,817,703 Non-trainable params: 42,023 __________________________________________________________________________________________________
def train_efficientnet():
print(f'number of classes: {num_classes}')
history_path = 'history_efficientnet.pickle'
save_history_callback = SaveHistoryCallback(history_path)
train_dataset = load_datasets(train_path)
test_dataset = load_datasets(test_path)
train_dataset = train_dataset.map(preprocess_efficientnet).shuffle(100).batch(32).repeat()
test_dataset = test_dataset.shuffle(100).map(preprocess_efficientnet).batch(32).repeat()
print(f'train dataset: {train_dataset}')
# build model
model = effcientnetB0_model()
# model.summary() # too long to display
# latest checkpoints
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
if latest_ckpt:
print(f'model resumed from: {latest_ckpt}')
model.load_weights(latest_ckpt)
else:
print('training from scratch')
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(ckpt_path,
save_weights_only=True,
verbose=1,
save_freq='epoch',
save_best_only=True),
save_history_callback
]
try:
model.fit(
train_dataset,
validation_data=test_dataset,
validation_steps=1000,
epochs=100,
steps_per_epoch=1024,
callbacks=callbacks,
use_multiprocessing=True)
except KeyboardInterrupt:
logging.info('keras model saved. KeyboardInterrupt')
save_history_callback.on_train_end()
return model
model.save_weights(ckpt_path.format(epoch=0))
model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr_eff.h5'))
logging.info('All epoch finished. keras model saved.')
return model
# train the model (computation heavy)
model = train_efficientnet()
number of classes: 3755 train dataset: <RepeatDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))> model resumed from: ./checkpoints/efficient_net\cn_ocr-1.ckpt Epoch 1/100 6/1024 [..............................] - ETA: 2:09 - loss: 0.0037 - accuracy: 1.0000WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0558s vs `on_train_batch_end` time: 0.0590s). Check your callbacks. 1024/1024 [==============================] - ETA: 0s - loss: 0.0130 - accuracy: 0.9972 Epoch 1: val_loss improved from inf to 0.54211, saving model to ./checkpoints/efficient_net\cn_ocr-1.ckpt 1024/1024 [==============================] - 163s 152ms/step - loss: 0.0130 - accuracy: 0.9972 - val_loss: 0.5421 - val_accuracy: 0.8860 Epoch 2/100 1024/1024 [==============================] - ETA: 0s - loss: 0.1858 - accuracy: 0.9505 Epoch 2: val_loss did not improve from 0.54211 1024/1024 [==============================] - 155s 152ms/step - loss: 0.1858 - accuracy: 0.9505 - val_loss: 0.7122 - val_accuracy: 0.8370 Epoch 3/100 1024/1024 [==============================] - ETA: 0s - loss: 0.1169 - accuracy: 0.9684 Epoch 3: val_loss did not improve from 0.54211 1024/1024 [==============================] - 154s 150ms/step - loss: 0.1169 - accuracy: 0.9684 - val_loss: 0.5773 - val_accuracy: 0.8697 Epoch 4/100 1024/1024 [==============================] - ETA: 0s - loss: 0.1093 - accuracy: 0.9713 Epoch 4: val_loss did not improve from 0.54211 1024/1024 [==============================] - 154s 150ms/step - loss: 0.1093 - accuracy: 0.9713 - val_loss: 0.5893 - val_accuracy: 0.8614 Epoch 5/100 1024/1024 [==============================] - ETA: 0s - loss: 0.1342 - accuracy: 0.9635 Epoch 5: val_loss did not improve from 0.54211 1024/1024 [==============================] - 154s 150ms/step - loss: 0.1342 - accuracy: 0.9635 - val_loss: 0.5836 - val_accuracy: 0.8711 Epoch 6/100 1024/1024 [==============================] - ETA: 0s - loss: 0.2706 - accuracy: 0.9327 Epoch 6: val_loss did not improve from 0.54211 1024/1024 [==============================] - 154s 150ms/step - loss: 0.2706 - accuracy: 0.9327 - val_loss: 0.6092 - val_accuracy: 0.8580 Epoch 7/100 1024/1024 [==============================] - ETA: 0s - loss: 0.1408 - accuracy: 0.9673 Epoch 7: val_loss did not improve from 0.54211 1024/1024 [==============================] - 154s 151ms/step - loss: 0.1408 - accuracy: 0.9673 - val_loss: 0.9957 - val_accuracy: 0.7847 Epoch 8/100 40/1024 [>.............................] - ETA: 2:02 - loss: 0.0888 - accuracy: 0.9766
18:38:56 05.30 INFO 2887911528.py:50]: keras model saved. KeyboardInterrupt
with open('history_efficientnet.pickle', 'rb') as f:
history = pickle.load(f)
# Plot accuracy
plt.figure(figsize=(10, 7))
plt.plot(history['accuracy'])
plt.plot(history['val_accuracy'])
plt.title('Model accuracy - EfficientNetB0', fontsize=15)
plt.ylabel('Accuracy', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')
# show the highest accuracy in the graph
max_acc = max(history['val_accuracy'])
plt.annotate(f'max accuracy: {max_acc:.4f}', xy=(np.argmax(history['val_accuracy']), max_acc),
xytext=(np.argmax(history['val_accuracy']) - 20, max_acc - 0.2),
arrowprops=dict(facecolor='black', shrink=0.001), fontsize=15)
max_acc_train = max(history['accuracy'])
plt.annotate(f'max accuracy: {max_acc_train:.4f}', xy=(np.argmax(history['accuracy']), max_acc_train),
xytext=(np.argmax(history['accuracy']) - 40, max_acc_train + 0.01),
arrowprops=dict(facecolor='black', shrink=0.001), fontsize=15)
plt.show()
# Plot loss
plt.figure(figsize=(10, 7))
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.title('Model loss - EfficientNetB0', fontsize=15)
plt.ylabel('Loss', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
# Load the best model
model = effcientnetB0_model()
model.load_weights(tf.train.latest_checkpoint(os.path.dirname(ckpt_path)))
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x2504d884130>
def display_result(model, ds, characters):
ds = ds.shuffle(100).map(preprocess_efficientnet).batch(32).repeat()
_, axe = plt.subplots(3, 3, figsize=(15, 15))
plt.rcParams["font.sans-serif"]=["SimHei"] #设置字体
plt.rcParams["axes.unicode_minus"]=False #该语句解决图像中的“-”负号的乱码问题
for _, data in enumerate(ds.take(1)):
images, labels = data
predictions = model.predict(images)
for i in range(9):
axe[i//3, i%3].imshow(images[i].numpy().astype('uint8'), cmap='gray')
axe[i//3, i%3].axis('on')
axe[i//3, i%3].set_title(f'pred: {characters[np.argmax(predictions[i])]}, label: {characters[labels[i]]}, Is correct: {np.argmax(predictions[i]) == labels[i]}')
plt.show()
test_ds = load_datasets(test_path)
display_result(model, test_ds, all_characters)
1/1 [==============================] - 1s 949ms/step
We can see that for the model with EfficientNet
, the validation accuracy
can reach 85%
with about 80 epoches, which is a huge improvement compared with the simple CNN model.