commit 9a0f4e4a3e85c263469a8c32093e1d5f92de5be2 Author: HfCloud Date: Sun Feb 5 21:27:35 2023 +0800 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..492997d --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +storage.json +*.png +.vscode +myconf.json +.DS_Store \ No newline at end of file diff --git a/bot.py b/bot.py new file mode 100644 index 0000000..7b79a8c --- /dev/null +++ b/bot.py @@ -0,0 +1,139 @@ +from distutils.command.config import config +import flask +import requests +import json +import openai +import re +from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler +import torch +import argparse + +ps = argparse.ArgumentParser() +ps.add_argument("--config", default="config.json", help="Configuration file") +args = ps.parse_args() + +with open(args.config) as f: + config_json = json.load(f) + +class GlobalData: + # OPENAI_ORGID = config_json[""] + OPENAI_APIKEY = config_json["OpenAI-API-Key"] + OPENAI_MODEL = config_json["GPT-Model"] + OPENAI_MODEL_TEMPERATURE = 0.66 + OPENAI_MODEL_MAXTOKENS = 2048 + + context_for_users = {} + context_for_groups = {} + + GENERATE_PICTURE_ARG_PAT = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))") + GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+") + GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数 + +app = flask.Flask(__name__) +sd_pipe = StableDiffusionPipeline.from_pretrained(config_json["Diffusion-Model"], torch_dtype=torch.float32) +sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) +if torch.backends.mps.is_available(): + sd_pipe = sd_pipe.to("mps") +elif torch.cuda.is_available(): + sd_pipe = sd_pipe.to("cuda") + +def send_text_to_user(user_id : str, text : str): + requests.post(url="http://localhost:11110/send_text", + data=json.dumps({"user_id" : user_id, "text" : text, "in_group" : False})) + +def call_gpt(prompt : str): + try: + res = openai.Completion.create( + model=GlobalData.OPENAI_MODEL, + prompt=prompt, + max_tokens=GlobalData.OPENAI_MODEL_MAXTOKENS, + temperature=GlobalData.OPENAI_MODEL_TEMPERATURE) + if len(res["choices"]) > 0: + return res["choices"][0]["text"].strip() + else: + return "" + except: + return "上下文长度超出模型限制,请对我说\“重置上下文\",然后再试一次" + +@app.route("/chat_clear", methods=['POST']) +def app_chat_clear(): + data = json.loads(flask.globals.request.get_data()) + GlobalData.context_for_users[data["user_id"]] = "" + print(f"Cleared context for {data['user_id']}") + return "" + +@app.route("/chat", methods=['POST']) +def app_chat(): + data = json.loads(flask.globals.request.get_data()) + #print(data) + prompt = GlobalData.context_for_users.get(data["user_id"], "") + + if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']: + data["text"] += "。" + + prompt += "\n" + data["text"] + + if len(prompt) > 4000: + prompt = prompt[:4000] + + resp = call_gpt(prompt=prompt) + GlobalData.context_for_users[data["user_id"]] = (prompt + resp) + + print(f"Prompt = {prompt}\nResponse = {resp}") + + return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False}) + +@app.route("/draw", methods=['POST']) +def app_draw(): + data = json.loads(flask.globals.request.get_data()) + + prompt = data["prompt"] + + i = 0 + for i in range(len(prompt)): + if prompt[i] == ':' or prompt[i] == ':': + break + if i == len(prompt): + return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数):Prompt"}) + + + match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT, prompt[:i]) + if not match_args is None: + W = int(match_args.group(2)) + H = int(match_args.group(3)) + ITS = int(match_args.group(4)) + + if W > 2000 or H > 2000: + return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"}) + + if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS: + return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"}) + else: + if len(prompt[:i].strip()) != 0: + return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数):Prompt"}) + else: + W = 768 + H = 768 + ITS = 50 + + prompt = prompt[(i+1):].strip() + + prompts = re.split(GlobalData.GENERATE_PICTURE_NEG_PROMPT_DELIMETER, prompt) + prompt = prompts[0] + + neg_prompt = None + if len(prompts) > 1: + neg_prompt = prompts[1] + + print(f"Generating picture with prompt = {prompt} , negative prompt = {neg_prompt}") + + image = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS).images[0] + + image.save("latest.png") + return json.dumps({"user_name" : data["user_name"], "filename" : "latest.png", "error" : False, "error_msg" : ""}) + +if __name__ == "__main__": + #openai.organization = GlobalData.OPENAI_ORGID + openai.api_key = GlobalData.OPENAI_APIKEY + + app.run(host="0.0.0.0", port=11111) \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000..7814332 --- /dev/null +++ b/config.json @@ -0,0 +1,5 @@ +{ + "OpenAI-API-Key" : "", + "GPT-Model" : "text-davinci-003", + "Diffusion-Model" : "stabilityai/stable-diffusion-2-1" +} \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..4842a09 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module SXYWechatBot + +go 1.19 + +require github.com/eatmoreapple/openwechat v1.3.8 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5081503 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/eatmoreapple/openwechat v1.3.8 h1:dGQy/UeuSb7eVaCgJMoxM0fBG5DhboKfYk3g+FLZOWE= +github.com/eatmoreapple/openwechat v1.3.8/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8= diff --git a/start.sh b/start.sh new file mode 100755 index 0000000..9597bee --- /dev/null +++ b/start.sh @@ -0,0 +1,2 @@ +go run wechat_client.go & python bot.py --config $1 + diff --git a/wechat_client.go b/wechat_client.go new file mode 100644 index 0000000..4f6f38a --- /dev/null +++ b/wechat_client.go @@ -0,0 +1,187 @@ +package main +import ( + "fmt" + "strings" + "net/http" + "io/ioutil" + "time" + "os" + "bytes" + "encoding/json" + "github.com/eatmoreapple/openwechat" +) + +func Use(vals ...interface{}) { + for _, val := range vals { + _ = val + } +} + +type SendTextRequest struct { + InGroup bool `json:"in_group"` //本来想用于区分在群聊和非群聊时的上下文记忆规则,但是最终没有实现... + UserID string `json:"user_id"` + Text string `json:"text"` +} + +type SendImageRequest struct { + UserName string `json:"user_name"` + FileName string `json:"filename"` + HasError bool `json:"error"` + ErrorMessage string `json:"error_msg"` +} + +type GenerateImageRequest struct { + UserName string `json:"user_name"` + Prompt string `json:"prompt"` +} + +func HttpPost(url string, data interface{}, timelim int) []byte { + // 超时时间 + timeout, _ := time.ParseDuration(fmt.Sprintf("%ss", timelim)) + + client := &http.Client{Timeout: timeout} + jsonStr, _ := json.Marshal(data) + resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonStr)) + if err != nil { + return []byte("") + } + defer resp.Body.Close() + + result, _ := ioutil.ReadAll(resp.Body) + return result + +// ——————————————— +// 版权声明:本文为CSDN博主「gaoluhua」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 +// 原文链接:https://blog.csdn.net/gaoluhua/article/details/124855716 +} + +func main() { + + bot := openwechat.DefaultBot(openwechat.Desktop) // 桌面模式,上面登录不上的可以尝试切换这种模式 + reloadStorage := openwechat.NewJsonFileHotReloadStorage("storage.json") + defer reloadStorage.Close() + + err := bot.PushLogin(reloadStorage, openwechat.NewRetryLoginOption()) + if err != nil { + fmt.Println(err) + return + } + + // 获取登陆的用户 + self, err := bot.GetCurrentUser() + if err != nil { + fmt.Println(err) + return + } + + Use(self) + + // 注册消息处理函数 + bot.MessageHandler = func(msg *openwechat.Message) { + if msg.IsTickledMe() { + msg.ReplyText("别拍了,机器人是会被拍坏掉的。") + return + } + + if !msg.IsText() { + return + } + + content := msg.Content + if msg.IsSendByGroup() && !msg.IsAt() { + return + } + + if msg.IsSendByGroup() && msg.IsAt() { + atheader := fmt.Sprintf("@%s", self.NickName) + //fmt.Println(atheader) + if strings.HasPrefix(content, atheader) { + content = strings.TrimLeft(content[len(atheader):], " \t\n") + for { + if strings.HasPrefix(content, " ") { + content = content[1:] + } else { + break + } + } + } + } + //fmt.Println(content) + + if strings.HasPrefix(content, "生成图片") { + // 调用Stable Diffusion + // msg.ReplyText("这个功能还没有实现,可以先期待一下~") + sender, _ := msg.Sender() + + content = strings.TrimLeft(content[len("生成图片"):], " \t\n") + + resp_raw := HttpPost("http://localhost:11111/draw", GenerateImageRequest{UserName : sender.ID(), Prompt : content}, 120) + if len(resp_raw) == 0 { + msg.ReplyText("生成图片出错啦QwQ,或许可以再试一次") + return + } + + resp := SendImageRequest{} + json.Unmarshal(resp_raw, &resp) + //fmt.Println(resp.FileName) + if resp.HasError { + msg.ReplyText( fmt.Sprintf("生成图片出错啦QwQ,错误信息是:%s", resp.ErrorMessage) ) + } else { + img, _ := os.Open(resp.FileName) + defer img.Close() + msg.ReplyImage(img) + } + + } else { + // 调用ChatGPT + + sender, _ := msg.Sender() + //var group openwechat.Group{} = nil + var group *openwechat.Group = nil + + if msg.IsSendByGroup() { + group = &openwechat.Group{User : sender} + } + + if content == "重置上下文" { + if !msg.IsSendByGroup() { + HttpPost("http://localhost:11111/chat_clear", SendTextRequest{InGroup : msg.IsSendByGroup(), UserID : sender.ID(), Text : ""}, 60) + } else { + HttpPost("http://localhost:11111/chat_clear", SendTextRequest{InGroup : msg.IsSendByGroup(), UserID : group.ID(), Text : ""}, 60) + } + msg.ReplyText("OK,我忘掉了之前的上下文。") + return + } + + resp := SendTextRequest{} + resp_raw := []byte("") + + if !msg.IsSendByGroup() { + resp_raw = HttpPost("http://localhost:11111/chat", SendTextRequest{InGroup : false, UserID : sender.ID(), Text : msg.Content}, 60) + } else { + resp_raw = HttpPost("http://localhost:11111/chat", SendTextRequest{InGroup : false, UserID : group.ID(), Text : msg.Content}, 60) + } + if len(resp_raw) == 0 { + msg.ReplyText("运算超时了QAQ,或许可以再试一次。") + return + } + + json.Unmarshal(resp_raw, &resp) + + if len(resp.Text) == 0 { + msg.ReplyText("GPT对此没有什么想说的,换个话题吧。") + } else { + if msg.IsSendByGroup() { + sender_in_group, _ := msg.SenderInGroup() + nickname := sender_in_group.NickName + msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text)) + } else { + msg.ReplyText(resp.Text) + } + } + + } + } + + bot.Block() +} \ No newline at end of file