forked from sashabaranov/go-openai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add whisper 1 support (sashabaranov#117)
* Add whisper 1 support * Resolve linting issues for audio source files
- Loading branch information
Showing
2 changed files
with
243 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
package gogpt | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"fmt" | ||
"io" | ||
"mime/multipart" | ||
"net/http" | ||
"os" | ||
) | ||
|
||
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. | ||
const ( | ||
Whisper1 = "whisper-1" | ||
) | ||
|
||
// AudioRequest represents a request structure for audio API. | ||
type AudioRequest struct { | ||
Model string | ||
FilePath string | ||
} | ||
|
||
// AudioResponse represents a response structure for audio API. | ||
type AudioResponse struct { | ||
Text string `json:"text"` | ||
} | ||
|
||
// CreateTranscription — API call to create a transcription. Returns transcribed text. | ||
func (c *Client) CreateTranscription( | ||
ctx context.Context, | ||
request AudioRequest, | ||
) (response AudioResponse, err error) { | ||
response, err = c.callAudioAPI(ctx, request, "transcriptions") | ||
return | ||
} | ||
|
||
// CreateTranscription — API call to create a transcription. Returns transcribed text. | ||
func (c *Client) CreateTranslation( | ||
ctx context.Context, | ||
request AudioRequest, | ||
) (response AudioResponse, err error) { | ||
response, err = c.callAudioAPI(ctx, request, "translations") | ||
return | ||
} | ||
|
||
// callAudioAPI — API call to an audio endpoint. | ||
func (c *Client) callAudioAPI( | ||
ctx context.Context, | ||
request AudioRequest, | ||
endpointSuffix string, | ||
) (response AudioResponse, err error) { | ||
var formBody bytes.Buffer | ||
w := multipart.NewWriter(&formBody) | ||
|
||
if err = audioMultipartForm(request, w); err != nil { | ||
return | ||
} | ||
|
||
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) | ||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody) | ||
if err != nil { | ||
return | ||
} | ||
req.Header.Add("Content-Type", w.FormDataContentType()) | ||
|
||
err = c.sendRequest(req, &response) | ||
return | ||
} | ||
|
||
// audioMultipartForm creates a form with audio file contents and the name of the model to use for | ||
// audio processing. | ||
func audioMultipartForm(request AudioRequest, w *multipart.Writer) error { | ||
f, err := os.Open(request.FilePath) | ||
if err != nil { | ||
return fmt.Errorf("opening audio file: %w", err) | ||
} | ||
|
||
fw, err := w.CreateFormFile("file", f.Name()) | ||
if err != nil { | ||
return fmt.Errorf("creating form file: %w", err) | ||
} | ||
|
||
if _, err = io.Copy(fw, f); err != nil { | ||
return fmt.Errorf("reading from opened audio file: %w", err) | ||
} | ||
|
||
fw, err = w.CreateFormField("model") | ||
if err != nil { | ||
return fmt.Errorf("creating form field: %w", err) | ||
} | ||
|
||
modelName := bytes.NewReader([]byte(request.Model)) | ||
if _, err = io.Copy(fw, modelName); err != nil { | ||
return fmt.Errorf("writing model name: %w", err) | ||
} | ||
w.Close() | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
package gogpt_test | ||
|
||
import ( | ||
"bytes" | ||
"errors" | ||
"io" | ||
"mime" | ||
"mime/multipart" | ||
"net/http" | ||
"os" | ||
"path/filepath" | ||
"strings" | ||
|
||
. "github.com/sashabaranov/go-gpt3" | ||
"github.com/sashabaranov/go-gpt3/internal/test" | ||
|
||
"context" | ||
"testing" | ||
) | ||
|
||
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. | ||
func TestAudio(t *testing.T) { | ||
server := test.NewTestServer() | ||
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) | ||
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) | ||
// create the test server | ||
var err error | ||
ts := server.OpenAITestServer() | ||
ts.Start() | ||
defer ts.Close() | ||
|
||
config := DefaultConfig(test.GetTestToken()) | ||
config.BaseURL = ts.URL + "/v1" | ||
client := NewClientWithConfig(config) | ||
|
||
testcases := []struct { | ||
name string | ||
createFn func(context.Context, AudioRequest) (AudioResponse, error) | ||
}{ | ||
{ | ||
"transcribe", | ||
client.CreateTranscription, | ||
}, | ||
{ | ||
"translate", | ||
client.CreateTranslation, | ||
}, | ||
} | ||
|
||
ctx := context.Background() | ||
|
||
dir, cleanup := createTestDirectory(t) | ||
defer cleanup() | ||
|
||
for _, tc := range testcases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
path := filepath.Join(dir, "fake.mp3") | ||
createTestFile(t, path) | ||
|
||
req := AudioRequest{ | ||
FilePath: path, | ||
Model: "whisper-3", | ||
} | ||
_, err = tc.createFn(ctx, req) | ||
if err != nil { | ||
t.Fatalf("audio API error: %v", err) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
// createTestFile creates a fake file with "hello" as the content. | ||
func createTestFile(t *testing.T, path string) { | ||
file, err := os.Create(path) | ||
if err != nil { | ||
t.Fatalf("failed to create file %v", err) | ||
} | ||
if _, err = file.WriteString("hello"); err != nil { | ||
t.Fatalf("failed to write to file %v", err) | ||
} | ||
file.Close() | ||
} | ||
|
||
// createTestDirectory creates a temporary folder which will be deleted when cleanup is called. | ||
func createTestDirectory(t *testing.T) (path string, cleanup func()) { | ||
t.Helper() | ||
|
||
path, err := os.MkdirTemp(os.TempDir(), "") | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
return path, func() { os.RemoveAll(path) } | ||
} | ||
|
||
// handleAudioEndpoint Handles the completion endpoint by the test server. | ||
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { | ||
var err error | ||
|
||
// audio endpoints only accept POST requests | ||
if r.Method != "POST" { | ||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed) | ||
} | ||
|
||
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) | ||
if err != nil { | ||
http.Error(w, "failed to parse media type", http.StatusBadRequest) | ||
return | ||
} | ||
|
||
if !strings.HasPrefix(mediaType, "multipart") { | ||
http.Error(w, "request is not multipart", http.StatusBadRequest) | ||
} | ||
|
||
boundary, ok := params["boundary"] | ||
if !ok { | ||
http.Error(w, "no boundary in params", http.StatusBadRequest) | ||
return | ||
} | ||
|
||
fileData := &bytes.Buffer{} | ||
mr := multipart.NewReader(r.Body, boundary) | ||
part, err := mr.NextPart() | ||
if err != nil && errors.Is(err, io.EOF) { | ||
http.Error(w, "error accessing file", http.StatusBadRequest) | ||
return | ||
} | ||
if _, err = io.Copy(fileData, part); err != nil { | ||
http.Error(w, "failed to copy file", http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
if len(fileData.Bytes()) == 0 { | ||
w.WriteHeader(http.StatusInternalServerError) | ||
http.Error(w, "received empty file data", http.StatusBadRequest) | ||
return | ||
} | ||
|
||
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { | ||
http.Error(w, "failed to write body", http.StatusInternalServerError) | ||
return | ||
} | ||
} |