使用BERT回归的代码

栏目: R语言 · 发布时间: 6年前

内容简介:我们在做句子相似度计算的时候需要的输出是一个0到1之间的实数值,用来表示句子的相似程度。BERT默认只提供了run_classifier.py,它可以用于Fine-Tuning文本分类、相似度分类、Entailment等任务。但是无法实现实数值的输出,因此我参照run_classifier.py实现了一个run_reg.py。需要的读者可以我如果读者是其它的版本,也可以去

我们在做句子相似度计算的时候需要的输出是一个0到1之间的实数值,用来表示句子的相似程度。BERT默认只提供了run_classifier.py,它可以用于Fine-Tuning文本分类、相似度分类、Entailment等任务。但是无法实现实数值的输出,因此我参照run_classifier.py实现了一个run_reg.py。

代码

需要的读者可以我 fork的版本 去clone代码,这个fork的版本是最新(昨天pull过)的,log为:

commit ffbda2a1aafe530525212d13194cc84d92ed0313
Merge: 0a0ea64 3084a39
Author: Jacob Devlin <44483550+jacobdevlin-google@users.noreply.github.com>
Date:   Tue Feb 12 14:17:23 2019 -0800

如果读者是其它的版本,也可以去 这个pr ,把它合并到自己的版本了。不了解怎么把pr合并到本地的读者可以参考 这里

用法

用法和run_classifier.py很像,它要求的输入是如下格式:

手机号码注销了,怎么换手机号吗?	如何修改手机号	1
支付宝怎么充值	微信怎么充值	0.5

也就是使用TAB分割的文件,每行三列,前两列是两个句子,最后一列是一个实数值。

使用的示例代码为:

python run_reg.py \
    --task_name=sim \
    --do_train=true \
    --do_eval=true \
    --data_dir=/path/to/your/data \
    --vocab_file=$BERT_BASE_DIR/vocab.txt \
    --bert_config_file=$BERT_BASE_DIR/bert_config.json \
    --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
    --max_seq_length=64 \
    --train_batch_size=8 \
    --learning_rate=5e-5 \
    --num_train_epochs=2 \
    --output_dir=/tmp/sim/

需要指定task_name为sim,如果回归的输出范围不是0和1之间,那么需要设置--use_sigmoid_act=false。

如果读者的输入格式不同,也可以自己参考SimProcessor类实现自己的数据处理。

实现细节

代码基本是拷贝的run_classifier.py,然后把输入从label_ids(int32)变成了vals(float32),把loss从交叉熵改成了MSE,同时修改了相应的Metrics。不感兴趣的读者可以跳过,并不影响使用。

create_model函数的修改

原来的代码:

with tf.variable_scope("loss"):
    if is_training:
      # I.e., 0.1 dropout
      output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    probabilities = tf.nn.softmax(logits, axis=-1)
    log_probs = tf.nn.log_softmax(logits, axis=-1)

    one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)

    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)

    return (loss, per_example_loss, logits, probabilities)

新的代码:

with tf.variable_scope("loss"):
    if is_training:
      # I.e., 0.1 dropout
      output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    if FLAGS.use_sigmoid_act:
      output = tf.nn.sigmoid(logits)
    else:
      output = logits

    output = tf.squeeze(output, [1])
    loss = tf.losses.mean_squared_error(vals, output)

    return (loss, output)

file_based_convert_examples_to_features的修改

原代码:

features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(feature.input_ids)
    features["input_mask"] = create_int_feature(feature.input_mask)
    features["segment_ids"] = create_int_feature(feature.segment_ids)
    features["label_ids"] = create_int_feature([feature.label_id])
    features["is_real_example"] = create_int_feature(
        [int(feature.is_real_example)])

修改后的:

def create_float_feature(vals):
      return tf.train.Feature(float_list=tf.train.FloatList(value=vals))

    features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(feature.input_ids)
    features["input_mask"] = create_int_feature(feature.input_mask)
    features["segment_ids"] = create_int_feature(feature.segment_ids)
    features["vals"] = create_float_feature([feature.val])

model_fn_builder的修改

原来是分类,因此有eval_accuracy和eval_loss等指标,现在是回归,因此只有eval_loss这一个指标。

原代码:

def metric_fn(per_example_loss, label_ids, logits, is_real_example):
        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        accuracy = tf.metrics.accuracy(
            labels=label_ids, predictions=predictions, weights=is_real_example)
        loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
        return {
            "eval_accuracy": accuracy,
            "eval_loss": loss,
        }

      eval_metrics = (metric_fn,
                      [per_example_loss, label_ids, logits, is_real_example])

修改后的:

def metric_fn(preds, vals):
        return {
            "eval_loss": tf.metrics.mean_squared_error(vals, preds),
        }

      eval_metrics = (metric_fn, [pred_vals, vals])

file_based_input_fn_builder

读取TFRecord文件时也有小的修改。从

name_to_features = {
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "label_ids": tf.FixedLenFeature([], tf.int64),
      "is_real_example": tf.FixedLenFeature([], tf.int64),
  }

变成

name_to_features = {
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "vals": tf.FixedLenFeature([], tf.float32),
  }

以上所述就是小编给大家介绍的《使用BERT回归的代码》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Big Java Late Objects

Big Java Late Objects

Horstmann, Cay S. / 2012-2 / 896.00元

The introductory programming course is difficult. Many students fail to succeed or have trouble in the course because they don't understand the material and do not practice programming sufficiently. ......一起来看看 《Big Java Late Objects》 这本书的介绍吧!

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具

html转js在线工具
html转js在线工具

html转js在线工具

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具