完成文生文AI接口,并将消息处理从同步改为异步

This commit is contained in:
Cool 2024-11-08 18:15:08 +08:00
parent 0f0a8d4da8
commit 6c67cb5f2a
4 changed files with 249 additions and 181 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ latest*.png
.DS_Store .DS_Store
anything-v4.0 anything-v4.0
chatglm2-6b chatglm2-6b
/.idea/

47
ai.py Normal file
View File

@ -0,0 +1,47 @@
import json
from zhipuai import ZhipuAI
import flask
# client = ZhipuAI(api_key="73bdeed728677bc80efc6956478a2315.VerNWJMCwN9L5gTi") # 请填写您自己的APIKey
# response = client.chat.completions.create(
# model="glm-4", # 请填写您要调用的模型名称
# messages=[
# {"role": "user", "content": "你好"},
# ],
# )
# print(response.choices[0].message)
app = flask.Flask(__name__)
@app.route("/chat", methods=['POST'])
def app_chat():
data = json.loads(flask.globals.request.get_data())
# print(data)
uid = data["user_id"]
if not data["text"][-1] in ['?', '', '.', '', ',', '', '!', '']:
data["text"] += ""
# 使用ZhipuAI库调用模型生成回复
client = ZhipuAI(api_key="73bdeed728677bc80efc6956478a2315.VerNWJMCwN9L5gTi") # 请填写您自己的APIKey
response = client.chat.completions.create(
model="glm-4-flash", # 请填写您要调用的模型名称
messages=[
{"role": "user", "content": data["text"]},
],
)
# 获取模型的回复
resp = response.choices[0].message.content
if resp:
return json.dumps({"user_id": data["user_id"], "text": resp, "error": False, "error_msg": ""})
else:
return json.dumps({"user_id": data["user_id"], "text": "", "error": True, "error_msg": "模型未返回回复"})
if __name__ == '__main__':
app.run(host="0.0.0.0", port=11111)

102
bot.py
View File

