Skip to content

Commit 79dd756

Browse files
authored
Merge pull request #265 from ConnectAI-E/support_vision
feat: 支持gpt4v 「WIP」
2 parents 20f1ea2 + c1d811e commit 79dd756

14 files changed

+521
-24
lines changed

code/handlers/card_common_action.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7-
"start-feishubot/logger"
8-
97
larkcard "github.com/larksuite/oapi-sdk-go/v3/card"
108
)
119

@@ -20,11 +18,13 @@ func NewCardHandler(m MessageHandler) CardHandlerFunc {
2018
handlers := []CardHandlerMeta{
2119
NewClearCardHandler,
2220
NewPicResolutionHandler,
21+
NewVisionResolutionHandler,
2322
NewPicTextMoreHandler,
2423
NewPicModeChangeHandler,
2524
NewRoleTagCardHandler,
2625
NewRoleCardHandler,
2726
NewAIModeCardHandler,
27+
NewVisionModeChangeHandler,
2828
}
2929

3030
return func(ctx context.Context, cardAction *larkcard.CardAction) (interface{}, error) {
@@ -35,7 +35,7 @@ func NewCardHandler(m MessageHandler) CardHandlerFunc {
3535
return nil, err
3636
}
3737
//pp.Println(cardMsg)
38-
logger.Debug("cardMsg ", cardMsg)
38+
//logger.Debug("cardMsg ", cardMsg)
3939
for _, handler := range handlers {
4040
h := handler(cardMsg, m)
4141
i, err := h(ctx, cardAction)

code/handlers/card_vision_action.go

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package handlers
2+
3+
import (
4+
"context"
5+
"fmt"
6+
larkcard "github.com/larksuite/oapi-sdk-go/v3/card"
7+
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
8+
"start-feishubot/services"
9+
)
10+
11+
func NewVisionResolutionHandler(cardMsg CardMsg,
12+
m MessageHandler) CardHandlerFunc {
13+
return func(ctx context.Context, cardAction *larkcard.CardAction) (interface{}, error) {
14+
if cardMsg.Kind == VisionStyleKind {
15+
CommonProcessVisionStyle(cardMsg, cardAction, m.sessionCache)
16+
return nil, nil
17+
}
18+
return nil, ErrNextHandler
19+
}
20+
}
21+
func NewVisionModeChangeHandler(cardMsg CardMsg,
22+
m MessageHandler) CardHandlerFunc {
23+
return func(ctx context.Context, cardAction *larkcard.CardAction) (interface{}, error) {
24+
if cardMsg.Kind == VisionModeChangeKind {
25+
newCard, err, done := CommonProcessVisionModeChange(cardMsg, m.sessionCache)
26+
if done {
27+
return newCard, err
28+
}
29+
return nil, nil
30+
}
31+
return nil, ErrNextHandler
32+
}
33+
}
34+
35+
func CommonProcessVisionStyle(msg CardMsg,
36+
cardAction *larkcard.CardAction,
37+
cache services.SessionServiceCacheInterface) {
38+
option := cardAction.Action.Option
39+
fmt.Println(larkcore.Prettify(msg))
40+
cache.SetVisionDetail(msg.SessionId, services.VisionDetail(option))
41+
//send text
42+
replyMsg(context.Background(), "图片解析度调整为:"+option,
43+
&msg.MsgId)
44+
}
45+
46+
func CommonProcessVisionModeChange(cardMsg CardMsg,
47+
session services.SessionServiceCacheInterface) (
48+
interface{}, error, bool) {
49+
if cardMsg.Value == "1" {
50+
51+
sessionId := cardMsg.SessionId
52+
session.Clear(sessionId)
53+
session.SetMode(sessionId,
54+
services.ModeVision)
55+
session.SetVisionDetail(sessionId,
56+
services.VisionDetailLow)
57+
58+
newCard, _ :=
59+
newSendCard(
60+
withHeader("🕵️️ 已进入图片推理模式", larkcard.TemplateBlue),
61+
withVisionDetailLevelBtn(&sessionId),
62+
withNote("提醒:回复图片,让LLM和你一起推理图片的内容。"))
63+
return newCard, nil, true
64+
}
65+
if cardMsg.Value == "0" {
66+
newCard, _ := newSendCard(
67+
withHeader("️🎒 机器人提醒", larkcard.TemplateGreen),
68+
withMainMd("依旧保留此话题的上下文信息"),
69+
withNote("我们可以继续探讨这个话题,期待和您聊天。如果您有其他问题或者想要讨论的话题,请告诉我哦"),
70+
)
71+
return newCard, nil, true
72+
}
73+
return nil, nil, false
74+
}

code/handlers/common.go

+27-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ func msgFilter(msg string) string {
1313
//replace @到下一个非空的字段 为 ''
1414
regex := regexp.MustCompile(`@[^ ]*`)
1515
return regex.ReplaceAllString(msg, "")
16-
1716
}
1817

1918
// Parse rich text json to text
@@ -47,6 +46,33 @@ func parsePostContent(content string) string {
4746
return msgFilter(text)
4847
}
4948

49+
func parsePostImageKeys(content string) []string {
50+
var contentMap map[string]interface{}
51+
err := json.Unmarshal([]byte(content), &contentMap)
52+
53+
if err != nil {
54+
fmt.Println(err)
55+
return nil
56+
}
57+
58+
var imageKeys []string
59+
60+
if contentMap["content"] == nil {
61+
return imageKeys
62+
}
63+
64+
contentList := contentMap["content"].([]interface{})
65+
for _, v := range contentList {
66+
for _, v1 := range v.([]interface{}) {
67+
if v1.(map[string]interface{})["tag"] == "img" {
68+
imageKeys = append(imageKeys, v1.(map[string]interface{})["image_key"].(string))
69+
}
70+
}
71+
}
72+
73+
return imageKeys
74+
}
75+
5076
func parseContent(content, msgType string) string {
5177
//"{\"text\":\"@_user_1 hahaha\"}",
5278
//only get text content hahaha

code/handlers/event_common_action.go

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ type MsgInfo struct {
1919
qParsed string
2020
fileKey string
2121
imageKey string
22+
imageKeys []string // post 消息卡片中的图片组
2223
sessionId *string
2324
mention []*larkim.MentionEvent
2425
}

code/handlers/event_msg_action.go

+17
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@ func setDefaultPrompt(msg []openai.Messages) []openai.Messages {
2323
return msg
2424
}
2525

26+
//func setDefaultVisionPrompt(msg []openai.VisionMessages) []openai.VisionMessages {
27+
// if !hasSystemRole(msg) {
28+
// msg = append(msg, openai.VisionMessages{
29+
// Role: "system", Content: []openai.ContentType{
30+
// {Type: "text", Text: "You are ChatGPT4V, " +
31+
// "You are ChatGPT4V, " +
32+
// "a large language and picture model trained by" +
33+
// " OpenAI. " +
34+
// "Answer in user's language as concisely as" +
35+
// " possible. Knowledge cutoff: 20230601 " +
36+
// "Current date" + time.Now().Format("20060102"),
37+
// }},
38+
// })
39+
// }
40+
// return msg
41+
//}
42+
2643
type MessageAction struct { /*消息*/
2744
}
2845

code/handlers/event_vision_action.go

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package handlers
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
"start-feishubot/initialization"
8+
"start-feishubot/services"
9+
"start-feishubot/services/openai"
10+
"start-feishubot/utils"
11+
12+
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
13+
)
14+
15+
type VisionAction struct { /*图片推理*/
16+
}
17+
18+
func (va *VisionAction) Execute(a *ActionInfo) bool {
19+
if !AzureModeCheck(a) {
20+
return true
21+
}
22+
23+
if isVisionCommand(a) {
24+
initializeVisionMode(a)
25+
sendVisionInstructionCard(*a.ctx, a.info.sessionId, a.info.msgId)
26+
return false
27+
}
28+
29+
mode := a.handler.sessionCache.GetMode(*a.info.sessionId)
30+
31+
if a.info.msgType == "image" {
32+
if mode != services.ModeVision {
33+
sendVisionModeCheckCard(*a.ctx, a.info.sessionId, a.info.msgId)
34+
return false
35+
}
36+
37+
return va.handleVisionImage(a)
38+
}
39+
40+
if a.info.msgType == "post" && mode == services.ModeVision {
41+
return va.handleVisionPost(a)
42+
}
43+
44+
return true
45+
}
46+
47+
func isVisionCommand(a *ActionInfo) bool {
48+
_, foundPic := utils.EitherTrimEqual(a.info.qParsed, "/vision", "图片推理")
49+
return foundPic
50+
}
51+
52+
func initializeVisionMode(a *ActionInfo) {
53+
a.handler.sessionCache.Clear(*a.info.sessionId)
54+
a.handler.sessionCache.SetMode(*a.info.sessionId, services.ModeVision)
55+
a.handler.sessionCache.SetVisionDetail(*a.info.sessionId, services.VisionDetailHigh)
56+
}
57+
58+
func (va *VisionAction) handleVisionImage(a *ActionInfo) bool {
59+
detail := a.handler.sessionCache.GetVisionDetail(*a.info.sessionId)
60+
base64, err := downloadAndEncodeImage(a.info.imageKey, a.info.msgId)
61+
if err != nil {
62+
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
63+
return false
64+
}
65+
66+
return va.processImageAndReply(a, base64, detail)
67+
}
68+
69+
func (va *VisionAction) handleVisionPost(a *ActionInfo) bool {
70+
detail := a.handler.sessionCache.GetVisionDetail(*a.info.sessionId)
71+
var base64s []string
72+
73+
for _, imageKey := range a.info.imageKeys {
74+
if imageKey == "" {
75+
continue
76+
}
77+
base64, err := downloadAndEncodeImage(imageKey, a.info.msgId)
78+
if err != nil {
79+
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
80+
return false
81+
}
82+
base64s = append(base64s, base64)
83+
}
84+
85+
if len(base64s) == 0 {
86+
replyMsg(*a.ctx, "🤖️:请发送一张图片", a.info.msgId)
87+
return false
88+
}
89+
90+
return va.processMultipleImagesAndReply(a, base64s, detail)
91+
}
92+
93+
func downloadAndEncodeImage(imageKey string, msgId *string) (string, error) {
94+
f := fmt.Sprintf("%s.png", imageKey)
95+
defer os.Remove(f)
96+
97+
req := larkim.NewGetMessageResourceReqBuilder().MessageId(*msgId).FileKey(imageKey).Type("image").Build()
98+
resp, err := initialization.GetLarkClient().Im.MessageResource.Get(context.Background(), req)
99+
if err != nil {
100+
return "", err
101+
}
102+
103+
resp.WriteFile(f)
104+
return openai.GetBase64FromImage(f)
105+
}
106+
107+
func replyWithErrorMsg(ctx context.Context, err error, msgId *string) {
108+
replyMsg(ctx, fmt.Sprintf("🤖️:图片下载失败,请稍后再试~\n 错误信息: %v", err), msgId)
109+
}
110+
111+
func (va *VisionAction) processImageAndReply(a *ActionInfo, base64 string, detail string) bool {
112+
msg := createVisionMessages("解释这个图片", base64, detail)
113+
completions, err := a.handler.gpt.GetVisionInfo(msg)
114+
if err != nil {
115+
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
116+
return false
117+
}
118+
sendVisionTopicCard(*a.ctx, a.info.sessionId, a.info.msgId, completions.Content)
119+
return false
120+
}
121+
122+
func (va *VisionAction) processMultipleImagesAndReply(a *ActionInfo, base64s []string, detail string) bool {
123+
msg := createMultipleVisionMessages(a.info.qParsed, base64s, detail)
124+
completions, err := a.handler.gpt.GetVisionInfo(msg)
125+
if err != nil {
126+
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
127+
return false
128+
}
129+
sendVisionTopicCard(*a.ctx, a.info.sessionId, a.info.msgId, completions.Content)
130+
return false
131+
}
132+
133+
func createVisionMessages(query, base64Image, detail string) []openai.VisionMessages {
134+
return []openai.VisionMessages{
135+
{
136+
Role: "user",
137+
Content: []openai.ContentType{
138+
{Type: "text", Text: query},
139+
{Type: "image_url", ImageURL: &openai.ImageURL{
140+
URL: "data:image/jpeg;base64," + base64Image,
141+
Detail: detail,
142+
}},
143+
},
144+
},
145+
}
146+
}
147+
148+
func createMultipleVisionMessages(query string, base64Images []string, detail string) []openai.VisionMessages {
149+
content := []openai.ContentType{{Type: "text", Text: query}}
150+
for _, base64Image := range base64Images {
151+
content = append(content, openai.ContentType{
152+
Type: "image_url",
153+
ImageURL: &openai.ImageURL{
154+
URL: "data:image/jpeg;base64," + base64Image,
155+
Detail: detail,
156+
},
157+
})
158+
}
159+
return []openai.VisionMessages{{Role: "user", Content: content}}
160+
}

code/handlers/handler.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2
8282
qParsed: strings.Trim(parseContent(*content, msgType), " "),
8383
fileKey: parseFileKey(*content),
8484
imageKey: parseImageKey(*content),
85+
imageKeys: parsePostImageKeys(*content),
8586
sessionId: sessionId,
8687
mention: mention,
8788
}
@@ -94,17 +95,17 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2
9495
&ProcessedUniqueAction{}, //避免重复处理
9596
&ProcessMentionAction{}, //判断机器人是否应该被调用
9697
&AudioAction{}, //语音处理
97-
&EmptyAction{}, //空消息处理
9898
&ClearAction{}, //清除消息处理
99+
&VisionAction{}, //图片推理处理
99100
&PicAction{}, //图片处理
100101
&AIModeAction{}, //模式切换处理
101102
&RoleListAction{}, //角色列表处理
102103
&HelpAction{}, //帮助处理
103104
&BalanceAction{}, //余额处理
104105
&RolePlayAction{}, //角色扮演处理
105106
&MessageAction{}, //消息处理
107+
&EmptyAction{}, //空消息处理
106108
&StreamMessageAction{}, //流式消息处理
107-
108109
}
109110
chain(data, actions...)
110111
return nil

0 commit comments

Comments
 (0)