博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
集成Netty|tensorflow实现 聊天AI--PigPig养成记(2)
阅读量:5789 次
发布时间:2019-06-18

本文共 5109 字,大约阅读时间需要 17 分钟。

集成Netty

通过我们已经可以训练得到一只傲娇的聊天AI_PigPig了。

图片描述

本章将介绍项目关于Netty的集成问题,在集成Netty之后,我们的AI_PigPig可以通过web应用与大家日常互撩。

由于只是一个小测试,所以不考虑性能方面的问题,在下一章我们将重点处理效率难关,集成Redis。

关于Netty的学习大家可以看,本节中关于Netty部分的代码改编自该文章中的,文章中会有详细的讲解。

Python代码改动

首先对测试训练结果的代码进行改动,将输入输出流重定向自作为中间媒介的测试文件中。

with tf.Session() as sess:#打开作为一次会话    # 恢复前一次训练    ckpt = tf.train.get_checkpoint_state('.')#从检查点文件中返回一个状态(ckpt)    #如果ckpt存在,输出模型路径    if ckpt != None:        print(ckpt.model_checkpoint_path)        model.saver.restore(sess, ckpt.model_checkpoint_path)#储存模型参数    else:        print("没找到模型")    #测试该模型的能力    while True:        #从文件中进行读取        #input_string = input('me > ')        #测试文件输入格式为"[内容]:[名字]"        #eg.你好:AI【表示AI的回复】        #你好:user【表示用户的输入】        with open('./temp.txt','r+',encoding='ANSI') as myf:            #从文件中读取用户的输入            line=myf.read()            list1=line.split(':')            #长度为一,表明不符合输入格式,设置为"no",则不进行测试处理            if len(list1)==1:                input_string='no'            else:                #符合输入格式,证明是用户输入的                #input_string为用户输入的内容                input_string=list1[0]                myf.seek(0)                #清空文件                myf.truncate()                #写入"no",若读到"no",则不进行测试处理                myf.write('no')                         # 退出        if input_string == 'quit':           exit()        #若读到"no",则不进行测试处理        if input_string != 'no':            input_string_vec = []#输入字符串向量化            for words in input_string.strip():                input_string_vec.append(vocab_en.get(words, UNK_ID))#get()函数:如果words在词表中,返回索引号;否则,返回UNK_ID                bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])#保留最小的大于输入的bucket的id                encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id)                #get_batch(A,B):两个参数,A为大小为len(buckets)的元组,返回了指定bucket_id的encoder_inputs,decoder_inputs,target_weights                _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)                #得到其输出                outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]#求得最大的预测范围列表                if EOS_ID in outputs:#如果EOS_ID在输出内部,则输出列表为[,,,,:End]                    outputs = outputs[:outputs.index(EOS_ID)]                             response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])#转为解码词汇分别添加到回复中                print('AI-PigPig > ' + response)#输出回复                #将AI的回复以要求的格式进行写入,方便Netty程序读取                with open('./temp1.txt','w',encoding='ANSI') as myf1:                    myf1.write(response+':AI')

Netty程序

完整代码参见netty包下。

在原本的ChatHandler类中添加了从文件中读取数据的方法readFromFile,以及向文件中覆盖地写入数据的方法writeToFile。

//从文件中读取数据    private static String readFromFile(String filePath) {        File file=new File(filePath);        String line=null;        String name=null;        String content=null;        try {            //以content:name的形式写入            BufferedReader br=new BufferedReader(new FileReader(file));            line=br.readLine();            String [] arr=line.split(":");            if(arr.length==1) {                name=null;                content=null;            }else {                content=arr[0];                name=arr[1];            }            br.close();        } catch (FileNotFoundException e) {            e.printStackTrace();        } catch (IOException e) {            e.printStackTrace();        }        return content;    }        //向文件中覆盖地写入    private static void writeToFile(String filePath,String content) {        File file =new File(filePath);        try {            FileWriter fileWriter=new FileWriter(file);            fileWriter.write("");            fileWriter.flush();            fileWriter.write(content);            fileWriter.close();        } catch (IOException e) {            e.printStackTrace();        }            }

对原来的channelRead0方法进行修改,将输入输出流重定向到临时文件中。

@Override    protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {        System.out.println("channelRead0");        //得到用户输入的消息,需要写入文件/缓存中,让AI进行读取        String content=msg.text();        if(content==null||content=="") {            System.out.println("content 为null");            return ;        }        System.out.println("接收到的消息:"+content);        //写入        writeToFile(writeFilePath, content+":user");        //给AI回复与写入的时间,后期会增对性能方面进行改进        Thread.sleep(1000);        //读取AI返回的内容        String AIsay=readFromFile(readFilePath);        //读取后马上写入        writeToFile(readFilePath,"no");        //没有说,或者还没说        if(AIsay==null||AIsay==""||AIsay=="no") {            System.out.println("AIsay为空或no");            return;        }        System.out.println("AI说:"+AIsay);                clients.writeAndFlush(                new TextWebSocketFrame(                        "AI_PigPig在"+LocalDateTime.now()                        +"说:"+AIsay));    }

客户端代码

            
发送消息:
接受消息:

测试结果

客户端发送消息

图片描述

用户与AI日常互撩

图片描述

转载地址:http://nrmyx.baihongyu.com/

你可能感兴趣的文章
C++多态、继承的简单分析
查看>>
库克称未来苹果用户可自己决定是否降频 网友:你是在搞笑吗?
查看>>
6倍性能差100TB容量,阿里云POLARDB咋实现?
查看>>
linux 安装 MySQLdb for python
查看>>
Sublime Text 2 技巧
查看>>
使用fscanf()函数从磁盘文件读取格式化数据
查看>>
参加婚礼
查看>>
h5 audio相关手册
查看>>
刚毕业从事java开发需要掌握的技术
查看>>
CSS Custom Properties 自定义属性
查看>>
vim
查看>>
MVVM计算器(下)
查看>>
C++中指针和引用的区别
查看>>
簡單分稀 iptables 記錄 udp 微軟 138 端口
查看>>
Java重写equals方法和hashCode方法
查看>>
Spark API编程动手实战-07-join操作深入实战
查看>>
H3C-路由策略
查看>>
centos 修改字符界面分辨率
查看>>
LNMP之Mysql主从复制(四)
查看>>
阅读Spring源代码(1)
查看>>