Skip to content

Commit a7d4e94

Browse files
authored
Merge pull request #10547 from deannagarcia/3.16.x
Apply patch
2 parents 152d7bf + 55815e4 commit a7d4e94

File tree

4 files changed

+152
-37
lines changed

4 files changed

+152
-37
lines changed

Diff for: src/google/protobuf/extension_set_inl.h

+18-9
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
206206
const char* ptr, const Msg* containing_type,
207207
internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
208208
std::string payload;
209-
uint32 type_id = 0;
210-
bool payload_read = false;
209+
uint32 type_id;
210+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
211+
State state = State::kNoTag;
212+
211213
while (!ctx->Done(&ptr)) {
212214
uint32 tag = static_cast<uint8>(*ptr++);
213215
if (tag == WireFormatLite::kMessageSetTypeIdTag) {
214216
uint64 tmp;
215217
ptr = ParseBigVarint(ptr, &tmp);
216218
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
217-
type_id = tmp;
218-
if (payload_read) {
219+
if (state == State::kNoTag) {
220+
type_id = tmp;
221+
state = State::kHasType;
222+
} else if (state == State::kHasPayload) {
223+
type_id = tmp;
219224
ExtensionInfo extension;
220225
bool was_packed_on_wire;
221226
if (!FindExtension(2, type_id, containing_type, ctx, &extension,
@@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
241246
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
242247
tmp_ctx.EndedAtLimit());
243248
}
244-
type_id = 0;
249+
state = State::kDone;
245250
}
246251
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
247-
if (type_id != 0) {
252+
if (state == State::kHasType) {
248253
ptr = ParseFieldMaybeLazily(static_cast<uint64>(type_id) * 8 + 2, ptr,
249254
containing_type, metadata, ctx);
250255
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
251-
type_id = 0;
256+
state = State::kDone;
252257
} else {
258+
std::string tmp;
253259
int32 size = ReadSize(&ptr);
254260
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
255-
ptr = ctx->ReadString(ptr, size, &payload);
261+
ptr = ctx->ReadString(ptr, size, &tmp);
256262
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
257-
payload_read = true;
263+
if (state == State::kNoTag) {
264+
payload = std::move(tmp);
265+
state = State::kHasPayload;
266+
}
258267
}
259268
} else {
260269
ptr = ReadTag(ptr - 1, &tag);

Diff for: src/google/protobuf/wire_format.cc

+18-8
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser {
657657
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
658658
// Parse a MessageSetItem
659659
auto metadata = reflection->MutableInternalMetadata(msg);
660+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
661+
State state = State::kNoTag;
662+
660663
std::string payload;
661664
uint32 type_id = 0;
662-
bool payload_read = false;
663665
while (!ctx->Done(&ptr)) {
664666
// We use 64 bit tags in order to allow typeid's that span the whole
665667
// range of 32 bit numbers.
@@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser {
668670
uint64 tmp;
669671
ptr = ParseBigVarint(ptr, &tmp);
670672
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
671-
type_id = tmp;
672-
if (payload_read) {
673+
if (state == State::kNoTag) {
674+
type_id = tmp;
675+
state = State::kHasType;
676+
} else if (state == State::kHasPayload) {
677+
type_id = tmp;
673678
const FieldDescriptor* field;
674679
if (ctx->data().pool == nullptr) {
675680
field = reflection->FindKnownExtensionByNumber(type_id);
@@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser {
696701
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
697702
tmp_ctx.EndedAtLimit());
698703
}
699-
type_id = 0;
704+
state = State::kDone;
700705
}
701706
continue;
702707
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
703-
if (type_id == 0) {
708+
if (state == State::kNoTag) {
704709
int32 size = ReadSize(&ptr);
705710
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
706711
ptr = ctx->ReadString(ptr, size, &payload);
707712
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
708-
payload_read = true;
709-
} else {
713+
state = State::kHasPayload;
714+
} else if (state == State::kHasType) {
710715
// We're now parsing the payload
711716
const FieldDescriptor* field = nullptr;
712717
if (descriptor->IsExtensionNumber(type_id)) {
@@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser {
720725
ptr = WireFormat::_InternalParseAndMergeField(
721726
msg, ptr, ctx, static_cast<uint64>(type_id) * 8 + 2, reflection,
722727
field);
723-
type_id = 0;
728+
state = State::kDone;
729+
} else {
730+
int32 size = ReadSize(&ptr);
731+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
732+
ptr = ctx->Skip(ptr, size);
733+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
724734
}
725735
} else {
726736
// An unknown field in MessageSetItem.

Diff for: src/google/protobuf/wire_format_lite.h

+18-9
Original file line numberDiff line numberDiff line change
@@ -1798,6 +1798,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
17981798
// we can parse it later.
17991799
std::string message_data;
18001800

1801+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
1802+
State state = State::kNoTag;
1803+
18011804
while (true) {
18021805
const uint32 tag = input->ReadTagNoLastTag();
18031806
if (tag == 0) return false;
@@ -1806,26 +1809,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18061809
case WireFormatLite::kMessageSetTypeIdTag: {
18071810
uint32 type_id;
18081811
if (!input->ReadVarint32(&type_id)) return false;
1809-
last_type_id = type_id;
1810-
1811-
if (!message_data.empty()) {
1812+
if (state == State::kNoTag) {
1813+
last_type_id = type_id;
1814+
state = State::kHasType;
1815+
} else if (state == State::kHasPayload) {
18121816
// We saw some message data before the type_id. Have to parse it
18131817
// now.
18141818
io::CodedInputStream sub_input(
18151819
reinterpret_cast<const uint8*>(message_data.data()),
18161820
static_cast<int>(message_data.size()));
18171821
sub_input.SetRecursionLimit(input->RecursionBudget());
1818-
if (!ms.ParseField(last_type_id, &sub_input)) {
1822+
if (!ms.ParseField(type_id, &sub_input)) {
18191823
return false;
18201824
}
18211825
message_data.clear();
1826+
state = State::kDone;
18221827
}
18231828

18241829
break;
18251830
}
18261831

18271832
case WireFormatLite::kMessageSetMessageTag: {
1828-
if (last_type_id == 0) {
1833+
if (state == State::kHasType) {
1834+
// Already saw type_id, so we can parse this directly.
1835+
if (!ms.ParseField(last_type_id, input)) {
1836+
return false;
1837+
}
1838+
state = State::kDone;
1839+
} else if (state == State::kNoTag) {
18291840
// We haven't seen a type_id yet. Append this data to message_data.
18301841
uint32 length;
18311842
if (!input->ReadVarint32(&length)) return false;
@@ -1836,11 +1847,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18361847
auto ptr = reinterpret_cast<uint8*>(&message_data[0]);
18371848
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
18381849
if (!input->ReadRaw(ptr, length)) return false;
1850+
state = State::kHasPayload;
18391851
} else {
1840-
// Already saw type_id, so we can parse this directly.
1841-
if (!ms.ParseField(last_type_id, input)) {
1842-
return false;
1843-
}
1852+
if (!ms.SkipField(tag, input)) return false;
18441853
}
18451854

18461855
break;

Diff for: src/google/protobuf/wire_format_unittest.cc

+98-11
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include <google/protobuf/io/zero_copy_stream_impl.h>
4747
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
4848
#include <google/protobuf/descriptor.h>
49+
#include <google/protobuf/dynamic_message.h>
4950
#include <google/protobuf/wire_format_lite.h>
5051
#include <google/protobuf/testing/googletest.h>
5152
#include <google/protobuf/stubs/logging.h>
@@ -585,41 +586,72 @@ TEST(WireFormatTest, ParseMessageSet) {
585586
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
586587
}
587588

588-
TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
589+
namespace {
590+
std::string BuildMessageSetItemStart() {
589591
std::string data;
590592
{
591-
unittest::TestMessageSetExtension1 message;
592-
message.set_i(123);
593-
// Build a MessageSet manually with its message content put before its
594-
// type_id.
595593
io::StringOutputStream output_stream(&data);
596594
io::CodedOutputStream coded_output(&output_stream);
597595
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
596+
}
597+
return data;
598+
}
599+
std::string BuildMessageSetItemEnd() {
600+
std::string data;
601+
{
602+
io::StringOutputStream output_stream(&data);
603+
io::CodedOutputStream coded_output(&output_stream);
604+
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
605+
}
606+
return data;
607+
}
608+
std::string BuildMessageSetTestExtension1(int value = 123) {
609+
std::string data;
610+
{
611+
unittest::TestMessageSetExtension1 message;
612+
message.set_i(value);
613+
io::StringOutputStream output_stream(&data);
614+
io::CodedOutputStream coded_output(&output_stream);
598615
// Write the message content first.
599616
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
600617
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
601618
&coded_output);
602619
coded_output.WriteVarint32(message.ByteSizeLong());
603620
message.SerializeWithCachedSizes(&coded_output);
604-
// Write the type id.
605-
uint32 type_id = message.GetDescriptor()->extension(0)->number();
621+
}
622+
return data;
623+
}
624+
std::string BuildMessageSetItemTypeId(int extension_number) {
625+
std::string data;
626+
{
627+
io::StringOutputStream output_stream(&data);
628+
io::CodedOutputStream coded_output(&output_stream);
606629
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
607-
type_id, &coded_output);
608-
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
630+
extension_number, &coded_output);
609631
}
632+
return data;
633+
}
634+
void ValidateTestMessageSet(const std::string& test_case,
635+
const std::string& data) {
636+
SCOPED_TRACE(test_case);
610637
{
611-
proto2_wireformat_unittest::TestMessageSet message_set;
638+
::proto2_wireformat_unittest::TestMessageSet message_set;
612639
ASSERT_TRUE(message_set.ParseFromString(data));
613640

614641
EXPECT_EQ(123,
615642
message_set
616643
.GetExtension(
617644
unittest::TestMessageSetExtension1::message_set_extension)
618645
.i());
646+
647+
// Make sure it does not contain anything else.
648+
message_set.ClearExtension(
649+
unittest::TestMessageSetExtension1::message_set_extension);
650+
EXPECT_EQ(message_set.SerializeAsString(), "");
619651
}
620652
{
621653
// Test parse the message via Reflection.
622-
proto2_wireformat_unittest::TestMessageSet message_set;
654+
::proto2_wireformat_unittest::TestMessageSet message_set;
623655
io::CodedInputStream input(reinterpret_cast<const uint8*>(data.data()),
624656
data.size());
625657
EXPECT_TRUE(WireFormat::ParseAndMergePartial(&input, &message_set));
@@ -631,6 +663,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
631663
unittest::TestMessageSetExtension1::message_set_extension)
632664
.i());
633665
}
666+
{
667+
// Test parse the message via DynamicMessage.
668+
DynamicMessageFactory factory;
669+
std::unique_ptr<Message> msg(
670+
factory
671+
.GetPrototype(
672+
::proto2_wireformat_unittest::TestMessageSet::descriptor())
673+
->New());
674+
msg->ParseFromString(data);
675+
auto* reflection = msg->GetReflection();
676+
std::vector<const FieldDescriptor*> fields;
677+
reflection->ListFields(*msg, &fields);
678+
ASSERT_EQ(fields.size(), 1);
679+
const auto& sub = reflection->GetMessage(*msg, fields[0]);
680+
reflection = sub.GetReflection();
681+
EXPECT_EQ(123, reflection->GetInt32(
682+
sub, sub.GetDescriptor()->FindFieldByName("i")));
683+
}
684+
}
685+
} // namespace
686+
687+
TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
688+
std::string start = BuildMessageSetItemStart();
689+
std::string end = BuildMessageSetItemEnd();
690+
std::string id = BuildMessageSetItemTypeId(
691+
unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
692+
std::string message = BuildMessageSetTestExtension1();
693+
694+
ValidateTestMessageSet("id + message", start + id + message + end);
695+
ValidateTestMessageSet("message + id", start + message + id + end);
696+
}
697+
698+
TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
699+
std::string start = BuildMessageSetItemStart();
700+
std::string end = BuildMessageSetItemEnd();
701+
std::string id = BuildMessageSetItemTypeId(
702+
unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
703+
std::string other_id = BuildMessageSetItemTypeId(123456);
704+
std::string message = BuildMessageSetTestExtension1();
705+
std::string other_message = BuildMessageSetTestExtension1(321);
706+
707+
// Double id
708+
ValidateTestMessageSet("id + other_id + message",
709+
start + id + other_id + message + end);
710+
ValidateTestMessageSet("id + message + other_id",
711+
start + id + message + other_id + end);
712+
ValidateTestMessageSet("message + id + other_id",
713+
start + message + id + other_id + end);
714+
// Double message
715+
ValidateTestMessageSet("id + message + other_message",
716+
start + id + message + other_message + end);
717+
ValidateTestMessageSet("message + id + other_message",
718+
start + message + id + other_message + end);
719+
ValidateTestMessageSet("message + other_message + id",
720+
start + message + other_message + id + end);
634721
}
635722

636723
void SerializeReverseOrder(

0 commit comments

Comments
 (0)