博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow -- 模型保存与读取
阅读量:5795 次
发布时间:2019-06-18

本文共 980 字,大约阅读时间需要 3 分钟。

最近学习Google的深度学习框架TensorFlow,CNN模型训练什么的都是OK的,官方也有代码,中文详解请参照:

但是在实际使用的时候不可能每次预测都训练一遍模型,这样太浪费时间,所以需要我们在训练完成的时候保存模型,并且在需要预测的时候加载。官方提供的例子和解释不够具体,让我踩了很多的坑,所以写个笔记分享一下,希望帮助大家跳过或者少踩这些坑。

模型保存:

①首先对于需要保存的变量进行定义,记得variable和placeholder保存用变量名的定义一定不能忘了

定义的形式大体上如下:

var_name_1= tf.Variable(........,name='var_name_1_store')var_name_2= tf.argmax(var_name_1,name='var_name_2_store')var_name_3=tf.placeholder(........,name='var_name_3_store')var_name_4=tf.matmul(var_name_1,var_name_3,name='var_name_4_store')

②其次就是保存处理

需要利用 tf.train.Saver来保存模型,其中global_step不定义的情况下,默认为0

saver = tf.train.Saver()saver.save(sess,'./data.chkp',global_step=XX)

模型加载:

①首先读取刚刚保存的meta文件,然后全局变量初始化,需要用到tf.train.import_meta_graph

saver = tf.train.import_meta_graph("./data.chkp.meta")sess.run(tf.global_variables_initializer())

②其次加载我们需要的变量,并预测,这里用到var_name_3_store,这就是为什么前面placeholder定义的时候一定要定义name

predict = tf.get_default_graph().get_tensor_by_name("var_name_4_store:0") predict.eval(feed_dist={'var_name_3_store':XXXXX})

 

转载地址:http://hddfx.baihongyu.com/

你可能感兴趣的文章
[Google Guava] 2.1-不可变集合
查看>>
三种数据分析法提升电商运营
查看>>
哪个线程执行 CompletableFuture’s tasks 和 callbacks?
查看>>
《数据科学与大数据分析——数据的发现 分析 可视化与表示》一2.10 练习
查看>>
Oracle ASM 翻译系列第六弹:高级知识 如何映射asmlib管理的盘到它对应的设备名...
查看>>
多线程之volatile关键字
查看>>
如何判断webview是不是滑到底部
查看>>
Raptor实践2——控制结构
查看>>
Smartisan OS一步之自定义拖拽内容
查看>>
《JavaScript权威指南第六版》学习笔记-对象
查看>>
开发者论坛一周精粹(第四期):Windows系统 SMB/RDP远程命令执行漏洞
查看>>
Kafka 0.10 常用运维命令
查看>>
常见的浏览器端数据存储方案
查看>>
Nodejs核心模块之net和http
查看>>
Spark+Hbase 亿级流量分析实战(数据结构设计)
查看>>
普通程序员,三年成为年薪100w架构师,只因做到了这些
查看>>
Dockly:从终端管理 Docker 容器
查看>>
Java并发实践(七)取消与关闭
查看>>
利用selenium爬取重定向内容
查看>>
iOS教程 免费使用SMSSDK语音验证的方法
查看>>