Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Unmarshal should fail if required fields are missing #16

Merged
merged 2 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ run:
timeout: 10m
skip-files:
- ".*\\.pb.*\\.go"
- ".*_test\\.go"

# output configuration options
output:
Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ clean-generated:

.PHONY: lint
lint: check-golangci-lint-install
@golangci-lint run
@golangci-lint run ./...
@cd example && golangci-lint run ./...

.PHONY: check-golangci-lint-install
check-golangci-lint-install:
Expand Down
44 changes: 44 additions & 0 deletions cmd/protoc-gen-fastmarshal/funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ func addProtoFunctions(fm template.FuncMap, protoFile *protogen.File) template.F
fm["getAdditionalImports"] = getAdditionalImports(protoFile)
fm["getImportPrefix"] = getImportPrefix(protoFile)
fm["mapFieldGoType"] = mapFieldGoType(protoFile)
fm["hasRequiredFields"] = hasRequiredFields(protoFile)
return fm
}

Expand Down Expand Up @@ -405,3 +406,46 @@ func mapFieldGoType(protoFile *protogen.File) func(*protogen.Field) string {
return fmt.Sprintf("map[%s]%s", ktype, vtype)
}
}

// msgHasRequiredField returns true if the specified message has at least 1 field marked required and
// false if not
func msgHasRequiredField(m *protogen.Message) bool {
if m.Desc.Syntax() == protoreflect.Proto3 {
return false
}
for _, f := range m.Fields {
if f.Desc.Cardinality() == protoreflect.Required {
return true
}
}
return false
}

// hasRequiredFields returns true if at least one field in the specified message is marked required
// and false if not.
//
// If m is nil, this function returns true if *any* message in the Protobuf file has a required field
// and false if not.
func hasRequiredFields(protoFile *protogen.File) func(*protogen.Message) bool {
anyMessageHasRequiredFields := false
if protoFile.Desc.Syntax() == protoreflect.Proto2 {
for _, m := range allMessages(protoFile)() {
anyMessageHasRequiredFields = anyMessageHasRequiredFields || msgHasRequiredField(m)
}
}

return func(m *protogen.Message) bool {
if m == nil {
return anyMessageHasRequiredFields
}
if m.Desc.Syntax() == protoreflect.Proto3 {
return false
}
for _, f := range m.Fields {
if f.Desc.Cardinality() == protoreflect.Required {
return true
}
}
return false
}
}
35 changes: 34 additions & 1 deletion cmd/protoc-gen-fastmarshal/templates/permessage.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
package {{ .ProtoDesc.GoPackageName }}

import (
"fmt"
"fmt"{{if and (eq $protoSyntax "proto2") (hasRequiredFields nil)}}
"strings"{{end}}
"github.com/CrowdStrike/csproto"
{{range (.Message | getAdditionalImports)}}{{.}}
{{end}}
Expand Down Expand Up @@ -114,7 +115,39 @@ func (m *{{ .Message.GoIdent.GoName}}) Unmarshal(p []byte) error {
{{- end -}}
}
}
}{{if hasRequiredFields .Message }}
// verify required fields are assigned
if err := m.csprotoCheckRequiredFields(); err != nil {
return err
}
{{end}}
return nil
}
{{if hasRequiredFields .Message}}

// csprotoCheckRequiredFields is called by Unmarshal() to ensure that all required fields have been
// populated.
func (m *{{.Message.GoIdent.GoName}}) csprotoCheckRequiredFields() error {
var missingFields []string
{{ range .Message.Fields }}
{{ if and (eq (.Desc.Cardinality | string) "required") (eq (.Desc.Syntax | string) "proto2") -}}
if m.{{.GoName}} == nil {
missingFields = append(missingFields, "{{.GoName}}")
}
{{- end -}}
{{ end }}
if len(missingFields) > 0 {
var sb strings.Builder
sb.WriteString("cannot unmarshal, one or more required fields missing: ")
for i, s := range missingFields {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(s)
}
return fmt.Errorf(sb.String())
}
return nil
}
{{ end }}
{{end}}
35 changes: 34 additions & 1 deletion cmd/protoc-gen-fastmarshal/templates/singlefile.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
package {{ .ProtoDesc.GoPackageName }}

import (
"fmt"
"fmt"{{if and (eq $protoSyntax "proto2") (hasRequiredFields nil)}}
"strings"{{end}}
"github.com/CrowdStrike/csproto"
{{range (allMessages | getAdditionalImports)}}{{.}}
{{end}}
Expand Down Expand Up @@ -117,8 +118,40 @@ func (m *{{ .GoIdent.GoName}}) Unmarshal(p []byte) error {
{{- end -}}
}
}
}{{if hasRequiredFields . }}
// verify required fields are assigned
if err := m.csprotoCheckRequiredFields(); err != nil {
return err
}
{{end}}
return nil
}
{{if hasRequiredFields .}}

