Skip to content

Commit 3992ea8

Browse files
lfolgerneild
authored andcommitted
all: implement depth limit for unmarshaling
+ This change introduce a default and configurable depth limit for proto.Unmarshal. If a message is nested deeper than the limit, unmarshaling will fail. There are two ways to nest messages. Either by having fields which are message types itself or by using groups. + The default limit is 10,000 for now. This might change in the future to align it with other language implementation (C++ and Java use 100 as limit). + If pure groups (groups that don't contain message fields) are nested deeper than the default limit the unmarshaling fails with: proto: cannot parse invalid wire-format data + Note: the configured limit does not apply to pure groups. + This change is introduced to improve security and robustness. Because unmarshaling is implemented using recursion it can lead to stack overflows for certain inputs. The introduced limit protects against this. + A secondary motivation for this limit is the alignment with other languages. Protocol buffers are a language interoperability mechanism and thus either all implementations should accept the input or all implementation should reject the input. Change-Id: I14bdb44d06e4bd1aa90d6336c2cf6446003b2037 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/385854 Trust: Dmitri Shuralyov <[email protected]> Reviewed-by: Damien Neil <[email protected]> Trust: Damien Neil <[email protected]> Reviewed-by: Nicolas Hillegeer <[email protected]> Reviewed-by: Chressie Himpel <[email protected]>
1 parent e5db296 commit 3992ea8

File tree

6 files changed

+41
-7
lines changed

6 files changed

+41
-7
lines changed

encoding/protowire/wire.go

+14-5
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ import (
2121
type Number int32
2222

2323
const (
24-
MinValidNumber Number = 1
25-
FirstReservedNumber Number = 19000
26-
LastReservedNumber Number = 19999
27-
MaxValidNumber Number = 1<<29 - 1
24+
MinValidNumber Number = 1
25+
FirstReservedNumber Number = 19000
26+
LastReservedNumber Number = 19999
27+
MaxValidNumber Number = 1<<29 - 1
28+
DefaultRecursionLimit = 10000
2829
)
2930

3031
// IsValid reports whether the field number is semantically valid.
@@ -55,6 +56,7 @@ const (
5556
errCodeOverflow
5657
errCodeReserved
5758
errCodeEndGroup
59+
errCodeRecursionDepth
5860
)
5961

6062
var (
@@ -112,6 +114,10 @@ func ConsumeField(b []byte) (Number, Type, int) {
112114
// When parsing a group, the length includes the end group marker and
113115
// the end group is verified to match the starting field number.
114116
func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
117+
return consumeFieldValueD(num, typ, b, DefaultRecursionLimit)
118+
}
119+
120+
func consumeFieldValueD(num Number, typ Type, b []byte, depth int) (n int) {
115121
switch typ {
116122
case VarintType:
117123
_, n = ConsumeVarint(b)
@@ -126,6 +132,9 @@ func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
126132
_, n = ConsumeBytes(b)
127133
return n
128134
case StartGroupType:
135+
if depth < 0 {
136+
return errCodeRecursionDepth
137+
}
129138
n0 := len(b)
130139
for {
131140
num2, typ2, n := ConsumeTag(b)
@@ -140,7 +149,7 @@ func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
140149
return n0 - len(b)
141150
}
142151

143-
n = ConsumeFieldValue(num2, typ2, b)
152+
n = consumeFieldValueD(num2, typ2, b, depth-1)
144153
if n < 0 {
145154
return n // forward error code
146155
}

internal/fuzz/wirefuzz/fuzz.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func Fuzz(data []byte) (score int) {
4141
// Unmarshal, Validate, and CheckInitialized should agree about initialization.
4242
checkInit := proto.CheckInitialized(m1) == nil
4343
methods := m1.ProtoReflect().ProtoMethods()
44-
in := piface.UnmarshalInput{Message: mt.New(), Resolver: protoregistry.GlobalTypes}
44+
in := piface.UnmarshalInput{Message: mt.New(), Resolver: protoregistry.GlobalTypes, Depth: 10000}
4545
if checkInit {
4646
// If the message initialized, the both Unmarshal and Validate should
4747
// report it as such. False negatives are tolerated, but have a

internal/impl/decode.go

+8
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ import (
1818
)
1919

2020
var errDecode = errors.New("cannot parse invalid wire-format data")
21+
var errRecursionDepth = errors.New("exceeded maximum recursion depth")
2122

2223
type unmarshalOptions struct {
2324
flags protoiface.UnmarshalInputFlags
2425
resolver interface {
2526
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
2627
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
2728
}
29+
depth int
2830
}
2931

3032
func (o unmarshalOptions) Options() proto.UnmarshalOptions {
@@ -44,6 +46,7 @@ func (o unmarshalOptions) IsDefault() bool {
4446

4547
var lazyUnmarshalOptions = unmarshalOptions{
4648
resolver: preg.GlobalTypes,
49+
depth: protowire.DefaultRecursionLimit,
4750
}
4851

4952
type unmarshalOutput struct {
@@ -62,6 +65,7 @@ func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutp
6265
out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
6366
flags: in.Flags,
6467
resolver: in.Resolver,
68+
depth: in.Depth,
6569
})
6670
var flags piface.UnmarshalOutputFlags
6771
if out.initialized {
@@ -82,6 +86,10 @@ var errUnknown = errors.New("unknown")
8286

8387
func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
8488
mi.init()
89+
opts.depth--
90+
if opts.depth < 0 {
91+
return out, errRecursionDepth
92+
}
8593
if flags.ProtoLegacy && mi.isMessageSet {
8694
return unmarshalMessageSet(mi, b, p, opts)
8795
}

proto/decode.go

+16-1
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,25 @@ type UnmarshalOptions struct {
4242
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
4343
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
4444
}
45+
46+
// RecursionLimit limits how deeply messages may be nested.
47+
// If zero, a default limit is applied.
48+
RecursionLimit int
4549
}
4650

4751
// Unmarshal parses the wire-format message in b and places the result in m.
4852
// The provided message must be mutable (e.g., a non-nil pointer to a message).
4953
func Unmarshal(b []byte, m Message) error {
50-
_, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
54+
_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
5155
return err
5256
}
5357

5458
// Unmarshal parses the wire-format message in b and places the result in m.
5559
// The provided message must be mutable (e.g., a non-nil pointer to a message).
5660
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
61+
if o.RecursionLimit == 0 {
62+
o.RecursionLimit = protowire.DefaultRecursionLimit
63+
}
5764
_, err := o.unmarshal(b, m.ProtoReflect())
5865
return err
5966
}
@@ -63,6 +70,9 @@ func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
6370
// This method permits fine-grained control over the unmarshaler.
6471
// Most users should use Unmarshal instead.
6572
func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
73+
if o.RecursionLimit == 0 {
74+
o.RecursionLimit = protowire.DefaultRecursionLimit
75+
}
6676
return o.unmarshal(in.Buf, in.Message)
6777
}
6878

@@ -86,12 +96,17 @@ func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out proto
8696
Message: m,
8797
Buf: b,
8898
Resolver: o.Resolver,
99+
Depth: o.RecursionLimit,
89100
}
90101
if o.DiscardUnknown {
91102
in.Flags |= protoiface.UnmarshalDiscardUnknown
92103
}
93104
out, err = methods.Unmarshal(in)
94105
} else {
106+
o.RecursionLimit--
107+
if o.RecursionLimit < 0 {
108+
return out, errors.New("exceeded max recursion depth")
109+
}
95110
err = o.unmarshalMessageSlow(b, m)
96111
}
97112
if err != nil {

reflect/protoreflect/methods.go

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ type (
5353
FindExtensionByName(field FullName) (ExtensionType, error)
5454
FindExtensionByNumber(message FullName, field FieldNumber) (ExtensionType, error)
5555
}
56+
Depth int
5657
}
5758
unmarshalOutput = struct {
5859
pragma.NoUnkeyedLiterals

runtime/protoiface/methods.go

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ type UnmarshalInput = struct {
103103
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
104104
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
105105
}
106+
Depth int
106107
}
107108

108109
// UnmarshalOutput is output from the Unmarshal method.

0 commit comments

Comments
 (0)