Subclass API

 

Import packages

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')

import tensorflow as tf
from tensorflow import keras

1. Load dataset & preprocessing

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

housing = fetch_california_housing()
X_train_full, X_test, y_train_full, y_test = train_test_split(housing.data, housing.target)
X_train, X_valid, y_train, y_valid         = train_test_split(X_train_full, y_train_full)

scaler  = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_valid = scaler.transform(X_valid)
X_test  = scaler.transform(X_test)

X_train_A, X_train_B = X_train[:, :5], X_train[:, 2:]
X_valid_A, X_valid_B = X_valid[:, :5], X_valid[:, 2:]
X_test_A,  X_test_B  = X_test[:, :5],  X_test[:, 2:]
X_new_A,   X_new_B   = X_test_A[:3],   X_test_B[:3]

2. Modeling

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Concatenate
from tensorflow_addons.metrics import RSquare

class WideAndDeepModel(Model):
    def __init__(self, units=30, activation='relu', **kwargs):
        super().__init__(**kwargs)
        self.hidden1 = Dense(units, activation=activation)
        self.hidden2 = Dense(units, activation=activation)
        self.main_output = Dense(1)
        self.aux_output  = Dense(1)
    def call(self, inputs):
        input_A, input_B = inputs
        hidden1 = self.hidden1(input_B)
        hidden2 = self.hidden2(hidden1)
        concat  = Concatenate()([input_A, hidden2])
        main_output = self.main_output(concat)
        aux_output  = self.aux_output(hidden2)
        return main_output, aux_output

model = WideAndDeepModel()
model.compile(loss='mse', optimizer='sgd', metrics=[RSquare(y_shape=(1,))])
# model.summary()  # subclass API는 summary() 불가

3. Training

history = model.fit([X_train_A, X_train_B], [y_train, y_train], epochs=5, validation_data=([X_valid_A, X_valid_B], [y_valid, y_valid]))
pd.DataFrame(history.history).plot(figsize=(20, 5), grid=True, xlabel='epoch', ylabel='score');
Epoch 1/5
363/363 [==============================] - 3s 5ms/step - loss: 1.7983 - output_1_loss: 0.7166 - output_2_loss: 1.0817 - output_1_r_square: 0.4640 - output_2_r_square: 0.1909 - val_loss: 1.1375 - val_output_1_loss: 0.4818 - val_output_2_loss: 0.6557 - val_output_1_r_square: 0.6405 - val_output_2_r_square: 0.5108
Epoch 2/5
363/363 [==============================] - 2s 4ms/step - loss: 1.0727 - output_1_loss: 0.4692 - output_2_loss: 0.6035 - output_1_r_square: 0.6491 - output_2_r_square: 0.5486 - val_loss: 1.0769 - val_output_1_loss: 0.4757 - val_output_2_loss: 0.6011 - val_output_1_r_square: 0.6450 - val_output_2_r_square: 0.5515
Epoch 3/5
363/363 [==============================] - 2s 4ms/step - loss: 1.0806 - output_1_loss: 0.5072 - output_2_loss: 0.5735 - output_1_r_square: 0.6207 - output_2_r_square: 0.5710 - val_loss: 1.0273 - val_output_1_loss: 0.4612 - val_output_2_loss: 0.5661 - val_output_1_r_square: 0.6559 - val_output_2_r_square: 0.5776
Epoch 4/5
363/363 [==============================] - 2s 4ms/step - loss: 1.2061 - output_1_loss: 0.6781 - output_2_loss: 0.5279 - output_1_r_square: 0.4928 - output_2_r_square: 0.6051 - val_loss: 0.9958 - val_output_1_loss: 0.4461 - val_output_2_loss: 0.5496 - val_output_1_r_square: 0.6671 - val_output_2_r_square: 0.5899
Epoch 5/5
363/363 [==============================] - 2s 4ms/step - loss: 1.2297 - output_1_loss: 0.5375 - output_2_loss: 0.6921 - output_1_r_square: 0.5979 - output_2_r_square: 0.4823 - val_loss: 0.9707 - val_output_1_loss: 0.4188 - val_output_2_loss: 0.5519 - val_output_1_r_square: 0.6875 - val_output_2_r_square: 0.5882

png

4. Evaluation

model.evaluate([X_test_A, X_test_B], [y_test, y_test])
162/162 [==============================] - 0s 3ms/step - loss: 2.4178 - output_1_loss: 1.1732 - output_2_loss: 1.2447 - output_1_r_square: 0.1064 - output_2_r_square: 0.0519

[2.417813777923584,
 1.1731566190719604,
 1.2446565628051758,
 0.10637742280960083,
 0.051914215087890625]
y_pred = model.predict([X_new_A, X_new_B]);  y_pred
(array([[1.2351167],
        [1.9935203],
        [4.297159 ]], dtype=float32),
 array([[0.90587264],
        [2.0000696 ],
        [3.281421  ]], dtype=float32))