Skip to content

Commit

Permalink
Add support for multi part chat messages
Browse files Browse the repository at this point in the history
OpenAI has recently introduced a new model called gpt-4-visual-preview,
which now supports images as input. The chat completion endpoint accepts
multi-part chat messages, where the content can be an array of structs
in addition to the usual string format.

This commit introduces new structures and constants to represent
different types of content parts. It also implements the json.Marshaler
and json.Unmarshaler interfaces on ChatCompletionMessage.
  • Loading branch information
rkintzi committed Nov 14, 2023
1 parent 515de02 commit b468820
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 2 deletions.
86 changes: 84 additions & 2 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai

import (
"context"
"encoding/json"
"errors"
"net/http"
)
Expand Down Expand Up @@ -51,9 +52,32 @@ type PromptAnnotation struct {
ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"`
}

const (
ImageURLDetailHigh = "high"
ImageURLDetailLow = "low"
ImageURLDetailAuto = "auto"
)

type ChatMessageImageURL struct {
URL string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}

const (
ChatMessagePartTypeText string = "text"
ChatMessagePartTypeImageURL string = "image_url"
)

type ChatMessagePart struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
}

type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
MultiContent []ChatMessagePart

// This property isn't in the official documentation, but it's in
// the documentation for the official library for python:
Expand All @@ -70,6 +94,64 @@ type ChatCompletionMessage struct {
ToolCallID string `json:"tool_call_id,omitempty"`
}

func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
if m.Content != "" && m.MultiContent != nil {
return nil, errors.New("can't use both Content and MultiContent files simultaneously")
}
if m.MultiContent != nil {
msg := struct {
Role string `json:"role"`
Content string `json:"-"`
MultiContent []ChatMessagePart `json:"content,omitempty"`
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}(m)
return json.Marshal(msg)
}
msg := struct {
Role string `json:"role"`
Content string `json:"content"`
MultiContent []ChatMessagePart `json:"-"`
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}(m)
return json.Marshal(msg)
}

func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
multiMsg := struct {
Role string `json:"role"`
Content string
MultiContent []ChatMessagePart `json:"content"`
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}{}
if err := json.Unmarshal(bs, &multiMsg); err == nil {
*m = ChatCompletionMessage(multiMsg)
return nil
}
msg := struct {
Role string `json:"role"`
Content string `json:"content"`
MultiContent []ChatMessagePart
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}{}
if err := json.Unmarshal(bs, &msg); err != nil {
return err
}

Check warning on line 150 in chat.go

View check run for this annotation

Codecov / codecov/patch

chat.go#L149-L150

Added lines #L149 - L150 were not covered by tests
*m = ChatCompletionMessage(msg)
return nil
}

type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
Expand Down
84 changes: 84 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,90 @@ func TestAzureChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateAzureChatCompletion error")
}

func TestMultipartChatCompletions(t *testing.T) {
client, server, teardown := setupAzureTestServer()
defer teardown()
server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint)

_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5,
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
MultiContent: []openai.ChatMessagePart{
{
Type: openai.ChatMessagePartTypeText,
Text: "Hello!",
},
{
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
URL: "URL",
Detail: openai.ImageURLDetailLow,
},
},
},
},
},
})
checks.NoError(t, err, "CreateAzureChatCompletion error")
}

func TestMultipartChatMessageSerialization(t *testing.T) {
jsonText := `[{"role":"system","content":"system-message"},` +
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`

messages := []openai.ChatCompletionMessage{
{
Role: "user",
Content: "some-text",
MultiContent: []openai.ChatMessagePart{
{
Type: "text",
Text: "nice-text",
},
},
},
}

var msgs []openai.ChatCompletionMessage
err := json.Unmarshal([]byte(jsonText), &msgs)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}
if len(msgs) != 2 {
t.Errorf("unexpected number of messages")
}
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
t.Errorf("invalid user message: %v", msgs[0])
}
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
t.Errorf("invalid user message")
}
parts := msgs[1].MultiContent
if parts[0].Type != "text" || parts[0].Text != "nice-text" {
t.Errorf("invalid text part: %v", parts[0])
}
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
t.Errorf("invalid image_url part")
}

s, err := json.Marshal(msgs)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}
res := strings.ReplaceAll(string(s), " ", "")
if res != jsonText {
t.Fatalf("invalid message: %s", string(s))
}
_, err = json.Marshal(messages)
if err == nil {
t.Fatalf("Expected error")
}
}

// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
Expand Down

0 comments on commit b468820

Please sign in to comment.