tensorflow中一种融合多个模型的方法

1.使用场景

假设我们有训练好的模型A,B,C,我们希望使用A,B,C中的部分或者全部变量,合成为一个模型D,用于初始化或其他目的,就需要融合多个模型的方法

2.如何实现

我们可以先声明模型D,再创建多个Saver实例,分别从模型A,B,C的保存文件(checkpoint文件)中读取所需的变量值,来达成这一目的,下面是示例代码:

首先创建一个只包含w1,w2两个变量的模型,初始化后保存:

 1 def train_model1():
 2     w1 = tf.get_variable("w1", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 3     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(), trainable=True)
 4     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
 5     a1 = tf.matmul(x, w1)
 6     input = np.random.rand(3200, 3)
 7     sess = tf.InteractiveSession()
 8     sess.run(tf.global_variables_initializer())
 9     saver1 = tf.train.Saver([w1,w2])
10     for i in range(0, 1):
11         w1_var,w2_var = sess.run([w1,w2], feed_dict={x: input[i * 32:(i + 1) * 32]})
12         print w1_var
13         print w2_var
14         print ‘=‘ * 30
15     saver1.save(sess, ‘save1-exp‘)

然后再创建一个只包含w2,w3两个变量的模型,也是初始化后保存:

 1 def train_model2():
 2     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 3     w3 = tf.get_variable("w3", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 4     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
 5     a2 = tf.matmul(x, w2 * w3)
 6     input = np.random.rand(3200, 3)
 7     sess = tf.InteractiveSession()
 8     sess.run(tf.global_variables_initializer())
 9     saver2 = tf.train.Saver([w2,w3])
10     for i in range(0, 1):
11         w2_var, w3_var = sess.run([w2, w3], feed_dict={x: input[i * 32:(i + 1) * 32]})
12         print w2_var
13         print w3_var
14         print ‘=‘ * 30
15     saver2.save(sess, ‘save2-exp‘)

最后我们创建一个包含w1,w2,w3变量的模型,从上面两个保存的ckp文件中恢复:

 1 def restore_model():
 2     w1 = tf.get_variable("w1", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 3     w2 = tf.get_variable("w2", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 4     w3 = tf.get_variable("w3", shape=[3, 1], initializer=tf.truncated_normal_initializer(),trainable=True)
 5     x = tf.placeholder(tf.float32, shape=[None, 3], name="x")
 6     a1 = tf.matmul(x, w1)
 7     a2 = tf.matmul(x, w2 * w3)
 8     loss = tf.reduce_mean(tf.square(a1 - a2))
 9     sess = tf.InteractiveSession()
10     sess.run(tf.global_variables_initializer())
11     saver1 = tf.train.Saver([w1,w2])
12     saver1.restore(sess, ‘save1-exp‘)
13     saver2 = tf.train.Saver([w2, w3])
14     saver2.restore(sess, ‘save2-exp‘)
15     saver3 = tf.train.Saver(tf.trainable_variables())
16     input = np.random.rand(3200, 3)
17     w1_var, w2_var, w3_var = sess.run([w1, w2, w3], feed_dict={x: input[0:32]})
18     print w1_var
19     print w2_var
20     print w3_var
21     print ‘=‘ * 30
22     saver3.save(sess, ‘save3-exp‘)

然后保存,即完成了我们的目标

3.注意事项

3.1 取的模型中有同名变量

假设同名变量为a,这种情况下,从不同模型中恢复的a是按照读取顺序覆盖到a中的,如果希望只读取特定ckpt保存的变量值,在创建读取其他ckpt的saver时,不要把a加入到var_list中

3.2 模型D中有部分变量不在A,B,C中

这种情况,恢复时会报错,需要指定var_list,只恢复当前cpkt中保存的变量

原文地址:https://www.cnblogs.com/hrlnw/p/10466145.html

时间: 03-05

tensorflow中一种融合多个模型的方法的相关文章

unity中三种调用其他脚本函数的方法

第一种,被调用脚本函数为static类型,调用时直接用  脚本名.函数名().很不实用-- 第二种,GameObject.Find("脚本所在物体名").SendMessage("函数名");  此种方法可以调用public和private类型函数 第三种,GameObject.Find("脚本所在物体名").GetComponent<脚本名>().函数名();此种方法只可以调用public类型函数 unity中三种调用其他脚本函数的

tensorflow中四种不同交叉熵函数tf.nn.softmax_cross_entropy_with_logits()

Tensorflow中的交叉熵函数tensorflow中自带四种交叉熵函数,可以轻松的实现交叉熵的计算. tf.nn.softmax_cross_entropy_with_logits() tf.nn.sparse_softmax_cross_entropy_with_logits() tf.nn.sigmoid_cross_entropy_with_logits() tf.nn.weighted_cross_entropy_with_logits()注意:tensorflow交叉熵计算函数输入

Android中8种异步处理与计算的方法

注:该文章翻译自Ali Muzaffar的文章<8 ways to do asynchronous processing in Android and counting>  Android提供了许多API来支持异步处理的功能,结合着Java提供的方法和你手上拥有的,估计目前已经有数十种进行异步任务的方法. 目前的趋势是仅使用Java的threads或者Android的AsyncTask来处理各种问题.虽然上述两种方法拥有较高的知名度,但是并非所有的API都适合,为你的需求选择最合适的方法能够使

比较C#中几种常见的复制字节数组方法的效率[转]

[原文链接] 在日常编程过程中,我们可能经常需要Copy各种数组,一般来说有以下几种常见的方法:Array.Copy,IList<T>.Copy,BinaryReader.ReadBytes,Buffer.BlockCopy,以及System.Buffer.memcpyimpl,由于最后一种需要使用指针,所以本文不引入该方法. 本次测试,使用以上前4种方法,各运行1000万次,观察结果. using System; using System.Collections.Generic; using

分享php中四种webservice实现的简单架构方法及实例

一:PHP本身的SOAP所有的webservice都包括服务端(server)和客户端(client).要使用php本身的soap首先要把该拓展安装好并且启用.下面看具体的code首先这是服务端实现: PHP Code复制内容到剪贴板 <?php class test { function show() { return 'the data you request!'; } } function getUserInfo($name) { return 'fbbin'; } //实例化的参数手册上

lua中,两种json和table互转方法的效率比较

lua中json和table的互转,是我们在平时开发过程中经常用到的.比如: 在用lua编写的服务器中,如果客户端发送json格式的数据,那么在lua处理业务逻辑的时候,必然需要转换成lua自己的数据结构,如table.此时,就会用到table和json格式的互转. 在用lua编写的服务器中,如果我们通过redis来存储数据,由于redis中不存在table这种数据结构,因此,我们可以选择将table转换成json字符串来进行存储.在数据的存取过程中,也会用到table和json格式的互转. 以

Python中几种数据的常用内置方法

1. int bit_lenth:二进制的长度 2.str capitalize():首字母大写,其他小写. upper():全部转换为大写,lower()相反;casefold()功能类似于lower,但是更强大,不常用 title():每个被特殊字符隔开的单词的首字母大写,其中中文属于特殊字符; strip():去除左边和有右边的空格,对中间的空格无能为力,也可以去掉两边的指定的字符串 replace(a, b):将a替换为b split(a):用a作为切割线进行切割,返回值为一个list

delphi 中一种好用的数组定义方法以及函数嵌套的使用源代码

type TByteBuff;= array of integer; function abc(a:integer):TByteBuff;; var tempArr:TByteBuff; begin setlength(tempArr,2); tempArr[0] := a; tempArr[1] := a + 10; result:=tempArr; end; procedure TForm1.Button1Click(Sender: TObject); var a : integer; be

tf.train.Saver()-tensorflow中模型的保存及读取

作用:训练网络之后保存训练好的模型,以及在程序中读取已保存好的模型 使用步骤: 实例化一个Saver对象 saver = tf.train.Saver() 在训练过程中,定期调用saver.save方法,像文件夹中写入包含当前模型中所有可训练变量的checkpoint文件 saver.save(sess,FLAGG.train_dir,global_step=step) 之后可以使用saver.restore()方法,重载模型的参数,继续训练或者用于测试数据 saver.restore(sess