update to v1.2
This commit is contained in:
parent
5497ae3a96
commit
adf34e91e9
10
Readme.md
10
Readme.md
|
@ -26,7 +26,13 @@
|
|||
python3需要再安装这些库,使用pip安装就可以:
|
||||
|
||||
```
|
||||
pip install torch flask openai transformers diffusers accelerate sentencepiece cpm_kernels
|
||||
pip install torch flask openai transformers diffusers accelerate
|
||||
```
|
||||
|
||||
如果你要使用ChatGLM,注意torch的版本应该>=2.0,transformers的版本为4.30.2,并且还要安装
|
||||
|
||||
```
|
||||
pip install gradio mdtex2html sentencepiece cpm_kernels
|
||||
```
|
||||
|
||||
当然如果使用cuda加速建议按照<a href="https://pytorch.org">pytorch官网</a>提供的方法安装支持cuda加速的torch版本。
|
||||
|
@ -73,7 +79,7 @@ go run wechat_client.go
|
|||
### 注意
|
||||
<p id="ch14"> </p>
|
||||
|
||||
第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆。
|
||||
第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆(这时候go程序会卡住没有任何提示,注意掏出手机确认登录微信)。
|
||||
|
||||
第一次运行需要下载Diffusion模型,文件很大,并且从外网下载,需要有比较快速稳定的网络条件。
|
||||
|
||||
|
|
41
bot.py
41
bot.py
|
@ -7,6 +7,7 @@ import torch
|
|||
import argparse
|
||||
import flask
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
ps = argparse.ArgumentParser()
|
||||
ps.add_argument("--config", default="config.json", help="Configuration file")
|
||||
|
@ -63,7 +64,7 @@ sd_args = {
|
|||
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args)
|
||||
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
|
||||
if config_json["NoNSFWChecker"]:
|
||||
if config_json["Diffusion"]["NoNSFWChecker"]:
|
||||
setattr(sd_pipe, "run_safety_checker", run_safety_nochecker)
|
||||
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
|
@ -88,16 +89,21 @@ def CallOpenAIGPT(prompts : typing.List[str]):
|
|||
except openai.InvalidRequestError as e:
|
||||
return (GPT_ERROR, e)
|
||||
except Exception as e:
|
||||
return (GPT_ERROR, e)
|
||||
traceback.print_exception(e)
|
||||
return (GPT_ERROR, str(e))
|
||||
|
||||
def CallChatGLM(msg, history : typing.List[str]):
|
||||
try:
|
||||
resp, _ = chatglm_model.chat(chatglm_tokenizer, msg, history)
|
||||
resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history)
|
||||
if isinstance(resp, tuple):
|
||||
resp = resp[0]
|
||||
return (GPT_SUCCESS, resp)
|
||||
except Exception as e:
|
||||
pass
|
||||
return (GPT_ERROR, str(e))
|
||||
|
||||
def add_context(uid : str, is_user : bool, msg : str):
|
||||
if not uid in GlobalData.context_for_users:
|
||||
GlobalData.context_for_users[uid] = []
|
||||
if USE_OPENAIGPT:
|
||||
GlobalData.context_for_users[uid].append({
|
||||
"role" : "system",
|
||||
|
@ -106,6 +112,11 @@ def add_context(uid : str, is_user : bool, msg : str):
|
|||
)
|
||||
elif USE_CHATGLM:
|
||||
GlobalData.context_for_users[uid].append(msg)
|
||||
|
||||
def get_context(uid : str):
|
||||
if not uid in GlobalData.context_for_users:
|
||||
GlobalData.context_for_users[uid] = []
|
||||
return GlobalData.context_for_users[uid]
|
||||
|
||||
|
||||
@app.route("/chat_clear", methods=['POST'])
|
||||
|
@ -126,20 +137,24 @@ def app_chat():
|
|||
|
||||
if USE_OPENAIGPT:
|
||||
add_context(uid, True, data["text"])
|
||||
prompt = GlobalData.context_for_users[uid]
|
||||
#prompt = GlobalData.context_for_users[uid]
|
||||
prompt = get_context(uid)
|
||||
resp = CallOpenAIGPT(prompt=prompt)
|
||||
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
|
||||
add_context(uid, False, resp)
|
||||
print(f"Prompt = {prompt}\nResponse = {resp}")
|
||||
add_context(uid, False, resp[1])
|
||||
#print(f"Prompt = {prompt}\nResponse = {resp[1]}")
|
||||
elif USE_CHATGLM:
|
||||
prompt = GlobalData.context_for_users[uid]
|
||||
resp, _ = CallChatGLM(msg=data["text"], history=prompt)
|
||||
add_context(uid, True, data["text"])
|
||||
add_context(uid, False, resp)
|
||||
#prompt = GlobalData.context_for_users[uid]
|
||||
prompt = get_context(uid)
|
||||
resp = CallChatGLM(msg=data["text"], history=prompt)
|
||||
add_context(uid, True, (data["text"], resp[1]))
|
||||
else:
|
||||
pass
|
||||
|
||||
return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False})
|
||||
|
||||
if resp[0] == GPT_SUCCESS:
|
||||
return json.dumps({"user_id" : data["user_id"], "text" : resp[1], "error" : False, "error_msg" : ""})
|
||||
else:
|
||||
return json.dumps({"user_id" : data["user_id"], "text" : "", "error" : True, "error_msg" : resp[1]})
|
||||
|
||||
@app.route("/draw", methods=['POST'])
|
||||
def app_draw():
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
"GPT-Model" : "THUDM/chatglm2-6b"
|
||||
},
|
||||
"Diffusion" : {
|
||||
"DiffusionModel" : "stabilityai/stable-diffusion-2-1",
|
||||
"Diffusion-Model" : "stabilityai/stable-diffusion-2-1",
|
||||
"NoNSFWChecker" : true,
|
||||
"UseFP16" : true
|
||||
}
|
||||
|
|
|
@ -23,6 +23,13 @@ type SendTextRequest struct {
|
|||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type SendTextResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
Text string `json:"text"`
|
||||
HasError bool `json:"error"`
|
||||
ErrorMessage string `json:"error_msg"`
|
||||
}
|
||||
|
||||
type SendImageRequest struct {
|
||||
UserName string `json:"user_name"`
|
||||
FileNames []string `json:"filenames"`
|
||||
|
@ -135,7 +142,7 @@ func main() {
|
|||
}
|
||||
|
||||
} else {
|
||||
// 调用ChatGPT
|
||||
// 调用GPT
|
||||
|
||||
sender, _ := msg.Sender()
|
||||
//var group openwechat.Group{} = nil
|
||||
|
@ -155,7 +162,7 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
resp := SendTextRequest{}
|
||||
resp := SendTextResponse{}
|
||||
resp_raw := []byte("")
|
||||
|
||||
if !msg.IsSendByGroup() {
|
||||
|
@ -173,12 +180,22 @@ func main() {
|
|||
if len(resp.Text) == 0 {
|
||||
msg.ReplyText("GPT对此没有什么想说的,换个话题吧。")
|
||||
} else {
|
||||
if msg.IsSendByGroup() {
|
||||
sender_in_group, _ := msg.SenderInGroup()
|
||||
nickname := sender_in_group.NickName
|
||||
msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text))
|
||||
if resp.HasError {
|
||||
if msg.IsSendByGroup() {
|
||||
sender_in_group, _ := msg.SenderInGroup()
|
||||
nickname := sender_in_group.NickName
|
||||
msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.ErrorMessage))
|
||||
} else {
|
||||
msg.ReplyText(resp.ErrorMessage)
|
||||
}
|
||||
} else {
|
||||
msg.ReplyText(resp.Text)
|
||||
if msg.IsSendByGroup() {
|
||||
sender_in_group, _ := msg.SenderInGroup()
|
||||
nickname := sender_in_group.NickName
|
||||
msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text))
|
||||
} else {
|
||||
msg.ReplyText(resp.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue