当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

栏目: IT技术 · 发布时间: 4年前

内容简介:作者 | Fabian Deuser

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

作者 | Fabian Deuser

译者 | 天道酬勤 责编 | Carol 

出品 | AI科技大本营(ID:rgznai100)

有些人生来伟大,有些人成就伟大,而另一些人则拥有伟大。

—— 威廉·莎士比亚《第十二夜》

在几个月前,谷歌的研究人员介绍了机器学习领域的一颗新星——Flax。从那以后发生了很多事情,预发行版有了巨大的改进。作者自己在Flax上进行的CNNs实验已经取得了成果,与Tensorflow相比,它的灵活性仍然非常好。

今天作者将展示递归神经网络(RNNs)在Flax中的一个应用: 字符级语言模型

在许多学习任务中,我们不必考虑对先前输入的时间依赖性。

但是如果我们没有独立的固定大小的输入和输出向量,该怎么办呢?如果我们有向量序列呢? 解决方案是递归神经网络 。它们允许我们对下面描述的向量序列进行操作。

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

递归神经网络

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...      

在上图中,你可以看到不同类型的输入输出结构:

  • 一对一是典型CNN或多层感知器,一个输入向量映射到一个输出向量。

  • 一对多是用于图像字幕的RNN体系结构。输入是图像,输出是描述图像的单词序列。

  • 多对多:第一种体系结构利用输入序列到输出序列进行机器翻译,如(德语译成英语)。第二个是适用于帧级别的视频字幕。

RNNs  的主要优点是它们不仅依赖于当前输入,而且还依赖于先前的输入。

RNN是一个具有内部隐藏状态h的单元,该状态根据隐藏的大小用零初始化。在每个时间步长t中,我们将输入x_t插入到RNN单元中,并更新隐藏状态。如今,在下一个时间步t +1中,隐藏状态不再用零初始化,而是使用先前的隐藏状态进行初始化。因此,RNN允许保留有关几个时间步长的信息并生成序列。

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...      

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

字符级语言模型

有了这些新知识,我们现在需要为RNN构建第一个应用程序。字符级语言模型是许多任务的基础,例如图片字幕或文本生成。 RNN单元的输入是字符序列形式的大量文本。现在的训练任务是学习在给定先前字符序列的情况下如何预测下一个字符。 因此,我们在每个时间步长t生成一个字符,而我们先前的字符是x_t-1,x_t-2…。

举例来说,让我们以FUZZY一词作为训练序列,现在的词汇为{'f','u','z','y'}。由于RNN仅适用于向量,因此我们将所有字符转换为所谓的“单热向量”。单热向量由零组成,其中一个基于词表中的位置为一个,对于“Z”,转换后的向量为[0,0,1,0]。

在下图中,你可以看到给定输入“ FUZZ”的示例,我们希望预测单词“ UZZY”的结尾。神经元的隐藏大小为4,我们希望输出层中的绿色数字较高,而红色为较低。

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...      

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

编程

作者在上一篇有关CNNs的文章中解释了Flax的一些基本概念。作为数据集,我们使用类似这样的对话组成莎士比亚的作品:

EDWARD:
Tis even so; yet you are Warwick still.
GLOUCESTER:
Come, Warwick, take the time; kneel down, kneel down: Nay, when? strike now, or else the iron cools.

我们再次使用Google Colab进行训练,因此我们必须再次安装必要的PIP-Packages:

pip install -q --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p"`/jaxlib-0.1.42-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-linux_x86_64.whl jax
pip install -q git+https://github.com/google/flax.git@master

因为训练任务非常艰巨,你应该使用具有GPU支持的运行。你可以使用以下命令测试是否存在GPU支持:

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

现在我们准备从头开始创建RNN:

class RNN(flax.nn.Module):
"""LSTM"""
def apply(self, carry, inputs):
carry1, outputs = jax_utils.scan_in_dim(
nn.LSTMCell.partial(name='lstm1'), carry[0], inputs, axis=1)
carry2, outputs = jax_utils.scan_in_dim(
nn.LSTMCell.partial(name='lstm2'), carry[1], outputs, axis=1)
carry3, outputs = jax_utils.scan_in_dim(
nn.LSTMCell.partial(name='lstm3'), carry[2], outputs, axis=1)
x = nn.Dense(outputs, features=params['vocab_length'], name='dense')
return [carry1, carry2, carry3], x

在这样的实际训练情况下,我们不使用普通的RNN单元,而是使用LSTM单元。这是更进一步的发展,可以更好地解决梯度消失的问题。为了获得更高的精度,我们使用了三个堆叠的LSTM单元。我们将第一个单元的输出传递给下一个单元,并用自己的隐藏状态初始化每个LSTM单元,这一点非常重要。否则,我们将无法追踪时间依赖性。

