diff --git a/invoke.go b/invoke.go index 0db362c5..b5bae4b1 100644 --- a/invoke.go +++ b/invoke.go @@ -227,15 +227,19 @@ func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.Met return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) } - if respHeaders, err := str.Header(); err == nil { - handler.OnReceiveHeaders(respHeaders) + if str != nil { + if respHeaders, err := str.Header(); err == nil { + handler.OnReceiveHeaders(respHeaders) + } } if stat.Code() == codes.OK { handler.OnReceiveResponse(resp) } - handler.OnReceiveTrailers(stat, str.Trailer()) + if str != nil { + handler.OnReceiveTrailers(stat, str.Trailer()) + } return nil } @@ -334,8 +338,10 @@ func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescr }() } - if respHeaders, err := str.Header(); err == nil { - handler.OnReceiveHeaders(respHeaders) + if str != nil { + if respHeaders, err := str.Header(); err == nil { + handler.OnReceiveHeaders(respHeaders) + } } // Download each response message @@ -362,7 +368,9 @@ func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescr return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) } - handler.OnReceiveTrailers(stat, str.Trailer()) + if str != nil { + handler.OnReceiveTrailers(stat, str.Trailer()) + } return nil }