【Keras】Callback関数を自作する
Keras
Lastmod: 2023-10-09

概要

KerasのCallback関数を自作します。

使い方

from keras.callbacks import Callback

class MyCallback(Callback):
    def __init__(self):
        pass
        
    def on_epoch_begin(self, epoch, logs=None):
        pass
        
    def on_epoch_end(self, epoch, logs=None):
        pass
    
    def on_batch_begin(self, batch, logs=None):
        pass
    
    def on_batch_end(self, batch, logs=None):
        pass
    
    def on_train_begin(self, logs=None):
        pass
    
    def on_train_end(self, logs=None):
        pass
関数 処理タイミング
on_epoch_begin 全てのエポックの開始時
on_epoch_end 全てのエポックの終了時
on_batch_begin 全てのバッチの終了時
on_batch_end 全てのバッチの終了時
on_train_begin 訓練の開始時
on_train_end 訓練の終了時

以上のコードのように書くことでCallback関数を自作することができます。あとは処理したいタイミングのとこの関数にコードを書いていけばOKです。もちろん全ての関数で処理する必要はなく、処理しない部分は関数ごと消してしまって大丈夫です。

サンプル

以下にサンプルコードを書いておきます。この関数では100 epochごとにモデルを保存していきます。

class MyModelCheckpoint(Callback):
    def __init__(self, each_epoch=100, file_path="./weights-epoch{}.h5f"):
        self.each_epoch = each_epoch
        self.file_path = file_path
        
    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.each_epoch == 0:
            print("epoch{} : save weights".format(epoch))
            self.model.save_weights(self.file_path.format(epoch), overwrite=True)
            
model_ckpt_callback = MyModelCheckpoint()

参考