概要
KerasのCallback関数を自作します。
使い方
from keras.callbacks import Callback
class MyCallback(Callback):
def __init__(self):
pass
def on_epoch_begin(self, epoch, logs=None):
pass
def on_epoch_end(self, epoch, logs=None):
pass
def on_batch_begin(self, batch, logs=None):
pass
def on_batch_end(self, batch, logs=None):
pass
def on_train_begin(self, logs=None):
pass
def on_train_end(self, logs=None):
pass
関数 | 処理タイミング |
---|---|
on_epoch_begin | 全てのエポックの開始時 |
on_epoch_end | 全てのエポックの終了時 |
on_batch_begin | 全てのバッチの終了時 |
on_batch_end | 全てのバッチの終了時 |
on_train_begin | 訓練の開始時 |
on_train_end | 訓練の終了時 |
以上のコードのように書くことでCallback関数を自作することができます。あとは処理したいタイミングのとこの関数にコードを書いていけばOKです。もちろん全ての関数で処理する必要はなく、処理しない部分は関数ごと消してしまって大丈夫です。
サンプル
以下にサンプルコードを書いておきます。この関数では100 epochごとにモデルを保存していきます。
class MyModelCheckpoint(Callback):
def __init__(self, each_epoch=100, file_path="./weights-epoch{}.h5f"):
self.each_epoch = each_epoch
self.file_path = file_path
def on_epoch_end(self, epoch, logs=None):
if epoch % self.each_epoch == 0:
print("epoch{} : save weights".format(epoch))
self.model.save_weights(self.file_path.format(epoch), overwrite=True)
model_ckpt_callback = MyModelCheckpoint()