深度有趣 | 24 语句相似度计算

栏目: 编程工具 · 发布时间: 7年前

内容简介:Siamese Network是指网络中包含两个或以上完全相同的子网络,多应用于语句相似度计算、人脸匹配、签名鉴别等任务上以语句相似度计算为例,两边的子网络从Embedding层到LSTM层等都是完全相同的,整个模型称作MaLSTM(Manhattan LSTM)通过LSTM层的最后输出得到两句话的固定长度表示,再使用以下公式计算两者的相似度,相似度在0至1之间

Siamese Network是指网络中包含两个或以上完全相同的子网络,多应用于语句相似度计算、人脸匹配、签名鉴别等任务上

  • 语句相似度计算:输入两句话,判断是否是一个意思
  • 人脸匹配:输入两张人脸,判断是否是同一个人
  • 签名鉴别:输入两个签名,判断是否是同一个人所写

以语句相似度计算为例,两边的子网络从Embedding层到LSTM层等都是完全相同的,整个模型称作MaLSTM(Manhattan LSTM)

深度有趣 | 24 语句相似度计算

通过LSTM层的最后输出得到两句话的固定长度表示,再使用以下公式计算两者的相似度,相似度在0至1之间

数据

使用Kaggle上的Quora问题对数据,Quora对应外国的知乎, www.kaggle.com/c/quora-que…

训练集和测试集分别有404290和3563475条数据,每条数据包括以下字段,但测试集不包括is_duplicate字段

  • id:问题对的id
  • qid1:问题1的id
  • qid2:问题2的id
  • question1:问题1的文本
  • question2:问题2的文本
  • is_duplicate:两个问题是不是意思一样,0或1

实现

加载库

# -*- coding: utf-8 -*-

from keras.preprocessing.sequence import pad_sequences
from keras.models import Model
from keras.layers import Input, Embedding, LSTM, Lambda
import keras.backend as K
from keras.optimizers import Adam

import pandas as pd
import numpy as np
from gensim.models import KeyedVectors
from nltk.corpus import stopwords
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
import re
from tqdm import tqdm
import pickle
复制代码

加载训练集和测试集

train_df = pd.read_csv('train.csv')
test_df = pd.read_csv('test.csv')
print(len(train_df), len(test_df))
train_df.head()
复制代码

加载nltk(Natural Language Toolkit)中的停用词,并定义一个文本预处理函数

# 如果报错nltk没有stopwords则下载
# import nltk
# nltk.download('stopwords')

stops = set(stopwords.words('english'))

def preprocess(text):
    # input: 'Hello are you ok?'
    # output: ['Hello', 'are', 'you', 'ok', '?']
    text = str(text)
    text = text.lower()
    
    text = re.sub(r"[^A-Za-z0-9^,!.\/'+-=]", " ", text)  # 去掉其他符号
    text = re.sub(r"what's", "what is ", text)           # 缩写
    text = re.sub(r"\'s", " is ", text)                  # 缩写   
    text = re.sub(r"\'ve", " have ", text)               # 缩写
    text = re.sub(r"can't", "cannot ", text)             # 缩写
    text = re.sub(r"n't", " not ", text)                 # 缩写
    text = re.sub(r"i'm", "i am ", text)                 # 缩写
    text = re.sub(r"\'re", " are ", text)                # 缩写
    text = re.sub(r"\'d", " would ", text)               # 缩写
    text = re.sub(r"\'ll", " will ", text)               # 缩写
    text = re.sub(r",", " ", text)                       # 去除逗号
    text = re.sub(r"\.", " ", text)                      # 去除句号
    text = re.sub(r"!", " ! ", text)                     # 保留感叹号
    text = re.sub(r"\/", " ", text)                      # 去掉右斜杠
    text = re.sub(r"\^", " ^ ", text)                    # 其他符号
    text = re.sub(r"\+", " + ", text)                    # 其他符号
    text = re.sub(r"\-", " - ", text)                    # 其他符号
    text = re.sub(r"\=", " = ", text)                    # 其他符号
    text = re.sub(r"\'", " ", text)                      # 去掉单引号 
    text = re.sub(r"(\d+)(k)", r"\g<1>000", text)        # 把30k等替换成30000
    text = re.sub(r":", " : ", text)                     # 其他符号
    text = re.sub(r" e g ", " eg ", text)                # 其他词
    text = re.sub(r" b g ", " bg ", text)                # 其他词
    text = re.sub(r" u s ", " american ", text)          # 其他词
    text = re.sub(r"\0s", "0", text)                     # 其他词
    text = re.sub(r" 9 11 ", " 911 ", text)              # 其他词
    text = re.sub(r"e - mail", "email", text)            # 其他词
    text = re.sub(r"j k", "jk", text)                    # 其他词
    text = re.sub(r"\s{2,}", " ", text)                  # 将多个空白符替换成一个空格

    return text.split()
