一次生成多张图片的功能
This commit is contained in:
parent
356295e051
commit
61814ff037
|
@ -1,5 +1,5 @@
|
||||||
storage.json
|
storage.json
|
||||||
.vscode
|
.vscode
|
||||||
myconf.json
|
myconf.json
|
||||||
latest.png
|
latest*.png
|
||||||
.DS_Store
|
.DS_Store
|
59
bot.py
59
bot.py
|
@ -1,3 +1,5 @@
|
||||||
|
from email.mime import image
|
||||||
|
from cv2 import grabCut
|
||||||
import flask
|
import flask
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
|
@ -25,6 +27,7 @@ class GlobalData:
|
||||||
context_for_groups = {}
|
context_for_groups = {}
|
||||||
|
|
||||||
GENERATE_PICTURE_ARG_PAT = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))")
|
GENERATE_PICTURE_ARG_PAT = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))")
|
||||||
|
GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))")
|
||||||
GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+")
|
GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+")
|
||||||
GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数
|
GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数
|
||||||
|
|
||||||
|
@ -93,27 +96,36 @@ def app_draw():
|
||||||
if prompt[i] == ':' or prompt[i] == ':':
|
if prompt[i] == ':' or prompt[i] == ':':
|
||||||
break
|
break
|
||||||
if i == len(prompt):
|
if i == len(prompt):
|
||||||
return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数):Prompt"})
|
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"})
|
||||||
|
|
||||||
|
|
||||||
match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT, prompt[:i])
|
match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT2, prompt[:i])
|
||||||
if not match_args is None:
|
if not match_args is None:
|
||||||
W = int(match_args.group(2))
|
W = int(match_args.group(2))
|
||||||
H = int(match_args.group(3))
|
H = int(match_args.group(3))
|
||||||
ITS = int(match_args.group(4))
|
ITS = int(match_args.group(4))
|
||||||
|
NUM_PIC = int(match_args.group(5))
|
||||||
if W > 2000 or H > 2000:
|
|
||||||
return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"})
|
|
||||||
|
|
||||||
if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS:
|
|
||||||
return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"})
|
|
||||||
else:
|
else:
|
||||||
if len(prompt[:i].strip()) != 0:
|
match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT, prompt[:i])
|
||||||
return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数):Prompt"})
|
if not match_args is None:
|
||||||
|
W = int(match_args.group(2))
|
||||||
|
H = int(match_args.group(3))
|
||||||
|
ITS = int(match_args.group(4))
|
||||||
|
NUM_PIC = 1
|
||||||
else:
|
else:
|
||||||
W = 768
|
if len(prompt[:i].strip()) != 0:
|
||||||
H = 768
|
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数 [图片最大数量(缺省1)]):Prompt"})
|
||||||
ITS = 50
|
else:
|
||||||
|
W = 768
|
||||||
|
H = 768
|
||||||
|
ITS = 50
|
||||||
|
NUM_PIC = 1
|
||||||
|
|
||||||
|
if W > 2500 or H > 2500:
|
||||||
|
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"})
|
||||||
|
|
||||||
|
if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS:
|
||||||
|
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"})
|
||||||
|
|
||||||
prompt = prompt[(i+1):].strip()
|
prompt = prompt[(i+1):].strip()
|
||||||
|
|
||||||
|
@ -124,12 +136,25 @@ def app_draw():
|
||||||
if len(prompts) > 1:
|
if len(prompts) > 1:
|
||||||
neg_prompt = prompts[1]
|
neg_prompt = prompts[1]
|
||||||
|
|
||||||
print(f"Generating picture with prompt = {prompt} , negative prompt = {neg_prompt}")
|
print(f"Generating {NUM_PIC} picture(s) with prompt = {prompt} , negative prompt = {neg_prompt}")
|
||||||
|
|
||||||
image = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS).images[0]
|
try:
|
||||||
|
if NUM_PIC > 1 and torch.backends.mps.is_available(): #Apple silicon上的bug:https://github.com/huggingface/diffusers/issues/363
|
||||||
|
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True,
|
||||||
|
"error_msg" : "单prompt生成多张图像在Apple silicon上无法实现,相关讨论参考https://github.com/huggingface/diffusers/issues/363"})
|
||||||
|
|
||||||
|
images = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS, num_images_per_prompt=NUM_PIC).images[:NUM_PIC]
|
||||||
|
if len(images) == 0:
|
||||||
|
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : "没有产生任何图像"})
|
||||||
|
filenames = []
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
img.save(f"latest-{i}.png")
|
||||||
|
filenames.append(f"latest-{i}.png")
|
||||||
|
return json.dumps({"user_name" : data["user_name"], "filenames" : filenames, "error" : False, "error_msg" : ""})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return json.dumps({"user_name" : data["user_name"], "filenames" : [], "error" : True, "error_msg" : str(e)})
|
||||||
|
|
||||||
image.save("latest.png")
|
|
||||||
return json.dumps({"user_name" : data["user_name"], "filename" : "latest.png", "error" : False, "error_msg" : ""})
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#openai.organization = GlobalData.OPENAI_ORGID
|
#openai.organization = GlobalData.OPENAI_ORGID
|
||||||
|
|
|
@ -25,7 +25,7 @@ type SendTextRequest struct {
|
||||||
|
|
||||||
type SendImageRequest struct {
|
type SendImageRequest struct {
|
||||||
UserName string `json:"user_name"`
|
UserName string `json:"user_name"`
|
||||||
FileName string `json:"filename"`
|
FileNames []string `json:"filenames"`
|
||||||
HasError bool `json:"error"`
|
HasError bool `json:"error"`
|
||||||
ErrorMessage string `json:"error_msg"`
|
ErrorMessage string `json:"error_msg"`
|
||||||
}
|
}
|
||||||
|
@ -122,9 +122,11 @@ func main() {
|
||||||
if resp.HasError {
|
if resp.HasError {
|
||||||
msg.ReplyText( fmt.Sprintf("生成图片出错啦QwQ,错误信息是:%s", resp.ErrorMessage) )
|
msg.ReplyText( fmt.Sprintf("生成图片出错啦QwQ,错误信息是:%s", resp.ErrorMessage) )
|
||||||
} else {
|
} else {
|
||||||
img, _ := os.Open(resp.FileName)
|
for i := 0; i < len(resp.FileNames); i++ {
|
||||||
defer img.Close()
|
img, _ := os.Open(resp.FileNames[i])
|
||||||
msg.ReplyImage(img)
|
defer img.Close()
|
||||||
|
msg.ReplyImage(img)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Reference in New Issue