foolyc

tensorflow02-tensorflow变量共享及网络共享

在使用tensorflow实现一些特定算法时,我们通常需要共享某些变量,在此分享一些变量共享的方式。

变量共享

常用方式主要由两种,1)预定义参数,并存在dict中; 2)利用scope机制,下面分别介绍:

预定义

这种方式非常简单,先定义所有变量,存在字典中,构建网络时显式指定所用参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
variables_dict = {
"conv1_weights": tf.Variable(tf.random_normal([5, 5, 32, 32]),
name="conv1_weights")
"conv1_biases": tf.Variable(tf.zeros([32]), name="conv1_biases")
... etc. ...
}
def my_image_filter(input_images, variables_dict):
conv1 = tf.nn.conv2d(input_images, variables_dict["conv1_weights"],
strides=[1, 1, 1, 1], padding='SAME')
relu1 = tf.nn.relu(conv1 + variables_dict["conv1_biases"])
conv2 = tf.nn.conv2d(relu1, variables_dict["conv2_weights"],
strides=[1, 1, 1, 1], padding='SAME')
return tf.nn.relu(conv2 + variables_dict["conv2_biases"])
# Both calls to my_image_filter() now use the same variables
result1 = my_image_filter(image1, variables_dict)
result2 = my_image_filter(image2, variables_dict)

scope机制

tf.variable_scope其实就是对在其内定义的variable设置namespace + 用于变量共享。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def conv_relu(input, kernel_shape, bias_shape):
# Create variable named "weights".
weights = tf.get_variable("weights", kernel_shape,
initializer=tf.random_normal_initializer())
# Create variable named "biases".
biases = tf.get_variable("biases", bias_shape,
initializer=tf.constant_initializer(0.0))
conv = tf.nn.conv2d(input, weights,
strides=[1, 1, 1, 1], padding='SAME')
return tf.nn.relu(conv + biases)
def my_image_filter(input_images):
with tf.variable_scope("conv1"):
# Variables created here will be named "conv1/weights", "conv1/biases".
relu1 = conv_relu(input_images, [5, 5, 32, 32], [32])
with tf.variable_scope("conv2"):
# Variables created here will be named "conv2/weights", "conv2/biases".
return conv_relu(relu1, [5, 5, 32, 32], [32])
with tf.variable_scope("image_filters") as scope:
result1 = my_image_filter(image1)
scope.reuse_variables()
result2 = my_image_filter(image2)

注意申明参数reuse

网络共享

对于重复的网络结构,除了像上面的例子中执行多次result = my_image_filter(image)代码外,推荐tf.map_fn函数

1
2
3
4
5
6
7
def map_fn(
fn, # 需要重复的函数
elems, # 输入元组
dtype=None, # 输出形式
# 其他参数
):

下面是一个简单使用实例:

1
2
3
4
elems = np.array([1, 2, 3])
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3]

其他

附常用变量收集方式:

1
2
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope)
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
本文由foolyc创作和发表,采用BY-NC-SA国际许可协议进行许可
转载请注明作者及出处,本文作者为foolyc
本文标题为tensorflow02-tensorflow变量共享及网络共享
本文链接为http://foolyc.com//2018/01/02/tensorflow02/.