@@ -26,6 +26,7 @@ import org.opencoap.ssl.SslSession
26
26
import org.slf4j.LoggerFactory
27
27
import java.net.InetSocketAddress
28
28
import java.nio.ByteBuffer
29
+ import java.nio.ByteOrder
29
30
import java.time.Duration
30
31
import java.time.Instant
31
32
import java.util.concurrent.CompletableFuture
@@ -38,7 +39,8 @@ class DtlsServer(
38
39
private val expireAfter : Duration = Duration .ofSeconds(60),
39
40
private val storeSession : (cid: ByteArray , session: SessionWithContext ) -> Unit ,
40
41
private val lifecycleCallbacks : DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {},
41
- private val executor : ScheduledExecutorService
42
+ private val executor : ScheduledExecutorService ,
43
+ private val cidRequired : Boolean = false
42
44
) {
43
45
companion object {
44
46
private val EMPTY_BUFFER = ByteBuffer .allocate(0 )
@@ -49,11 +51,12 @@ class DtlsServer(
49
51
50
52
// note: non thread save, must be used only from same thread
51
53
private val sessions = mutableMapOf<InetSocketAddress , DtlsState >()
52
- private val cidSize = sslConfig.cidSupplier.next().size
54
+ private val cidSize = sslConfig.cidSupplier? .next()? .size ? : 0
53
55
val numberOfSessions get() = sessions.size
54
56
55
57
fun handleReceived (adr : InetSocketAddress , buf : ByteBuffer ): ReceiveResult {
56
58
val cid by lazy { SslContext .peekCID(cidSize, buf) }
59
+ val isValidHandshake by lazy { isValidHandshakeRequest(buf) }
57
60
val dtlsState = sessions[adr]
58
61
59
62
return when {
@@ -63,12 +66,19 @@ class DtlsServer(
63
66
// no session, but dtls packet contains CID
64
67
cid != null -> ReceiveResult .CidSessionMissing (cid!! )
65
68
66
- // new handshake
67
- else -> {
69
+ // start new handshake if datagram is valid
70
+ isValidHandshake -> {
68
71
val dtlsHandshake = DtlsHandshake (sslConfig.newContext(adr), adr)
69
72
sessions[adr] = dtlsHandshake
70
73
dtlsHandshake.step(buf)
71
74
}
75
+
76
+ // drop silently
77
+ else -> {
78
+ logger.warn(" [{}] Invalid DTLS session handshake." , adr)
79
+ reportMessageDrop(adr)
80
+ ReceiveResult .Handled
81
+ }
72
82
}
73
83
}
74
84
@@ -186,6 +196,7 @@ class DtlsServer(
186
196
when (ex) {
187
197
is SslException ->
188
198
logger.warn(" [{}] DTLS failed: {}" , peerAddress, ex.message)
199
+
189
200
else ->
190
201
logger.error(ex.toString(), ex)
191
202
}
@@ -305,4 +316,66 @@ class DtlsServer(
305
316
lifecycleCallbacks.sessionFinished(peerAddress, reason, err)
306
317
}
307
318
}
319
+
320
+ private fun isValidHandshakeRequest (buf : ByteBuffer ): Boolean {
321
+ val workingBuf = buf.slice().order(ByteOrder .BIG_ENDIAN )
322
+
323
+ // Check if the header is correct
324
+ val header = workingBuf.getLong(0 )
325
+ if (header != 0x16FEFD0000000000L ) {
326
+ logger.debug(" Bad DTLS header" )
327
+ return false
328
+ }
329
+
330
+ // Check if it is a ClientHello handshake
331
+ val handshakeType = workingBuf.get(13 ).toInt()
332
+ if (handshakeType != 1 ) {
333
+ logger.debug(" Bad handshake type" )
334
+ return false
335
+ }
336
+
337
+ // Check if CID is supported by the client in case if CID support is mandatory
338
+ if (cidRequired && ! supportsCid(workingBuf)) {
339
+ logger.debug(" No CID support" )
340
+ return false
341
+ }
342
+
343
+ return true
344
+ }
345
+
346
+ private fun supportsCid (buf : ByteBuffer ): Boolean {
347
+ val workingBuffer = buf.slice().order(ByteOrder .BIG_ENDIAN )
348
+
349
+ // Go to the start of extensions
350
+ workingBuffer
351
+ // Skip DTLSHeader(13) + HandshakeHeader(12) + CookieLengthOffset(35)
352
+ .seek(60 )
353
+ // Skip variable-length Cookie
354
+ .readByteAndSeek()
355
+ // Skip variable-length CipherSuites
356
+ .readShortAndSeek()
357
+ // Skip variable-length CompressionMethods
358
+ .readByteAndSeek()
359
+ // Limit buffer to the extensions length
360
+ .getShort().also {
361
+ workingBuffer.limit(workingBuffer.position() + it.toInt())
362
+ }
363
+
364
+ // Search for CID extension
365
+ while (workingBuffer.remaining() >= 4 ) {
366
+ val type = workingBuffer.getShort()
367
+ if (type == 0x36 .toShort()) {
368
+ return true
369
+ }
370
+
371
+ // Skip to the next extension
372
+ workingBuffer.readShortAndSeek()
373
+ }
374
+
375
+ return false
376
+ }
308
377
}
378
+
379
+ private fun ByteBuffer.seek (offset : Int ): ByteBuffer = this .position(this .position() + offset) as ByteBuffer
380
+ private fun ByteBuffer.readShortAndSeek (): ByteBuffer = this .getShort().let { this .seek(it.toInt()) }
381
+ private fun ByteBuffer.readByteAndSeek (): ByteBuffer = this .get().let { this .seek(it.toInt()) }
0 commit comments