diff --git a/.gitignore b/.gitignore index 87d1efe..1c086d9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ storage.json myconf.json latest*.png .DS_Store +anything-v4.0 +chatglm2-6b diff --git a/Readme.md b/Readme.md index 56767ad..5ed5abc 100644 --- a/Readme.md +++ b/Readme.md @@ -26,7 +26,7 @@ python3需要再安装这些库,使用pip安装就可以: ``` -pip install torch flask openai transformers diffusers accelerate +pip install torch flask openai transformers diffusers accelerate sentencepiece cpm_kernels ``` 当然如果使用cuda加速建议按照pytorch官网提供的方法安装支持cuda加速的torch版本。 @@ -36,12 +36,20 @@ Apple Silicon的macbook上可以使用mps后端加速,我开发的时候使用 ### 修改配置

-你需要有一个OpenAI账号,然后将API Key写到config.json的OpenAI-API-Key字段后,然后保存。 +打开config.json进行修改 -其余的配置通常按照默认的就可以,或者可以前往OpenAI官网查看其他可用的GPT模型,或者到huggingface上查看其他可用的Stable Diffusion模型。 +如果你使用OpenAI的GPT: + +默认的配置文件就是使用OpenAI的GPT的,你需要有一个OpenAI账号,然后将API Key写到config.json的OpenAI-API-Key字段后,然后保存。其余的配置通常按照默认的就可以,或者可以前往OpenAI官网查看其他可用的GPT模型, + +如果你使用ChatGLM: + +先把OpenAI-GPT的Enable改为false,然后再把ChatGLM的Enable改为true即可。 因为懒省事所以有一些参数是写死在代码里的(坏文明),也是可以调整的,比如超时时间可以在wechat_client.go的代码中修改,这样在生成高分辨率、迭代次数非常多的图片的时候留有更多的时间。总之就是代码太简单了,自己看着改一下就行了(就是作者懒)。还有bot.py运行的时候用WSGI什么的,也就是加两行代码。(懒+1) +关于Diffusion模型: + UseFP16一般打开就可以了,能显著减少显存需求,并且对画质几乎没有影响。 NoNSFWChecker一般打开就可以了,用于过滤生成含有NSFW内容的图片比如涩图,被过滤的图片会变为纯黑色。 @@ -74,7 +82,7 @@ go run wechat_client.go Diffusion推荐使用的模型: ``` -andite/anything-v4.0 : 二次元浓度很高,画人的水平不错 +andite/anything-v4.0 : 二次元浓度很高,画人的水平不错。(目前在huggingface上该模型已被删除,如果有本地缓存的话还能找到) stabilityai/stable-diffusion-2-1 : 比较通用,能生成各种图片,二次元风格真实风格都可以,但是画人的能力很差,经常出现崩坏的手,缺胳膊少腿等问题。。。 ``` @@ -137,6 +145,13 @@ low quality, dark, fuzzy, normal quality, ugly, twisted face, scary eyes, sexual + + + + + + + diff --git a/bot.py b/bot.py index 16e8982..95a7b52 100644 --- a/bot.py +++ b/bot.py @@ -2,9 +2,11 @@ import json import openai import re from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler +from transformers import AutoTokenizer, AutoModel import torch import argparse import flask +import typing ps = argparse.ArgumentParser() ps.add_argument("--config", default="config.json", help="Configuration file") @@ -15,10 +17,12 @@ with open(args.config) as f: class GlobalData: # OPENAI_ORGID = config_json[""] - OPENAI_APIKEY = config_json["OpenAI-API-Key"] - OPENAI_MODEL = config_json["GPT-Model"] - OPENAI_MODEL_TEMPERATURE = 0.66 - OPENAI_MODEL_MAXTOKENS = 2048 + OPENAI_APIKEY = config_json["OpenAI-GPT"]["OpenAI-Key"] + OPENAI_MODEL = config_json["OpenAI-GPT"]["GPT-Model"] + OPENAI_MODEL_TEMPERATURE = int(config_json["OpenAI-GPT"]["Temperature"]) + OPENAI_MODEL_MAXTOKENS = min(2048, int(config_json["OpenAI-GPT"]["MaxTokens"])) + + CHATGLM_MODEL = config_json["ChatGLM"]["GPT-Model"] context_for_users = {} context_for_groups = {} @@ -27,17 +31,33 @@ class GlobalData: GENERATE_PICTURE_ARG_PAT2 = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))") GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+") GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数 - + +USE_OPENAIGPT = False +USE_CHATGLM = False + +if config_json["OpenAI-GPT"]["Enable"]: + print(f"Use OpenAI GPT Model({GlobalData.OPENAI_MODEL}).") + USE_OPENAIGPT = True +elif config_json["ChatGLM"]["Enable"]: + print(f"Use ChatGLM({GlobalData.CHATGLM_MODEL}) as GPT-Model.") + chatglm_tokenizer = AutoTokenizer.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True) + chatglm_model = AutoModel.from_pretrained(GlobalData.CHATGLM_MODEL, trust_remote_code=True) + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + chatglm_model = chatglm_model.to('mps') + elif torch.cuda.is_available(): + chatglm_model = chatglm_model.to('cuda') + chatglm_model = chatglm_model.eval() + USE_CHATGLM = True + app = flask.Flask(__name__) - # 这个用于放行生成的任何图片,替换掉默认的NSFW检查器,公共场合慎重使用 def run_safety_nochecker(image, device, dtype): print("警告:屏蔽了内容安全性检查,可能会产生有害内容") return image, None sd_args = { - "pretrained_model_name_or_path" : config_json["Diffusion-Model"], + "pretrained_model_name_or_path" : config_json["Diffusion"]["Diffusion-Model"], "torch_dtype" : (torch.float16 if config_json.get("UseFP16", True) else torch.float32) } @@ -46,29 +66,52 @@ sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.co if config_json["NoNSFWChecker"]: setattr(sd_pipe, "run_safety_checker", run_safety_nochecker) -if torch.backends.mps.is_available(): +if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): sd_pipe = sd_pipe.to("mps") elif torch.cuda.is_available(): sd_pipe = sd_pipe.to("cuda") + +GPT_SUCCESS = 0 +GPT_NORESULT = 1 +GPT_ERROR = 2 -def call_gpt(prompt : str): +def CallOpenAIGPT(prompts : typing.List[str]): try: - res = openai.Completion.create( - model=GlobalData.OPENAI_MODEL, - prompt=prompt, - max_tokens=GlobalData.OPENAI_MODEL_MAXTOKENS, - temperature=GlobalData.OPENAI_MODEL_TEMPERATURE) + res = openai.ChatCompletion.create( + model=config_json["OpenAI-GPT"]["GPT-Model"], + messages=prompts + ) if len(res["choices"]) > 0: - return res["choices"][0]["text"].strip() + return (GPT_SUCCESS, res["choices"][0]["message"]["content"].strip()) else: - return "" - except: - return "上下文长度超出模型限制,请对我说\“重置上下文\",然后再试一次" + return (GPT_NORESULT, "") + except openai.InvalidRequestError as e: + return (GPT_ERROR, e) + except Exception as e: + return (GPT_ERROR, e) + +def CallChatGLM(msg, history : typing.List[str]): + try: + resp, _ = chatglm_model.chat(chatglm_tokenizer, msg, history) + return (GPT_SUCCESS, resp) + except Exception as e: + pass + +def add_context(uid : str, is_user : bool, msg : str): + if USE_OPENAIGPT: + GlobalData.context_for_users[uid].append({ + "role" : "system", + "content" : msg + } + ) + elif USE_CHATGLM: + GlobalData.context_for_users[uid].append(msg) + @app.route("/chat_clear", methods=['POST']) def app_chat_clear(): data = json.loads(flask.globals.request.get_data()) - GlobalData.context_for_users[data["user_id"]] = "" + GlobalData.context_for_users[data["user_id"]] = [] print(f"Cleared context for {data['user_id']}") return "" @@ -76,20 +119,25 @@ def app_chat_clear(): def app_chat(): data = json.loads(flask.globals.request.get_data()) #print(data) - prompt = GlobalData.context_for_users.get(data["user_id"], "") + uid = data["user_id"] if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']: data["text"] += "。" - - prompt += "\n" + data["text"] - - if len(prompt) > 4000: - prompt = prompt[:4000] - - resp = call_gpt(prompt=prompt) - GlobalData.context_for_users[data["user_id"]] = (prompt + resp) - - print(f"Prompt = {prompt}\nResponse = {resp}") + + if USE_OPENAIGPT: + add_context(uid, True, data["text"]) + prompt = GlobalData.context_for_users[uid] + resp = CallOpenAIGPT(prompt=prompt) + #GlobalData.context_for_users[data["user_id"]] = (prompt + resp) + add_context(uid, False, resp) + print(f"Prompt = {prompt}\nResponse = {resp}") + elif USE_CHATGLM: + prompt = GlobalData.context_for_users[uid] + resp, _ = CallChatGLM(msg=data["text"], history=prompt) + add_context(uid, True, data["text"]) + add_context(uid, False, resp) + else: + pass return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False}) @@ -167,14 +215,12 @@ def app_draw(): def app_info(): return "\n".join([f"GPT模型:{config_json['GPT-Model']}", f"Diffusion模型:{config_json['Diffusion-Model']}", "默认图片规格:768x768 RGB三通道", "Diffusion默认迭代轮数:20", - f"使用半精度浮点数 : {'是' if config_json.get('UseFP16', True) else '否'}", - f"屏蔽NSFW检查:{'是' if config_json['NoNSFWChecker'] else '否'}"]) + f"使用半精度浮点数 : {'是' if config_json['Diffusion'].get('UseFP16', True) else '否'}", + f"屏蔽NSFW检查:{'是' if config_json['Diffusion']['NoNSFWChecker'] else '否'}"]) if __name__ == "__main__": - #openai.organization = GlobalData.OPENAI_ORGID - if len(GlobalData.OPENAI_APIKEY) == 0: - raise RuntimeError("Please set your OpenAI API Key in config.json") - openai.api_key = GlobalData.OPENAI_APIKEY + if USE_OPENAIGPT: + openai.api_key = GlobalData.OPENAI_APIKEY app.run(host="0.0.0.0", port=11111) \ No newline at end of file diff --git a/config.json b/config.json index 1e67a8b..e324732 100644 --- a/config.json +++ b/config.json @@ -1,8 +1,19 @@ { - "OpenAI-API-Key" : "", - "GPT-Model" : "text-davinci-003", - "Diffusion-Model" : "andite/anything-v4.0", - "DefaultDiffutionIterations" : 20, - "UseFP16" : true, - "NoNSFWChecker" : false + "OpenAI-GPT" : { + "Enable" : true, + "OpenAI-Key" : "Please write your openai api key", + "GPT-Model" : "gpt-3.5-turbo-0301", + "Temperature" : 0.7, + "MaxPromptTokens" : 2560, + "MaxTokens" : 4096 + }, + "ChatGLM" : { + "Enable" : false, + "GPT-Model" : "THUDM/chatglm2-6b" + }, + "Diffusion" : { + "DiffusionModel" : "stabilityai/stable-diffusion-2-1", + "NoNSFWChecker" : true, + "UseFP16" : true + } } \ No newline at end of file diff --git a/go.mod b/go.mod index 4842a09..58038e8 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module SXYWechatBot go 1.19 -require github.com/eatmoreapple/openwechat v1.3.8 +require github.com/eatmoreapple/openwechat v1.4.3 diff --git a/go.sum b/go.sum index 5081503..4f3f891 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/eatmoreapple/openwechat v1.3.8 h1:dGQy/UeuSb7eVaCgJMoxM0fBG5DhboKfYk3g+FLZOWE= -github.com/eatmoreapple/openwechat v1.3.8/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8= +github.com/eatmoreapple/openwechat v1.4.3 h1:hpqR3M0c180GN5e6sfkqdTmna1+vnvohqv8LkS7MecI= +github.com/eatmoreapple/openwechat v1.4.3/go.mod h1:ZxMcq7IpVWVU9JG7ERjExnm5M8/AQ6yZTtX30K3rwRQ= diff --git a/wechat_client.go b/wechat_client.go index e4738ba..096938b 100644 --- a/wechat_client.go +++ b/wechat_client.go @@ -35,14 +35,6 @@ type GenerateImageRequest struct { Prompt string `json:"prompt"` } -type GlobalConfig struct { - OpenAIKey string `json:"OpenAI-API-Key"` - GPTModel string `json:"GPT-Model"` - DiffusionModel string `json:"Diffusion-Model"` - DefaultDiffusionIteration int `json:"DefaultDiffutionIterations"` - UseFP16 bool `json:"UseFP16"` -} - func HttpPost(url string, data interface{}, timelim int) []byte { // 超时时间 timeout, _ := time.ParseDuration(fmt.Sprintf("%ss", timelim)) //是的,这里有个bug,但是这里就是靠这个bug正常运行的!!!???
版本 日期 说明
v1.2 2023.07.02 跟进OpenAI的更新,使用openai.ChatCompletion对话API而不是文本补全API
支持清华大学的开源的ChatGLM2作为GPT模型。
由于anything-v4.0模型在hunggingface上被删除,默认的Diffusion模型改为stabilityai/stable-diffusion-2-1
更新了所依赖的openwechat的版本到v1.4.3
v1.1 2023.02.07