Skip to content

Commit

Permalink
fix: use bytes in gRPC proto instead of strings (mudler#813)
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
  • Loading branch information
mudler committed Jul 27, 2023
1 parent 0af0df7 commit b96e30e
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 16 deletions.
8 changes: 4 additions & 4 deletions api/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,17 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c
opts.Prompt = s
if tokenCallback != nil {
ss := ""
err := inferenceModel.PredictStream(ctx, opts, func(s string) {
tokenCallback(s)
ss += s
err := inferenceModel.PredictStream(ctx, opts, func(s []byte) {
tokenCallback(string(s))
ss += string(s)
})
return ss, err
} else {
reply, err := inferenceModel.Predict(ctx, opts)
if err != nil {
return "", err
}
return reply.Message, err
return string(reply.Message), err
}
}

Expand Down
2 changes: 1 addition & 1 deletion extra/grpc/huggingface/backend_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion extra/grpc/huggingface/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context):
return backend_pb2.Reply(message="OK")
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
model_name = request.Model
model_name = os.path.basename(model_name)
Expand Down
4 changes: 2 additions & 2 deletions pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (c *Client) HealthCheck(ctx context.Context) bool {
return false
}

if res.Message == "OK" {
if string(res.Message) == "OK" {
return true
}
return false
Expand Down Expand Up @@ -80,7 +80,7 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
return client.LoadModel(ctx, in, opts...)
}

func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s string), opts ...grpc.CallOption) error {
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return err
Expand Down
4 changes: 4 additions & 0 deletions pkg/grpc/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ type LLM interface {
AudioTranscription(*pb.TranscriptRequest) (api.Result, error)
TTS(*pb.TTSRequest) error
}

func newReply(s string) *pb.Reply {
return &pb.Reply{Message: []byte(s)}
}
8 changes: 4 additions & 4 deletions pkg/grpc/proto/backend.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/grpc/proto/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ message PredictOptions {

// The response message containing the result
message Reply {
string message = 1;
bytes message = 1;
}

message ModelOptions {
Expand Down
6 changes: 3 additions & 3 deletions pkg/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type server struct {
}

func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) {
return &pb.Reply{Message: "OK"}, nil
return newReply("OK"), nil
}

func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) {
Expand All @@ -48,7 +48,7 @@ func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result

func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) {
result, err := s.llm.Predict(in)
return &pb.Reply{Message: result}, err
return newReply(result), err
}

func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) {
Expand Down Expand Up @@ -99,7 +99,7 @@ func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictS
done := make(chan bool)
go func() {
for result := range resultChan {
stream.Send(&pb.Reply{Message: result})
stream.Send(newReply(result))
}
done <- true
}()
Expand Down

0 comments on commit b96e30e

Please sign in to comment.