Skip to content

Commit

Permalink
convert EmbeddingModel to string type (sashabaranov#629)
Browse files Browse the repository at this point in the history
This gives the user the ability to pass in models for embeddings that are not
already defined in the library. Also more closely matches how the completions
API works.
  • Loading branch information
jaffee committed Jan 15, 2024
1 parent 682b7ad commit e01a2d7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 118 deletions.
120 changes: 21 additions & 99 deletions embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,108 +13,30 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch")

// EmbeddingModel enumerates the models which can be used
// to generate Embedding vectors.
type EmbeddingModel int

// String implements the fmt.Stringer interface.
func (e EmbeddingModel) String() string {
return enumToString[e]
}

// MarshalText implements the encoding.TextMarshaler interface.
func (e EmbeddingModel) MarshalText() ([]byte, error) {
return []byte(e.String()), nil
}

// UnmarshalText implements the encoding.TextUnmarshaler interface.
// On unrecognized value, it sets |e| to Unknown.
func (e *EmbeddingModel) UnmarshalText(b []byte) error {
if val, ok := stringToEnum[(string(b))]; ok {
*e = val
return nil
}

*e = Unknown

return nil
}
type EmbeddingModel string

const (
Unknown EmbeddingModel = iota
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaCodeSearchCode
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaCodeSearchText
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageCodeSearchCode
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageCodeSearchText
AdaEmbeddingV2
// Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSimilarity EmbeddingModel = "text-similarity-ada-001"
BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001"
CurieSimilarity EmbeddingModel = "text-similarity-curie-001"
DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001"
AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001"
AdaSearchQuery EmbeddingModel = "text-search-ada-query-001"
BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001"
BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001"
CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001"
CurieSearchQuery EmbeddingModel = "text-search-curie-query-001"
DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001"
DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001"
AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001"
AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001"
BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001"
BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001"

AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002"
)

var enumToString = map[EmbeddingModel]string{
AdaSimilarity: "text-similarity-ada-001",
BabbageSimilarity: "text-similarity-babbage-001",
CurieSimilarity: "text-similarity-curie-001",
DavinciSimilarity: "text-similarity-davinci-001",
AdaSearchDocument: "text-search-ada-doc-001",
AdaSearchQuery: "text-search-ada-query-001",
BabbageSearchDocument: "text-search-babbage-doc-001",
BabbageSearchQuery: "text-search-babbage-query-001",
CurieSearchDocument: "text-search-curie-doc-001",
CurieSearchQuery: "text-search-curie-query-001",
DavinciSearchDocument: "text-search-davinci-doc-001",
DavinciSearchQuery: "text-search-davinci-query-001",
AdaCodeSearchCode: "code-search-ada-code-001",
AdaCodeSearchText: "code-search-ada-text-001",
BabbageCodeSearchCode: "code-search-babbage-code-001",
BabbageCodeSearchText: "code-search-babbage-text-001",
AdaEmbeddingV2: "text-embedding-ada-002",
}

var stringToEnum = map[string]EmbeddingModel{
"text-similarity-ada-001": AdaSimilarity,
"text-similarity-babbage-001": BabbageSimilarity,
"text-similarity-curie-001": CurieSimilarity,
"text-similarity-davinci-001": DavinciSimilarity,
"text-search-ada-doc-001": AdaSearchDocument,
"text-search-ada-query-001": AdaSearchQuery,
"text-search-babbage-doc-001": BabbageSearchDocument,
"text-search-babbage-query-001": BabbageSearchQuery,
"text-search-curie-doc-001": CurieSearchDocument,
"text-search-curie-query-001": CurieSearchQuery,
"text-search-davinci-doc-001": DavinciSearchDocument,
"text-search-davinci-query-001": DavinciSearchQuery,
"code-search-ada-code-001": AdaCodeSearchCode,
"code-search-ada-text-001": AdaCodeSearchText,
"code-search-babbage-code-001": BabbageCodeSearchCode,
"code-search-babbage-text-001": BabbageCodeSearchText,
"text-embedding-ada-002": AdaEmbeddingV2,
}

// Embedding is a special format of data representation that can be easily utilized by machine
// learning models and algorithms. The embedding is an information dense representation of the
// semantic meaning of a piece of text. Each embedding is a vector of floating point numbers,
Expand Down Expand Up @@ -306,7 +228,7 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) {
baseReq := conv.Convert()
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq))
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq))
if err != nil {
return
}
Expand Down
22 changes: 3 additions & 19 deletions embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) {
// the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}

Expand All @@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) {
}
marshaled, err = json.Marshal(embeddingReqStrings)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}

Expand All @@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) {
}
marshaled, err = json.Marshal(embeddingReqTokens)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
}
}

func TestEmbeddingModel(t *testing.T) {
var em openai.EmbeddingModel
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
checks.NoError(t, err, "Could not marshal embedding model")

if em != openai.AdaSimilarity {
t.Errorf("Model is not equal to AdaSimilarity")
}

err = em.UnmarshalText([]byte("some-non-existent-model"))
checks.NoError(t, err, "Could not marshal embedding model")
if em != openai.Unknown {
t.Errorf("Model is not equal to Unknown")
}
}

func TestEmbeddingEndpoint(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
Expand Down

0 comments on commit e01a2d7

Please sign in to comment.