最后一个LSTM单元的输出提供给我们密集层。密集层的词汇量和我们词汇量相当。在前面的“模糊”示例中,神经元的数量为四个。如果将“ FUZZ”设置为RNN的输入,则神经元最多产生类似于[1.7,0.1,-1.0,3.1]这样的输出,因为此输出表明“ Y”是最可能的字符。

因为我们有两种不同的模式,所以针对不同的情况,我们将RNN包装在另一个模块中。

class charRNN(flax.nn.Module):
"""Char Generator"""
def apply(self, inputs, carry_pred=None, train=True):
batch_size = params['batch_size']
vocab_size = params['vocab_length']
hidden_size = 512
if train:
carry1 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)
carry2 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)
carry3 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)
carry = [carry1, carry2, carry3]
_, x = RNN(carry, inputs)
return x
else:
carry, x = RNN(carry_pred, inputs)
return carry, x

这种情况是:

  • 训练模型,我们要学习如何预测。

  • 预测模型,实际上在这里我们采样一些文本。

在训练模型之前,我们需要使用以下函数创建它:

def create_model(rng):
"""Creates a model."""
vocab_size = params['vocab_length']
_, initial_params = charRNN.init_by_shape(
rng, [((1, params['seq_length'], vocab_size), jnp.float32)])
model = nn.Model(charRNN, initial_params)
return model

我们每个序列长度为50个字符,词汇表包含65个不同的字符。

作为RNN的优化程序,为了避免初始权重过大,我们选择了初始学习率为0.002且权重衰减的Adam优化器。

def create_optimizer(model, learning_rate):
"""Creates an Adam optimizer for model."""
optimizer_def = optim.Adam(learning_rate=learning_rate, weight_decay=1e-1)
optimizer = optimizer_def.create(model)
return optimizer

训练模型

在训练模型下,我们将32个序列的批次输入到RNN中。每个序列均取自我们的数据集,并包含两个子序列,一个是子序列的字符从0到49,另一个子序列的字符从1到50。通过这种简单的拆分,我们的网络可以学习到最有可能的下一个字符。在每一批中,我们初始化隐藏状态,并将序列提供给我们的RNN。

@jax.jit
def train_step(optimizer, batch):
"""Train one step."""
def loss_fn(model):
"""Compute cross-entropy loss and predict logits of the current batch"""
logits = model(batch[0])
loss = jnp.mean(cross_entropy_loss(logits, batch[1])) / params['batch_size']
return loss, logits
def exponential_decay(steps):
"""Decrease the learning rate every 5 epochs"""
x_decay = (steps / params['step_decay']).astype('int32')
ret = params['learning_rate']* jax.lax.pow((params['learning_rate_decay']), x_decay.astype('float32'))
return jnp.asarray(ret, dtype=jnp.float32)
current_step = optimizer.state.step
new_lr = exponential_decay(current_step)
# calculate and apply the gradient
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grad = grad_fn(optimizer.target)
new_optimizer = optimizer.apply_gradient(grad, learning_rate=new_lr)
metrics = compute_metrics(logits, batch[1])
metrics['learning_rate'] = new_lr
return new_optimizer, metrics

在我们的训练方法中有两个子函数。loss_fn通过将被解释为向量的输出神经元与所需的单热向量进行比较来计算交叉熵损失。因此在“模糊”示例中,我们将有一个输出[1.7,0.1,-1.0,3.1]和一个热向量[0,0,0,1]。现在我们使用以下公式计算损失:

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...      

我们不得不从CNN示例中重写一些代码,因为我们现在使用的不是简单类的序列:

@jax.vmap
def cross_entropy_loss(logits, labels):
"""Returns cross-entropy loss."""
return -jnp.mean(jnp.sum(nn.log_softmax(logits) * labels))

训练步骤中的另一种方法是exponential_decay。我们使用的是Adam优化器,初始学习率为0.002。为了避免太强烈的振荡,我们想每五个周期降低学习率。在每五个周期之后,因子0.97乘以我们的初始学习率,x是多长时间我们达到五个时期。

你将再次看到Flax的优势,即以轻松灵活的方式集成自己的学习速率调度程序。

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

预测模型

现在我们要评估学习模型,因此我们从词汇表中选择一个随机字符作为切入点。像在训练中一样,我们初始化隐藏状态,但是这次只是在采样开始时。现在子函数推断将一个字符作为输入。对于隐藏状态,我们在每个时间步长后输出,并在下一个时间步长中将它们输入到RNN中。因此,我们不会失去时间依赖性。

