common update
This commit is contained in:
parent
0b9c7f1724
commit
5497ae3a96
|
@ -3,3 +3,5 @@ storage.json
|
|||
myconf.json
|
||||
latest*.png
|
||||
.DS_Store
|
||||
anything-v4.0
|
||||
chatglm2-6b
|
||||
|
|
23
Readme.md
23
Readme.md
|
@ -26,7 +26,7 @@
|
|||
python3需要再安装这些库,使用pip安装就可以:
|
||||
|
||||
```
|
||||
pip install torch flask openai transformers diffusers accelerate
|
||||
pip install torch flask openai transformers diffusers accelerate sentencepiece cpm_kernels
|
||||
```
|
||||
|
||||
当然如果使用cuda加速建议按照<a href="https://pytorch.org">pytorch官网</a>提供的方法安装支持cuda加速的torch版本。
|
||||
|
@ -36,12 +36,20 @@ Apple Silicon的macbook上可以使用mps后端加速,我开发的时候使用
|
|||
### 修改配置
|
||||
<p id="ch12"> </p>
|
||||
|
||||
你需要有一个OpenAI账号,然后将API Key写到config.json的OpenAI-API-Key字段后,然后保存。
|
||||
打开config.json进行修改
|
||||
|
||||
其余的配置通常按照默认的就可以,或者可以前往OpenAI官网查看其他可用的GPT模型,或者到huggingface上查看其他可用的Stable Diffusion模型。
|
||||
如果你使用OpenAI的GPT:
|
||||
|
||||
默认的配置文件就是使用OpenAI的GPT的,你需要有一个OpenAI账号,然后将API Key写到config.json的OpenAI-API-Key字段后,然后保存。其余的配置通常按照默认的就可以,或者可以前往OpenAI官网查看其他可用的GPT模型,
|
||||
|
||||
如果你使用ChatGLM:
|
||||
|
||||
先把OpenAI-GPT的Enable改为false,然后再把ChatGLM的Enable改为true即可。
|
||||
|
||||
因为懒省事所以有一些参数是写死在代码里的(坏文明),也是可以调整的,比如超时时间可以在wechat_client.go的代码中修改,这样在生成高分辨率、迭代次数非常多的图片的时候留有更多的时间。总之就是代码太简单了,自己看着改一下就行了(就是作者懒)。还有bot.py运行的时候用WSGI什么的,也就是加两行代码。(懒+1)
|
||||
|
||||
关于Diffusion模型:
|
||||
|
||||
UseFP16一般打开就可以了,能显著减少显存需求,并且对画质几乎没有影响。
|
||||
|
||||
NoNSFWChecker一般打开就可以了,用于过滤生成含有NSFW内容的图片比如涩图,被过滤的图片会变为纯黑色。
|
||||
|
@ -74,7 +82,7 @@ go run wechat_client.go
|
|||
Diffusion推荐使用的模型:
|
||||
|
||||
```
|
||||
andite/anything-v4.0 : 二次元浓度很高,画人的水平不错
|
||||
andite/anything-v4.0 : 二次元浓度很高,画人的水平不错。(目前在huggingface上该模型已被删除,如果有本地缓存的话还能找到)
|
||||
stabilityai/stable-diffusion-2-1 : 比较通用,能生成各种图片,二次元风格真实风格都可以,但是画人的能力很差,经常出现崩坏的手,缺胳膊少腿等问题。。。
|
||||
```
|
||||
|
||||
|
@ -137,6 +145,13 @@ low quality, dark, fuzzy, normal quality, ugly, twisted face, scary eyes, sexual
|
|||
<table>
|
||||
<tr> <th>版本</th> <th>日期</th> <th>说明</th> </tr>
|
||||
|
||||
<tr>
|
||||
<td> v1.2 </td>
|
||||
<td> 2023.07.02 </td>
|
||||
<td> 跟进OpenAI的更新,使用openai.ChatCompletion对话API而不是文本补全API <br> 支持清华大学的开源的<a href="https://github.com/THUDM/ChatGLM2-6B">ChatGLM2</a>作为GPT模型。<br>由于anything-v4.0模型在hunggingface上被删除,默认的Diffusion模型改为stabilityai/stable-diffusion-2-1 <br> 更新了所依赖的openwechat的版本到v1.4.3 </td>
|
||||
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td> v1.1 </td>
|
||||
<td> 2023.02.07 </td>
|
||||
|
|
118
bot.py
118
bot.py
|
@ -2,9 +2,11 @@ import json
|
|||
import openai
|
||||
import re
|
||||
from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
import argparse
|
||||
import flask
|
||||
import typing
|
||||
|
||||
ps = argparse.ArgumentParser()
|
||||
ps.add_argument("--config", default="config.json", help="Configuration file")
|
||||
|
@ -15,10 +17,12 @@ with open(args.config) as f:
|
|||
|
||||
class GlobalData:
|
||||
# OPENAI_ORGID = config_json[""]
|
||||
OPENAI_APIKEY = config_json["OpenAI-API-Key"]
|
||||
OPENAI_MODEL = config_json["GPT-Model"]
|
||||
OPENAI_MODEL_TEMPERATURE = 0.66
|
||||
OPENAI_MODEL_MAXTOKENS = 2048
|
||||
OPENAI_APIKEY = config_json["OpenAI-GPT"]["OpenAI-Key"]
|
||||
OPENAI_MODEL = config_json["OpenAI-GPT"]["GPT-Model"]
|
||||
OPENAI_MODEL_TEMPERATURE = int(config_json["OpenAI-GPT"]["Temperature"])
|
||||
OPENAI_MODEL_MAXTOKENS = min(2048, int(config_json["OpenAI-GPT"]["MaxTokens"]))
|
||||
|
||||
CHATGLM_MODEL = config_json["ChatGLM"]["GPT-Model"]
|
||||
|
||||
context_for_users = {}
|
||||
context_for_groups = {}
|
||||
|
@ -27,17 +31,33 @@ class GlobalData:
|
|||
GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))")
|
||||
GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+")
|
||||
GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数
|
||||
|
||||
|
||||
USE_OPENAIGPT = False
|
||||
USE_CHATGLM = False
|
||||
|
||||
if config_json["OpenAI-GPT"]["Enable"]:
|
||||
print(f"Use OpenAI GPT Model({GlobalData.OPENAI_MODEL}).")
|
||||
USE_OPENAIGPT = True
|
||||
elif config_json["ChatGLM"]["Enable"]:
|
||||
print(f"Use ChatGLM({GlobalData.CHATGLM_MODEL}) as GPT-Model.")
|
||||
chatglm_tokenizer = AutoTokenizer.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True)
|
||||
chatglm_model = AutoModel.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True)
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
chatglm_model = chatglm_model.to('mps')
|
||||
elif torch.cuda.is_available():
|
||||
chatglm_model = chatglm_model.to('cuda')
|
||||
chatglm_model = chatglm_model.eval()
|
||||
USE_CHATGLM = True
|
||||
|
||||
app = flask.Flask(__name__)
|
||||
|
||||
|
||||
# 这个用于放行生成的任何图片,替换掉默认的NSFW检查器,公共场合慎重使用
|
||||
def run_safety_nochecker(image, device, dtype):
|
||||
print("警告:屏蔽了内容安全性检查,可能会产生有害内容")
|
||||
return image, None
|
||||
|
||||
sd_args = {
|
||||
"pretrained_model_name_or_path" : config_json["Diffusion-Model"],
|
||||
"pretrained_model_name_or_path" : config_json["Diffusion"]["Diffusion-Model"],
|
||||
"torch_dtype" : (torch.float16 if config_json.get("UseFP16", True) else torch.float32)
|
||||
}
|
||||
|
||||
|
@ -46,29 +66,52 @@ sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.co
|
|||
if config_json["NoNSFWChecker"]:
|
||||
setattr(sd_pipe, "run_safety_checker", run_safety_nochecker)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
sd_pipe = sd_pipe.to("mps")
|
||||
elif torch.cuda.is_available():
|
||||
sd_pipe = sd_pipe.to("cuda")
|
||||
|
||||
GPT_SUCCESS = 0
|
||||
GPT_NORESULT = 1
|
||||
GPT_ERROR = 2
|
||||
|
||||
def call_gpt(prompt : str):
|
||||
def CallOpenAIGPT(prompts : typing.List[str]):
|
||||
try:
|
||||
res = openai.Completion.create(
|
||||
model=GlobalData.OPENAI_MODEL,
|
||||
prompt=prompt,
|
||||
max_tokens=GlobalData.OPENAI_MODEL_MAXTOKENS,
|
||||
temperature=GlobalData.OPENAI_MODEL_TEMPERATURE)
|
||||
res = openai.ChatCompletion.create(
|
||||
model=config_json["OpenAI-GPT"]["GPT-Model"],
|
||||
messages=prompts
|
||||
)
|
||||
if len(res["choices"]) > 0:
|
||||
return res["choices"][0]["text"].strip()
|
||||
return (GPT_SUCCESS, res["choices"][0]["message"]["content"].strip())
|
||||
else:
|
||||
return ""
|
||||
except:
|
||||
return "上下文长度超出模型限制,请对我说\“重置上下文\",然后再试一次"
|
||||
return (GPT_NORESULT, "")
|
||||
except openai.InvalidRequestError as e:
|
||||
return (GPT_ERROR, e)
|
||||
except Exception as e:
|
||||
return (GPT_ERROR, e)
|
||||
|
||||
def CallChatGLM(msg, history : typing.List[str]):
|
||||
try:
|
||||
resp, _ = chatglm_model.chat(chatglm_tokenizer, msg, history)
|
||||
return (GPT_SUCCESS, resp)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def add_context(uid : str, is_user : bool, msg : str):
|
||||
if USE_OPENAIGPT:
|
||||
GlobalData.context_for_users[uid].append({
|
||||
"role" : "system",
|
||||
"content" : msg
|
||||
}
|
||||
)
|
||||
elif USE_CHATGLM:
|
||||
GlobalData.context_for_users[uid].append(msg)
|
||||
|
||||
|
||||
@app.route("/chat_clear", methods=['POST'])
|
||||
def app_chat_clear():
|
||||
data = json.loads(flask.globals.request.get_data())
|
||||
GlobalData.context_for_users[data["user_id"]] = ""
|
||||
GlobalData.context_for_users[data["user_id"]] = []
|
||||
print(f"Cleared context for {data['user_id']}")
|
||||
return ""
|
||||
|
||||
|
@ -76,20 +119,25 @@ def app_chat_clear():
|
|||
def app_chat():
|
||||
data = json.loads(flask.globals.request.get_data())
|
||||
#print(data)
|
||||
prompt = GlobalData.context_for_users.get(data["user_id"], "")
|
||||
uid = data["user_id"]
|
||||
|
||||
if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']:
|
||||
data["text"] += "。"
|
||||
|
||||
prompt += "\n" + data["text"]
|
||||
|
||||
if len(prompt) > 4000:
|
||||
prompt = prompt[:4000]
|
||||
|
||||
resp = call_gpt(prompt=prompt)
|
||||
GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
|
||||
|
||||
print(f"Prompt = {prompt}\nResponse = {resp}")
|
||||
|
||||
if USE_OPENAIGPT:
|
||||
add_context(uid, True, data["text"])
|
||||
prompt = GlobalData.context_for_users[uid]
|
||||
resp = CallOpenAIGPT(prompt=prompt)
|
||||
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
|
||||
add_context(uid, False, resp)
|
||||
print(f"Prompt = {prompt}\nResponse = {resp}")
|
||||
elif USE_CHATGLM:
|
||||
prompt = GlobalData.context_for_users[uid]
|
||||
resp, _ = CallChatGLM(msg=data["text"], history=prompt)
|
||||
add_context(uid, True, data["text"])
|
||||
add_context(uid, False, resp)
|
||||
else:
|
||||
pass
|
||||
|
||||
return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False})
|
||||
|
||||
|
@ -167,14 +215,12 @@ def app_draw():
|
|||
def app_info():
|
||||
return "\n".join([f"GPT模型:{config_json['GPT-Model']}", f"Diffusion模型:{config_json['Diffusion-Model']}",
|
||||
"默认图片规格:768x768 RGB三通道", "Diffusion默认迭代轮数:20",
|
||||
f"使用半精度浮点数 : {'是' if config_json.get('UseFP16', True) else '否'}",
|
||||
f"屏蔽NSFW检查:{'是' if config_json['NoNSFWChecker'] else '否'}"])
|
||||
f"使用半精度浮点数 : {'是' if config_json['Diffusion'].get('UseFP16', True) else '否'}",
|
||||
f"屏蔽NSFW检查:{'是' if config_json['Diffusion']['NoNSFWChecker'] else '否'}"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
#openai.organization = GlobalData.OPENAI_ORGID
|
||||
if len(GlobalData.OPENAI_APIKEY) == 0:
|
||||
raise RuntimeError("Please set your OpenAI API Key in config.json")
|
||||
|
||||
openai.api_key = GlobalData.OPENAI_APIKEY
|
||||
if USE_OPENAIGPT:
|
||||
openai.api_key = GlobalData.OPENAI_APIKEY
|
||||
|
||||
app.run(host="0.0.0.0", port=11111)
|
23
config.json
23
config.json
|
@ -1,8 +1,19 @@
|
|||
{
|
||||
"OpenAI-API-Key" : "",
|
||||
"GPT-Model" : "text-davinci-003",
|
||||
"Diffusion-Model" : "andite/anything-v4.0",
|
||||
"DefaultDiffutionIterations" : 20,
|
||||
"UseFP16" : true,
|
||||
"NoNSFWChecker" : false
|
||||
"OpenAI-GPT" : {
|
||||
"Enable" : true,
|
||||
"OpenAI-Key" : "Please write your openai api key",
|
||||
"GPT-Model" : "gpt-3.5-turbo-0301",
|
||||
"Temperature" : 0.7,
|
||||
"MaxPromptTokens" : 2560,
|
||||
"MaxTokens" : 4096
|
||||
},
|
||||
"ChatGLM" : {
|
||||
"Enable" : false,
|
||||
"GPT-Model" : "THUDM/chatglm2-6b"
|
||||
},
|
||||
"Diffusion" : {
|
||||
"DiffusionModel" : "stabilityai/stable-diffusion-2-1",
|
||||
"NoNSFWChecker" : true,
|
||||
"UseFP16" : true
|
||||
}
|
||||
}
|
2
go.mod
2
go.mod
|
@ -2,4 +2,4 @@ module SXYWechatBot
|
|||
|
||||
go 1.19
|
||||
|
||||
require github.com/eatmoreapple/openwechat v1.3.8
|
||||
require github.com/eatmoreapple/openwechat v1.4.3
|
||||
|
|
4
go.sum
4
go.sum
|
@ -1,2 +1,2 @@
|
|||
github.com/eatmoreapple/openwechat v1.3.8 h1:dGQy/UeuSb7eVaCgJMoxM0fBG5DhboKfYk3g+FLZOWE=
|
||||
github.com/eatmoreapple/openwechat v1.3.8/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
|
||||
github.com/eatmoreapple/openwechat v1.4.3 h1:hpqR3M0c180GN5e6sfkqdTmna1+vnvohqv8LkS7MecI=
|
||||
github.com/eatmoreapple/openwechat v1.4.3/go.mod h1:ZxMcq7IpVWVU9JG7ERjExnm5M8/AQ6yZTtX30K3rwRQ=
|
||||
|
|
|
@ -35,14 +35,6 @@ type GenerateImageRequest struct {
|
|||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type GlobalConfig struct {
|
||||
OpenAIKey string `json:"OpenAI-API-Key"`
|
||||
GPTModel string `json:"GPT-Model"`
|
||||
DiffusionModel string `json:"Diffusion-Model"`
|
||||
DefaultDiffusionIteration int `json:"DefaultDiffutionIterations"`
|
||||
UseFP16 bool `json:"UseFP16"`
|
||||
}
|
||||
|
||||
func HttpPost(url string, data interface{}, timelim int) []byte {
|
||||
// 超时时间
|
||||
timeout, _ := time.ParseDuration(fmt.Sprintf("%ss", timelim)) //是的,这里有个bug,但是这里就是靠这个bug正常运行的!!!???
|
||||
|
|
Loading…
Reference in New Issue