Skip to content

Commit

Permalink
simplify unmarshal (sashabaranov#191)
Browse files Browse the repository at this point in the history
* simplify unmarshal

* simplify unmarshalError

* rename errorAccumulate -> defaultErrorAccumulator

* update converage
  • Loading branch information
sashabaranov committed Mar 22, 2023
1 parent a5a945a commit eb68a72
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 22 deletions.
28 changes: 14 additions & 14 deletions error_accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

type errorAccumulator interface {
write(p []byte) error
unmarshalError() (*ErrorResponse, error)
unmarshalError() *ErrorResponse
}

type errorBuffer interface {
Expand All @@ -17,35 +17,35 @@ type errorBuffer interface {
Bytes() []byte
}

type errorAccumulate struct {
type defaultErrorAccumulator struct {
buffer errorBuffer
unmarshaler unmarshaler
}

func newErrorAccumulator() errorAccumulator {
return &errorAccumulate{
return &defaultErrorAccumulator{
buffer: &bytes.Buffer{},
unmarshaler: &jsonUnmarshaler{},
}
}

func (e *errorAccumulate) write(p []byte) error {
func (e *defaultErrorAccumulator) write(p []byte) error {
_, err := e.buffer.Write(p)
if err != nil {
return fmt.Errorf("error accumulator write error, %w", err)
}
return nil
}

func (e *errorAccumulate) unmarshalError() (*ErrorResponse, error) {
var err error
if e.buffer.Len() > 0 {
var errRes ErrorResponse
err = e.unmarshaler.unmarshal(e.buffer.Bytes(), &errRes)
if err != nil {
return nil, err
}
return &errRes, nil
func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) {
if e.buffer.Len() == 0 {
return
}
return nil, err

err := e.unmarshaler.unmarshal(e.buffer.Bytes(), &errResp)
if err != nil {
errResp = nil
}

return
}
18 changes: 12 additions & 6 deletions error_accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,29 @@ func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error {
}

func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
accumulator := &errorAccumulate{
accumulator := &defaultErrorAccumulator{
buffer: &bytes.Buffer{},
unmarshaler: &failingUnMarshaller{},
}

respErr := accumulator.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil with empty buffer: %v", respErr)
}

err := accumulator.write([]byte("{"))
if err != nil {
t.Fatalf("%+v", err)
}
_, err = accumulator.unmarshalError()
if !errors.Is(err, errTestUnmarshalerFailed) {
t.Fatalf("Did not return error when unmarshaler failed: %v", err)

respErr = accumulator.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr)
}
}

func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &errorAccumulate{
accumulator := &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{},
}
Expand All @@ -78,7 +84,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
if err != nil {
t.Fatal(err)
}
stream.errAccumulator = &errorAccumulate{
stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{},
}
Expand Down
5 changes: 3 additions & 2 deletions stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ func (stream *streamReader[T]) Recv() (response T, err error) {
waitForData:
line, err := stream.reader.ReadBytes('\n')
if err != nil {
if errRes, _ := stream.errAccumulator.unmarshalError(); errRes != nil {
err = fmt.Errorf("error, %w", errRes.Error)
respErr := stream.errAccumulator.unmarshalError()
if respErr != nil {
err = fmt.Errorf("error, %w", respErr.Error)
}
return
}
Expand Down

0 comments on commit eb68a72

Please sign in to comment.