Skip to content

Commit 8e3596c

Browse files
authored
cherry-pick #7557 to v1.66.x branch (#7564)
1 parent 62baa5f commit 8e3596c

9 files changed

+72
-133
lines changed

codec.go

+2-8
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,11 @@ type baseCodec interface {
3939
// with encoding.GetCodec and if it is registered wraps it with newCodecV1Bridge
4040
// to turn it into an encoding.CodecV2. Returns nil otherwise.
4141
func getCodec(name string) encoding.CodecV2 {
42-
codecV2 := encoding.GetCodecV2(name)
43-
if codecV2 != nil {
44-
return codecV2
45-
}
46-
47-
codecV1 := encoding.GetCodec(name)
48-
if codecV1 != nil {
42+
if codecV1 := encoding.GetCodec(name); codecV1 != nil {
4943
return newCodecV1Bridge(codecV1)
5044
}
5145

52-
return nil
46+
return encoding.GetCodecV2(name)
5347
}
5448

5549
func newCodecV0Bridge(c Codec) baseCodec {

codec_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
)
2727

2828
func (s) TestGetCodecForProtoIsNotNil(t *testing.T) {
29-
if encoding.GetCodec(proto.Name) == nil {
29+
if encoding.GetCodecV2(proto.Name) == nil {
3030
t.Fatalf("encoding.GetCodec(%q) must not be nil by default", proto.Name)
3131
}
3232
}

encoding/encoding.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ type Codec interface {
9494
Name() string
9595
}
9696

97-
var registeredCodecs = make(map[string]Codec)
97+
var registeredCodecs = make(map[string]any)
9898

9999
// RegisterCodec registers the provided Codec for use with all gRPC clients and
100100
// servers.
@@ -126,5 +126,6 @@ func RegisterCodec(codec Codec) {
126126
//
127127
// The content-subtype is expected to be lowercase.
128128
func GetCodec(contentSubtype string) Codec {
129-
return registeredCodecs[contentSubtype]
129+
c, _ := registeredCodecs[contentSubtype].(Codec)
130+
return c
130131
}

encoding/encoding_test.go

+19-18
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"google.golang.org/grpc/internal/grpctest"
3737
"google.golang.org/grpc/internal/grpcutil"
3838
"google.golang.org/grpc/internal/stubserver"
39+
"google.golang.org/grpc/mem"
3940
"google.golang.org/grpc/metadata"
4041
"google.golang.org/grpc/status"
4142

@@ -90,18 +91,18 @@ type errProtoCodec struct {
9091
decodingErr error
9192
}
9293

93-
func (c *errProtoCodec) Marshal(v any) ([]byte, error) {
94+
func (c *errProtoCodec) Marshal(v any) (mem.BufferSlice, error) {
9495
if c.encodingErr != nil {
9596
return nil, c.encodingErr
9697
}
97-
return encoding.GetCodec(proto.Name).Marshal(v)
98+
return encoding.GetCodecV2(proto.Name).Marshal(v)
9899
}
99100

100-
func (c *errProtoCodec) Unmarshal(data []byte, v any) error {
101+
func (c *errProtoCodec) Unmarshal(data mem.BufferSlice, v any) error {
101102
if c.decodingErr != nil {
102103
return c.decodingErr
103104
}
104-
return encoding.GetCodec(proto.Name).Unmarshal(data, v)
105+
return encoding.GetCodecV2(proto.Name).Unmarshal(data, v)
105106
}
106107

107108
func (c *errProtoCodec) Name() string {
@@ -118,7 +119,7 @@ func (s) TestEncodeDoesntPanicOnServer(t *testing.T) {
118119
ec := &errProtoCodec{name: t.Name(), encodingErr: encodingErr}
119120

120121
// Start a server with the above codec.
121-
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(ec))
122+
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(ec))
122123
defer backend.Stop()
123124

124125
// Create a channel to the above server.
@@ -154,7 +155,7 @@ func (s) TestDecodeDoesntPanicOnServer(t *testing.T) {
154155
ec := &errProtoCodec{name: t.Name(), decodingErr: decodingErr}
155156

156157
// Start a server with the above codec.
157-
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(ec))
158+
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(ec))
158159
defer backend.Stop()
159160

160161
// Create a channel to the above server. Since we do not specify any codec
@@ -206,15 +207,15 @@ func (s) TestEncodeDoesntPanicOnClient(t *testing.T) {
206207
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
207208
defer cancel()
208209
client := testgrpc.NewTestServiceClient(cc)
209-
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec))
210+
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec))
210211
if err == nil || !strings.Contains(err.Error(), encodingErr.Error()) {
211212
t.Fatalf("RPC failed with error: %v, want: %v", err, encodingErr)
212213
}
213214

214215
// Configure the codec on the client to not return errors anymore and expect
215216
// the RPC to succeed.
216217
ec.encodingErr = nil
217-
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)); err != nil {
218+
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec)); err != nil {
218219
t.Fatalf("RPC failed with error: %v", err)
219220
}
220221
}
@@ -242,15 +243,15 @@ func (s) TestDecodeDoesntPanicOnClient(t *testing.T) {
242243
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
243244
defer cancel()
244245
client := testgrpc.NewTestServiceClient(cc)
245-
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec))
246+
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec))
246247
if err == nil || !strings.Contains(err.Error(), decodingErr.Error()) {
247248
t.Fatalf("RPC failed with error: %v, want: %v", err, decodingErr)
248249
}
249250

