diff --git a/Readme.md b/Readme.md index 5ed5abc..e07c2ce 100644 --- a/Readme.md +++ b/Readme.md @@ -26,7 +26,13 @@ python3需要再安装这些库,使用pip安装就可以: ``` -pip install torch flask openai transformers diffusers accelerate sentencepiece cpm_kernels +pip install torch flask openai transformers diffusers accelerate +``` + +如果你要使用ChatGLM,注意torch的版本应该>=2.0,transformers的版本为4.30.2,并且还要安装 + +``` +pip install gradio mdtex2html sentencepiece cpm_kernels ``` 当然如果使用cuda加速建议按照pytorch官网提供的方法安装支持cuda加速的torch版本。 @@ -73,7 +79,7 @@ go run wechat_client.go ### 注意

-第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆。 +第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆(这时候go程序会卡住没有任何提示,注意掏出手机确认登录微信)。 第一次运行需要下载Diffusion模型,文件很大,并且从外网下载,需要有比较快速稳定的网络条件。 diff --git a/bot.py b/bot.py index 95a7b52..ef7f642 100644 --- a/bot.py +++ b/bot.py @@ -7,6 +7,7 @@ import torch import argparse import flask import typing +import traceback ps = argparse.ArgumentParser() ps.add_argument("--config", default="config.json", help="Configuration file") @@ -63,7 +64,7 @@ sd_args = { sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args) sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) -if config_json["NoNSFWChecker"]: +if config_json["Diffusion"]["NoNSFWChecker"]: setattr(sd_pipe, "run_safety_checker", run_safety_nochecker) if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): @@ -88,16 +89,21 @@ def CallOpenAIGPT(prompts : typing.List[str]): except openai.InvalidRequestError as e: return (GPT_ERROR, e) except Exception as e: - return (GPT_ERROR, e) + traceback.print_exception(e) + return (GPT_ERROR, str(e)) def CallChatGLM(msg, history : typing.List[str]): try: - resp, _ = chatglm_model.chat(chatglm_tokenizer, msg, history) + resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history) + if isinstance(resp, tuple): + resp = resp[0] return (GPT_SUCCESS, resp) except Exception as e: - pass + return (GPT_ERROR, str(e)) def add_context(uid : str, is_user : bool, msg : str): + if not uid in GlobalData.context_for_users: + GlobalData.context_for_users[uid] = [] if USE_OPENAIGPT: GlobalData.context_for_users[uid].append({ "role" : "system", @@ -106,6 +112,11 @@ def add_context(uid : str, is_user : bool, msg : str): ) elif USE_CHATGLM: GlobalData.context_for_users[uid].append(msg) + +def get_context(uid : str): + if not uid in GlobalData.context_for_users: + GlobalData.context_for_users[uid] = [] + return GlobalData.context_for_users[uid] @app.route("/chat_clear", methods=['POST']) @@ -126,20 +137,24 @@ def app_chat(): if USE_OPENAIGPT: add_context(uid, True, data["text"]) - prompt = GlobalData.context_for_users[uid] + #prompt = GlobalData.context_for_users[uid] + prompt = get_context(uid) resp = CallOpenAIGPT(prompt=prompt) #GlobalData.context_for_users[data["user_id"]] = (prompt + resp) - add_context(uid, False, resp) - print(f"Prompt = {prompt}\nResponse = {resp}") + add_context(uid, False, resp[1]) + #print(f"Prompt = {prompt}\nResponse = {resp[1]}") elif USE_CHATGLM: - prompt = GlobalData.context_for_users[uid] - resp, _ = CallChatGLM(msg=data["text"], history=prompt) - add_context(uid, True, data["text"]) - add_context(uid, False, resp) + #prompt = GlobalData.context_for_users[uid] + prompt = get_context(uid) + resp = CallChatGLM(msg=data["text"], history=prompt) + add_context(uid, True, (data["text"], resp[1])) else: pass - - return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False}) + + if resp[0] == GPT_SUCCESS: + return json.dumps({"user_id" : data["user_id"], "text" : resp[1], "error" : False, "error_msg" : ""}) + else: + return json.dumps({"user_id" : data["user_id"], "text" : "", "error" : True, "error_msg" : resp[1]}) @app.route("/draw", methods=['POST']) def app_draw(): diff --git a/config.json b/config.json index e324732..607a208 100644 --- a/config.json +++ b/config.json @@ -12,7 +12,7 @@ "GPT-Model" : "THUDM/chatglm2-6b" }, "Diffusion" : { - "DiffusionModel" : "stabilityai/stable-diffusion-2-1", + "Diffusion-Model" : "stabilityai/stable-diffusion-2-1", "NoNSFWChecker" : true, "UseFP16" : true } diff --git a/wechat_client.go b/wechat_client.go index 096938b..e0b2be0 100644 --- a/wechat_client.go +++ b/wechat_client.go @@ -23,6 +23,13 @@ type SendTextRequest struct { Text string `json:"text"` } +type SendTextResponse struct { + UserID string `json:"user_id"` + Text string `json:"text"` + HasError bool `json:"error"` + ErrorMessage string `json:"error_msg"` +} + type SendImageRequest struct { UserName string `json:"user_name"` FileNames []string `json:"filenames"` @@ -135,7 +142,7 @@ func main() { } } else { - // 调用ChatGPT + // 调用GPT sender, _ := msg.Sender() //var group openwechat.Group{} = nil @@ -155,7 +162,7 @@ func main() { return } - resp := SendTextRequest{} + resp := SendTextResponse{} resp_raw := []byte("") if !msg.IsSendByGroup() { @@ -173,12 +180,22 @@ func main() { if len(resp.Text) == 0 { msg.ReplyText("GPT对此没有什么想说的,换个话题吧。") } else { - if msg.IsSendByGroup() { - sender_in_group, _ := msg.SenderInGroup() - nickname := sender_in_group.NickName - msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text)) + if resp.HasError { + if msg.IsSendByGroup() { + sender_in_group, _ := msg.SenderInGroup() + nickname := sender_in_group.NickName + msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.ErrorMessage)) + } else { + msg.ReplyText(resp.ErrorMessage) + } } else { - msg.ReplyText(resp.Text) + if msg.IsSendByGroup() { + sender_in_group, _ := msg.SenderInGroup() + nickname := sender_in_group.NickName + msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text)) + } else { + msg.ReplyText(resp.Text) + } } }