인공지능/텐서플로우
텐서플로우의 콜백클래스를 이용해서, 원하는 조건이 되면 학습을 멈추게 하기
iminu
2022. 6. 13. 17:37
def build_model():
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, 'relu'))
model.add(tf.keras.layers.Dense(10, 'softmax'))
model.compile('adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
return model
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs = []):
if logs.get('val_accuracy') > 0.87:
print('\n밸리데이션 정확도가 87% 넘으므로, 학습을 멈추게 합니다.')
self.model.stop_training = True
my_callback = myCallback()
model = build_model()
epoch_history = model.fit(training_images, training_labels, epochs = 30, validation_split = 0.2, callbacks = [my_callback, early_stop])
먼저 모델링 함수를 만들고, 콜백클래스를 상속받는 myCallback클래스를 만든다. on_epoch_end는 epoch가 한번 끝날 때 마다 실행되는 함수 인데 val_accuracy가 0.87이 넘으면 부모 클래스의 stop_training변수를 True로 만들게 한다.
이 클래스의 객체를 생성하고 모델을 만들고, fit의 callbacks에다 배열안에 객체를 대입해준다.