표준 워크플로우 : compile(), fit(), evaluate(), predict()

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

def get_mnist_model():   # 나중에 재사용할 수 있도록 함수로 만듦
    inputs = keras.Input(shape=(28*28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

# 데이터를 읽어들이고 검증을 위해 일부를 떼어 놓음
(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28*28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28*28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(val_images, val_labels))
test_metrics = model.evaluate(test_images, test_labels)
predictions = model.predict(test_images)

Downloading data from

https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

11490434/11490434 [==============================] - 1s 0us/step Epoch 1/3 1563/1563 [==============================] - 16s 9ms/step - loss: 0.2976 - accuracy: 0.9126 - val_loss: 0.1586 - val_accuracy: 0.9544 Epoch 2/3 1563/1563 [==============================] - 14s 9ms/step - loss: 0.1642 - accuracy: 0.9548 - val_loss: 0.1172 - val_accuracy: 0.9688 Epoch 3/3 1563/1563 [==============================] - 13s 9ms/step - loss: 0.1415 - accuracy: 0.9618 - val_loss: 0.1307 - val_accuracy: 0.9671 313/313 [==============================] - 1s 3ms/step - loss: 0.1298 - accuracy: 0.9670 313/313 [==============================] - 1s 3ms/step

사용자 정의 지표 만들기

Metric 클래스를 상속하여 사용자 정의 지표 만들기

(평균 제곱근 오차(Root Mean Squared Error, RMSE)를 계산하는 지표)

Untitled

import tensorflow as tf

class RootMeanSquaredError(keras.metrics.Metric):
    def __init__(self, name="rmse", **kwargs):    # 생성자에서 add_weight 메서드를 사용하여 상태변수를 만든다.
        super().__init__(name=name, **kwargs)
        self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")
        self.total_samples = self.add_weight(
            name="total_samples", initializer="zeros", dtype="int32")
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])   # y_pred는 각 클래스에 대한 확률 데이터이므로 정수 레이블인 y_true를 이에 맞추어 one-hot 인코딩함
        mse = tf.reduce_sum(tf.square(y_true - y_pred))
        self.mse_sum.assign_add(mse)
        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)

    def result(self):
        return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))   # 현재 지표값을 반환

    def reset_state(self):   # reset_state는 객체를 다시 생성하지 않고 초기화하는 기능을 수행
        self.mse_sum.assign(0.)
        self.total_samples.assign(0)
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy", **RootMeanSquaredError()**])
model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(val_images, val_labels))
test_metrics = model.evaluate(test_images, test_labels)

Epoch 1/3 1563/1563 [==============================] - 15s 9ms/step - loss: 0.2933 - accuracy: 0.9138 - rmse: 7.1807 - val_loss: 0.1453 - val_accuracy: 0.9593 - val_rmse: 7.3677 Epoch 2/3 1563/1563 [==============================] - 14s 9ms/step - loss: 0.1653 - accuracy: 0.9536 - rmse: 7.3547 - val_loss: 0.1218 - val_accuracy: 0.9683 - val_rmse: 7.4035 Epoch 3/3 1563/1563 [==============================] - 14s 9ms/step - loss: 0.1396 - accuracy: 0.9632 - rmse: 7.3910 - val_loss: 0.1180 - val_accuracy: 0.9703 - val_rmse: 7.4215 313/313 [==============================] - 1s 4ms/step - loss: 0.1073 - accuracy: 0.9718 - rmse: 7.4346

※ fit( ) 메서드 출력 내용에 rmse가 표시된다!