diff --git a/example-project/src/main/kotlin/krotoplus/example/StandServiceCoroutineImpl.kt b/example-project/src/main/kotlin/krotoplus/example/StandServiceCoroutineImpl.kt index 12c22d2..79ad1bc 100644 --- a/example-project/src/main/kotlin/krotoplus/example/StandServiceCoroutineImpl.kt +++ b/example-project/src/main/kotlin/krotoplus/example/StandServiceCoroutineImpl.kt @@ -5,9 +5,13 @@ import jojo.bizarre.adventure.character.CharacterProto import jojo.bizarre.adventure.stand.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import kotlin.coroutines.CoroutineContext class StandServiceCoroutineImpl : StandServiceCoroutineGrpc.StandServiceImplBase(){ + override val initialContext: CoroutineContext + get() = Dispatchers.Unconfined + override suspend fun getStandByName( request: StandServiceProto.GetStandByNameRequest ): StandProto.Stand = coroutineScope { diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcUtils.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcUtils.kt index 604e676..ba4acc1 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcUtils.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcUtils.kt @@ -18,12 +18,9 @@ package com.github.marcoferrer.krotoplus.coroutines import com.github.marcoferrer.krotoplus.coroutines.call.newProducerScope import com.github.marcoferrer.krotoplus.coroutines.call.toRpcException -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.Job +import kotlinx.coroutines.* import kotlinx.coroutines.channels.ProducerScope import kotlinx.coroutines.channels.SendChannel -import kotlinx.coroutines.launch import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext @@ -42,9 +39,10 @@ import kotlin.coroutines.EmptyCoroutineContext public fun CoroutineScope.launchProducerJob( channel: SendChannel, context: CoroutineContext = EmptyCoroutineContext, + start: CoroutineStart = CoroutineStart.DEFAULT, block: suspend ProducerScope.()->Unit ): Job = - launch(context) { + launch(context, start) { newProducerScope(channel).block() }.apply { invokeOnCompletion { diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExts.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExts.kt index d22387a..a89faa4 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExts.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExts.kt @@ -26,65 +26,30 @@ import io.grpc.stub.StreamObserver import io.grpc.ClientCall import kotlinx.coroutines.* import kotlinx.coroutines.channels.* -import kotlinx.coroutines.channels.Channel -import java.util.concurrent.atomic.AtomicBoolean import kotlin.coroutines.CoroutineContext -internal fun CoroutineScope.newSendChannelFromObserver( - observer: StreamObserver, - capacity: Int = 1 -): SendChannel = - actor( - context = observer.exceptionHandler + Dispatchers.Unconfined, - capacity = capacity, - start = CoroutineStart.LAZY - ) { - try { - consumeEach { observer.onNext(it) } - channel.close() - }catch (e:Throwable){ - channel.close(e) - } - }.apply{ - invokeOnClose(observer.completionHandler) - } - - -internal fun CoroutineScope.newManagedServerResponseChannel( - responseObserver: ServerCallStreamObserver, - isMessagePreloaded: AtomicBoolean, - requestChannel: Channel = Channel(capacity = 1) -): SendChannel { - - val responseChannel = newSendChannelFromObserver(responseObserver) - - responseObserver.enableManualFlowControl(requestChannel,isMessagePreloaded) - - return responseChannel -} - -internal fun CoroutineScope.bindToClientCancellation(observer: ServerCallStreamObserver<*>){ +internal fun CoroutineScope.bindToClientCancellation(observer: ServerCallStreamObserver<*>) { observer.setOnCancelHandler { this@bindToClientCancellation.cancel() } } -internal fun CoroutineScope.bindScopeCancellationToCall(call: ClientCall<*, *>){ +internal fun CoroutineScope.bindScopeCancellationToCall(call: ClientCall<*, *>) { val job = coroutineContext[Job] ?: error("Unable to bind cancellation to call because scope does not have a job: $this") job.apply { invokeOnCompletion { - if(isCancelled){ - call.cancel(it?.message,it?.cause ?: it) + if (isCancelled) { + call.cancel(it?.message, it?.cause ?: it) } } } } -internal fun StreamObserver<*>.completeSafely(error: Throwable? = null){ +internal fun StreamObserver<*>.completeSafely(error: Throwable? = null) { // If the call was cancelled already // the stream observer will throw kotlin.runCatching { @@ -94,26 +59,24 @@ internal fun StreamObserver<*>.completeSafely(error: Throwable? = null){ } } -internal val StreamObserver<*>.exceptionHandler: CoroutineExceptionHandler - get() = CoroutineExceptionHandler { _, e -> - completeSafely(e) - } - -internal val StreamObserver<*>.completionHandler: CompletionHandler - get() = { completeSafely(it) } - internal fun Throwable.toRpcException(): Throwable = when (this) { is StatusException, is StatusRuntimeException -> this else -> { - val error = Status.fromThrowable(this) - .asRuntimeException(Status.trailersFromThrowable(this)) - - if(error.status.code == Status.Code.UNKNOWN && this is CancellationException) + val statusFromThrowable = Status.fromThrowable(this) + val status = if ( + statusFromThrowable.code == Status.UNKNOWN.code && + this is CancellationException + ) { Status.CANCELLED - .withDescription(this.message) - .asRuntimeException() else error + } else { + statusFromThrowable + } + + status + .withDescription(this.message) + .asRuntimeException(Status.trailersFromThrowable(this)) } } @@ -130,7 +93,6 @@ internal fun newRpcScope( methodDescriptor.getCoroutineName() ) -@ExperimentalCoroutinesApi internal fun CoroutineScope.newProducerScope(channel: SendChannel): ProducerScope = object : ProducerScope, CoroutineScope by this, @@ -138,35 +100,4 @@ internal fun CoroutineScope.newProducerScope(channel: SendChannel): Produ override val channel: SendChannel get() = channel - } - -internal inline fun StreamObserver.handleUnaryRpc(block: ()->T){ - try{ - onNext(block()) - onCompleted() - }catch (e: Throwable){ - completeSafely(e) - } -} - -internal inline fun SendChannel.handleStreamingRpc(block: (SendChannel)->Unit){ - try{ - block(this) - close() - }catch (e: Throwable){ - close(e.toRpcException()) - } -} - -internal inline fun handleBidiStreamingRpc( - requestChannel: ReceiveChannel, - responseChannel: SendChannel, - block: (ReceiveChannel, SendChannel) -> Unit -) { - try{ - block(requestChannel,responseChannel) - responseChannel.close() - }catch (e:Throwable){ - responseChannel.close(e.toRpcException()) - } -} \ No newline at end of file + } \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CompletableDeferredExts.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CompletableDeferredExts.kt deleted file mode 100644 index e316cda..0000000 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CompletableDeferredExts.kt +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2019 Kroto+ Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.github.marcoferrer.krotoplus.coroutines.call - -import io.grpc.stub.StreamObserver -import kotlinx.coroutines.CompletableDeferred - -internal fun CompletableDeferred.toStreamObserver(): StreamObserver = - object : StreamObserver { - - /** - * Since [CompletableDeferred] is a single value coroutine primitive, - * once [onNext] has been called we can be sure that we have completed - * our stream. - * - */ - override fun onNext(value: T) { - complete(value) - } - - override fun onError(t: Throwable) { - completeExceptionally(t) - } - - /** - * This method is intentionally left blank. - * - * Since this stream represents a single value, completion is marked by - * the first invocation of [onNext] - */ - override fun onCompleted() { - // NOOP - } - } - diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/FlowControl.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/FlowControl.kt index c8ef007..294d58d 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/FlowControl.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/FlowControl.kt @@ -20,64 +20,59 @@ import io.grpc.stub.CallStreamObserver import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger -internal interface FlowControlledObserver { - @ExperimentalCoroutinesApi - fun CoroutineScope.nextValueWithBackPressure( - value: T, - channel: Channel, - callStreamObserver: CallStreamObserver, - isMessagePreloaded: AtomicBoolean - ) { - try { - when { - !channel.isClosedForSend && channel.offer(value) -> callStreamObserver.request(1) - - !channel.isClosedForSend -> { - // We are setting isMessagePreloaded to true to prevent the - // onReadyHandler from requesting a new message while we have - // a message preloaded. - isMessagePreloaded.set(true) - - // Using [CoroutineStart.UNDISPATCHED] ensures that - // values are sent in the proper order (FIFO). - // This also prevents a race between [StreamObserver.onNext] and - // [StreamObserver.onComplete] by making sure all preloaded messages - // have been submitted before invoking [Channel.close] - launch(start = CoroutineStart.UNDISPATCHED) { - try { - channel.send(value) - callStreamObserver.request(1) - - // Allow the onReadyHandler to begin requesting messages again. - isMessagePreloaded.set(false) - }catch (e: Throwable){ - channel.close(e) - } - } - } - } - } catch (e: Throwable) { - channel.close(e) - } - } -} - -@ExperimentalCoroutinesApi -internal fun CallStreamObserver.enableManualFlowControl( - targetChannel: Channel, - isMessagePreloaded: AtomicBoolean +internal fun CallStreamObserver<*>.applyInboundFlowControl( + targetChannel: Channel, + transientInboundMessageCount: AtomicInteger ) { disableAutoInboundFlowControl() setOnReadyHandler { if ( isReady && - !targetChannel.isFull && - !targetChannel.isClosedForSend && - isMessagePreloaded.compareAndSet(false, true) + !targetChannel.isClosedForReceive && + transientInboundMessageCount.get() == 0 ) { request(1) } } } + +internal fun CoroutineScope.applyOutboundFlowControl( + streamObserver: CallStreamObserver, + targetChannel: Channel +){ + val isOutboundJobRunning = AtomicBoolean() + val channelIterator = targetChannel.iterator() + streamObserver.setOnReadyHandler { + if(targetChannel.isClosedForReceive){ + streamObserver.completeSafely() + }else if( + streamObserver.isReady && + !targetChannel.isClosedForReceive && + isOutboundJobRunning.compareAndSet(false, true) + ){ + launch(Dispatchers.Unconfined + CoroutineExceptionHandler { _, e -> + streamObserver.completeSafely(e) + targetChannel.close(e) + }) { + try{ + while( + streamObserver.isReady && + !targetChannel.isClosedForReceive && + channelIterator.hasNext() + ){ + val value = channelIterator.next() + streamObserver.onNext(value) + } + if(targetChannel.isClosedForReceive){ + streamObserver.onCompleted() + } + } finally { + isOutboundJobRunning.set(false) + } + } + } + } +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/FlowControlledInboundStreamObserver.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/FlowControlledInboundStreamObserver.kt new file mode 100644 index 0000000..7b71252 --- /dev/null +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/FlowControlledInboundStreamObserver.kt @@ -0,0 +1,81 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.call + +import io.grpc.stub.CallStreamObserver +import io.grpc.stub.StreamObserver +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.launch +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger + +interface FlowControlledInboundStreamObserver : StreamObserver, CoroutineScope { + + val inboundChannel: Channel + + val isInboundCompleted: AtomicBoolean + + val transientInboundMessageCount: AtomicInteger + + val callStreamObserver: CallStreamObserver<*> + + val isChannelReadyForClose: Boolean + get() = isInboundCompleted.get() && transientInboundMessageCount.get() == 0 + + fun onNextWithBackPressure(value: T) { + transientInboundMessageCount.incrementAndGet() + when { + !inboundChannel.isClosedForSend && inboundChannel.offer(value) -> { + transientInboundMessageCount.decrementAndGet() + requestNextOrClose() + } + !inboundChannel.isClosedForSend -> { + launch(context = Dispatchers.Unconfined) { + try { + inboundChannel.send(value) + } catch (e: Throwable) { + inboundChannel.close(e) + } + }.invokeOnCompletion { + transientInboundMessageCount.decrementAndGet() + if (!inboundChannel.isClosedForReceive) { + requestNextOrClose() + } + } + } + else -> { + transientInboundMessageCount.decrementAndGet() + error("Received value but inbound channel is closed for send") + } + } + } + + fun requestNextOrClose() { + if (isChannelReadyForClose) + inboundChannel.close() else + callStreamObserver.request(1) + } + + override fun onCompleted() { + isInboundCompleted.set(true) + if (isChannelReadyForClose) { + inboundChannel.close() + } + } +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientChannels.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientBidiCallChannel.kt similarity index 61% rename from kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientChannels.kt rename to kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientBidiCallChannel.kt index ed51d59..69dcb3d 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientChannels.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientBidiCallChannel.kt @@ -16,16 +16,15 @@ package com.github.marcoferrer.krotoplus.coroutines.client -import com.github.marcoferrer.krotoplus.coroutines.call.FlowControlledObserver -import com.github.marcoferrer.krotoplus.coroutines.call.enableManualFlowControl +import com.github.marcoferrer.krotoplus.coroutines.call.FlowControlledInboundStreamObserver +import com.github.marcoferrer.krotoplus.coroutines.call.applyOutboundFlowControl import io.grpc.stub.ClientCallStreamObserver import io.grpc.stub.ClientResponseObserver -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Deferred import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger import kotlin.coroutines.CoroutineContext /** @@ -65,7 +64,7 @@ import kotlin.coroutines.CoroutineContext * and response handling into separate channels. * */ -public interface ClientBidiCallChannel : SendChannel, ReceiveChannel{ +public interface ClientBidiCallChannel : SendChannel, ReceiveChannel { public val requestChannel: SendChannel @@ -76,64 +75,39 @@ public interface ClientBidiCallChannel : SendChannel, Receive public operator fun component2(): ReceiveChannel = responseChannel } -internal class ClientBidiCallChannelImpl( - public override val requestChannel: SendChannel, - public override val responseChannel: ReceiveChannel -) : ClientBidiCallChannel, - SendChannel by requestChannel, - ReceiveChannel by responseChannel - -/** - * - */ -public interface ClientStreamingCallChannel : SendChannel { - - public val requestChannel: SendChannel - - public val response: Deferred - - public operator fun component1(): SendChannel = requestChannel - - public operator fun component2(): Deferred = response -} - -internal class ClientStreamingCallChannelImpl( - public override val requestChannel: SendChannel, - public override val response: Deferred -) : ClientStreamingCallChannel, - SendChannel by requestChannel +internal class ClientBidiCallChannelImpl( + override val coroutineContext: CoroutineContext, + override val inboundChannel: Channel = Channel(), + private val outboundChannel: Channel = Channel() +) : FlowControlledInboundStreamObserver, + ClientResponseObserver, + ClientBidiCallChannel, + SendChannel by outboundChannel, + ReceiveChannel by inboundChannel +{ + override val requestChannel: SendChannel + get() = outboundChannel + override val responseChannel: ReceiveChannel + get() = inboundChannel -internal class ClientResponseObserverChannel( - override val coroutineContext: CoroutineContext, - private val responseChannelDelegate: Channel = Channel(capacity = 1) -) : ClientResponseObserver, - FlowControlledObserver, - ReceiveChannel by responseChannelDelegate, - CoroutineScope { + override val isInboundCompleted = AtomicBoolean() - private val isMessagePreloaded = AtomicBoolean() + override val transientInboundMessageCount: AtomicInteger = AtomicInteger() - private lateinit var requestStream: ClientCallStreamObserver + override lateinit var callStreamObserver: ClientCallStreamObserver override fun beforeStart(requestStream: ClientCallStreamObserver) { - this.requestStream = requestStream.apply { - enableManualFlowControl(responseChannelDelegate,isMessagePreloaded) - } + callStreamObserver = requestStream + applyOutboundFlowControl(requestStream,outboundChannel) } - override fun onNext(value: RespT) = nextValueWithBackPressure( - value = value, - channel = responseChannelDelegate, - callStreamObserver = requestStream, - isMessagePreloaded = isMessagePreloaded - ) + override fun onNext(value: RespT): Unit = onNextWithBackPressure(value) override fun onError(t: Throwable) { - responseChannelDelegate.close(t) + outboundChannel.close(t) + outboundChannel.cancel() + inboundChannel.close(t) } +} - override fun onCompleted() { - responseChannelDelegate.close() - } -} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt index b0a0a17..9e344b5 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt @@ -84,7 +84,7 @@ public fun > T.clientCallServerStreaming( with(newRpcScope(coroutineContext, method)) { val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)) - val responseObserverChannel = ClientResponseObserverChannel(coroutineContext) + val responseObserverChannel = ClientResponseStreamChannel(coroutineContext) asyncServerStreamingCall( call, request, @@ -100,17 +100,13 @@ public fun > T.clientCallBidiStreaming( ): ClientBidiCallChannel { with(newRpcScope(coroutineContext, method)) { + val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)) - val responseDelegate = Channel(capacity = 1) - val responseChannel = ClientResponseObserverChannel(coroutineContext, responseDelegate) - val requestObserver = asyncBidiStreamingCall( - call, responseChannel - ) + val callChannel = ClientBidiCallChannelImpl(coroutineContext) + asyncBidiStreamingCall(call, callChannel) bindScopeCancellationToCall(call) - val requestChannel = newSendChannelFromObserver(requestObserver) - responseDelegate.invokeOnClose { requestChannel.close(it) } - return ClientBidiCallChannelImpl(requestChannel, responseChannel) + return callChannel } } @@ -119,17 +115,12 @@ public fun > T.clientCallClientStreaming( ): ClientStreamingCallChannel { with(newRpcScope(coroutineContext, method)) { - val completableResponse = CompletableDeferred(parent = coroutineContext[Job]) val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)) - val requestObserver = asyncClientStreamingCall( - call, completableResponse.toStreamObserver() - ) + val callChannel = ClientStreamingCallChannelImpl(coroutineContext) + asyncClientStreamingCall(call, callChannel) bindScopeCancellationToCall(call) - val requestChannel = newSendChannelFromObserver(requestObserver) - return ClientStreamingCallChannelImpl( - requestChannel, - completableResponse - ) + + return callChannel } } diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientResponseStreamChannel.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientResponseStreamChannel.kt new file mode 100644 index 0000000..c23b9bb --- /dev/null +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientResponseStreamChannel.kt @@ -0,0 +1,56 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.client + +import com.github.marcoferrer.krotoplus.coroutines.call.FlowControlledInboundStreamObserver +import com.github.marcoferrer.krotoplus.coroutines.call.applyInboundFlowControl +import io.grpc.stub.ClientCallStreamObserver +import io.grpc.stub.ClientResponseObserver +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ReceiveChannel +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import kotlin.coroutines.CoroutineContext + + +internal class ClientResponseStreamChannel( + override val coroutineContext: CoroutineContext, + override val inboundChannel: Channel = Channel() +) : ClientResponseObserver, + FlowControlledInboundStreamObserver, + ReceiveChannel by inboundChannel, + CoroutineScope { + + override val isInboundCompleted: AtomicBoolean = AtomicBoolean() + + override val transientInboundMessageCount: AtomicInteger = AtomicInteger() + + override lateinit var callStreamObserver: ClientCallStreamObserver + + override fun beforeStart(requestStream: ClientCallStreamObserver) { + callStreamObserver = requestStream.apply { + applyInboundFlowControl(inboundChannel,transientInboundMessageCount) + } + } + + override fun onNext(value: RespT): Unit = onNextWithBackPressure(value) + + override fun onError(t: Throwable) { + inboundChannel.close(t) + } +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt new file mode 100644 index 0000000..0f5e7b6 --- /dev/null +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt @@ -0,0 +1,84 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.client + +import com.github.marcoferrer.krotoplus.coroutines.call.applyOutboundFlowControl +import io.grpc.stub.ClientCallStreamObserver +import io.grpc.stub.ClientResponseObserver +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.SendChannel +import kotlin.coroutines.CoroutineContext + +/** + * + */ +public interface ClientStreamingCallChannel : SendChannel { + + public val requestChannel: SendChannel + + public val response: Deferred + + public operator fun component1(): SendChannel = requestChannel + + public operator fun component2(): Deferred = response +} + + +internal class ClientStreamingCallChannelImpl( + + override val coroutineContext: CoroutineContext, + + private val outboundChannel: Channel = Channel(), + + private val completableResponse: CompletableDeferred = CompletableDeferred(parent = coroutineContext[Job]) + +) : ClientResponseObserver, + ClientStreamingCallChannel, + SendChannel by outboundChannel, + CoroutineScope { + + override val requestChannel: SendChannel + get() = outboundChannel + + override val response: Deferred + get() = completableResponse + + private lateinit var callStreamObserver: ClientCallStreamObserver + + override fun beforeStart(requestStream: ClientCallStreamObserver) { + callStreamObserver = requestStream + applyOutboundFlowControl(requestStream, outboundChannel) + } + + override fun onNext(value: RespT) { + completableResponse.complete(value) + } + + override fun onError(t: Throwable) { + outboundChannel.close(t) + outboundChannel.cancel() + completableResponse.completeExceptionally(t) + } + + override fun onCompleted() { + require(completableResponse.isCompleted){ + "Stream was completed before onNext was called" + } + } + +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCalls.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCalls.kt index 0ac8be0..c16b3fb 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCalls.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCalls.kt @@ -24,7 +24,7 @@ import io.grpc.stub.ServerCallStreamObserver import io.grpc.stub.StreamObserver import kotlinx.coroutines.* import kotlinx.coroutines.channels.* -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger public fun ServiceScope.serverCallUnary( @@ -35,7 +35,12 @@ public fun ServiceScope.serverCallUnary( with(newRpcScope(initialContext, methodDescriptor)) rpcScope@ { bindToClientCancellation(responseObserver as ServerCallStreamObserver<*>) launch { - responseObserver.handleUnaryRpc { block() } + try{ + responseObserver.onNext(block()) + responseObserver.onCompleted() + }catch (e: Throwable){ + responseObserver.completeSafely(e) + } } } } @@ -45,16 +50,23 @@ public fun ServiceScope.serverCallServerStreaming( responseObserver: StreamObserver, block: suspend (SendChannel) -> Unit ) { + val responseChannel = Channel() val serverCallObserver = responseObserver as ServerCallStreamObserver - with(newRpcScope(initialContext, methodDescriptor)) rpcScope@ { bindToClientCancellation(serverCallObserver) - - val responseChannel = newSendChannelFromObserver(responseObserver, capacity = 0) - + applyOutboundFlowControl(serverCallObserver,responseChannel) launch { - responseChannel.handleStreamingRpc { block(it) } + try{ + block(responseChannel) + responseChannel.close() + }catch (e: Throwable){ + val rpcError = e.toRpcException() + serverCallObserver.completeSafely(rpcError) + responseChannel.close(rpcError) + } } + + bindScopeCompletionToObserver(serverCallObserver) } } @@ -65,19 +77,19 @@ public fun ServiceScope.serverCallClientStreaming( block: suspend (ReceiveChannel) -> RespT ): StreamObserver { - val isMessagePreloaded = AtomicBoolean(false) - val requestChannelDelegate = Channel(capacity = 1) + val activeInboundJobCount = AtomicInteger() + val inboundChannel = Channel() val serverCallObserver = (responseObserver as ServerCallStreamObserver) - .apply { enableManualFlowControl(requestChannelDelegate, isMessagePreloaded) } + .apply { applyInboundFlowControl(inboundChannel, activeInboundJobCount) } with(newRpcScope(initialContext, methodDescriptor)) rpcScope@ { bindToClientCancellation(serverCallObserver) val requestChannel = ServerRequestStreamChannel( coroutineContext = coroutineContext, - delegateChannel = requestChannelDelegate, + inboundChannel = inboundChannel, + transientInboundMessageCount = activeInboundJobCount, callStreamObserver = serverCallObserver, - isMessagePreloaded = isMessagePreloaded, onErrorHandler = { // Call cancellation already cancels the coroutine scope // and closes the response stream. So we dont need to @@ -90,9 +102,12 @@ public fun ServiceScope.serverCallClientStreaming( ) launch { - responseObserver.handleUnaryRpc { block(requestChannel) } - // If the request channel was abandoned but we completed successfully - // close it and clear its contents. + try{ + responseObserver.onNext(block(requestChannel)) + responseObserver.onCompleted() + }catch (e: Throwable){ + responseObserver.completeSafely(e) + } if(!requestChannel.isClosedForReceive){ requestChannel.cancel() } @@ -110,23 +125,14 @@ public fun ServiceScope.serverCallBidiStreaming( block: suspend (ReceiveChannel, SendChannel) -> Unit ): StreamObserver { - val isMessagePreloaded = AtomicBoolean(false) - val requestChannelDelegate = Channel(capacity = 1) + val responseChannel = Channel() val serverCallObserver = (responseObserver as ServerCallStreamObserver) - with(newRpcScope(initialContext, methodDescriptor)) rpcScope@ { bindToClientCancellation(serverCallObserver) - - val responseChannel = newManagedServerResponseChannel( - responseObserver = serverCallObserver, - requestChannel = requestChannelDelegate, - isMessagePreloaded = isMessagePreloaded - ) - val requestChannel = ServerRequestStreamChannel( + applyOutboundFlowControl(serverCallObserver,responseChannel) + val requestChannel = ServerRequestStreamChannel( coroutineContext = coroutineContext, - delegateChannel = requestChannelDelegate, callStreamObserver = serverCallObserver, - isMessagePreloaded = isMessagePreloaded, onErrorHandler = { // Call cancellation already cancels the coroutine scope // and closes the response stream. So we dont need to @@ -135,6 +141,7 @@ public fun ServiceScope.serverCallBidiStreaming( // In the event of a request error, we // need to close the responseChannel before // cancelling the rpcScope. + responseObserver.completeSafely(it) responseChannel.close(it) this@rpcScope.cancel() } @@ -142,14 +149,21 @@ public fun ServiceScope.serverCallBidiStreaming( ) launch { - handleBidiStreamingRpc(requestChannel, responseChannel){ req, resp -> block(req,resp) } - // If the request channel was abandoned but we completed successfully - // close it and clear its contents. + serverCallObserver.request(1) + try{ + block(requestChannel,responseChannel) + responseChannel.close() + }catch (e:Throwable){ + val rpcError = e.toRpcException() + serverCallObserver.completeSafely(rpcError) + responseChannel.close(rpcError) + } if(!requestChannel.isClosedForReceive){ requestChannel.cancel() } } + bindScopeCompletionToObserver(serverCallObserver) return requestChannel } } @@ -165,3 +179,17 @@ private fun MethodDescriptor<*, *>.getUnimplementedException(): StatusRuntimeExc Status.UNIMPLEMENTED .withDescription("Method $fullMethodName is unimplemented") .asRuntimeException() + +/** + * Binds the completion of the coroutine context job to the outbound stream observer. + * + * This is used in server call handlers with outbound streams to ensure completion of any scheduled outbound producers + * before invoking `onComplete` and closing the call stream. + * + */ +private fun CoroutineScope.bindScopeCompletionToObserver(streamObserver: StreamObserver<*>) { + + coroutineContext[Job]?.invokeOnCompletion { + streamObserver.completeSafely(it) + } +} diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerChannels.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerChannels.kt index afe4313..ca7dc5f 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerChannels.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerChannels.kt @@ -16,40 +16,31 @@ package com.github.marcoferrer.krotoplus.coroutines.server -import com.github.marcoferrer.krotoplus.coroutines.call.FlowControlledObserver -import io.grpc.stub.CallStreamObserver -import io.grpc.stub.StreamObserver -import kotlinx.coroutines.* +import com.github.marcoferrer.krotoplus.coroutines.call.FlowControlledInboundStreamObserver +import io.grpc.stub.ServerCallStreamObserver +import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ReceiveChannel import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger import kotlin.coroutines.CoroutineContext -internal class ServerRequestStreamChannel( +internal class ServerRequestStreamChannel( override val coroutineContext: CoroutineContext, - private val delegateChannel: Channel, - private val isMessagePreloaded: AtomicBoolean, - private val callStreamObserver: CallStreamObserver, + override val inboundChannel: Channel = Channel(), + override val transientInboundMessageCount: AtomicInteger = AtomicInteger(), + override val callStreamObserver: ServerCallStreamObserver<*>, private val onErrorHandler: ((Throwable) -> Unit)? = null -) : ReceiveChannel by delegateChannel, - FlowControlledObserver, - StreamObserver, +) : ReceiveChannel by inboundChannel, + FlowControlledInboundStreamObserver, CoroutineScope { - @ExperimentalCoroutinesApi - override fun onNext(value: ReqT) = nextValueWithBackPressure( - value, - delegateChannel, - callStreamObserver, - isMessagePreloaded - ) + override val isInboundCompleted: AtomicBoolean = AtomicBoolean() + + override fun onNext(value: ReqT) = onNextWithBackPressure(value) override fun onError(t: Throwable) { - delegateChannel.close(t) + inboundChannel.close(t) onErrorHandler?.invoke(t) } - - override fun onCompleted() { - delegateChannel.close() - } } \ No newline at end of file diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExtsTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExtsTests.kt index 6707be1..608c32f 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExtsTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExtsTests.kt @@ -39,203 +39,6 @@ import kotlin.test.assertTrue import kotlin.test.fail -class NewSendChannelFromObserverTests { - - @Test - fun `Channel send to observer success`() = runBlocking { - - val observer = mockk>().apply { - every { onNext(allAny()) } just Runs - every { onCompleted() } just Runs - } - - GlobalScope.newSendChannelFromObserver(observer).apply { - repeat(3) { send(it) } - close() - } - - verify(exactly = 3) { observer.onNext(allAny()) } - verify(exactly = 1) { observer.onCompleted() } - } - - @Test - fun `Channel close with error`() = runBlocking { - - val statusException = Status.INVALID_ARGUMENT.asException() - val observer = mockk>().apply { - every { onNext(allAny()) } just Runs - every { onError(statusException) } just Runs - } - - val channel = GlobalScope.newSendChannelFromObserver(observer).apply { - send("") - close(statusException) - } - - assert(channel.isClosedForSend){ "Channel should be closed for send" } - verify(exactly = 1) { observer.onNext(allAny()) } - verify(exactly = 1) { observer.onError(statusException) } - verify(exactly = 0) { observer.onCompleted() } - } - - - @Test - fun `Channel is closed when scope is cancelled normally`() { - - val observer = mockk>().apply { - every { onNext(allAny()) } just Runs - every { onError(any()) } just Runs - } - - lateinit var channel: SendChannel - runBlocking { - launch { - launch(start = CoroutineStart.UNDISPATCHED) { - channel = newSendChannelFromObserver(observer).apply { - send("") - } - } - cancel() - } - } - - assert(channel.isClosedForSend){ "Channel should be closed for send" } - verify(exactly = 1) { observer.onNext(allAny()) } - verify(exactly = 1) { observer.onError(any()) } - verify(exactly = 0) { observer.onCompleted() } - } - - @Test - fun `Channel is closed when scope is cancelled exceptionally`() { - - val observer = mockk>().apply { - every { onNext(allAny()) } just Runs - every { onError(any()) } just Runs - } - - lateinit var channel: SendChannel - assertFails("cancel"){ - runBlocking { - launch { - channel = newSendChannelFromObserver(observer).apply { - send("") - } - } - launch { - error("cancel") - } - } - } - - assert(channel.isClosedForSend){ "Channel should be closed for send" } - verify(exactly = 1) { observer.onNext(allAny()) } - verify(exactly = 1) { observer.onError(any()) } - verify(exactly = 0) { observer.onCompleted() } - } - - @Test - fun `Channel close when observer onNext error `() { - - val statusException = Status.INVALID_ARGUMENT.asRuntimeException() - val observer = mockk>().apply { - every { onNext(any()) } throws statusException - every { onError(statusException) } just Runs - } - - lateinit var channel: SendChannel - assertFailsWithStatus(Status.INVALID_ARGUMENT) { - runBlocking { - - channel = newSendChannelFromObserver(observer).apply { - - val send1Result = runCatching { send("") } - assertTrue(send1Result.isSuccess, "Error during observer.onNext should not fail channel.send") - assertTrue(isClosedForSend, "Channel should be closed after onNext error") - - val send2Result = runCatching { send("") } - assertTrue(send2Result.isFailure, "Expecting error after sending a value to failed channel") - assertEquals(statusException, send2Result.exceptionOrNull()) - } - } - } - - assert(channel.isClosedForSend) { "Channel should be closed for send" } - verify(exactly = 1) { observer.onNext(allAny()) } - verify(exactly = 1) { observer.onError(statusException) } - verify(exactly = 0) { observer.onCompleted() } - } -} - -class NewManagedServerResponseChannelTests { - - lateinit var observer: ServerCallStreamObserver - - @BeforeTest - fun setup(){ - observer = mockk>().apply { - every { disableAutoInboundFlowControl() } just Runs - every { setOnReadyHandler(any()) } just Runs - } - } - - //TODO: Verify number of requests being made to test back pressure - - - @Test - fun `Test manual flow control is enabled`() { - GlobalScope.newManagedServerResponseChannel(observer,AtomicBoolean()).close() - verify(exactly = 1) { observer.disableAutoInboundFlowControl() } - verify(exactly = 1) { observer.setOnReadyHandler(any()) } - } - - @Test - fun `Test channel propagates values to observer onNext`() = runBlocking { - observer.apply { - every { onNext(Unit) } just Runs - every { onCompleted() } just Runs - } - - with(newManagedServerResponseChannel(observer,AtomicBoolean())){ - repeat(3){ - send(Unit) - } - close() - } - - verify(exactly = 3) { observer.onNext(Unit) } - verify(exactly = 1) { observer.onCompleted() } - } - - @Test - fun `Test channel propagates errors to observer onError`(){ - - observer.apply { - every { onError(matchStatus(Status.UNKNOWN)) } just Runs - every { onNext(Unit) } just Runs - every { onCompleted() } just Runs - } - - val error = IllegalArgumentException("error") - lateinit var channel: SendChannel - runBlocking { - channel = newManagedServerResponseChannel(observer, AtomicBoolean()).apply { - send(Unit) - close(error) - } - - assertFails(error.message) { - channel.send(Unit) - } - } - - assert(channel.isClosedForSend){ "Channel should be closed" } - verify(exactly = 1) { observer.onNext(Unit) } - verify(exactly = 1) { observer.onError(any()) } - verify(exactly = 0) { observer.onCompleted() } - } - -} - class MethodDescriptorExtTests { @Test @@ -249,134 +52,6 @@ class MethodDescriptorExtTests { } } -class HandleUnaryRpcBlockTests { - - @Test - fun `Test block completed successfully`(){ - val observer = mockk>().apply { - every { onNext(Unit) } just Runs - every { onCompleted() } just Runs - } - - observer.handleUnaryRpc { Unit } - - verify(exactly = 1) { observer.onNext(Unit) } - verify(exactly = 1) { observer.onCompleted() } - } - - @Test - fun `Test block completed exceptionally`(){ - val observer = mockk>().apply { - every { - val matcher = match { - it.status.code == Status.UNKNOWN.code - } - onError(matcher) - } just Runs - } - - observer.handleUnaryRpc { error("failed") } - - verify(exactly = 1) { observer.onError(any()) } - } -} - -class HandleStreamingRpcTests { - - @Test - fun `Test block completed successfully`(){ - - val channel = mockk>().apply { - coEvery { send(Unit) } just Runs - every { close() } returns true - } - - runBlocking { - channel.handleStreamingRpc { - it.send(Unit) - assertEquals(channel, it) - } - } - - coVerify(exactly = 1) { channel.send(Unit) } - verify(exactly = 1) { channel.close() } - } - - @Test - fun `Test block completed exceptionally`(){ - val channel = mockk>().apply { - every { - val matcher = match { - it.status.code == Status.UNKNOWN.code - } - close(matcher) - } returns true - } - - runBlocking { - channel.handleStreamingRpc { error("failed") } - } - - verify(exactly = 1) { channel.close(any()) } - } -} - -class HandleBidiStreamingRpcTests { - - @Test - fun `Test block completed successfully`(){ - - val reqChannel = mockk>().apply { - coEvery { receive() } returns "request" - } - val respChannel = mockk>().apply { - coEvery { send("response") } just Runs - every { close() } returns true - } - - runBlocking { - handleBidiStreamingRpc(reqChannel,respChannel) { req, resp -> - req.receive() - respChannel.send("response") - assertEquals(reqChannel, req) - assertEquals(respChannel, resp) - } - } - - coVerify(exactly = 1) { reqChannel.receive() } - coVerify(exactly = 1) { respChannel.send("response") } - verify(exactly = 1) { respChannel.close() } - } - - @Test - fun `Test block completed exceptionally`(){ - val reqChannel = mockk>().apply { - coEvery { receive() } throws ClosedReceiveChannelException("closed") - } - val respChannel = mockk>().apply { - every { - val matcher = match { - it.status.code == Status.UNKNOWN.code - } - close(matcher) - } returns true - } - - runBlocking { - handleBidiStreamingRpc(reqChannel,respChannel) { req, resp -> - assertEquals(reqChannel, req) - assertEquals(respChannel, resp) - resp.send("response") - req.receive() - resp.send("response") - } - } - - coVerify(exactly = 1) { respChannel.send("response") } - verify(exactly = 1) { respChannel.close(any()) } - } -} - class BindToClientCancellationTests { @Test diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CompletableDeferredExtsTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CompletableDeferredExtsTests.kt deleted file mode 100644 index aa1e523..0000000 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CompletableDeferredExtsTests.kt +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright 2019 Kroto+ Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.github.marcoferrer.krotoplus.coroutines.call - -import io.mockk.spyk -import io.mockk.verify -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.runBlocking -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertNotNull - -class CompletableDeferredExtsTests { - - @Test - fun `Deferred completed successfully on next value`(){ - val value = "test_value" - val result = spyk(CompletableDeferred()) - - with(result.toStreamObserver()){ - onNext(value) - onCompleted() - } - - runBlocking { - assertEquals(value,result.await()) - } - - verify(exactly = 1) { result.complete(value) } - } - - @Test - fun `Deferred completed exceptionally on next value`(){ - val errorMessage = "error_message" - val error = IllegalArgumentException(errorMessage) - val result = spyk(CompletableDeferred()) - - with(result.toStreamObserver()){ - onError(error) - } - - val throwable = runBlocking { - result.runCatching { await() }.exceptionOrNull() - } - - assertNotNull(throwable) - assertEquals(errorMessage,throwable.message) - assert(throwable is IllegalArgumentException) - - verify(exactly = 0) { result.complete(any()) } - verify(exactly = 1) { result.completeExceptionally(any()) } - } - - @Test - fun `Deferred result doesn't change on repeated completion`(){ - val value = "test_value" - val result = spyk(CompletableDeferred()) - - with(result.toStreamObserver()){ - onNext(value) - onNext("extra_value") - onCompleted() - } - - runBlocking { - assertEquals(value,result.await()) - } - } - - @Test - fun `Deferred result doesn't change on excessive exceptional completion`(){ - val value = "test_value" - val result = CompletableDeferred() - val errorMessage = "error_message" - val error = IllegalArgumentException(errorMessage) - - with(result.toStreamObserver()){ - onNext(value) - onCompleted() - onError(error) - } - - runBlocking { - assertEquals(value,result.await()) - } - } - - @Test - fun `Deferred exception doesn't change on excessive exceptional completion`(){ - val errorMessage = "error_message" - val error = IllegalArgumentException(errorMessage) - val result = CompletableDeferred() - - - with(result.toStreamObserver()){ - onError(error) - onError(IndexOutOfBoundsException()) - } - - val throwable = runBlocking { - result.runCatching { await() }.exceptionOrNull() - } - - assertNotNull(throwable) - assertEquals(errorMessage,throwable.message) - assert(throwable is IllegalArgumentException) - } -} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/EnableManualFlowControlTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/EnableManualFlowControlTests.kt deleted file mode 100644 index 873da0a..0000000 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/EnableManualFlowControlTests.kt +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Copyright 2019 Kroto+ Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.github.marcoferrer.krotoplus.coroutines.call - -import io.grpc.stub.CallStreamObserver -import io.mockk.* -import kotlinx.coroutines.channels.Channel -import org.junit.Test -import java.util.concurrent.atomic.AtomicBoolean - - -class EnableManualFlowControlTests { - - @Test - fun `Test observer not ready`(){ - - val onReadyHandler = slot() - val targetChannel = mockk>() - val observer = mockk>().apply { - every { isReady } returns false - every { setOnReadyHandler(capture(onReadyHandler)) } just Runs - every { disableAutoInboundFlowControl() } just Runs - } - - observer.enableManualFlowControl(targetChannel, AtomicBoolean()) - - onReadyHandler.captured.run() - - verify(exactly = 1) { observer.isReady } - verify(exactly = 1) { observer.disableAutoInboundFlowControl() } - verify(exactly = 1) { observer.setOnReadyHandler(any()) } - verify(inverse = true) { observer.request(any()) } - } - - @Test - fun `Test channel is full`(){ - - val onReadyHandler = slot() - val targetChannel = mockk>().apply { - every { isFull } returns true - } - val observer = mockk>().apply { - every { isReady } returns true - every { setOnReadyHandler(capture(onReadyHandler)) } just Runs - every { disableAutoInboundFlowControl() } just Runs - } - - observer.enableManualFlowControl(targetChannel, AtomicBoolean()) - - onReadyHandler.captured.run() - - verify(exactly = 1) { observer.isReady } - verify(exactly = 1) { observer.disableAutoInboundFlowControl() } - verify(exactly = 1) { observer.setOnReadyHandler(any()) } - verify(inverse = true) { observer.request(any()) } - - verify(exactly = 1) { targetChannel.isFull } - verify(inverse = true) { targetChannel.isClosedForSend } - } - - - @Test - fun `Test channel is closed for send`(){ - - val onReadyHandler = slot() - val isMessagePreloaded = mockk() - val targetChannel = mockk>().apply { - every { isFull } returns false - every { isClosedForSend } returns true - } - val observer = mockk>().apply { - every { isReady } returns true - every { setOnReadyHandler(capture(onReadyHandler)) } just Runs - every { disableAutoInboundFlowControl() } just Runs - } - - observer.enableManualFlowControl(targetChannel,isMessagePreloaded) - - onReadyHandler.captured.run() - - verify(exactly = 1) { observer.isReady } - verify(exactly = 1) { observer.disableAutoInboundFlowControl() } - verify(exactly = 1) { observer.setOnReadyHandler(any()) } - verify(inverse = true) { observer.request(any()) } - - verify(exactly = 1) { targetChannel.isFull } - verify(exactly = 1) { targetChannel.isClosedForSend } - - verify(inverse = true) { isMessagePreloaded.compareAndSet(any(),any()) } - } - - - @Test - fun `Test message is preloaded for target channel`(){ - - val onReadyHandler = slot() - val isMessagePreloaded = mockk().apply { - every { compareAndSet(false,true) } returns false - } - val targetChannel = mockk>().apply { - every { isFull } returns false - every { isClosedForSend } returns false - } - val observer = mockk>().apply { - every { isReady } returns true - every { setOnReadyHandler(capture(onReadyHandler)) } just Runs - every { disableAutoInboundFlowControl() } just Runs - } - - observer.enableManualFlowControl(targetChannel,isMessagePreloaded) - - onReadyHandler.captured.run() - - verify(exactly = 1) { observer.isReady } - verify(exactly = 1) { observer.disableAutoInboundFlowControl() } - verify(exactly = 1) { observer.setOnReadyHandler(any()) } - verify(inverse = true) { observer.request(any()) } - - verify(exactly = 1) { targetChannel.isFull } - verify(exactly = 1) { targetChannel.isClosedForSend } - - verify(exactly = 1) { isMessagePreloaded.compareAndSet(false,true) } - } - - - @Test - fun `Test ready to request new message from observer`(){ - - val onReadyHandler = slot() - val isMessagePreloaded = mockk().apply { - every { compareAndSet(false,true) } returns true - } - val targetChannel = mockk>().apply { - every { isFull } returns false - every { isClosedForSend } returns false - } - val observer = mockk>().apply { - every { isReady } returns true - every { setOnReadyHandler(capture(onReadyHandler)) } just Runs - every { disableAutoInboundFlowControl() } just Runs - every { request(1) } just Runs - } - - observer.enableManualFlowControl(targetChannel,isMessagePreloaded) - - onReadyHandler.captured.run() - - verify(exactly = 1) { observer.isReady } - verify(exactly = 1) { observer.disableAutoInboundFlowControl() } - verify(exactly = 1) { observer.setOnReadyHandler(any()) } - verify(exactly = 1) { observer.request(1) } - - verify(exactly = 1) { targetChannel.isFull } - verify(exactly = 1) { targetChannel.isClosedForSend } - - verify(exactly = 1) { isMessagePreloaded.compareAndSet(false,true) } - } -} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/FlowControlledObserverTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/FlowControlledObserverTests.kt deleted file mode 100644 index fd3051d..0000000 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/FlowControlledObserverTests.kt +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright 2019 Kroto+ Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.github.marcoferrer.krotoplus.coroutines.call - -import io.grpc.stub.CallStreamObserver -import io.mockk.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.Channel -import org.junit.Test -import java.util.concurrent.atomic.AtomicBoolean - - -class NextValueWithBackPressureTests { - - private val mockObserver = object: FlowControlledObserver {} - - @Test - fun `Test channel is closed`() { - - val observer = mockk>() - val isMessagePreloaded = mockk() - val mockScope = mockk() - val channel = mockk>().apply { - every { isClosedForSend } returns true - } - - with(mockObserver){ - mockScope.nextValueWithBackPressure(1,channel, observer, isMessagePreloaded) - } - - verify { isMessagePreloaded wasNot Called } - verify(atLeast = 1) { channel.isClosedForSend } - verify(inverse = true) { channel.offer(any()) } - coVerify(inverse = true) { channel.send(any()) } - } - - @Test - fun `Test channel accepts value`() { - - val observer = mockk>().apply { - every { request(1) } just Runs - } - val mockScope = mockk() - val isMessagePreloaded = mockk() - val channel = mockk>().apply { - every { isClosedForSend } returns false - every { offer(1) } returns true - } - - with(mockObserver){ - mockScope.nextValueWithBackPressure(1,channel, observer, isMessagePreloaded) - } - - verify { isMessagePreloaded wasNot Called } - verify(atLeast = 1) { channel.isClosedForSend } - verify(exactly = 1) { channel.offer(any()) } - verify(exactly = 1) { observer.request(1) } - coVerify(inverse = true) { channel.send(any()) } - } - - @Test - fun `Test channel buffer is full and preload value`() { - - val observer = mockk>().apply { - every { request(1) } just Runs - } - val isMessagePreloaded = mockk().apply { - every { set(allAny()) } just Runs - } - val channel = spyk(Channel(capacity = 1)) - - runBlocking { - channel.offer(0) - assert(channel.isFull){ "Target channel will not cause preload" } - with(mockObserver) { - nextValueWithBackPressure(1, channel, observer, isMessagePreloaded) - } - channel.receive() - } - - verifyOrder { - isMessagePreloaded.set(true) - isMessagePreloaded.set(false) - } - verify { channel.offer(any()) } - verify(atLeast = 1) { channel.isClosedForSend } - verify(exactly = 1) { observer.request(1) } - coVerify(exactly = 1) { channel.send(any()) } - } -} - - diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallBidiStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallBidiStreamingTests.kt index df1694d..1bad791 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallBidiStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallBidiStreamingTests.kt @@ -31,6 +31,8 @@ import io.mockk.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.channels.map +import kotlinx.coroutines.channels.toList import org.junit.Rule import org.junit.Test import kotlin.test.BeforeTest @@ -126,7 +128,7 @@ class ClientCallBidiStreamingTests { val (requestChannel, responseChannel) = stub .clientCallBidiStreaming(methodDescriptor) - runBlocking(Dispatchers.Default) { + val result = runBlocking(Dispatchers.Default) { launch { repeat(3){ requestChannel.send( @@ -136,13 +138,14 @@ class ClientCallBidiStreamingTests { } requestChannel.close() } - launch{ - repeat(3){ - assertEquals("Req:#$it/Resp:#$it",responseChannel.receive().message) - } - } + + responseChannel.map { it.message }.toList() } + assertEquals(3,result.size) + result.forEachIndexed { index, message -> + assertEquals("Req:#$index/Resp:#$index",message) + } verify(exactly = 0) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } @@ -173,7 +176,7 @@ class ClientCallBidiStreamingTests { } } } - launch{ + launch { repeat(3) { assertEquals("Req:#$it/Resp:#$it", responseChannel.receive().message) } @@ -183,12 +186,7 @@ class ClientCallBidiStreamingTests { } } - verify(exactly = 1) { - rpcSpy.call.cancel( - "Cancelled by client with StreamObserver.onError()", - matchStatus(Status.INVALID_ARGUMENT) - ) - } + verify(exactly = 0) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } } @@ -230,9 +228,7 @@ class ClientCallBidiStreamingTests { } } - // First invocation comes from the requestChannel being closed and calling `onError` - // Second invocation comes from the scope cancellation handler - verify(exactly = 2) { rpcSpy.call.cancel(any(), any()) } + verify(exactly = 1) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } } diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallClientStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallClientStreamingTests.kt index 44f5e57..ef52fcf 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallClientStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallClientStreamingTests.kt @@ -228,9 +228,7 @@ class ClientCallClientStreamingTests { } } - // First invocation comes from the requestChannel being closed and calling `onError` - // Second invocation comes from the scope cancellation handler - verify(exactly = 2) { rpcSpy.call.cancel(any(), any()) } + verify(exactly = 1) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } } diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallBidiStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallBidiStreamingTests.kt index a4dbbfe..f5b006d 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallBidiStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallBidiStreamingTests.kt @@ -34,6 +34,7 @@ import io.grpc.testing.GrpcServerRule import io.mockk.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import org.junit.Ignore import org.junit.Rule import org.junit.Test import kotlin.coroutines.CoroutineContext diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallClientStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallClientStreamingTests.kt index aa4c06f..763a955 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallClientStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallClientStreamingTests.kt @@ -35,6 +35,7 @@ import kotlinx.coroutines.* import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.consumeEach import kotlinx.coroutines.channels.toList +import org.junit.Ignore import org.junit.Rule import org.junit.Test import kotlin.coroutines.CoroutineContext @@ -173,6 +174,9 @@ class ServerCallClientStreamingTests { .sayHelloClientStreaming(responseObserver) requestObserver.sendRequests(3) + + // We sleep to ensure the server had time to send all responses + Thread.sleep(10L) verify(exactly = 0) { responseObserver.onError(any()) } verify(exactly = 1) { responseObserver.onNext(expectedResponse) } verify(exactly = 1) { responseObserver.onCompleted() } diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallServerStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallServerStreamingTests.kt index bb0f21b..e559e06 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallServerStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallServerStreamingTests.kt @@ -48,9 +48,15 @@ class ServerCallServerStreamingTests { private val request = HelloRequest.newBuilder().setName("abc").build() private val expectedResponse = HelloReply.newBuilder().setMessage("reply").build() private val responseObserver = spyk>(object: StreamObserver{ - override fun onNext(value: HelloReply?) {} - override fun onError(t: Throwable?) {} - override fun onCompleted() {} + override fun onNext(value: HelloReply?) { +// println("client:onNext:$value") + } + override fun onError(t: Throwable?) { +// println("client:onError:$t") + } + override fun onCompleted() { +// println("client:onComplete") + } }) private fun newCall(): ClientCall { @@ -114,6 +120,8 @@ class ServerCallServerStreamingTests { fun `Server responds with cancellation when scope cancelled normally`(){ lateinit var respChannel: SendChannel grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ + // We're using `Dispatchers.Unconfined` so that we can make sure the response was returned + // before verifying the result. override val initialContext: CoroutineContext = Dispatchers.Unconfined override suspend fun sayHelloServerStreaming( request: HelloRequest, @@ -152,10 +160,11 @@ class ServerCallServerStreamingTests { ) { respChannel = responseChannel coroutineScope { - launch { + launch(start = CoroutineStart.UNDISPATCHED) { error("unexpected cancellation") } repeat(3){ + yield() responseChannel.send(expectedResponse) } } diff --git a/protoc-gen-kroto-plus/generator-tests/src/test/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcCoroutinesGeneratorTests.kt b/protoc-gen-kroto-plus/generator-tests/src/test/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcCoroutinesGeneratorTests.kt index b763d2d..67f11ee 100644 --- a/protoc-gen-kroto-plus/generator-tests/src/test/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcCoroutinesGeneratorTests.kt +++ b/protoc-gen-kroto-plus/generator-tests/src/test/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcCoroutinesGeneratorTests.kt @@ -20,15 +20,15 @@ import com.github.marcoferrer.krotoplus.coroutines.launchProducerJob import com.github.marcoferrer.krotoplus.coroutines.withCoroutineContext import io.grpc.examples.helloworld.* import io.grpc.testing.GrpcServerRule -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.ObsoleteCoroutinesApi +import kotlinx.coroutines.* import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.channels.consumeEach import kotlinx.coroutines.channels.toList -import kotlinx.coroutines.runBlocking +import org.junit.Ignore import org.junit.Rule import org.junit.Test +import kotlin.coroutines.CoroutineContext import kotlin.test.assertEquals import kotlin.test.assertNull import kotlin.test.BeforeTest @@ -45,6 +45,9 @@ class GrpcCoroutinesGeneratorTests { fun setupService(){ grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ + override val initialContext: CoroutineContext + get() = Dispatchers.Default + override suspend fun sayHello(request: HelloRequest): HelloReply { return HelloReply { message = expectedMessage } } @@ -141,13 +144,16 @@ class GrpcCoroutinesGeneratorTests { val (requestChannel, responseChannel) = stub.sayHelloStreaming() - launchProducerJob(requestChannel) { +// launchProducerJob(requestChannel) { + launch(Dispatchers.Default) { repeat(3) { - send { name = "name $it" } + requestChannel.send { name = "name $it" } } + requestChannel.close() } val results = responseChannel.toList() + println(results) assertEquals(9, results.size) val expected = "name 0|name 0|name 0" + diff --git a/protoc-gen-kroto-plus/generator-tests/src/test/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcStubExtsGeneratorTests.kt b/protoc-gen-kroto-plus/generator-tests/src/test/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcStubExtsGeneratorTests.kt index a24dbe7..3a34734 100644 --- a/protoc-gen-kroto-plus/generator-tests/src/test/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcStubExtsGeneratorTests.kt +++ b/protoc-gen-kroto-plus/generator-tests/src/test/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcStubExtsGeneratorTests.kt @@ -24,6 +24,7 @@ import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import org.junit.Rule import org.junit.Test +import kotlin.coroutines.CoroutineContext import kotlin.test.BeforeTest import kotlin.test.assertEquals import kotlin.test.assertNull @@ -40,6 +41,9 @@ class GrpcStubExtsGeneratorTests { fun setupService(){ grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ + override val initialContext: CoroutineContext + get() = Dispatchers.Unconfined + override suspend fun sayHello(request: HelloRequest): HelloReply { return HelloReply { message = expectedMessage } }