common update
This commit is contained in:
parent
0b9c7f1724
commit
5497ae3a96
|
@ -3,3 +3,5 @@ storage.json
|
||||||
myconf.json
|
myconf.json
|
||||||
latest*.png
|
latest*.png
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
anything-v4.0
|
||||||
|
chatglm2-6b
|
||||||
|
|
23
Readme.md
23
Readme.md
|
@ -26,7 +26,7 @@
|
||||||
python3需要再安装这些库,使用pip安装就可以:
|
python3需要再安装这些库,使用pip安装就可以:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install torch flask openai transformers diffusers accelerate
|
pip install torch flask openai transformers diffusers accelerate sentencepiece cpm_kernels
|
||||||
```
|
```
|
||||||
|
|
||||||
当然如果使用cuda加速建议按照<a href="https://pytorch.org">pytorch官网</a>提供的方法安装支持cuda加速的torch版本。
|
当然如果使用cuda加速建议按照<a href="https://pytorch.org">pytorch官网</a>提供的方法安装支持cuda加速的torch版本。
|
||||||
|
@ -36,12 +36,20 @@ Apple Silicon的macbook上可以使用mps后端加速,我开发的时候使用
|
||||||
### 修改配置
|
### 修改配置
|
||||||
<p id="ch12"> </p>
|
<p id="ch12"> </p>
|
||||||
|
|
||||||
你需要有一个OpenAI账号,然后将API Key写到config.json的OpenAI-API-Key字段后,然后保存。
|
打开config.json进行修改
|
||||||
|
|
||||||
其余的配置通常按照默认的就可以,或者可以前往OpenAI官网查看其他可用的GPT模型,或者到huggingface上查看其他可用的Stable Diffusion模型。
|
如果你使用OpenAI的GPT:
|
||||||
|
|
||||||
|
默认的配置文件就是使用OpenAI的GPT的,你需要有一个OpenAI账号,然后将API Key写到config.json的OpenAI-API-Key字段后,然后保存。其余的配置通常按照默认的就可以,或者可以前往OpenAI官网查看其他可用的GPT模型,
|
||||||
|
|
||||||
|
如果你使用ChatGLM:
|
||||||
|
|
||||||
|
先把OpenAI-GPT的Enable改为false,然后再把ChatGLM的Enable改为true即可。
|
||||||
|
|
||||||
因为懒省事所以有一些参数是写死在代码里的(坏文明),也是可以调整的,比如超时时间可以在wechat_client.go的代码中修改,这样在生成高分辨率、迭代次数非常多的图片的时候留有更多的时间。总之就是代码太简单了,自己看着改一下就行了(就是作者懒)。还有bot.py运行的时候用WSGI什么的,也就是加两行代码。(懒+1)
|
因为懒省事所以有一些参数是写死在代码里的(坏文明),也是可以调整的,比如超时时间可以在wechat_client.go的代码中修改,这样在生成高分辨率、迭代次数非常多的图片的时候留有更多的时间。总之就是代码太简单了,自己看着改一下就行了(就是作者懒)。还有bot.py运行的时候用WSGI什么的,也就是加两行代码。(懒+1)
|
||||||
|
|
||||||
|
关于Diffusion模型:
|
||||||
|
|
||||||
UseFP16一般打开就可以了,能显著减少显存需求,并且对画质几乎没有影响。
|
UseFP16一般打开就可以了,能显著减少显存需求,并且对画质几乎没有影响。
|
||||||
|
|
||||||
NoNSFWChecker一般打开就可以了,用于过滤生成含有NSFW内容的图片比如涩图,被过滤的图片会变为纯黑色。
|
NoNSFWChecker一般打开就可以了,用于过滤生成含有NSFW内容的图片比如涩图,被过滤的图片会变为纯黑色。
|
||||||
|
@ -74,7 +82,7 @@ go run wechat_client.go
|
||||||
Diffusion推荐使用的模型:
|
Diffusion推荐使用的模型:
|
||||||
|
|
||||||
```
|
```
|
||||||
andite/anything-v4.0 : 二次元浓度很高,画人的水平不错
|
andite/anything-v4.0 : 二次元浓度很高,画人的水平不错。(目前在huggingface上该模型已被删除,如果有本地缓存的话还能找到)
|
||||||
stabilityai/stable-diffusion-2-1 : 比较通用,能生成各种图片,二次元风格真实风格都可以,但是画人的能力很差,经常出现崩坏的手,缺胳膊少腿等问题。。。
|
stabilityai/stable-diffusion-2-1 : 比较通用,能生成各种图片,二次元风格真实风格都可以,但是画人的能力很差,经常出现崩坏的手,缺胳膊少腿等问题。。。
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -137,6 +145,13 @@ low quality, dark, fuzzy, normal quality, ugly, twisted face, scary eyes, sexual
|
||||||
<table>
|
<table>
|
||||||
<tr> <th>版本</th> <th>日期</th> <th>说明</th> </tr>
|
<tr> <th>版本</th> <th>日期</th> <th>说明</th> </tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td> v1.2 </td>
|
||||||
|
<td> 2023.07.02 </td>
|
||||||
|
<td> 跟进OpenAI的更新,使用openai.ChatCompletion对话API而不是文本补全API <br> 支持清华大学的开源的<a href="https://github.com/THUDM/ChatGLM2-6B">ChatGLM2</a>作为GPT模型。<br>由于anything-v4.0模型在hunggingface上被删除,默认的Diffusion模型改为stabilityai/stable-diffusion-2-1 <br> 更新了所依赖的openwechat的版本到v1.4.3 </td>
|
||||||
|
|
||||||
|
</tr>
|
||||||
|
|
||||||
<tr>
|
<tr>
|
||||||
<td> v1.1 </td>
|
<td> v1.1 </td>
|
||||||
<td> 2023.02.07 </td>
|
<td> 2023.02.07 </td>
|
||||||
|
|
114
bot.py
114
bot.py
|
@ -2,9 +2,11 @@ import json
|
||||||
import openai
|
import openai
|
||||||
import re
|
import re
|
||||||
from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
|
from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
import flask
|
import flask
|
||||||
|
import typing
|
||||||
|
|
||||||
ps = argparse.ArgumentParser()
|
ps = argparse.ArgumentParser()
|
||||||
ps.add_argument("--config", default="config.json", help="Configuration file")
|
ps.add_argument("--config", default="config.json", help="Configuration file")
|
||||||
|
@ -15,10 +17,12 @@ with open(args.config) as f:
|
||||||
|
|
||||||
class GlobalData:
|
class GlobalData:
|
||||||
# OPENAI_ORGID = config_json[""]
|
# OPENAI_ORGID = config_json[""]
|
||||||
OPENAI_APIKEY = config_json["OpenAI-API-Key"]
|
OPENAI_APIKEY = config_json["OpenAI-GPT"]["OpenAI-Key"]
|
||||||
OPENAI_MODEL = config_json["GPT-Model"]
|
OPENAI_MODEL = config_json["OpenAI-GPT"]["GPT-Model"]
|
||||||
OPENAI_MODEL_TEMPERATURE = 0.66
|
OPENAI_MODEL_TEMPERATURE = int(config_json["OpenAI-GPT"]["Temperature"])
|
||||||
OPENAI_MODEL_MAXTOKENS = 2048
|
OPENAI_MODEL_MAXTOKENS = min(2048, int(config_json["OpenAI-GPT"]["MaxTokens"]))
|
||||||
|
|
||||||
|
CHATGLM_MODEL = config_json["ChatGLM"]["GPT-Model"]
|
||||||
|
|
||||||
context_for_users = {}
|
context_for_users = {}
|
||||||
context_for_groups = {}
|
context_for_groups = {}
|
||||||
|
@ -28,8 +32,24 @@ class GlobalData:
|
||||||
GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+")
|
GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+")
|
||||||
GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数
|
GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数
|
||||||
|
|
||||||
app = flask.Flask(__name__)
|
USE_OPENAIGPT = False
|
||||||
|
USE_CHATGLM = False
|
||||||
|
|
||||||
|
if config_json["OpenAI-GPT"]["Enable"]:
|
||||||
|
print(f"Use OpenAI GPT Model({GlobalData.OPENAI_MODEL}).")
|
||||||
|
USE_OPENAIGPT = True
|
||||||
|
elif config_json["ChatGLM"]["Enable"]:
|
||||||
|
print(f"Use ChatGLM({GlobalData.CHATGLM_MODEL}) as GPT-Model.")
|
||||||
|
chatglm_tokenizer = AutoTokenizer.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True)
|
||||||
|
chatglm_model = AutoModel.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True)
|
||||||
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
chatglm_model = chatglm_model.to('mps')
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
chatglm_model = chatglm_model.to('cuda')
|
||||||
|
chatglm_model = chatglm_model.eval()
|
||||||
|
USE_CHATGLM = True
|
||||||
|
|
||||||
|
app = flask.Flask(__name__)
|
||||||
|
|
||||||
# 这个用于放行生成的任何图片,替换掉默认的NSFW检查器,公共场合慎重使用
|
# 这个用于放行生成的任何图片,替换掉默认的NSFW检查器,公共场合慎重使用
|
||||||
def run_safety_nochecker(image, device, dtype):
|
def run_safety_nochecker(image, device, dtype):
|
||||||
|
@ -37,7 +57,7 @@ def run_safety_nochecker(image, device, dtype):
|
||||||
return image, None
|
return image, None
|
||||||
|
|
||||||
sd_args = {
|
sd_args = {
|
||||||
"pretrained_model_name_or_path" : config_json["Diffusion-Model"],
|
"pretrained_model_name_or_path" : config_json["Diffusion"]["Diffusion-Model"],
|
||||||
"torch_dtype" : (torch.float16 if config_json.get("UseFP16", True) else torch.float32)
|
"torch_dtype" : (torch.float16 if config_json.get("UseFP16", True) else torch.float32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,29 +66,52 @@ sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.co
|
||||||
if config_json["NoNSFWChecker"]:
|
if config_json["NoNSFWChecker"]:
|
||||||
setattr(sd_pipe, "run_safety_checker", run_safety_nochecker)
|
setattr(sd_pipe, "run_safety_checker", run_safety_nochecker)
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
sd_pipe = sd_pipe.to("mps")
|
sd_pipe = sd_pipe.to("mps")
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
sd_pipe = sd_pipe.to("cuda")
|
sd_pipe = sd_pipe.to("cuda")
|
||||||
|
|
||||||
def call_gpt(prompt : str):
|
GPT_SUCCESS = 0
|
||||||
|
GPT_NORESULT = 1
|
||||||
|
GPT_ERROR = 2
|
||||||
|
|
||||||
|
def CallOpenAIGPT(prompts : typing.List[str]):
|
||||||
try:
|
try:
|
||||||
res = openai.Completion.create(
|
res = openai.ChatCompletion.create(
|
||||||
model=GlobalData.OPENAI_MODEL,
|
model=config_json["OpenAI-GPT"]["GPT-Model"],
|
||||||
prompt=prompt,
|
messages=prompts
|
||||||
max_tokens=GlobalData.OPENAI_MODEL_MAXTOKENS,
|
)
|
||||||
temperature=GlobalData.OPENAI_MODEL_TEMPERATURE)
|
|
||||||
if len(res["choices"]) > 0:
|
if len(res["choices"]) > 0:
|
||||||
return res["choices"][0]["text"].strip()
|
return (GPT_SUCCESS, res["choices"][0]["message"]["content"].strip())
|
||||||
else:
|
else:
|
||||||
return ""
|
return (GPT_NORESULT, "")
|
||||||
except:
|
except openai.InvalidRequestError as e:
|
||||||
return "上下文长度超出模型限制,请对我说\“重置上下文\",然后再试一次"
|
return (GPT_ERROR, e)
|
||||||
|
except Exception as e:
|
||||||
|
return (GPT_ERROR, e)
|
||||||
|
|
||||||
|
def CallChatGLM(msg, history : typing.List[str]):
|
||||||
|
try:
|
||||||
|
resp, _ = chatglm_model.chat(chatglm_tokenizer, msg, history)
|
||||||
|
return (GPT_SUCCESS, resp)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_context(uid : str, is_user : bool, msg : str):
|
||||||
|
if USE_OPENAIGPT:
|
||||||
|
GlobalData.context_for_users[uid].append({
|
||||||
|
"role" : "system",
|
||||||
|
"content" : msg
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif USE_CHATGLM:
|
||||||
|
GlobalData.context_for_users[uid].append(msg)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/chat_clear", methods=['POST'])
|
@app.route("/chat_clear", methods=['POST'])
|
||||||
def app_chat_clear():
|
def app_chat_clear():
|
||||||
data = json.loads(flask.globals.request.get_data())
|
data = json.loads(flask.globals.request.get_data())
|
||||||
GlobalData.context_for_users[data["user_id"]] = ""
|
GlobalData.context_for_users[data["user_id"]] = []
|
||||||
print(f"Cleared context for {data['user_id']}")
|
print(f"Cleared context for {data['user_id']}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@ -76,20 +119,25 @@ def app_chat_clear():
|
||||||
def app_chat():
|
def app_chat():
|
||||||
data = json.loads(flask.globals.request.get_data())
|
data = json.loads(flask.globals.request.get_data())
|
||||||
#print(data)
|
#print(data)
|
||||||
prompt = GlobalData.context_for_users.get(data["user_id"], "")
|
uid = data["user_id"]
|
||||||
|
|
||||||
if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']:
|
if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']:
|
||||||
data["text"] += "。"
|
data["text"] += "。"
|
||||||
|
|
||||||
prompt += "\n" + data["text"]
|
if USE_OPENAIGPT:
|
||||||
|
add_context(uid, True, data["text"])
|
||||||
if len(prompt) > 4000:
|
prompt = GlobalData.context_for_users[uid]
|
||||||
prompt = prompt[:4000]
|
resp = CallOpenAIGPT(prompt=prompt)
|
||||||
|
#GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
|
||||||
resp = call_gpt(prompt=prompt)
|
add_context(uid, False, resp)
|
||||||
GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
|
print(f"Prompt = {prompt}\nResponse = {resp}")
|
||||||
|
elif USE_CHATGLM:
|
||||||
print(f"Prompt = {prompt}\nResponse = {resp}")
|
prompt = GlobalData.context_for_users[uid]
|
||||||
|
resp, _ = CallChatGLM(msg=data["text"], history=prompt)
|
||||||
|
add_context(uid, True, data["text"])
|
||||||
|
add_context(uid, False, resp)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False})
|
return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False})
|
||||||
|
|
||||||
|
@ -167,14 +215,12 @@ def app_draw():
|
||||||
def app_info():
|
def app_info():
|
||||||
return "\n".join([f"GPT模型:{config_json['GPT-Model']}", f"Diffusion模型:{config_json['Diffusion-Model']}",
|
return "\n".join([f"GPT模型:{config_json['GPT-Model']}", f"Diffusion模型:{config_json['Diffusion-Model']}",
|
||||||
"默认图片规格:768x768 RGB三通道", "Diffusion默认迭代轮数:20",
|
"默认图片规格:768x768 RGB三通道", "Diffusion默认迭代轮数:20",
|
||||||
f"使用半精度浮点数 : {'是' if config_json.get('UseFP16', True) else '否'}",
|
f"使用半精度浮点数 : {'是' if config_json['Diffusion'].get('UseFP16', True) else '否'}",
|
||||||
f"屏蔽NSFW检查:{'是' if config_json['NoNSFWChecker'] else '否'}"])
|
f"屏蔽NSFW检查:{'是' if config_json['Diffusion']['NoNSFWChecker'] else '否'}"])
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#openai.organization = GlobalData.OPENAI_ORGID
|
|
||||||
if len(GlobalData.OPENAI_APIKEY) == 0:
|
|
||||||
raise RuntimeError("Please set your OpenAI API Key in config.json")
|
|
||||||
|
|
||||||
openai.api_key = GlobalData.OPENAI_APIKEY
|
if USE_OPENAIGPT:
|
||||||
|
openai.api_key = GlobalData.OPENAI_APIKEY
|
||||||
|
|
||||||
app.run(host="0.0.0.0", port=11111)
|
app.run(host="0.0.0.0", port=11111)
|
23
config.json
23
config.json
|
@ -1,8 +1,19 @@
|
||||||
{
|
{
|
||||||
"OpenAI-API-Key" : "",
|
"OpenAI-GPT" : {
|
||||||
"GPT-Model" : "text-davinci-003",
|
"Enable" : true,
|
||||||
"Diffusion-Model" : "andite/anything-v4.0",
|
"OpenAI-Key" : "Please write your openai api key",
|
||||||
"DefaultDiffutionIterations" : 20,
|
"GPT-Model" : "gpt-3.5-turbo-0301",
|
||||||
"UseFP16" : true,
|
"Temperature" : 0.7,
|
||||||
"NoNSFWChecker" : false
|
"MaxPromptTokens" : 2560,
|
||||||
|
"MaxTokens" : 4096
|
||||||
|
},
|
||||||
|
"ChatGLM" : {
|
||||||
|
"Enable" : false,
|
||||||
|
"GPT-Model" : "THUDM/chatglm2-6b"
|
||||||
|
},
|
||||||
|
"Diffusion" : {
|
||||||
|
"DiffusionModel" : "stabilityai/stable-diffusion-2-1",
|
||||||
|
"NoNSFWChecker" : true,
|
||||||
|
"UseFP16" : true
|
||||||
|
}
|
||||||
}
|
}
|
2
go.mod
2
go.mod
|
@ -2,4 +2,4 @@ module SXYWechatBot
|
||||||
|
|
||||||
go 1.19
|
go 1.19
|
||||||
|
|
||||||
require github.com/eatmoreapple/openwechat v1.3.8
|
require github.com/eatmoreapple/openwechat v1.4.3
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -1,2 +1,2 @@
|
||||||
github.com/eatmoreapple/openwechat v1.3.8 h1:dGQy/UeuSb7eVaCgJMoxM0fBG5DhboKfYk3g+FLZOWE=
|
github.com/eatmoreapple/openwechat v1.4.3 h1:hpqR3M0c180GN5e6sfkqdTmna1+vnvohqv8LkS7MecI=
|
||||||
github.com/eatmoreapple/openwechat v1.3.8/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
|
github.com/eatmoreapple/openwechat v1.4.3/go.mod h1:ZxMcq7IpVWVU9JG7ERjExnm5M8/AQ6yZTtX30K3rwRQ=
|
||||||
|
|
|
@ -35,14 +35,6 @@ type GenerateImageRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GlobalConfig struct {
|
|
||||||
OpenAIKey string `json:"OpenAI-API-Key"`
|
|
||||||
GPTModel string `json:"GPT-Model"`
|
|
||||||
DiffusionModel string `json:"Diffusion-Model"`
|
|
||||||
DefaultDiffusionIteration int `json:"DefaultDiffutionIterations"`
|
|
||||||
UseFP16 bool `json:"UseFP16"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func HttpPost(url string, data interface{}, timelim int) []byte {
|
func HttpPost(url string, data interface{}, timelim int) []byte {
|
||||||
// 超时时间
|
// 超时时间
|
||||||
timeout, _ := time.ParseDuration(fmt.Sprintf("%ss", timelim)) //是的,这里有个bug,但是这里就是靠这个bug正常运行的!!!???
|
timeout, _ := time.ParseDuration(fmt.Sprintf("%ss", timelim)) //是的,这里有个bug,但是这里就是靠这个bug正常运行的!!!???
|
||||||
|
|
Loading…
Reference in New Issue