集成Netty
通过我们已经可以训练得到一只傲娇的聊天AI_PigPig了。
本章将介绍项目关于Netty的集成问题,在集成Netty之后,我们的AI_PigPig可以通过web应用与大家日常互撩。
由于只是一个小测试,所以不考虑性能方面的问题,在下一章我们将重点处理效率难关,集成Redis。关于Netty的学习大家可以看,本节中关于Netty部分的代码改编自该文章中的,文章中会有详细的讲解。
Python代码改动
首先对测试训练结果的代码进行改动,将输入输出流重定向自作为中间媒介的测试文件中。
with tf.Session() as sess:#打开作为一次会话 # 恢复前一次训练 ckpt = tf.train.get_checkpoint_state('.')#从检查点文件中返回一个状态(ckpt) #如果ckpt存在,输出模型路径 if ckpt != None: print(ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path)#储存模型参数 else: print("没找到模型") #测试该模型的能力 while True: #从文件中进行读取 #input_string = input('me > ') #测试文件输入格式为"[内容]:[名字]" #eg.你好:AI【表示AI的回复】 #你好:user【表示用户的输入】 with open('./temp.txt','r+',encoding='ANSI') as myf: #从文件中读取用户的输入 line=myf.read() list1=line.split(':') #长度为一,表明不符合输入格式,设置为"no",则不进行测试处理 if len(list1)==1: input_string='no' else: #符合输入格式,证明是用户输入的 #input_string为用户输入的内容 input_string=list1[0] myf.seek(0) #清空文件 myf.truncate() #写入"no",若读到"no",则不进行测试处理 myf.write('no') # 退出 if input_string == 'quit': exit() #若读到"no",则不进行测试处理 if input_string != 'no': input_string_vec = []#输入字符串向量化 for words in input_string.strip(): input_string_vec.append(vocab_en.get(words, UNK_ID))#get()函数:如果words在词表中,返回索引号;否则,返回UNK_ID bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])#保留最小的大于输入的bucket的id encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id) #get_batch(A,B):两个参数,A为大小为len(buckets)的元组,返回了指定bucket_id的encoder_inputs,decoder_inputs,target_weights _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) #得到其输出 outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]#求得最大的预测范围列表 if EOS_ID in outputs:#如果EOS_ID在输出内部,则输出列表为[,,,,:End] outputs = outputs[:outputs.index(EOS_ID)] response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])#转为解码词汇分别添加到回复中 print('AI-PigPig > ' + response)#输出回复 #将AI的回复以要求的格式进行写入,方便Netty程序读取 with open('./temp1.txt','w',encoding='ANSI') as myf1: myf1.write(response+':AI')
Netty程序
完整代码参见netty包下。
在原本的ChatHandler类中添加了从文件中读取数据的方法readFromFile,以及向文件中覆盖地写入数据的方法writeToFile。
//从文件中读取数据 private static String readFromFile(String filePath) { File file=new File(filePath); String line=null; String name=null; String content=null; try { //以content:name的形式写入 BufferedReader br=new BufferedReader(new FileReader(file)); line=br.readLine(); String [] arr=line.split(":"); if(arr.length==1) { name=null; content=null; }else { content=arr[0]; name=arr[1]; } br.close(); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } return content; } //向文件中覆盖地写入 private static void writeToFile(String filePath,String content) { File file =new File(filePath); try { FileWriter fileWriter=new FileWriter(file); fileWriter.write(""); fileWriter.flush(); fileWriter.write(content); fileWriter.close(); } catch (IOException e) { e.printStackTrace(); } }
对原来的channelRead0方法进行修改,将输入输出流重定向到临时文件中。
@Override protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception { System.out.println("channelRead0"); //得到用户输入的消息,需要写入文件/缓存中,让AI进行读取 String content=msg.text(); if(content==null||content=="") { System.out.println("content 为null"); return ; } System.out.println("接收到的消息:"+content); //写入 writeToFile(writeFilePath, content+":user"); //给AI回复与写入的时间,后期会增对性能方面进行改进 Thread.sleep(1000); //读取AI返回的内容 String AIsay=readFromFile(readFilePath); //读取后马上写入 writeToFile(readFilePath,"no"); //没有说,或者还没说 if(AIsay==null||AIsay==""||AIsay=="no") { System.out.println("AIsay为空或no"); return; } System.out.println("AI说:"+AIsay); clients.writeAndFlush( new TextWebSocketFrame( "AI_PigPig在"+LocalDateTime.now() +"说:"+AIsay)); }
客户端代码
发送消息:接受消息:
测试结果
客户端发送消息
用户与AI日常互撩