TF的三种模型的保存与加载方式

TF的三种模型的保存与加载方法

当我们训练完一波模型,准备把模型应用于生产环境时,我们需要在生产环境部署预测的model,此时就涉及到如何把我们训练好的model重新加载以及处于效率的考虑,用什么方法更好的部署模型。因此我们来谈谈常见的三种模型的保存与加载方法。

  • Checkpoint: 知道模型结构,单纯保存变量
  • SavedModel: 不知道模型结构,保存模型和变量
  • Freeze pb: 不需要再改变量,只要常量化的模型(“冻结”)

checkpoint

第一种必然是大家最为熟悉的checkpoint的方式,通常在模型训练时,我们通过saver=tf.train.saver()定义saver,在session中通过saver.save()保存模型中的变量。此时我们只是单纯保存了变量而没有对模型本身做任何保存,此时恢复模型需要有模型对应的源代码,因此当我们需要在C++中恢复模型,只能用C++把模型的代码复写一遍。
保存:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import tensorflow as tf
a = tf.placeholder(dtype=tf.float32, shape=[2,2], name='a')
b = tf.placeholder(dtype=tf.float32, shape=[2,2], name='b')

w = tf.get_variable(name='w', shape=[2, 2], initializer=tf.ones_initializer())
c = tf.add(tf.add(a, b), w)

saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
result = sess.run(c, feed_dict={a:[[1,2],[2,3]], b:[[2,3],[4,5]]})
saver.save(sess, './saved_model/model.ckpt')

print(result)

1
2
[[4. 6.]
[7. 9.]]

加载:

1
2
3
4
5
6
7
8
9
10
11
12
13
import tensorflow as tf
a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a')
b = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='b')

w = tf.get_variable(name='w', shape=[2, 2], initializer=tf.ones_initializer())
c = tf.add(tf.add(a, b), w)

saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './saved_model/model.ckpt')
result = sess.run(c, feed_dict={a: [[1, 2], [2, 3]], b: [[2, 3], [4, 5]]})

print(result)

1
2
[[4. 6.]
[7. 9.]]

可以看到,通过restore我们可以得到同样的结果。

SavedModel

TF官网介绍说SavedModel是一种独立于语言而且可以恢复的序列化格式,使较高级别的系统和工具可以创建,使用和转换Tensorflow模型。常见的两种与SavedModel交互的方式包括tf.saved_model API 和tf.estimator.Estimator。

tf.saved_model API

保存方法1:使用tf.saved_model.simple_save

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import tensorflow as tf
a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a')
b = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='b')

w = tf.get_variable(name='w', shape=[2, 2], initializer=tf.ones_initializer())
c = tf.add(tf.add(a, b), w)

saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

tf.saved_model.simple_save(session=sess,
export_dir='./saved_model/pb/1',
inputs={'a': a, 'b': b},
outputs={'c': c})

此时,我们可以在saved_model/pb下看到如下文件结构:

 --  saved_model
    |--  pb
        |-- 1
            |-- saved_model.pb
            |-- variables
                |-- variables.data-00000-of-00001
                |-- variables.index

保存方法2:通过SavedModelBuilder构建

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import tensorflow as tf


a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a')
b = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='b')

w = tf.get_variable(name='w', shape=[2, 2], initializer=tf.ones_initializer())
c = tf.add(tf.add(a, b), w)
saver =tf.train.Saver()
with tf.Session() as sess:
# sess.run(tf.global_variables_initializer())
saver.restore(sess, './saved_model/model.ckpt')
builder = tf.saved_model.builder.SavedModelBuilder("./saved_model/pb/2")
inputs = {'a': tf.saved_model.utils.build_tensor_info(a),
'b': tf.saved_model.utils.build_tensor_info(b)}
output = {'c': tf.saved_model.utils.build_tensor_info(c)}

prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=output,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
{tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature}
)
builder.save()

通过saved_model_cli命令查看SavedModel

1
saved_model_cli show --dir /Users/xxx/Documents/pycharm_workspace/test_python/saved_model/pb/1 --all

可以得到如下结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['a'] tensor_info:
dtype: DT_FLOAT
shape: (2, 2)
name: a:0
inputs['b'] tensor_info:
dtype: DT_FLOAT
shape: (2, 2)
name: b:0
The given SavedModel SignatureDef contains the following output(s):
outputs['c'] tensor_info:
dtype: DT_FLOAT
shape: (2, 2)
name: Add_1:0
Method name is: tensorflow/serving/predict

加载方法1: 通过tf.saved_model.loader.load加载

1
2
3
4
5
6
7
8
9
10
11
12

import tensorflow as tf

export_dir = './saved_model/pb/2'
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess, ['serve'],
export_dir)
a = sess.graph.get_tensor_by_name('a:0')
b = sess.graph.get_tensor_by_name('b:0')

c = sess.graph.get_tensor_by_name('Add_1:0')
print(sess.run(c, feed_dict={a: [[1, 2], [2, 3]], b: [[2, 3], [4, 5]]}))

加载方法2: 通过tf.contrib.predictor.from_saved_model

1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf
export_dir = './saved_model/pb/2'
predictor_fn = tf.contrib.predictor.from_saved_model(
export_dir=export_dir,
signature_def_key="serving_default"
)

output = predictor_fn({'a': [[1, 2], [2, 3]],
'b': [[2, 3], [4, 5]]
})
print(output)

结合tf.estimator.Estimator使用

保存方法

1
2
3
4
5
6
7
8
9
10
11
12
13
def serving_input_receiver_fn():
feature_spec = {'a': tf.FixedLenFeature([2,2], tf.float32),
'b': tf.FixedLenFeature([2,2], tf.float32)}

serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=[1],
name='input_example_tensor')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)

return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

