【自然语言处理】手撕 FastText 源码(01)分类器的预测过程

栏目: 编程工具 · 发布时间: 5年前

内容简介:Last updated on 2019年7月1日作者:LogM本文原载于

Last updated on 2019年7月1日

作者:LogM

本文原载于 https://segmentfault.com/u/logm/articles

,不允许转载~

1. 源码来源

FastText 源码: https://github.com/facebookre…

本文对应的源码版本:Commits on Jun 27 2019, 979d8a9ac99c731d653843890c2364ade0f7d9d3

FastText 论文:

[1] P. Bojanowski, E. Grave, A. Joulin, T. Mikolov,

Enriching Word Vectors with Subword Information

[2] A. Joulin, E. Grave, P. Bojanowski, T. Mikolov,

Bag of Tricks for Efficient Text Classification

2. 概述

FastText 的论文写的比较简单,有些细节不明白,网上也查不到,所幸直接撕源码。

FastText 的”分类器”功能是用的最多的,所以先从”分类器的predict”开始挖。

3. 开撕

先看程序入口的 main
函数,ok,是调用了 predict
函数。

// 文件:src/main.cc
// 行数:403
int main(int argc, char** argv) {
  std::vector<std::string> args(argv, argv + argc);
  if (args.size() < 2) {
    printUsage();
    exit(EXIT_FAILURE);
  }
  std::string command(args[1]);
  if (command == "skipgram" || command == "cbow" || command == "supervised") {
    train(args);           
  } else if (command == "test" || command == "test-label") {
    test(args);
  } else if (command == "quantize") {
    quantize(args);
  } else if (command == "print-word-vectors") {
    printWordVectors(args);
  } else if (command == "print-sentence-vectors") {
    printSentenceVectors(args);
  } else if (command == "print-ngrams") {
    printNgrams(args);
  } else if (command == "nn") {
    nn(args);
  } else if (command == "analogies") {
    analogies(args);
  } else if (command == "predict" || command == "predict-prob") {
    predict(args);       // 这句是我们想要的
  } else if (command == "dump") {
    dump(args);
  } else {
    printUsage();
    exit(EXIT_FAILURE);
  }
  return 0;
}

再看 predict
函数,预处理的代码不用管,直接看 predict 的那行,调用了 FastText::predictLine
。这里注意下,这是个 while
循环,所以

FastText::predictLine
这个函数每次只处理一行。

// 文件:src/main.cc
// 行数:205
void predict(const std::vector<std::string>& args) {
  if (args.size() < 4 || args.size() > 6) {
    printPredictUsage();
    exit(EXIT_FAILURE);
  }
  int32_t k = 1;
  real threshold = 0.0;
  if (args.size() > 4) {
    k = std::stoi(args[4]);
    if (args.size() == 6) {
      threshold = std::stof(args[5]);
    }
  }
  bool printProb = args[1] == "predict-prob";
  FastText fasttext;
  fasttext.loadModel(std::string(args[2]));
  std::ifstream ifs;
  std::string infile(args[3]);
  bool inputIsStdIn = infile == "-";
  if (!inputIsStdIn) {
    ifs.open(infile);
    if (!inputIsStdIn && !ifs.is_open()) {
      std::cerr << "Input file cannot be opened!" << std::endl;
      exit(EXIT_FAILURE);
    }
  }
  std::istream& in = inputIsStdIn ? std::cin : ifs;
  std::vector<std::pair<real, std::string>> predictions;
  while (fasttext.predictLine(in, predictions, k, threshold)) {     // 这句是重点
    printPredictions(predictions, printProb, false);
  }
  if (ifs.is_open()) {
    ifs.close();
  }
  exit(0);
}

再看 FastText::predictLine
,注意这边有两个重点。

// 文件:src/fasttext.cc
// 行数:451
bool FastText::predictLine(
    std::istream& in,
    std::vector<std::pair<real, std::string>>& predictions,
    int32_t k,
    real threshold) const {
  predictions.clear();
  if (in.peek() == EOF) {
    return false;
  }
  std::vector<int32_t> words, labels;
  dict_->getLine(in, words, labels);                // 这句是第一个重点
  Predictions linePredictions;
  predict(k, words, linePredictions, threshold);    // 这句是第二个重点
  for (const auto& p : linePredictions) {
    predictions.push_back(
        std::make_pair(std::exp(p.first), dict_->getLabel(p.second)));
  }
  return true;
}

先看第一个重点, getLine
函数其实是 Dictionary::getLine
,定义在 src/dictionary.cc

