Chinese OCR¶


This is the final project of CSE204 Machine Learning course in Ecole Polytechnique. We use CNN and EfficientNet model to do OCR for singular handwriting Chinese Characters. The project is mainly based on Tensorflow 2.0 API.

Making sure to install all the packages in the requirements.txt file, we can now start the process of OCRing the Chinese text.

Contributors¶


  • Junyuan Wang
  • Yubo Cai
In [1]:
# import necessary packages
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os
from alfred.utils.log import logger as logging
from tensorflow.keras.applications import EfficientNetB0
import pickle
import random
In [2]:
# load character mapping
this_dir = os.path.dirname(os.path.abspath("__file__"))

def load_characters():
    # a = open(os.path.join(this_dir, 'dataset\\characters.txt'), 'r').readlines() # If you are using Windows
    a = open(os.path.join(this_dir, 'dataset/characters.txt'), 'r').readlines() # If you are using Mac OS
    return [i.strip() for i in a]

all_characters = load_characters()
num_classes = len(all_characters)
print(f'There are {num_classes} classes of chinese characters in the dataset')
logging.info('all characters: {}'.format(num_classes))
18:10:37 05.30 INFO 1114677896.py:12]: all characters: 3755
There are 3755 classes of chinese characters in the dataset

Load datasets¶


We want to load the dataset from the local file. The dataset is some tfrecord files which conotains images, labels, and resolution information for the images. The reason for saving as tfrecord format is to improve the training performance.

For more information about the dataset, please read the file README.md Part 3. Dataset Preparation. We also provide the converted dataset in the following link. You can directly use it for training.

  • test.tfrecord
  • train.tfrecord
In [3]:
# Necessary parameters
IMG_SIZE = 80 # This size is fixed for EfficientNetB0
In [4]:
# decode tfrecord to images and labels

def parse_image(record):
    """
    :param record: tfrecord file
    :return: image and label
    """
    features = tf.io.parse_single_example(record,
                                          features={
                                              'width':
                                                  tf.io.FixedLenFeature([], tf.int64),
                                              'height':
                                                  tf.io.FixedLenFeature([], tf.int64),
                                              'label':
                                                  tf.io.FixedLenFeature([], tf.int64),
                                              'image':
                                                  tf.io.FixedLenFeature([], tf.string),
                                          })
    img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
    w = features['width']
    h = features['height']
    img = tf.cast(tf.reshape(img, (w, h)), dtype=tf.float32)
    label = tf.cast(features['label'], tf.int64)
    return {'image': img, 'label': label}

def load_datasets(filename):
    """
    :param filename: tfrecord file
    :return: dataset
    """
    dataset = tf.data.TFRecordDataset([filename])
    dataset = dataset.map(parse_image)
    return dataset

Explore datasets¶

In this part, we can see what are the datasets look like.

In [5]:
train_ds = load_datasets('dataset/train.tfrecord') # read train.tfrecord
test_ds = load_datasets('dataset/test.tfrecord')

train_mapped = train_ds.shuffle(100).batch(32).repeat()
test_mapped = test_ds.batch(32).repeat()
train_mapped
Out[5]:
<RepeatDataset element_spec={'image': TensorSpec(shape=(None, None, None), dtype=tf.float32, name=None), 'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}>
In [8]:
# summary inforamtion of the dataset, plot the width and height of the images
shapes_plot_train = [[], []]
shapes_plot_test = [[], []]
for i in train_ds:
    shapes_plot_train[0].append(i['image'].shape[0])
    shapes_plot_train[1].append(i['image'].shape[1])

for i in test_ds:
    shapes_plot_test[0].append(i['image'].shape[0])
    shapes_plot_test[1].append(i['image'].shape[1])
In [9]:
# show the number of pictures in the dataset
print(f'There are {len(shapes_plot_train[0])} pictures in the training dataset')
print(f'There are {len(shapes_plot_test[0])} pictures in the testing dataset')
There are 897758 pictures in the training dataset
There are 223991 pictures in the testing dataset
In [10]:
def plot_info_dataset(shapes_plot):
    # plot the histogram of the width and height of the images, and the third plot is the scatter plot
    fig, ax = plt.subplots(1, 3, figsize=(20, 5))
    ax[0].hist(shapes_plot[0], bins=100)
    ax[0].set_title('width')
    ax[0].set_xlabel('pixels')
    ax[0].set_ylabel('frequency')
    ax[1].hist(shapes_plot[1], bins=100)
    ax[1].set_title('height')
    ax[0].set_xlabel('pixels')
    ax[0].set_ylabel('frequency')
    ax[2].scatter(shapes_plot[0], shapes_plot[1])
    ax[2].set_title('width vs height')
    ax[2].set_xlabel('width (pixels)')
    ax[2].set_ylabel('height (pixels)')
    plt.show()
In [11]:
print('------------------- Information of training dataset -------------------')
plot_info_dataset(shapes_plot_train)
print('------------------- Information of testing dataset -------------------')
plot_info_dataset(shapes_plot_test)
------------------- Information of training dataset -------------------
------------------- Information of testing dataset -------------------

We can also plot some images and labels from the datasets.

In [12]:
# default_font = FontProperties()
def plot_img(ds, characters, num =9):
    plt.rcParams["font.sans-serif"]=["SimHei"] #Set font to avoid garbled characters
    plt.rcParams["axes.unicode_minus"]=False # The same as above
    ax, fig = plt.subplots(3, 3, figsize=(12, 12))
    for i, data in enumerate(ds.take(num)):
        img = data['image'].numpy()
        label = data['label'].numpy()
        fig[i//3][i%3].imshow(img, cmap='gray')
        fig[i//3][i%3].set_title(f'label: {characters[label]}, size: {img.shape}', fontsize=15)
    
    plt.show()
In [13]:
plot_img(train_ds, all_characters, num=9)

Issue with the labels display in Mac OS¶

Important, Please Read!!! If you are using Mac OS, you may find that the labels are not displayed correctly. This is because the package Matplotlib is not compatible with SimHei font. To solve this problem, you can read this MacOS Chinese Characters Label Problem.

Training - With Simple CNN Model¶


In this part, we will define the models and training functions.
In the first part, We will first start from a simple CNN model.

In [14]:
# some basic parameters
TARGET_SIZE = 64
IMG_SIZE = 224 # This size is fixed for EfficientNetB0
ckpt_path = './checkpoints/simple_net/cn_ocr-{epoch}.ckpt'
train_path = 'dataset/train.tfrecord'
test_path = 'dataset/test.tfrecord'
In [15]:
# image preprocessing
def preprocess(ds):
    """
    :param ds: dataset
    :return: image and label
    """
    ds['image'] = tf.expand_dims(ds['image'], axis=-1)
    ds['image'] = tf.image.resize(ds['image'], (TARGET_SIZE, TARGET_SIZE))
    ds['image'] = (ds['image'] - 128.) / 128.
    return ds['image'], ds['label']
In [16]:
# history log callback function
#
# This part is used to collecting the training statics and save them into a file.
#
class SaveHistoryCallback(tf.keras.callbacks.Callback):
    def __init__(self, history_path):
        super().__init__()
        self.history_path = history_path
        if os.path.exists(self.history_path):
            with open(self.history_path, 'rb') as f:
                self.history = pickle.load(f)
        else:
            self.history = {}

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        for key in logs:
            if key in self.history:
                self.history[key].append(logs[key])
            else:
                self.history[key] = [logs[key]]

    def on_train_end(self, logs=None):
        with open(self.history_path, 'wb') as f:
            pickle.dump(self.history, f)
In [17]:
# Define a simple CNN model
def simple_net(input_shape, n_classes):
    model = tf.keras.Sequential([
        layers.Conv2D(input_shape=input_shape, filters=32, kernel_size=(3, 3), strides=(1, 1),
                      padding='same', activation='relu'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),
        layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),

        layers.Flatten(),
        # layers.Dense(1024, activation='relu'),
        layers.Dense(n_classes, activation='softmax')
    ])
    return model

# Define with a more complex CNN model - However, this model is not even converging
def CNN_ComplexModel_1(input_shape, n_classes):
    model = tf.keras.Sequential([
        layers.Conv2D(input_shape=input_shape, filters=64, kernel_size=(3, 3), strides=(1, 1),
                      padding='same', activation='relu'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),
        layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),
        layers.Conv2D(filters=256, kernel_size=(3, 3), padding='same'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),
        layers.Conv2D(filters=512, kernel_size=(3, 3), padding='same'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),

        layers.Flatten(),
        layers.Dense(1024, activation='relu'),
        layers.Dense(n_classes, activation='softmax')
    ])
    return model
In [18]:
# show the summary of the model
model = simple_net((TARGET_SIZE, TARGET_SIZE, 1), len(all_characters))
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 64, 64, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 32, 32, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 32, 32, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 16, 16, 64)       0         
 2D)                                                             
                                                                 
 flatten (Flatten)           (None, 16384)             0         
                                                                 
 dense (Dense)               (None, 3755)              61525675  
                                                                 
