Skip to content

Commit b914539

Browse files
committed
add TTL koroutine intergration demo
1 parent bd3a4c2 commit b914539

File tree

3 files changed

+226
-0
lines changed

3 files changed

+226
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.alibaba.demo.coroutine.ttl_intergration
2+
3+
import com.alibaba.ttl.TransmittableThreadLocal.Transmitter.*
4+
import com.alibaba.ttl.threadpool.agent.TtlAgent
5+
import kotlinx.coroutines.ThreadContextElement
6+
import kotlin.coroutines.CoroutineContext
7+
import kotlin.coroutines.EmptyCoroutineContext
8+
9+
/**
10+
* @see [kotlinx.coroutines.asContextElement]
11+
*/
12+
fun ttlContext(): CoroutineContext =
13+
// if (TtlAgent.isTtlAgentLoaded()) // FIXME Open the if when implement TtlAgent for koroutine
14+
// EmptyCoroutineContext
15+
// else
16+
TtlElement()
17+
18+
/**
19+
* @see [kotlinx.coroutines.internal.ThreadLocalElement]
20+
*/
21+
internal class TtlElement : ThreadContextElement<Any> {
22+
companion object Key : CoroutineContext.Key<TtlElement>
23+
24+
override val key: CoroutineContext.Key<*> get() = Key
25+
26+
private var captured: Any =
27+
capture()
28+
29+
override fun updateThreadContext(context: CoroutineContext): Any =
30+
replay(captured)
31+
32+
override fun restoreThreadContext(context: CoroutineContext, oldState: Any) {
33+
captured = capture() // FIXME This capture operation is a MUST, WHY? This operation is too expensive?!
34+
restore(oldState)
35+
}
36+
37+
// this method is overridden to perform value comparison (==) on key
38+
override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext =
39+
if (Key == key) EmptyCoroutineContext else this
40+
41+
// this method is overridden to perform value comparison (==) on key
42+
override operator fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
43+
@Suppress("UNCHECKED_CAST")
44+
if (Key == key) this as E else null
45+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package com.alibaba.demo.coroutine.ttl_intergration.usage
2+
3+
import com.alibaba.demo.coroutine.ttl_intergration.ttlContext
4+
import com.alibaba.ttl.TransmittableThreadLocal
5+
import kotlinx.coroutines.*
6+
7+
private val threadLocal = TransmittableThreadLocal<String?>() // declare thread-local variable
8+
9+
/**
10+
* [Thread-local data - Coroutine Context and Dispatchers - Kotlin Programming Language](https://kotlinlang.org/docs/reference/coroutines/coroutine-context-and-dispatchers.html#thread-local-data)
11+
*/
12+
fun main(): Unit = runBlocking {
13+
val block: suspend CoroutineScope.() -> Unit = {
14+
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
15+
threadLocal.set("!reset!")
16+
println("After reset, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
17+
delay(5)
18+
println("After yield, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
19+
}
20+
21+
threadLocal.set("main")
22+
println("======================\nEmpty Coroutine Context\n======================")
23+
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
24+
launch(block = block).join()
25+
println("Post-main, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
26+
27+
threadLocal.set("main")
28+
println()
29+
println("======================\nTTL Coroutine Context\n======================")
30+
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
31+
launch(ttlContext(), block = block).join()
32+
println("Post-main, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
33+
34+
threadLocal.set("main")
35+
println()
36+
println("======================\nDispatchers.Default Coroutine Context\n======================")
37+
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
38+
launch(Dispatchers.Default, block = block).join()
39+
println("Post-main, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
40+
41+
threadLocal.set("main")
42+
println()
43+
println("======================\nDispatchers.Default + TTL Coroutine Context\n======================")
44+
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
45+
launch(Dispatchers.Default + ttlContext(), block = block).join()
46+
println("Post-main, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
47+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package com.alibaba.demo.coroutine.ttl_intergration.usage
2+
3+
import com.alibaba.demo.coroutine.ttl_intergration.ttlContext
4+
import com.alibaba.ttl.TransmittableThreadLocal
5+
import kotlinx.coroutines.Dispatchers
6+
import kotlinx.coroutines.delay
7+
import kotlinx.coroutines.launch
8+
import kotlinx.coroutines.runBlocking
9+
import org.junit.Assert.assertEquals
10+
import org.junit.Assert.assertNotEquals
11+
import org.junit.Test
12+
13+
class TtlCoroutineContextTest {
14+
@Test
15+
fun threadContextElement_passByValue(): Unit = runBlocking {
16+
val mainValue = "main-${System.currentTimeMillis()}"
17+
val testThread = Thread.currentThread()
18+
19+
// String ThreadLocal, String is immutable value, can only be passed by value
20+
val threadLocal = TransmittableThreadLocal<String?>()
21+
threadLocal.set(mainValue)
22+
println("test thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
23+
24+
val job = launch(Dispatchers.Default + ttlContext()) {
25+
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
26+
assertEquals(mainValue, threadLocal.get())
27+
assertNotEquals(testThread, Thread.currentThread())
28+
29+
delay(5)
30+
31+
println("After delay, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
32+
assertEquals(mainValue, threadLocal.get())
33+
assertNotEquals(testThread, Thread.currentThread())
34+
35+
val reset = "job-reset-${threadLocal.get()}"
36+
threadLocal.set(reset)
37+
assertEquals(reset, threadLocal.get())
38+
39+
delay(5)
40+
41+
println("After delay set reset, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
42+
assertEquals(reset, threadLocal.get())
43+
assertNotEquals(testThread, Thread.currentThread())
44+
}
45+
job.join()
46+
47+
println("after launch, test thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
48+
assertEquals(mainValue, threadLocal.get())
49+
}
50+
51+
@Test
52+
fun threadContextElement_passByReference(): Unit = runBlocking {
53+
data class Reference(var data: Int = 42)
54+
55+
val mainValue = Reference()
56+
val testThread = Thread.currentThread()
57+
58+
// Reference ThreadLocal, mutable value, pass by reference
59+
val threadLocal = TransmittableThreadLocal<Reference>() // declare thread-local variable
60+
threadLocal.set(mainValue)
61+
println("test thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
62+
63+
val job = launch(Dispatchers.Default + ttlContext()) {
64+
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
65+
assertEquals(mainValue, threadLocal.get())
66+
assertNotEquals(testThread, Thread.currentThread())
67+
68+
delay(5)
69+
70+
println("After delay, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
71+
assertEquals(mainValue, threadLocal.get())
72+
assertNotEquals(testThread, Thread.currentThread())
73+
74+
val reset = -42
75+
threadLocal.get().data = reset
76+
77+
delay(5)
78+
79+
println("After delay set reset, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
80+
assertEquals(Reference(reset), threadLocal.get())
81+
assertNotEquals(testThread, Thread.currentThread())
82+
}
83+
job.join()
84+
85+
println("after launch, test thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()}")
86+
assertEquals(mainValue, threadLocal.get())
87+
}
88+
89+
@Test
90+
fun twoThreadContextElement(): Unit = runBlocking {
91+
val mainValue = "main-a-${System.currentTimeMillis()}"
92+
val anotherMainValue = "main-another-${System.currentTimeMillis()}"
93+
val testThread = Thread.currentThread()
94+
95+
val threadLocal = TransmittableThreadLocal<String?>() // declare thread-local variable
96+
val anotherThreadLocal = TransmittableThreadLocal<String?>() // declare thread-local variable
97+
98+
threadLocal.set(mainValue)
99+
anotherThreadLocal.set(anotherMainValue)
100+
println("test thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()} | ${anotherThreadLocal.get()}")
101+
102+
println()
103+
launch(Dispatchers.Default + ttlContext()) {
104+
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()} | ${anotherThreadLocal.get()}")
105+
assertEquals(mainValue, threadLocal.get())
106+
assertEquals(anotherMainValue, anotherThreadLocal.get())
107+
assertNotEquals(testThread, Thread.currentThread())
108+
109+
delay(5)
110+
111+
println("After delay, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()} | ${anotherThreadLocal.get()}")
112+
assertEquals(mainValue, threadLocal.get())
113+
assertEquals(anotherMainValue, anotherThreadLocal.get())
114+
assertNotEquals(testThread, Thread.currentThread())
115+
116+
val resetA = "job-reset-${threadLocal.get()}"
117+
threadLocal.set(resetA)
118+
val resetAnother = "job-reset-${anotherThreadLocal.get()}"
119+
anotherThreadLocal.set(resetAnother)
120+
println("Before delay set reset, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()} | ${anotherThreadLocal.get()}")
121+
122+
delay(5)
123+
124+
println("After delay set reset, current thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()} | ${anotherThreadLocal.get()}")
125+
assertEquals(resetA, threadLocal.get())
126+
assertEquals(resetAnother, anotherThreadLocal.get())
127+
assertNotEquals(testThread, Thread.currentThread())
128+
}.join()
129+
130+
println("after launch2, test thread: ${Thread.currentThread()}, thread local value: ${threadLocal.get()} | ${anotherThreadLocal.get()}")
131+
assertEquals(mainValue, threadLocal.get())
132+
assertEquals(anotherMainValue, anotherThreadLocal.get())
133+
}
134+
}

0 commit comments

Comments
 (0)