common update

This commit is contained in:
HfCloud 2023-07-02 15:21:20 +08:00
parent 0b9c7f1724
commit 5497ae3a96
7 changed files with 123 additions and 57 deletions

2
.gitignore vendored
View File

@ -3,3 +3,5 @@ storage.json
myconf.json
latest*.png
.DS_Store
anything-v4.0
chatglm2-6b

View File

@ -26,7 +26,7 @@
python3需要再安装这些库使用pip安装就可以
```
pip install torch flask openai transformers diffusers accelerate
pip install torch flask openai transformers diffusers accelerate sentencepiece cpm_kernels
```
当然如果使用cuda加速建议按照<a href="https://pytorch.org">pytorch官网</a>提供的方法安装支持cuda加速的torch版本。
@ -36,12 +36,20 @@ Apple Silicon的macbook上可以使用mps后端加速我开发的时候使用
### 修改配置
<p id="ch12"> </p>
你需要有一个OpenAI账号然后将API Key写到config.json的OpenAI-API-Key字段后然后保存。
打开config.json进行修改
其余的配置通常按照默认的就可以或者可以前往OpenAI官网查看其他可用的GPT模型或者到huggingface上查看其他可用的Stable Diffusion模型。
如果你使用OpenAI的GPT
默认的配置文件就是使用OpenAI的GPT的你需要有一个OpenAI账号然后将API Key写到config.json的OpenAI-API-Key字段后然后保存。其余的配置通常按照默认的就可以或者可以前往OpenAI官网查看其他可用的GPT模型
如果你使用ChatGLM:
先把OpenAI-GPT的Enable改为false然后再把ChatGLM的Enable改为true即可。
因为懒省事所以有一些参数是写死在代码里的坏文明也是可以调整的比如超时时间可以在wechat_client.go的代码中修改这样在生成高分辨率、迭代次数非常多的图片的时候留有更多的时间。总之就是代码太简单了自己看着改一下就行了就是作者懒。还有bot.py运行的时候用WSGI什么的也就是加两行代码。懒+1
关于Diffusion模型
UseFP16一般打开就可以了能显著减少显存需求并且对画质几乎没有影响。
NoNSFWChecker一般打开就可以了用于过滤生成含有NSFW内容的图片比如涩图被过滤的图片会变为纯黑色。
@ -74,7 +82,7 @@ go run wechat_client.go
Diffusion推荐使用的模型
```
andite/anything-v4.0 : 二次元浓度很高,画人的水平不错
andite/anything-v4.0 : 二次元浓度很高,画人的水平不错目前在huggingface上该模型已被删除如果有本地缓存的话还能找到
stabilityai/stable-diffusion-2-1 : 比较通用,能生成各种图片,二次元风格真实风格都可以,但是画人的能力很差,经常出现崩坏的手,缺胳膊少腿等问题。。。
```
@ -137,6 +145,13 @@ low quality, dark, fuzzy, normal quality, ugly, twisted face, scary eyes, sexual
<table>
<tr> <th>版本</th> <th>日期</th> <th>说明</th> </tr>
<tr>
<td> v1.2 </td>
<td> 2023.07.02 </td>
<td> 跟进OpenAI的更新使用openai.ChatCompletion对话API而不是文本补全API <br> 支持清华大学的开源的<a href="https://github.com/THUDM/ChatGLM2-6B">ChatGLM2</a>作为GPT模型。<br>由于anything-v4.0模型在hunggingface上被删除默认的Diffusion模型改为stabilityai/stable-diffusion-2-1 <br> 更新了所依赖的openwechat的版本到v1.4.3 </td>
</tr>
<tr>
<td> v1.1 </td>
<td> 2023.02.07 </td>

118
bot.py
View File

