# tensorflow中一种融合多个模型的方法

#### 2.如何实现

``` 1 def train_model1():
2     w1 = tf.get_variable("w1", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
3     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(), trainable=True)
4     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
5     a1 = tf.matmul(x, w1)
6     input = np.random.rand(3200, 3)
7     sess = tf.InteractiveSession()
8     sess.run(tf.global_variables_initializer())
9     saver1 = tf.train.Saver([w1,w2])
10     for i in range(0, 1):
11         w1_var,w2_var = sess.run([w1,w2], feed_dict={x: input[i * 32:(i + 1) * 32]})
12         print w1_var
13         print w2_var
14         print ‘=‘ * 30
15     saver1.save(sess, ‘save1-exp‘)```

``` 1 def train_model2():
2     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
3     w3 = tf.get_variable("w3", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
4     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
5     a2 = tf.matmul(x, w2 * w3)
6     input = np.random.rand(3200, 3)
7     sess = tf.InteractiveSession()
8     sess.run(tf.global_variables_initializer())
9     saver2 = tf.train.Saver([w2,w3])
10     for i in range(0, 1):
11         w2_var, w3_var = sess.run([w2, w3], feed_dict={x: input[i * 32:(i + 1) * 32]})
12         print w2_var
13         print w3_var
14         print ‘=‘ * 30
15     saver2.save(sess, ‘save2-exp‘)```

``` 1 def restore_model():
2     w1 = tf.get_variable("w1", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
3     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
4     w3 = tf.get_variable("w3", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
5     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
6     a1 = tf.matmul(x, w1)
7     a2 = tf.matmul(x, w2 * w3)
8     loss = tf.reduce_mean(tf.square(a1 - a2))
9     sess = tf.InteractiveSession()
10     sess.run(tf.global_variables_initializer())
11     saver1 = tf.train.Saver([w1,w2])
12     saver1.restore(sess, ‘save1-exp‘)
13     saver2 = tf.train.Saver([w2, w3])
14     saver2.restore(sess, ‘save2-exp‘)
15     saver3 = tf.train.Saver(tf.trainable_variables())
16     input = np.random.rand(3200, 3)
17     w1_var, w2_var, w3_var = sess.run([w1, w2, w3], feed_dict={x: input[0:32]})
18     print w1_var
19     print w2_var
20     print w3_var
21     print ‘=‘ * 30
22     saver3.save(sess, ‘save3-exp‘)```

