20
20
import static org .apache .inlong .tubemq .corebase .utils .AddressUtils .getRemoteAddressIP ;
21
21
22
22
import io .netty .buffer .ByteBuf ;
23
+ import io .netty .buffer .Unpooled ;
23
24
import io .netty .channel .Channel ;
24
25
import io .netty .channel .ChannelHandlerContext ;
25
26
import io .netty .handler .codec .MessageToMessageDecoder ;
27
+ import io .netty .util .ReferenceCountUtil ;
26
28
import java .nio .ByteBuffer ;
27
29
import java .util .ArrayList ;
28
30
import java .util .List ;
@@ -43,50 +45,83 @@ public class NettyProtocolDecoder extends MessageToMessageDecoder<ByteBuf> {
43
45
new ConcurrentHashMap <>();
44
46
private static AtomicLong lastProtolTime = new AtomicLong (0 );
45
47
private static AtomicLong lastSizeTime = new AtomicLong (0 );
48
+ private boolean packHeaderRead = false ;
49
+ private int listSize ;
50
+ private List <RpcDataPack > rpcDataPackList = new ArrayList <>();
51
+ private RpcDataPack dataPack ;
52
+ private ByteBuf lastByteBuf ;
46
53
47
54
@ Override
48
55
protected void decode (ChannelHandlerContext ctx , ByteBuf buffer , List <Object > out ) throws Exception {
49
- if (buffer .readableBytes () < 12 ) {
50
- logger .warn ("Decode buffer.readableBytes() < 12 !" );
51
- return ;
52
- }
53
- int frameToken = buffer .readInt ();
54
- filterIllegalPkgToken (frameToken ,
55
- RpcConstants .RPC_PROTOCOL_BEGIN_TOKEN , ctx .channel ());
56
- int serialNo = buffer .readInt ();
57
- int tmpListSize = buffer .readInt ();
58
- filterIllegalPackageSize (true , tmpListSize ,
59
- RpcConstants .MAX_FRAME_MAX_LIST_SIZE , ctx .channel ());
60
- RpcDataPack dataPack = new RpcDataPack (serialNo , new ArrayList <ByteBuffer >());
61
- // get PackBody
62
- int i = 0 ;
63
- while (i < tmpListSize ) {
64
- i ++;
56
+ buffer = convertToNewBuf (buffer );
57
+ while (buffer .readableBytes () > 0 ) {
58
+ if (!packHeaderRead ) {
59
+ if (buffer .readableBytes () < 12 ) {
60
+ saveRemainedByteBuf (buffer );
61
+ break ;
62
+ }
63
+ int frameToken = buffer .readInt ();
64
+ filterIllegalPkgToken (frameToken , RpcConstants .RPC_PROTOCOL_BEGIN_TOKEN , ctx .channel ());
65
+ int serialNo = buffer .readInt ();
66
+ int tmpListSize = buffer .readInt ();
67
+ filterIllegalPackageSize (true , tmpListSize ,
68
+ RpcConstants .MAX_FRAME_MAX_LIST_SIZE , ctx .channel ());
69
+ this .listSize = tmpListSize ;
70
+ this .dataPack = new RpcDataPack (serialNo , new ArrayList <>(this .listSize ));
71
+ this .packHeaderRead = true ;
72
+ }
73
+ // get PackBody
65
74
if (buffer .readableBytes () < 4 ) {
66
- logger . warn ( "Decode buffer.readableBytes() < 4 !" );
75
+ saveRemainedByteBuf ( buffer );
67
76
break ;
68
77
}
69
78
buffer .markReaderIndex ();
70
79
int length = buffer .readInt ();
71
- filterIllegalPackageSize (false , length ,
72
- RpcConstants .RPC_MAX_BUFFER_SIZE , ctx .channel ());
80
+ if (buffer .readableBytes () < length ) {
81
+ buffer .resetReaderIndex ();
82
+ saveRemainedByteBuf (buffer );
83
+ break ;
84
+ }
73
85
ByteBuffer bb = ByteBuffer .allocate (length );
74
86
buffer .readBytes (bb );
75
87
bb .flip ();
76
88
dataPack .getDataLst ().add (bb );
89
+ if (dataPack .getDataLst ().size () == listSize ) {
90
+ packHeaderRead = false ;
91
+ rpcDataPackList .add (dataPack );
92
+ }
77
93
}
94
+ if (rpcDataPackList .size () > 0 ) {
95
+ out .addAll (rpcDataPackList );
96
+ rpcDataPackList .clear ();
97
+ }
98
+ }
99
+
100
+ private void saveRemainedByteBuf (ByteBuf byteBuf ) {
101
+ if (byteBuf != null && byteBuf .readableBytes () > 0 ) {
102
+ lastByteBuf = Unpooled .copiedBuffer (byteBuf );
103
+ }
104
+ }
78
105
79
- if (dataPack .getDataLst ().size () == tmpListSize ) {
80
- out .add (dataPack );
81
- } else {
82
- logger .warn ("Decode dataPack.getDataLst().size()[{}] != tmpListSize [{}] !" ,
83
- dataPack .getDataLst ().size (), tmpListSize );
84
- return ;
106
+ private ByteBuf convertToNewBuf (ByteBuf byteBuf ) {
107
+ ByteBuf newByteBuf = byteBuf ;
108
+ int totalReadBytes = byteBuf .readableBytes ();
109
+ if (lastByteBuf != null ) {
110
+ try {
111
+ totalReadBytes += lastByteBuf .readableBytes ();
112
+ newByteBuf = Unpooled .buffer (totalReadBytes );
113
+ newByteBuf .writeBytes (lastByteBuf );
114
+ newByteBuf .writeBytes (byteBuf );
115
+ } finally {
116
+ ReferenceCountUtil .release (lastByteBuf );
117
+ }
118
+ lastByteBuf = null ;
85
119
}
120
+ return newByteBuf ;
86
121
}
87
122
88
- private void filterIllegalPkgToken (int inParamValue ,
89
- int allowTokenVal , Channel channel ) throws UnknownProtocolException {
123
+ private void filterIllegalPkgToken (int inParamValue , int allowTokenVal ,
124
+ Channel channel ) throws UnknownProtocolException {
90
125
if (inParamValue != allowTokenVal ) {
91
126
String rmtaddrIp = getRemoteAddressIP (channel );
92
127
if (rmtaddrIp != null ) {
@@ -103,7 +138,11 @@ private void filterIllegalPkgToken(int inParamValue,
103
138
long curTime = System .currentTimeMillis ();
104
139
if (curTime - befTime > 180000 ) {
105
140
if (lastProtolTime .compareAndSet (befTime , System .currentTimeMillis ())) {
106
- logger .warn ("[Abnormal Visit] OSS Tube visit list is :" + errProtolAddrMap .toString ());
141
+ logger .warn ("[Abnormal Visit] OSS Tube [inParamValue = {} vs "
142
+ + "allowTokenVal = {}] visit "
143
+ + "list is : {}" ,
144
+ inParamValue , allowTokenVal ,
145
+ errProtolAddrMap .toString ());
107
146
errProtolAddrMap .clear ();
108
147
}
109
148
}
0 commit comments