3. Model 서브클래싱
- Model 클래스를 상속하여 모델 구축을 한다.
- _ _ init _ _( ) 메서드에서 모델이 사용할 층을 정의한다.
- call( ) 메서드에서 앞서 만든 층을 사용하여 모델의 정방향 패스를 정의한다.
- 서브클래스의 객체를 만들고 데이터와 함께 호출하여 가중치를 만든다.
- 층이 연결되는 방식이 call( ) 메서드안에 감추어지기 때문에, summary( ) 메서드나 plot_model( ) 함수로 모델의 구조를 나타낼 수 없다.
- 새로운 파이썬 객체를 만드는 것이므로 기능적으로 유연하나 오류 가능성이 있다.
이전 예제를 서브클래싱 모델로 다시 만들기
간단한 서브클래싱 모델 정의
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
vocabulary_size = 10000
num_tags = 100
num_departments =4
num_samples = 1280
# 더미 입력 데이터
title_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))
# 더미 타킷 데이터
priority_data = np.random.random(size=(num_samples, 1))
department_data = np.random.randint(0, 2, size=(num_samples, num_departments))
class CustomerTicketModel(keras.Model):
def __init__(self, num_departments):
super().__init__() # 부모 클래스(keras.Model)의 생성자를 호출
# 생성자에서 층을 생성한다
self.concat_layer = layers.Concatenate()
self.mixing_layer = layers.Dense(64, activation="relu")
self.priority_scorer = layers.Dense(1, activation="sigmoid")
self.department_classifier = layers.Dense(num_departments, activation="softmax")
def call(self, inputs): # call()메서드에서 정방향 패스를 정의한다
title = inputs["title"]
text_body = inputs["text_body"]
tags = inputs["tags"]
features = self.concat_layer([title, text_body, tags])
features = self.mixing_layer(features)
priority = self.priority_scorer(features)
department = self.department_classifier(features)
return priority, department
클래스의 객체 만들기
model = CustomerTicketModel(num_departments=4)
priority, department = model(
{"title": title_data, "text_body": text_body_data, "tags": tags_data})
컴파일, 훈련
# 손실(loss)과 측정(metrics) 지표로 전달하는 값은 call() 메서드가 반환하는 것과 정확하게
# 일치해야 한다.(여기서는 2개의 원소를 가진 리스트)
model.compile(optimizer="rmsprop",
loss=["mean_squared_error", "categorical_crossentropy"],
metrics=[["mean_absolute_error"],["accuracy"]])
# 입력 데이터와 타깃 데이터의 구조는 call() 메서드가 기대하는 것과 정확하게 일치해야 한다.
# (입력은 3개의 키와 값을 가진 딕셔너리, 출력은 2개의 원소를 가진 리스트)
model.fit({"title": title_data, "text_body": text_body_data, "tags": tags_data},
[priority_data, department_data],
epochs=1)
model.evaluate({"title": title_data, "text_body": text_body_data, "tags": tags_data},
[priority_data, department_data])
priority_preds, department_preds = model.predict(
{"title": title_data, "text_body": text_body_data, "tags": tags_data})