diff --git a/examples/pod_portforward.rs b/examples/pod_portforward.rs index da72195bd..b8720bf2f 100644 --- a/examples/pod_portforward.rs +++ b/examples/pod_portforward.rs @@ -37,8 +37,7 @@ async fn main() -> anyhow::Result<()> { let _ = tokio::time::timeout(std::time::Duration::from_secs(15), running).await?; let mut pf = pods.portforward("example", &[80]).await?; - let ports = pf.ports(); - let mut port = ports[0].stream().unwrap(); + let mut port = pf.take_stream(80).unwrap(); port.write_all(b"GET / HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nAccept: */*\r\n\r\n") .await?; let mut rstream = tokio_util::io::ReaderStream::new(port); diff --git a/examples/pod_portforward_bind.rs b/examples/pod_portforward_bind.rs index 192fbdc88..6705426c6 100644 --- a/examples/pod_portforward_bind.rs +++ b/examples/pod_portforward_bind.rs @@ -47,7 +47,7 @@ async fn main() -> anyhow::Result<()> { // Get `Portforwarder` that handles the WebSocket connection. // There's no need to spawn a task to drive this, but it can be awaited to be notified on error. let mut forwarder = pods.portforward("example", &[80]).await?; - let port = forwarder.ports()[0].stream().unwrap(); + let port = forwarder.take_stream(80).unwrap(); // let hyper drive the HTTP state in our DuplexStream via a task let (sender, connection) = hyper::client::conn::handshake(port).await?; @@ -59,7 +59,7 @@ async fn main() -> anyhow::Result<()> { // The following task is only used to show any error from the forwarder. // This example can be stopped with Ctrl-C if anything happens. tokio::spawn(async move { - if let Err(e) = forwarder.await { + if let Err(e) = forwarder.join().await { log::error!("forwarder errored: {}", e); } }); diff --git a/examples/pod_portforward_hyper_http.rs b/examples/pod_portforward_hyper_http.rs index 27df7e657..fdec1984a 100644 --- a/examples/pod_portforward_hyper_http.rs +++ b/examples/pod_portforward_hyper_http.rs @@ -36,8 +36,7 @@ async fn main() -> anyhow::Result<()> { let _ = tokio::time::timeout(std::time::Duration::from_secs(15), running).await?; let mut pf = pods.portforward("example", &[80]).await?; - let ports = pf.ports(); - let port = ports[0].stream().unwrap(); + let port = pf.take_stream(80).unwrap(); // let hyper drive the HTTP state in our DuplexStream via a task let (mut sender, connection) = hyper::client::conn::handshake(port).await?; diff --git a/kube-client/src/api/portforward.rs b/kube-client/src/api/portforward.rs index 901b06271..56f2d1bea 100644 --- a/kube-client/src/api/portforward.rs +++ b/kube-client/src/api/portforward.rs @@ -1,9 +1,4 @@ -use std::{ - future::Future, - pin::Pin, - sync::{Arc, Mutex}, - task::{Context, Poll, Waker}, -}; +use std::{collections::HashMap, future::Future}; use bytes::{Buf, Bytes}; use futures::{ @@ -62,6 +57,9 @@ pub enum Error { /// Failed to receive a WebSocket message from the server. #[error("failed to receive a WebSocket message: {0}")] ReceiveWebSocketMessage(#[source] ws::Error), + + #[error("failed to complete the background task: {0}")] + Spawn(#[source] tokio::task::JoinError), } type ErrorReceiver = oneshot::Receiver; @@ -73,18 +71,15 @@ enum Message { ToPod(u8, Bytes), } -struct PortforwarderState { - waker: Option, - result: Option>, -} - -// Provides `AsyncRead + AsyncWrite` for each port and **does not** bind to local ports. -// Error channel for each port is only written by the server when there's an exception and -// the port cannot be used (didn't initialize or can't be used anymore). -/// Manage port forwarding. +/// Manages port-forwarded streams. +/// +/// Provides `AsyncRead + AsyncWrite` for each port and **does not** bind to local ports. Error +/// channel for each port is only written by the server when there's an exception and +//. the port cannot be used (didn't initialize or can't be used anymore). pub struct Portforwarder { - ports: Vec, - state: Arc>, + ports: HashMap, + errors: HashMap, + task: tokio::task::JoinHandle>, } impl Portforwarder { @@ -92,90 +87,59 @@ impl Portforwarder { where S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static, { - let mut ports = Vec::new(); - let mut errors = Vec::new(); - let mut duplexes = Vec::new(); - for _ in port_nums.iter() { + let mut ports = HashMap::with_capacity(port_nums.len()); + let mut error_rxs = HashMap::with_capacity(port_nums.len()); + let mut error_txs = Vec::with_capacity(port_nums.len()); + let mut task_ios = Vec::with_capacity(port_nums.len()); + for port in port_nums.iter() { let (a, b) = tokio::io::duplex(1024 * 1024); - let (tx, rx) = oneshot::channel(); - ports.push(Port::new(a, rx)); - errors.push(Some(tx)); - duplexes.push(b); - } - - let state = Arc::new(Mutex::new(PortforwarderState { - waker: None, - result: None, - })); - let shared_state = state.clone(); - let port_nums = port_nums.to_owned(); - tokio::spawn(async move { - let result = start_message_loop(stream, port_nums, duplexes, errors).await; - - let mut shared = shared_state.lock().unwrap(); - shared.result = Some(result); - if let Some(waker) = shared.waker.take() { - waker.wake() - } - }); - Portforwarder { ports, state } - } + ports.insert(*port, a); + task_ios.push(b); - /// Get streams for forwarded ports. - pub fn ports(&mut self) -> &mut [Port] { - self.ports.as_mut_slice() - } -} - -impl Future for Portforwarder { - type Output = Result<(), Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut state = self.state.lock().unwrap(); - if let Some(result) = state.result.take() { - return Poll::Ready(result); - } - - if let Some(waker) = &state.waker { - if waker.will_wake(cx.waker()) { - return Poll::Pending; - } + let (tx, rx) = oneshot::channel(); + error_rxs.insert(*port, rx); + error_txs.push(Some(tx)); } - - state.waker = Some(cx.waker().clone()); - Poll::Pending - } -} - -pub struct Port { - // Data pipe. - stream: Option, - // Error channel. - error: Option, -} - -impl Port { - pub(crate) fn new(stream: DuplexStream, error: ErrorReceiver) -> Self { - Port { - stream: Some(stream), - error: Some(error), + let task = tokio::spawn(start_message_loop( + stream, + port_nums.to_vec(), + task_ios, + error_txs, + )); + + Portforwarder { + ports, + errors: error_rxs, + task, } } - /// Data pipe for sending to and receiving from the forwarded port. + /// Take a port stream by the port on the target resource. /// - /// This returns a `Some` on the first call, then a `None` on every subsequent call - pub fn stream(&mut self) -> Option { - self.stream.take() + /// A value is returned at most once per port. + #[inline] + pub fn take_stream(&mut self, port: u16) -> Option { + self.ports.remove(&port) } - /// Future that resolves with any error message or when the error sender is dropped. + /// Take a future that resolves with any error message or when the error sender is dropped. /// When the future resolves, the port should be considered no longer usable. /// - /// This returns a `Some` on the first call, then a `None` on every subsequent call - pub fn error(&mut self) -> Option>> { - // Ignore Cancellation error. - self.error.take().map(|recv| recv.map(|res| res.ok())) + /// A value is returned at most once per port. + #[inline] + pub fn take_error(&mut self, port: u16) -> Option>> { + self.errors.remove(&port).map(|recv| recv.map(|res| res.ok())) + } + + /// Abort the background task, causing port forwards to fail. + #[inline] + pub fn abort(&self) { + self.task.abort(); + } + + /// Waits for port forwarding task to complete. + pub async fn join(self) -> Result<(), Error> { + self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e))) } }