// csprotoCheckRequiredFields is called by Unmarshal() to ensure that all required fields have been
// populated.
func (m *{{.GoIdent.GoName}}) csprotoCheckRequiredFields() error {
var missingFields []string
{{ range .Fields }}
{{ if and (eq (.Desc.Cardinality | string) "required") (eq (.Desc.Syntax | string) "proto2") -}}
if m.{{.GoName}} == nil {
missingFields = append(missingFields, "{{.GoName}}")
}
{{- end -}}
{{ end }}
if len(missingFields) > 0 {
var sb strings.Builder
sb.WriteString("cannot unmarshal, one or more required fields missing: ")
for i, s := range missingFields {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(s)
}
return fmt.Errorf(sb.String())
}
return nil
}
{{ end }}
{{ end }}
{{end}}
149 changes: 149 additions & 0 deletions example/proto2/gogo/gogo_proto2_example.pb.fm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package gogo

import (
"fmt"
"strings"
"github.com/CrowdStrike/csproto"
)

Expand Down Expand Up @@ -172,6 +173,42 @@ func (m *BaseEvent) Unmarshal(p []byte) error {
}
}
}
// verify required fields are assigned
if err := m.csprotoCheckRequiredFields(); err != nil {
return err
}

return nil
}

// csprotoCheckRequiredFields is called by Unmarshal() to ensure that all required fields have been
// populated.
func (m *BaseEvent) csprotoCheckRequiredFields() error {
var missingFields []string

if m.EventID == nil {
missingFields = append(missingFields, "EventID")
}
if m.SourceID == nil {
missingFields = append(missingFields, "SourceID")
}
if m.Timestamp == nil {
missingFields = append(missingFields, "Timestamp")
}
if m.EventType == nil {
missingFields = append(missingFields, "EventType")
}
if len(missingFields) > 0 {
var sb strings.Builder
sb.WriteString("cannot unmarshal, one or more required fields missing: ")
for i, s := range missingFields {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(s)
}
return fmt.Errorf(sb.String())
}
return nil
}

Expand Down Expand Up @@ -413,6 +450,34 @@ func (m *TestEvent) Unmarshal(p []byte) error {
}
}
}
// verify required fields are assigned
if err := m.csprotoCheckRequiredFields(); err != nil {
return err
}

return nil
}

// csprotoCheckRequiredFields is called by Unmarshal() to ensure that all required fields have been
// populated.
func (m *TestEvent) csprotoCheckRequiredFields() error {
var missingFields []string

if m.Embedded == nil {
missingFields = append(missingFields, "Embedded")
}

if len(missingFields) > 0 {
var sb strings.Builder
sb.WriteString("cannot unmarshal, one or more required fields missing: ")
for i, s := range missingFields {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(s)
}
return fmt.Errorf(sb.String())
}
return nil
}

Expand Down Expand Up @@ -557,6 +622,34 @@ func (m *EmbeddedEvent) Unmarshal(p []byte) error {
}
}
}
// verify required fields are assigned
if err := m.csprotoCheckRequiredFields(); err != nil {
return err
}

return nil
}

// csprotoCheckRequiredFields is called by Unmarshal() to ensure that all required fields have been
// populated.
func (m *EmbeddedEvent) csprotoCheckRequiredFields() error {
var missingFields []string

if m.ID == nil {
missingFields = append(missingFields, "ID")
}

if len(missingFields) > 0 {
var sb strings.Builder
sb.WriteString("cannot unmarshal, one or more required fields missing: ")
for i, s := range missingFields {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(s)
}
return fmt.Errorf(sb.String())
}
return nil
}

Expand Down Expand Up @@ -904,6 +997,34 @@ func (m *AllTheThings) Unmarshal(p []byte) error {
}
}
}
// verify required fields are assigned
if err := m.csprotoCheckRequiredFields(); err != nil {
return err
}

return nil
}

// csprotoCheckRequiredFields is called by Unmarshal() to ensure that all required fields have been
// populated.
func (m *AllTheThings) csprotoCheckRequiredFields() error {
var missingFields []string

if m.ID == nil {
missingFields = append(missingFields, "ID")
}

if len(missingFields) > 0 {
var sb strings.Builder
sb.WriteString("cannot unmarshal, one or more required fields missing: ")
for i, s := range missingFields {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(s)
}
return fmt.Errorf(sb.String())
}
return nil
}

Expand Down Expand Up @@ -1333,6 +1454,34 @@ func (m *RepeatAllTheThings) Unmarshal(p []byte) error {
}
}
}
// verify required fields are assigned
if err := m.csprotoCheckRequiredFields(); err != nil {
return err
}

return nil
}

// csprotoCheckRequiredFields is called by Unmarshal() to ensure that all required fields have been
// populated.
func (m *RepeatAllTheThings) csprotoCheckRequiredFields() error {
var missingFields []string

if m.ID == nil {
missingFields = append(missingFields, "ID")
}

if len(missingFields) > 0 {
var sb strings.Builder
sb.WriteString("cannot unmarshal, one or more required fields missing: ")
for i, s := range missingFields {
if i > 0 {
sb.WriteRune(',')
}
sb.WriteString(s)
}
return fmt.Errorf(sb.String())
}
return nil
}

Expand Down
Loading