250251
// Configure the codec on the client to not return errors anymore and expect
251252
// the RPC to succeed.
252253
ec.decodingErr = nil
253-
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)); err != nil {
254+
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec)); err != nil {
254255
t.Fatalf("RPC failed with error: %v", err)
255256
}
256257
}
@@ -265,14 +266,14 @@ type countingProtoCodec struct {
265266
unmarshalCount int32
266267
}
267268

268-
func (p *countingProtoCodec) Marshal(v any) ([]byte, error) {
269+
func (p *countingProtoCodec) Marshal(v any) (mem.BufferSlice, error) {
269270
atomic.AddInt32(&p.marshalCount, 1)
270-
return encoding.GetCodec(proto.Name).Marshal(v)
271+
return encoding.GetCodecV2(proto.Name).Marshal(v)
271272
}
272273

273-
func (p *countingProtoCodec) Unmarshal(data []byte, v any) error {
274+
func (p *countingProtoCodec) Unmarshal(data mem.BufferSlice, v any) error {
274275
atomic.AddInt32(&p.unmarshalCount, 1)
275-
return encoding.GetCodec(proto.Name).Unmarshal(data, v)
276+
return encoding.GetCodecV2(proto.Name).Unmarshal(data, v)
276277
}
277278

278279
func (p *countingProtoCodec) Name() string {
@@ -284,7 +285,7 @@ func (p *countingProtoCodec) Name() string {
284285
func (s) TestForceServerCodec(t *testing.T) {
285286
// Create an server with the counting proto codec.
286287
codec := &countingProtoCodec{name: t.Name()}
287-
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(codec))
288+
backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(codec))
288289
defer backend.Stop()
289290

290291
// Create a channel to the above server.
@@ -317,7 +318,7 @@ func (s) TestForceServerCodec(t *testing.T) {
317318

318319
// renameProtoCodec wraps the proto codec and allows customizing the Name().
319320
type renameProtoCodec struct {
320-
encoding.Codec
321+
encoding.CodecV2
321322
name string
322323
}
323324

@@ -356,9 +357,9 @@ func (s) TestForceCodecName(t *testing.T) {
356357

357358
// Force the use of the custom codec on the client with the ForceCodec call
358359
// option. Confirm the name is converted to lowercase before transmitting.
359-
codec := &renameProtoCodec{Codec: encoding.GetCodec(proto.Name), name: t.Name()}
360+
codec := &renameProtoCodec{CodecV2: encoding.GetCodecV2(proto.Name), name: t.Name()}
360361
wantContentTypeCh <- []string{fmt.Sprintf("application/grpc+%s", strings.ToLower(t.Name()))}
361-
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil {
362+
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(codec)); err != nil {
362363
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
363364
}
364365
}

encoding/encoding_v2.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ type CodecV2 interface {
4343
Name() string
4444
}
4545

46-
var registeredV2Codecs = make(map[string]CodecV2)
47-
4846
// RegisterCodecV2 registers the provided CodecV2 for use with all gRPC clients and
4947
// servers.
5048
//
@@ -70,13 +68,14 @@ func RegisterCodecV2(codec CodecV2) {
7068
panic("cannot register CodecV2 with empty string result for Name()")
7169
}
7270
contentSubtype := strings.ToLower(codec.Name())
73-
registeredV2Codecs[contentSubtype] = codec
71+
registeredCodecs[contentSubtype] = codec
7472
}
7573

7674
// GetCodecV2 gets a registered CodecV2 by content-subtype, or nil if no CodecV2 is
7775
// registered for the content-subtype.
7876
//
7977
// The content-subtype is expected to be lowercase.
8078
func GetCodecV2(contentSubtype string) CodecV2 {
81-
return registeredV2Codecs[contentSubtype]
79+
c, _ := registeredCodecs[contentSubtype].(CodecV2)
80+
return c
8281
}

encoding/proto/proto.go

+34-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*
22
*
3-
* Copyright 2018 gRPC authors.
3+
* Copyright 2024 gRPC authors.
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import (
2424
"fmt"
2525

2626
"google.golang.org/grpc/encoding"
27+
"google.golang.org/grpc/mem"
2728
"google.golang.org/protobuf/proto"
2829
"google.golang.org/protobuf/protoadapt"
2930
)
@@ -32,28 +33,51 @@ import (
3233
const Name = "proto"
3334

3435
func init() {
35-
encoding.RegisterCodec(codec{})
36+
encoding.RegisterCodecV2(&codecV2{})
3637
}
3738

38-
// codec is a Codec implementation with protobuf. It is the default codec for gRPC.
39-
type codec struct{}
39+
// codec is a CodecV2 implementation with protobuf. It is the default codec for
40+
// gRPC.
41+
type codecV2 struct{}
4042

41-
func (codec) Marshal(v any) ([]byte, error) {
43+
func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
4244
vv := messageV2Of(v)
4345
if vv == nil {
44-
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
46+
return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)
4547
}
4648

47-
return proto.Marshal(vv)
49+
size := proto.Size(vv)
50+
if mem.IsBelowBufferPoolingThreshold(size) {
51+
buf, err := proto.Marshal(vv)
52+
if err != nil {
53+
return nil, err
54+
}
55+
data = append(data, mem.SliceBuffer(buf))
56+
} else {
57+
pool := mem.DefaultBufferPool()
58+
buf := pool.Get(size)
59+
if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil {
60+
pool.Put(buf)
61+
return nil, err
62+
}
63+
data = append(data, mem.NewBuffer(buf, pool))
64+
}
65+
66+
return data, nil
4867
}
4968

50-
func (codec) Unmarshal(data []byte, v any) error {
69+
func (c *codecV2) Unmarshal(data mem.BufferSlice, v any) (err error) {
5170
vv := messageV2Of(v)
5271
if vv == nil {
5372
return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v)
5473
}
5574

56-
return proto.Unmarshal(data, vv)
75+
buf := data.MaterializeToBuffer(mem.DefaultBufferPool())
76+
defer buf.Free()
77+
// TODO: Upgrade proto.Unmarshal to support mem.BufferSlice. Right now, it's not
78+
// really possible without a major overhaul of the proto package, but the
79+
// vtprotobuf library may be able to support this.
80+
return proto.Unmarshal(buf.ReadOnlyData(), vv)
5781
}
5882

5983
func messageV2Of(v any) proto.Message {
@@ -67,6 +91,6 @@ func messageV2Of(v any) proto.Message {
6791
return nil
6892
}
6993

70-
func (codec) Name() string {
94+
func (c *codecV2) Name() string {
7195
return Name
7296
}

encoding/proto/proto_benchmark_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func BenchmarkProtoCodec(b *testing.B) {
6868
protoStructs := setupBenchmarkProtoCodecInputs(s)
6969
name := fmt.Sprintf("MinPayloadSize:%v/SetParallelism(%v)", s, p)
7070
b.Run(name, func(b *testing.B) {
71-
codec := &codec{}
71+
codec := &codecV2{}
7272
b.SetParallelism(p)
7373
b.RunParallel(func(pb *testing.PB) {
7474
benchmarkProtoCodec(codec, protoStructs, pb, b)
@@ -78,7 +78,7 @@ func BenchmarkProtoCodec(b *testing.B) {
7878
}
7979
}
8080

81-
func benchmarkProtoCodec(codec *codec, protoStructs []proto.Message, pb *testing.PB, b *testing.B) {
81+
func benchmarkProtoCodec(codec *codecV2, protoStructs []proto.Message, pb *testing.PB, b *testing.B) {
8282
counter := 0
8383
for pb.Next() {
8484
counter++
@@ -87,7 +87,7 @@ func benchmarkProtoCodec(codec *codec, protoStructs []proto.Message, pb *testing
8787
}
8888
}
8989

90-
func fastMarshalAndUnmarshal(codec encoding.Codec, protoStruct proto.Message, b *testing.B) {
90+
func fastMarshalAndUnmarshal(codec encoding.CodecV2, protoStruct proto.Message, b *testing.B) {
9191
marshaledBytes, err := codec.Marshal(protoStruct)
9292
if err != nil {
9393
b.Errorf("codec.Marshal(_) returned an error")

encoding/proto/proto_test.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ import (
2525

2626
"google.golang.org/grpc/encoding"
2727
"google.golang.org/grpc/internal/grpctest"
28+
"google.golang.org/grpc/mem"
2829
pb "google.golang.org/grpc/test/codec_perf"
2930
)
3031

31-
func marshalAndUnmarshal(t *testing.T, codec encoding.Codec, expectedBody []byte) {
32+
func marshalAndUnmarshal(t *testing.T, codec encoding.CodecV2, expectedBody []byte) {
3233
p := &pb.Buffer{}
3334
p.Body = expectedBody
3435

@@ -55,7 +56,7 @@ func Test(t *testing.T) {
5556
}
5657

5758
func (s) TestBasicProtoCodecMarshalAndUnmarshal(t *testing.T) {
58-
marshalAndUnmarshal(t, codec{}, []byte{1, 2, 3})
59+
marshalAndUnmarshal(t, &codecV2{}, []byte{1, 2, 3})
5960
}
6061

6162
// Try to catch possible race conditions around use of pools
@@ -75,7 +76,7 @@ func (s) TestConcurrentUsage(t *testing.T) {
7576
}
7677

7778
var wg sync.WaitGroup
78-
codec := codec{}
79+
codec := &codecV2{}
7980

8081
for i := 0; i < numGoRoutines; i++ {
8182
wg.Add(1)
@@ -93,16 +94,16 @@ func (s) TestConcurrentUsage(t *testing.T) {
9394
// TestStaggeredMarshalAndUnmarshalUsingSamePool tries to catch potential errors in which slices get
9495
// stomped on during reuse of a proto.Buffer.
9596
func (s) TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) {
96-
codec1 := codec{}
97-
codec2 := codec{}
97+
codec1 := &codecV2{}
98+
codec2 := &codecV2{}
9899

99100
expectedBody1 := []byte{1, 2, 3}
100101
expectedBody2 := []byte{4, 5, 6}
101102

102103
proto1 := pb.Buffer{Body: expectedBody1}
103104
proto2 := pb.Buffer{Body: expectedBody2}
104105

105-
var m1, m2 []byte
106+
var m1, m2 mem.BufferSlice
106107
var err error
107108

108109
if m1, err = codec1.Marshal(&proto1); err != nil {

0 commit comments

Comments
 (0)