复制代码

加载Google预训练好的300维词向量

word2vec = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True)
复制代码

整理词典,一共有58564个词,将文本替换成整数序列表示,获得词向量映射矩阵

vocabulary = []
word2id = {}
id2word = {}

for df in [train_df, test_df]:
    for i in tqdm(range(len(df))):
        row = df.iloc[i]
        for column in ['question1', 'question2']:
            q2n = []
            for word in preprocess(row[column]):
                if word in stops or word not in word2vec.vocab:
                    continue
                
                if word not in vocabulary:
                    word2id[word] = len(vocabulary) + 1
                    id2word[len(vocabulary) + 1] = word
                    vocabulary.append(word)
                    q2n.append(word2id[word])
                else:
                    q2n.append(word2id[word])
            
            df.at[i, column] = q2n

embedding_dim = 300
embeddings = np.random.randn(len(vocabulary) + 1, embedding_dim)
embeddings[0] = 0  # 零填充对应的词向量

for index, word in enumerate(vocabulary):
    embeddings[index] = word2vec.word_vec(word)

del word2vec
print(len(vocabulary))
复制代码

分割训练集和验证集,将整数序列padding到统一长度

maxlen = max(train_df.question1.map(lambda x: len(x)).max(),
             train_df.question2.map(lambda x: len(x)).max(),
             test_df.question1.map(lambda x: len(x)).max(),
             test_df.question2.map(lambda x: len(x)).max())

valid_size = 40000
train_size = len(train_df) - valid_size

X = train_df[['question1', 'question2']]
Y = train_df['is_duplicate']

X_train, X_valid, Y_train, Y_valid = train_test_split(X, Y, test_size=valid_size)
X_train = {'left': X_train.question1.values, 'right': X_train.question2.values}
X_valid = {'left': X_valid.question1.values, 'right': X_valid.question2.values}
Y_train = np.expand_dims(Y_train.values, axis=-1)
Y_valid = np.expand_dims(Y_valid.values, axis=-1)

# 前向填充或截断
X_train['left'] = np.array(pad_sequences(X_train['left'], maxlen=maxlen))
X_train['right'] = np.array(pad_sequences(X_train['right'], maxlen=maxlen))
X_valid['left'] = np.array(pad_sequences(X_valid['left'], maxlen=maxlen))
X_valid['right'] = np.array(pad_sequences(X_valid['right'], maxlen=maxlen))

print(X_train['left'].shape, X_train['right'].shape)
print(X_valid['left'].shape, X_valid['right'].shape)
print(Y_train.shape, Y_valid.shape)
复制代码

定义模型并训练

hidden_size = 128
gradient_clipping_norm = 1.25
batch_size = 64
epochs = 20

def exponent_neg_manhattan_distance(args):
    left, right = args
    return K.exp(-K.sum(K.abs(left - right), axis=1, keepdims=True))

left_input = Input(shape=(None,), dtype='int32')
right_input = Input(shape=(None,), dtype='int32')

embedding_layer = Embedding(len(embeddings), embedding_dim, weights=[embeddings], input_length=maxlen, trainable=False)

embedded_left = embedding_layer(left_input)
embedded_right = embedding_layer(right_input)

shared_lstm = LSTM(hidden_size)

left_output = shared_lstm(embedded_left)
right_output = shared_lstm(embedded_right)

malstm_distance = Lambda(exponent_neg_manhattan_distance, output_shape=(1,))([left_output, right_output])

malstm = Model([left_input, right_input], malstm_distance)

optimizer = Adam(clipnorm=gradient_clipping_norm)
malstm.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])

history = malstm.fit([X_train['left'], X_train['right']], Y_train, batch_size=batch_size, epochs=epochs,
                     validation_data=([X_valid['left'], X_valid['right']], Y_valid))
复制代码

绘制训练过程中的正确率曲线和损失函数曲线

# Plot Accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

# Plot Loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.show()
复制代码

训练集损失不断降低,但验证集损失趋于平缓,说明模型泛化能力还不够

深度有趣 | 24 语句相似度计算

训练集正确率提升到了86%以上,而验证集正确率维持在80%左右,模型有待进一步改进

深度有趣 | 24 语句相似度计算

保存模型,以便后续使用

malstm.save('malstm.h5')
with open('data.pkl', 'wb') as fw:
    pickle.dump({'word2id': word2id, 'id2word': id2word}, fw)
复制代码

在单机上使用训练好的模型做个简单测试,从训练集中随机拿出一些样本,观察模型分类的结果是否和标签一致,主要是熟悉下如何应用模型进行推断

