前言
Tensorflow Serving是专门为生产环境设计的一套系统,近正好有时间就尝试去使用了一下,我只是简单的部署了一下,不会涉及到用法,在这里简单做一个记录。
正文
一、安装docker,搭建环境
我简单介绍一下docker,docker其实是一种虚拟化方式,类似于我们使用过的虚拟机,但是相比于虚拟机又更加轻便,灵活,具体区别请看图1。大家把自己配置好的一个系统环境提交到docker hub中,我们可以直接去把镜像pull下来使用,我们运行pull下来的镜像就会生成一个容器,这个容器就是一个操作系统,里边有别人配置好的东西,我们自己也可以去添加部署项目需要的东西,配置好之后可以commit,别人也可以去使用这个镜像。
下载docker就直接去官网下载吧,docker官网。
因为我们是使用TensorFlow Serving,所以搜索配置好的镜像进行下载。使用命令
docker search serving
如图2所示,可以搜索到很多镜像,我们选择个,这个是官方出品的。还有一点就是镜像的版本,我使用的是1.12.0-devel版本。
docker pull tensorflow/serving:1.12.0-devel
下载好之后运行这个镜像,生成一个容器。端口这里我简单说一下,-p后边8888是我本地的端口,9999的容器内的端口,这个参数的意思就是把我本地的8888端口映射到容器中的9999端口中去。其他具体参数大家参考[2]了解。
docker run -it -p 8888:9999 tensorflow/serving bash
到此我们的环境就算是搭建好了,非常简单,就像从github拉取了一个代码一样,直接可以跑起来。
二、将训练好的模型导成TensorFlow Serving格式,启动服务
有两种方式,一种是训练的过程中直接将模型导出成TensorFlow Serving所需格式,另外一种是将图恢复再导出,我这里使用的是第二种,直接上代码。
import tensorflow as tf
import textCNN.cnn_model_modify_minshi as cnn_model
model_path = './checkpoints/paragraph_minshi/best_validation'
model_version = '3'
export_model_dir = './serving_model/' + model_version
config = cnn_model.TCNNConfig_minshi()
model = cnn_model.TextCNN_minshi(config)
with tf.get_default_graph().as_default():
saver = tf.train.Saver(tf.global_variables())
# 定义你的输入输出以及计算图
x = model.input_x
keep_prob = model.keep_prob
y_pred_cls = model.y_pred_cls
# 导入你已经训练好的模型.ckpt文件
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
print('Restore from {}'.format(model_path))
saver.restore(sess, model_path)
# 定义导出模型的各项参数
# 定义导出地址
print('Exporting trained model to', export_model_dir)
builder = tf.saved_model.builder.SavedModelBuilder(export_model_dir)
# 定义Input tensor info
inputs = {
'x': tf.saved_model.utils.build_tensor_info(x),
'keep_prob': tf.saved_model.utils.build_tensor_info(keep_prob)
}
# 定义Output tensor info
tensor_info_output = tf.saved_model.utils.build_tensor_info(y_pred_cls)
# 创建预测签名
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs={'predict': tensor_info_output},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'predict_cls': prediction_signature})
# 导出模型
builder.save()
print('Done exporting!')
将导出的模型导入到运行的容器中
docker cp 模型路径 容器ID:路径
在容器中启动服务,我们服务的端口就是8500,调用的时候记得这个端口
tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=saved_model --model_base_path=/tensorflow-serving/checkpoint/serving_model
三、编写服务端部署代码
接下来我们就要编写服务端代码,调用启动的模型服务,使用web框架将我们部署好的服务暴露给外边。直接上代码。
import json
import os
import sys
import time
import grpc
import tensorflow as tf
currentUrl = os.path.dirname(__file__)
parentUrl = os.path.abspath(os.path.join(currentUrl, os.pardir))
sys.path.append(parentUrl)
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import tensorflow.contrib.keras as kr
from textCNN.data.cnews_loader_minshi import read_category_minshi, read_vocab_minshi
from textCNN.cnn_model_modify_minshi import TCNNConfig_minshi, TextCNN_minshi
from flask import Flask, request
base_dir_minshi = '../textCNN/data/paragraph_minshi'
vocab_dir_minshi = os.path.join(base_dir_minshi, 'vocab.txt')
app = Flask(__name__)
words, word_to_id = read_vocab_minshi(vocab_dir_minshi)
categories, cat_to_id = read_category_minshi()
config = TCNNConfig_minshi()
config.vocab_size = len(words)
model = TextCNN_minshi(config)
def request_server(x, keep_prob):
# Request.
channel = grpc.insecure_channel('0.0.0.0:8500')
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = "saved_model" # 模型名称
request.model_spec.signature_name = "predict_cls" # 签名名称
# "x"是你导出模型时设置的输入名称
request.inputs["x"].CopyFrom(
tf.contrib.util.make_tensor_proto(x))
request.inputs["keep_prob"].CopyFrom(tf.contrib.util.make_tensor_proto(keep_prob))
response = stub.Predict(request, 20.0) # 20 secs timeout
return tf.contrib.util.make_ndarray(response.outputs["predict"])[]
@app.route('/get_type_minshi', methods=['POST'])
def getType():
param = json.loads(request.get_data().decode('utf-8'))
content = param['content']
data = [word_to_id[x] for x in content if x in word_to_id]
x = kr.preprocessing.sequence.pad_sequences([data], config.seq_length)
type_id = request_server(x, 1.0)
data = {'type': categories[type_id]}
data = json.dumps(data)
return data
if __name__ == '__main__':
app.run(host='0.0.0.0', port=9999)
我这里使用flask将程序封装成HTTP接口,看后一行可以看到,我这里指定的端口是9999,这正好对应好了我启动容器时候的端口,服务启动成功之后,我在本地调用8888端口就直接映射到容器内的9999端口,也就能访问到内部的程序了。
这篇文章写得比较仓促,比较乱,主要目的还是记录一下自己使用的过程,如果能帮助到大家就更好了,接下来如果在真正项目中的使用的话,我会重新将这篇文章写一遍。