1 回调函数的定义

当程序跑起来时,一般情况下,应用程序(application program)会时常通过API调用库里所预先备好的函数。但是有些库函数(library function)却要求应用先传给它一个函数,好在合适的时候调用,以完成目标任务。这个被传入的、后又被调用的函数就称为回调函数(callback function)。

回调函数的例子

### 回调函数
#回调函数1
#生成一个2k形式的偶数
def double(x):
    return x * 2
    
#回调函数2
#生成一个4k形式的偶数
def quadruple(x):
    return x * 4

## 使用回调函数的中间函数,也就是回调函数的使用者
#中间函数
#接受一个生成偶数的函数作为参数
#返回一个奇数
def getOddNumber(k, getEvenNumber):
    return 1 + getEvenNumber(k)
    
#起始函数,这里是程序的主函数
def main():    
    k = 1
    #当需要生成一个2k+1形式的奇数时
    i = getOddNumber(k, double)
    print(i)
    #当需要一个4k+1形式的奇数时
    i = getOddNumber(k, quadruple)
    print(i)
    #当需要一个8k+1形式的奇数时
    i = getOddNumber(k, lambda x: x * 8)
    print(i)
    
if __name__ == "__main__":
    main()

可以发现,其实我们只需要知道回调函数是传入什么参数,有什么功能即可,而不能自己去定义传入参数,所以我们去写回调函数的时候要按照中间函数传入的参数去写

2 keras中的回调函数

这里我们不讲那些keras中已经定义的回调函数,这里说一下如何创建自己的回调函数。
keras中定义了一个回调函数的抽象类,这个类包含多个回调函数(即一个类里面有很多方法,这些方法是不同的时期被中间函数调用的)。我们继承这个类,然后重新其中的回调函数即可。
下面看一下这个基类:
https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L275

class Callback(object):
			 """Abstract base class used to build new callbacks.
    # Properties
        params: dict. Training parameters
            (eg. verbosity, batch size, number of epochs...).
        model: instance of `keras.models.Model`.
            Reference of the model being trained.
    The `logs` dictionary that callback methods
    take as argument will contain keys for quantities relevant to
    the current batch or epoch.
    Currently, the `.fit()` method of the `Sequential` model class
    will include the following quantities in the `logs` that
    it passes to its callbacks:
        on_epoch_end: logs include `acc` and `loss`, and
            optionally include `val_loss`
            (if validation is enabled in `fit`), and `val_acc`
            (if validation and accuracy monitoring are enabled).
        on_batch_begin: logs include `size`,
            the number of samples in the current batch.
        on_batch_end: logs include `loss`, and optionally `acc`
            (if accuracy monitoring is enabled).
    """
		
     def __init__(self):
        self.validation_data = None
        self.model = None
		
		## 注意:set params和 set_model是已经定义好的,也就是说继承这个类以后回调函数本身就会有self.params 和self.model,不需要我们去关心有没有或者怎么得到这两个变量。
		# 既然是回调函数,那么params这个参数是中间函数传给它的,不需要我们去传。
    def set_params(self, params): 
        self.params = params

    def set_model(self, model):
        self.model = model
		
		## 什么时候会调用,直接可以看方法名。
		# Arguments
    #       logs: 具体可以看上面链接中给出的注释,每个都不一样。
	def on_batch_begin(self, batch, logs=None)
	def on_batch_end(self, batch, logs=None)
	def on_epoch_begin(self, epoch, logs=None)
	def on_epoch_end(self, epoch, logs=None)
	def on_train_batch_begin(self, batch, logs=None)
	def on_train_batch_end(self, batch, logs=None)
	def on_test_batch_begin(self, batch, logs=None)
	def on_test_batch_end(self, batch, logs=None)
	def on_predict_batch_begin(self, batch, logs=None)
	def on_predict_batch_end(self, batch, logs=None)
	def on_train_begin(self, logs=None)
	def on_train_end(self, logs=None)
	def on_test_begin(self, logs=None)
	def on_test_end(self, logs=None)
	def on_predict_begin(self, logs=None)
	def on_predict_end(self, logs=None)	

在回调函数中可以使用这两个参数。

  •     self.params = params: 字典。训练参数, (例如,verbosity, batch size, number of epochs...)。可以打印出来看看。
    
  •     self.model = model:keras.models.Model 的实例。 指代被训练模型。
    

通过类的属性 self.model,回调函数可以获得它所联系的模型。

keras自定义回调函数的例子

回调函数使用以后,还可以通过类实例来访问实例变量。

class LossHistory(keras.callbacks.Callback):
    def __init__(self):
				super(LossHistory, self).__init__()
				self.losses = []
				
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

model = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

history = LossHistory()
model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0, callbacks=[history])

print(history.losses)

引用

  • https://github.com/chensvm/Keras-Callback-F1/blob/master/roc_auc_score_Metrics
  • https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L275