Skip to content

Commit

Permalink
Add DotProduct Method and README Example for Embedding Similarity Sea…
Browse files Browse the repository at this point in the history
…rch (sashabaranov#492)

* Add DotProduct Method and README Example for Embedding Similarity Search

- Implement a DotProduct() method for the Embedding struct to calculate the dot product between two embeddings.
- Add a custom error type for vector length mismatch.
- Update README.md with a complete example demonstrating how to perform an embedding similarity search for user queries.
- Add unit tests to validate the new DotProduct() method and error handling.

* Update README to focus on Embedding Semantic Similarity
  • Loading branch information
ealvar3z committed Oct 2, 2023
1 parent 0d5256f commit 84f77a0
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
56 changes: 56 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,62 @@ func main() {
```
</details>

<detail>
<summary>Embedding Semantic Similarity</summary>

```go
package main

import (
"context"
"log"
openai "github.com/sashabaranov/go-openai"

)

func main() {
client := openai.NewClient("your-token")

// Create an EmbeddingRequest for the user query
queryReq := openai.EmbeddingRequest{
Input: []string{"How many chucks would a woodchuck chuck"},
Model: openai.AdaEmbeddingv2,
}

// Create an embedding for the user query
queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq)
if err != nil {
log.Fatal("Error creating query embedding:", err)
}

// Create an EmbeddingRequest for the target text
targetReq := openai.EmbeddingRequest{
Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"},
Model: openai.AdaEmbeddingv2,
}

// Create an embedding for the target text
targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq)
if err != nil {
log.Fatal("Error creating target embedding:", err)
}

// Now that we have the embeddings for the user query and the target text, we
// can calculate their similarity.
queryEmbedding := queryResponse.Data[0]
targetEmbedding := targetResponse.Data[0]

similarity, err := queryEmbedding.DotProduct(&targetEmbedding)
if err != nil {
log.Fatal("Error calculating dot product:", err)
}

log.Printf("The similarity score between the query and the target is %f", similarity)
}

```
</detail>

<details>
<summary>Azure OpenAI Embeddings</summary>

Expand Down
20 changes: 20 additions & 0 deletions embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import (
"context"
"encoding/base64"
"encoding/binary"
"errors"
"math"
"net/http"
)

var ErrVectorLengthMismatch = errors.New("vector length mismatch")

// EmbeddingModel enumerates the models which can be used
// to generate Embedding vectors.
type EmbeddingModel int
Expand Down Expand Up @@ -124,6 +127,23 @@ type Embedding struct {
Index int `json:"index"`
}

// DotProduct calculates the dot product of the embedding vector with another
// embedding vector. Both vectors must have the same length; otherwise, an
// ErrVectorLengthMismatch is returned. The method returns the calculated dot
// product as a float32 value.
func (e *Embedding) DotProduct(other *Embedding) (float32, error) {
if len(e.Embedding) != len(other.Embedding) {
return 0, ErrVectorLengthMismatch
}

var dotProduct float32
for i := range e.Embedding {
dotProduct += e.Embedding[i] * other.Embedding[i]
}

return dotProduct, nil
}

// EmbeddingResponse is the response from a Create embeddings request.
type EmbeddingResponse struct {
Object string `json:"object"`
Expand Down
38 changes: 38 additions & 0 deletions embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"reflect"
"testing"
Expand Down Expand Up @@ -233,3 +235,39 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
})
}
}

func TestDotProduct(t *testing.T) {
v1 := &Embedding{Embedding: []float32{1, 2, 3}}
v2 := &Embedding{Embedding: []float32{2, 4, 6}}
expected := float32(28.0)

result, err := v1.DotProduct(v2)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if math.Abs(float64(result-expected)) > 1e-12 {
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
}

v1 = &Embedding{Embedding: []float32{1, 0, 0}}
v2 = &Embedding{Embedding: []float32{0, 1, 0}}
expected = float32(0.0)

result, err = v1.DotProduct(v2)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if math.Abs(float64(result-expected)) > 1e-12 {
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
}

// Test for VectorLengthMismatchError
v1 = &Embedding{Embedding: []float32{1, 0, 0}}
v2 = &Embedding{Embedding: []float32{0, 1}}
_, err = v1.DotProduct(v2)
if !errors.Is(err, ErrVectorLengthMismatch) {
t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err)
}
}

0 comments on commit 84f77a0

Please sign in to comment.