1 回调函数的定义
当程序跑起来时,一般情况下,应用程序(application program)会时常通过API调用库里所预先备好的函数。但是有些库函数(library function)却要求应用先传给它一个函数,好在合适的时候调用,以完成目标任务。这个被传入的、后又被调用的函数就称为回调函数(callback function)。
回调函数的例子
### 回调函数
#回调函数1
#生成一个2k形式的偶数
def double(x):
return x * 2
#回调函数2
#生成一个4k形式的偶数
def quadruple(x):
return x * 4
## 使用回调函数的中间函数,也就是回调函数的使用者
#中间函数
#接受一个生成偶数的函数作为参数
#返回一个奇数
def getOddNumber(k, getEvenNumber):
return 1 + getEvenNumber(k)
#起始函数,这里是程序的主函数
def main():
k = 1
#当需要生成一个2k+1形式的奇数时
i = getOddNumber(k, double)
print(i)
#当需要一个4k+1形式的奇数时
i = getOddNumber(k, quadruple)
print(i)
#当需要一个8k+1形式的奇数时
i = getOddNumber(k, lambda x: x * 8)
print(i)
if __name__ == "__main__":
main()
可以发现,其实我们只需要知道回调函数是传入什么参数,有什么功能即可,而不能自己去定义传入参数,所以我们去写回调函数的时候要按照中间函数传入的参数去写。
2 keras中的回调函数
这里我们不讲那些keras中已经定义的回调函数,这里说一下如何创建自己的回调函数。
keras中定义了一个回调函数的抽象类,这个类包含多个回调函数(即一个类里面有很多方法,这些方法是不同的时期被中间函数调用的)。我们继承这个类,然后重新其中的回调函数即可。
下面看一下这个基类:
https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L275
class Callback(object):
"""Abstract base class used to build new callbacks.
# Properties
params: dict. Training parameters
(eg. verbosity, batch size, number of epochs...).
model: instance of `keras.models.Model`.
Reference of the model being trained.
The `logs` dictionary that callback methods
take as argument will contain keys for quantities relevant to
the current batch or epoch.
Currently, the `.fit()` method of the `Sequential` model class
will include the following quantities in the `logs` that
it passes to its callbacks:
on_epoch_end: logs include `acc` and `loss`, and
optionally include `val_loss`
(if validation is enabled in `fit`), and `val_acc`
(if validation and accuracy monitoring are enabled).
on_batch_begin: logs include `size`,
the number of samples in the current batch.
on_batch_end: logs include `loss`, and optionally `acc`
(if accuracy monitoring is enabled).
"""
def __init__(self):
self.validation_data = None
self.model = None
## 注意:set params和 set_model是已经定义好的,也就是说继承这个类以后回调函数本身就会有self.params 和self.model,不需要我们去关心有没有或者怎么得到这两个变量。
# 既然是回调函数,那么params这个参数是中间函数传给它的,不需要我们去传。
def set_params(self, params):
self.params = params
def set_model(self, model):
self.model = model
## 什么时候会调用,直接可以看方法名。
# Arguments
# logs: 具体可以看上面链接中给出的注释,每个都不一样。
def on_batch_begin(self, batch, logs=None)
def on_batch_end(self, batch, logs=None)
def on_epoch_begin(self, epoch, logs=None)
def on_epoch_end(self, epoch, logs=None)
def on_train_batch_begin(self, batch, logs=None)
def on_train_batch_end(self, batch, logs=None)
def on_test_batch_begin(self, batch, logs=None)
def on_test_batch_end(self, batch, logs=None)
def on_predict_batch_begin(self, batch, logs=None)
def on_predict_batch_end(self, batch, logs=None)
def on_train_begin(self, logs=None)
def on_train_end(self, logs=None)
def on_test_begin(self, logs=None)
def on_test_end(self, logs=None)
def on_predict_begin(self, logs=None)
def on_predict_end(self, logs=None)
在回调函数中可以使用这两个参数。
-
self.params = params: 字典。训练参数, (例如,verbosity, batch size, number of epochs...)。可以打印出来看看。
-
self.model = model:keras.models.Model 的实例。 指代被训练模型。
通过类的属性 self.model,回调函数可以获得它所联系的模型。
keras自定义回调函数的例子
回调函数使用以后,还可以通过类实例来访问实例变量。
class LossHistory(keras.callbacks.Callback):
def __init__(self):
super(LossHistory, self).__init__()
self.losses = []
def on_train_begin(self, logs={}):
self.losses = []
def on_batch_end(self, batch, logs={}):
self.losses.append(logs.get('loss'))
model = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
history = LossHistory()
model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0, callbacks=[history])
print(history.losses)
引用
- https://github.com/chensvm/Keras-Callback-F1/blob/master/roc_auc_score_Metrics
- https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L275