TensorFlow关于怎样解决Estimater.predict总是重新加载模型的问题
来源:锐游网
问题:
大家用Estimater.predict总是把模型重新load一遍,这样工程业务根本没法用。
解决方案:
代码
我的代码是用于Bert模型的,思路说清楚了,具体功能请自行修改。
from tokenization import FullTokenizer, validate_case_matches_checkpoint
from conv_example import convert_single_example
from process import InputExample
from modeling import BertConfig
from model_func import model_fn_builder
from config import FLAGS, TFConfig
import tensorflow as tf
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.run_config import RunConfig
class Fast(object):
def __init__(self, label):
self.label = label
self.closed = False
self.first_run = True
self.tokenizer = FullTokenizer(
vocab_file=FLAGS.vocab_file,
do_lower_case=True)
self.init_checkpoint = FLAGS.init_checkpoint
self.seq_length = FLAGS.max_seq_length
self.text = None
self.num_examples = None
self.predictions = None
self.estimator = self.get_estimator()
def get_estimator(self):
validate_case_matches_checkpoint(True, self.init_checkpoint)
bert_config = BertConfig.from_json_file(FLAGS.bert_config_file) # 载入bert自定义配置
if FLAGS.max_seq_length > bert_config.max_position_embeddings: # 验证配置信息准确性
raise ValueError(
"Cannot use sequence length %d because the BERT pre_model "
"was only trained up to sequence length %d" %
(self.seq_length, bert_config.max_position_embeddings))
run_config = RunConfig(
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
session_config=TFConfig.cpu()
)
model_fn = model_fn_builder( # 估计器函数,提供Estimator使用的model_fn,内部使用EstimatorSpec构建的
bert_config=bert_config,
num_labels=len(self.label),
init_checkpoint=self.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=None,
num_warmup_steps=None,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu)
estimator = Estimator( # 实例化估计器
model_fn=model_fn,
config=run_config,
warm_start_from=self.init_checkpoint # 新增预热
)
return estimator
def get_feature(self, index, text):
example = InputExample(f"text_{index}", text, None, self.label[0])
feature = convert_single_example(index, example, self.label, self.seq_length, self.tokenizer)
return feature.input_ids, feature.input_mask, feature.segment_ids, feature.label_id
def create_generator(self):
"""构建生成器"""
while not self.closed:
self.num_examples = len(self.text)
features = (self.get_feature(*f) for f in enumerate(self.text))
yield dict(zip(("input_ids", "input_mask", "segment_ids", "label_ids"), zip(*features)))
def input_fn_builder(self):
"""用于预测单独对预测数据进行创建,不基于文件数据"""
dataset = tf.data.Dataset.from_generator(
self.create_generator,
output_types={'input_ids': tf.int32,
'input_mask': tf.int32,
'segment_ids': tf.int32,
'label_ids': tf.int32},
output_shapes={
'label_ids': (None),
'input_ids': (None, None),
'input_mask': (None, None),
'segment_ids': (None, None)}
)
return dataset
def predict(self, text):
self.text = text
if self.first_run:
self.predictions = self.estimator.predict(
input_fn=self.input_fn_builder, yield_single_examples=False)
self.first_run = False
probabilities = next(self.predictions)
return [self.label[i] for i in probabilities["probabilities"].argmax(axis=1)]
def close(self):
self.closed = True
因篇幅问题不能全部显示,请点此查看更多更全内容