这段代码的干货度还是很高的,里面有两个重点, Dictionary::addSubwords
Dictionary::addWordNgrams
,以后会讲。这边只要知道整个函数把读到的这一行的每个Id(包括词语的id,SubWords的Id,WordNgram的Id),存到了数组 words
中。

// 文件:src/dictionary.cc
// 行数:378
int32_t Dictionary::getLine(
    std::istream& in,
    std::vector<int32_t>& words,
    std::vector<int32_t>& labels) const {
  std::vector<int32_t> word_hashes;
  std::string token;
  int32_t ntokens = 0;
  reset(in);
  words.clear();
  labels.clear();
  while (readWord(in, token)) {     // `token` 是读到的一个词语,如果读到一行的行尾,则返回`EOF`
    uint32_t h = hash(token);       // 找到这个词语位于哪个hash桶
    int32_t wid = getId(token, h);      // 在hash桶中找到这个词语的Id,如果负数就是没找到对应的Id
    entry_type type = wid < 0 ? getType(token) : getType(wid);   // 如果没找到对应Id,则有可能是label,`getType`里会处理
    ntokens++;
    if (type == entry_type::word) {
      addSubwords(words, token, wid);   // 重点1,以后会讲
      word_hashes.push_back(h);
    } else if (type == entry_type::label && wid >= 0) {
      labels.push_back(wid - nwords_);
    }
    if (token == EOS) {
      break;
    }
  }
  addWordNgrams(words, word_hashes, args_->wordNgrams);  // 重点2,以后会讲
  return ntokens;
}

再来看第二个重点, FastText::predict
函数,重点是 Model::predict
函数。

// 文件:src/fasttext.cc
// 行数:437
void FastText::predict(
    int32_t k,
    const std::vector<int32_t>& words,
    Predictions& predictions,
    real threshold) const {
  if (words.empty()) {
    return;
  }
  Model::State state(args_->dim, dict_->nlabels(), 0);
  if (args_->model != model_name::sup) {
    throw std::invalid_argument("Model needs to be supervised for prediction!");
  }
  model_->predict(words, k, threshold, predictions, state);       // 这句是重点
}

来到 Model::predict
,有两个重点.

其中 Loss::predict
是将 hidden 层的输出结果进行 softmax 后得到最终概率最大的k个类别,”分类器的predict” 用的是经典的softmax,所以代码也比较简单。而如果是”分类器的train” 则涉及到 Hierarchical SoftmaxLoss
NegativeSamplingLoss
等一些加速手段,比较复杂,以后有机会再讲。

// 文件:src/model.cc
// 行数:53
void Model::predict(
    const std::vector<int32_t>& input,
    int32_t k,
    real threshold,
    Predictions& heap,
    State& state) const {
  if (k == Model::kUnlimitedPredictions) {
    k = wo_->size(0); // output size
  } else if (k <= 0) {
    throw std::invalid_argument("k needs to be 1 or higher!");
  }
  heap.reserve(k + 1);
  computeHidden(input, state);      // 重点1
  loss_->predict(k, threshold, heap, state);    // 重点2,以后再讲
}

我们再来看另一个重点, Model::computeHidden
函数。

Model::computeHidden
函数理解起来比较简单,注意这里的 input
就是前面的 words
,是一系列id组成的数组(包括词语的id,SubWords的Id,WordNgram的Id),把这些求和,然后取平均。

当然有些小伙伴可能有点疑问, Vector::addRow
为什幺是求和,这个以后再讲吧。

// 文件:src/model.cc
// 行数:43
void Model::computeHidden(const std::vector<int32_t>& input, State& state)
    const {
  Vector& hidden = state.hidden;
  hidden.zero();
  for (auto it = input.cbegin(); it != input.cend(); ++it) {
    hidden.addRow(*wi_, *it);           // 求和
  }
  hidden.mul(1.0 / input.size());       // 然后取平均
}

4. 总结

至此,FastText里面的”分类器的predict”的大致流程讲完了,其他的,如”分类器的train”和”词向量”的源码也是类似的方法来阅读。

这里面有几段代码没有详细叙述: Dictionary::addSubwords
Dictionary::addWordNgrams
Vector::addRow
以及 训练时softmax的加速
,先把坑留着,以后有时间再填。


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Data Structures and Algorithms in Java

Data Structures and Algorithms in Java

Robert Lafore / Sams / 2002-11-06 / USD 64.99

Data Structures and Algorithms in Java, Second Edition is designed to be easy to read and understand although the topic itself is complicated. Algorithms are the procedures that software programs use......一起来看看 《Data Structures and Algorithms in Java》 这本书的介绍吧!

UNIX 时间戳转换
UNIX 时间戳转换

UNIX 时间戳转换

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试