estimator.export_savedmodel('saved_model', serving_input_receiver_fn)

通过tf-serving部署服务访问

  • 部署server端:

    • 基于Dockerfile 创建镜像:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      FROM ubuntu:18.04


      # Install general packages
      RUN apt-get update && apt-get install -y \
      curl \
      libcurl3-dev \
      unzip \
      wget \
      && \
      apt-get clean && \
      rm -rf /var/lib/apt/lists/*

      # Previous Installation of tensorflow-model-server (BROKEN RECENTLY)
      #RUN echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list \
      # && curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add - \
      # && apt-get update && apt-get install tensorflow-model-server

      # New installation of tensorflow-model-server
      RUN TEMP_DEB="$(mktemp)" \
      && wget -O "$TEMP_DEB" 'http://storage.googleapis.com/tensorflow-serving-apt/pool/tensorflow-model-server-1.12.0/t/tensorflow-model-server/tensorflow-model-server_1.12.0_all.deb' \
      && dpkg -i "$TEMP_DEB" \
      && rm -f "$TEMP_DEB"

      # gRPC port
      EXPOSE 8500
      # REST API port
      EXPOSE 8501

      # Serve the model when the container starts
      CMD tensorflow_model_server \
      --port=8500 \
      --rest_api_port=8501 \
      --model_name="$MODEL_NAME" \
      --model_base_path="$MODEL_PATH"

    运行如下命令创建镜像:

    1
    docker build --rm -f Dockerfile -t tensorflow-serving-example:0.1 .
    • 创建临时目录保存savedModel

      1
      2
      mkdir -p ./saved_model/dkt/1
      cp -R ./saved_model/pb/* ./saved_model/test/1
    • 启动容器

      1
      docker run --rm -it -v /home/xxx/tf_serving/saved_model/:/models -e MODEL_NAME=test -e MODEL_PATH=/models/test -p 8500:8500 -p 8501:8501 --name tensorflow-serving-example tensorflow-serving-example:0.1

      至此,server已启动,运行client进行测试

  • client

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
import time
import numpy as np

import grpc
import tensorflow as tf
from tensorflow.contrib.util import make_tensor_proto
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2


def run(host, port, model, signature_name):
channel = grpc.insecure_channel('{host}:{port}'.format(host=host, port=port))
stub = prediction_service_pb2.PredictionServiceStub(channel)
start = time.time()
# Call classification model to make prediction
request = predict_pb2.PredictRequest()
request.model_spec.name = model
request.model_spec.signature_name = signature_name

feature = {}
feature['a'] = tf.train.Feature(float_list=tf.train.FloatList(value=[[1,2],[3,4]]))
feature['b'] = tf.train.Feature(float_list=tf.train.FloatList(value=[[2,3],[4,5]]]))
example = tf.train.Example(
features=tf.train.Features(
feature=feature
)
)

request.inputs['examples'].CopyFrom(make_tensor_proto([example.SerializeToString()], shape=[1]))
result = stub.Predict(request, 10.0) # 10 secs timeout

end = time.time()
time_diff = end - start

# Reference:
# How to access nested values
# https://stackoverflow.com/questions/44785847/how-to-retrieve-float-val-from-a-predictresponse-object
# print(result)
result = result.outputs['predict'].float_val
print(result)
print('predict shape {}'.format(len(result)))
print('time elapased: {}'.format(time_diff))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--host', help='Tensorflow server host name', default='10.8.8.71', type=str)
parser.add_argument('--port', help='Tensorflow server port number', default=8500, type=int)
parser.add_argument('--model', help='model name', default='dkt', type=str)
parser.add_argument('--signature_name', help='Signature name of saved TF model',
default='serving_default', type=str)

args = parser.parse_args()
run(args.host, args.port, args.model, args.signature_name)

Freeze pb

当不再需要改变变量,只要常量化当模型时,我们可以采用freeze pb的方式。可以用在不同语言部署的场景下,好处是除了可以冻结模型外,还可以指定剔除某些多余的节点。

冻结

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
input_checkpoint = './saved_model'
output_graph = './saved_model/pb/4/saved_model.pb'
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "Add_1"
# saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

graph = tf.Graph() # 获得默认的图
with graph.as_default() as g:
a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a')
b = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='b')

w = tf.get_variable(name='w', shape=[2, 2], initializer=tf.ones_initializer())
c = tf.add(tf.add(a, b), w)
# Define saver & Load checkpoint
saver = tf.train.Saver()

with tf.Session(graph=graph) as sess:
ckpt = tf.train.get_checkpoint_state(input_checkpoint)
if ckpt and ckpt.model_checkpoint_path:
print('restore True...')
saver.restore(sess, ckpt.model_checkpoint_path) # 恢复图并得到数据
# for op in graph.get_operations():
# print(op.name, op.values())
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开

with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点

运行上述程序,可在./saved_model/pb/4/下看到saved_model.pb文件。

加载pb

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import tensorflow as tf
import numpy as np
dir = './saved_model/pb/4/saved_model.pb'
with tf.Session() as sess:
print("load graph")
with tf.gfile.FastGFile(dir, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
return_elements = [u'a:0', u'b:0', u'Add_1:0']
return_elements = tf.import_graph_def(graph_def,
return_elements=return_elements)
a, b, c = return_elements[0], return_elements[1], return_elements[2]

feed_dict_testing = {a: [[1,2],[3,4]],
b: [[2,3],[4,5]],

}

result = sess.run(c, feed_dict=feed_dict_testing)
print(result)

以上就是常见的TF的模型保存及对应的加载方法~

参考文献:

  1. tensorflow 模型的存档、保存、冻结、优化
  2. Introduction to RESTful API with Tensorflow Serving
  3. tensorflow-serving-example
  4. tensorflow 官方文档