update to v1.2

This commit is contained in:
HfCloud 2023-07-02 16:43:10 +08:00
parent 5497ae3a96
commit adf34e91e9
4 changed files with 61 additions and 23 deletions

View File

@ -26,7 +26,13 @@
python3需要再安装这些库使用pip安装就可以
```
pip install torch flask openai transformers diffusers accelerate sentencepiece cpm_kernels
pip install torch flask openai transformers diffusers accelerate
```
如果你要使用ChatGLM注意torch的版本应该>=2.0transformers的版本为4.30.2,并且还要安装
```
pip install gradio mdtex2html sentencepiece cpm_kernels
```
当然如果使用cuda加速建议按照<a href="https://pytorch.org">pytorch官网</a>提供的方法安装支持cuda加速的torch版本。
@ -73,7 +79,7 @@ go run wechat_client.go
### 注意
<p id="ch14"> </p>
第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆。
第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆这时候go程序会卡住没有任何提示注意掏出手机确认登录微信
第一次运行需要下载Diffusion模型文件很大并且从外网下载需要有比较快速稳定的网络条件。

41
bot.py
View File

@ -7,6 +7,7 @@ import torch
import argparse
import flask
import typing
import traceback
ps = argparse.ArgumentParser()
ps.add_argument("--config", default="config.json", help="Configuration file")
@ -63,7 +64,7 @@ sd_args = {
sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args)
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
if config_json["NoNSFWChecker"]:
if config_json["Diffusion"]["NoNSFWChecker"]:
setattr(sd_pipe, "run_safety_checker", run_safety_nochecker)
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
@ -88,16 +89,21 @@ def CallOpenAIGPT(prompts : typing.List[str]):
except openai.InvalidRequestError as e:
return (GPT_ERROR, e)
except Exception as e:
return (GPT_ERROR, e)
traceback.print_exception(e)
return (GPT_ERROR, str(e))
def CallChatGLM(msg, history : typing.List[str]):
try:
resp, _ = chatglm_model.chat(chatglm_tokenizer, msg, history)
resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history)
if isinstance(resp, tuple):
resp = resp[0]
return (GPT_SUCCESS, resp)
except Exception as e:
pass
return (GPT_ERROR, str(e))
def add_context(uid : str, is_user : bool, msg : str):
if not uid in GlobalData.context_for_users:
GlobalData.context_for_users[uid] = []
if USE_OPENAIGPT:
GlobalData.context_for_users[uid].append({
"role" : "system",
@ -106,6 +112,11 @@ def add_context(uid : str, is_user : bool, msg : str):
)
elif USE_CHATGLM:
GlobalData.context_for_users[uid].append(msg)
def get_context(uid : str):
if not uid in GlobalData.context_for_users:
GlobalData.context_for_users[uid] = []
return GlobalData.context_for_users[uid]
@app.route("/chat_clear", methods=['POST'])
@ -126,20 +137,24 @@ def app_chat():
if USE_OPENAIGPT:
add_context(uid, True, data["text"])
prompt = GlobalData.context_for_users[uid]
#prompt = GlobalData.context_for_users[uid]
prompt = get_context(uid)
resp = CallOpenAIGPT(prompt=prompt)
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
add_context(uid, False, resp)
print(f"Prompt = {prompt}\nResponse = {resp}")
add_context(uid, False, resp[1])
#print(f"Prompt = {prompt}\nResponse = {resp[1]}")
elif USE_CHATGLM:
prompt = GlobalData.context_for_users[uid]
resp, _ = CallChatGLM(msg=data["text"], history=prompt)
add_context(uid, True, data["text"])
add_context(uid, False, resp)
#prompt = GlobalData.context_for_users[uid]
prompt = get_context(uid)
resp = CallChatGLM(msg=data["text"], history=prompt)
add_context(uid, True, (data["text"], resp[1]))
else:
pass
return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False})
if resp[0] == GPT_SUCCESS:
return json.dumps({"user_id" : data["user_id"], "text" : resp[1], "error" : False, "error_msg" : ""})
else:
return json.dumps({"user_id" : data["user_id"], "text" : "", "error" : True, "error_msg" : resp[1]})
@app.route("/draw", methods=['POST'])
def app_draw():

View File

@ -12,7 +12,7 @@
"GPT-Model" : "THUDM/chatglm2-6b"
},
"Diffusion" : {
"DiffusionModel" : "stabilityai/stable-diffusion-2-1",
"Diffusion-Model" : "stabilityai/stable-diffusion-2-1",
"NoNSFWChecker" : true,
"UseFP16" : true
}

View File

@ -23,6 +23,13 @@ type SendTextRequest struct {
Text string `json:"text"`
}
type SendTextResponse struct {
UserID string `json:"user_id"`
Text string `json:"text"`
HasError bool `json:"error"`
ErrorMessage string `json:"error_msg"`
}
type SendImageRequest struct {
UserName string `json:"user_name"`
FileNames []string `json:"filenames"`
@ -135,7 +142,7 @@ func main() {
}
} else {
// 调用ChatGPT
// 调用GPT
sender, _ := msg.Sender()
//var group openwechat.Group{} = nil
@ -155,7 +162,7 @@ func main() {
return
}
resp := SendTextRequest{}
resp := SendTextResponse{}
resp_raw := []byte("")
if !msg.IsSendByGroup() {
@ -173,12 +180,22 @@ func main() {
if len(resp.Text) == 0 {
msg.ReplyText("GPT对此没有什么想说的换个话题吧。")
} else {
if msg.IsSendByGroup() {
sender_in_group, _ := msg.SenderInGroup()
nickname := sender_in_group.NickName
msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text))
if resp.HasError {
if msg.IsSendByGroup() {
sender_in_group, _ := msg.SenderInGroup()
nickname := sender_in_group.NickName
msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.ErrorMessage))
} else {
msg.ReplyText(resp.ErrorMessage)
}
} else {
msg.ReplyText(resp.Text)
if msg.IsSendByGroup() {
sender_in_group, _ := msg.SenderInGroup()
nickname := sender_in_group.NickName
msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text))
} else {
msg.ReplyText(resp.Text)
}
}
}