인공지능/텐서플로우

텐서플로우의 콜백클래스를 이용해서, 원하는 조건이 되면 학습을 멈추게 하기

iminu 2022. 6. 13. 17:37
def build_model():
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(128, 'relu'))
    model.add(tf.keras.layers.Dense(10, 'softmax'))
    model.compile('adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
    return model
    
class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs = []):
        if logs.get('val_accuracy') > 0.87:
            print('\n밸리데이션 정확도가 87% 넘으므로, 학습을 멈추게 합니다.')
            self.model.stop_training = True

my_callback = myCallback()
model = build_model()
epoch_history = model.fit(training_images, training_labels, epochs = 30, validation_split = 0.2, callbacks = [my_callback, early_stop])

먼저 모델링 함수를 만들고, 콜백클래스를 상속받는 myCallback클래스를 만든다. on_epoch_end는 epoch가 한번 끝날 때 마다 실행되는 함수 인데 val_accuracy가 0.87이 넘으면 부모 클래스의 stop_training변수를 True로 만들게 한다.

이 클래스의 객체를 생성하고 모델을 만들고, fit의 callbacks에다 배열안에 객체를 대입해준다.