Callbacks

 

학습에 도움되는 여러가지 도구들을 Callback 이라는 형태로 사용할 수 있다.

Import packages

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')

import tensorflow as tf
from tensorflow import keras

1. Load dataset

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

housing = fetch_california_housing()
X_train_full, X_test, y_train_full, y_test = train_test_split(housing.data, housing.target)
X_train, X_valid, y_train, y_valid         = train_test_split(X_train_full, y_train_full)

2. Preprocessing

scaler  = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_valid = scaler.transform(X_valid)
X_test  = scaler.transform(X_test)

3. Modeling

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow_addons.metrics import RSquare

model = Sequential([
    Dense(30, input_shape=X_train[0].shape, kernel_initializer='he_normal', activation='relu'),
    Dense(30, kernel_initializer='he_normal', activation='relu'),
    Dense(1)
])
model.compile(loss='mse', optimizer='sgd', metrics=[RSquare(y_shape=(1,))])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense (Dense)                (None, 30)                270
_________________________________________________________________
dense_1 (Dense)              (None, 30)                930
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 31
=================================================================
Total params: 1,231
Trainable params: 1,231
Non-trainable params: 0
_________________________________________________________________

4. Training

4.1 Prepare Callbacks

import os

os.makedirs('ckpt', exist_ok=True)
os.makedirs('result', exist_ok=True)
os.makedirs('tensorboard', exist_ok=True)
from tensorflow.keras.callbacks import Callback

class LossPlot(Callback):
    def __init__(self, dir_path, figsize=(15, 8)):
        self.dir_path = dir_path
        self.figsize  = figsize
        mkdir(self.dir_path)
        self.metrics = pd.DataFrame()
        self.best_epoch = -1
    def on_epoch_end(self, epoch, logs={}):
        self.metrics    = self.metrics.append(logs, ignore_index=True)
        self.best_epoch = np.argmin(self.metrics['val_loss'])
        self._save_fig(epoch)
    def _save_fig(self, epoch):
        fig, ax = plt.subplots(figsize=self.figsize)
        for col in self.metrics:
            self.metrics[col].plot(linestyle='-' if 'loss' in col else '--', color='r' if col.startswith('val') else 'b', ax=ax)
        ax.axvline(self.best_epoch, color='k', ls=':', lw=3)
        ax.set_xlabel('epoch');  ax.set_ylabel('metric')
        ax.set_yscale('log')
        ax.legend(loc='upper right')
        fig.tight_layout()
        fig.savefig(join(self.dir_path, 'metrics.png'))
        plt.close(fig)
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger, ReduceLROnPlateau, TensorBoard

callbacks = [
    ModelCheckpoint('ckpt/[{epoch:03d} epoch] val_loss: {val_loss:.4f}.h5', save_best_only=True),
    EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
    CSVLogger('result/logs.csv'),
    ReduceLROnPlateau(monitor='val_loss', factor=0.9, patience=5, verbose=1),
    TensorBoard('tensorboard', write_graph=True, write_images=True),
    LossPlot('result')
]

4.2 Training

history = model.fit(X_train, y_train, epochs=100, validation_data=(X_valid, y_valid), callbacks=callbacks, use_multiprocessing=True)
Epoch 1/100
  1/363 [..............................] - ETA: 3:11 - loss: 14.5933 - r_square: -7.8717WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0015s vs `on_train_batch_end` time: 0.0047s). Check your callbacks.
363/363 [==============================] - 1s 2ms/step - loss: 1.0000 - r_square: 0.2456 - val_loss: 0.5235 - val_r_square: 0.6158
Epoch 2/100
363/363 [==============================] - 1s 2ms/step - loss: 1.8754 - r_square: -0.4148 - val_loss: 0.5246 - val_r_square: 0.6150
Epoch 3/100
363/363 [==============================] - 1s 2ms/step - loss: 0.4900 - r_square: 0.6304 - val_loss: 0.4827 - val_r_square: 0.6458
Epoch 4/100
363/363 [==============================] - 1s 2ms/step - loss: 0.4541 - r_square: 0.6574 - val_loss: 0.4621 - val_r_square: 0.6609
Epoch 5/100
363/363 [==============================] - 1s 2ms/step - loss: 0.4350 - r_square: 0.6718 - val_loss: 0.4426 - val_r_square: 0.6752
Epoch 6/100
363/363 [==============================] - 1s 2ms/step - loss: 0.4269 - r_square: 0.6780 - val_loss: 0.4535 - val_r_square: 0.6672
Epoch 7/100
363/363 [==============================] - 1s 2ms/step - loss: 0.4133 - r_square: 0.6882 - val_loss: 0.4336 - val_r_square: 0.6818
Epoch 8/100
363/363 [==============================] - 1s 2ms/step - loss: 0.4129 - r_square: 0.6885 - val_loss: 0.4310 - val_r_square: 0.6837
Epoch 9/100
363/363 [==============================] - 1s 2ms/step - loss: 0.4005 - r_square: 0.6979 - val_loss: 0.4155 - val_r_square: 0.6951
Epoch 10/100
363/363 [==============================] - 1s 2ms/step - loss: 0.4270 - r_square: 0.6779 - val_loss: 0.4259 - val_r_square: 0.6874
Epoch 11/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3995 - r_square: 0.6986 - val_loss: 0.4047 - val_r_square: 0.7030
Epoch 12/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3810 - r_square: 0.7126 - val_loss: 0.4062 - val_r_square: 0.7019
Epoch 13/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3785 - r_square: 0.7145 - val_loss: 0.3947 - val_r_square: 0.7104
Epoch 14/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3708 - r_square: 0.7203 - val_loss: 0.3936 - val_r_square: 0.7112
Epoch 15/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3668 - r_square: 0.7233 - val_loss: 0.3924 - val_r_square: 0.7120
Epoch 16/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3644 - r_square: 0.7251 - val_loss: 0.3837 - val_r_square: 0.7184
Epoch 17/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3587 - r_square: 0.7294 - val_loss: 0.3829 - val_r_square: 0.7190
Epoch 18/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3556 - r_square: 0.7318 - val_loss: 0.3867 - val_r_square: 0.7163
Epoch 19/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3534 - r_square: 0.7334 - val_loss: 0.3687 - val_r_square: 0.7294
Epoch 20/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3476 - r_square: 0.7378 - val_loss: 0.3658 - val_r_square: 0.7315
Epoch 21/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3449 - r_square: 0.7398 - val_loss: 0.3701 - val_r_square: 0.7284
Epoch 22/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3431 - r_square: 0.7411 - val_loss: 0.3882 - val_r_square: 0.7151
Epoch 23/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3391 - r_square: 0.7442 - val_loss: 0.3668 - val_r_square: 0.7308
Epoch 24/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3379 - r_square: 0.7451 - val_loss: 0.3720 - val_r_square: 0.7270
Epoch 25/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3337 - r_square: 0.7483 - val_loss: 0.3542 - val_r_square: 0.7401
Epoch 26/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3328 - r_square: 0.7490 - val_loss: 0.3630 - val_r_square: 0.7336
Epoch 27/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3302 - r_square: 0.7509 - val_loss: 0.3510 - val_r_square: 0.7425
Epoch 28/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3289 - r_square: 0.7519 - val_loss: 0.3510 - val_r_square: 0.7424
Epoch 29/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3276 - r_square: 0.7528 - val_loss: 0.3640 - val_r_square: 0.7329
Epoch 30/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3243 - r_square: 0.7553 - val_loss: 0.3564 - val_r_square: 0.7384
Epoch 31/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3226 - r_square: 0.7567 - val_loss: 0.3566 - val_r_square: 0.7383
Epoch 32/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3208 - r_square: 0.7580 - val_loss: 0.3590 - val_r_square: 0.7366