@ -2,9 +2,11 @@ import json
import openai
import re
from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
from transformers import AutoTokenizer, AutoModel
import torch
import argparse
import flask
import typing
ps = argparse.ArgumentParser()
ps.add_argument("--config", default="config.json", help="Configuration file")
@ -15,10 +17,12 @@ with open(args.config) as f:
class GlobalData:
# OPENAI_ORGID = config_json[""]
OPENAI_APIKEY = config_json["OpenAI-API-Key"]
OPENAI_MODEL = config_json["GPT-Model"]
OPENAI_MODEL_TEMPERATURE = 0.66
OPENAI_MODEL_MAXTOKENS = 2048
OPENAI_APIKEY = config_json["OpenAI-GPT"]["OpenAI-Key"]
OPENAI_MODEL = config_json["OpenAI-GPT"]["GPT-Model"]
OPENAI_MODEL_TEMPERATURE = int(config_json["OpenAI-GPT"]["Temperature"])
OPENAI_MODEL_MAXTOKENS = min(2048, int(config_json["OpenAI-GPT"]["MaxTokens"]))
CHATGLM_MODEL = config_json["ChatGLM"]["GPT-Model"]
context_for_users = {}
context_for_groups = {}
@ -27,17 +31,33 @@ class GlobalData:
GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|)([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|)")
GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+")
GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数
USE_OPENAIGPT = False
USE_CHATGLM = False
if config_json["OpenAI-GPT"]["Enable"]:
print(f"Use OpenAI GPT Model({GlobalData.OPENAI_MODEL}).")
USE_OPENAIGPT = True
elif config_json["ChatGLM"]["Enable"]:
print(f"Use ChatGLM({GlobalData.CHATGLM_MODEL}) as GPT-Model.")
chatglm_tokenizer = AutoTokenizer.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True)
chatglm_model = AutoModel.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True)
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
chatglm_model = chatglm_model.to('mps')
elif torch.cuda.is_available():
chatglm_model = chatglm_model.to('cuda')
chatglm_model = chatglm_model.eval()
USE_CHATGLM = True
app = flask.Flask(__name__)
# 这个用于放行生成的任何图片替换掉默认的NSFW检查器公共场合慎重使用
def run_safety_nochecker(image, device, dtype):
print("警告:屏蔽了内容安全性检查,可能会产生有害内容")
return image, None
sd_args = {
"pretrained_model_name_or_path" : config_json["Diffusion-Model"],
"pretrained_model_name_or_path" : config_json["Diffusion"]["Diffusion-Model"],
"torch_dtype" : (torch.float16 if config_json.get("UseFP16", True) else torch.float32)
}
@ -46,29 +66,52 @@ sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.co
if config_json["NoNSFWChecker"]:
setattr(sd_pipe, "run_safety_checker", run_safety_nochecker)
if torch.backends.mps.is_available():
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
sd_pipe = sd_pipe.to("mps")
elif torch.cuda.is_available():
sd_pipe = sd_pipe.to("cuda")
GPT_SUCCESS = 0
GPT_NORESULT = 1
GPT_ERROR = 2
def call_gpt(prompt : str):
def CallOpenAIGPT(prompts : typing.List[str]):
try:
res = openai.Completion.create(
model=GlobalData.OPENAI_MODEL,
prompt=prompt,
max_tokens=GlobalData.OPENAI_MODEL_MAXTOKENS,
temperature=GlobalData.OPENAI_MODEL_TEMPERATURE)
res = openai.ChatCompletion.create(
model=config_json["OpenAI-GPT"]["GPT-Model"],
messages=prompts
)
if len(res["choices"]) > 0:
return res["choices"][0]["text"].strip()
return (GPT_SUCCESS, res["choices"][0]["message"]["content"].strip())
else:
return ""
except:
return "上下文长度超出模型限制,请对我说\“重置上下文\",然后再试一次"
return (GPT_NORESULT, "")
except openai.InvalidRequestError as e:
return (GPT_ERROR, e)
except Exception as e:
return (GPT_ERROR, e)
def CallChatGLM(msg, history : typing.List[str]):
try:
resp, _ = chatglm_model.chat(chatglm_tokenizer, msg, history)
return (GPT_SUCCESS, resp)
except Exception as e:
pass
def add_context(uid : str, is_user : bool, msg : str):
if USE_OPENAIGPT:
GlobalData.context_for_users[uid].append({
"role" : "system",
"content" : msg
}
)
elif USE_CHATGLM:
GlobalData.context_for_users[uid].append(msg)
@app.route("/chat_clear", methods=['POST'])
def app_chat_clear():
data = json.loads(flask.globals.request.get_data())
GlobalData.context_for_users[data["user_id"]] = ""
GlobalData.context_for_users[data["user_id"]] = []
print(f"Cleared context for {data['user_id']}")
return ""
@ -76,20 +119,25 @@ def app_chat_clear():
def app_chat():
data = json.loads(flask.globals.request.get_data())
#print(data)
prompt = GlobalData.context_for_users.get(data["user_id"], "")
uid = data["user_id"]
if not data["text"][-1] in ['?', '', '.', '', ',', '', '!', '']:
data["text"] += ""
prompt += "\n" + data["text"]
if len(prompt) > 4000:
prompt = prompt[:4000]
resp = call_gpt(prompt=prompt)
GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
print(f"Prompt = {prompt}\nResponse = {resp}")
if USE_OPENAIGPT:
add_context(uid, True, data["text"])
prompt = GlobalData.context_for_users[uid]
resp = CallOpenAIGPT(prompt=prompt)
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
add_context(uid, False, resp)
print(f"Prompt = {prompt}\nResponse = {resp}")
elif USE_CHATGLM:
prompt = GlobalData.context_for_users[uid]
resp, _ = CallChatGLM(msg=data["text"], history=prompt)
add_context(uid, True, data["text"])
add_context(uid, False, resp)
else:
pass
return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False})
@ -167,14 +215,12 @@ def app_draw():
def app_info():
return "\n".join([f"GPT模型{config_json['GPT-Model']}", f"Diffusion模型{config_json['Diffusion-Model']}",
"默认图片规格768x768 RGB三通道", "Diffusion默认迭代轮数20",
f"使用半精度浮点数 : {'' if config_json.get('UseFP16', True) else ''}",
f"屏蔽NSFW检查{'' if config_json['NoNSFWChecker'] else ''}"])
f"使用半精度浮点数 : {'' if config_json['Diffusion'].get('UseFP16', True) else ''}",
f"屏蔽NSFW检查{'' if config_json['Diffusion']['NoNSFWChecker'] else ''}"])
if __name__ == "__main__":
#openai.organization = GlobalData.OPENAI_ORGID
if len(GlobalData.OPENAI_APIKEY) == 0:
raise RuntimeError("Please set your OpenAI API Key in config.json")
openai.api_key = GlobalData.OPENAI_APIKEY
if USE_OPENAIGPT:
openai.api_key = GlobalData.OPENAI_APIKEY
app.run(host="0.0.0.0", port=11111)