# -*- coding: utf-8 -*-

from keras.preprocessing.sequence import pad_sequences
from keras.models import Model, load_model
import pandas as pd
import numpy as np
from nltk.corpus import stopwords
import re
import pickle

with open('data.pkl', 'rb') as fr:
    data = pickle.load(fr)
    word2id = data['word2id']
    id2word = data['id2word']

train_df = pd.read_csv('train.csv')

stops = set(stopwords.words('english'))
def preprocess(text):
    # input: 'Hello are you ok?'
    # output: ['Hello', 'are', 'you', 'ok', '?']
    text = str(text)
    text = text.lower()
    
    text = re.sub(r"[^A-Za-z0-9^,!.\/'+-=]", " ", text)  # 去掉其他符号
    text = re.sub(r"what's", "what is ", text)           # 缩写
    text = re.sub(r"\'s", " is ", text)                  # 缩写   
    text = re.sub(r"\'ve", " have ", text)               # 缩写
    text = re.sub(r"can't", "cannot ", text)             # 缩写
    text = re.sub(r"n't", " not ", text)                 # 缩写
    text = re.sub(r"i'm", "i am ", text)                 # 缩写
    text = re.sub(r"\'re", " are ", text)                # 缩写
    text = re.sub(r"\'d", " would ", text)               # 缩写
    text = re.sub(r"\'ll", " will ", text)               # 缩写
    text = re.sub(r",", " ", text)                       # 去除逗号
    text = re.sub(r"\.", " ", text)                      # 去除句号
    text = re.sub(r"!", " ! ", text)                     # 保留感叹号
    text = re.sub(r"\/", " ", text)                      # 去掉右斜杠
    text = re.sub(r"\^", " ^ ", text)                    # 其他符号
    text = re.sub(r"\+", " + ", text)                    # 其他符号
    text = re.sub(r"\-", " - ", text)                    # 其他符号
    text = re.sub(r"\=", " = ", text)                    # 其他符号
    text = re.sub(r"\'", " ", text)                      # 去掉单引号 
    text = re.sub(r"(\d+)(k)", r"\g<1>000", text)        # 把30k等替换成30000
    text = re.sub(r":", " : ", text)                     # 其他符号
    text = re.sub(r" e g ", " eg ", text)                # 其他词
    text = re.sub(r" b g ", " bg ", text)                # 其他词
    text = re.sub(r" u s ", " american ", text)          # 其他词
    text = re.sub(r"\0s", "0", text)                     # 其他词
    text = re.sub(r" 9 11 ", " 911 ", text)              # 其他词
    text = re.sub(r"e - mail", "email", text)            # 其他词
    text = re.sub(r"j k", "jk", text)                    # 其他词
    text = re.sub(r"\s{2,}", " ", text)                  # 将多个空白符替换成一个空格

    return text.split()

malstm = load_model('malstm.h5')
correct = 0
for i in range(5):
    print('Testing Case:', i + 1)
    random_sample = dict(train_df.iloc[np.random.randint(len(train_df))])
    left = random_sample['question1']
    right = random_sample['question2']
    print('Origin Questions...')
    print('==', left)
    print('==', right)

    left = preprocess(left)
    right = preprocess(right)
    print('Preprocessing...')
    print('==', left)
    print('==', right)

    left = [word2id[w] for w in left if w in word2id]
    right = [word2id[w] for w in right if w in word2id]
    print('To ids...')
    print('==', left, [id2word[i] for i in left])
    print('==', right, [id2word[i] for i in right])

    left = np.expand_dims(left, 0)
    right = np.expand_dims(right, 0)
    maxlen = max(left.shape[-1], right.shape[-1])
    left = pad_sequences(left, maxlen=maxlen)
    right = pad_sequences(right, maxlen=maxlen)

    print('Padding...')
    print('==', left.shape)
    print('==', right.shape)

    pred = malstm.predict([left, right])
    pred = 1 if pred[0][0] > 0.5 else 0
    print('True:', random_sample['is_duplicate'])
    print('Pred:', pred)
    if pred == random_sample['is_duplicate']:
        correct += 1
print(correct / 5)
复制代码

以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

An Introduction to the Analysis of Algorithms

An Introduction to the Analysis of Algorithms

Robert Sedgewick、Philippe Flajolet / Addison-Wesley Professional / 1995-12-10 / CAD 67.99

This book is a thorough overview of the primary techniques and models used in the mathematical analysis of algorithms. The first half of the book draws upon classical mathematical material from discre......一起来看看 《An Introduction to the Analysis of Algorithms》 这本书的介绍吧!

html转js在线工具
html转js在线工具

html转js在线工具

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具