WeChatBot/bot.py

241 lines
9.7 KiB
Python
Raw Normal View History

2023-02-05 21:27:35 +08:00
import json
import openai
import re
2023-02-07 20:52:45 +08:00
from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
2023-07-02 15:21:20 +08:00
from transformers import AutoTokenizer, AutoModel
2023-02-05 21:27:35 +08:00
import torch
import argparse
2023-02-07 20:52:45 +08:00
import flask
2023-07-02 15:21:20 +08:00
import typing
2023-07-02 16:43:10 +08:00
import traceback
2023-02-05 21:27:35 +08:00
ps = argparse.ArgumentParser()
ps.add_argument("--config", default="config.json", help="Configuration file")
args = ps.parse_args()
with open(args.config) as f:
config_json = json.load(f)
class GlobalData:
# OPENAI_ORGID = config_json[""]
2023-07-02 15:21:20 +08:00
OPENAI_APIKEY = config_json["OpenAI-GPT"]["OpenAI-Key"]
OPENAI_MODEL = config_json["OpenAI-GPT"]["GPT-Model"]
OPENAI_MODEL_TEMPERATURE = int(config_json["OpenAI-GPT"]["Temperature"])
OPENAI_MODEL_MAXTOKENS = min(2048, int(config_json["OpenAI-GPT"]["MaxTokens"]))
CHATGLM_MODEL = config_json["ChatGLM"]["GPT-Model"]
2023-02-05 21:27:35 +08:00
context_for_users = {}
context_for_groups = {}
GENERATE_PICTURE_ARG_PAT = re.compile("(\(|)([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|)")
2023-02-06 01:28:27 +08:00
GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|)([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|)")
2023-02-05 21:27:35 +08:00
GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+")
GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数
2023-07-02 15:21:20 +08:00
USE_OPENAIGPT = False
USE_CHATGLM = False
if config_json["OpenAI-GPT"]["Enable"]:
print(f"Use OpenAI GPT Model({GlobalData.OPENAI_MODEL}).")
USE_OPENAIGPT = True
elif config_json["ChatGLM"]["Enable"]:
print(f"Use ChatGLM({GlobalData.CHATGLM_MODEL}) as GPT-Model.")
chatglm_tokenizer = AutoTokenizer.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True)
chatglm_model = AutoModel.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True)
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
chatglm_model = chatglm_model.to('mps')
elif torch.cuda.is_available():
chatglm_model = chatglm_model.to('cuda')
chatglm_model = chatglm_model.eval()
USE_CHATGLM = True
2023-02-05 21:27:35 +08:00
app = flask.Flask(__name__)
2023-02-07 20:52:45 +08:00
# 这个用于放行生成的任何图片替换掉默认的NSFW检查器公共场合慎重使用
def run_safety_nochecker(image, device, dtype):
print("警告:屏蔽了内容安全性检查,可能会产生有害内容")
return image, None
sd_args = {
2023-07-02 15:21:20 +08:00
"pretrained_model_name_or_path" : config_json["Diffusion"]["Diffusion-Model"],
2023-07-02 16:57:43 +08:00
"torch_dtype" : (torch.float16 if config_json["Diffusion"].get("UseFP16", True) else torch.float32)
2023-02-07 20:52:45 +08:00
}
sd_pipe = StableDiffusionPipeline.from_pretrained(**sd_args)
2023-02-05 21:27:35 +08:00
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
2023-07-02 16:43:10 +08:00
if config_json["Diffusion"]["NoNSFWChecker"]:
2023-02-07 20:52:45 +08:00
setattr(sd_pipe, "run_safety_checker", run_safety_nochecker)
2023-07-02 15:21:20 +08:00
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
2023-02-05 21:27:35 +08:00
sd_pipe = sd_pipe.to("mps")
elif torch.cuda.is_available():
sd_pipe = sd_pipe.to("cuda")
2023-07-02 15:21:20 +08:00
GPT_SUCCESS = 0
GPT_NORESULT = 1
GPT_ERROR = 2
2023-02-05 21:27:35 +08:00
2023-07-02 15:21:20 +08:00
def CallOpenAIGPT(prompts : typing.List[str]):
2023-02-05 21:27:35 +08:00
try:
2023-07-02 15:21:20 +08:00
res = openai.ChatCompletion.create(
model=config_json["OpenAI-GPT"]["GPT-Model"],
messages=prompts
)
2023-02-05 21:27:35 +08:00
if len(res["choices"]) > 0:
2023-07-02 15:21:20 +08:00
return (GPT_SUCCESS, res["choices"][0]["message"]["content"].strip())
2023-02-05 21:27:35 +08:00
else:
2023-07-02 15:21:20 +08:00
return (GPT_NORESULT, "")
except openai.InvalidRequestError as e:
return (GPT_ERROR, e)
except Exception as e:
2023-07-02 16:43:10 +08:00
traceback.print_exception(e)
return (GPT_ERROR, str(e))
2023-07-02 15:21:20 +08:00
def CallChatGLM(msg, history : typing.List[str]):
try:
2023-07-02 16:43:10 +08:00
resp, hist = chatglm_model.chat(chatglm_tokenizer, msg, history=history)
if isinstance(resp, tuple):
resp = resp[0]
2023-07-02 15:21:20 +08:00
return (GPT_SUCCESS, resp)
except Exception as e:
2023-07-02 16:43:10 +08:00
return (GPT_ERROR, str(e))
2023-07-02 15:21:20 +08:00
def add_context(uid : str, is_user : bool, msg : str):
2023-07-02 16:43:10 +08:00
if not uid in GlobalData.context_for_users:
GlobalData.context_for_users[uid] = []
2023-07-02 15:21:20 +08:00
if USE_OPENAIGPT:
GlobalData.context_for_users[uid].append({
"role" : "system",
"content" : msg
}
)
elif USE_CHATGLM:
GlobalData.context_for_users[uid].append(msg)
2023-07-02 16:43:10 +08:00
def get_context(uid : str):
if not uid in GlobalData.context_for_users:
GlobalData.context_for_users[uid] = []
return GlobalData.context_for_users[uid]
2023-07-02 15:21:20 +08:00
2023-02-05 21:27:35 +08:00
@app.route("/chat_clear", methods=['POST'])
def app_chat_clear():
data = json.loads(flask.globals.request.get_data())
2023-07-02 15:21:20 +08:00
GlobalData.context_for_users[data["user_id"]] = []
2023-02-05 21:27:35 +08:00
print(f"Cleared context for {data['user_id']}")
return ""
@app.route("/chat", methods=['POST'])
def app_chat():
data = json.loads(flask.globals.request.get_data())
#print(data)
2023-07-02 15:21:20 +08:00
uid = data["user_id"]
2023-02-05 21:27:35 +08:00
if not data["text"][-1] in ['?', '', '.', '', ',', '', '!', '']:
data["text"] += ""
2023-07-02 15:21:20 +08:00
if USE_OPENAIGPT:
add_context(uid, True, data["text"])
2023-07-02 16:43:10 +08:00
#prompt = GlobalData.context_for_users[uid]
prompt = get_context(uid)
2023-07-02 15:21:20 +08:00
resp = CallOpenAIGPT(prompt=prompt)
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
2023-07-02 16:43:10 +08:00
add_context(uid, False, resp[1])
#print(f"Prompt = {prompt}\nResponse = {resp[1]}")
2023-07-02 15:21:20 +08:00
elif USE_CHATGLM:
2023-07-02 16:43:10 +08:00
#prompt = GlobalData.context_for_users[uid]
prompt = get_context(uid)
resp = CallChatGLM(msg=data["text"], history=prompt)
add_context(uid, True, (data["text"], resp[1]))
2023-07-02 15:21:20 +08:00
else:
pass
2023-07-02 16:43:10 +08:00
if resp[0] == GPT_SUCCESS:
return json.dumps({"user_id" : data["user_id"], "text" : resp[1], "error" : False, "error_msg" : ""})
else:
return json.dumps({"user_id" : data["user_id"], "text" : "", "error" : True, "error_msg" : resp[1]})
2023-02-05 21:27:35 +08:00
@app.route("/draw", methods=['POST'])
def app_draw():
data = json.loads(flask.globals.request.get_data())
prompt = data["prompt"]
i = 0
for i in range(len(prompt)):
if prompt[i] == ':' or prompt[i] == '':
break
if i == len(prompt):
2023-02-06 01:28:27 +08:00
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对正确的格式是生成图片Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)])Prompt"})
2023-02-05 21:27:35 +08:00
2023-02-06 01:28:27 +08:00
match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT2, prompt[:i])
2023-02-05 21:27:35 +08:00
if not match_args is None:
W = int(match_args.group(2))
H = int(match_args.group(3))
ITS = int(match_args.group(4))
2023-02-06 01:28:27 +08:00
NUM_PIC = int(match_args.group(5))
2023-02-05 21:27:35 +08:00
else:
2023-02-06 01:28:27 +08:00
match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT, prompt[:i])
if not match_args is None:
W = int(match_args.group(2))
H = int(match_args.group(3))
ITS = int(match_args.group(4))
NUM_PIC = 1
2023-02-05 21:27:35 +08:00
else:
2023-02-06 01:28:27 +08:00
if len(prompt[:i].strip()) != 0:
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对正确的格式是生成图片Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)])Prompt"})
else:
W = 768
H = 768
ITS = config_json.get('DefaultDiffutionIterations', 20)
2023-02-06 01:28:27 +08:00
NUM_PIC = 1
if W > 2500 or H > 2500:
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"})
2023-02-05 21:27:35 +08:00
2023-02-06 01:28:27 +08:00
if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS:
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}"})
2023-02-05 21:27:35 +08:00
prompt = prompt[(i+1):].strip()
prompts = re.split(GlobalData.GENERATE_PICTURE_NEG_PROMPT_DELIMETER, prompt)
prompt = prompts[0]
neg_prompt = None
if len(prompts) > 1:
neg_prompt = prompts[1]
2023-02-06 01:28:27 +08:00
print(f"Generating {NUM_PIC} picture(s) with prompt = {prompt} , negative prompt = {neg_prompt}")
try:
if NUM_PIC > 1 and torch.backends.mps.is_available(): #Apple silicon上的bughttps://github.com/huggingface/diffusers/issues/363
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True,
"error_msg" : "单prompt生成多张图像在Apple silicon上无法实现相关讨论参考https://github.com/huggingface/diffusers/issues/363"})
images = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS, num_images_per_prompt=NUM_PIC).images[:NUM_PIC]
if len(images) == 0:
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "没有产生任何图像"})
filenames = []
for i, img in enumerate(images):
img.save(f"latest-{i}.png")
filenames.append(f"latest-{i}.png")
return json.dumps({"user_name" : data["user_name"], "filenames" : filenames, "error" : False, "error_msg" : ""})
except Exception as e:
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : str(e)})
2023-02-05 21:27:35 +08:00
2023-02-07 20:52:45 +08:00
@app.route("/info", methods=['POST', 'GET'])
def app_info():
2023-07-02 17:29:36 +08:00
return "\n".join([f"GPT模型{config_json['OpenAI-GPT']['GPT-Model'] if USE_OPENAIGPT else config_json['ChatGLM']['GPT-Model']}", f"Diffusion模型{config_json['Diffusion']['Diffusion-Model']}",
2023-02-07 20:52:45 +08:00
"默认图片规格768x768 RGB三通道", "Diffusion默认迭代轮数20",
2023-07-02 15:21:20 +08:00
f"使用半精度浮点数 : {'' if config_json['Diffusion'].get('UseFP16', True) else ''}",
f"屏蔽NSFW检查{'' if config_json['Diffusion']['NoNSFWChecker'] else ''}"])
2023-02-05 21:27:35 +08:00
if __name__ == "__main__":
2023-02-06 02:11:48 +08:00
2023-07-02 15:21:20 +08:00
if USE_OPENAIGPT:
openai.api_key = GlobalData.OPENAI_APIKEY
2023-02-05 21:27:35 +08:00
app.run(host="0.0.0.0", port=11111)