11. Variable

11.1. what is tf.Variable

本质上讲,variables是tensorflow operation

In python, Variable is a class object

11.2. creating/using/sharing variables

有两个方法可以create a variable operation:

  • tf.get_variable()
  • tf.Variable(),肯定会创建新的variable operations

但是,无论是creating, using, 还是sharing variables,都推荐使用同一个函数,tf.get_variable()。这个函数使用了一个关键参数name,可以回顾graph/operation/name章节中的内容。

tf.get_variable()干了两件事,都和tf.variable_scope()有关:

  1. prefixes the name with the current variable scope and
  2. performs reuse checks(reuse是tf.variable_scope()的一个参数)
1
2
3
4
#tf.get_variable()必须和tf.variable_scope()一起使用
with tf.variable_scope('vs'):
  v = tf.get_variable('ws', [2,2,3,32])
print(v.name) #=>vs/ws:0

11.3. initialization

如果不初始化,直接执行session.run(my_variable)会报错。

  1. 最容易想到的初始化方法:
1
2
#run the variable's initializer operation. For example:
session.run(my_variable.initializer)
  1. 如果定义的Variable太多,利用了collection的概念,效率更高的方法来初始化
1
2
#tf.global_variables_initializer(). This function returns a single operation responsible for initializing all variables in the tf.GraphKeys.GLOBAL_VARIABLES collection.
session.run(tf.global_variables_initializer())

3. Most high-level frameworks such as tf.contrib.slim, tf.estimator.Estimator and Keras automatically initialize variables for you before training a model.

  1. 查询还没有初始化的variable
1
print(session.run(tf.report_uninitialized_variables()))

11.4. name

在某个tf.Graph中,不应出现同名的variable。

Variable从本质上是一个operation,所以tf.name_scope()也会在使用tf.Variable()创建variable operation时给name加上prefix。那么,又提供tf.varaible_scope()的原因何在呢?支持”sharing variables”,tf.get_variable()不会考虑tf.name_scope()给variable operation name加的prefix,所以使用tf.get_variable()甚至可以共享不同name scope下的variable operation.

使用tf.variable_scope(),给variable name加上prefix,就算创建variable的name参数相同也没事,避免重名(传给tf.get_variable()或者tf.Variable()的name参数是否允许相同)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#graph会自动处理name参数相同的问题
v = tf.Variable( [2,2,3,32],name='weights')
v1 = tf.Variable( [2,2,3,32],name='weights')
print(v.name) #=>weights_4:0
print(v1.name) #=>weights_5:0

#tf.Variable()并没有检查variable_scope()的reuse参数
with tf.variable_scope('vs'):
  v = tf.Variable( [2,2,3,32],name='weights')
  v1 = tf.Variable( [2,2,3,32],name='weights')
print(v.name) #=>vs/weights:0
print(v1.name) #=>vs/weights_1:0

#tf.name_scope()对tf.Variable()还是起作用的
with tf.name_scope('vs'):
  v = tf.Variable( [2,2,3,32],name='weights')
  v1 = tf.Variable( [2,2,3,32],name='weights')
print(v.name) #=>vs_1/weights:0
print(v1.name) #=>vs_1/weights_1:0

#设置reuse参数后,tf.get_variable()的name参数可以相同,否则会报错
with tf.variable_scope('vs', reuse=tf.AUTO_REUSE):
  v = tf.get_variable('ws3', [2,2,3,32])
  #发现了同名variable,vs/ws3:0,然后检查reuse产生可用,于是返回了上一步的v
  v1 = tf.get_variable('ws3', [2,2,3,32])
print(v.name) #=>vs/ws3:0
print(v1.name) #=>vs/ws3:0

11.5. Sharing variable

使用variable name来分辨不同的variable,所谓sharing,就是返回同名的已经创建的variable。

11.5.1. Using tf.Variable()

如果使用tf.Variable(), 即使传入的name参数相同,tf也自动处理同名进而会创建新的variable operation, sharing variable也无从谈起。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
#graph会自动处理name参数相同的问题
v = tf.Variable( [2,2,3,32],name='weights')
v1 = tf.Variable( [2,2,3,32],name='weights')
print(v.name) #=>weights_4:0
print(v1.name) #=>weights_5:0

#tf.Variable()并没有检查variable_scope()的reuse参数
with tf.variable_scope('vs'):
  v = tf.Variable( [2,2,3,32],name='weights')
  v1 = tf.Variable( [2,2,3,32],name='weights')
print(v.name) #=>vs/weights:0
print(v1.name) #=>vs/weights_1:0

11.5.2. Using tf.get_variable()

1
2
3
4
5
6
7
#设置reuse参数后,tf.get_variable()可以重用同名的已经创建的variable
with tf.variable_scope('vs', reuse=tf.AUTO_REUSE):
  v = tf.get_variable('ws3', [2,2,3,32])
  #发现了同名variable,vs/ws3:0,然后检查reuse可用,于是返回了上一步的v
  v1 = tf.get_variable('ws3', [2,2,3,32])
print(v.name) #=>vs/ws3:0
print(v1.name) #=>vs/ws3:0

11.6. Saving

11.6.1. How to

The save and restore ops are added by tf.train.Saver constructor to the graph for all, or a specified list, of the variables in the graph. 每一个variable node都会链接一个save node,每隔几轮迭代就会保存一次数据到持久化的存储系统。同样,每一个variable node都会链接一个restore node,在每次重启时会被调用并恢复数据。Tensorflow支持检查点(checkpoint)的保存和恢复,

1
2
3
4
5
6
7
sess = tf.Session()
# Add ops to save and restore all the variables
saver = tf.train.Saver(max_to_keep=0)
for step in range(MAX_STEP):
  ...
  #@prefix: 必须包含路径名, 例如在MTCNN中的"data/MTCNN_model/PNet_landmark/PNet"
  saver.save(sess, prefix, global_step=epoch*2)

从上述代码看出,

  • 初始化一个saver object,就自动给varaible node加上了save node & restore node,这个过程的Graph图示可以参见 write event file
  • save()动作是在BP过程之外单独执行的,虽然没有显示调用session.run(),但是在def save()的 source code 中调用了它, line1652

11.6.2. Result

在MTCNN训练完PNet中,每次执行saver.save()生成三个文件:PNet-8.meta, PNet-8.index, PNet-8.data-00000-of-00001。最后,还会生成一个名为checkpoint的单独的文件。

If the saver is sharded(分片), this string(path prefix used for the checkpoint files) ends with: ‘-?????-of-nnnnn’ where ‘nnnnn’ is the number of shards created.

  • The protocol buffer file named checkpoint

TensorFlow saves variables in binary checkpoint files that, roughly speaking, map variable names to tensor values.

1
2
3
$more checkpoint
model_checkpoint_path: "PNet-30"
all_model_checkpoint_paths: "PNet-30"

有一个 checkpoint_state.proto 与之对应

11.7. Restoring

Restoring variables