@ -1,5 +1,6 @@
import json import json
import openai
from zhipuai import ZhipuAI
import re import re
from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
@ -16,6 +17,7 @@ args = ps.parse_args()
with open(args.config) as f: with open(args.config) as f:
config_json = json.load(f) config_json = json.load(f)
class GlobalData: class GlobalData:
# OPENAI_ORGID = config_json[""] # OPENAI_ORGID = config_json[""]
OPENAI_APIKEY = config_json["OpenAI-GPT"]["OpenAI-Key"] OPENAI_APIKEY = config_json["OpenAI-GPT"]["OpenAI-Key"]
@ -31,7 +33,8 @@ class GlobalData:
GENERATE_PICTURE_ARG_PAT = re.compile("(\(|)([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|)") GENERATE_PICTURE_ARG_PAT = re.compile("(\(|)([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|)")
GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|)([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|)") GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|)([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|)")
GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+") GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+")
GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数 GENERATE_PICTURE_MAX_ITS = 200 # 最大迭代次数
USE_OPENAIGPT = False USE_OPENAIGPT = False
USE_CHATGLM = False USE_CHATGLM = False
@ -52,14 +55,16 @@ elif config_json["ChatGLM"]["Enable"]:
app = flask.Flask(__name__) app = flask.Flask(__name__)
# 这个用于放行生成的任何图片替换掉默认的NSFW检查器公共场合慎重使用 # 这个用于放行生成的任何图片替换掉默认的NSFW检查器公共场合慎重使用
def run_safety_nochecker(image, device, dtype): def run_safety_nochecker(image, device, dtype):
print("警告:屏蔽了内容安全性检查,可能会产生有害内容") print("警告:屏蔽了内容安全性检查,可能会产生有害内容")
return image, None return image, None
sd_args = { sd_args = {
"pretrained_model_name_or_path" : config_json["Diffusion"]["Diffusion-Model"], "pretrained_model_name_or_path": config_json["Diffusion"]["Diffusion-Model"],
"torch_dtype" : (torch.float16 if config_json["Diffusion"].get("UseFP16", True) else torch.float32) "torch_dtype": (torch.float16 if config_json["Diffusion"].get("UseFP16", True) else torch.float32)
} }
sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args) sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args)
@ -76,7 +81,8 @@ GPT_SUCCESS = 0
GPT_NORESULT = 1 GPT_NORESULT = 1
GPT_ERROR = 2 GPT_ERROR = 2
def CallOpenAIGPT(prompts : typing.List[str]):
def CallOpenAIGPT(prompts: typing.List[str]):
try: try:
res = openai.ChatCompletion.create( res = openai.ChatCompletion.create(
model=config_json["OpenAI-GPT"]["GPT-Model"], model=config_json["OpenAI-GPT"]["GPT-Model"],
@ -92,7 +98,8 @@ def CallOpenAIGPT(prompts : typing.List[str]):
traceback.print_exception(e) traceback.print_exception(e)
return (GPT_ERROR, str(e)) return (GPT_ERROR, str(e))
def CallChatGLM(msg, history : typing.List[str]):
def CallChatGLM(msg, history: typing.List[str]):
try: try:
resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history) resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history)
if isinstance(resp, tuple): if isinstance(resp, tuple):
@ -101,19 +108,21 @@ def CallChatGLM(msg, history : typing.List[str]):
except Exception as e: except Exception as e:
return (GPT_ERROR, str(e)) return (GPT_ERROR, str(e))
def add_context(uid : str, is_user : bool, msg : str):
def add_context(uid: str, is_user: bool, msg: str):
if not uid in GlobalData.context_for_users: if not uid in GlobalData.context_for_users:
GlobalData.context_for_users[uid] = [] GlobalData.context_for_users[uid] = []
if USE_OPENAIGPT: if USE_OPENAIGPT:
GlobalData.context_for_users[uid].append({ GlobalData.context_for_users[uid].append({
"role" : "system", "role": "system",
"content" : msg "content": msg
} }
) )
elif USE_CHATGLM: elif USE_CHATGLM:
GlobalData.context_for_users[uid].append(msg) GlobalData.context_for_users[uid].append(msg)
def get_context(uid : str):
def get_context(uid: str):
if not uid in GlobalData.context_for_users: if not uid in GlobalData.context_for_users:
GlobalData.context_for_users[uid] = [] GlobalData.context_for_users[uid] = []
return GlobalData.context_for_users[uid] return GlobalData.context_for_users[uid]
@ -126,35 +135,33 @@ def app_chat_clear():
print(f"Cleared context for {data['user_id']}") print(f"Cleared context for {data['user_id']}")
return "" return ""
@app.route("/chat", methods=['POST']) @app.route("/chat", methods=['POST'])
def app_chat(): def app_chat():
data = json.loads(flask.globals.request.get_data()) data = json.loads(flask.globals.request.get_data())
#print(data) # print(data)
uid = data["user_id"] uid = data["user_id"]
if not data["text"][-1] in ['?', '', '.', '', ',', '', '!', '']: if not data["text"][-1] in ['?', '', '.', '', ',', '', '!', '']:
data["text"] += "" data["text"] += ""
if USE_OPENAIGPT: # 使用ZhipuAI库调用模型生成回复
add_context(uid, True, data["text"]) client = ZhipuAI(api_key="73bdeed728677bc80efc6956478a2315.VerNWJMCwN9L5gTi") # 请填写您自己的APIKey
#prompt = GlobalData.context_for_users[uid] response = client.chat.completions.create(
prompt = get_context(uid) model="glm-4", # 请填写您要调用的模型名称
resp = CallOpenAIGPT(prompt=prompt) messages=[
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp) {"role": "user", "content": data["text"]},
add_context(uid, False, resp[1]) ],
#print(f"Prompt = {prompt}\nResponse = {resp[1]}") )
elif USE_CHATGLM:
#prompt = GlobalData.context_for_users[uid]
prompt = get_context(uid)
resp = CallChatGLM(msg=data["text"], history=prompt)
add_context(uid, True, (data["text"], resp[1]))
else:
pass
if resp[0] == GPT_SUCCESS: # 获取模型的回复
return json.dumps({"user_id" : data["user_id"], "text" : resp[1], "error" : False, "error_msg" : ""}) resp = response.choices[0].message.content
if resp:
return json.dumps({"user_id": data["user_id"], "text": resp, "error": False, "error_msg": ""})
else: else:
return json.dumps({"user_id" : data["user_id"], "text" : "", "error" : True, "error_msg" : resp[1]}) return json.dumps({"user_id": data["user_id"], "text": "", "error": True, "error_msg": "模型未返回回复"})
@app.route("/draw", methods=['POST']) @app.route("/draw", methods=['POST'])
def app_draw(): def app_draw():
@ -167,8 +174,8 @@ def app_draw():
if prompt[i] == ':' or prompt[i] == '': if prompt[i] == ':' or prompt[i] == '':
break break
if i == len(prompt): if i == len(prompt):
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对正确的格式是生成图片Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)])Prompt"}) return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True,
"error_msg": "格式不对正确的格式是生成图片Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)])Prompt"})
match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT2, prompt[:i]) match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT2, prompt[:i])
if not match_args is None: if not match_args is None:
@ -185,7 +192,8 @@ def app_draw():
NUM_PIC = 1 NUM_PIC = 1
else: else:
if len(prompt[:i].strip()) != 0: if len(prompt[:i].strip()) != 0:
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对正确的格式是生成图片Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)])Prompt"}) return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True,
"error_msg": "格式不对正确的格式是生成图片Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)])Prompt"})
else: else:
W = 768 W = 768
H = 768 H = 768
@ -193,12 +201,14 @@ def app_draw():
NUM_PIC = 1 NUM_PIC = 1
if W > 2500 or H > 2500: if W > 2500 or H > 2500:
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"}) return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True,
"error_msg": "你要求的图片太大了,我不干了~"})
if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS: if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS:
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}"}) return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True,
"error_msg": f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}"})
prompt = prompt[(i+1):].strip() prompt = prompt[(i + 1):].strip()
prompts = re.split(GlobalData.GENERATE_PICTURE_NEG_PROMPT_DELIMETER, prompt) prompts = re.split(GlobalData.GENERATE_PICTURE_NEG_PROMPT_DELIMETER, prompt)
prompt = prompts[0] prompt = prompts[0]
@ -210,31 +220,37 @@ def app_draw():
print(f"Generating {NUM_PIC} picture(s) with prompt = {prompt} , negative prompt = {neg_prompt}") print(f"Generating {NUM_PIC} picture(s) with prompt = {prompt} , negative prompt = {neg_prompt}")
try: try:
if NUM_PIC > 1 and torch.backends.mps.is_available(): #Apple silicon上的bughttps://github.com/huggingface/diffusers/issues/363 if NUM_PIC > 1 and torch.backends.mps.is_available(): # Apple silicon上的bughttps://github.com/huggingface/diffusers/issues/363
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True,
"error_msg" : "单prompt生成多张图像在Apple silicon上无法实现相关讨论参考https://github.com/huggingface/diffusers/issues/363"}) "error_msg": "单prompt生成多张图像在Apple silicon上无法实现相关讨论参考https://github.com/huggingface/diffusers/issues/363"})
images = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS, num_images_per_prompt=NUM_PIC).images[:NUM_PIC] images = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS,
num_images_per_prompt=NUM_PIC).images[:NUM_PIC]
if len(images) == 0: if len(images) == 0:
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "没有产生任何图像"}) return json.dumps(
{"user_name": data["user_name"], "filenames": [], "error": True, "error_msg": "没有产生任何图像"})
filenames = [] filenames = []
for i, img in enumerate(images): for i, img in enumerate(images):
img.save(f"latest-{i}.png") img.save(f"latest-{i}.png")
filenames.append(f"latest-{i}.png") filenames.append(f"latest-{i}.png")
return json.dumps({"user_name" : data["user_name"], "filenames" : filenames, "error" : False, "error_msg" : ""}) return json.dumps({"user_name": data["user_name"], "filenames": filenames, "error": False, "error_msg": ""})
except Exception as e: except Exception as e:
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : str(e)}) return json.dumps({"user_name": data["user_name"], "filenames": [], "error": True, "error_msg": str(e)})
@app.route("/info", methods=['POST', 'GET']) @app.route("/info", methods=['POST', 'GET'])
def app_info(): def app_info():
return "\n".join([f"GPT模型{config_json['OpenAI-GPT']['GPT-Model'] if USE_OPENAIGPT else config_json['ChatGLM']['GPT-Model']}", f"Diffusion模型{config_json['Diffusion']['Diffusion-Model']}", return "\n".join(
[f"GPT模型{config_json['OpenAI-GPT']['GPT-Model'] if USE_OPENAIGPT else config_json['ChatGLM']['GPT-Model']}",
f"Diffusion模型{config_json['Diffusion']['Diffusion-Model']}",
"默认图片规格768x768 RGB三通道", "Diffusion默认迭代轮数20", "默认图片规格768x768 RGB三通道", "Diffusion默认迭代轮数20",
f"使用半精度浮点数 : {'' if config_json['Diffusion'].get('UseFP16', True) else ''}", f"使用半精度浮点数 : {'' if config_json['Diffusion'].get('UseFP16', True) else ''}",
f"屏蔽NSFW检查{'' if config_json['Diffusion']['NoNSFWChecker'] else ''}", f"屏蔽NSFW检查{'' if config_json['Diffusion']['NoNSFWChecker'] else ''}",
"清空上下文指令:重置上下文", "清空上下文指令:重置上下文",
"生成图片指令:生成图片(宽 高 迭代次数):正面提示 换行写负面提示,其中(宽 高 迭代次数)和换行写的负面提示都是可以省略的"]) "生成图片指令:生成图片(宽 高 迭代次数):正面提示 换行写负面提示,其中(宽 高 迭代次数)和换行写的负面提示都是可以省略的"])
if __name__ == "__main__": if __name__ == "__main__":
if USE_OPENAIGPT: if USE_OPENAIGPT:

View File

@ -86,6 +86,13 @@ func main() {
// 注册消息处理函数 // 注册消息处理函数
bot.MessageHandler = func(msg *openwechat.Message) { bot.MessageHandler = func(msg *openwechat.Message) {
go sendMessage(msg, self)
}
bot.Block()
}
func sendMessage(msg *openwechat.Message, self *openwechat.Self) {
if msg.IsTickledMe() { if msg.IsTickledMe() {
msg.ReplyText("别拍了,机器人是会被拍坏掉的。") msg.ReplyText("别拍了,机器人是会被拍坏掉的。")
return return
@ -199,7 +206,4 @@ func main() {
} }
} }
}
bot.Block()
} }