Skip to content

Commit cd98cc0

Browse files
jhumpcarmo-evan
andauthored
desc: avoid panic if descriptor proto has invalid type references (#525)
Co-authored-by: Evan do Carmo <carmo.evan@gmail.com>
1 parent ab4615d commit cd98cc0

File tree

1 file changed

+58
-11
lines changed

1 file changed

+58
-11
lines changed

desc/descriptor.go

+58-11
Original file line numberDiff line numberDiff line change
@@ -549,31 +549,69 @@ func createFieldDescriptor(fd *FileDescriptor, parent Descriptor, enclosing stri
549549
return ret, fldName
550550
}
551551

552+
func descriptorType(d Descriptor) string {
553+
switch d := d.(type) {
554+
case *FileDescriptor:
555+
return "a file"
556+
case *MessageDescriptor:
557+
return "a message"
558+
case *FieldDescriptor:
559+
if d.IsExtension() {
560+
return "an extension"
561+
}
562+
return "a field"
563+
case *OneOfDescriptor:
564+
return "a oneof"
565+
case *EnumDescriptor:
566+
return "an enum"
567+
case *EnumValueDescriptor:
568+
return "an enum value"
569+
case *ServiceDescriptor:
570+
return "a service"
571+
case *MethodDescriptor:
572+
return "a method"
573+
default:
574+
return fmt.Sprintf("a %T", d)
575+
}
576+
}
577+
552578
func (fd *FieldDescriptor) resolve(path []int32, scopes []scope) error {
553579
if fd.proto.OneofIndex != nil && fd.oneOf == nil {
554580
return fmt.Errorf("could not link field %s to one-of index %d", fd.fqn, *fd.proto.OneofIndex)
555581
}
556582
fd.sourceInfoPath = append([]int32(nil), path...) // defensive copy
557583
if fd.proto.GetType() == dpb.FieldDescriptorProto_TYPE_ENUM {
558-
if desc, err := resolve(fd.file, fd.proto.GetTypeName(), scopes); err != nil {
584+
desc, err := resolve(fd.file, fd.proto.GetTypeName(), scopes)
585+
if err != nil {
559586
return err
560-
} else {
561-
fd.enumType = desc.(*EnumDescriptor)
562587
}
588+
enumType, ok := desc.(*EnumDescriptor)
589+
if !ok {
590+
return fmt.Errorf("field %v indicates a type of enum, but references %q which is %s", fd.fqn, fd.proto.GetTypeName(), descriptorType(desc))
591+
}
592+
fd.enumType = enumType
563593
}
564594
if fd.proto.GetType() == dpb.FieldDescriptorProto_TYPE_MESSAGE || fd.proto.GetType() == dpb.FieldDescriptorProto_TYPE_GROUP {
565-
if desc, err := resolve(fd.file, fd.proto.GetTypeName(), scopes); err != nil {
595+
desc, err := resolve(fd.file, fd.proto.GetTypeName(), scopes)
596+
if err != nil {
566597
return err
567-
} else {
568-
fd.msgType = desc.(*MessageDescriptor)
569598
}
599+
msgType, ok := desc.(*MessageDescriptor)
600+
if !ok {
601+
return fmt.Errorf("field %v indicates a type of message, but references %q which is %s", fd.fqn, fd.proto.GetTypeName(), descriptorType(desc))
602+
}
603+
fd.msgType = msgType
570604
}
571605
if fd.proto.GetExtendee() != "" {
572-
if desc, err := resolve(fd.file, fd.proto.GetExtendee(), scopes); err != nil {
606+
desc, err := resolve(fd.file, fd.proto.GetExtendee(), scopes)
607+
if err != nil {
573608
return err
574-
} else {
575-
fd.owner = desc.(*MessageDescriptor)
576609
}
610+
msgType, ok := desc.(*MessageDescriptor)
611+
if !ok {
612+
return fmt.Errorf("field %v extends %q which should be a message but is %s", fd.fqn, fd.proto.GetExtendee(), descriptorType(desc))
613+
}
614+
fd.owner = msgType
577615
}
578616
fd.file.registerField(fd)
579617
fd.isMap = fd.proto.GetLabel() == dpb.FieldDescriptorProto_LABEL_REPEATED &&
@@ -1119,6 +1157,7 @@ func (sv sortedValues) Less(i, j int) bool {
11191157

11201158
func (sv sortedValues) Swap(i, j int) {
11211159
sv[i], sv[j] = sv[j], sv[i]
1160+
11221161
}
11231162

11241163
func (ed *EnumDescriptor) resolve(path []int32) {
@@ -1441,12 +1480,20 @@ func (md *MethodDescriptor) resolve(path []int32, scopes []scope) error {
14411480
if desc, err := resolve(md.file, md.proto.GetInputType(), scopes); err != nil {
14421481
return err
14431482
} else {
1444-
md.inType = desc.(*MessageDescriptor)
1483+
msgType, ok := desc.(*MessageDescriptor)
1484+
if !ok {
1485+
return fmt.Errorf("method %v has request type %q which should be a message but is %s", md.fqn, md.proto.GetInputType(), descriptorType(desc))
1486+
}
1487+
md.inType = msgType
14451488
}
14461489
if desc, err := resolve(md.file, md.proto.GetOutputType(), scopes); err != nil {
14471490
return err
14481491
} else {
1449-
md.outType = desc.(*MessageDescriptor)
1492+
msgType, ok := desc.(*MessageDescriptor)
1493+
if !ok {
1494+
return fmt.Errorf("method %v has response type %q which should be a message but is %s", md.fqn, md.proto.GetOutputType(), descriptorType(desc))
1495+
}
1496+
md.outType = msgType
14501497
}
14511498
return nil
14521499
}

0 commit comments

Comments
 (0)