def build_model():
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, 'relu'))
model.add(tf.keras.layers.Dense(10, 'softmax'))
model.compile('adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
return model
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs = []):
if logs.get('val_accuracy') > 0.87:
print('\n밸리데이션 정확도가 87% 넘으므로, 학습을 멈추게 합니다.')
self.model.stop_training = True
my_callback = myCallback()
model = build_model()
epoch_history = model.fit(training_images, training_labels, epochs = 30, validation_split = 0.2, callbacks = [my_callback, early_stop])
먼저 모델링 함수를 만들고, 콜백클래스를 상속받는 myCallback클래스를 만든다. on_epoch_end는 epoch가 한번 끝날 때 마다 실행되는 함수 인데 val_accuracy가 0.87이 넘으면 부모 클래스의 stop_training변수를 True로 만들게 한다.
이 클래스의 객체를 생성하고 모델을 만들고, fit의 callbacks에다 배열안에 객체를 대입해준다.
'인공지능 > 텐서플로우' 카테고리의 다른 글
텐서플로우의 모델을 저장하고 불러오는 방법 (0) | 2022.06.14 |
---|---|
validation_data 파라미터 사용법 (0) | 2022.06.14 |
에포크, 학습데이터/밸리데이션데이터와 오버피팅 (0) | 2022.06.13 |
softmax로 나온 결과를 레이블 인코딩으로 바꾸는 방법 (0) | 2022.06.13 |
분류의 문제에서 loss 셋팅하는 방법 (0) | 2022.06.13 |