init
This commit is contained in:
commit
9a0f4e4a3e
|
@ -0,0 +1,5 @@
|
|||
storage.json
|
||||
*.png
|
||||
.vscode
|
||||
myconf.json
|
||||
.DS_Store
|
|
@ -0,0 +1,139 @@
|
|||
from distutils.command.config import config
|
||||
import flask
|
||||
import requests
|
||||
import json
|
||||
import openai
|
||||
import re
|
||||
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
ps = argparse.ArgumentParser()
|
||||
ps.add_argument("--config", default="config.json", help="Configuration file")
|
||||
args = ps.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
config_json = json.load(f)
|
||||
|
||||
class GlobalData:
|
||||
# OPENAI_ORGID = config_json[""]
|
||||
OPENAI_APIKEY = config_json["OpenAI-API-Key"]
|
||||
OPENAI_MODEL = config_json["GPT-Model"]
|
||||
OPENAI_MODEL_TEMPERATURE = 0.66
|
||||
OPENAI_MODEL_MAXTOKENS = 2048
|
||||
|
||||
context_for_users = {}
|
||||
context_for_groups = {}
|
||||
|
||||
GENERATE_PICTURE_ARG_PAT = re.compile("(\(|()([0-9]+)[ \n\t]+([0-9]+)[ \n\t]+([0-9]+)(\)|))")
|
||||
GENERATE_PICTURE_NEG_PROMPT_DELIMETER = re.compile("\n+")
|
||||
GENERATE_PICTURE_MAX_ITS = 200 #最大迭代次数
|
||||
|
||||
app = flask.Flask(__name__)
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(config_json["Diffusion-Model"], torch_dtype=torch.float32)
|
||||
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
|
||||
if torch.backends.mps.is_available():
|
||||
sd_pipe = sd_pipe.to("mps")
|
||||
elif torch.cuda.is_available():
|
||||
sd_pipe = sd_pipe.to("cuda")
|
||||
|
||||
def send_text_to_user(user_id : str, text : str):
|
||||
requests.post(url="http://localhost:11110/send_text",
|
||||
data=json.dumps({"user_id" : user_id, "text" : text, "in_group" : False}))
|
||||
|
||||
def call_gpt(prompt : str):
|
||||
try:
|
||||
res = openai.Completion.create(
|
||||
model=GlobalData.OPENAI_MODEL,
|
||||
prompt=prompt,
|
||||
max_tokens=GlobalData.OPENAI_MODEL_MAXTOKENS,
|
||||
temperature=GlobalData.OPENAI_MODEL_TEMPERATURE)
|
||||
if len(res["choices"]) > 0:
|
||||
return res["choices"][0]["text"].strip()
|
||||
else:
|
||||
return ""
|
||||
except:
|
||||
return "上下文长度超出模型限制,请对我说\“重置上下文\",然后再试一次"
|
||||
|
||||
@app.route("/chat_clear", methods=['POST'])
|
||||
def app_chat_clear():
|
||||
data = json.loads(flask.globals.request.get_data())
|
||||
GlobalData.context_for_users[data["user_id"]] = ""
|
||||
print(f"Cleared context for {data['user_id']}")
|
||||
return ""
|
||||
|
||||
@app.route("/chat", methods=['POST'])
|
||||
def app_chat():
|
||||
data = json.loads(flask.globals.request.get_data())
|
||||
#print(data)
|
||||
prompt = GlobalData.context_for_users.get(data["user_id"], "")
|
||||
|
||||
if not data["text"][-1] in ['?', '?', '.', '。', ',', ',', '!', '!']:
|
||||
data["text"] += "。"
|
||||
|
||||
prompt += "\n" + data["text"]
|
||||
|
||||
if len(prompt) > 4000:
|
||||
prompt = prompt[:4000]
|
||||
|
||||
resp = call_gpt(prompt=prompt)
|
||||
GlobalData.context_for_users[data["user_id"]] = (prompt + resp)
|
||||
|
||||
print(f"Prompt = {prompt}\nResponse = {resp}")
|
||||
|
||||
return json.dumps({"user_id" : data["user_id"], "text" : resp, "in_group" : False})
|
||||
|
||||
@app.route("/draw", methods=['POST'])
|
||||
def app_draw():
|
||||
data = json.loads(flask.globals.request.get_data())
|
||||
|
||||
prompt = data["prompt"]
|
||||
|
||||
i = 0
|
||||
for i in range(len(prompt)):
|
||||
if prompt[i] == ':' or prompt[i] == ':':
|
||||
break
|
||||
if i == len(prompt):
|
||||
return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数):Prompt"})
|
||||
|
||||
|
||||
match_args = re.match(GlobalData.GENERATE_PICTURE_ARG_PAT, prompt[:i])
|
||||
if not match_args is None:
|
||||
W = int(match_args.group(2))
|
||||
H = int(match_args.group(3))
|
||||
ITS = int(match_args.group(4))
|
||||
|
||||
if W > 2000 or H > 2000:
|
||||
return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "你要求的图片太大了,我不干了~"})
|
||||
|
||||
if ITS > GlobalData.GENERATE_PICTURE_MAX_ITS:
|
||||
return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : f"迭代次数太多了,不要超过{GlobalData.GENERATE_PICTURE_MAX_ITS}次"})
|
||||
else:
|
||||
if len(prompt[:i].strip()) != 0:
|
||||
return json.dumps({"user_name" : data["user_name"], "filename" : "", "error" : True, "error_msg" : "格式不对,正确的格式是:生成图片:Prompt 或者 生成图片(宽 高 迭代次数):Prompt"})
|
||||
else:
|
||||
W = 768
|
||||
H = 768
|
||||
ITS = 50
|
||||
|
||||
prompt = prompt[(i+1):].strip()
|
||||
|
||||
prompts = re.split(GlobalData.GENERATE_PICTURE_NEG_PROMPT_DELIMETER, prompt)
|
||||
prompt = prompts[0]
|
||||
|
||||
neg_prompt = None
|
||||
if len(prompts) > 1:
|
||||
neg_prompt = prompts[1]
|
||||
|
||||
print(f"Generating picture with prompt = {prompt} , negative prompt = {neg_prompt}")
|
||||
|
||||
image = sd_pipe(prompt=prompt, negative_prompt=neg_prompt, width=W, height=H, num_inference_steps=ITS).images[0]
|
||||
|
||||
image.save("latest.png")
|
||||
return json.dumps({"user_name" : data["user_name"], "filename" : "latest.png", "error" : False, "error_msg" : ""})
|
||||
|
||||
if __name__ == "__main__":
|
||||
#openai.organization = GlobalData.OPENAI_ORGID
|
||||
openai.api_key = GlobalData.OPENAI_APIKEY
|
||||
|
||||
app.run(host="0.0.0.0", port=11111)
|
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"OpenAI-API-Key" : "",
|
||||
"GPT-Model" : "text-davinci-003",
|
||||
"Diffusion-Model" : "stabilityai/stable-diffusion-2-1"
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
module SXYWechatBot
|
||||
|
||||
go 1.19
|
||||
|
||||
require github.com/eatmoreapple/openwechat v1.3.8
|
|
@ -0,0 +1,2 @@
|
|||
github.com/eatmoreapple/openwechat v1.3.8 h1:dGQy/UeuSb7eVaCgJMoxM0fBG5DhboKfYk3g+FLZOWE=
|
||||
github.com/eatmoreapple/openwechat v1.3.8/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
|
|
@ -0,0 +1,2 @@
|
|||
go run wechat_client.go & python bot.py --config $1
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
package main
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"net/http"
|
||||
"io/ioutil"
|
||||
"time"
|
||||
"os"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/eatmoreapple/openwechat"
|
||||
)
|
||||
|
||||
func Use(vals ...interface{}) {
|
||||
for _, val := range vals {
|
||||
_ = val
|
||||
}
|
||||
}
|
||||
|
||||
type SendTextRequest struct {
|
||||
InGroup bool `json:"in_group"` //本来想用于区分在群聊和非群聊时的上下文记忆规则,但是最终没有实现...
|
||||
UserID string `json:"user_id"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type SendImageRequest struct {
|
||||
UserName string `json:"user_name"`
|
||||
FileName string `json:"filename"`
|
||||
HasError bool `json:"error"`
|
||||
ErrorMessage string `json:"error_msg"`
|
||||
}
|
||||
|
||||
type GenerateImageRequest struct {
|
||||
UserName string `json:"user_name"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
func HttpPost(url string, data interface{}, timelim int) []byte {
|
||||
// 超时时间
|
||||
timeout, _ := time.ParseDuration(fmt.Sprintf("%ss", timelim))
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
jsonStr, _ := json.Marshal(data)
|
||||
resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonStr))
|
||||
if err != nil {
|
||||
return []byte("")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
result, _ := ioutil.ReadAll(resp.Body)
|
||||
return result
|
||||
|
||||
// ———————————————
|
||||
// 版权声明:本文为CSDN博主「gaoluhua」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
|
||||
// 原文链接:https://blog.csdn.net/gaoluhua/article/details/124855716
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
bot := openwechat.DefaultBot(openwechat.Desktop) // 桌面模式,上面登录不上的可以尝试切换这种模式
|
||||
reloadStorage := openwechat.NewJsonFileHotReloadStorage("storage.json")
|
||||
defer reloadStorage.Close()
|
||||
|
||||
err := bot.PushLogin(reloadStorage, openwechat.NewRetryLoginOption())
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取登陆的用户
|
||||
self, err := bot.GetCurrentUser()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
Use(self)
|
||||
|
||||
// 注册消息处理函数
|
||||
bot.MessageHandler = func(msg *openwechat.Message) {
|
||||
if msg.IsTickledMe() {
|
||||
msg.ReplyText("别拍了,机器人是会被拍坏掉的。")
|
||||
return
|
||||
}
|
||||
|
||||
if !msg.IsText() {
|
||||
return
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
if msg.IsSendByGroup() && !msg.IsAt() {
|
||||
return
|
||||
}
|
||||
|
||||
if msg.IsSendByGroup() && msg.IsAt() {
|
||||
atheader := fmt.Sprintf("@%s", self.NickName)
|
||||
//fmt.Println(atheader)
|
||||
if strings.HasPrefix(content, atheader) {
|
||||
content = strings.TrimLeft(content[len(atheader):], " \t\n")
|
||||
for {
|
||||
if strings.HasPrefix(content, " ") {
|
||||
content = content[1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//fmt.Println(content)
|
||||
|
||||
if strings.HasPrefix(content, "生成图片") {
|
||||
// 调用Stable Diffusion
|
||||
// msg.ReplyText("这个功能还没有实现,可以先期待一下~")
|
||||
sender, _ := msg.Sender()
|
||||
|
||||
content = strings.TrimLeft(content[len("生成图片"):], " \t\n")
|
||||
|
||||
resp_raw := HttpPost("http://localhost:11111/draw", GenerateImageRequest{UserName : sender.ID(), Prompt : content}, 120)
|
||||
if len(resp_raw) == 0 {
|
||||
msg.ReplyText("生成图片出错啦QwQ,或许可以再试一次")
|
||||
return
|
||||
}
|
||||
|
||||
resp := SendImageRequest{}
|
||||
json.Unmarshal(resp_raw, &resp)
|
||||
//fmt.Println(resp.FileName)
|
||||
if resp.HasError {
|
||||
msg.ReplyText( fmt.Sprintf("生成图片出错啦QwQ,错误信息是:%s", resp.ErrorMessage) )
|
||||
} else {
|
||||
img, _ := os.Open(resp.FileName)
|
||||
defer img.Close()
|
||||
msg.ReplyImage(img)
|
||||
}
|
||||
|
||||
} else {
|
||||
// 调用ChatGPT
|
||||
|
||||
sender, _ := msg.Sender()
|
||||
//var group openwechat.Group{} = nil
|
||||
var group *openwechat.Group = nil
|
||||
|
||||
if msg.IsSendByGroup() {
|
||||
group = &openwechat.Group{User : sender}
|
||||
}
|
||||
|
||||
if content == "重置上下文" {
|
||||
if !msg.IsSendByGroup() {
|
||||
HttpPost("http://localhost:11111/chat_clear", SendTextRequest{InGroup : msg.IsSendByGroup(), UserID : sender.ID(), Text : ""}, 60)
|
||||
} else {
|
||||
HttpPost("http://localhost:11111/chat_clear", SendTextRequest{InGroup : msg.IsSendByGroup(), UserID : group.ID(), Text : ""}, 60)
|
||||
}
|
||||
msg.ReplyText("OK,我忘掉了之前的上下文。")
|
||||
return
|
||||
}
|
||||
|
||||
resp := SendTextRequest{}
|
||||
resp_raw := []byte("")
|
||||
|
||||
if !msg.IsSendByGroup() {
|
||||
resp_raw = HttpPost("http://localhost:11111/chat", SendTextRequest{InGroup : false, UserID : sender.ID(), Text : msg.Content}, 60)
|
||||
} else {
|
||||
resp_raw = HttpPost("http://localhost:11111/chat", SendTextRequest{InGroup : false, UserID : group.ID(), Text : msg.Content}, 60)
|
||||
}
|
||||
if len(resp_raw) == 0 {
|
||||
msg.ReplyText("运算超时了QAQ,或许可以再试一次。")
|
||||
return
|
||||
}
|
||||
|
||||
json.Unmarshal(resp_raw, &resp)
|
||||
|
||||
if len(resp.Text) == 0 {
|
||||
msg.ReplyText("GPT对此没有什么想说的,换个话题吧。")
|
||||
} else {
|
||||
if msg.IsSendByGroup() {
|
||||
sender_in_group, _ := msg.SenderInGroup()
|
||||
nickname := sender_in_group.NickName
|
||||
msg.ReplyText(fmt.Sprintf("@%s\n%s\n-------------------\n%s", nickname, content, resp.Text))
|
||||
} else {
|
||||
msg.ReplyText(resp.Text)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
bot.Block()
|
||||
}
|
Loading…
Reference in New Issue