update to v1.2

This commit is contained in:
HfCloud 2023-07-02 16:43:10 +08:00
parent 5497ae3a96
commit adf34e91e9
4 changed files with 61 additions and 23 deletions

View File

@ -26,7 +26,13 @@
python3需要再安装这些库使用pip安装就可以 python3需要再安装这些库使用pip安装就可以
``` ```
pip install torch flask openai transformers diffusers accelerate sentencepiece cpm_kernels pip install torch flask openai transformers diffusers accelerate
```
如果你要使用ChatGLM注意torch的版本应该>=2.0transformers的版本为4.30.2,并且还要安装
```
pip install gradio mdtex2html sentencepiece cpm_kernels
``` ```
当然如果使用cuda加速建议按照<a href="https://pytorch.org">pytorch官网</a>提供的方法安装支持cuda加速的torch版本。 当然如果使用cuda加速建议按照<a href="https://pytorch.org">pytorch官网</a>提供的方法安装支持cuda加速的torch版本。
@ -73,7 +79,7 @@ go run wechat_client.go
### 注意 ### 注意
<p id="ch14"> </p> <p id="ch14"> </p>
第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆。 第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆这时候go程序会卡住没有任何提示注意掏出手机确认登录微信
第一次运行需要下载Diffusion模型文件很大并且从外网下载需要有比较快速稳定的网络条件。 第一次运行需要下载Diffusion模型文件很大并且从外网下载需要有比较快速稳定的网络条件。

39
bot.py
View File

@ -7,6 +7,7 @@ import torch
import argparse import argparse
import flask import flask
import typing import typing
import traceback
ps = argparse.ArgumentParser() ps = argparse.ArgumentParser()
ps.add_argument("--config", default="config.json", help="Configuration file") ps.add_argument("--config", default="config.json", help="Configuration file")
@ -63,7 +64,7 @@ sd_args = {
sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args) sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args)
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
if config_json["NoNSFWChecker"]: if config_json["Diffusion"]["NoNSFWChecker"]:
setattr(sd_pipe, "run_safety_checker", run_safety_nochecker) setattr(sd_pipe, "run_safety_checker", run_safety_nochecker)
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
@ -88,16 +89,21 @@ def CallOpenAIGPT(prompts : typing.List[str]):
except openai.InvalidRequestError as e: except openai.InvalidRequestError as e:
return (GPT_ERROR, e) return (GPT_ERROR, e)
except Exception as e: except Exception as e:
return (GPT_ERROR, e) traceback.print_exception(e)
return (GPT_ERROR, str(e))
def CallChatGLM(msg, history : typing.List[str]): def CallChatGLM(msg, history : typing.List[str]):
try: try:
resp, _ = chatglm_model.chat(chatglm_tokenizer, msg, history) resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history)
if isinstance(resp, tuple):
resp = resp[0]
return (GPT_SUCCESS, resp) return (GPT_SUCCESS, resp)
except Exception as e: except Exception as e:
pass return (GPT_ERROR, str(e))
def add_context(uid : str, is_user : bool, msg : str): def add_context(uid : str, is_user : bool, msg : str):
if not uid in GlobalData.context_for_users:
GlobalData.context_for_users[uid] = []
if USE_OPENAIGPT: if USE_OPENAIGPT:
GlobalData.context_for_users[uid].append({ GlobalData.context_for_users[uid].append({
"role" : "system", "role" : "system",
@ -107,6 +113,11 @@ def add_context(uid : str, is_user : bool, msg : str):
elif USE_CHATGLM: elif USE_CHATGLM:
GlobalData.context_for_users[uid].append(msg) GlobalData.context_for_users[uid].append(msg)
def get_context(uid : str):
if not uid in GlobalData.context_for_users:
GlobalData.context_for_users[uid] = []
return GlobalData.context_for_users[uid]
@app.route("/chat_clear", methods=['POST']) @app.route("/chat_clear", methods=['POST'])
def app_chat_clear(): def app_chat_clear():
@ -126,20 +137,24 @@ def app_chat():
if USE_OPENAIGPT: if USE_OPENAIGPT:
add_context(uid, True, data["text"]) add_context(uid, True, data["text"])
prompt = GlobalData.context_for_users[uid] #prompt = GlobalData.context_for_users[uid]
prompt = get_context(uid)
resp = CallOpenAIGPT(prompt=prompt) resp = CallOpenAIGPT(prompt=prompt)
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp) #GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
add_context(uid, False, resp) add_context(uid, False, resp[1])
print(f"Prompt = {prompt}\nResponse = {resp}") #print(f"Prompt = {prompt}\nResponse = {resp[1]}")
elif USE_CHATGLM: elif USE_CHATGLM:
prompt = GlobalData.context_for_users[uid] #prompt = GlobalData.context_for_users[uid]
resp, _ = CallChatGLM(msg=data["text"], history=prompt) prompt = get_context(uid)
add_context(uid, True, data["text"]) resp = CallChatGLM(msg=data["text"], history=prompt)
add_context(uid, False, resp) add_context(uid, True, (data["text"], resp[1]))
else: else:
pass pass
return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False}) if resp[0] == GPT_SUCCESS:
return json.dumps({"user_id" : data["user_id"], "text" : resp[1], "error" : False, "error_msg" : ""})
else:
return json.dumps({"user_id" : data["user_id"], "text" : "", "error" : True, "error_msg" : resp[1]})
@app.route("/draw", methods=['POST']) @app.route("/draw", methods=['POST'])
def app_draw(): def app_draw():

View File

@ -12,7 +12,7 @@
"GPT-Model" : "THUDM/chatglm2-6b" "GPT-Model" : "THUDM/chatglm2-6b"
}, },
"Diffusion" : { "Diffusion" : {
"DiffusionModel" : "stabilityai/stable-diffusion-2-1", "Diffusion-Model" : "stabilityai/stable-diffusion-2-1",
"NoNSFWChecker" : true, "NoNSFWChecker" : true,
"UseFP16" : true "UseFP16" : true
} }

