tensorflow自定义变量赋值及加载预训练模型
在使用tensorflow实现算法的时候,我们通常会有自定义给我们网络参数赋值的需求,例如自定义更新或者加载部分预训练模型的时候。
变量赋值操作
tensorflow中的tensor提供了assign操作
|
|
上面的例子是根据Numpy数据给tensorflow变量赋值,如果是根据tensor给变量赋值,与上相同,不用建立占位符,如下:
|
|
加载预训练模型
很多时候我们需要用到别人预先训练好的模型,如果同样的网络结构(包扩命名),并且别人的模型是ckpt文件,那么很简单,我们只需要建立saver并restore即可。
如果模型文件是caffemodel,可以用caffe加载后转存为hdf5格式,对于转化之后的hdf5格式权重文件或者keras训练的hdf5文件,我们都可以利用h5py打开,按照第一步所述,给我们自己的tensorflow网络初始化网络,这里有两种方式:
a) 参数定义时即用预训练的数据,例如weights[‘conv0’] = tf.Variable(pretrain_weight)), 这种适合自己定义网络参数且网络不庞大时使用
b) 参数初始化之后,按照第一部分赋值操作给所有或部分参数赋值
另外,如果别人的模型是ckpt文件,但是和我的网络结构不完全一致该怎么办呢?
如果我们要使用的网络参数与别人命名一致,我们可以使用from tensorflow.contrib.framework.python.ops.variables import assign_from_checkpoint_fn中的函数assign_from_checkpoint_fn
|
|
我们命名方式也不同,我们只能先拿到ckpt所有参数,再根据第一部分内容再来一次赋值:
获得ckpt中所有变量
然后根据所需获得要的tensor
然后根据第一步部分给参数赋值即可,简单示例:
需要加载参数原始命名name_ori, 自己网络中命名name_new, 而weight是网络参数列表
|
|