您的当前位置:首页TensorFlow关于怎样解决Estimater.predict总是重新加载模型的问题

TensorFlow关于怎样解决Estimater.predict总是重新加载模型的问题

来源:锐游网

问题:

大家用Estimater.predict总是把模型重新load一遍,这样工程业务根本没法用。

解决方案:

代码

我的代码是用于Bert模型的,思路说清楚了,具体功能请自行修改。

from tokenization import FullTokenizer, validate_case_matches_checkpoint
from conv_example import convert_single_example
from process import InputExample
from modeling import BertConfig
from model_func import model_fn_builder
from config import FLAGS, TFConfig
import tensorflow as tf
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.run_config import RunConfig


class Fast(object):
    def __init__(self, label):
        self.label = label
        self.closed = False
        self.first_run = True
        self.tokenizer = FullTokenizer(
            vocab_file=FLAGS.vocab_file,
            do_lower_case=True)
        self.init_checkpoint = FLAGS.init_checkpoint
        self.seq_length = FLAGS.max_seq_length
        self.text = None
        self.num_examples = None
        self.predictions = None
        self.estimator = self.get_estimator()

    def get_estimator(self):
        validate_case_matches_checkpoint(True, self.init_checkpoint)
        bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)  # 载入bert自定义配置
        if FLAGS.max_seq_length > bert_config.max_position_embeddings:  # 验证配置信息准确性
            raise ValueError(
                "Cannot use sequence length %d because the BERT pre_model "
                "was only trained up to sequence length %d" %
                (self.seq_length, bert_config.max_position_embeddings))

        run_config = RunConfig(
            model_dir=FLAGS.output_dir,
            save_checkpoints_steps=FLAGS.save_checkpoints_steps,
            session_config=TFConfig.cpu()
        )
        model_fn = model_fn_builder(  # 估计器函数,提供Estimator使用的model_fn,内部使用EstimatorSpec构建的
            bert_config=bert_config,
            num_labels=len(self.label),
            init_checkpoint=self.init_checkpoint,
            learning_rate=FLAGS.learning_rate,
            num_train_steps=None,
            num_warmup_steps=None,
            use_tpu=FLAGS.use_tpu,
            use_one_hot_embeddings=FLAGS.use_tpu)

        estimator = Estimator(  # 实例化估计器
            model_fn=model_fn,
            config=run_config,
            warm_start_from=self.init_checkpoint  # 新增预热
        )
        return estimator

    def get_feature(self, index, text):
        example = InputExample(f"text_{index}", text, None, self.label[0])
        feature = convert_single_example(index, example, self.label, self.seq_length, self.tokenizer)
        return feature.input_ids, feature.input_mask, feature.segment_ids, feature.label_id

    def create_generator(self):
        """构建生成器"""
        while not self.closed:
            self.num_examples = len(self.text)
            features = (self.get_feature(*f) for f in enumerate(self.text))
            yield dict(zip(("input_ids", "input_mask", "segment_ids", "label_ids"), zip(*features)))

    def input_fn_builder(self):
        """用于预测单独对预测数据进行创建,不基于文件数据"""
        dataset = tf.data.Dataset.from_generator(
            self.create_generator,
            output_types={'input_ids': tf.int32,
                          'input_mask': tf.int32,
                          'segment_ids': tf.int32,
                          'label_ids': tf.int32},
            output_shapes={
                'label_ids': (None),
                'input_ids': (None, None),
                'input_mask': (None, None),
                'segment_ids': (None, None)}
        )
        return dataset

    def predict(self, text):
        self.text = text
        if self.first_run:
            self.predictions = self.estimator.predict(
                input_fn=self.input_fn_builder, yield_single_examples=False)
            self.first_run = False
        probabilities = next(self.predictions)
        return [self.label[i] for i in probabilities["probabilities"].argmax(axis=1)]

    def close(self):
        self.closed = True

因篇幅问题不能全部显示,请点此查看更多更全内容

Top