update to v1.2
This commit is contained in:
parent
5497ae3a96
commit
adf34e91e9
10
Readme.md
10
Readme.md
|
@ -26,7 +26,13 @@
|
||||||
python3需要再安装这些库,使用pip安装就可以:
|
python3需要再安装这些库,使用pip安装就可以:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install torch flask openai transformers diffusers accelerate sentencepiece cpm_kernels
|
pip install torch flask openai transformers diffusers accelerate
|
||||||
|
```
|
||||||
|
|
||||||
|
如果你要使用ChatGLM,注意torch的版本应该>=2.0,transformers的版本为4.30.2,并且还要安装
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install gradio mdtex2html sentencepiece cpm_kernels
|
||||||
```
|
```
|
||||||
|
|
||||||
当然如果使用cuda加速建议按照<a href="https://pytorch.org">pytorch官网</a>提供的方法安装支持cuda加速的torch版本。
|
当然如果使用cuda加速建议按照<a href="https://pytorch.org">pytorch官网</a>提供的方法安装支持cuda加速的torch版本。
|
||||||
|
@ -73,7 +79,7 @@ go run wechat_client.go
|
||||||
### 注意
|
### 注意
|
||||||
<p id="ch14"> </p>
|
<p id="ch14"> </p>
|
||||||
|
|
||||||
第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆。
|
第一次运行时候会弹出网页扫码登录微信,登陆一次之后之后再登陆不需要扫码,但仍然需要在手机上点击确认登陆(这时候go程序会卡住没有任何提示,注意掏出手机确认登录微信)。
|
||||||
|
|
||||||
第一次运行需要下载Diffusion模型,文件很大,并且从外网下载,需要有比较快速稳定的网络条件。
|
第一次运行需要下载Diffusion模型,文件很大,并且从外网下载,需要有比较快速稳定的网络条件。
|
||||||
|
|
||||||
|
|
39
bot.py
39
bot.py
|
@ -7,6 +7,7 @@ import torch
|
||||||
import argparse
|
import argparse
|
||||||
import flask
|
import flask
|
||||||
import typing
|
import typing
|
||||||
|
import traceback
|
||||||
|
|
||||||
ps = argparse.ArgumentParser()
|
ps = argparse.ArgumentParser()
|
||||||
ps.add_argument("--config", default="config.json", help="Configuration file")
|
ps.add_argument("--config", default="config.json", help="Configuration file")
|
||||||
|
@ -63,7 +64,7 @@ sd_args = {
|
||||||
|
|
||||||
sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args)
|
sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args)
|
||||||
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
|
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
|
||||||
if config_json["NoNSFWChecker"]:
|
if config_json["Diffusion"]["NoNSFWChecker"]:
|
||||||
setattr(sd_pipe, "run_safety_checker", run_safety_nochecker)
|
setattr(sd_pipe, "run_safety_checker", run_safety_nochecker)
|
||||||
|
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
@ -88,16 +89,21 @@ def CallOpenAIGPT(prompts : typing.List[str]):
|
||||||
except openai.InvalidRequestError as e:
|
except openai.InvalidRequestError as e:
|
||||||
return (GPT_ERROR, e)
|
return (GPT_ERROR, e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return (GPT_ERROR, e)
|
traceback.print_exception(e)
|
||||||
|
return (GPT_ERROR, str(e))
|
||||||
|
|
||||||
def CallChatGLM(msg, history : typing.List[str]):
|
def CallChatGLM(msg, history : typing.List[str]):
|
||||||
try:
|
try:
|
||||||
resp, _ = chatglm_model.chat(chatglm_tokenizer, msg, history)
|
resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history)
|
||||||
|
if isinstance(resp, tuple):
|
||||||
|
resp = resp[0]
|
||||||
return (GPT_SUCCESS, resp)
|
return (GPT_SUCCESS, resp)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
return (GPT_ERROR, str(e))
|
||||||
|
|
||||||
def add_context(uid : str, is_user : bool, msg : str):
|
def add_context(uid : str, is_user : bool, msg : str):
|
||||||
|
if not uid in GlobalData.context_for_users:
|
||||||
|
GlobalData.context_for_users[uid] = []
|
||||||
if USE_OPENAIGPT:
|
if USE_OPENAIGPT:
|
||||||
GlobalData.context_for_users[uid].append({
|
GlobalData.context_for_users[uid].append({
|
||||||
"role" : "system",
|
"role" : "system",
|
||||||
|
@ -107,6 +113,11 @@ def add_context(uid : str, is_user : bool, msg : str):
|
||||||
elif USE_CHATGLM:
|
elif USE_CHATGLM:
|
||||||
GlobalData.context_for_users[uid].append(msg)
|
GlobalData.context_for_users[uid].append(msg)
|
||||||
|
|
||||||
|
def get_context(uid : str):
|
||||||
|
if not uid in GlobalData.context_for_users:
|
||||||
|
GlobalData.context_for_users[uid] = []
|
||||||
|
return GlobalData.context_for_users[uid]
|
||||||
|
|
||||||
|
|
||||||
@app.route("/chat_clear", methods=['POST'])
|
@app.route("/chat_clear", methods=['POST'])
|
||||||
def app_chat_clear():
|
def app_chat_clear():
|
||||||
|
@ -126,20 +137,24 @@ def app_chat():
|
||||||
|
|
||||||
if USE_OPENAIGPT:
|
if USE_OPENAIGPT:
|
||||||
add_context(uid, True, data["text"])
|
add_context(uid, True, data["text"])
|
||||||
prompt = GlobalData.context_for_users[uid]
|
#prompt = GlobalData.context_for_users[uid]
|
||||||
|
prompt = get_context(uid)
|
||||||
resp = CallOpenAIGPT(prompt=prompt)
|
resp = CallOpenAIGPT(prompt=prompt)
|
||||||
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
|
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
|
||||||
add_context(uid, False, resp)
|
add_context(uid, False, resp[1])
|
||||||
print(f"Prompt = {prompt}\nResponse = {resp}")
|
#print(f"Prompt = {prompt}\nResponse = {resp[1]}")
|
||||||
elif USE_CHATGLM:
|
elif USE_CHATGLM:
|
||||||
prompt = GlobalData.context_for_users[uid]
|
#prompt = GlobalData.context_for_users[uid]
|
||||||
resp, _ = CallChatGLM(msg=data["text"], history=prompt)
|
prompt = get_context(uid)
|
||||||
add_context(uid, True, data["text"])
|
resp = CallChatGLM(msg=data["text"], history=prompt)
|
||||||
add_context(uid, False, resp)
|
add_context(uid, True, (data["text"], resp[1]))
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False})
|
if resp[0] == GPT_SUCCESS:
|
||||||
|
return json.dumps({"user_id" : data["user_id"], "text" : resp[1], "error" : False, "error_msg" : ""})
|
||||||
|
else:
|
||||||
|
return json.dumps({"user_id" : data["user_id"], "text" : "", "error" : True, "error_msg" : resp[1]})
|
||||||
|
|
||||||
@app.route("/draw", methods=['POST'])
|
@app.route("/draw", methods=['POST'])
|
||||||
def app_draw():
|
def app_draw():
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
"GPT-Model" : "THUDM/chatglm2-6b"
|
"GPT-Model" : "THUDM/chatglm2-6b"
|
||||||
},
|
},
|
||||||
"Diffusion" : {
|
"Diffusion" : {
|
||||||
"DiffusionModel" : "stabilityai/stable-diffusion-2-1",
|
"Diffusion-Model" : "stabilityai/stable-diffusion-2-1",
|
||||||
"NoNSFWChecker" : true,
|
"NoNSFWChecker" : true,
|
||||||
"UseFP16" : true
|
"UseFP16" : true
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,13 @@ type SendTextRequest struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SendTextResponse struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
HasError bool `json:"error"`
|
||||||
|
ErrorMessage string `json:"error_msg"`
|
||||||
|
}
|
||||||
|
|
||||||
type SendImageRequest struct {
|
type SendImageRequest struct {
|
||||||
UserName string `json:"user_name"`
|
UserName string `json:"user_name"`
|
||||||
FileNames []string `json:"filenames"`
|
FileNames []string `json:"filenames"`
|
||||||
|
@ -135,7 +142,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// 调用ChatGPT
|
// 调用GPT
|
||||||
|
|
||||||
sender, _ := msg.Sender()
|
sender, _ := msg.Sender()
|
||||||
//var group openwechat.Group{} = nil
|
//var group openwechat.Group{} = nil
|
||||||
|
@ -155,7 +162,7 @@ func main() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := SendTextRequest{}
|
resp := SendTextResponse{}
|
||||||
resp_raw := []byte("")
|
resp_raw := []byte("")
|
||||||
|
|
||||||
if !msg.IsSendByGroup() {
|
if !msg.IsSendByGroup() {
|
||||||
|
@ -172,6 +179,15 @@ func main() {
|
||||||
|
|
||||||
if len(resp.Text) == 0 {
|
if len(resp.Text) == 0 {
|
||||||
msg.ReplyText("GPT对此没有什么想说的,换个话题吧。")
|
msg.ReplyText("GPT对此没有什么想说的,换个话题吧。")
|
||||||
|
} else {
|
||||||
|
if resp.HasError {
|
||||||
|
if msg.IsSendByGroup() {
|
||||||
|
sender_in_group, _ := msg.SenderInGroup()
|
||||||
|
nickname := sender_in_group.NickName
|
||||||
|
msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.ErrorMessage))
|
||||||
|
} else {
|
||||||
|
msg.ReplyText(resp.ErrorMessage)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if msg.IsSendByGroup() {
|
if msg.IsSendByGroup() {
|
||||||
sender_in_group, _ := msg.SenderInGroup()
|
sender_in_group, _ := msg.SenderInGroup()
|
||||||
|
@ -181,6 +197,7 @@ func main() {
|
||||||
msg.ReplyText(resp.Text)
|
msg.ReplyText(resp.Text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue