diff --git a/.gitignore b/.gitignore index 1c086d9..5ef6b92 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ latest*.png .DS_Store anything-v4.0 chatglm2-6b +/.idea/ diff --git a/ai.py b/ai.py new file mode 100644 index 0000000..8d18032 --- /dev/null +++ b/ai.py @@ -0,0 +1,47 @@ +import json + +from zhipuai import ZhipuAI +import flask + + + +# client = ZhipuAI(api_key="73bdeed728677bc80efc6956478a2315.VerNWJMCwN9L5gTi") # 请填写您自己的APIKey +# response = client.chat.completions.create( +# model="glm-4", # 请填写您要调用的模型名称 +# messages=[ +# {"role": "user", "content": "你好"}, +# ], +# ) +# print(response.choices[0].message) +app = flask.Flask(__name__) + + +@app.route("/chat", methods=['POST']) +def app_chat(): + data = json.loads(flask.globals.request.get_data()) + # print(data) + uid = data["user_id"] + + if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']: + data["text"] += "。" + + # 使用ZhipuAI库调用模型生成回复 + client = ZhipuAI(api_key="73bdeed728677bc80efc6956478a2315.VerNWJMCwN9L5gTi") # 请填写您自己的APIKey + response = client.chat.completions.create( + model="glm-4-flash", # 请填写您要调用的模型名称 + messages=[ + {"role": "user", "content": data["text"]}, + ], + ) + + # 获取模型的回复 + resp = response.choices[0].message.content + + if resp: + return json.dumps({"user_id": data["user_id"], "text": resp, "error": False, "error_msg": ""}) + else: + return json.dumps({"user_id": data["user_id"], "text": "", "error": True, "error_msg": "模型未返回回复"}) + + +if __name__ == '__main__': + app.run(host="0.0.0.0", port=11111) diff --git a/bot.py b/bot.py index d77f512..5e1527a 100644 --- a/bot.py +++ b/bot.py @@ -1,5 +1,6 @@ import json -import openai + +from zhipuai import ZhipuAI import re from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler from transformers import AutoTokenizer, AutoModel @@ -16,13 +17,14 @@ args = ps.parse_args() with open(args.config) as f: config_json = json.load(f) + class GlobalData: - # OPENAI_ORGID = config_json[""] + # OPENAI_ORGID = config_json[""] OPENAI_APIKEY = config_json["OpenAI-GPT"]["OpenAI-Key"] OPENAI_MODEL = config_json["OpenAI-GPT"]["GPT-Model"] OPENAI_MODEL_TEMPERATURE = int(config_json["OpenAI-GPT"]["Temperature"]) OPENAI_MODEL_MAXTOKENS = min(2048, int(config_json["OpenAI-GPT"]["MaxTokens"])) - + CHATGLM_MODEL = config_json["ChatGLM"]["GPT-Model"] context_for_users = {} @@ -31,11 +33,12 @@ class GlobalData: GENERATE_PICTURE_ARG_PAT = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))") GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))") GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+") - GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数 - + GENERATE_PICTURE_MAX_ITS = 200 # 最大迭代次数 + + USE_OPENAIGPT = False USE_CHATGLM = False - + if config_json["OpenAI-GPT"]["Enable"]: print(f"Use OpenAI GPT Model({GlobalData.OPENAI_MODEL}).") USE_OPENAIGPT = True @@ -49,17 +52,19 @@ elif config_json["ChatGLM"]["Enable"]: chatglm_model = chatglm_model.to('cuda') chatglm_model = chatglm_model.eval() USE_CHATGLM = True - + app = flask.Flask(__name__) + # 这个用于放行生成的任何图片,替换掉默认的NSFW检查器,公共场合慎重使用 def run_safety_nochecker(image, device, dtype): print("警告:屏蔽了内容安全性检查,可能会产生有害内容") return image, None + sd_args = { - "pretrained_model_name_or_path" : config_json["Diffusion"]["Diffusion-Model"], - "torch_dtype" : (torch.float16 if config_json["Diffusion"].get("UseFP16", True) else torch.float32) + "pretrained_model_name_or_path": config_json["Diffusion"]["Diffusion-Model"], + "torch_dtype": (torch.float16 if config_json["Diffusion"].get("UseFP16", True) else torch.float32) } sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args) @@ -71,12 +76,13 @@ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): sd_pipe = sd_pipe.to("mps") elif torch.cuda.is_available(): sd_pipe = sd_pipe.to("cuda") - + GPT_SUCCESS = 0 GPT_NORESULT = 1 GPT_ERROR = 2 -def CallOpenAIGPT(prompts : typing.List[str]): + +def CallOpenAIGPT(prompts: typing.List[str]): try: res = openai.ChatCompletion.create( model=config_json["OpenAI-GPT"]["GPT-Model"], @@ -91,8 +97,9 @@ def CallOpenAIGPT(prompts : typing.List[str]): except Exception as e: traceback.print_exception(e) return (GPT_ERROR, str(e)) - -def CallChatGLM(msg, history : typing.List[str]): + + +def CallChatGLM(msg, history: typing.List[str]): try: resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history) if isinstance(resp, tuple): @@ -100,24 +107,26 @@ def CallChatGLM(msg, history : typing.List[str]): return (GPT_SUCCESS, resp) except Exception as e: return (GPT_ERROR, str(e)) - -def add_context(uid : str, is_user : bool, msg : str): + + +def add_context(uid: str, is_user: bool, msg: str): if not uid in GlobalData.context_for_users: GlobalData.context_for_users[uid] = [] if USE_OPENAIGPT: GlobalData.context_for_users[uid].append({ - "role" : "system", - "content" : msg + "role": "system", + "content": msg } ) elif USE_CHATGLM: GlobalData.context_for_users[uid].append(msg) - -def get_context(uid : str): + + +def get_context(uid: str): if not uid in GlobalData.context_for_users: GlobalData.context_for_users[uid] = [] return GlobalData.context_for_users[uid] - + @app.route("/chat_clear", methods=['POST']) def app_chat_clear(): @@ -126,40 +135,38 @@ def app_chat_clear(): print(f"Cleared context for {data['user_id']}") return "" + @app.route("/chat", methods=['POST']) def app_chat(): data = json.loads(flask.globals.request.get_data()) - #print(data) + # print(data) uid = data["user_id"] if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']: data["text"] += "。" - - if USE_OPENAIGPT: - add_context(uid, True, data["text"]) - #prompt = GlobalData.context_for_users[uid] - prompt = get_context(uid) - resp = CallOpenAIGPT(prompt=prompt) - #GlobalData.context_for_users[data["user_id"]] = (prompt + resp) - add_context(uid, False, resp[1]) - #print(f"Prompt = {prompt}\nResponse = {resp[1]}") - elif USE_CHATGLM: - #prompt = GlobalData.context_for_users[uid] - prompt = get_context(uid) - resp = CallChatGLM(msg=data["text"], history=prompt) - add_context(uid, True, (data["text"], resp[1])) + + # 使用ZhipuAI库调用模型生成回复 + client = ZhipuAI(api_key="73bdeed728677bc80efc6956478a2315.VerNWJMCwN9L5gTi") # 请填写您自己的APIKey + response = client.chat.completions.create( + model="glm-4", # 请填写您要调用的模型名称 + messages=[ + {"role": "user", "content": data["text"]}, + ], + ) + + # 获取模型的回复 + resp = response.choices[0].message.content + + if resp: + return json.dumps({"user_id": data["user_id"], "text": resp, "error": False, "error_msg": ""}) else: - pass - - if resp[0] == GPT_SUCCESS: - return json.dumps({"user_id" : data["user_id"], "text" : resp[1], "error" : False, "error_msg" : ""}) - else: - return json.dumps({"user_id" : data["user_id"], "text" : "", "error" : True, "error_msg" : resp[1]}) + return json.dumps({"user_id": data["user_id"], "text": "", "error": True, "error_msg": "模型未返回回复"}) + @app.route("/draw", methods=['POST']) def app_draw(): data = json.loads(flask.globals.request.get_data()) - + prompt = data["prompt"] i = 0 @@ -167,8 +174,8 @@ def app_draw(): if prompt[i] == ':' or prompt[i] == ':': break if i == len(prompt): - return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"}) - + return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True, + "error_msg": "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"}) match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT2, prompt[:i]) if not match_args is None: @@ -185,7 +192,8 @@ def app_draw(): NUM_PIC = 1 else: if len(prompt[:i].strip()) != 0: - return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"}) + return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True, + "error_msg": "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"}) else: W = 768 H = 768 @@ -193,51 +201,59 @@ def app_draw(): NUM_PIC = 1 if W > 2500 or H > 2500: - return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"}) - - if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS: - return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"}) + return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True, + "error_msg": "你要求的图片太大了,我不干了~"}) - prompt = prompt[(i+1):].strip() + if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS: + return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True, + "error_msg": f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"}) + + prompt = prompt[(i + 1):].strip() prompts = re.split(GlobalData.GENERATE_PICTURE_NEG_PROMPT_DELIMETER, prompt) prompt = prompts[0] - neg_prompt = None + neg_prompt = None if len(prompts) > 1: neg_prompt = prompts[1] print(f"Generating {NUM_PIC} picture(s) with prompt = {prompt} , negative prompt = {neg_prompt}") - - try: - if NUM_PIC > 1 and torch.backends.mps.is_available(): #Apple silicon上的bug:https://github.com/huggingface/diffusers/issues/363 - return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, - "error_msg" : "单prompt生成多张图像在Apple silicon上无法实现,相关讨论参考https://github.com/huggingface/diffusers/issues/363"}) - images = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS, num_images_per_prompt=NUM_PIC).images[:NUM_PIC] + try: + if NUM_PIC > 1 and torch.backends.mps.is_available(): # Apple silicon上的bug:https://github.com/huggingface/diffusers/issues/363 + return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True, + "error_msg": "单prompt生成多张图像在Apple silicon上无法实现,相关讨论参考https://github.com/huggingface/diffusers/issues/363"}) + + images = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS, + num_images_per_prompt=NUM_PIC).images[:NUM_PIC] if len(images) == 0: - return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "没有产生任何图像"}) + return json.dumps( + {"user_name": data["user_name"], "filenames": [], "error": True, "error_msg": "没有产生任何图像"}) filenames = [] for i, img in enumerate(images): img.save(f"latest-{i}.png") filenames.append(f"latest-{i}.png") - return json.dumps({"user_name" : data["user_name"], "filenames" : filenames, "error" : False, "error_msg" : ""}) + return json.dumps({"user_name": data["user_name"], "filenames": filenames, "error": False, "error_msg": ""}) + + except Exception as e: + return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True, "error_msg": str(e)}) - except Exception as e: - return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : str(e)}) @app.route("/info", methods=['POST', 'GET']) def app_info(): - return "\n".join([f"GPT模型:{config_json['OpenAI-GPT']['GPT-Model'] if USE_OPENAIGPT else config_json['ChatGLM']['GPT-Model']}", f"Diffusion模型:{config_json['Diffusion']['Diffusion-Model']}", - "默认图片规格:768x768 RGB三通道", "Diffusion默认迭代轮数:20", - f"使用半精度浮点数 : {'是' if config_json['Diffusion'].get('UseFP16', True) else '否'}", - f"屏蔽NSFW检查:{'是' if config_json['Diffusion']['NoNSFWChecker'] else '否'}", - "清空上下文指令:重置上下文", - "生成图片指令:生成图片(宽 高 迭代次数):正面提示 换行写负面提示,其中(宽 高 迭代次数)和换行写的负面提示都是可以省略的"]) + return "\n".join( + [f"GPT模型:{config_json['OpenAI-GPT']['GPT-Model'] if USE_OPENAIGPT else config_json['ChatGLM']['GPT-Model']}", + f"Diffusion模型:{config_json['Diffusion']['Diffusion-Model']}", + "默认图片规格:768x768 RGB三通道", "Diffusion默认迭代轮数:20", + f"使用半精度浮点数 : {'是' if config_json['Diffusion'].get('UseFP16', True) else '否'}", + f"屏蔽NSFW检查:{'是' if config_json['Diffusion']['NoNSFWChecker'] else '否'}", + "清空上下文指令:重置上下文", + "生成图片指令:生成图片(宽 高 迭代次数):正面提示 换行写负面提示,其中(宽 高 迭代次数)和换行写的负面提示都是可以省略的"]) + if __name__ == "__main__": if USE_OPENAIGPT: openai.api_key = GlobalData.OPENAI_APIKEY - app.run(host="0.0.0.0", port=11111) \ No newline at end of file + app.run(host="0.0.0.0", port=11111) diff --git a/wechat_client.go b/wechat_client.go index a539939..635460c 100644 --- a/wechat_client.go +++ b/wechat_client.go @@ -86,120 +86,124 @@ func main() { // 注册消息处理函数 bot.MessageHandler = func(msg *openwechat.Message) { - if msg.IsTickledMe() { - msg.ReplyText("别拍了,机器人是会被拍坏掉的。") - return - } - - if !msg.IsText() { - return - } - - // fmt.Println(msg.Content) - - content := msg.Content - if msg.IsSendByGroup() && !msg.IsAt() { - return - } - - if msg.IsSendByGroup() && msg.IsAt() { - atheader := fmt.Sprintf("@%s", self.NickName) - //fmt.Println(atheader) - if strings.HasPrefix(content, atheader) { - content = strings.TrimLeft(content[len(atheader):], "  \t\n") - } - } - //fmt.Println(content) - content = strings.TrimRight(content, "  \t\n") - if content == "查看机器人信息" { - info := HttpPost("http://localhost:11111/info", nil, 20) - msg.ReplyText(string(info)) - - } else if strings.HasPrefix(content, "生成图片") { - // 调用Stable Diffusion - // msg.ReplyText("这个功能还没有实现,可以先期待一下~") - sender, _ := msg.Sender() - content = strings.TrimLeft(content[len("生成图片"):], " \t\n") - - resp_raw := HttpPost("http://localhost:11111/draw", GenerateImageRequest{UserName: sender.ID(), Prompt: content}, 120) - if len(resp_raw) == 0 { - msg.ReplyText("生成图片出错啦QwQ,或许可以再试一次") - return - } - - resp := SendImageRequest{} - json.Unmarshal(resp_raw, &resp) - //fmt.Println(resp.FileName) - if resp.HasError { - msg.ReplyText(fmt.Sprintf("生成图片出错啦QwQ,错误信息是:%s", resp.ErrorMessage)) - } else { - for i := 0; i < len(resp.FileNames); i++ { - img, _ := os.Open(resp.FileNames[i]) - defer img.Close() - msg.ReplyImage(img) - } - } - - } else { - // 调用GPT - - sender, _ := msg.Sender() - //var group openwechat.Group{} = nil - var group *openwechat.Group = nil - - if msg.IsSendByGroup() { - group = &openwechat.Group{User: sender} - } - - if content == "重置上下文" { - if !msg.IsSendByGroup() { - HttpPost("http://localhost:11111/chat_clear", SendTextRequest{InGroup: msg.IsSendByGroup(), UserID: sender.ID(), Text: ""}, 60) - } else { - HttpPost("http://localhost:11111/chat_clear", SendTextRequest{InGroup: msg.IsSendByGroup(), UserID: group.ID(), Text: ""}, 60) - } - msg.ReplyText("OK,我忘掉了之前的上下文。") - return - } - - resp := SendTextResponse{} - resp_raw := []byte("") - - if !msg.IsSendByGroup() { - resp_raw = HttpPost("http://localhost:11111/chat", SendTextRequest{InGroup: false, UserID: sender.ID(), Text: msg.Content}, 60) - } else { - resp_raw = HttpPost("http://localhost:11111/chat", SendTextRequest{InGroup: false, UserID: group.ID(), Text: msg.Content}, 60) - } - if len(resp_raw) == 0 { - msg.ReplyText("运算超时了QAQ,或许可以再试一次。") - return - } - - json.Unmarshal(resp_raw, &resp) - - if len(resp.Text) == 0 { - msg.ReplyText("GPT对此没有什么想说的,换个话题吧。") - } else { - if resp.HasError { - if msg.IsSendByGroup() { - sender_in_group, _ := msg.SenderInGroup() - nickname := sender_in_group.NickName - msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.ErrorMessage)) - } else { - msg.ReplyText(resp.ErrorMessage) - } - } else { - if msg.IsSendByGroup() { - sender_in_group, _ := msg.SenderInGroup() - nickname := sender_in_group.NickName - msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text)) - } else { - msg.ReplyText(resp.Text) - } - } - } - - } + go sendMessage(msg, self) } bot.Block() } + +func sendMessage(msg *openwechat.Message, self *openwechat.Self) { + if msg.IsTickledMe() { + msg.ReplyText("别拍了,机器人是会被拍坏掉的。") + return + } + + if !msg.IsText() { + return + } + + // fmt.Println(msg.Content) + + content := msg.Content + if msg.IsSendByGroup() && !msg.IsAt() { + return + } + + if msg.IsSendByGroup() && msg.IsAt() { + atheader := fmt.Sprintf("@%s", self.NickName) + //fmt.Println(atheader) + if strings.HasPrefix(content, atheader) { + content = strings.TrimLeft(content[len(atheader):], "  \t\n") + } + } + //fmt.Println(content) + content = strings.TrimRight(content, "  \t\n") + if content == "查看机器人信息" { + info := HttpPost("http://localhost:11111/info", nil, 20) + msg.ReplyText(string(info)) + + } else if strings.HasPrefix(content, "生成图片") { + // 调用Stable Diffusion + // msg.ReplyText("这个功能还没有实现,可以先期待一下~") + sender, _ := msg.Sender() + content = strings.TrimLeft(content[len("生成图片"):], " \t\n") + + resp_raw := HttpPost("http://localhost:11111/draw", GenerateImageRequest{UserName: sender.ID(), Prompt: content}, 120) + if len(resp_raw) == 0 { + msg.ReplyText("生成图片出错啦QwQ,或许可以再试一次") + return + } + + resp := SendImageRequest{} + json.Unmarshal(resp_raw, &resp) + //fmt.Println(resp.FileName) + if resp.HasError { + msg.ReplyText(fmt.Sprintf("生成图片出错啦QwQ,错误信息是:%s", resp.ErrorMessage)) + } else { + for i := 0; i < len(resp.FileNames); i++ { + img, _ := os.Open(resp.FileNames[i]) + defer img.Close() + msg.ReplyImage(img) + } + } + + } else { + // 调用GPT + + sender, _ := msg.Sender() + //var group openwechat.Group{} = nil + var group *openwechat.Group = nil + + if msg.IsSendByGroup() { + group = &openwechat.Group{User: sender} + } + + if content == "重置上下文" { + if !msg.IsSendByGroup() { + HttpPost("http://localhost:11111/chat_clear", SendTextRequest{InGroup: msg.IsSendByGroup(), UserID: sender.ID(), Text: ""}, 60) + } else { + HttpPost("http://localhost:11111/chat_clear", SendTextRequest{InGroup: msg.IsSendByGroup(), UserID: group.ID(), Text: ""}, 60) + } + msg.ReplyText("OK,我忘掉了之前的上下文。") + return + } + + resp := SendTextResponse{} + resp_raw := []byte("") + + if !msg.IsSendByGroup() { + resp_raw = HttpPost("http://localhost:11111/chat", SendTextRequest{InGroup: false, UserID: sender.ID(), Text: msg.Content}, 60) + } else { + resp_raw = HttpPost("http://localhost:11111/chat", SendTextRequest{InGroup: false, UserID: group.ID(), Text: msg.Content}, 60) + } + if len(resp_raw) == 0 { + msg.ReplyText("运算超时了QAQ,或许可以再试一次。") + return + } + + json.Unmarshal(resp_raw, &resp) + + if len(resp.Text) == 0 { + msg.ReplyText("GPT对此没有什么想说的,换个话题吧。") + } else { + if resp.HasError { + if msg.IsSendByGroup() { + sender_in_group, _ := msg.SenderInGroup() + nickname := sender_in_group.NickName + msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.ErrorMessage)) + } else { + msg.ReplyText(resp.ErrorMessage) + } + } else { + if msg.IsSendByGroup() { + sender_in_group, _ := msg.SenderInGroup() + nickname := sender_in_group.NickName + msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text)) + } else { + msg.ReplyText(resp.Text) + } + } + } + + } +}