使用BERT回归的代码

栏目: R语言 · 发布时间: 6年前

内容简介:我们在做句子相似度计算的时候需要的输出是一个0到1之间的实数值,用来表示句子的相似程度。BERT默认只提供了run_classifier.py,它可以用于Fine-Tuning文本分类、相似度分类、Entailment等任务。但是无法实现实数值的输出,因此我参照run_classifier.py实现了一个run_reg.py。需要的读者可以我如果读者是其它的版本,也可以去

我们在做句子相似度计算的时候需要的输出是一个0到1之间的实数值,用来表示句子的相似程度。BERT默认只提供了run_classifier.py,它可以用于Fine-Tuning文本分类、相似度分类、Entailment等任务。但是无法实现实数值的输出,因此我参照run_classifier.py实现了一个run_reg.py。

代码

需要的读者可以我 fork的版本 去clone代码,这个fork的版本是最新(昨天pull过)的,log为:

commit ffbda2a1aafe530525212d13194cc84d92ed0313
Merge: 0a0ea64 3084a39
Author: Jacob Devlin <44483550+jacobdevlin-google@users.noreply.github.com>
Date:   Tue Feb 12 14:17:23 2019 -0800

如果读者是其它的版本,也可以去 这个pr ,把它合并到自己的版本了。不了解怎么把pr合并到本地的读者可以参考 这里

用法

用法和run_classifier.py很像,它要求的输入是如下格式:

手机号码注销了,怎么换手机号吗?	如何修改手机号	1
支付宝怎么充值	微信怎么充值	0.5

也就是使用TAB分割的文件,每行三列,前两列是两个句子,最后一列是一个实数值。

使用的示例代码为:

python run_reg.py \
    --task_name=sim \
    --do_train=true \
    --do_eval=true \
    --data_dir=/path/to/your/data \
    --vocab_file=$BERT_BASE_DIR/vocab.txt \
    --bert_config_file=$BERT_BASE_DIR/bert_config.json \
    --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
    --max_seq_length=64 \
    --train_batch_size=8 \
    --learning_rate=5e-5 \
    --num_train_epochs=2 \
    --output_dir=/tmp/sim/

需要指定task_name为sim,如果回归的输出范围不是0和1之间,那么需要设置--use_sigmoid_act=false。

如果读者的输入格式不同,也可以自己参考SimProcessor类实现自己的数据处理。

实现细节

代码基本是拷贝的run_classifier.py,然后把输入从label_ids(int32)变成了vals(float32),把loss从交叉熵改成了MSE,同时修改了相应的Metrics。不感兴趣的读者可以跳过,并不影响使用。

create_model函数的修改

原来的代码:

with tf.variable_scope("loss"):
    if is_training:
      # I.e., 0.1 dropout
      output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    probabilities = tf.nn.softmax(logits, axis=-1)
    log_probs = tf.nn.log_softmax(logits, axis=-1)

    one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)

    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)

    return (loss, per_example_loss, logits, probabilities)

新的代码:

with tf.variable_scope("loss"):
    if is_training:
      # I.e., 0.1 dropout
      output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    if FLAGS.use_sigmoid_act:
      output = tf.nn.sigmoid(logits)
    else:
      output = logits

    output = tf.squeeze(output, [1])
    loss = tf.losses.mean_squared_error(vals, output)

    return (loss, output)

file_based_convert_examples_to_features的修改

原代码:

features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(feature.input_ids)
    features["input_mask"] = create_int_feature(feature.input_mask)
    features["segment_ids"] = create_int_feature(feature.segment_ids)
    features["label_ids"] = create_int_feature([feature.label_id])
    features["is_real_example"] = create_int_feature(
        [int(feature.is_real_example)])

修改后的:

def create_float_feature(vals):
      return tf.train.Feature(float_list=tf.train.FloatList(value=vals))

    features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(feature.input_ids)
    features["input_mask"] = create_int_feature(feature.input_mask)
    features["segment_ids"] = create_int_feature(feature.segment_ids)
    features["vals"] = create_float_feature([feature.val])

model_fn_builder的修改

原来是分类,因此有eval_accuracy和eval_loss等指标,现在是回归,因此只有eval_loss这一个指标。

原代码:

def metric_fn(per_example_loss, label_ids, logits, is_real_example):
        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        accuracy = tf.metrics.accuracy(
            labels=label_ids, predictions=predictions, weights=is_real_example)
        loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
        return {
            "eval_accuracy": accuracy,
            "eval_loss": loss,
        }

      eval_metrics = (metric_fn,
                      [per_example_loss, label_ids, logits, is_real_example])

修改后的:

def metric_fn(preds, vals):
        return {
            "eval_loss": tf.metrics.mean_squared_error(vals, preds),
        }

      eval_metrics = (metric_fn, [pred_vals, vals])

file_based_input_fn_builder

读取TFRecord文件时也有小的修改。从

name_to_features = {
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "label_ids": tf.FixedLenFeature([], tf.int64),
      "is_real_example": tf.FixedLenFeature([], tf.int64),
  }

变成

name_to_features = {
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "vals": tf.FixedLenFeature([], tf.float32),
  }

以上所述就是小编给大家介绍的《使用BERT回归的代码》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

云攻略

云攻略

马克·贝尼奥夫、卡莱尔·阿德勒 / 徐杰 / 海天出版社 / 2010年8月 / 36.00元

Apple、Google、甲骨文、腾讯 都已投入了云的怀抱, 你还在等什么? 快来加入我们! 最初,Salesforce.com 只是一间小小的租赁公寓 在短短10年内 它已成长为 世界上发展最快、最具创新力的 产业变革领导者 曾经,这是个软件为王的时代。 现在,这是个云计算的新时代。 NO SOFTWARE 抛弃软件的......一起来看看 《云攻略》 这本书的介绍吧!

html转js在线工具
html转js在线工具

html转js在线工具

UNIX 时间戳转换
UNIX 时间戳转换

UNIX 时间戳转换

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具