完成文生文AI接口,并将消息处理从同步改为异步
This commit is contained in:
parent
0f0a8d4da8
commit
6c67cb5f2a
|
@ -5,3 +5,4 @@ latest*.png
|
|||
.DS_Store
|
||||
anything-v4.0
|
||||
chatglm2-6b
|
||||
/.idea/
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
import json
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
import flask
|
||||
|
||||
|
||||
|
||||
# client = ZhipuAI(api_key="73bdeed728677bc80efc6956478a2315.VerNWJMCwN9L5gTi") # 请填写您自己的APIKey
|
||||
# response = client.chat.completions.create(
|
||||
# model="glm-4", # 请填写您要调用的模型名称
|
||||
# messages=[
|
||||
# {"role": "user", "content": "你好"},
|
||||
# ],
|
||||
# )
|
||||
# print(response.choices[0].message)
|
||||
app = flask.Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/chat", methods=['POST'])
|
||||
def app_chat():
|
||||
data = json.loads(flask.globals.request.get_data())
|
||||
# print(data)
|
||||
uid = data["user_id"]
|
||||
|
||||
if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']:
|
||||
data["text"] += "。"
|
||||
|
||||
# 使用ZhipuAI库调用模型生成回复
|
||||
client = ZhipuAI(api_key="73bdeed728677bc80efc6956478a2315.VerNWJMCwN9L5gTi") # 请填写您自己的APIKey
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4-flash", # 请填写您要调用的模型名称
|
||||
messages=[
|
||||
{"role": "user", "content": data["text"]},
|
||||
],
|
||||
)
|
||||
|
||||
# 获取模型的回复
|
||||
resp = response.choices[0].message.content
|
||||
|
||||
if resp:
|
||||
return json.dumps({"user_id": data["user_id"], "text": resp, "error": False, "error_msg": ""})
|
||||
else:
|
||||
return json.dumps({"user_id": data["user_id"], "text": "", "error": True, "error_msg": "模型未返回回复"})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0", port=11111)
|
102
bot.py
102
bot.py
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import openai
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
import re
|
||||
from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
@ -16,6 +17,7 @@ args = ps.parse_args()
|
|||
with open(args.config) as f:
|
||||
config_json = json.load(f)
|
||||
|
||||
|
||||
class GlobalData:
|
||||
# OPENAI_ORGID = config_json[""]
|
||||
OPENAI_APIKEY = config_json["OpenAI-GPT"]["OpenAI-Key"]
|
||||
|
@ -31,7 +33,8 @@ class GlobalData:
|
|||
GENERATE_PICTURE_ARG_PAT = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))")
|
||||
GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))")
|
||||
GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+")
|
||||
GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数
|
||||
GENERATE_PICTURE_MAX_ITS = 200 # 最大迭代次数
|
||||
|
||||
|
||||
USE_OPENAIGPT = False
|
||||
USE_CHATGLM = False
|
||||
|
@ -52,14 +55,16 @@ elif config_json["ChatGLM"]["Enable"]:
|
|||
|
||||
app = flask.Flask(__name__)
|
||||
|
||||
|
||||
# 这个用于放行生成的任何图片,替换掉默认的NSFW检查器,公共场合慎重使用
|
||||
def run_safety_nochecker(image, device, dtype):
|
||||
print("警告:屏蔽了内容安全性检查,可能会产生有害内容")
|
||||
return image, None
|
||||
|
||||
|
||||
sd_args = {
|
||||
"pretrained_model_name_or_path" : config_json["Diffusion"]["Diffusion-Model"],
|
||||
"torch_dtype" : (torch.float16 if config_json["Diffusion"].get("UseFP16", True) else torch.float32)
|
||||
"pretrained_model_name_or_path": config_json["Diffusion"]["Diffusion-Model"],
|
||||
"torch_dtype": (torch.float16 if config_json["Diffusion"].get("UseFP16", True) else torch.float32)
|
||||
}
|
||||
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args)
|
||||
|
@ -76,7 +81,8 @@ GPT_SUCCESS = 0
|
|||
GPT_NORESULT = 1
|
||||
GPT_ERROR = 2
|
||||
|
||||
def CallOpenAIGPT(prompts : typing.List[str]):
|
||||
|
||||
def CallOpenAIGPT(prompts: typing.List[str]):
|
||||
try:
|
||||
res = openai.ChatCompletion.create(
|
||||
model=config_json["OpenAI-GPT"]["GPT-Model"],
|
||||
|
@ -92,7 +98,8 @@ def CallOpenAIGPT(prompts : typing.List[str]):
|
|||
traceback.print_exception(e)
|
||||
return (GPT_ERROR, str(e))
|
||||
|
||||
def CallChatGLM(msg, history : typing.List[str]):
|
||||
|
||||
def CallChatGLM(msg, history: typing.List[str]):
|
||||
try:
|
||||
resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history)
|
||||
if isinstance(resp, tuple):
|
||||
|
@ -101,19 +108,21 @@ def CallChatGLM(msg, history : typing.List[str]):
|
|||
except Exception as e:
|
||||
return (GPT_ERROR, str(e))
|
||||
|
||||
def add_context(uid : str, is_user : bool, msg : str):
|
||||
|
||||
def add_context(uid: str, is_user: bool, msg: str):
|
||||
if not uid in GlobalData.context_for_users:
|
||||
GlobalData.context_for_users[uid] = []
|
||||
if USE_OPENAIGPT:
|
||||
GlobalData.context_for_users[uid].append({
|
||||
"role" : "system",
|
||||
"content" : msg
|
||||
"role": "system",
|
||||
"content": msg
|
||||
}
|
||||
)
|
||||
elif USE_CHATGLM:
|
||||
GlobalData.context_for_users[uid].append(msg)
|
||||
|
||||
def get_context(uid : str):
|
||||
|
||||
def get_context(uid: str):
|
||||
if not uid in GlobalData.context_for_users:
|
||||
GlobalData.context_for_users[uid] = []
|
||||
return GlobalData.context_for_users[uid]
|
||||
|
@ -126,35 +135,33 @@ def app_chat_clear():
|
|||
print(f"Cleared context for {data['user_id']}")
|
||||
return ""
|
||||
|
||||
|
||||
@app.route("/chat", methods=['POST'])
|
||||
def app_chat():
|
||||
data = json.loads(flask.globals.request.get_data())
|
||||
#print(data)
|
||||
# print(data)
|
||||
uid = data["user_id"]
|
||||
|
||||
if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']:
|
||||
data["text"] += "。"
|
||||
|
||||
if USE_OPENAIGPT:
|
||||
add_context(uid, True, data["text"])
|
||||
#prompt = GlobalData.context_for_users[uid]
|
||||
prompt = get_context(uid)
|
||||
resp = CallOpenAIGPT(prompt=prompt)
|
||||
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
|
||||
add_context(uid, False, resp[1])
|
||||
#print(f"Prompt = {prompt}\nResponse = {resp[1]}")
|
||||
elif USE_CHATGLM:
|
||||
#prompt = GlobalData.context_for_users[uid]
|
||||
prompt = get_context(uid)
|
||||
resp = CallChatGLM(msg=data["text"], history=prompt)
|
||||
add_context(uid, True, (data["text"], resp[1]))
|
||||
else:
|
||||
pass
|
||||
# 使用ZhipuAI库调用模型生成回复
|
||||
client = ZhipuAI(api_key="73bdeed728677bc80efc6956478a2315.VerNWJMCwN9L5gTi") # 请填写您自己的APIKey
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4", # 请填写您要调用的模型名称
|
||||
messages=[
|
||||
{"role": "user", "content": data["text"]},
|
||||
],
|
||||
)
|
||||
|
||||
if resp[0] == GPT_SUCCESS:
|
||||
return json.dumps({"user_id" : data["user_id"], "text" : resp[1], "error" : False, "error_msg" : ""})
|
||||
# 获取模型的回复
|
||||
resp = response.choices[0].message.content
|
||||
|
||||
if resp:
|
||||
return json.dumps({"user_id": data["user_id"], "text": resp, "error": False, "error_msg": ""})
|
||||
else:
|
||||
return json.dumps({"user_id" : data["user_id"], "text" : "", "error" : True, "error_msg" : resp[1]})
|
||||
return json.dumps({"user_id": data["user_id"], "text": "", "error": True, "error_msg": "模型未返回回复"})
|
||||
|
||||
|
||||
@app.route("/draw", methods=['POST'])
|
||||
def app_draw():
|
||||
|
@ -167,8 +174,8 @@ def app_draw():
|
|||
if prompt[i] == ':' or prompt[i] == ':':
|
||||
break
|
||||
if i == len(prompt):
|
||||
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"})
|
||||
|
||||
return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True,
|
||||
"error_msg": "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"})
|
||||
|
||||
match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT2, prompt[:i])
|
||||
if not match_args is None:
|
||||
|
@ -185,7 +192,8 @@ def app_draw():
|
|||
NUM_PIC = 1
|
||||
else:
|
||||
if len(prompt[:i].strip()) != 0:
|
||||
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"})
|
||||
return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True,
|
||||
"error_msg": "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"})
|
||||
else:
|
||||
W = 768
|
||||
H = 768
|
||||
|
@ -193,12 +201,14 @@ def app_draw():
|
|||
NUM_PIC = 1
|
||||
|
||||
if W > 2500 or H > 2500:
|
||||
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"})
|
||||
return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True,
|
||||
"error_msg": "你要求的图片太大了,我不干了~"})
|
||||
|
||||
if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS:
|
||||
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"})
|
||||
return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True,
|
||||
"error_msg": f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"})
|
||||
|
||||
prompt = prompt[(i+1):].strip()
|
||||
prompt = prompt[(i + 1):].strip()
|
||||
|
||||
prompts = re.split(GlobalData.GENERATE_PICTURE_NEG_PROMPT_DELIMETER, prompt)
|
||||
prompt = prompts[0]
|
||||
|
@ -210,31 +220,37 @@ def app_draw():
|
|||
print(f"Generating {NUM_PIC} picture(s) with prompt = {prompt} , negative prompt = {neg_prompt}")
|
||||
|
||||
try:
|
||||
if NUM_PIC > 1 and torch.backends.mps.is_available(): #Apple silicon上的bug:https://github.com/huggingface/diffusers/issues/363
|
||||
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True,
|
||||
"error_msg" : "单prompt生成多张图像在Apple silicon上无法实现,相关讨论参考https://github.com/huggingface/diffusers/issues/363"})
|
||||
if NUM_PIC > 1 and torch.backends.mps.is_available(): # Apple silicon上的bug:https://github.com/huggingface/diffusers/issues/363
|
||||
return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True,
|
||||
"error_msg": "单prompt生成多张图像在Apple silicon上无法实现,相关讨论参考https://github.com/huggingface/diffusers/issues/363"})
|
||||
|
||||
images = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS, num_images_per_prompt=NUM_PIC).images[:NUM_PIC]
|
||||
images = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS,
|
||||
num_images_per_prompt=NUM_PIC).images[:NUM_PIC]
|
||||
if len(images) == 0:
|
||||
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "没有产生任何图像"})
|
||||
return json.dumps(
|
||||
{"user_name": data["user_name"], "filenames": [], "error": True, "error_msg": "没有产生任何图像"})
|
||||
filenames = []
|
||||
for i, img in enumerate(images):
|
||||
img.save(f"latest-{i}.png")
|
||||
filenames.append(f"latest-{i}.png")
|
||||
return json.dumps({"user_name" : data["user_name"], "filenames" : filenames, "error" : False, "error_msg" : ""})
|
||||
return json.dumps({"user_name": data["user_name"], "filenames": filenames, "error": False, "error_msg": ""})
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : str(e)})
|
||||
return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True, "error_msg": str(e)})
|
||||
|
||||
|
||||
@app.route("/info", methods=['POST', 'GET'])
|
||||
def app_info():
|
||||
return "\n".join([f"GPT模型:{config_json['OpenAI-GPT']['GPT-Model'] if USE_OPENAIGPT else config_json['ChatGLM']['GPT-Model']}", f"Diffusion模型:{config_json['Diffusion']['Diffusion-Model']}",
|
||||
return "\n".join(
|
||||
[f"GPT模型:{config_json['OpenAI-GPT']['GPT-Model'] if USE_OPENAIGPT else config_json['ChatGLM']['GPT-Model']}",
|
||||
f"Diffusion模型:{config_json['Diffusion']['Diffusion-Model']}",
|
||||
"默认图片规格:768x768 RGB三通道", "Diffusion默认迭代轮数:20",
|
||||
f"使用半精度浮点数 : {'是' if config_json['Diffusion'].get('UseFP16', True) else '否'}",
|
||||
f"屏蔽NSFW检查:{'是' if config_json['Diffusion']['NoNSFWChecker'] else '否'}",
|
||||
"清空上下文指令:重置上下文",
|
||||
"生成图片指令:生成图片(宽 高 迭代次数):正面提示 换行写负面提示,其中(宽 高 迭代次数)和换行写的负面提示都是可以省略的"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if USE_OPENAIGPT:
|
||||
|
|
|
@ -86,6 +86,13 @@ func main() {
|
|||
|
||||
// 注册消息处理函数
|
||||
bot.MessageHandler = func(msg *openwechat.Message) {
|
||||
go sendMessage(msg, self)
|
||||
}
|
||||
|
||||
bot.Block()
|
||||
}
|
||||
|
||||
func sendMessage(msg *openwechat.Message, self *openwechat.Self) {
|
||||
if msg.IsTickledMe() {
|
||||
msg.ReplyText("别拍了,机器人是会被拍坏掉的。")
|
||||
return
|
||||
|
@ -199,7 +206,4 @@ func main() {
|
|||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
bot.Block()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue