Skip to content

Commit b0d3fde

Browse files
authored
Provide a way to require CID support from a peer (#53)
1 parent f431eb8 commit b0d3fde

File tree

7 files changed

+131
-20
lines changed

7 files changed

+131
-20
lines changed

kotlin-mbedtls-metrics/src/test/kotlin/org/opencoap/ssl/transport/metrics/micrometer/DtlsServerMetricsCallbacksTest.kt

+1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class DtlsServerMetricsCallbacksTest {
117117
}
118118

119119
@Test
120+
@Disabled("After implementation of invalid handshake datagrams dropping it's hard to simulate wrong handshake")
120121
fun `should report DTLS server metrics for handshake errors`() {
121122
server = DtlsServerTransport.create(conf, lifecycleCallbacks = metricsCallbacks).listen(echoHandler)
122123
val cliChannel: DatagramChannel = DatagramChannel.open()

kotlin-mbedtls-netty/src/main/kotlin/org/opencoap/ssl/netty/DtlsChannelHandler.kt

+3-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class DtlsChannelHandler @JvmOverloads constructor(
3636
private val sslConfig: SslConfig,
3737
private val expireAfter: Duration = Duration.ofSeconds(60),
3838
private val sessionStore: SessionStore = NoOpsSessionStore,
39-
private val lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {}
39+
private val lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {},
40+
private val cidRequired: Boolean = false
4041
) : ChannelDuplexHandler() {
4142
private lateinit var ctx: ChannelHandlerContext
4243
lateinit var dtlsServer: DtlsServer
@@ -52,7 +53,7 @@ class DtlsChannelHandler @JvmOverloads constructor(
5253

5354
override fun handlerAdded(ctx: ChannelHandlerContext) {
5455
this.ctx = ctx
55-
this.dtlsServer = DtlsServer(::write, sslConfig, expireAfter, sessionStore::write, lifecycleCallbacks, ctx.executor())
56+
this.dtlsServer = DtlsServer(::write, sslConfig, expireAfter, sessionStore::write, lifecycleCallbacks, ctx.executor(), cidRequired)
5657
}
5758

5859
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {

kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslConfig.kt

+10-8
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ import java.time.Duration.ofSeconds
6363

6464
class SslConfig(
6565
private val conf: Memory,
66-
val cidSupplier: CidSupplier,
66+
val cidSupplier: CidSupplier?,
6767
private val mtu: Int,
6868
private val close: Closeable
6969
) : Closeable by close {
@@ -75,8 +75,10 @@ class SslConfig(
7575
mbedtls_ssl_setup(sslContext, conf).verify()
7676
mbedtls_ssl_set_timer_cb(sslContext, Pointer.NULL, NoOpsSetDelayCallback, NoOpsGetDelayCallback)
7777

78-
val cid = cidSupplier.next()
79-
mbedtls_ssl_set_cid(sslContext, 1, cid, cid.size).verify()
78+
val cid = cidSupplier?.next()
79+
if (cid != null) {
80+
mbedtls_ssl_set_cid(sslContext, 1, cid, cid.size).verify()
81+
}
8082
mbedtls_ssl_set_mtu(sslContext, mtu)
8183

8284
val clientId = peerAddress.toString()
@@ -103,25 +105,25 @@ class SslConfig(
103105

104106
@JvmStatic
105107
@JvmOverloads
106-
fun client(auth: AuthConfig, cipherSuites: List<String> = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier = EmptyCidSupplier, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig {
108+
fun client(auth: AuthConfig, cipherSuites: List<String> = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier? = EmptyCidSupplier, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig {
107109
return create(false, auth, cipherSuites, cidSupplier, reqAuthentication, 0, retransmitMin, retransmitMax)
108110
}
109111

110112
@JvmStatic
111113
@JvmOverloads
112-
fun server(auth: AuthConfig, cipherSuites: List<String> = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier = EmptyCidSupplier, mtu: Int = 0, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig {
114+
fun server(auth: AuthConfig, cipherSuites: List<String> = emptyList(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier? = EmptyCidSupplier, mtu: Int = 0, retransmitMin: Duration = ofSeconds(1), retransmitMax: Duration = ofSeconds(60)): SslConfig {
113115
return create(true, auth, cipherSuites, cidSupplier, reqAuthentication, mtu, retransmitMin, retransmitMax)
114116
}
115117

116118
private fun create(
117119
isServer: Boolean,
118120
authConfig: AuthConfig,
119121
cipherSuites: List<String>,
120-
cidSupplier: CidSupplier,
122+
cidSupplier: CidSupplier?,
121123
requiredAuthMode: Boolean = true,
122124
mtu: Int,
123125
retransmitMin: Duration,
124-
retransmitMax: Duration,
126+
retransmitMax: Duration
125127
): SslConfig {
126128
val sslConfig = Memory(MbedtlsSizeOf.mbedtls_ssl_config).also(MbedtlsApi::mbedtls_ssl_config_init)
127129
val entropy = Memory(MbedtlsSizeOf.mbedtls_entropy_context).also(MbedtlsApi.Crypto::mbedtls_entropy_init)
@@ -154,7 +156,7 @@ class SslConfig(
154156
mbedtls_ssl_conf_ciphersuites(sslConfig, cipherSuiteIds)
155157
}
156158

157-
if (cidSupplier != EmptyCidSupplier) {
159+
if (cidSupplier != null && cidSupplier != EmptyCidSupplier) {
158160
mbedtls_ssl_conf_cid(sslConfig, cidSupplier.next().size, 0)
159161
}
160162

kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt

+77-4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.opencoap.ssl.SslSession
2626
import org.slf4j.LoggerFactory
2727
import java.net.InetSocketAddress
2828
import java.nio.ByteBuffer
29+
import java.nio.ByteOrder
2930
import java.time.Duration
3031
import java.time.Instant
3132
import java.util.concurrent.CompletableFuture
@@ -38,7 +39,8 @@ class DtlsServer(
3839
private val expireAfter: Duration = Duration.ofSeconds(60),
3940
private val storeSession: (cid: ByteArray, session: SessionWithContext) -> Unit,
4041
private val lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {},
41-
private val executor: ScheduledExecutorService
42+
private val executor: ScheduledExecutorService,
43+
private val cidRequired: Boolean = false
4244
) {
4345
companion object {
4446
private val EMPTY_BUFFER = ByteBuffer.allocate(0)
@@ -49,11 +51,12 @@ class DtlsServer(
4951

5052
// note: non thread save, must be used only from same thread
5153
private val sessions = mutableMapOf<InetSocketAddress, DtlsState>()
52-
private val cidSize = sslConfig.cidSupplier.next().size
54+
private val cidSize = sslConfig.cidSupplier?.next()?.size ?: 0
5355
val numberOfSessions get() = sessions.size
5456

5557
fun handleReceived(adr: InetSocketAddress, buf: ByteBuffer): ReceiveResult {
5658
val cid by lazy { SslContext.peekCID(cidSize, buf) }
59+
val isValidHandshake by lazy { isValidHandshakeRequest(buf) }
5760
val dtlsState = sessions[adr]
5861

5962
return when {
@@ -63,12 +66,19 @@ class DtlsServer(
6366
// no session, but dtls packet contains CID
6467
cid != null -> ReceiveResult.CidSessionMissing(cid!!)
6568

66-
// new handshake
67-
else -> {
69+
// start new handshake if datagram is valid
70+
isValidHandshake -> {
6871
val dtlsHandshake = DtlsHandshake(sslConfig.newContext(adr), adr)
6972
sessions[adr] = dtlsHandshake
7073
dtlsHandshake.step(buf)
7174
}
75+
76+
// drop silently
77+
else -> {
78+
logger.warn("[{}] Invalid DTLS session handshake.", adr)
79+
reportMessageDrop(adr)
80+
ReceiveResult.Handled
81+
}
7282
}
7383
}
7484

@@ -186,6 +196,7 @@ class DtlsServer(
186196
when (ex) {
187197
is SslException ->
188198
logger.warn("[{}] DTLS failed: {}", peerAddress, ex.message)
199+
189200
else ->
190201
logger.error(ex.toString(), ex)
191202
}
@@ -305,4 +316,66 @@ class DtlsServer(
305316
lifecycleCallbacks.sessionFinished(peerAddress, reason, err)
306317
}
307318
}
319+
320+
private fun isValidHandshakeRequest(buf: ByteBuffer): Boolean {
321+
val workingBuf = buf.slice().order(ByteOrder.BIG_ENDIAN)
322+
323+
// Check if the header is correct
324+
val header = workingBuf.getLong(0)
325+
if (header != 0x16FEFD0000000000L) {
326+
logger.debug("Bad DTLS header")
327+
return false
328+
}
329+
330+
// Check if it is a ClientHello handshake
331+
val handshakeType = workingBuf.get(13).toInt()
332+
if (handshakeType != 1) {
333+
logger.debug("Bad handshake type")
334+
return false
335+
}
336+
337+
// Check if CID is supported by the client in case if CID support is mandatory
338+
if (cidRequired && !supportsCid(workingBuf)) {
339+
logger.debug("No CID support")
340+
return false
341+
}
342+
343+
return true
344+
}
345+
346+
private fun supportsCid(buf: ByteBuffer): Boolean {
347+
val workingBuffer = buf.slice().order(ByteOrder.BIG_ENDIAN)
348+
349+
// Go to the start of extensions
350+
workingBuffer
351+
// Skip DTLSHeader(13) + HandshakeHeader(12) + CookieLengthOffset(35)
352+
.seek(60)
353+
// Skip variable-length Cookie
354+
.readByteAndSeek()
355+
// Skip variable-length CipherSuites
356+
.readShortAndSeek()
357+
// Skip variable-length CompressionMethods
358+
.readByteAndSeek()
359+
// Limit buffer to the extensions length
360+
.getShort().also {
361+
workingBuffer.limit(workingBuffer.position() + it.toInt())
362+
}
363+
364+
// Search for CID extension
365+
while (workingBuffer.remaining() >= 4) {
366+
val type = workingBuffer.getShort()
367+
if (type == 0x36.toShort()) {
368+
return true
369+
}
370+
371+
// Skip to the next extension
372+
workingBuffer.readShortAndSeek()
373+
}
374+
375+
return false
376+
}
308377
}
378+
379+
private fun ByteBuffer.seek(offset: Int): ByteBuffer = this.position(this.position() + offset) as ByteBuffer
380+
private fun ByteBuffer.readShortAndSeek(): ByteBuffer = this.getShort().let { this.seek(it.toInt()) }
381+
private fun ByteBuffer.readByteAndSeek(): ByteBuffer = this.get().let { this.seek(it.toInt()) }

kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServerTransport.kt

+3-2
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ class DtlsServerTransport private constructor(
4545
expireAfter: Duration = Duration.ofSeconds(60),
4646
sessionStore: SessionStore = NoOpsSessionStore,
4747
transport: Transport<ByteBufferPacket> = DatagramChannelAdapter.open(listenPort),
48-
lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {}
48+
lifecycleCallbacks: DtlsSessionLifecycleCallbacks = object : DtlsSessionLifecycleCallbacks {},
49+
cidRequired: Boolean = false
4950
): DtlsServerTransport {
5051
val executor = SingleThreadExecutor.create("dtls-srv-")
51-
val dtlsServer = DtlsServer(transport, config, expireAfter, sessionStore::write, lifecycleCallbacks, executor)
52+
val dtlsServer = DtlsServer(transport, config, expireAfter, sessionStore::write, lifecycleCallbacks, executor, cidRequired)
5253
return DtlsServerTransport(transport, dtlsServer, sessionStore, executor)
5354
}
5455
}

kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt

+37
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import java.util.concurrent.CompletableFuture.completedFuture
4747
class DtlsServerTest {
4848
val serverConf = SslConfig.server(CertificateAuth(Certs.serverChain, Certs.server.privateKey), listOf("TLS-ECDHE-ECDSA-WITH-AES-128-GCM-SHA256"), false, RandomCidSupplier(16))
4949
val clientConf = SslConfig.client(CertificateAuth.trusted(Certs.root.asX509()), cipherSuites = listOf("TLS-ECDHE-ECDSA-WITH-AES-128-GCM-SHA256"))
50+
val clientConfNoCid = SslConfig.client(CertificateAuth.trusted(Certs.root.asX509()), cipherSuites = listOf("TLS-ECDHE-ECDSA-WITH-AES-128-GCM-SHA256"), cidSupplier = null)
5051

5152
private val sessionStore = HashMapSessionStore()
5253
private lateinit var dtlsServer: DtlsServer
@@ -122,6 +123,42 @@ class DtlsServerTest {
122123
clientSession.close()
123124
}
124125

126+
@Test
127+
fun `should handshake when CID is required`() {
128+
dtlsServer = DtlsServer(::outboundTransport, serverConf, 100.millis, sessionStore::write, executor = SingleThreadExecutor.create("dtls-srv-"), cidRequired = true)
129+
130+
// when
131+
val clientSession = clientHandshake()
132+
133+
// then
134+
val dtlsPacket = clientSession.encrypt("terve".toByteBuffer()).order(ByteOrder.BIG_ENDIAN)
135+
val dtlsPacketIn = (dtlsServer.handleReceived(localAddress(2_5684), dtlsPacket) as ReceiveResult.Decrypted).packet
136+
assertEquals("terve", dtlsPacketIn.buffer.decodeToString())
137+
assertEquals(1, dtlsServer.numberOfSessions)
138+
assertNotNull(dtlsPacketIn.sessionContext.sessionStartTimestamp)
139+
140+
await.untilAsserted {
141+
assertTrue(serverOutboundQueue.isEmpty())
142+
}
143+
144+
clientSession.close()
145+
}
146+
147+
@Test
148+
fun `should fail handshake when CID is required and client doesn't provide it`() {
149+
dtlsServer = DtlsServer(::outboundTransport, serverConf, 100.millis, sessionStore::write, executor = SingleThreadExecutor.create("dtls-srv-"), cidRequired = true)
150+
val send: (ByteBuffer) -> Unit = { dtlsServer.handleReceived(localAddress(2_5684), it) }
151+
val cliHandshake = clientConfNoCid.newContext(localAddress(5684))
152+
153+
// when
154+
cliHandshake.step(send)
155+
156+
// then
157+
await.untilAsserted {
158+
assertTrue(serverOutboundQueue.isEmpty())
159+
}
160+
}
161+
125162
@Test
126163
fun `should handshake with replaying records`() {
127164
lateinit var sendingBuffer: ByteBuffer

kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTransportTest.kt

-4
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,6 @@ class DtlsServerTransportTest {
234234
assertEquals(0, cliChannel.read("aaa".toByteBuffer()))
235235
cliChannel.close()
236236

237-
verify(atMost = 100) {
238-
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(SslException::class))
239-
}
240-
241237
verify(exactly = 0) {
242238
sslLifecycleCallbacks.handshakeFinished(any(), any(), any(), DtlsSessionLifecycleCallbacks.Reason.FAILED, ofType(HelloVerifyRequired::class))
243239
}

0 commit comments

Comments
 (0)