TF的三种模型的保存与加载方法
当我们训练完一波模型,准备把模型应用于生产环境时,我们需要在生产环境部署预测的model,此时就涉及到如何把我们训练好的model重新加载以及处于效率的考虑,用什么方法更好的部署模型。因此我们来谈谈常见的三种模型的保存与加载方法。
- Checkpoint: 知道模型结构,单纯保存变量
- SavedModel: 不知道模型结构,保存模型和变量
- Freeze pb: 不需要再改变量,只要常量化的模型(“冻结”)
checkpoint
第一种必然是大家最为熟悉的checkpoint的方式,通常在模型训练时,我们通过saver=tf.train.saver()定义saver,在session中通过saver.save()保存模型中的变量。此时我们只是单纯保存了变量而没有对模型本身做任何保存,此时恢复模型需要有模型对应的源代码,因此当我们需要在C++中恢复模型,只能用C++把模型的代码复写一遍。
保存:1
2
3
4
5
6
7
8
9
10
11
12
13
14import tensorflow as tf
a = tf.placeholder(dtype=tf.float32, shape=[2,2], name='a')
b = tf.placeholder(dtype=tf.float32, shape=[2,2], name='b')
w = tf.get_variable(name='w', shape=[2, 2], initializer=tf.ones_initializer())
c = tf.add(tf.add(a, b), w)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
result = sess.run(c, feed_dict={a:[[1,2],[2,3]], b:[[2,3],[4,5]]})
saver.save(sess, './saved_model/model.ckpt')
print(result)
1 | [[4. 6.] |
加载:1
2
3
4
5
6
7
8
9
10
11
12
13import tensorflow as tf
a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a')
b = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='b')
w = tf.get_variable(name='w', shape=[2, 2], initializer=tf.ones_initializer())
c = tf.add(tf.add(a, b), w)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './saved_model/model.ckpt')
result = sess.run(c, feed_dict={a: [[1, 2], [2, 3]], b: [[2, 3], [4, 5]]})
print(result)
1 | [[4. 6.] |
可以看到,通过restore我们可以得到同样的结果。
SavedModel
TF官网介绍说SavedModel是一种独立于语言而且可以恢复的序列化格式,使较高级别的系统和工具可以创建,使用和转换Tensorflow模型。常见的两种与SavedModel交互的方式包括tf.saved_model API 和tf.estimator.Estimator。
tf.saved_model API
保存方法1:使用tf.saved_model.simple_save
1 | import tensorflow as tf |
此时,我们可以在saved_model/pb下看到如下文件结构:
-- saved_model
|-- pb
|-- 1
|-- saved_model.pb
|-- variables
|-- variables.data-00000-of-00001
|-- variables.index
保存方法2:通过SavedModelBuilder构建
1 | import tensorflow as tf |
通过saved_model_cli命令查看SavedModel
1 | saved_model_cli show --dir /Users/xxx/Documents/pycharm_workspace/test_python/saved_model/pb/1 --all |
可以得到如下结果:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['a'] tensor_info:
dtype: DT_FLOAT
shape: (2, 2)
name: a:0
inputs['b'] tensor_info:
dtype: DT_FLOAT
shape: (2, 2)
name: b:0
The given SavedModel SignatureDef contains the following output(s):
outputs['c'] tensor_info:
dtype: DT_FLOAT
shape: (2, 2)
name: Add_1:0
Method name is: tensorflow/serving/predict
加载方法1: 通过tf.saved_model.loader.load加载
1 |
|
加载方法2: 通过tf.contrib.predictor.from_saved_model
1 | import tensorflow as tf |
结合tf.estimator.Estimator使用
保存方法
1 | def serving_input_receiver_fn(): |
通过tf-serving部署服务访问
部署server端:
- 基于Dockerfile 创建镜像:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35FROM ubuntu:18.04
# Install general packages
RUN apt-get update && apt-get install -y \
curl \
libcurl3-dev \
unzip \
wget \
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Previous Installation of tensorflow-model-server (BROKEN RECENTLY)
#RUN echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list \
# && curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add - \
# && apt-get update && apt-get install tensorflow-model-server
# New installation of tensorflow-model-server
RUN TEMP_DEB="$(mktemp)" \
&& wget -O "$TEMP_DEB" 'http://storage.googleapis.com/tensorflow-serving-apt/pool/tensorflow-model-server-1.12.0/t/tensorflow-model-server/tensorflow-model-server_1.12.0_all.deb' \
&& dpkg -i "$TEMP_DEB" \
&& rm -f "$TEMP_DEB"
# gRPC port
EXPOSE 8500
# REST API port
EXPOSE 8501
# Serve the model when the container starts
CMD tensorflow_model_server \
--port=8500 \
--rest_api_port=8501 \
--model_name="$MODEL_NAME" \
--model_base_path="$MODEL_PATH"
运行如下命令创建镜像:
1
docker build --rm -f Dockerfile -t tensorflow-serving-example:0.1 .
创建临时目录保存savedModel
1
2mkdir -p ./saved_model/dkt/1
cp -R ./saved_model/pb/* ./saved_model/test/1启动容器
1
docker run --rm -it -v /home/xxx/tf_serving/saved_model/:/models -e MODEL_NAME=test -e MODEL_PATH=/models/test -p 8500:8500 -p 8501:8501 --name tensorflow-serving-example tensorflow-serving-example:0.1
至此,server已启动,运行client进行测试
- 基于Dockerfile 创建镜像:
client
1 | import argparse |
Freeze pb
当不再需要改变变量,只要常量化当模型时,我们可以采用freeze pb的方式。可以用在不同语言部署的场景下,好处是除了可以冻结模型外,还可以指定剔除某些多余的节点。
冻结
1 | input_checkpoint = './saved_model' |
运行上述程序,可在./saved_model/pb/4/下看到saved_model.pb文件。
加载pb
1 | import tensorflow as tf |
以上就是常见的TF的模型保存及对应的加载方法~