View File

@ -23,6 +23,13 @@ type SendTextRequest struct {
Text string `json:"text"` Text string `json:"text"`
} }
type SendTextResponse struct {
UserID string `json:"user_id"`
Text string `json:"text"`
HasError bool `json:"error"`
ErrorMessage string `json:"error_msg"`
}
type SendImageRequest struct { type SendImageRequest struct {
UserName string `json:"user_name"` UserName string `json:"user_name"`
FileNames []string `json:"filenames"` FileNames []string `json:"filenames"`
@ -135,7 +142,7 @@ func main() {
} }
} else { } else {
// 调用ChatGPT // 调用GPT
sender, _ := msg.Sender() sender, _ := msg.Sender()
//var group openwechat.Group{} = nil //var group openwechat.Group{} = nil
@ -155,7 +162,7 @@ func main() {
return return
} }
resp := SendTextRequest{} resp := SendTextResponse{}
resp_raw := []byte("") resp_raw := []byte("")
if !msg.IsSendByGroup() { if !msg.IsSendByGroup() {
@ -172,6 +179,15 @@ func main() {
if len(resp.Text) == 0 { if len(resp.Text) == 0 {
msg.ReplyText("GPT对此没有什么想说的换个话题吧。") msg.ReplyText("GPT对此没有什么想说的换个话题吧。")
} else {
if resp.HasError {
if msg.IsSendByGroup() {
sender_in_group, _ := msg.SenderInGroup()
nickname := sender_in_group.NickName
msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.ErrorMessage))
} else {
msg.ReplyText(resp.ErrorMessage)
}
} else { } else {
if msg.IsSendByGroup() { if msg.IsSendByGroup() {
sender_in_group, _ := msg.SenderInGroup() sender_in_group, _ := msg.SenderInGroup()
@ -181,6 +197,7 @@ func main() {
msg.ReplyText(resp.Text) msg.ReplyText(resp.Text)
} }
} }
}
} }
} }