引入Redis|tensorflow实现 聊天AI--PigPig养成记(3)

栏目: 数据库 · 发布时间: 6年前

内容简介:在集成Netty之后,为了提高效率,我打算将消息存储在Redis缓存系统中,本节将介绍Redis在项目中的引入,以及前端界面的开发。引入Redis后,

引入Redis

项目github链接

在集成Netty之后,为了提高效率,我打算将消息存储在 Redis 缓存系统中,本节将介绍Redis在项目中的引入,以及前端界面的开发。

引入Redis后, 完整代码链接

想要直接得到训练了13000步的聊天机器人可以直接下载 链接

引入Redis|tensorflow实现 聊天AI--PigPig养成记(3)

这三个文件,以及词汇表文件

引入Redis|tensorflow实现 聊天AI--PigPig养成记(3)

然后直接运行连接中的py脚本进行测试即可。

引入Redis|tensorflow实现 聊天AI--PigPig养成记(3)

最终实现效果如下:

引入Redis|tensorflow实现 聊天AI--PigPig养成记(3)

在Netty中引入Redis

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.time.LocalDateTime;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor;
import redis.clients.jedis.Jedis;

public class ChatHandler 
    extends SimpleChannelInboundHandler<TextWebSocketFrame>{
    private static ChannelGroup clients=
            new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
    
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {
        System.out.println("channelRead0...");
        
        //连接redis
        Jedis jedis=new Jedis("localhost");
        System.out.println("连接成功...");
        System.out.println("服务正在运行:"+jedis.ping());
        
        //得到用户输入的消息,需要写入文件/缓存中,让AI进行读取
        String content=msg.text();
        if(content==null||content=="") {
            System.out.println("content 为null");
            return ;
        }
        System.out.println("接收到的消息:"+content);

        //写入缓存中
        jedis.set("user_say", content+":user");
        
        Thread.sleep(1000);
        //读取AI返回的内容
        String AIsay=null;
        while(AIsay=="no"||AIsay==null) {
            //从缓存中读取AI回复的内容
            AIsay=jedis.get("ai_say");
            String [] arr=AIsay.split(":");
            AIsay=arr[0];
        }

        //读取后马上向缓存中写入
        jedis.set("ai_say", "no");
        //没有说,或者还没说
        if(AIsay==null||AIsay=="") {
            System.out.println("AIsay==null||AIsay==\"\"");
            return;
        }
        System.out.println("AI说:"+AIsay);
        
        clients.writeAndFlush(
                new TextWebSocketFrame(
                        "AI_PigPig在"+LocalDateTime.now()
                        +"说:"+AIsay));
    }
    
    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        System.out.println("add...");
        clients.add(ctx.channel());
    }
    
    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        System.out.println("客户端断开,channel对应的长id为:"
                +ctx.channel().id().asLongText());
        System.out.println("客户端断开,channel对应的短id为:"
                +ctx.channel().id().asShortText());
    }
 
}

Python 中引入Redis

with tf.Session() as sess:#打开作为一次会话
    # 恢复前一次训练
    ckpt = tf.train.get_checkpoint_state('.')#从检查点文件中返回一个状态(ckpt)
    #如果ckpt存在,输出模型路径
    if ckpt != None:
        print(ckpt.model_checkpoint_path)
        model.saver.restore(sess, ckpt.model_checkpoint_path)#储存模型参数
    else:
        print("没找到模型")
    r.set('user_say','no')
    #测试该模型的能力
    while True:
        line='no'
        #从缓存中进行读取
        while line=='no':
            line=r.get('user_say').decode()
            #print(line)
        list1=line.split(':')
        if len(list1)==1:
            input_string='no'
        else:
            input_string=list1[0]
            r.set('user_say','no')
                          
        
    # 退出
        if input_string == 'quit':
           exit()
        if input_string != 'no':
            input_string_vec = []#输入字符串向量化
            for words in input_string.strip():
                input_string_vec.append(vocab_en.get(words, UNK_ID))#get()函数:如果words在词表中,返回索引号;否则,返回UNK_ID
                bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])#保留最小的大于输入的bucket的id
                encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id)
                #get_batch(A,B):两个参数,A为大小为len(buckets)的元组,返回了指定bucket_id的encoder_inputs,decoder_inputs,target_weights
                _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
                #得到其输出
                outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]#求得最大的预测范围列表
                if EOS_ID in outputs:#如果EOS_ID在输出内部,则输出列表为[,,,,:End]
                    outputs = outputs[:outputs.index(EOS_ID)]
             
                response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])#转为解码词汇分别添加到回复中
                print('AI-PigPig > ' + response)#输出回复
                #向缓存中进行写入
                r.set('ai_say',response+':AI')

下一节将讲述通信规则的制定,以规范应用程序。


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Computer Age Statistical Inference

Computer Age Statistical Inference

Bradley Efron、Trevor Hastie / Cambridge University Press / 2016-7-21 / USD 74.99

The twenty-first century has seen a breathtaking expansion of statistical methodology, both in scope and in influence. 'Big data', 'data science', and 'machine learning' have become familiar terms in ......一起来看看 《Computer Age Statistical Inference》 这本书的介绍吧!

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

图片转BASE64编码
图片转BASE64编码

在线图片转Base64编码工具

Base64 编码/解码
Base64 编码/解码

Base64 编码/解码