From 61814ff037944b7b5e16aff62deb5bd9ca72ed02 Mon Sep 17 00:00:00 2001 From: HfCloud Date: Mon, 6 Feb 2023 01:28:27 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=80=E6=AC=A1=E7=94=9F=E6=88=90=E5=A4=9A?= =?UTF-8?q?=E5=BC=A0=E5=9B=BE=E7=89=87=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 ++-- bot.py | 59 ++++++++++++++++++++++++++++++++++-------------- wechat_client.go | 10 ++++---- 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index cac1a67..87d1efe 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ storage.json .vscode myconf.json -latest.png -.DS_Store \ No newline at end of file +latest*.png +.DS_Store diff --git a/bot.py b/bot.py index a4c32df..7c74bb1 100644 --- a/bot.py +++ b/bot.py @@ -1,3 +1,5 @@ +from email.mime import image +from cv2 import grabCut import flask import requests import json @@ -25,6 +27,7 @@ class GlobalData: context_for_groups = {} GENERATE_PICTURE_ARG_PAT = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))") + GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))") GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+") GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数 @@ -93,28 +96,37 @@ def app_draw(): if prompt[i] == ':' or prompt[i] == ':': break if i == len(prompt): - return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数):Prompt"}) + return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"}) - match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT, prompt[:i]) + match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT2, prompt[:i]) if not match_args is None: W = int(match_args.group(2)) H = int(match_args.group(3)) ITS = int(match_args.group(4)) - - if W > 2000 or H > 2000: - return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"}) - - if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS: - return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"}) + NUM_PIC = int(match_args.group(5)) else: - if len(prompt[:i].strip()) != 0: - return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数):Prompt"}) + match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT, prompt[:i]) + if not match_args is None: + W = int(match_args.group(2)) + H = int(match_args.group(3)) + ITS = int(match_args.group(4)) + NUM_PIC = 1 else: - W = 768 - H = 768 - ITS = 50 + if len(prompt[:i].strip()) != 0: + return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"}) + else: + W = 768 + H = 768 + ITS = 50 + NUM_PIC = 1 + + if W > 2500 or H > 2500: + return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"}) + if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS: + return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"}) + prompt = prompt[(i+1):].strip() prompts = re.split(GlobalData.GENERATE_PICTURE_NEG_PROMPT_DELIMETER, prompt) @@ -124,12 +136,25 @@ def app_draw(): if len(prompts) > 1: neg_prompt = prompts[1] - print(f"Generating picture with prompt = {prompt} , negative prompt = {neg_prompt}") + print(f"Generating {NUM_PIC} picture(s) with prompt = {prompt} , negative prompt = {neg_prompt}") + + try: + if NUM_PIC > 1 and torch.backends.mps.is_available(): #Apple silicon上的bug:https://github.com/huggingface/diffusers/issues/363 + return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, + "error_msg" : "单prompt生成多张图像在Apple silicon上无法实现,相关讨论参考https://github.com/huggingface/diffusers/issues/363"}) - image = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS).images[0] + images = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS, num_images_per_prompt=NUM_PIC).images[:NUM_PIC] + if len(images) == 0: + return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "没有产生任何图像"}) + filenames = [] + for i, img in enumerate(images): + img.save(f"latest-{i}.png") + filenames.append(f"latest-{i}.png") + return json.dumps({"user_name" : data["user_name"], "filenames" : filenames, "error" : False, "error_msg" : ""}) + + except Exception as e: + return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : str(e)}) - image.save("latest.png") - return json.dumps({"user_name" : data["user_name"], "filename" : "latest.png", "error" : False, "error_msg" : ""}) if __name__ == "__main__": #openai.organization = GlobalData.OPENAI_ORGID diff --git a/wechat_client.go b/wechat_client.go index b768243..226a823 100644 --- a/wechat_client.go +++ b/wechat_client.go @@ -25,7 +25,7 @@ type SendTextRequest struct { type SendImageRequest struct { UserName string `json:"user_name"` - FileName string `json:"filename"` + FileNames []string `json:"filenames"` HasError bool `json:"error"` ErrorMessage string `json:"error_msg"` } @@ -122,9 +122,11 @@ func main() { if resp.HasError { msg.ReplyText( fmt.Sprintf("生成图片出错啦QwQ,错误信息是:%s", resp.ErrorMessage) ) } else { - img, _ := os.Open(resp.FileName) - defer img.Close() - msg.ReplyImage(img) + for i := 0; i < len(resp.FileNames); i++ { + img, _ := os.Open(resp.FileNames[i]) + defer img.Close() + msg.ReplyImage(img) + } } } else {