이전 mnist 사용

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

def get_mnist_model():
    inputs = keras.Input(shape=(28*28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28*28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28*28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

fit( ) 메서드에서 callbacks 매개변수 사용하기

# fit( ) 메서드의 callback 매개변수에 전달할 콜백 리스트 만들기
callbacks_list = [
    keras.callbacks.EarlyStopping(   # 성능 향상이 멈추면 훈련을 중지
        monitor="val_accuracy",   # 모델의 검증 정확도를 모니터링
        patience=2,   # 두번의 에포크 동안 정확도가 향상되지 않으면 훈련을 중지
    ),
    keras.callbacks.ModelCheckpoint(   # 매 에포크 끝에서 가중치를 저장
        filepath="checkpoint_path.keras",   # 모델 파일의 저장 경로와 파일명
        monitor="val_loss",   # val_loss가 좋아지지 않으면 모델 파일을 덮어쓰지 않고,
        save_best_only=True,   # 훈련하는 동안 가장 좋은 모델이 저장된다.
    )
]
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])   # 정확도를 모니터링하므로 accuracy가 모델 지표
model.fit(train_images, train_labels,
          epochs=20,
          **callbacks=callbacks_list**,
          validation_data=(val_images, val_labels))   # 콜백이 검증 손실과 검증 정확도를 모니터링하기 때문에 validation_data 매개변수로 검증 데이터를 전달해야 함

Epoch 1/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.2914 - accuracy: 0.9145 - val_loss: 0.1467 - val_accuracy: 0.9573 Epoch 2/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.1644 - accuracy: 0.9538 - val_loss: 0.1174 - val_accuracy: 0.9683 Epoch 3/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.1363 - accuracy: 0.9632 - val_loss: 0.1154 - val_accuracy: 0.9705 Epoch 4/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.1247 - accuracy: 0.9678 - val_loss: 0.1100 - val_accuracy: 0.9720 Epoch 5/20 1563/1563 [==============================] - 11s 7ms/step - loss: 0.1167 - accuracy: 0.9709 - val_loss: 0.1198 - val_accuracy: 0.9730 Epoch 6/20 1563/1563 [==============================] - 11s 7ms/step - loss: 0.1074 - accuracy: 0.9739 - val_loss: 0.1229 - val_accuracy: 0.9727 Epoch 7/20 1563/1563 [==============================] - 10s 7ms/step - loss: 0.1046 - accuracy: 0.9742 - val_loss: 0.1228 - val_accuracy: 0.9757 Epoch 8/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.0976 - accuracy: 0.9762 - val_loss: 0.1085 - val_accuracy: 0.9781 Epoch 9/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.0984 - accuracy: 0.9777 - val_loss: 0.1165 - val_accuracy: 0.9781 Epoch 10/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.0931 - accuracy: 0.9795 - val_loss: 0.1261 - val_accuracy: 0.9778 <keras.callbacks.History at 0x7f9dae027310>

10회째에서 중단됨

※ 저장된 모델을 로드하려면

model = keras.model.load_model(”checkpoint_path.keras)