Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Destroy the process during coroutine cancellation #7

Merged
merged 3 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ dependencies {
implementation(kotlin("stdlib-jdk8"))
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.4.2")

testImplementation("org.amshove.kluent:kluent:1.61")
testImplementation("org.amshove.kluent:kluent:1.68")
val junit5 = "5.7.1"
testImplementation("org.junit.jupiter:junit-jupiter-api:$junit5")
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:$junit5")
Expand Down
43 changes: 32 additions & 11 deletions src/main/kotlin/com/github/pgreze/process/Process.kt
Original file line number Diff line number Diff line change
@@ -1,31 +1,43 @@
package com.github.pgreze.process

import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runInterruptible
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import java.io.File
import java.io.InputStream

private suspend fun <R> coroutineScopeIO(block: suspend CoroutineScope.() -> R) =
withContext(Dispatchers.IO) {
// Encapsulates all async calls in the current scope.
// https://elizarov.medium.com/structured-concurrency-722d765aa952
coroutineScope(block)
}

@ExperimentalCoroutinesApi
@Suppress("BlockingMethodInNonBlockingContext", "LongParameterList", "ComplexMethod")
suspend fun process(
vararg command: String,
stdin: InputSource? = null,
stdout: Redirect = Redirect.PRINT,
stderr: Redirect = Redirect.PRINT,
/** Allowing to append new environment variables during this process's invocation. */
/** Extend with new environment variables during this process's invocation. */
env: Map<String, String>? = null,
/** Override the process working directory. */
directory: File? = null,
/** Consume without delay all streams configured with [Redirect.CAPTURE] */
consumer: suspend (String) -> Unit = {},
): ProcessResult = withContext(Dispatchers.IO) {
): ProcessResult = coroutineScopeIO {
// Based on the fact that it's hardcore to achieve manually:
// https://stackoverflow.com/a/4959696
val captureAll = stdout == stderr && stderr == Redirect.CAPTURE
Expand Down Expand Up @@ -60,8 +72,12 @@ suspend fun process(
stderr == Redirect.CAPTURE ->
process.errorStream
else -> null
}?.lineFlow { f -> f.map { it.also { consumer(it) } }.toList() }
?: emptyList()
}?.lineFlow { f ->
f.map {
yield()
it.also { consumer(it) }
}.toList()
} ?: emptyList()
}

val input = async {
Expand All @@ -70,13 +86,18 @@ suspend fun process(
}
}

@Suppress("UNCHECKED_CAST")
return@withContext ProcessResult(
// Consume the output before waitFor,
// ensuring no content is skipped.
output = awaitAll(input, output).last() as List<String>,
resultCode = process.waitFor(),
)
try {
@Suppress("UNCHECKED_CAST")
ProcessResult(
// Consume the output before waitFor,
// ensuring no content is skipped.
output = awaitAll(input, output).last() as List<String>,
resultCode = runInterruptible { process.waitFor() },
)
} catch (e: CancellationException) {
process.destroy()
throw e
}
}

private suspend fun <T> InputStream.lineFlow(block: suspend (Flow<String>) -> T): T =
Expand Down
37 changes: 37 additions & 0 deletions src/test/kotlin/com/github/pgreze/process/ProcessKtTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@ import com.github.pgreze.process.Redirect.Consume
import com.github.pgreze.process.Redirect.PRINT
import com.github.pgreze.process.Redirect.SILENT
import com.github.pgreze.process.Redirect.ToFile
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.launch
import org.amshove.kluent.shouldBeEqualTo
import org.amshove.kluent.shouldContain
import org.junit.jupiter.api.DisplayName
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.RepeatedTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.Timeout
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.io.TempDir
import org.junit.jupiter.params.ParameterizedTest
Expand All @@ -21,6 +26,7 @@ import java.io.ByteArrayOutputStream
import java.io.File
import java.io.PrintStream
import java.nio.file.Path
import java.util.concurrent.TimeUnit
import kotlin.io.path.ExperimentalPathApi
import kotlin.io.path.absolutePathString
import kotlin.io.path.writeText
Expand Down Expand Up @@ -159,6 +165,37 @@ class ProcessKtTest {
stdout shouldBeEqualTo OUT.toList()
}

@Nested
@DisplayName("print to console or not")
inner class Cancellation {
@ParameterizedTest
@ValueSource(booleans = [true, false])
@Timeout(value = 3, unit = TimeUnit.SECONDS)
fun `job cancellation should destroy the process`(captureStdout: Boolean) = runSuspendTest {
var visitedCancelledBlock = false
val job = launch(Dispatchers.IO) {
try {
val ret = process(
"cat", // cat without args is an endless process.
stdout = if (captureStdout) CAPTURE else SILENT
)
throw AssertionError("Process completed despite being cancelled: $ret")
} catch (e: CancellationException) {
visitedCancelledBlock = true
}
}

// Introduce delays to be sure the job was started before being cancelled.
delay(500L)
job.cancel()
delay(500L)

job.isCancelled shouldBeEqualTo true
job.isCompleted shouldBeEqualTo true
visitedCancelledBlock shouldBeEqualTo true
}
}

@Nested
inner class Unwrap {
@Test
Expand Down