=================================================================
Total params: 61,544,491
Trainable params: 61,544,491
Non-trainable params: 0
_________________________________________________________________
In [19]:
model_cnn_complex = CNN_ComplexModel_1((TARGET_SIZE, TARGET_SIZE, 1), len(all_characters))
model_cnn_complex.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_2 (Conv2D)           (None, 64, 64, 64)        640       
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 32, 32, 64)       0         
 2D)                                                             
                                                                 
 conv2d_3 (Conv2D)           (None, 32, 32, 128)       73856     
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 16, 16, 128)      0         
 2D)                                                             
                                                                 
 conv2d_4 (Conv2D)           (None, 16, 16, 256)       295168    
                                                                 
 max_pooling2d_4 (MaxPooling  (None, 8, 8, 256)        0         
 2D)                                                             
                                                                 
 conv2d_5 (Conv2D)           (None, 8, 8, 512)         1180160   
                                                                 
 max_pooling2d_5 (MaxPooling  (None, 4, 4, 512)        0         
 2D)                                                             
                                                                 
 flatten_1 (Flatten)         (None, 8192)              0         
                                                                 
 dense_1 (Dense)             (None, 1024)              8389632   
                                                                 
 dense_2 (Dense)             (None, 3755)              3848875   
                                                                 
=================================================================
Total params: 13,788,331
Trainable params: 13,788,331
Non-trainable params: 0
_________________________________________________________________

In this part. We will try to train the model.

In [20]:
def train_simple():
    print(f'number of classes: {num_classes}')

    history_path = 'history_simple.pickle'
    save_history_callback = SaveHistoryCallback(history_path)

    train_dataset = load_datasets(train_path)
    test_dataset = load_datasets(test_path)

    train_dataset = train_dataset.map(preprocess).shuffle(100).batch(32).repeat()
    test_dataset = test_dataset.shuffle(100).map(preprocess).batch(32)

    print(f'train dataset: {train_dataset}')

    # build model
    model = simple_net((TARGET_SIZE, TARGET_SIZE, 1), num_classes)
    model.summary()

    # latest checkpoints
    latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
    if latest_ckpt:
        print(f'model resumed from: {latest_ckpt}')
        model.load_weights(latest_ckpt)
    else:
        print('training from scratch')
        
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy'])
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(ckpt_path,
                                            save_weights_only=True,
                                            verbose=1,
                                            save_freq='epoch',
                                            save_best_only=True),
        save_history_callback
    ]
    try:
        model.fit(
            train_dataset,
            validation_data=test_dataset,
            validation_steps=1000,
            epochs=100,
            steps_per_epoch=1024,
            callbacks=callbacks)
    except KeyboardInterrupt:
        logging.info('keras model saved. KeyboardInterrupt')
        save_history_callback.on_train_end()
        return model 
    
    model.save_weights(ckpt_path.format(epoch=0))
    model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr_simple.h5'))
    logging.info('All epoch finished. keras model saved.')
    return model

From our trainning result, In the two CNN models we provide above, the simpler one has better performance. The more complex CNN model not even converge which is a quite interesting result.