@jax.jit
def sample(inputs, optimizer):
    next_inputs = inputs
    output = []
    batch_size = 1
    carry1 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)
    carry2 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)
    carry3 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)
    carry = [carry1, carry2, carry3]


    def inference(model, carry):
        carry, rnn_output = model(inputs=next_inputs, train=False, carry_pred=carry)
        return carry, rnn_output
    for i in range(200):
        carry, rnn_output = inference(optimizer.target, carry)
        output.append(jnp.argmax(rnn_output, axis=-1))
        # Select the argmax as the next input.
        next_inputs = jnp.expand_dims(common_utils.onehot(jnp.argmax(rnn_output), params['vocab_length']), axis=0)
    return output

这种方法称为“贪婪采样”,因为我们总是取输出向量中概率最大的字符。还有更好的采样方法,比如波束搜索,在此就不做介绍。

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

训练和样本循环

至少我们可以在训练和样本循环中调用所有编写的函数。

def train_model():
    """Train and inference """
    rng = jax.random.PRNGKey(0)
    model = create_model(rng)
    optimizer = create_optimizer(model, params['learning_rate'])
    del model
    for epoch in range(100):
        for text in tfds.as_numpy(ds):
            optimizer, metrics = train_step(optimizer, text)
        print('epoch: %d, loss: %.4f, accuracy: %.2f, LR: %.8f' % (epoch+1,metrics['loss'], metrics['accuracy'] * 100, metrics['learning_rate']))
        test = test_ds(params['vocab_length'])
        sampled_text = ""
        if ((epoch+1)%10 == 0):
            for i in test:
                sampled_text += vocab[int(jnp.argmax(i.numpy(),-1))]
                start = np.expand_dims(i, axis=0)
                text = sample(start, optimizer)
            for i in text:
                sampled_text += vocab[int(i)]
            print(sampled_text)

每隔10个周期后,我们会生成一个文本示例,并且在开始时看起来非常重复:

peak the mariners all the merchant of the meaning of the meaning of the meaning of the meaning of the meaning of the meaning…

但是我们变得越来越好,经过100个周期的训练,莎士比亚的作品似乎还活着,并在写新的文字!

This is a shift respected woman to the king's forth,

To this most dangerous soldier there and fortune.

ANTONIO:

If she would concount a sight on honour

Of the moon, why,...

100个周期训练准确性为86.10%,我们的学习率降至0.00112123。

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

结论

字符级语言模型的基础是一个能够完成文本的强大工具,可以用作自动补全。可以用作自动补全。 也可以利用这个概念来学习一篇文章的观点。但是,生成完整的新文本是一项非常困难的任务。

我们的模型输出的句子看起来像莎士比亚的文本,但它缺乏意义。大家也可以尝试用这种模型并根据有意义的输入创建更有意义的句子。

Flax功能强大且 工具 众多,但仍处于开发的初期阶段,但它们在开发我喜欢的框架方面处于良好的发展状态。真正巧妙的是,我们只需要稍微更改一下“旧” CNN代码即可在现有基础上使用RNN。

但是Flax仍然缺少它自己的输入管道,因此作者已经用Tensorflow编写了它。如果你想尝试使用作者的代码,你可以在Github Repo中找到用于数据集创建和完整RNN的代码 (https://github.com/Skyy93/CharacterLevelModelFlax/)

原文:https://hackernoon.com/shakespeare-meets-googles-flax-8m1r34q9

本文为 AI 科技大本营翻译,转载请经授权。

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

今日福利

遇见陆奇

同样作为“百万人学 AI”的重要组成部分,2020 AIProCon 开发者万人大会将于 7 月 3 日至 4 日通过线上直播形式,让开发者们一站式学习了解当下 AI 的前沿技术研究、核心技术与应用以及企业案例的实践经验,同时还可以在线参加精彩多样的开发者沙龙与编程项目。参与前瞻系列活动、在线直播互动,不仅可以与上万名开发者们一起交流,还有机会赢取直播专属好礼,与技术大咖连麦。

门票限量大放送!今日起点击阅读原文报名「2020 AI开发者万人大会」,使用优惠码“AIP211”,即可免费获得价值299元的大会在线直播门票一张。 限量100张,先到先得! 快来动动手指,免费获取入会资格吧!

点击阅读原文 ,直达大会官网。

当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

你点的每个“在看”,我都认真当成了AI


以上所述就是小编给大家介绍的《当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Microformats

Microformats

John Allsopp / friends of ED / March 26, 2007 / $34.99

In this book, noted web developer and long time WaSP member John Allsop teaches all you need to know about the technology: what Microformats are currently available and how to use them; the general pr......一起来看看 《Microformats》 这本书的介绍吧!

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具