tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下:
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") v2= tf.Variable(tf.zeros([200]), name="v2") saver = tf.train.Saver() with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) saver.save(sess,"checkpoint/model_test",global_step=1)
当我们保存模型后,我们可以通过saver.restore()来加载模型,初始化变量:
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") v2= tf.Variable(tf.zeros([200]), name="v2") saver = tf.train.Saver() with tf.Session() as sess: # init_op = tf.global_variables_initializer() # sess.run(init_op) saver.restore(sess, "checkpoint/model_test-1") # saver.save(sess,"checkpoint/model_test",global_step=1)
神经网络训练时,有时候我们需要从预训练的模型中加载部分参数,初始化当前模型,例如加入CNN有6层,我们需要从已有的模型初始化CNN前5层参数.这可以通过saver.restore()实现.
之前我们已经介绍可以通过tf.train.Saver()的保存部分变量的方法,即需要保存的变量列表,同样的,在变量初始化的时候,我们可以对需要单独初始化的变量分别定义一个tf.train.Saver()函数,这样就可以单独对该部分变量初始化,例如下面代码,saver1用于初始化变量v1,saver2用于初始化变量v2,v3:
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") v2= tf.Variable(tf.zeros([200]), name="v2") v3= tf.Variable(tf.zeros([100]), name="v3") #saver = tf.train.Saver() saver1 = tf.train.Saver([v1]) saver2 = tf.train.Saver([v2]+[v3]) with tf.Session() as sess: # init_op = tf.global_variables_initializer() # sess.run(init_op) saver1.restore(sess, "checkpoint/model_test-1") saver2.restore(sess, "checkpoint/model_test-1") # saver.save(sess,"checkpoint/model_test",global_step=1)
以上这篇tensorflow 加载部分变量的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
标签:
tensorflow,加载,变量
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件!
如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
狼山资源网 Copyright www.pvsay.com
暂无“tensorflow 加载部分变量的实例讲解”评论...
RTX 5090要首发 性能要翻倍!三星展示GDDR7显存
三星在GTC上展示了专为下一代游戏GPU设计的GDDR7内存。
首次推出的GDDR7内存模块密度为16GB,每个模块容量为2GB。其速度预设为32 Gbps(PAM3),但也可以降至28 Gbps,以提高产量和初始阶段的整体性能和成本效益。
据三星表示,GDDR7内存的能效将提高20%,同时工作电压仅为1.1V,低于标准的1.2V。通过采用更新的封装材料和优化的电路设计,使得在高速运行时的发热量降低,GDDR7的热阻比GDDR6降低了70%。