【译】如何在每次训练中都得到相同的word2vec/doc2vec/Paragraph Vectors

栏目: 编程工具 · 发布时间: 6年前

内容简介:本文翻译自作者在medium发布的一篇推文,这里是原文链接本文是我会尽力不把各位读者引导到一大堆冗长而又无法让人真正理解的教程中,最后以放弃告终(相信我,我也是网上诸多教程的受害者)。我想我们可以一起

本文翻译自作者在medium发布的一篇推文,这里是原文链接

本文是 Word Embedding 系列的第一篇。本文适合中级以上的读者或者训练过 word2vec/doc2vec/Paragraph Vectors 的读者阅读,但别担心,我将在接下来的推文中介绍理论以及背景知识,并 联系论文讲解代码是如何实现 的。

我会尽力不把各位读者引导到一大堆冗长而又无法让人真正理解的教程中,最后以放弃告终(相信我,我也是网上诸多教程的受害者)。我想我们可以一起 从代码层面来了解word2vec, 这样我们可以知道 如何设计并实现我们自己的word embedding 和language model.

如果您曾经自己训练过word vectors,会发现尽管使用相同的数据进行训练,但每次训练得到的模型和词向量表示都不一样。这是因为在训练过程中引入了随机性所致。让我们一起来从代码中找到这些随机性是如何引入的,以及如何消除这种随机性。我将用DL4j的 Paragraph Vectors实现 来展示代码。如果您想看其他包的实现,可以看gensim的doc2vec,它有相同的实现方法。

随机性从哪里来

模型权重和词向量的初始化

我们知道在训练最初,模型各参数和词向量表示会随机初始化,这里的随机性是由seed控制实现的。因此,当我们把seed设为0,我们在每次训练中会得到完全相同的初始化。 这里 来看seed是如何影响初始化的,syn0是模型权重。

// Nd4j 设置有关生成随机数的seed
Nd4j.getRandom().setSeed(configuration.getSeed());
// Nd4j 为 syn0 初始化一个随机矩阵
syn0 = Nd4j.rand(new int[] {vocab.numWords(), vectorLength}, rng).subi(0.5).divi(vectorLength);复制代码

PV-DBOW 算法

如果我们使用PV-DBOW算法训练Paragraph Vectors,在训练迭代中,单词会从窗口中随机取得并计算、更新模型。但是这里的随机在 代码实现 中并不是真正的随机。

// nextRandom 是一个 AtomicLong,并被threadId初始化
this.nextRandom = new AtomicLong(this.threadId);复制代码

nextRandom在 trainSequence(sequence, nextRandom, alpha); 被用到,在 trainSequence 中, nextRandom.set(nextRandom.get() * 25214903917L + 11); 如果我们更加深入到每个训练的步骤,我们会发现nextRandom产生于相同的步骤及方法,即进行固定的数学运算(到这里和这里了解为什么这样做),所以 nextRandom 是依赖于 threadId 的数字,而 threadId 是0,1,2,3,...所以这里我们实际上不再有随机性。

并行tokenization

因为对文本的处理是一项耗时的工作,所以进行并行tokenization可以提高性能,但训练的一致性将不能得到保证。并行处理下,提供给每个thread进行训练的数据将出现随机性。从 代码 中可以看到,如果我们将 allowParallelBuilder 设为 false ,进行tokenization的 runnable 将阻塞其他thread直到tokenization结束,从而保持输入训练数据的一致性。

if (!allowParallelBuilder) {
    try {
        runnable.awaitDone();
    } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
        throw new RuntimeException(e);
    }
}复制代码

为各个thread提供训练数据的队列

该队列是一个 LinkedBlockingQueue ,这个队列从迭代器中取出训练文本,然后提供给各个线程进行训练。因为各个线程请求数据的时间可以是任意的,所以在每次训练中,每个线程得到的数据也是不一样的。请看这里的 代码具体实现

// 初始化一个 sequencer 来提供数据给每个线程
val sequencer = new AsyncSequencer(this.iterator, this.stopWords);
// 每个线程使用同一个 sequencer
// worker是我们设置的进行训练的线程数
for (int x = 0; x < workers; x++) {
    threads.add(x, new VectorCalculationsThread(x, ..., sequencer);                
    threads.get(x).start();            
}
// 在sequencer中 初始化一个 LinkedBlockingQueue buffer
// 同时保持该buffer的size在[limitLower, limitUpper]
private final LinkedBlockingQueue<Sequence<T>> buffer;
limitLower = workers * batchSize;
limitUpper = workers * batchSize * 2;
// 线程从buffer中读取数据
buffer.poll(3L, TimeUnit.SECONDS);复制代码

所以,如果我们将 worker 设为1,即采用单线程进行训练,那么每次训练我们将得到相同顺序的数据。这里需要注意的是,如果采用单线程,训练的速度将会大幅降低。

总结

为了将随机性排除,我们需要做以下:

  1. seed 设为0;
  2. allowParallelTokenization 设为 false ;
  3. worker 设为1。

这样在使用相同数据训练,我们将会得到完全相同的模型参数和向量表示。

最终,我们的训练代码将会像:

ParagraphVectors vec = new ParagraphVectors.Builder()
                .minWordFrequency(1)
                .labels(labelsArray)
                .layerSize(100)
                .stopWords(new ArrayList<String>())
                .windowSize(5)
                .iterate(iter)
                .allowParallelTokenization(false)
                .workers(1)
                .seed(0)
                .tokenizerFactory(t)
                .build();

vec.fit();复制代码

如果您觉得对上述内容不理解,那么别担心,我将在之后的推文中联系代码和论文,详细解释word embedding以及language model的技术。

参考

  1. Deeplearning4j, ND4J, DataVec and more - deep learning & linear algebra for Java/Scala with GPUs + Spark - From Skymind http://deeplearning4j.org https://github.com/deeplearning4j/deeplearning4j
  2. Java™ Platform, Standard Edition 8 API Specification https://docs.oracle.com/javase/8/docs/api/

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

概率编程实战

概率编程实战

[美]艾维·费弗 (Avi Pfeffer) / 姚军 / 人民邮电出版社 / 2017-4 / 89

概率推理是不确定性条件下做出决策的重要方法,在许多领域都已经得到了广泛的应用。概率编程充分结合了概率推理模型和现代计算机编程语言,使这一方法的实施更加简便,现已在许多领域(包括炙手可热的机器学习)中崭露头角,各种概率编程系统也如雨后春笋般出现。本书的作者Avi Pfeffer正是主流概率编程系统Figaro的首席开发者,他以详尽的实例、清晰易懂的解说引领读者进入这一过去令人望而生畏的领域。通读本书......一起来看看 《概率编程实战》 这本书的介绍吧!

XML 在线格式化
XML 在线格式化

在线 XML 格式化压缩工具

UNIX 时间戳转换
UNIX 时间戳转换

UNIX 时间戳转换

RGB HSV 转换
RGB HSV 转换

RGB HSV 互转工具