Epoch 00032: ReduceLROnPlateau reducing learning rate to 0.008999999798834325.
Epoch 33/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3173 - r_square: 0.7606 - val_loss: 0.3526 - val_r_square: 0.7413
Epoch 34/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3152 - r_square: 0.7622 - val_loss: 0.3538 - val_r_square: 0.7404
Epoch 35/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3136 - r_square: 0.7635 - val_loss: 0.3512 - val_r_square: 0.7423
Epoch 36/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3121 - r_square: 0.7645 - val_loss: 0.3504 - val_r_square: 0.7429
Epoch 37/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3119 - r_square: 0.7647 - val_loss: 0.3304 - val_r_square: 0.7576
Epoch 38/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3112 - r_square: 0.7653 - val_loss: 0.3497 - val_r_square: 0.7434
Epoch 39/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3082 - r_square: 0.7675 - val_loss: 0.3335 - val_r_square: 0.7553
Epoch 40/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3096 - r_square: 0.7664 - val_loss: 0.3332 - val_r_square: 0.7555
Epoch 41/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3090 - r_square: 0.7669 - val_loss: 0.3276 - val_r_square: 0.7596
Epoch 42/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3076 - r_square: 0.7680 - val_loss: 0.3297 - val_r_square: 0.7581
Epoch 43/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3054 - r_square: 0.7696 - val_loss: 0.3339 - val_r_square: 0.7550
Epoch 44/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3055 - r_square: 0.7695 - val_loss: 0.3702 - val_r_square: 0.7283
Epoch 45/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3057 - r_square: 0.7694 - val_loss: 0.3333 - val_r_square: 0.7554
Epoch 46/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3042 - r_square: 0.7705 - val_loss: 0.3329 - val_r_square: 0.7557

Epoch 00046: ReduceLROnPlateau reducing learning rate to 0.008099999651312828.
Epoch 47/100
363/363 [==============================] - 1s 2ms/step - loss: 0.3015 - r_square: 0.7726 - val_loss: 0.3366 - val_r_square: 0.7530
Epoch 48/100
363/363 [==============================] - 1s 2ms/step - loss: 0.2990 - r_square: 0.7744 - val_loss: 0.3514 - val_r_square: 0.7421
Epoch 49/100
363/363 [==============================] - 1s 2ms/step - loss: 0.2993 - r_square: 0.7742 - val_loss: 0.3492 - val_r_square: 0.7438
Epoch 50/100
363/363 [==============================] - 1s 2ms/step - loss: 0.2995 - r_square: 0.7741 - val_loss: 0.3326 - val_r_square: 0.7559
Epoch 51/100
363/363 [==============================] - 1s 2ms/step - loss: 0.2986 - r_square: 0.7748 - val_loss: 0.3434 - val_r_square: 0.7480

Epoch 00051: ReduceLROnPlateau reducing learning rate to 0.007289999350905419.
fig, ax = plt.subplots(figsize=(20, 10))
ax_twin = ax.twinx()
pd.DataFrame(history.history, columns=['loss', 'val_loss']).plot(xlabel='epoch', ylabel='loss', ax=ax)
pd.DataFrame(history.history, columns=['r_square', 'val_r_square']).plot(linestyle='--', ylabel='r_square', ax=ax_twin);

png

5. Evaluation

model.evaluate(X_test, y_test)
162/162 [==============================] - 0s 1ms/step - loss: 0.3388 - r_square: 0.7435

[0.3387697637081146, 0.7434945702552795]
X_new = X_test[:3]
y_pred = model.predict(X_new);  y_pred
array([[3.197331 ],
       [2.6825705],
       [1.7492278]], dtype=float32)