View File

@ -1,8 +1,19 @@
{
"OpenAI-API-Key" : "",
"GPT-Model" : "text-davinci-003",
"Diffusion-Model" : "andite/anything-v4.0",
"DefaultDiffutionIterations" : 20,
"UseFP16" : true,
"NoNSFWChecker" : false
"OpenAI-GPT" : {
"Enable" : true,
"OpenAI-Key" : "Please write your openai api key",
"GPT-Model" : "gpt-3.5-turbo-0301",
"Temperature" : 0.7,
"MaxPromptTokens" : 2560,
"MaxTokens" : 4096
},
"ChatGLM" : {
"Enable" : false,
"GPT-Model" : "THUDM/chatglm2-6b"
},
"Diffusion" : {
"DiffusionModel" : "stabilityai/stable-diffusion-2-1",
"NoNSFWChecker" : true,
"UseFP16" : true
}
}

2
go.mod
View File

@ -2,4 +2,4 @@ module SXYWechatBot
go 1.19
require github.com/eatmoreapple/openwechat v1.3.8
require github.com/eatmoreapple/openwechat v1.4.3

4
go.sum
View File

@ -1,2 +1,2 @@
github.com/eatmoreapple/openwechat v1.3.8 h1:dGQy/UeuSb7eVaCgJMoxM0fBG5DhboKfYk3g+FLZOWE=
github.com/eatmoreapple/openwechat v1.3.8/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
github.com/eatmoreapple/openwechat v1.4.3 h1:hpqR3M0c180GN5e6sfkqdTmna1+vnvohqv8LkS7MecI=
github.com/eatmoreapple/openwechat v1.4.3/go.mod h1:ZxMcq7IpVWVU9JG7ERjExnm5M8/AQ6yZTtX30K3rwRQ=

View File

@ -35,14 +35,6 @@ type GenerateImageRequest struct {
Prompt string `json:"prompt"`
}
type GlobalConfig struct {
OpenAIKey string `json:"OpenAI-API-Key"`
GPTModel string `json:"GPT-Model"`
DiffusionModel string `json:"Diffusion-Model"`
DefaultDiffusionIteration int `json:"DefaultDiffutionIterations"`
UseFP16 bool `json:"UseFP16"`
}
func HttpPost(url string, data interface{}, timelim int) []byte {
// 超时时间
timeout, _ := time.ParseDuration(fmt.Sprintf("%ss", timelim)) //是的这里有个bug但是这里就是靠这个bug正常运行的