from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
def get_mnist_model():
inputs = keras.Input(shape=(28*28,))
features = layers.Dense(512, activation="relu")(inputs)
features = layers.Dropout(0.5)(features)
outputs = layers.Dense(10, activation="softmax")(features)
model = keras.Model(inputs, outputs)
return model
(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28*28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28*28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]
# fit( ) 메서드의 callback 매개변수에 전달할 콜백 리스트 만들기
callbacks_list = [
keras.callbacks.EarlyStopping( # 성능 향상이 멈추면 훈련을 중지
monitor="val_accuracy", # 모델의 검증 정확도를 모니터링
patience=2, # 두번의 에포크 동안 정확도가 향상되지 않으면 훈련을 중지
),
keras.callbacks.ModelCheckpoint( # 매 에포크 끝에서 가중치를 저장
filepath="checkpoint_path.keras", # 모델 파일의 저장 경로와 파일명
monitor="val_loss", # val_loss가 좋아지지 않으면 모델 파일을 덮어쓰지 않고,
save_best_only=True, # 훈련하는 동안 가장 좋은 모델이 저장된다.
)
]
model = get_mnist_model()
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]) # 정확도를 모니터링하므로 accuracy가 모델 지표
model.fit(train_images, train_labels,
epochs=20,
**callbacks=callbacks_list**,
validation_data=(val_images, val_labels)) # 콜백이 검증 손실과 검증 정확도를 모니터링하기 때문에 validation_data 매개변수로 검증 데이터를 전달해야 함
Epoch 1/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.2914 - accuracy: 0.9145 - val_loss: 0.1467 - val_accuracy: 0.9573 Epoch 2/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.1644 - accuracy: 0.9538 - val_loss: 0.1174 - val_accuracy: 0.9683 Epoch 3/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.1363 - accuracy: 0.9632 - val_loss: 0.1154 - val_accuracy: 0.9705 Epoch 4/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.1247 - accuracy: 0.9678 - val_loss: 0.1100 - val_accuracy: 0.9720 Epoch 5/20 1563/1563 [==============================] - 11s 7ms/step - loss: 0.1167 - accuracy: 0.9709 - val_loss: 0.1198 - val_accuracy: 0.9730 Epoch 6/20 1563/1563 [==============================] - 11s 7ms/step - loss: 0.1074 - accuracy: 0.9739 - val_loss: 0.1229 - val_accuracy: 0.9727 Epoch 7/20 1563/1563 [==============================] - 10s 7ms/step - loss: 0.1046 - accuracy: 0.9742 - val_loss: 0.1228 - val_accuracy: 0.9757 Epoch 8/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.0976 - accuracy: 0.9762 - val_loss: 0.1085 - val_accuracy: 0.9781 Epoch 9/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.0984 - accuracy: 0.9777 - val_loss: 0.1165 - val_accuracy: 0.9781 Epoch 10/20 1563/1563 [==============================] - 10s 6ms/step - loss: 0.0931 - accuracy: 0.9795 - val_loss: 0.1261 - val_accuracy: 0.9778 <keras.callbacks.History at 0x7f9dae027310>
10회째에서 중단됨
※ 저장된 모델을 로드하려면
model = keras.model.load_model(”checkpoint_path.keras)