Skip to content

Commit

Permalink
alter some errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Suchun-sv committed Aug 24, 2024
1 parent 5cbae03 commit 27b2f71
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 526 deletions.
17 changes: 3 additions & 14 deletions plugins/wasm-go/extensions/ai-cache/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,19 @@ type RedisConfig struct {
// @Title zh-CN 请求超时
// @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒
RedisTimeout uint32 `required:"false" yaml:"timeout" json:"timeout"`

RedisHost string `required:"false" yaml:"host" json:"host"`
}

func CreateProvider(cf RedisConfig, log wrapper.Log) (Provider, error) {
<<<<<<< HEAD
=======
log.Warnf("redis config: %v", cf)
>>>>>>> origin/feat/chroma
rp := redisProvider{
config: cf,
client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: cf.RedisServiceName,
Host: "redis",
Host: cf.RedisHost,
Port: int64(cf.RedisServicePort)}),
// client: wrapper.NewRedisClusterClient(wrapper.DnsCluster{
// ServiceName: cf.RedisServiceName,
// Port: int64(cf.RedisServicePort)}),
}
// FQDN := wrapper.FQDNCluster{
// FQDN: cf.RedisServiceName,
// Host: "redis",
// Port: int64(cf.RedisServicePort)}
// log.Debugf("test:%s", FQDN.ClusterName())
// log.Debugf("test:%d", cf.RedisServicePort)
// log.Debugf("test:%s", proxywasm.RedisInit(FQDN.ClusterName(), "", "", 100))
err := rp.Init(cf.RedisUsername, cf.RedisPassword, cf.RedisTimeout)
return &rp, err
}
Expand Down
28 changes: 14 additions & 14 deletions plugins/wasm-go/extensions/ai-cache/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package config
import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vectorDatabase"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
Expand All @@ -16,11 +16,11 @@ type KVExtractor struct {
}

type PluginConfig struct {
EmbeddingProviderConfig embedding.ProviderConfig `required:"true" yaml:"embeddingProvider" json:"embeddingProvider"`
VectorDatabaseProviderConfig vectorDatabase.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"`
CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"`
CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"`
CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"`
EmbeddingProviderConfig embedding.ProviderConfig `required:"true" yaml:"embeddingProvider" json:"embeddingProvider"`
vectorProviderConfig vector.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"`
CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"`
CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"`
CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"`
// @Title zh-CN 返回 HTTP 响应的模版
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"`
Expand All @@ -39,14 +39,14 @@ type PluginConfig struct {

RedisConfig cache.RedisConfig `required:"true" yaml:"redisConfig" json:"redisConfig"`
// 现在只支持RedisClient作为cacheClient
redisProvider cache.Provider `yaml:"-"`
embeddingProvider embedding.Provider `yaml:"-"`
vectorDatabaseProvider vectorDatabase.Provider `yaml:"-"`
redisProvider cache.Provider `yaml:"-"`
embeddingProvider embedding.Provider `yaml:"-"`
vectorProvider vector.Provider `yaml:"-"`
}

func (c *PluginConfig) FromJson(json gjson.Result) {
c.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider"))
c.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider"))
c.vectorProviderConfig.FromJson(json.Get("vectorProvider"))
c.RedisConfig.FromJson(json.Get("redis"))
if c.CacheKeyFrom.RequestBody == "" {
c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content"
Expand Down Expand Up @@ -84,7 +84,7 @@ func (c *PluginConfig) Validate() error {
if err := c.EmbeddingProviderConfig.Validate(); err != nil {
return err
}
if err := c.VectorDatabaseProviderConfig.Validate(); err != nil {
if err := c.vectorProviderConfig.Validate(); err != nil {
return err
}
return nil
Expand All @@ -96,7 +96,7 @@ func (c *PluginConfig) Complete(log wrapper.Log) error {
if err != nil {
return err
}
c.vectorDatabaseProvider, err = vectorDatabase.CreateProvider(c.VectorDatabaseProviderConfig)
c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig)
if err != nil {
return err
}
Expand All @@ -111,8 +111,8 @@ func (c *PluginConfig) GetEmbeddingProvider() embedding.Provider {
return c.embeddingProvider
}

func (c *PluginConfig) GetVectorDatabaseProvider() vectorDatabase.Provider {
return c.vectorDatabaseProvider
func (c *PluginConfig) GetvectorProvider() vector.Provider {
return c.vectorProvider
}

func (c *PluginConfig) GetCacheProvider() cache.Provider {
Expand Down
102 changes: 102 additions & 0 deletions plugins/wasm-go/extensions/ai-cache/core.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package main

import (
"fmt"
"net/http"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/resp"
)

func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, ifUseEmbedding bool) {
activeCacheProvider := config.GetCacheProvider()
log.Debugf("activeCacheProvider:%v", activeCacheProvider)
activeCacheProvider.Get(embedding.CacheKeyPrefix+key, func(response resp.Value) {
if err := response.Error(); err == nil && !response.IsNull() {
log.Warnf("cache hit, key:%s", key)
HandleCacheHit(key, response, stream, ctx, config, log)
} else {
log.Warnf("cache miss, key:%s", key)
if ifUseEmbedding {
HandleCacheMiss(key, err, response, ctx, config, log, key, stream)
} else {
proxywasm.ResumeHttpRequest()
return
}
}
})
}

func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) {
ctx.SetContext(embedding.CacheKeyContextKey, nil)
if !stream {
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "[Test, this is cache]"+response.String())), -1)
} else {
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, "[Test, this is cache]"+response.String())), -1)
}
}

func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) {
if err != nil {
log.Warnf("redis get key:%s failed, err:%v", key, err)
}
if response.IsNull() {
log.Warnf("cache miss, key:%s", key)
}
FetchAndProcessEmbeddings(key, ctx, config, log, queryString, stream)
}

func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) {
activeEmbeddingProvider := config.GetEmbeddingProvider()
activeEmbeddingProvider.GetEmbedding(queryString, ctx, log,
func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != 200 {
log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody))
} else {
log.Debugf("Successfully fetched embeddings for key: %s", key)
QueryVectorDB(key, emb, ctx, config, log, stream)
}
})
}

func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) {
log.Debugf("QueryVectorDB key: %s", key)
activeVectorDatabaseProvider := config.GetvectorProvider()
log.Debugf("activeVectorDatabaseProvider: %+v", activeVectorDatabaseProvider)
activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log,
func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) {
resp, err := activeVectorDatabaseProvider.ParseQueryResponse(responseBody, ctx, log)
if err != nil {
log.Errorf("Failed to query vector database, err: %v", err)
proxywasm.ResumeHttpRequest()
return
}

if len(resp.MostSimilarData) == 0 {
log.Warnf("Failed to query vector database, no most similar key found")
activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log,
func(ctx wrapper.HttpContext, log wrapper.Log) {
proxywasm.ResumeHttpRequest()
})
return
}

log.Infof("most similar key: %s", resp.MostSimilarData)
if resp.Score < activeVectorDatabaseProvider.GetThreshold() {
log.Infof("accept most similar key: %s, score: %f", resp.MostSimilarData, resp.Score)
// ctx.SetContext(embedding.CacheKeyContextKey, nil)
RedisSearchHandler(resp.MostSimilarData, ctx, config, log, stream, false)
} else {
log.Infof("the most similar key's score is too high, key: %s, score: %f", resp.MostSimilarData, resp.Score)
activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log,
func(ctx wrapper.HttpContext, log wrapper.Log) {
proxywasm.ResumeHttpRequest()
})
return
}
},
)
}
Loading

0 comments on commit 27b2f71

Please sign in to comment.