In [21]:
# train the model (computation heavy)
model = train_simple()
number of classes: 3755
train dataset: <RepeatDataset element_spec=(TensorSpec(shape=(None, 64, 64, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>
Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_6 (Conv2D)           (None, 64, 64, 32)        320       
                                                                 
 max_pooling2d_6 (MaxPooling  (None, 32, 32, 32)       0         
 2D)                                                             
                                                                 
 conv2d_7 (Conv2D)           (None, 32, 32, 64)        18496     
                                                                 
 max_pooling2d_7 (MaxPooling  (None, 16, 16, 64)       0         
 2D)                                                             
                                                                 
 flatten_2 (Flatten)         (None, 16384)             0         
                                                                 
 dense_3 (Dense)             (None, 3755)              61525675  
                                                                 
=================================================================
Total params: 61,544,491
Trainable params: 61,544,491
Non-trainable params: 0
_________________________________________________________________
model resumed from: ./checkpoints/simple_net\cn_ocr-1.ckpt
Epoch 1/100
   4/1024 [..............................] - ETA: 20s - loss: 0.0467 - accuracy: 0.9766        WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0068s vs `on_train_batch_end` time: 0.0108s). Check your callbacks.
1022/1024 [============================>.] - ETA: 0s - loss: 0.0242 - accuracy: 0.9927
Epoch 1: val_loss improved from inf to 4.37403, saving model to ./checkpoints/simple_net\cn_ocr-1.ckpt
1024/1024 [==============================] - 20s 16ms/step - loss: 0.0242 - accuracy: 0.9927 - val_loss: 4.3740 - val_accuracy: 0.4745
Epoch 2/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1772 - accuracy: 0.9523
Epoch 2: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 15ms/step - loss: 0.1772 - accuracy: 0.9523 - val_loss: 5.0706 - val_accuracy: 0.3959
Epoch 3/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1651 - accuracy: 0.9527
Epoch 3: val_loss did not improve from 4.37403
1024/1024 [==============================] - 14s 14ms/step - loss: 0.1652 - accuracy: 0.9528 - val_loss: 5.4529 - val_accuracy: 0.4083
Epoch 4/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1733 - accuracy: 0.9509
Epoch 4: val_loss did not improve from 4.37403
1024/1024 [==============================] - 14s 14ms/step - loss: 0.1732 - accuracy: 0.9509 - val_loss: 4.7948 - val_accuracy: 0.4309
Epoch 5/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1912 - accuracy: 0.9467
Epoch 5: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.1910 - accuracy: 0.9467 - val_loss: 4.8460 - val_accuracy: 0.4083
Epoch 6/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.2448 - accuracy: 0.9330
Epoch 6: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.2444 - accuracy: 0.9331 - val_loss: 5.6845 - val_accuracy: 0.3722
Epoch 7/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.2070 - accuracy: 0.9426
Epoch 7: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.2068 - accuracy: 0.9426 - val_loss: 5.0307 - val_accuracy: 0.3918
Epoch 8/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.2148 - accuracy: 0.9402
Epoch 8: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.2151 - accuracy: 0.9401 - val_loss: 5.2210 - val_accuracy: 0.4089
Epoch 9/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.2385 - accuracy: 0.9332
Epoch 9: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.2385 - accuracy: 0.9332 - val_loss: 4.9118 - val_accuracy: 0.4042
Epoch 10/100
1020/1024 [============================>.] - ETA: 0s - loss: 0.2330 - accuracy: 0.9342
Epoch 10: val_loss did not improve from 4.37403
1024/1024 [==============================] - 15s 14ms/step - loss: 0.2326 - accuracy: 0.9343 - val_loss: 6.6977 - val_accuracy: 0.3193
Epoch 11/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1729 - accuracy: 0.9498
Epoch 11: val_loss did not improve from 4.37403
1024/1024 [==============================] - 14s 14ms/step - loss: 0.1730 - accuracy: 0.9497 - val_loss: 4.9370 - val_accuracy: 0.4512
Epoch 12/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.2010 - accuracy: 0.9420
Epoch 12: val_loss did not improve from 4.37403
1024/1024 [==============================] - 14s 14ms/step - loss: 0.2009 - accuracy: 0.9420 - val_loss: 5.2270 - val_accuracy: 0.4319
Epoch 13/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.1908 - accuracy: 0.9446
Epoch 13: val_loss did not improve from 4.37403
1024/1024 [==============================] - 14s 14ms/step - loss: 0.1910 - accuracy: 0.9445 - val_loss: 4.6441 - val_accuracy: 0.4633
Epoch 14/100
1021/1024 [============================>.] - ETA: 0s - loss: 0.9264 - accuracy: 0.7941
18:20:18 05.30 INFO 3396354300.py:49]: keras model saved. KeyboardInterrupt

Evaluaton and Prediction visualization¶

In [22]:
with open('history_simple.pickle', 'rb') as f:
    history = pickle.load(f)

# Plot accuracy
plt.figure(figsize=(10, 7))
plt.plot(history['accuracy'])
plt.plot(history['val_accuracy'])
plt.title('Model accuracy - Simple CNN', fontsize=15)
plt.ylabel('Accuracy', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')

# show the highest accuracy in the graph
max_acc = max(history['val_accuracy'])
plt.annotate(f'max accuracy: {max_acc:.4f}', xy=(np.argmax(history['val_accuracy']), max_acc),
                xytext=(np.argmax(history['val_accuracy']) - 10, max_acc + 0.1),
                arrowprops=dict(facecolor='black', shrink=0.03), fontsize=15)
max_acc_train = max(history['accuracy'])
plt.annotate(f'max accuracy: {max_acc_train:.4f}', xy=(np.argmax(history['accuracy']), max_acc_train),
                xytext=(np.argmax(history['accuracy']) - 30, max_acc_train - 0.1),
                arrowprops=dict(facecolor='black', shrink=0.03), fontsize=15)
plt.show()
In [23]:
# Plot loss
plt.figure(figsize=(10, 7))
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.title('Model loss - Simple CNN', fontsize=15)
plt.ylabel('Loss', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
In [24]:
# Load the best model
model = simple_net((TARGET_SIZE, TARGET_SIZE, 1), num_classes)
model.load_weights(tf.train.latest_checkpoint(os.path.dirname(ckpt_path)))
Out[24]:
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x25053771510>
In [25]:
# display the result
import random

def display_result(model, ds, characters):
    ds = ds.shuffle(100).map(preprocess).batch(32).repeat()
    f, axe = plt.subplots(3, 3, figsize=(15, 15))
    plt.rcParams["font.sans-serif"]=["SimHei"] #Set font to avoid garbled characters
    plt.rcParams["axes.unicode_minus"]=False # The same as above
    # take 9 samples
    for data in ds.take(1):
        images, labels = data
        predictions = model.predict(images)
        for i in range(9):
            axe[i//3, i%3].imshow(images[i].numpy().reshape(64, 64), cmap='gray')
            axe[i//3, i%3].axis('on')
            axe[i//3, i%3].set_title(f'pred: {characters[np.argmax(predictions[i])]}, label: {characters[labels[i]]}, Is Correct: {np.argmax(predictions[i]) == labels[i]}')
    plt.show()
In [26]:
test_ds = load_datasets(test_path)
display_result(model, test_ds, all_characters)
1/1 [==============================] - 0s 65ms/step

Refine the Model with EfficientNet¶


This model can only reach the validation accuracy to around 40%, which is not satisfying for us. We decided to use a far more complicated model called effcientnet to train the model. This model is a state-of-the-art model for image classification. We will use the B0 version of the model. For more information about the model, please read the paper and this webpage

In [27]:
# some basic parameters
IMG_SIZE = 224 # This size is fixed for EfficientNetB0
ckpt_path = './checkpoints/efficient_net/cn_ocr-{epoch}.ckpt'
train_path = 'dataset/train.tfrecord'
test_path = 'dataset/test.tfrecord'
In [28]:
def preprocess_efficientnet(ds):
    """
    :param ds: dataset
    :return: image and label

    No normalization is needed here for EfficientNetB0
    """
    ds['image'] = tf.expand_dims(ds['image'], axis=-1)
    ds['image'] = tf.image.grayscale_to_rgb(ds['image'])
    ds['image'] = tf.image.resize(ds['image'], (IMG_SIZE, IMG_SIZE))
    return ds['image'], ds['label']
In [29]:
# model
def effcientnetB0_model():
    inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))  # EfficientNetB0 expects 3 channels
    base_model = EfficientNetB0(include_top=False, input_tensor=inputs, weights='imagenet')
    x = base_model.output
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(num_classes, activation='softmax')(x)
    model = keras.Model(inputs=inputs, outputs=x)
    return model
In [30]:
# model summary of EfficientNetB0
model = effcientnetB0_model()
model.summary()
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 224, 224, 3)  0           ['input_1[0][0]']                
                                                                                                  
 normalization (Normalization)  (None, 224, 224, 3)  7           ['rescaling[0][0]']              
                                                                                                  
 rescaling_1 (Rescaling)        (None, 224, 224, 3)  0           ['normalization[0][0]']          
                                                                                                  
 stem_conv_pad (ZeroPadding2D)  (None, 225, 225, 3)  0           ['rescaling_1[0][0]']            
                                                                                                  
 stem_conv (Conv2D)             (None, 112, 112, 32  864         ['stem_conv_pad[0][0]']          
                                )                                                                 
                                                                                                  
 stem_bn (BatchNormalization)   (None, 112, 112, 32  128         ['stem_conv[0][0]']              
                                )                                                                 
                                                                                                  
 stem_activation (Activation)   (None, 112, 112, 32  0           ['stem_bn[0][0]']                
                                )                                                                 
                                                                                                  
 block1a_dwconv (DepthwiseConv2  (None, 112, 112, 32  288        ['stem_activation[0][0]']        
 D)                             )                                                                 
                                                                                                  
 block1a_bn (BatchNormalization  (None, 112, 112, 32  128        ['block1a_dwconv[0][0]']         
 )                              )                                                                 
                                                                                                  
 block1a_activation (Activation  (None, 112, 112, 32  0          ['block1a_bn[0][0]']             
 )                              )                                                                 
                                                                                                  
 block1a_se_squeeze (GlobalAver  (None, 32)          0           ['block1a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block1a_se_reshape (Reshape)   (None, 1, 1, 32)     0           ['block1a_se_squeeze[0][0]']     
                                                                                                  
 block1a_se_reduce (Conv2D)     (None, 1, 1, 8)      264         ['block1a_se_reshape[0][0]']     
                                                                                                  
 block1a_se_expand (Conv2D)     (None, 1, 1, 32)     288         ['block1a_se_reduce[0][0]']      
                                                                                                  
 block1a_se_excite (Multiply)   (None, 112, 112, 32  0           ['block1a_activation[0][0]',     
                                )                                 'block1a_se_expand[0][0]']      
                                                                                                  
 block1a_project_conv (Conv2D)  (None, 112, 112, 16  512         ['block1a_se_excite[0][0]']      
                                )                                                                 
                                                                                                  
 block1a_project_bn (BatchNorma  (None, 112, 112, 16  64         ['block1a_project_conv[0][0]']   
 lization)                      )                                                                 
                                                                                                  
 block2a_expand_conv (Conv2D)   (None, 112, 112, 96  1536        ['block1a_project_bn[0][0]']     
                                )                                                                 
                                                                                                  
 block2a_expand_bn (BatchNormal  (None, 112, 112, 96  384        ['block2a_expand_conv[0][0]']    
 ization)                       )                                                                 
                                                                                                  
 block2a_expand_activation (Act  (None, 112, 112, 96  0          ['block2a_expand_bn[0][0]']      
 ivation)                       )                                                                 
                                                                                                  
 block2a_dwconv_pad (ZeroPaddin  (None, 113, 113, 96  0          ['block2a_expand_activation[0][0]
 g2D)                           )                                ']                               
                                                                                                  
 block2a_dwconv (DepthwiseConv2  (None, 56, 56, 96)  864         ['block2a_dwconv_pad[0][0]']     
 D)                                                                                               
                                                                                                  
 block2a_bn (BatchNormalization  (None, 56, 56, 96)  384         ['block2a_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block2a_activation (Activation  (None, 56, 56, 96)  0           ['block2a_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block2a_se_squeeze (GlobalAver  (None, 96)          0           ['block2a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block2a_se_reshape (Reshape)   (None, 1, 1, 96)     0           ['block2a_se_squeeze[0][0]']     
                                                                                                  
 block2a_se_reduce (Conv2D)     (None, 1, 1, 4)      388         ['block2a_se_reshape[0][0]']     
                                                                                                  
 block2a_se_expand (Conv2D)     (None, 1, 1, 96)     480         ['block2a_se_reduce[0][0]']      
                                                                                                  
 block2a_se_excite (Multiply)   (None, 56, 56, 96)   0           ['block2a_activation[0][0]',     
                                                                  'block2a_se_expand[0][0]']      
                                                                                                  
 block2a_project_conv (Conv2D)  (None, 56, 56, 24)   2304        ['block2a_se_excite[0][0]']      
                                                                                                  
 block2a_project_bn (BatchNorma  (None, 56, 56, 24)  96          ['block2a_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block2b_expand_conv (Conv2D)   (None, 56, 56, 144)  3456        ['block2a_project_bn[0][0]']     
                                                                                                  
 block2b_expand_bn (BatchNormal  (None, 56, 56, 144)  576        ['block2b_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block2b_expand_activation (Act  (None, 56, 56, 144)  0          ['block2b_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block2b_dwconv (DepthwiseConv2  (None, 56, 56, 144)  1296       ['block2b_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block2b_bn (BatchNormalization  (None, 56, 56, 144)  576        ['block2b_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block2b_activation (Activation  (None, 56, 56, 144)  0          ['block2b_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block2b_se_squeeze (GlobalAver  (None, 144)         0           ['block2b_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block2b_se_reshape (Reshape)   (None, 1, 1, 144)    0           ['block2b_se_squeeze[0][0]']     
                                                                                                  
 block2b_se_reduce (Conv2D)     (None, 1, 1, 6)      870         ['block2b_se_reshape[0][0]']     
                                                                                                  
 block2b_se_expand (Conv2D)     (None, 1, 1, 144)    1008        ['block2b_se_reduce[0][0]']      
                                                                                                  
 block2b_se_excite (Multiply)   (None, 56, 56, 144)  0           ['block2b_activation[0][0]',     
                                                                  'block2b_se_expand[0][0]']      
                                                                                                  
 block2b_project_conv (Conv2D)  (None, 56, 56, 24)   3456        ['block2b_se_excite[0][0]']      
                                                                                                  
 block2b_project_bn (BatchNorma  (None, 56, 56, 24)  96          ['block2b_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block2b_drop (Dropout)         (None, 56, 56, 24)   0           ['block2b_project_bn[0][0]']     
                                                                                                  
 block2b_add (Add)              (None, 56, 56, 24)   0           ['block2b_drop[0][0]',           
                                                                  'block2a_project_bn[0][0]']     
                                                                                                  
 block3a_expand_conv (Conv2D)   (None, 56, 56, 144)  3456        ['block2b_add[0][0]']            
                                                                                                  
 block3a_expand_bn (BatchNormal  (None, 56, 56, 144)  576        ['block3a_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block3a_expand_activation (Act  (None, 56, 56, 144)  0          ['block3a_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block3a_dwconv_pad (ZeroPaddin  (None, 59, 59, 144)  0          ['block3a_expand_activation[0][0]
 g2D)                                                            ']                               
                                                                                                  
 block3a_dwconv (DepthwiseConv2  (None, 28, 28, 144)  3600       ['block3a_dwconv_pad[0][0]']     
 D)                                                                                               
                                                                                                  
 block3a_bn (BatchNormalization  (None, 28, 28, 144)  576        ['block3a_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block3a_activation (Activation  (None, 28, 28, 144)  0          ['block3a_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block3a_se_squeeze (GlobalAver  (None, 144)         0           ['block3a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block3a_se_reshape (Reshape)   (None, 1, 1, 144)    0           ['block3a_se_squeeze[0][0]']     
                                                                                                  
 block3a_se_reduce (Conv2D)     (None, 1, 1, 6)      870         ['block3a_se_reshape[0][0]']     
                                                                                                  
 block3a_se_expand (Conv2D)     (None, 1, 1, 144)    1008        ['block3a_se_reduce[0][0]']      
                                                                                                  
 block3a_se_excite (Multiply)   (None, 28, 28, 144)  0           ['block3a_activation[0][0]',     
                                                                  'block3a_se_expand[0][0]']      
                                                                                                  
 block3a_project_conv (Conv2D)  (None, 28, 28, 40)   5760        ['block3a_se_excite[0][0]']      
                                                                                                  
 block3a_project_bn (BatchNorma  (None, 28, 28, 40)  160         ['block3a_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block3b_expand_conv (Conv2D)   (None, 28, 28, 240)  9600        ['block3a_project_bn[0][0]']     
                                                                                                  
 block3b_expand_bn (BatchNormal  (None, 28, 28, 240)  960        ['block3b_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block3b_expand_activation (Act  (None, 28, 28, 240)  0          ['block3b_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block3b_dwconv (DepthwiseConv2  (None, 28, 28, 240)  6000       ['block3b_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block3b_bn (BatchNormalization  (None, 28, 28, 240)  960        ['block3b_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block3b_activation (Activation  (None, 28, 28, 240)  0          ['block3b_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block3b_se_squeeze (GlobalAver  (None, 240)         0           ['block3b_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block3b_se_reshape (Reshape)   (None, 1, 1, 240)    0           ['block3b_se_squeeze[0][0]']     
                                                                                                  
 block3b_se_reduce (Conv2D)     (None, 1, 1, 10)     2410        ['block3b_se_reshape[0][0]']     
                                                                                                  
 block3b_se_expand (Conv2D)     (None, 1, 1, 240)    2640        ['block3b_se_reduce[0][0]']      
                                                                                                  
 block3b_se_excite (Multiply)   (None, 28, 28, 240)  0           ['block3b_activation[0][0]',     
                                                                  'block3b_se_expand[0][0]']      
                                                                                                  
 block3b_project_conv (Conv2D)  (None, 28, 28, 40)   9600        ['block3b_se_excite[0][0]']      
                                                                                                  
 block3b_project_bn (BatchNorma  (None, 28, 28, 40)  160         ['block3b_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block3b_drop (Dropout)         (None, 28, 28, 40)   0           ['block3b_project_bn[0][0]']     
                                                                                                  
 block3b_add (Add)              (None, 28, 28, 40)   0           ['block3b_drop[0][0]',           
                                                                  'block3a_project_bn[0][0]']     
                                                                                                  
 block4a_expand_conv (Conv2D)   (None, 28, 28, 240)  9600        ['block3b_add[0][0]']            
                                                                                                  
 block4a_expand_bn (BatchNormal  (None, 28, 28, 240)  960        ['block4a_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block4a_expand_activation (Act  (None, 28, 28, 240)  0          ['block4a_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block4a_dwconv_pad (ZeroPaddin  (None, 29, 29, 240)  0          ['block4a_expand_activation[0][0]
 g2D)                                                            ']                               
                                                                                                  
 block4a_dwconv (DepthwiseConv2  (None, 14, 14, 240)  2160       ['block4a_dwconv_pad[0][0]']     
 D)                                                                                               
                                                                                                  
 block4a_bn (BatchNormalization  (None, 14, 14, 240)  960        ['block4a_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block4a_activation (Activation  (None, 14, 14, 240)  0          ['block4a_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block4a_se_squeeze (GlobalAver  (None, 240)         0           ['block4a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block4a_se_reshape (Reshape)   (None, 1, 1, 240)    0           ['block4a_se_squeeze[0][0]']     
                                                                                                  
 block4a_se_reduce (Conv2D)     (None, 1, 1, 10)     2410        ['block4a_se_reshape[0][0]']     
                                                                                                  
 block4a_se_expand (Conv2D)     (None, 1, 1, 240)    2640        ['block4a_se_reduce[0][0]']      
                                                                                                  
 block4a_se_excite (Multiply)   (None, 14, 14, 240)  0           ['block4a_activation[0][0]',     
                                                                  'block4a_se_expand[0][0]']      
                                                                                                  
 block4a_project_conv (Conv2D)  (None, 14, 14, 80)   19200       ['block4a_se_excite[0][0]']      
                                                                                                  
 block4a_project_bn (BatchNorma  (None, 14, 14, 80)  320         ['block4a_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block4b_expand_conv (Conv2D)   (None, 14, 14, 480)  38400       ['block4a_project_bn[0][0]']     
                                                                                                  
 block4b_expand_bn (BatchNormal  (None, 14, 14, 480)  1920       ['block4b_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block4b_expand_activation (Act  (None, 14, 14, 480)  0          ['block4b_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block4b_dwconv (DepthwiseConv2  (None, 14, 14, 480)  4320       ['block4b_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block4b_bn (BatchNormalization  (None, 14, 14, 480)  1920       ['block4b_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block4b_activation (Activation  (None, 14, 14, 480)  0          ['block4b_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block4b_se_squeeze (GlobalAver  (None, 480)         0           ['block4b_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block4b_se_reshape (Reshape)   (None, 1, 1, 480)    0           ['block4b_se_squeeze[0][0]']     
                                                                                                  
 block4b_se_reduce (Conv2D)     (None, 1, 1, 20)     9620        ['block4b_se_reshape[0][0]']     
                                                                                                  
 block4b_se_expand (Conv2D)     (None, 1, 1, 480)    10080       ['block4b_se_reduce[0][0]']      
                                                                                                  
 block4b_se_excite (Multiply)   (None, 14, 14, 480)  0           ['block4b_activation[0][0]',     
                                                                  'block4b_se_expand[0][0]']      
                                                                                                  
 block4b_project_conv (Conv2D)  (None, 14, 14, 80)   38400       ['block4b_se_excite[0][0]']      
                                                                                                  
 block4b_project_bn (BatchNorma  (None, 14, 14, 80)  320         ['block4b_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block4b_drop (Dropout)         (None, 14, 14, 80)   0           ['block4b_project_bn[0][0]']     
                                                                                                  
 block4b_add (Add)              (None, 14, 14, 80)   0           ['block4b_drop[0][0]',           
                                                                  'block4a_project_bn[0][0]']     
                                                                                                  
 block4c_expand_conv (Conv2D)   (None, 14, 14, 480)  38400       ['block4b_add[0][0]']            
                                                                                                  
 block4c_expand_bn (BatchNormal  (None, 14, 14, 480)  1920       ['block4c_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block4c_expand_activation (Act  (None, 14, 14, 480)  0          ['block4c_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block4c_dwconv (DepthwiseConv2  (None, 14, 14, 480)  4320       ['block4c_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block4c_bn (BatchNormalization  (None, 14, 14, 480)  1920       ['block4c_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block4c_activation (Activation  (None, 14, 14, 480)  0          ['block4c_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block4c_se_squeeze (GlobalAver  (None, 480)         0           ['block4c_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block4c_se_reshape (Reshape)   (None, 1, 1, 480)    0           ['block4c_se_squeeze[0][0]']     
                                                                                                  
 block4c_se_reduce (Conv2D)     (None, 1, 1, 20)     9620        ['block4c_se_reshape[0][0]']     
                                                                                                  
 block4c_se_expand (Conv2D)     (None, 1, 1, 480)    10080       ['block4c_se_reduce[0][0]']      
                                                                                                  
 block4c_se_excite (Multiply)   (None, 14, 14, 480)  0           ['block4c_activation[0][0]',     
                                                                  'block4c_se_expand[0][0]']      
                                                                                                  
 block4c_project_conv (Conv2D)  (None, 14, 14, 80)   38400       ['block4c_se_excite[0][0]']      
                                                                                                  
 block4c_project_bn (BatchNorma  (None, 14, 14, 80)  320         ['block4c_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block4c_drop (Dropout)         (None, 14, 14, 80)   0           ['block4c_project_bn[0][0]']     
                                                                                                  
 block4c_add (Add)              (None, 14, 14, 80)   0           ['block4c_drop[0][0]',           
                                                                  'block4b_add[0][0]']            
                                                                                                  
 block5a_expand_conv (Conv2D)   (None, 14, 14, 480)  38400       ['block4c_add[0][0]']            
                                                                                                  
 block5a_expand_bn (BatchNormal  (None, 14, 14, 480)  1920       ['block5a_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5a_expand_activation (Act  (None, 14, 14, 480)  0          ['block5a_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5a_dwconv (DepthwiseConv2  (None, 14, 14, 480)  12000      ['block5a_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block5a_bn (BatchNormalization  (None, 14, 14, 480)  1920       ['block5a_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block5a_activation (Activation  (None, 14, 14, 480)  0          ['block5a_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block5a_se_squeeze (GlobalAver  (None, 480)         0           ['block5a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5a_se_reshape (Reshape)   (None, 1, 1, 480)    0           ['block5a_se_squeeze[0][0]']     
                                                                                                  
 block5a_se_reduce (Conv2D)     (None, 1, 1, 20)     9620        ['block5a_se_reshape[0][0]']     
                                                                                                  
 block5a_se_expand (Conv2D)     (None, 1, 1, 480)    10080       ['block5a_se_reduce[0][0]']      
                                                                                                  
 block5a_se_excite (Multiply)   (None, 14, 14, 480)  0           ['block5a_activation[0][0]',     
                                                                  'block5a_se_expand[0][0]']      
                                                                                                  
 block5a_project_conv (Conv2D)  (None, 14, 14, 112)  53760       ['block5a_se_excite[0][0]']      
                                                                                                  
 block5a_project_bn (BatchNorma  (None, 14, 14, 112)  448        ['block5a_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block5b_expand_conv (Conv2D)   (None, 14, 14, 672)  75264       ['block5a_project_bn[0][0]']     
                                                                                                  
 block5b_expand_bn (BatchNormal  (None, 14, 14, 672)  2688       ['block5b_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5b_expand_activation (Act  (None, 14, 14, 672)  0          ['block5b_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5b_dwconv (DepthwiseConv2  (None, 14, 14, 672)  16800      ['block5b_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block5b_bn (BatchNormalization  (None, 14, 14, 672)  2688       ['block5b_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block5b_activation (Activation  (None, 14, 14, 672)  0          ['block5b_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block5b_se_squeeze (GlobalAver  (None, 672)         0           ['block5b_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5b_se_reshape (Reshape)   (None, 1, 1, 672)    0           ['block5b_se_squeeze[0][0]']     
                                                                                                  
 block5b_se_reduce (Conv2D)     (None, 1, 1, 28)     18844       ['block5b_se_reshape[0][0]']     
                                                                                                  
 block5b_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block5b_se_reduce[0][0]']      
                                                                                                  
 block5b_se_excite (Multiply)   (None, 14, 14, 672)  0           ['block5b_activation[0][0]',     
                                                                  'block5b_se_expand[0][0]']      
                                                                                                  
 block5b_project_conv (Conv2D)  (None, 14, 14, 112)  75264       ['block5b_se_excite[0][0]']      
                                                                                                  
 block5b_project_bn (BatchNorma  (None, 14, 14, 112)  448        ['block5b_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block5b_drop (Dropout)         (None, 14, 14, 112)  0           ['block5b_project_bn[0][0]']     
                                                                                                  
 block5b_add (Add)              (None, 14, 14, 112)  0           ['block5b_drop[0][0]',           
                                                                  'block5a_project_bn[0][0]']     
                                                                                                  
 block5c_expand_conv (Conv2D)   (None, 14, 14, 672)  75264       ['block5b_add[0][0]']            
                                                                                                  
 block5c_expand_bn (BatchNormal  (None, 14, 14, 672)  2688       ['block5c_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5c_expand_activation (Act  (None, 14, 14, 672)  0          ['block5c_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5c_dwconv (DepthwiseConv2  (None, 14, 14, 672)  16800      ['block5c_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block5c_bn (BatchNormalization  (None, 14, 14, 672)  2688       ['block5c_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block5c_activation (Activation  (None, 14, 14, 672)  0          ['block5c_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block5c_se_squeeze (GlobalAver  (None, 672)         0           ['block5c_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5c_se_reshape (Reshape)   (None, 1, 1, 672)    0           ['block5c_se_squeeze[0][0]']     
                                                                                                  
 block5c_se_reduce (Conv2D)     (None, 1, 1, 28)     18844       ['block5c_se_reshape[0][0]']     
                                                                                                  
 block5c_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block5c_se_reduce[0][0]']      
                                                                                                  
 block5c_se_excite (Multiply)   (None, 14, 14, 672)  0           ['block5c_activation[0][0]',     
                                                                  'block5c_se_expand[0][0]']      
                                                                                                  
 block5c_project_conv (Conv2D)  (None, 14, 14, 112)  75264       ['block5c_se_excite[0][0]']      
                                                                                                  
 block5c_project_bn (BatchNorma  (None, 14, 14, 112)  448        ['block5c_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block5c_drop (Dropout)         (None, 14, 14, 112)  0           ['block5c_project_bn[0][0]']     
                                                                                                  
 block5c_add (Add)              (None, 14, 14, 112)  0           ['block5c_drop[0][0]',           
                                                                  'block5b_add[0][0]']            
                                                                                                  
 block6a_expand_conv (Conv2D)   (None, 14, 14, 672)  75264       ['block5c_add[0][0]']            
                                                                                                  
 block6a_expand_bn (BatchNormal  (None, 14, 14, 672)  2688       ['block6a_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block6a_expand_activation (Act  (None, 14, 14, 672)  0          ['block6a_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block6a_dwconv_pad (ZeroPaddin  (None, 17, 17, 672)  0          ['block6a_expand_activation[0][0]
 g2D)                                                            ']                               
                                                                                                  
 block6a_dwconv (DepthwiseConv2  (None, 7, 7, 672)   16800       ['block6a_dwconv_pad[0][0]']     
 D)                                                                                               
                                                                                                  
 block6a_bn (BatchNormalization  (None, 7, 7, 672)   2688        ['block6a_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block6a_activation (Activation  (None, 7, 7, 672)   0           ['block6a_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block6a_se_squeeze (GlobalAver  (None, 672)         0           ['block6a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block6a_se_reshape (Reshape)   (None, 1, 1, 672)    0           ['block6a_se_squeeze[0][0]']     
                                                                                                  
 block6a_se_reduce (Conv2D)     (None, 1, 1, 28)     18844       ['block6a_se_reshape[0][0]']     
                                                                                                  
 block6a_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block6a_se_reduce[0][0]']      
                                                                                                  
 block6a_se_excite (Multiply)   (None, 7, 7, 672)    0           ['block6a_activation[0][0]',     
                                                                  'block6a_se_expand[0][0]']      
                                                                                                  
 block6a_project_conv (Conv2D)  (None, 7, 7, 192)    129024      ['block6a_se_excite[0][0]']      
                                                                                                  
 block6a_project_bn (BatchNorma  (None, 7, 7, 192)   768         ['block6a_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block6b_expand_conv (Conv2D)   (None, 7, 7, 1152)   221184      ['block6a_project_bn[0][0]']     
                                                                                                  
 block6b_expand_bn (BatchNormal  (None, 7, 7, 1152)  4608        ['block6b_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block6b_expand_activation (Act  (None, 7, 7, 1152)  0           ['block6b_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block6b_dwconv (DepthwiseConv2  (None, 7, 7, 1152)  28800       ['block6b_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block6b_bn (BatchNormalization  (None, 7, 7, 1152)  4608        ['block6b_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block6b_activation (Activation  (None, 7, 7, 1152)  0           ['block6b_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block6b_se_squeeze (GlobalAver  (None, 1152)        0           ['block6b_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block6b_se_reshape (Reshape)   (None, 1, 1, 1152)   0           ['block6b_se_squeeze[0][0]']     
                                                                                                  
 block6b_se_reduce (Conv2D)     (None, 1, 1, 48)     55344       ['block6b_se_reshape[0][0]']     
                                                                                                  
 block6b_se_expand (Conv2D)     (None, 1, 1, 1152)   56448       ['block6b_se_reduce[0][0]']      
                                                                                                  
 block6b_se_excite (Multiply)   (None, 7, 7, 1152)   0           ['block6b_activation[0][0]',     
                                                                  'block6b_se_expand[0][0]']      
                                                                                                  
 block6b_project_conv (Conv2D)  (None, 7, 7, 192)    221184      ['block6b_se_excite[0][0]']      
                                                                                                  
 block6b_project_bn (BatchNorma  (None, 7, 7, 192)   768         ['block6b_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block6b_drop (Dropout)         (None, 7, 7, 192)    0           ['block6b_project_bn[0][0]']     
                                                                                                  
 block6b_add (Add)              (None, 7, 7, 192)    0           ['block6b_drop[0][0]',           
                                                                  'block6a_project_bn[0][0]']     
                                                                                                  
 block6c_expand_conv (Conv2D)   (None, 7, 7, 1152)   221184      ['block6b_add[0][0]']            
                                                                                                  
 block6c_expand_bn (BatchNormal  (None, 7, 7, 1152)  4608        ['block6c_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block6c_expand_activation (Act  (None, 7, 7, 1152)  0           ['block6c_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block6c_dwconv (DepthwiseConv2  (None, 7, 7, 1152)  28800       ['block6c_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block6c_bn (BatchNormalization  (None, 7, 7, 1152)  4608        ['block6c_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block6c_activation (Activation  (None, 7, 7, 1152)  0           ['block6c_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block6c_se_squeeze (GlobalAver  (None, 1152)        0           ['block6c_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block6c_se_reshape (Reshape)   (None, 1, 1, 1152)   0           ['block6c_se_squeeze[0][0]']     
                                                                                                  
 block6c_se_reduce (Conv2D)     (None, 1, 1, 48)     55344       ['block6c_se_reshape[0][0]']     
                                                                                                  
 block6c_se_expand (Conv2D)     (None, 1, 1, 1152)   56448       ['block6c_se_reduce[0][0]']      
                                                                                                  
 block6c_se_excite (Multiply)   (None, 7, 7, 1152)   0           ['block6c_activation[0][0]',     
                                                                  'block6c_se_expand[0][0]']      
                                                                                                  
 block6c_project_conv (Conv2D)  (None, 7, 7, 192)    221184      ['block6c_se_excite[0][0]']      
                                                                                                  
 block6c_project_bn (BatchNorma  (None, 7, 7, 192)   768         ['block6c_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block6c_drop (Dropout)         (None, 7, 7, 192)    0           ['block6c_project_bn[0][0]']     
                                                                                                  
 block6c_add (Add)              (None, 7, 7, 192)    0           ['block6c_drop[0][0]',           
                                                                  'block6b_add[0][0]']            
                                                                                                  
 block6d_expand_conv (Conv2D)   (None, 7, 7, 1152)   221184      ['block6c_add[0][0]']            
                                                                                                  
 block6d_expand_bn (BatchNormal  (None, 7, 7, 1152)  4608        ['block6d_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block6d_expand_activation (Act  (None, 7, 7, 1152)  0           ['block6d_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block6d_dwconv (DepthwiseConv2  (None, 7, 7, 1152)  28800       ['block6d_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block6d_bn (BatchNormalization  (None, 7, 7, 1152)  4608        ['block6d_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block6d_activation (Activation  (None, 7, 7, 1152)  0           ['block6d_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block6d_se_squeeze (GlobalAver  (None, 1152)        0           ['block6d_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block6d_se_reshape (Reshape)   (None, 1, 1, 1152)   0           ['block6d_se_squeeze[0][0]']     
                                                                                                  
 block6d_se_reduce (Conv2D)     (None, 1, 1, 48)     55344       ['block6d_se_reshape[0][0]']     
                                                                                                  
 block6d_se_expand (Conv2D)     (None, 1, 1, 1152)   56448       ['block6d_se_reduce[0][0]']      
                                                                                                  
 block6d_se_excite (Multiply)   (None, 7, 7, 1152)   0           ['block6d_activation[0][0]',     
                                                                  'block6d_se_expand[0][0]']      
                                                                                                  
 block6d_project_conv (Conv2D)  (None, 7, 7, 192)    221184      ['block6d_se_excite[0][0]']      
                                                                                                  
 block6d_project_bn (BatchNorma  (None, 7, 7, 192)   768         ['block6d_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block6d_drop (Dropout)         (None, 7, 7, 192)    0           ['block6d_project_bn[0][0]']     
                                                                                                  
 block6d_add (Add)              (None, 7, 7, 192)    0           ['block6d_drop[0][0]',           
                                                                  'block6c_add[0][0]']            
                                                                                                  
 block7a_expand_conv (Conv2D)   (None, 7, 7, 1152)   221184      ['block6d_add[0][0]']            
                                                                                                  
 block7a_expand_bn (BatchNormal  (None, 7, 7, 1152)  4608        ['block7a_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block7a_expand_activation (Act  (None, 7, 7, 1152)  0           ['block7a_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block7a_dwconv (DepthwiseConv2  (None, 7, 7, 1152)  10368       ['block7a_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block7a_bn (BatchNormalization  (None, 7, 7, 1152)  4608        ['block7a_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block7a_activation (Activation  (None, 7, 7, 1152)  0           ['block7a_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block7a_se_squeeze (GlobalAver  (None, 1152)        0           ['block7a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block7a_se_reshape (Reshape)   (None, 1, 1, 1152)   0           ['block7a_se_squeeze[0][0]']     
                                                                                                  
 block7a_se_reduce (Conv2D)     (None, 1, 1, 48)     55344       ['block7a_se_reshape[0][0]']     
                                                                                                  
 block7a_se_expand (Conv2D)     (None, 1, 1, 1152)   56448       ['block7a_se_reduce[0][0]']      
                                                                                                  
 block7a_se_excite (Multiply)   (None, 7, 7, 1152)   0           ['block7a_activation[0][0]',     
                                                                  'block7a_se_expand[0][0]']      
                                                                                                  
 block7a_project_conv (Conv2D)  (None, 7, 7, 320)    368640      ['block7a_se_excite[0][0]']      
                                                                                                  
 block7a_project_bn (BatchNorma  (None, 7, 7, 320)   1280        ['block7a_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 top_conv (Conv2D)              (None, 7, 7, 1280)   409600      ['block7a_project_bn[0][0]']     
                                                                                                  
 top_bn (BatchNormalization)    (None, 7, 7, 1280)   5120        ['top_conv[0][0]']               
                                                                                                  
 top_activation (Activation)    (None, 7, 7, 1280)   0           ['top_bn[0][0]']                 
                                                                                                  
 global_average_pooling2d (Glob  (None, 1280)        0           ['top_activation[0][0]']         
 alAveragePooling2D)                                                                              
                                                                                                  
 dense_5 (Dense)                (None, 3755)         4810155     ['global_average_pooling2d[0][0]'
                                                                 ]                                
                                                                                                  
==================================================================================================
Total params: 8,859,726
Trainable params: 8,817,703
Non-trainable params: 42,023
__________________________________________________________________________________________________
In [31]:
def train_efficientnet():
    print(f'number of classes: {num_classes}')

    history_path = 'history_efficientnet.pickle'
    save_history_callback = SaveHistoryCallback(history_path)

    train_dataset = load_datasets(train_path)
    test_dataset = load_datasets(test_path)

    train_dataset = train_dataset.map(preprocess_efficientnet).shuffle(100).batch(32).repeat()
    test_dataset = test_dataset.shuffle(100).map(preprocess_efficientnet).batch(32).repeat()

    print(f'train dataset: {train_dataset}')

    # build model
    model = effcientnetB0_model()
    # model.summary() # too long to display

    # latest checkpoints
    latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
    if latest_ckpt:
        print(f'model resumed from: {latest_ckpt}')
        model.load_weights(latest_ckpt)
    else:
        print('training from scratch')
        
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy'])
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(ckpt_path,
                                            save_weights_only=True,
                                            verbose=1,
                                            save_freq='epoch',
                                            save_best_only=True),
        save_history_callback
    ]
    try:
        model.fit(
            train_dataset,
            validation_data=test_dataset,
            validation_steps=1000,
            epochs=100,
            steps_per_epoch=1024,
            callbacks=callbacks,
            use_multiprocessing=True)
    except KeyboardInterrupt:
        logging.info('keras model saved. KeyboardInterrupt')
        save_history_callback.on_train_end()
        return model 
    
    model.save_weights(ckpt_path.format(epoch=0))
    model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr_eff.h5'))
    logging.info('All epoch finished. keras model saved.')
    return model
In [38]:
# train the model (computation heavy)
model = train_efficientnet()
number of classes: 3755
train dataset: <RepeatDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>
model resumed from: ./checkpoints/efficient_net\cn_ocr-1.ckpt
Epoch 1/100
   6/1024 [..............................] - ETA: 2:09 - loss: 0.0037 - accuracy: 1.0000WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0558s vs `on_train_batch_end` time: 0.0590s). Check your callbacks.
1024/1024 [==============================] - ETA: 0s - loss: 0.0130 - accuracy: 0.9972
Epoch 1: val_loss improved from inf to 0.54211, saving model to ./checkpoints/efficient_net\cn_ocr-1.ckpt
1024/1024 [==============================] - 163s 152ms/step - loss: 0.0130 - accuracy: 0.9972 - val_loss: 0.5421 - val_accuracy: 0.8860
Epoch 2/100
1024/1024 [==============================] - ETA: 0s - loss: 0.1858 - accuracy: 0.9505
Epoch 2: val_loss did not improve from 0.54211
1024/1024 [==============================] - 155s 152ms/step - loss: 0.1858 - accuracy: 0.9505 - val_loss: 0.7122 - val_accuracy: 0.8370
Epoch 3/100
1024/1024 [==============================] - ETA: 0s - loss: 0.1169 - accuracy: 0.9684
Epoch 3: val_loss did not improve from 0.54211
1024/1024 [==============================] - 154s 150ms/step - loss: 0.1169 - accuracy: 0.9684 - val_loss: 0.5773 - val_accuracy: 0.8697
Epoch 4/100
1024/1024 [==============================] - ETA: 0s - loss: 0.1093 - accuracy: 0.9713
Epoch 4: val_loss did not improve from 0.54211
1024/1024 [==============================] - 154s 150ms/step - loss: 0.1093 - accuracy: 0.9713 - val_loss: 0.5893 - val_accuracy: 0.8614
Epoch 5/100
1024/1024 [==============================] - ETA: 0s - loss: 0.1342 - accuracy: 0.9635
Epoch 5: val_loss did not improve from 0.54211
1024/1024 [==============================] - 154s 150ms/step - loss: 0.1342 - accuracy: 0.9635 - val_loss: 0.5836 - val_accuracy: 0.8711
Epoch 6/100
1024/1024 [==============================] - ETA: 0s - loss: 0.2706 - accuracy: 0.9327
Epoch 6: val_loss did not improve from 0.54211
1024/1024 [==============================] - 154s 150ms/step - loss: 0.2706 - accuracy: 0.9327 - val_loss: 0.6092 - val_accuracy: 0.8580
Epoch 7/100
1024/1024 [==============================] - ETA: 0s - loss: 0.1408 - accuracy: 0.9673
Epoch 7: val_loss did not improve from 0.54211
1024/1024 [==============================] - 154s 151ms/step - loss: 0.1408 - accuracy: 0.9673 - val_loss: 0.9957 - val_accuracy: 0.7847
Epoch 8/100
  40/1024 [>.............................] - ETA: 2:02 - loss: 0.0888 - accuracy: 0.9766
18:38:56 05.30 INFO 2887911528.py:50]: keras model saved. KeyboardInterrupt

Evaluaton and Prediction visualization¶

In [39]:
with open('history_efficientnet.pickle', 'rb') as f:
    history = pickle.load(f)

# Plot accuracy
plt.figure(figsize=(10, 7))
plt.plot(history['accuracy'])
plt.plot(history['val_accuracy'])
plt.title('Model accuracy - EfficientNetB0', fontsize=15)
plt.ylabel('Accuracy', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')

# show the highest accuracy in the graph
max_acc = max(history['val_accuracy'])
plt.annotate(f'max accuracy: {max_acc:.4f}', xy=(np.argmax(history['val_accuracy']), max_acc),
                xytext=(np.argmax(history['val_accuracy']) - 20, max_acc - 0.2),
                arrowprops=dict(facecolor='black', shrink=0.001), fontsize=15)
max_acc_train = max(history['accuracy'])
plt.annotate(f'max accuracy: {max_acc_train:.4f}', xy=(np.argmax(history['accuracy']), max_acc_train),
                xytext=(np.argmax(history['accuracy']) - 40, max_acc_train + 0.01),
                arrowprops=dict(facecolor='black', shrink=0.001), fontsize=15)
plt.show()
In [40]:
# Plot loss
plt.figure(figsize=(10, 7))
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.title('Model loss - EfficientNetB0', fontsize=15)
plt.ylabel('Loss', fontsize=15)
plt.xlabel('Epoch', fontsize=15)
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
In [41]:
# Load the best model
model = effcientnetB0_model()
model.load_weights(tf.train.latest_checkpoint(os.path.dirname(ckpt_path)))
Out[41]:
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x2504d884130>
In [42]:
def display_result(model, ds, characters):
    ds = ds.shuffle(100).map(preprocess_efficientnet).batch(32).repeat()
    _, axe = plt.subplots(3, 3, figsize=(15, 15))
    plt.rcParams["font.sans-serif"]=["SimHei"] #设置字体
    plt.rcParams["axes.unicode_minus"]=False #该语句解决图像中的“-”负号的乱码问题
    for _, data in enumerate(ds.take(1)):
        images, labels = data
        predictions = model.predict(images)
        for i in range(9):
            axe[i//3, i%3].imshow(images[i].numpy().astype('uint8'), cmap='gray')
            axe[i//3, i%3].axis('on')
            axe[i//3, i%3].set_title(f'pred: {characters[np.argmax(predictions[i])]}, label: {characters[labels[i]]}, Is correct: {np.argmax(predictions[i]) == labels[i]}')
    plt.show()
In [43]:
test_ds = load_datasets(test_path)
display_result(model, test_ds, all_characters)
1/1 [==============================] - 1s 949ms/step

We can see that for the model with EfficientNet, the validation accuracy can reach 85% with about 80 epoches, which is a huge improvement compared with the simple CNN model.