This is the final project of CSE204 Machine Learning course in Ecole Polytechnique. We use CNN and EfficientNet model to do OCR for singular handwriting Chinese Characters. The project is mainly based on Tensorflow 2.0 API.
Making sure to install all the packages in the requirements.txt file, we can now start the process of OCRing the Chinese text.
# import necessary packages
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os
from alfred.utils.log import logger as logging
from tensorflow.keras.applications import EfficientNetB0
import pickle
import random
# load character mapping
this_dir = os.path.dirname(os.path.abspath("__file__"))
def load_characters():
# a = open(os.path.join(this_dir, 'dataset\\characters.txt'), 'r').readlines() # If you are using Windows
a = open(os.path.join(this_dir, 'dataset/characters.txt'), 'r').readlines() # If you are using Mac OS
return [i.strip() for i in a]
all_characters = load_characters()
num_classes = len(all_characters)
print(f'There are {num_classes} classes of chinese characters in the dataset')
logging.info('all characters: {}'.format(num_classes))
18:10:37 05.30 INFO 1114677896.py:12]: all characters: 3755
There are 3755 classes of chinese characters in the dataset
We want to load the dataset from the local file. The dataset is some tfrecord files which conotains images, labels, and resolution information for the images. The reason for saving as tfrecord format is to improve the training performance.
For more information about the dataset, please read the file README.md Part 3. Dataset Preparation. We also provide the converted dataset in the following link. You can directly use it for training.
# Necessary parameters
IMG_SIZE = 80 # This size is fixed for EfficientNetB0
# decode tfrecord to images and labels
def parse_image(record):
"""
:param record: tfrecord file
:return: image and label
"""
features = tf.io.parse_single_example(record,
features={
'width':
tf.io.FixedLenFeature([], tf.int64),
'height':
tf.io.FixedLenFeature([], tf.int64),
'label':
tf.io.FixedLenFeature([], tf.int64),
'image':
tf.io.FixedLenFeature([], tf.string),
})
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
w = features['width']
h = features['height']
img = tf.cast(tf.reshape(img, (w, h)), dtype=tf.float32)
label = tf.cast(features['label'], tf.int64)
return {'image': img, 'label': label}
def load_datasets(filename):
"""
:param filename: tfrecord file
:return: dataset
"""
dataset = tf.data.TFRecordDataset([filename])
dataset = dataset.map(parse_image)
return dataset
In this part, we can see what are the datasets look like.
train_ds = load_datasets('dataset/train.tfrecord') # read train.tfrecord
test_ds = load_datasets('dataset/test.tfrecord')
train_mapped = train_ds.shuffle(100).batch(32).repeat()
test_mapped = test_ds.batch(32).repeat()
train_mapped
<RepeatDataset element_spec={'image': TensorSpec(shape=(None, None, None), dtype=tf.float32, name=None), 'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}>
# summary inforamtion of the dataset, plot the width and height of the images
shapes_plot_train = [[], []]
shapes_plot_test = [[], []]
for i in train_ds:
shapes_plot_train[0].append(i['image'].shape[0])
shapes_plot_train[1].append(i['image'].shape[1])
for i in test_ds:
shapes_plot_test[0].append(i['image'].shape[0])
shapes_plot_test[1].append(i['image'].shape[1])
# show the number of pictures in the dataset
print(f'There are {len(shapes_plot_train[0])} pictures in the training dataset')
print(f'There are {len(shapes_plot_test[0])} pictures in the testing dataset')
There are 897758 pictures in the training dataset There are 223991 pictures in the testing dataset
def plot_info_dataset(shapes_plot):
# plot the histogram of the width and height of the images, and the third plot is the scatter plot
fig, ax = plt.subplots(1, 3, figsize=(20, 5))
ax[0].hist(shapes_plot[0], bins=100)
ax[0].set_title('width')
ax[0].set_xlabel('pixels')
ax[0].set_ylabel('frequency')
ax[1].hist(shapes_plot[1], bins=100)
ax[1].set_title('height')
ax[0].set_xlabel('pixels')
ax[0].set_ylabel('frequency')
ax[2].scatter(shapes_plot[0], shapes_plot[1])
ax[2].set_title('width vs height')
ax[2].set_xlabel('width (pixels)')
ax[2].set_ylabel('height (pixels)')
plt.show()
print('------------------- Information of training dataset -------------------')
plot_info_dataset(shapes_plot_train)
print('------------------- Information of testing dataset -------------------')
plot_info_dataset(shapes_plot_test)
------------------- Information of training dataset -------------------
------------------- Information of testing dataset -------------------
We can also plot some images and labels from the datasets.
# default_font = FontProperties()
def plot_img(ds, characters, num =9):
plt.rcParams["font.sans-serif"]=["SimHei"] #Set font to avoid garbled characters
plt.rcParams["axes.unicode_minus"]=False # The same as above
ax, fig = plt.subplots(3, 3, figsize=(12, 12))
for i, data in enumerate(ds.take(num)):
img = data['image'].numpy()
label = data['label'].numpy()
fig[i//3][i%3].imshow(img, cmap='gray')
fig[i//3][i%3].set_title(f'label: {characters[label]}, size: {img.shape}', fontsize=15)
plt.show()
plot_img(train_ds, all_characters, num=9)
Important, Please Read!!! If you are using Mac OS, you may find that the labels are not displayed correctly. This is because the package Matplotlib is not compatible with SimHei font. To solve this problem, you can read this MacOS Chinese Characters Label Problem.
In this part, we will define the models and training functions.
In the first part, We will first start from a simple CNN model.
# some basic parameters
TARGET_SIZE = 64
IMG_SIZE = 224 # This size is fixed for EfficientNetB0
ckpt_path = './checkpoints/simple_net/cn_ocr-{epoch}.ckpt'
train_path = 'dataset/train.tfrecord'
test_path = 'dataset/test.tfrecord'
# image preprocessing
def preprocess(ds):
"""
:param ds: dataset
:return: image and label
"""
ds['image'] = tf.expand_dims(ds['image'], axis=-1)
ds['image'] = tf.image.resize(ds['image'], (TARGET_SIZE, TARGET_SIZE))
ds['image'] = (ds['image'] - 128.) / 128.
return ds['image'], ds['label']
# history log callback function
#
# This part is used to collecting the training statics and save them into a file.
#
class SaveHistoryCallback(tf.keras.callbacks.Callback):
def __init__(self, history_path):
super().__init__()
self.history_path = history_path
if os.path.exists(self.history_path):
with open(self.history_path, 'rb') as f:
self.history = pickle.load(f)
else:
self.history = {}
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
for key in logs:
if key in self.history:
self.history[key].append(logs[key])
else:
self.history[key] = [logs[key]]
def on_train_end(self, logs=None):
with open(self.history_path, 'wb') as f:
pickle.dump(self.history, f)
# Define a simple CNN model
def simple_net(input_shape, n_classes):
model = tf.keras.Sequential([
layers.Conv2D(input_shape=input_shape, filters=32, kernel_size=(3, 3), strides=(1, 1),
padding='same', activation='relu'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Flatten(),
# layers.Dense(1024, activation='relu'),
layers.Dense(n_classes, activation='softmax')
])
return model
# Define with a more complex CNN model - However, this model is not even converging
def CNN_ComplexModel_1(input_shape, n_classes):
model = tf.keras.Sequential([
layers.Conv2D(input_shape=input_shape, filters=64, kernel_size=(3, 3), strides=(1, 1),
padding='same', activation='relu'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=256, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=512, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Flatten(),
layers.Dense(1024, activation='relu'),
layers.Dense(n_classes, activation='softmax')
])
return model
# show the summary of the model
model = simple_net((TARGET_SIZE, TARGET_SIZE, 1), len(all_characters))
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 64, 64, 32) 320
max_pooling2d (MaxPooling2D (None, 32, 32, 32) 0
)
conv2d_1 (Conv2D) (None, 32, 32, 64) 18496
max_pooling2d_1 (MaxPooling (None, 16, 16, 64) 0
2D)
flatten (Flatten) (None, 16384) 0
dense (Dense) (None, 3755) 61525675
=================================================================
Total params: 61,544,491
Trainable params: 61,544,491
Non-trainable params: 0
_________________________________________________________________
model_cnn_complex = CNN_ComplexModel_1((TARGET_SIZE, TARGET_SIZE, 1), len(all_characters))
model_cnn_complex.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_2 (Conv2D) (None, 64, 64, 64) 640
max_pooling2d_2 (MaxPooling (None, 32, 32, 64) 0
2D)
conv2d_3 (Conv2D) (None, 32, 32, 128) 73856
max_pooling2d_3 (MaxPooling (None, 16, 16, 128) 0
2D)
conv2d_4 (Conv2D) (None, 16, 16, 256) 295168
max_pooling2d_4 (MaxPooling (None, 8, 8, 256) 0
2D)
conv2d_5 (Conv2D) (None, 8, 8, 512) 1180160
max_pooling2d_5 (MaxPooling (None, 4, 4, 512) 0
2D)
flatten_1 (Flatten) (None, 8192) 0
dense_1 (Dense) (None, 1024) 8389632
dense_2 (Dense) (None, 3755) 3848875
=================================================================
Total params: 13,788,331
Trainable params: 13,788,331
Non-trainable params: 0
_________________________________________________________________
In this part. We will try to train the model.
def train_simple():
print(f'number of classes: {num_classes}')
history_path = 'history_simple.pickle'
save_history_callback = SaveHistoryCallback(history_path)
train_dataset = load_datasets(train_path)
test_dataset = load_datasets(test_path)
train_dataset = train_dataset.map(preprocess).shuffle(100).batch(32).repeat()
test_dataset = test_dataset.shuffle(100).map(preprocess).batch(32)
print(f'train dataset: {train_dataset}')
# build model
model = simple_net((TARGET_SIZE, TARGET_SIZE, 1), num_classes)
model.summary()
# latest checkpoints
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
if latest_ckpt:
print(f'model resumed from: {latest_ckpt}')
model.load_weights(latest_ckpt)
else:
print('training from scratch')
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(ckpt_path,
save_weights_only=True,
verbose=1,
save_freq='epoch',
save_best_only=True),
save_history_callback
]
try:
model.fit(
train_dataset,
validation_data=test_dataset,
validation_steps=1000,
epochs=100,
steps_per_epoch=1024,
callbacks=callbacks)
except KeyboardInterrupt:
logging.info('keras model saved. KeyboardInterrupt')
save_history_callback.on_train_end()
return model
model.save_weights(ckpt_path.format(epoch=0))
model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr_simple.h5'))
logging.info('All epoch finished. keras model saved.')
return model
From our trainning result, In the two CNN models we provide above, the simpler one has better performance. The more complex CNN model not even converge which is a quite interesting result.
# train the model (computation heavy)
model = train_simple()
number of classes: 3755
train dataset: <RepeatDataset element_spec=(TensorSpec(shape=(None, 64, 64, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_6 (Conv2D) (None, 64, 64, 32) 320
max_pooling2d_6 (MaxPooling (None, 32, 32, 32) 0
2D)
conv2d_7 (Conv2D) (None, 32, 32, 64) 18496
max_pooling2d_7 (MaxPooling (None, 16, 16, 64) 0
2D)
flatten_2 (Flatten) (None, 16384) 0
dense_3 (Dense) (None, 3755) 61525675
=================================================================
Total params: 61,544,491
Trainable params: 61,544,491
Non-trainable params: 0
_________________________________________________________________
model resumed from: ./checkpoints/simple_net\cn_ocr-1.ckpt
Epoch 1/100
4/1024 [..............................] - ETA: 20s - loss: 0.0467 - accuracy: 0.9766 WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0068s vs `on_train_batch_end` time: 0.0108s). Check your callbacks.
1022/1024 [============================>.] - ETA: 0s - loss: 0.0242 - accuracy: 0.9927
Epoch 1: val_loss improved from inf to 4.37403, saving model to ./checkpoints/simple_net\cn_ocr-1.ckpt
1024/1024 [==============================] - 20s 16ms/step - loss: 0.0242 - accuracy: 0.9927 - val_loss: 4.3740 - val_accuracy: 0.4745
Epoch 2/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1772 - accuracy: 0.9523
Epoch 2: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 15ms/step - loss: 0.1772 - accuracy: 0.9523 - val_loss: 5.0706 - val_accuracy: 0.3959
Epoch 3/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1651 - accuracy: 0.9527
Epoch 3: val_loss did not improve from 4.37403
1024/1024 [==============================] - 14s 14ms/step - loss: 0.1652 - accuracy: 0.9528 - val_loss: 5.4529 - val_accuracy: 0.4083
Epoch 4/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1733 - accuracy: 0.9509
Epoch 4: val_loss did not improve from 4.37403
1024/1024 [==============================] - 14s 14ms/step - loss: 0.1732 - accuracy: 0.9509 - val_loss: 4.7948 - val_accuracy: 0.4309
Epoch 5/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1912 - accuracy: 0.9467
Epoch 5: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.1910 - accuracy: 0.9467 - val_loss: 4.8460 - val_accuracy: 0.4083
Epoch 6/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.2448 - accuracy: 0.9330
Epoch 6: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.2444 - accuracy: 0.9331 - val_loss: 5.6845 - val_accuracy: 0.3722
Epoch 7/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.2070 - accuracy: 0.9426
Epoch 7: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.2068 - accuracy: 0.9426 - val_loss: 5.0307 - val_accuracy: 0.3918
Epoch 8/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.2148 - accuracy: 0.9402
Epoch 8: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.2151 - accuracy: 0.9401 - val_loss: 5.2210 - val_accuracy: 0.4089
Epoch 9/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.2385 - accuracy: 0.9332
Epoch 9: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.2385 - accuracy: 0.9332 - val_loss: 4.9118 - val_accuracy: 0.4042
Epoch 10/100
1020/1024 [============================>.] - ETA: 0s - loss: 0.2330 - accuracy: 0.9342
Epoch 10: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.2326 - accuracy: 0.9343 - val_loss: 6.6977 - val_accuracy: 0.3193
Epoch 11/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1729 - accuracy: 0.9498
Epoch 11: val_loss did not improve from 4.37403
1024/1024 [==============================] - 14s 14ms/step - loss: 0.1730 - accuracy: 0.9497 - val_loss: 4.9370 - val_accuracy: 0.4512
Epoch 12/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.2010 - accuracy: 0.9420
Epoch 12: val_loss did not improve from 4.37403
1024/1024 [==============================] - 14s 14ms/step - loss: 0.2009 - accuracy: 0.9420 - val_loss: 5.2270 - val_accuracy: 0.4319
Epoch 13/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1908 - accuracy: 0.9446
Epoch 13: val_loss did not improve from 4.37403
1024/1024 [==============================] - 14s 14ms/step - loss: 0.1910 - accuracy: 0.9445 - val_loss: 4.6441 - val_accuracy: 0.4633
Epoch 14/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.9264 - accuracy: 0.7941
18:20:18 05.30 INFO 3396354300.py:49]: keras model saved. KeyboardInterrupt
with open('history_simple.pickle', 'rb') as f:
history = pickle.load(f)
# Plot accuracy
plt.figure(figsize=(10, 7))
plt.plot(history['accuracy'])
plt.plot(history['val_accuracy'])
plt.title('Model accuracy - Simple CNN', fontsize=15)
plt.ylabel('Accuracy', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')
# show the highest accuracy in the graph
max_acc = max(history['val_accuracy'])
plt.annotate(f'max accuracy: {max_acc:.4f}', xy=(np.argmax(history['val_accuracy']), max_acc),
xytext=(np.argmax(history['val_accuracy']) - 10, max_acc + 0.1),
arrowprops=dict(facecolor='black', shrink=0.03), fontsize=15)
max_acc_train = max(history['accuracy'])
plt.annotate(f'max accuracy: {max_acc_train:.4f}', xy=(np.argmax(history['accuracy']), max_acc_train),
xytext=(np.argmax(history['accuracy']) - 30, max_acc_train - 0.1),
arrowprops=dict(facecolor='black', shrink=0.03), fontsize=15)
plt.show()
# Plot loss
plt.figure(figsize=(10, 7))
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.title('Model loss - Simple CNN', fontsize=15)
plt.ylabel('Loss', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
# Load the best model
model = simple_net((TARGET_SIZE, TARGET_SIZE, 1), num_classes)
model.load_weights(tf.train.latest_checkpoint(os.path.dirname(ckpt_path)))
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x25053771510>
# display the result
import random
def display_result(model, ds, characters):
ds = ds.shuffle(100).map(preprocess).batch(32).repeat()
f, axe = plt.subplots(3, 3, figsize=(15, 15))
plt.rcParams["font.sans-serif"]=["SimHei"] #Set font to avoid garbled characters
plt.rcParams["axes.unicode_minus"]=False # The same as above
# take 9 samples
for data in ds.take(1):
images, labels = data
predictions = model.predict(images)
for i in range(9):
axe[i//3, i%3].imshow(images[i].numpy().reshape(64, 64), cmap='gray')
axe[i//3, i%3].axis('on')
axe[i//3, i%3].set_title(f'pred: {characters[np.argmax(predictions[i])]}, label: {characters[labels[i]]}, Is Correct: {np.argmax(predictions[i]) == labels[i]}')
plt.show()
test_ds = load_datasets(test_path)
display_result(model, test_ds, all_characters)
1/1 [==============================] - 0s 65ms/step
This model can only reach the validation accuracy to around 40%, which is not satisfying for us. We decided to use a far more complicated model called effcientnet to train the model. This model is a state-of-the-art model for image classification. We will use the B0 version of the model. For more information about the model, please read the paper and this webpage
# some basic parameters
IMG_SIZE = 224 # This size is fixed for EfficientNetB0
ckpt_path = './checkpoints/efficient_net/cn_ocr-{epoch}.ckpt'
train_path = 'dataset/train.tfrecord'
test_path = 'dataset/test.tfrecord'
def preprocess_efficientnet(ds):
"""
:param ds: dataset
:return: image and label
No normalization is needed here for EfficientNetB0
"""
ds['image'] = tf.expand_dims(ds['image'], axis=-1)
ds['image'] = tf.image.grayscale_to_rgb(ds['image'])
ds['image'] = tf.image.resize(ds['image'], (IMG_SIZE, IMG_SIZE))
return ds['image'], ds['label']
# model
def effcientnetB0_model():
inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3)) # EfficientNetB0 expects 3 channels
base_model = EfficientNetB0(include_top=False, input_tensor=inputs, weights='imagenet')
x = base_model.output
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(num_classes, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=x)
return model
# model summary of EfficientNetB0
model = effcientnetB0_model()
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 224, 224, 3 0 []
)]
rescaling (Rescaling) (None, 224, 224, 3) 0 ['input_1[0][0]']
normalization (Normalization) (None, 224, 224, 3) 7 ['rescaling[0][0]']
rescaling_1 (Rescaling) (None, 224, 224, 3) 0 ['normalization[0][0]']
stem_conv_pad (ZeroPadding2D) (None, 225, 225, 3) 0 ['rescaling_1[0][0]']
stem_conv (Conv2D) (None, 112, 112, 32 864 ['stem_conv_pad[0][0]']
)
stem_bn (BatchNormalization) (None, 112, 112, 32 128 ['stem_conv[0][0]']
)
stem_activation (Activation) (None, 112, 112, 32 0 ['stem_bn[0][0]']
)
block1a_dwconv (DepthwiseConv2 (None, 112, 112, 32 288 ['stem_activation[0][0]']
D) )
block1a_bn (BatchNormalization (None, 112, 112, 32 128 ['block1a_dwconv[0][0]']
) )
block1a_activation (Activation (None, 112, 112, 32 0 ['block1a_bn[0][0]']
) )
block1a_se_squeeze (GlobalAver (None, 32) 0 ['block1a_activation[0][0]']
agePooling2D)
block1a_se_reshape (Reshape) (None, 1, 1, 32) 0 ['block1a_se_squeeze[0][0]']
block1a_se_reduce (Conv2D) (None, 1, 1, 8) 264 ['block1a_se_reshape[0][0]']
block1a_se_expand (Conv2D) (None, 1, 1, 32) 288 ['block1a_se_reduce[0][0]']
block1a_se_excite (Multiply) (None, 112, 112, 32 0 ['block1a_activation[0][0]',
) 'block1a_se_expand[0][0]']
block1a_project_conv (Conv2D) (None, 112, 112, 16 512 ['block1a_se_excite[0][0]']
)
block1a_project_bn (BatchNorma (None, 112, 112, 16 64 ['block1a_project_conv[0][0]']
lization) )
block2a_expand_conv (Conv2D) (None, 112, 112, 96 1536 ['block1a_project_bn[0][0]']
)
block2a_expand_bn (BatchNormal (None, 112, 112, 96 384 ['block2a_expand_conv[0][0]']
ization) )
block2a_expand_activation (Act (None, 112, 112, 96 0 ['block2a_expand_bn[0][0]']
ivation) )
block2a_dwconv_pad (ZeroPaddin (None, 113, 113, 96 0 ['block2a_expand_activation[0][0]
g2D) ) ']
block2a_dwconv (DepthwiseConv2 (None, 56, 56, 96) 864 ['block2a_dwconv_pad[0][0]']
D)
block2a_bn (BatchNormalization (None, 56, 56, 96) 384 ['block2a_dwconv[0][0]']
)
block2a_activation (Activation (None, 56, 56, 96) 0 ['block2a_bn[0][0]']
)
block2a_se_squeeze (GlobalAver (None, 96) 0 ['block2a_activation[0][0]']
agePooling2D)
block2a_se_reshape (Reshape) (None, 1, 1, 96) 0 ['block2a_se_squeeze[0][0]']
block2a_se_reduce (Conv2D) (None, 1, 1, 4) 388 ['block2a_se_reshape[0][0]']
block2a_se_expand (Conv2D) (None, 1, 1, 96) 480 ['block2a_se_reduce[0][0]']
block2a_se_excite (Multiply) (None, 56, 56, 96) 0 ['block2a_activation[0][0]',
'block2a_se_expand[0][0]']
block2a_project_conv (Conv2D) (None, 56, 56, 24) 2304 ['block2a_se_excite[0][0]']
block2a_project_bn (BatchNorma (None, 56, 56, 24) 96 ['block2a_project_conv[0][0]']
lization)
block2b_expand_conv (Conv2D) (None, 56, 56, 144) 3456 ['block2a_project_bn[0][0]']
block2b_expand_bn (BatchNormal (None, 56, 56, 144) 576 ['block2b_expand_conv[0][0]']
ization)
block2b_expand_activation (Act (None, 56, 56, 144) 0 ['block2b_expand_bn[0][0]']
ivation)
block2b_dwconv (DepthwiseConv2 (None, 56, 56, 144) 1296 ['block2b_expand_activation[0][0]
D) ']
block2b_bn (BatchNormalization (None, 56, 56, 144) 576 ['block2b_dwconv[0][0]']
)
block2b_activation (Activation (None, 56, 56, 144) 0 ['block2b_bn[0][0]']
)
block2b_se_squeeze (GlobalAver (None, 144) 0 ['block2b_activation[0][0]']
agePooling2D)
block2b_se_reshape (Reshape) (None, 1, 1, 144) 0 ['block2b_se_squeeze[0][0]']
block2b_se_reduce (Conv2D) (None, 1, 1, 6) 870 ['block2b_se_reshape[0][0]']
block2b_se_expand (Conv2D) (None, 1, 1, 144) 1008 ['block2b_se_reduce[0][0]']
block2b_se_excite (Multiply) (None, 56, 56, 144) 0 ['block2b_activation[0][0]',
'block2b_se_expand[0][0]']
block2b_project_conv (Conv2D) (None, 56, 56, 24) 3456 ['block2b_se_excite[0][0]']
block2b_project_bn (BatchNorma (None, 56, 56, 24) 96 ['block2b_project_conv[0][0]']
lization)
block2b_drop (Dropout) (None, 56, 56, 24) 0 ['block2b_project_bn[0][0]']
block2b_add (Add) (None, 56, 56, 24) 0 ['block2b_drop[0][0]',
'block2a_project_bn[0][0]']
block3a_expand_conv (Conv2D) (None, 56, 56, 144) 3456 ['block2b_add[0][0]']
block3a_expand_bn (BatchNormal (None, 56, 56, 144) 576 ['block3a_expand_conv[0][0]']
ization)
block3a_expand_activation (Act (None, 56, 56, 144) 0 ['block3a_expand_bn[0][0]']
ivation)
block3a_dwconv_pad (ZeroPaddin (None, 59, 59, 144) 0 ['block3a_expand_activation[0][0]
g2D) ']
block3a_dwconv (DepthwiseConv2 (None, 28, 28, 144) 3600 ['block3a_dwconv_pad[0][0]']
D)
block3a_bn (BatchNormalization (None, 28, 28, 144) 576 ['block3a_dwconv[0][0]']
)
block3a_activation (Activation (None, 28, 28, 144) 0 ['block3a_bn[0][0]']
)
block3a_se_squeeze (GlobalAver (None, 144) 0 ['block3a_activation[0][0]']
agePooling2D)
block3a_se_reshape (Reshape) (None, 1, 1, 144) 0 ['block3a_se_squeeze[0][0]']
block3a_se_reduce (Conv2D) (None, 1, 1, 6) 870 ['block3a_se_reshape[0][0]']
block3a_se_expand (Conv2D) (None, 1, 1, 144) 1008 ['block3a_se_reduce[0][0]']
block3a_se_excite (Multiply) (None, 28, 28, 144) 0 ['block3a_activation[0][0]',
'block3a_se_expand[0][0]']
block3a_project_conv (Conv2D) (None, 28, 28, 40) 5760 ['block3a_se_excite[0][0]']
block3a_project_bn (BatchNorma (None, 28, 28, 40) 160 ['block3a_project_conv[0][0]']
lization)
block3b_expand_conv (Conv2D) (None, 28, 28, 240) 9600 ['block3a_project_bn[0][0]']
block3b_expand_bn (BatchNormal (None, 28, 28, 240) 960 ['block3b_expand_conv[0][0]']
ization)
block3b_expand_activation (Act (None, 28, 28, 240) 0 ['block3b_expand_bn[0][0]']
ivation)
block3b_dwconv (DepthwiseConv2 (None, 28, 28, 240) 6000 ['block3b_expand_activation[0][0]
D) ']
block3b_bn (BatchNormalization (None, 28, 28, 240) 960 ['block3b_dwconv[0][0]']
)
block3b_activation (Activation (None, 28, 28, 240) 0 ['block3b_bn[0][0]']
)
block3b_se_squeeze (GlobalAver (None, 240) 0 ['block3b_activation[0][0]']
agePooling2D)
block3b_se_reshape (Reshape) (None, 1, 1, 240) 0 ['block3b_se_squeeze[0][0]']
block3b_se_reduce (Conv2D) (None, 1, 1, 10) 2410 ['block3b_se_reshape[0][0]']
block3b_se_expand (Conv2D) (None, 1, 1, 240) 2640 ['block3b_se_reduce[0][0]']
block3b_se_excite (Multiply) (None, 28, 28, 240) 0 ['block3b_activation[0][0]',
'block3b_se_expand[0][0]']
block3b_project_conv (Conv2D) (None, 28, 28, 40) 9600 ['block3b_se_excite[0][0]']
block3b_project_bn (BatchNorma (None, 28, 28, 40) 160 ['block3b_project_conv[0][0]']
lization)
block3b_drop (Dropout) (None, 28, 28, 40) 0 ['block3b_project_bn[0][0]']
block3b_add (Add) (None, 28, 28, 40) 0 ['block3b_drop[0][0]',
'block3a_project_bn[0][0]']
block4a_expand_conv (Conv2D) (None, 28, 28, 240) 9600 ['block3b_add[0][0]']
block4a_expand_bn (BatchNormal (None, 28, 28, 240) 960 ['block4a_expand_conv[0][0]']
ization)
block4a_expand_activation (Act (None, 28, 28, 240) 0 ['block4a_expand_bn[0][0]']
ivation)
block4a_dwconv_pad (ZeroPaddin (None, 29, 29, 240) 0 ['block4a_expand_activation[0][0]
g2D) ']
block4a_dwconv (DepthwiseConv2 (None, 14, 14, 240) 2160 ['block4a_dwconv_pad[0][0]']
D)
block4a_bn (BatchNormalization (None, 14, 14, 240) 960 ['block4a_dwconv[0][0]']
)
block4a_activation (Activation (None, 14, 14, 240) 0 ['block4a_bn[0][0]']
)
block4a_se_squeeze (GlobalAver (None, 240) 0 ['block4a_activation[0][0]']
agePooling2D)
block4a_se_reshape (Reshape) (None, 1, 1, 240) 0 ['block4a_se_squeeze[0][0]']
block4a_se_reduce (Conv2D) (None, 1, 1, 10) 2410 ['block4a_se_reshape[0][0]']
block4a_se_expand (Conv2D) (None, 1, 1, 240) 2640 ['block4a_se_reduce[0][0]']
block4a_se_excite (Multiply) (None, 14, 14, 240) 0 ['block4a_activation[0][0]',
'block4a_se_expand[0][0]']
block4a_project_conv (Conv2D) (None, 14, 14, 80) 19200 ['block4a_se_excite[0][0]']
block4a_project_bn (BatchNorma (None, 14, 14, 80) 320 ['block4a_project_conv[0][0]']
lization)
block4b_expand_conv (Conv2D) (None, 14, 14, 480) 38400 ['block4a_project_bn[0][0]']
block4b_expand_bn (BatchNormal (None, 14, 14, 480) 1920 ['block4b_expand_conv[0][0]']
ization)
block4b_expand_activation (Act (None, 14, 14, 480) 0 ['block4b_expand_bn[0][0]']
ivation)
block4b_dwconv (DepthwiseConv2 (None, 14, 14, 480) 4320 ['block4b_expand_activation[0][0]
D) ']
block4b_bn (BatchNormalization (None, 14, 14, 480) 1920 ['block4b_dwconv[0][0]']
)
block4b_activation (Activation (None, 14, 14, 480) 0 ['block4b_bn[0][0]']
)
block4b_se_squeeze (GlobalAver (None, 480) 0 ['block4b_activation[0][0]']
agePooling2D)
block4b_se_reshape (Reshape) (None, 1, 1, 480) 0 ['block4b_se_squeeze[0][0]']
block4b_se_reduce (Conv2D) (None, 1, 1, 20) 9620 ['block4b_se_reshape[0][0]']
block4b_se_expand (Conv2D) (None, 1, 1, 480) 10080 ['block4b_se_reduce[0][0]']
block4b_se_excite (Multiply) (None, 14, 14, 480) 0 ['block4b_activation[0][0]',
'block4b_se_expand[0][0]']
block4b_project_conv (Conv2D) (None, 14, 14, 80) 38400 ['block4b_se_excite[0][0]']
block4b_project_bn (BatchNorma (None, 14, 14, 80) 320 ['block4b_project_conv[0][0]']
lization)
block4b_drop (Dropout) (None, 14, 14, 80) 0 ['block4b_project_bn[0][0]']
block4b_add (Add) (None, 14, 14, 80) 0 ['block4b_drop[0][0]',
'block4a_project_bn[0][0]']
block4c_expand_conv (Conv2D) (None, 14, 14, 480) 38400 ['block4b_add[0][0]']
block4c_expand_bn (BatchNormal (None, 14, 14, 480) 1920 ['block4c_expand_conv[0][0]']
ization)
block4c_expand_activation (Act (None, 14, 14, 480) 0 ['block4c_expand_bn[0][0]']
ivation)
block4c_dwconv (DepthwiseConv2 (None, 14, 14, 480) 4320 ['block4c_expand_activation[0][0]
D) ']
block4c_bn (BatchNormalization (None, 14, 14, 480) 1920 ['block4c_dwconv[0][0]']
)
block4c_activation (Activation (None, 14, 14, 480) 0 ['block4c_bn[0][0]']
)
block4c_se_squeeze (GlobalAver (None, 480) 0 ['block4c_activation[0][0]']
agePooling2D)
block4c_se_reshape (Reshape) (None, 1, 1, 480) 0 ['block4c_se_squeeze[0][0]']
block4c_se_reduce (Conv2D) (None, 1, 1, 20) 9620 ['block4c_se_reshape[0][0]']
block4c_se_expand (Conv2D) (None, 1, 1, 480) 10080 ['block4c_se_reduce[0][0]']
block4c_se_excite (Multiply) (None, 14, 14, 480) 0 ['block4c_activation[0][0]',
'block4c_se_expand[0][0]']
block4c_project_conv (Conv2D) (None, 14, 14, 80) 38400 ['block4c_se_excite[0][0]']
block4c_project_bn (BatchNorma (None, 14, 14, 80) 320 ['block4c_project_conv[0][0]']
lization)
block4c_drop (Dropout) (None, 14, 14, 80) 0 ['block4c_project_bn[0][0]']
block4c_add (Add) (None, 14, 14, 80) 0 ['block4c_drop[0][0]',
'block4b_add[0][0]']
block5a_expand_conv (Conv2D) (None, 14, 14, 480) 38400 ['block4c_add[0][0]']
block5a_expand_bn (BatchNormal (None, 14, 14, 480) 1920 ['block5a_expand_conv[0][0]']
ization)
block5a_expand_activation (Act (None, 14, 14, 480) 0 ['block5a_expand_bn[0][0]']
ivation)
block5a_dwconv (DepthwiseConv2 (None, 14, 14, 480) 12000 ['block5a_expand_activation[0][0]
D) ']
block5a_bn (BatchNormalization (None, 14, 14, 480) 1920 ['block5a_dwconv[0][0]']
)
block5a_activation (Activation (None, 14, 14, 480) 0 ['block5a_bn[0][0]']
)
block5a_se_squeeze (GlobalAver (None, 480) 0 ['block5a_activation[0][0]']
agePooling2D)
block5a_se_reshape (Reshape) (None, 1, 1, 480) 0 ['block5a_se_squeeze[0][0]']
block5a_se_reduce (Conv2D) (None, 1, 1, 20) 9620 ['block5a_se_reshape[0][0]']
block5a_se_expand (Conv2D) (None, 1, 1, 480) 10080 ['block5a_se_reduce[0][0]']
block5a_se_excite (Multiply) (None, 14, 14, 480) 0 ['block5a_activation[0][0]',
'block5a_se_expand[0][0]']
block5a_project_conv (Conv2D) (None, 14, 14, 112) 53760 ['block5a_se_excite[0][0]']
block5a_project_bn (BatchNorma (None, 14, 14, 112) 448 ['block5a_project_conv[0][0]']
lization)
block5b_expand_conv (Conv2D) (None, 14, 14, 672) 75264 ['block5a_project_bn[0][0]']
block5b_expand_bn (BatchNormal (None, 14, 14, 672) 2688 ['block5b_expand_conv[0][0]']
ization)
block5b_expand_activation (Act (None, 14, 14, 672) 0 ['block5b_expand_bn[0][0]']
ivation)
block5b_dwconv (DepthwiseConv2 (None, 14, 14, 672) 16800 ['block5b_expand_activation[0][0]
D) ']
block5b_bn (BatchNormalization (None, 14, 14, 672) 2688 ['block5b_dwconv[0][0]']
)
block5b_activation (Activation (None, 14, 14, 672) 0 ['block5b_bn[0][0]']
)
block5b_se_squeeze (GlobalAver (None, 672) 0 ['block5b_activation[0][0]']
agePooling2D)
block5b_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block5b_se_squeeze[0][0]']
block5b_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block5b_se_reshape[0][0]']
block5b_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block5b_se_reduce[0][0]']
block5b_se_excite (Multiply) (None, 14, 14, 672) 0 ['block5b_activation[0][0]',
'block5b_se_expand[0][0]']
block5b_project_conv (Conv2D) (None, 14, 14, 112) 75264 ['block5b_se_excite[0][0]']
block5b_project_bn (BatchNorma (None, 14, 14, 112) 448 ['block5b_project_conv[0][0]']
lization)
block5b_drop (Dropout) (None, 14, 14, 112) 0 ['block5b_project_bn[0][0]']
block5b_add (Add) (None, 14, 14, 112) 0 ['block5b_drop[0][0]',
'block5a_project_bn[0][0]']
block5c_expand_conv (Conv2D) (None, 14, 14, 672) 75264 ['block5b_add[0][0]']
block5c_expand_bn (BatchNormal (None, 14, 14, 672) 2688 ['block5c_expand_conv[0][0]']
ization)
block5c_expand_activation (Act (None, 14, 14, 672) 0 ['block5c_expand_bn[0][0]']
ivation)
block5c_dwconv (DepthwiseConv2 (None, 14, 14, 672) 16800 ['block5c_expand_activation[0][0]
D) ']
block5c_bn (BatchNormalization (None, 14, 14, 672) 2688 ['block5c_dwconv[0][0]']
)
block5c_activation (Activation (None, 14, 14, 672) 0 ['block5c_bn[0][0]']
)
block5c_se_squeeze (GlobalAver (None, 672) 0 ['block5c_activation[0][0]']
agePooling2D)
block5c_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block5c_se_squeeze[0][0]']
block5c_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block5c_se_reshape[0][0]']
block5c_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block5c_se_reduce[0][0]']
block5c_se_excite (Multiply) (None, 14, 14, 672) 0 ['block5c_activation[0][0]',
'block5c_se_expand[0][0]']
block5c_project_conv (Conv2D) (None, 14, 14, 112) 75264 ['block5c_se_excite[0][0]']
block5c_project_bn (BatchNorma (None, 14, 14, 112) 448 ['block5c_project_conv[0][0]']
lization)
block5c_drop (Dropout) (None, 14, 14, 112) 0 ['block5c_project_bn[0][0]']
block5c_add (Add) (None, 14, 14, 112) 0 ['block5c_drop[0][0]',
'block5b_add[0][0]']
block6a_expand_conv (Conv2D) (None, 14, 14, 672) 75264 ['block5c_add[0][0]']
block6a_expand_bn (BatchNormal (None, 14, 14, 672) 2688 ['block6a_expand_conv[0][0]']
ization)
block6a_expand_activation (Act (None, 14, 14, 672) 0 ['block6a_expand_bn[0][0]']
ivation)
block6a_dwconv_pad (ZeroPaddin (None, 17, 17, 672) 0 ['block6a_expand_activation[0][0]
g2D) ']
block6a_dwconv (DepthwiseConv2 (None, 7, 7, 672) 16800 ['block6a_dwconv_pad[0][0]']
D)
block6a_bn (BatchNormalization (None, 7, 7, 672) 2688 ['block6a_dwconv[0][0]']
)
block6a_activation (Activation (None, 7, 7, 672) 0 ['block6a_bn[0][0]']
)
block6a_se_squeeze (GlobalAver (None, 672) 0 ['block6a_activation[0][0]']
agePooling2D)
block6a_se_reshape (Reshape) (None, 1, 1, 672) 0 ['block6a_se_squeeze[0][0]']
block6a_se_reduce (Conv2D) (None, 1, 1, 28) 18844 ['block6a_se_reshape[0][0]']
block6a_se_expand (Conv2D) (None, 1, 1, 672) 19488 ['block6a_se_reduce[0][0]']
block6a_se_excite (Multiply) (None, 7, 7, 672) 0 ['block6a_activation[0][0]',
'block6a_se_expand[0][0]']
block6a_project_conv (Conv2D) (None, 7, 7, 192) 129024 ['block6a_se_excite[0][0]']
block6a_project_bn (BatchNorma (None, 7, 7, 192) 768 ['block6a_project_conv[0][0]']
lization)
block6b_expand_conv (Conv2D) (None, 7, 7, 1152) 221184 ['block6a_project_bn[0][0]']
block6b_expand_bn (BatchNormal (None, 7, 7, 1152) 4608 ['block6b_expand_conv[0][0]']
ization)
block6b_expand_activation (Act (None, 7, 7, 1152) 0 ['block6b_expand_bn[0][0]']
ivation)
block6b_dwconv (DepthwiseConv2 (None, 7, 7, 1152) 28800 ['block6b_expand_activation[0][0]
D) ']
block6b_bn (BatchNormalization (None, 7, 7, 1152) 4608 ['block6b_dwconv[0][0]']
)
block6b_activation (Activation (None, 7, 7, 1152) 0 ['block6b_bn[0][0]']
)
block6b_se_squeeze (GlobalAver (None, 1152) 0 ['block6b_activation[0][0]']
agePooling2D)
block6b_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6b_se_squeeze[0][0]']
block6b_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6b_se_reshape[0][0]']
block6b_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6b_se_reduce[0][0]']
block6b_se_excite (Multiply) (None, 7, 7, 1152) 0 ['block6b_activation[0][0]',
'block6b_se_expand[0][0]']
block6b_project_conv (Conv2D) (None, 7, 7, 192) 221184 ['block6b_se_excite[0][0]']
block6b_project_bn (BatchNorma (None, 7, 7, 192) 768 ['block6b_project_conv[0][0]']
lization)
block6b_drop (Dropout) (None, 7, 7, 192) 0 ['block6b_project_bn[0][0]']
block6b_add (Add) (None, 7, 7, 192) 0 ['block6b_drop[0][0]',
'block6a_project_bn[0][0]']
block6c_expand_conv (Conv2D) (None, 7, 7, 1152) 221184 ['block6b_add[0][0]']
block6c_expand_bn (BatchNormal (None, 7, 7, 1152) 4608 ['block6c_expand_conv[0][0]']
ization)
block6c_expand_activation (Act (None, 7, 7, 1152) 0 ['block6c_expand_bn[0][0]']
ivation)
block6c_dwconv (DepthwiseConv2 (None, 7, 7, 1152) 28800 ['block6c_expand_activation[0][0]
D) ']
block6c_bn (BatchNormalization (None, 7, 7, 1152) 4608 ['block6c_dwconv[0][0]']
)
block6c_activation (Activation (None, 7, 7, 1152) 0 ['block6c_bn[0][0]']
)
block6c_se_squeeze (GlobalAver (None, 1152) 0 ['block6c_activation[0][0]']
agePooling2D)
block6c_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6c_se_squeeze[0][0]']
block6c_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6c_se_reshape[0][0]']
block6c_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6c_se_reduce[0][0]']
block6c_se_excite (Multiply) (None, 7, 7, 1152) 0 ['block6c_activation[0][0]',
'block6c_se_expand[0][0]']
block6c_project_conv (Conv2D) (None, 7, 7, 192) 221184 ['block6c_se_excite[0][0]']
block6c_project_bn (BatchNorma (None, 7, 7, 192) 768 ['block6c_project_conv[0][0]']
lization)
block6c_drop (Dropout) (None, 7, 7, 192) 0 ['block6c_project_bn[0][0]']
block6c_add (Add) (None, 7, 7, 192) 0 ['block6c_drop[0][0]',
'block6b_add[0][0]']
block6d_expand_conv (Conv2D) (None, 7, 7, 1152) 221184 ['block6c_add[0][0]']
block6d_expand_bn (BatchNormal (None, 7, 7, 1152) 4608 ['block6d_expand_conv[0][0]']
ization)
block6d_expand_activation (Act (None, 7, 7, 1152) 0 ['block6d_expand_bn[0][0]']
ivation)
block6d_dwconv (DepthwiseConv2 (None, 7, 7, 1152) 28800 ['block6d_expand_activation[0][0]
D) ']
block6d_bn (BatchNormalization (None, 7, 7, 1152) 4608 ['block6d_dwconv[0][0]']
)
block6d_activation (Activation (None, 7, 7, 1152) 0 ['block6d_bn[0][0]']
)
block6d_se_squeeze (GlobalAver (None, 1152) 0 ['block6d_activation[0][0]']
agePooling2D)
block6d_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block6d_se_squeeze[0][0]']
block6d_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block6d_se_reshape[0][0]']
block6d_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block6d_se_reduce[0][0]']
block6d_se_excite (Multiply) (None, 7, 7, 1152) 0 ['block6d_activation[0][0]',
'block6d_se_expand[0][0]']
block6d_project_conv (Conv2D) (None, 7, 7, 192) 221184 ['block6d_se_excite[0][0]']
block6d_project_bn (BatchNorma (None, 7, 7, 192) 768 ['block6d_project_conv[0][0]']
lization)
block6d_drop (Dropout) (None, 7, 7, 192) 0 ['block6d_project_bn[0][0]']
block6d_add (Add) (None, 7, 7, 192) 0 ['block6d_drop[0][0]',
'block6c_add[0][0]']
block7a_expand_conv (Conv2D) (None, 7, 7, 1152) 221184 ['block6d_add[0][0]']
block7a_expand_bn (BatchNormal (None, 7, 7, 1152) 4608 ['block7a_expand_conv[0][0]']
ization)
block7a_expand_activation (Act (None, 7, 7, 1152) 0 ['block7a_expand_bn[0][0]']
ivation)
block7a_dwconv (DepthwiseConv2 (None, 7, 7, 1152) 10368 ['block7a_expand_activation[0][0]
D) ']
block7a_bn (BatchNormalization (None, 7, 7, 1152) 4608 ['block7a_dwconv[0][0]']
)
block7a_activation (Activation (None, 7, 7, 1152) 0 ['block7a_bn[0][0]']
)
block7a_se_squeeze (GlobalAver (None, 1152) 0 ['block7a_activation[0][0]']
agePooling2D)
block7a_se_reshape (Reshape) (None, 1, 1, 1152) 0 ['block7a_se_squeeze[0][0]']
block7a_se_reduce (Conv2D) (None, 1, 1, 48) 55344 ['block7a_se_reshape[0][0]']
block7a_se_expand (Conv2D) (None, 1, 1, 1152) 56448 ['block7a_se_reduce[0][0]']
block7a_se_excite (Multiply) (None, 7, 7, 1152) 0 ['block7a_activation[0][0]',
'block7a_se_expand[0][0]']
block7a_project_conv (Conv2D) (None, 7, 7, 320) 368640 ['block7a_se_excite[0][0]']
block7a_project_bn (BatchNorma (None, 7, 7, 320) 1280 ['block7a_project_conv[0][0]']
lization)
top_conv (Conv2D) (None, 7, 7, 1280) 409600 ['block7a_project_bn[0][0]']
top_bn (BatchNormalization) (None, 7, 7, 1280) 5120 ['top_conv[0][0]']
top_activation (Activation) (None, 7, 7, 1280) 0 ['top_bn[0][0]']
global_average_pooling2d (Glob (None, 1280) 0 ['top_activation[0][0]']
alAveragePooling2D)
dense_5 (Dense) (None, 3755) 4810155 ['global_average_pooling2d[0][0]'
]
==================================================================================================
Total params: 8,859,726
Trainable params: 8,817,703
Non-trainable params: 42,023
__________________________________________________________________________________________________
def train_efficientnet():
print(f'number of classes: {num_classes}')
history_path = 'history_efficientnet.pickle'
save_history_callback = SaveHistoryCallback(history_path)
train_dataset = load_datasets(train_path)
test_dataset = load_datasets(test_path)
train_dataset = train_dataset.map(preprocess_efficientnet).shuffle(100).batch(32).repeat()
test_dataset = test_dataset.shuffle(100).map(preprocess_efficientnet).batch(32).repeat()
print(f'train dataset: {train_dataset}')
# build model
model = effcientnetB0_model()
# model.summary() # too long to display
# latest checkpoints
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
if latest_ckpt:
print(f'model resumed from: {latest_ckpt}')
model.load_weights(latest_ckpt)
else:
print('training from scratch')
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(ckpt_path,
save_weights_only=True,
verbose=1,
save_freq='epoch',
save_best_only=True),
save_history_callback
]
try:
model.fit(
train_dataset,
validation_data=test_dataset,
validation_steps=1000,
epochs=100,
steps_per_epoch=1024,
callbacks=callbacks,
use_multiprocessing=True)
except KeyboardInterrupt:
logging.info('keras model saved. KeyboardInterrupt')
save_history_callback.on_train_end()
return model
model.save_weights(ckpt_path.format(epoch=0))
model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr_eff.h5'))
logging.info('All epoch finished. keras model saved.')
return model
# train the model (computation heavy)
model = train_efficientnet()
number of classes: 3755 train dataset: <RepeatDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))> model resumed from: ./checkpoints/efficient_net\cn_ocr-1.ckpt Epoch 1/100 6/1024 [..............................] - ETA: 2:09 - loss: 0.0037 - accuracy: 1.0000WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0558s vs `on_train_batch_end` time: 0.0590s). Check your callbacks. 1024/1024 [==============================] - ETA: 0s - loss: 0.0130 - accuracy: 0.9972 Epoch 1: val_loss improved from inf to 0.54211, saving model to ./checkpoints/efficient_net\cn_ocr-1.ckpt 1024/1024 [==============================] - 163s 152ms/step - loss: 0.0130 - accuracy: 0.9972 - val_loss: 0.5421 - val_accuracy: 0.8860 Epoch 2/100 1024/1024 [==============================] - ETA: 0s - loss: 0.1858 - accuracy: 0.9505 Epoch 2: val_loss did not improve from 0.54211 1024/1024 [==============================] - 155s 152ms/step - loss: 0.1858 - accuracy: 0.9505 - val_loss: 0.7122 - val_accuracy: 0.8370 Epoch 3/100 1024/1024 [==============================] - ETA: 0s - loss: 0.1169 - accuracy: 0.9684 Epoch 3: val_loss did not improve from 0.54211 1024/1024 [==============================] - 154s 150ms/step - loss: 0.1169 - accuracy: 0.9684 - val_loss: 0.5773 - val_accuracy: 0.8697 Epoch 4/100 1024/1024 [==============================] - ETA: 0s - loss: 0.1093 - accuracy: 0.9713 Epoch 4: val_loss did not improve from 0.54211 1024/1024 [==============================] - 154s 150ms/step - loss: 0.1093 - accuracy: 0.9713 - val_loss: 0.5893 - val_accuracy: 0.8614 Epoch 5/100 1024/1024 [==============================] - ETA: 0s - loss: 0.1342 - accuracy: 0.9635 Epoch 5: val_loss did not improve from 0.54211 1024/1024 [==============================] - 154s 150ms/step - loss: 0.1342 - accuracy: 0.9635 - val_loss: 0.5836 - val_accuracy: 0.8711 Epoch 6/100 1024/1024 [==============================] - ETA: 0s - loss: 0.2706 - accuracy: 0.9327 Epoch 6: val_loss did not improve from 0.54211 1024/1024 [==============================] - 154s 150ms/step - loss: 0.2706 - accuracy: 0.9327 - val_loss: 0.6092 - val_accuracy: 0.8580 Epoch 7/100 1024/1024 [==============================] - ETA: 0s - loss: 0.1408 - accuracy: 0.9673 Epoch 7: val_loss did not improve from 0.54211 1024/1024 [==============================] - 154s 151ms/step - loss: 0.1408 - accuracy: 0.9673 - val_loss: 0.9957 - val_accuracy: 0.7847 Epoch 8/100 40/1024 [>.............................] - ETA: 2:02 - loss: 0.0888 - accuracy: 0.9766
18:38:56 05.30 INFO 2887911528.py:50]: keras model saved. KeyboardInterrupt
with open('history_efficientnet.pickle', 'rb') as f:
history = pickle.load(f)
# Plot accuracy
plt.figure(figsize=(10, 7))
plt.plot(history['accuracy'])
plt.plot(history['val_accuracy'])
plt.title('Model accuracy - EfficientNetB0', fontsize=15)
plt.ylabel('Accuracy', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')
# show the highest accuracy in the graph
max_acc = max(history['val_accuracy'])
plt.annotate(f'max accuracy: {max_acc:.4f}', xy=(np.argmax(history['val_accuracy']), max_acc),
xytext=(np.argmax(history['val_accuracy']) - 20, max_acc - 0.2),
arrowprops=dict(facecolor='black', shrink=0.001), fontsize=15)
max_acc_train = max(history['accuracy'])
plt.annotate(f'max accuracy: {max_acc_train:.4f}', xy=(np.argmax(history['accuracy']), max_acc_train),
xytext=(np.argmax(history['accuracy']) - 40, max_acc_train + 0.01),
arrowprops=dict(facecolor='black', shrink=0.001), fontsize=15)
plt.show()
# Plot loss
plt.figure(figsize=(10, 7))
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.title('Model loss - EfficientNetB0', fontsize=15)
plt.ylabel('Loss', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
# Load the best model
model = effcientnetB0_model()
model.load_weights(tf.train.latest_checkpoint(os.path.dirname(ckpt_path)))
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x2504d884130>
def display_result(model, ds, characters):
ds = ds.shuffle(100).map(preprocess_efficientnet).batch(32).repeat()
_, axe = plt.subplots(3, 3, figsize=(15, 15))
plt.rcParams["font.sans-serif"]=["SimHei"] #设置字体
plt.rcParams["axes.unicode_minus"]=False #该语句解决图像中的“-”负号的乱码问题
for _, data in enumerate(ds.take(1)):
images, labels = data
predictions = model.predict(images)
for i in range(9):
axe[i//3, i%3].imshow(images[i].numpy().astype('uint8'), cmap='gray')
axe[i//3, i%3].axis('on')
axe[i//3, i%3].set_title(f'pred: {characters[np.argmax(predictions[i])]}, label: {characters[labels[i]]}, Is correct: {np.argmax(predictions[i]) == labels[i]}')
plt.show()
test_ds = load_datasets(test_path)
display_result(model, test_ds, all_characters)
1/1 [==============================] - 1s 949ms/step
We can see that for the model with EfficientNet, the validation accuracy can reach 85% with about 80 epoches, which is a huge improvement compared with the simple CNN model.