<aside> 🖥️ on_epoch_begin(epoch, logs) : 각 에포크가 시작될 때 호출 (epoch 인덱스를 매개변수로 받음) on_epoch_end(epoch, logs) : 각 에포크가 끝날 때 호출 (epoch 인덱스를 매개변수로 받음) on_batch_begin(batch, logs) : 각 배치 처리가 시작하기 전에 호출 (batch 인덱스를 매개변수로 받음) on_batch_end(batch, logs) : 각 배치 처리가 끝난 후에 호출 (batch 인덱스를 매개변수로 받음) on_train_begin(logs) : 훈련이 시작될 때 호출 on_train_end(logs) : 훈련이 끝날 때 호출

</aside>

이전 mnist 사용

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

def get_mnist_model():
    inputs = keras.Input(shape=(28*28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28*28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28*28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

Callback 클래스를 상속하여 사용자 정의 콜백 만들기

(훈련 도중 배치 손실 값을 리스트에 추가하고, 에포크 끝에서 이 값을 그래프로 저장)

from matplotlib import pyplot as plt

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []
    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))
    def on_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses, label="Training loss for each batch")
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"plot_at_epoch_{epoch}")
        self.per_batch_losses = []

콜백 테스트

model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=10,
          callbacks=[LossHistory()],
          validation_data=(val_images, val_labels))

Epoch 1/10 1563/1563 [==============================] - 14s 9ms/step - loss: 0.2934 - accuracy: 0.9126 - val_loss: 0.1515 - val_accuracy: 0.9554 Epoch 2/10 1563/1563 [==============================] - 13s 8ms/step - loss: 0.1602 - accuracy: 0.9554 - val_loss: 0.1324 - val_accuracy: 0.9644 Epoch 3/10 1563/1563 [==============================] - 14s 9ms/step - loss: 0.1365 - accuracy: 0.9621 - val_loss: 0.1091 - val_accuracy: 0.9732 Epoch 4/10 1563/1563 [==============================] - 13s 8ms/step - loss: 0.1249 - accuracy: 0.9677 - val_loss: 0.1042 - val_accuracy: 0.9741 Epoch 5/10 1563/1563 [==============================] - 13s 8ms/step - loss: 0.1132 - accuracy: 0.9708 - val_loss: 0.1073 - val_accuracy: 0.9761 Epoch 6/10 1563/1563 [==============================] - 13s 8ms/step - loss: 0.1093 - accuracy: 0.9734 - val_loss: 0.1032 - val_accuracy: 0.9776 Epoch 7/10 1563/1563 [==============================] - 13s 8ms/step - loss: 0.1065 - accuracy: 0.9753 - val_loss: 0.1073 - val_accuracy: 0.9770 Epoch 8/10 1563/1563 [==============================] - 16s 10ms/step - loss: 0.0990 - accuracy: 0.9770 - val_loss: 0.1110 - val_accuracy: 0.9789 Epoch 9/10 1563/1563 [==============================] - 14s 9ms/step - loss: 0.0967 - accuracy: 0.9780 - val_loss: 0.1234 - val_accuracy: 0.9771 Epoch 10/10 1563/1563 [==============================] - 13s 8ms/step - loss: 0.0934 - accuracy: 0.9784 - val_loss: 0.1128 - val_accuracy: 0.9796 <keras.callbacks.History at 0x7effccb4ed10>

※ 10개의 그래프가 저장되었다.