3. Model 서브클래싱

이전 예제를 서브클래싱 모델로 다시 만들기

간단한 서브클래싱 모델 정의

from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

vocabulary_size = 10000
num_tags = 100
num_departments =4
num_samples = 1280

# 더미 입력 데이터
title_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))

# 더미 타킷 데이터
priority_data = np.random.random(size=(num_samples, 1))
department_data = np.random.randint(0, 2, size=(num_samples, num_departments))

class CustomerTicketModel(keras.Model):

    def __init__(self, num_departments):
        super().__init__()   # 부모 클래스(keras.Model)의 생성자를 호출
        # 생성자에서 층을 생성한다
        self.concat_layer = layers.Concatenate()
        self.mixing_layer = layers.Dense(64, activation="relu")
        self.priority_scorer = layers.Dense(1, activation="sigmoid")
        self.department_classifier = layers.Dense(num_departments, activation="softmax")

    def call(self, inputs):    # call()메서드에서 정방향 패스를 정의한다
        title = inputs["title"]
        text_body = inputs["text_body"]
        tags = inputs["tags"]
        features = self.concat_layer([title, text_body, tags])
        features = self.mixing_layer(features)
        priority = self.priority_scorer(features)
        department = self.department_classifier(features)
        return priority, department

클래스의 객체 만들기

model = CustomerTicketModel(num_departments=4)
priority, department = model(
    {"title": title_data, "text_body": text_body_data, "tags": tags_data})

컴파일, 훈련

# 손실(loss)과 측정(metrics) 지표로 전달하는 값은 call() 메서드가 반환하는 것과 정확하게
# 일치해야 한다.(여기서는 2개의 원소를 가진 리스트)
model.compile(optimizer="rmsprop",
              loss=["mean_squared_error", "categorical_crossentropy"],
              metrics=[["mean_absolute_error"],["accuracy"]])

# 입력 데이터와 타깃 데이터의 구조는 call() 메서드가 기대하는 것과 정확하게 일치해야 한다.
# (입력은 3개의 키와 값을 가진 딕셔너리, 출력은 2개의 원소를 가진 리스트) 
model.fit({"title": title_data, "text_body": text_body_data, "tags": tags_data},
          [priority_data, department_data],
          epochs=1)

model.evaluate({"title": title_data, "text_body": text_body_data, "tags": tags_data},
               [priority_data, department_data])

priority_preds, department_preds = model.predict(
    {"title": title_data, "text_body": text_body_data, "tags": tags_data})