Skip to content

Commit

Permalink
portforward: Improve API and support background task cancelation (#854)
Browse files Browse the repository at this point in the history
  • Loading branch information
olix0r authored Mar 22, 2022
1 parent f7ac702 commit 116a970
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 96 deletions.
3 changes: 1 addition & 2 deletions examples/pod_portforward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ async fn main() -> anyhow::Result<()> {
let _ = tokio::time::timeout(std::time::Duration::from_secs(15), running).await?;

let mut pf = pods.portforward("example", &[80]).await?;
let ports = pf.ports();
let mut port = ports[0].stream().unwrap();
let mut port = pf.take_stream(80).unwrap();
port.write_all(b"GET / HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nAccept: */*\r\n\r\n")
.await?;
let mut rstream = tokio_util::io::ReaderStream::new(port);
Expand Down
4 changes: 2 additions & 2 deletions examples/pod_portforward_bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async fn main() -> anyhow::Result<()> {
// Get `Portforwarder` that handles the WebSocket connection.
// There's no need to spawn a task to drive this, but it can be awaited to be notified on error.
let mut forwarder = pods.portforward("example", &[80]).await?;
let port = forwarder.ports()[0].stream().unwrap();
let port = forwarder.take_stream(80).unwrap();

// let hyper drive the HTTP state in our DuplexStream via a task
let (sender, connection) = hyper::client::conn::handshake(port).await?;
Expand All @@ -59,7 +59,7 @@ async fn main() -> anyhow::Result<()> {
// The following task is only used to show any error from the forwarder.
// This example can be stopped with Ctrl-C if anything happens.
tokio::spawn(async move {
if let Err(e) = forwarder.await {
if let Err(e) = forwarder.join().await {
log::error!("forwarder errored: {}", e);
}
});
Expand Down
3 changes: 1 addition & 2 deletions examples/pod_portforward_hyper_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ async fn main() -> anyhow::Result<()> {
let _ = tokio::time::timeout(std::time::Duration::from_secs(15), running).await?;

let mut pf = pods.portforward("example", &[80]).await?;
let ports = pf.ports();
let port = ports[0].stream().unwrap();
let port = pf.take_stream(80).unwrap();

// let hyper drive the HTTP state in our DuplexStream via a task
let (mut sender, connection) = hyper::client::conn::handshake(port).await?;
Expand Down
144 changes: 54 additions & 90 deletions kube-client/src/api/portforward.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
use std::{
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};
use std::{collections::HashMap, future::Future};

use bytes::{Buf, Bytes};
use futures::{
Expand Down Expand Up @@ -62,6 +57,9 @@ pub enum Error {
/// Failed to receive a WebSocket message from the server.
#[error("failed to receive a WebSocket message: {0}")]
ReceiveWebSocketMessage(#[source] ws::Error),

#[error("failed to complete the background task: {0}")]
Spawn(#[source] tokio::task::JoinError),
}

type ErrorReceiver = oneshot::Receiver<String>;
Expand All @@ -73,109 +71,75 @@ enum Message {
ToPod(u8, Bytes),
}

struct PortforwarderState {
waker: Option<Waker>,
result: Option<Result<(), Error>>,
}

// Provides `AsyncRead + AsyncWrite` for each port and **does not** bind to local ports.
// Error channel for each port is only written by the server when there's an exception and
// the port cannot be used (didn't initialize or can't be used anymore).
/// Manage port forwarding.
/// Manages port-forwarded streams.
///
/// Provides `AsyncRead + AsyncWrite` for each port and **does not** bind to local ports. Error
/// channel for each port is only written by the server when there's an exception and
//. the port cannot be used (didn't initialize or can't be used anymore).
pub struct Portforwarder {
ports: Vec<Port>,
state: Arc<Mutex<PortforwarderState>>,
ports: HashMap<u16, DuplexStream>,
errors: HashMap<u16, ErrorReceiver>,
task: tokio::task::JoinHandle<Result<(), Error>>,
}

impl Portforwarder {
pub(crate) fn new<S>(stream: WebSocketStream<S>, port_nums: &[u16]) -> Self
where
S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
let mut ports = Vec::new();
let mut errors = Vec::new();
let mut duplexes = Vec::new();
for _ in port_nums.iter() {
let mut ports = HashMap::with_capacity(port_nums.len());
let mut error_rxs = HashMap::with_capacity(port_nums.len());
let mut error_txs = Vec::with_capacity(port_nums.len());
let mut task_ios = Vec::with_capacity(port_nums.len());
for port in port_nums.iter() {
let (a, b) = tokio::io::duplex(1024 * 1024);
let (tx, rx) = oneshot::channel();
ports.push(Port::new(a, rx));
errors.push(Some(tx));
duplexes.push(b);
}

let state = Arc::new(Mutex::new(PortforwarderState {
waker: None,
result: None,
}));
let shared_state = state.clone();
let port_nums = port_nums.to_owned();
tokio::spawn(async move {
let result = start_message_loop(stream, port_nums, duplexes, errors).await;

let mut shared = shared_state.lock().unwrap();
shared.result = Some(result);
if let Some(waker) = shared.waker.take() {
waker.wake()
}
});
Portforwarder { ports, state }
}
ports.insert(*port, a);
task_ios.push(b);

/// Get streams for forwarded ports.
pub fn ports(&mut self) -> &mut [Port] {
self.ports.as_mut_slice()
}
}

impl Future for Portforwarder {
type Output = Result<(), Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.state.lock().unwrap();
if let Some(result) = state.result.take() {
return Poll::Ready(result);
}

if let Some(waker) = &state.waker {
if waker.will_wake(cx.waker()) {
return Poll::Pending;
}
let (tx, rx) = oneshot::channel();
error_rxs.insert(*port, rx);
error_txs.push(Some(tx));
}

state.waker = Some(cx.waker().clone());
Poll::Pending
}
}

pub struct Port {
// Data pipe.
stream: Option<DuplexStream>,
// Error channel.
error: Option<ErrorReceiver>,
}

impl Port {
pub(crate) fn new(stream: DuplexStream, error: ErrorReceiver) -> Self {
Port {
stream: Some(stream),
error: Some(error),
let task = tokio::spawn(start_message_loop(
stream,
port_nums.to_vec(),
task_ios,
error_txs,
));

Portforwarder {
ports,
errors: error_rxs,
task,
}
}

/// Data pipe for sending to and receiving from the forwarded port.
/// Take a port stream by the port on the target resource.
///
/// This returns a `Some` on the first call, then a `None` on every subsequent call
pub fn stream(&mut self) -> Option<impl AsyncRead + AsyncWrite + Unpin> {
self.stream.take()
/// A value is returned at most once per port.
#[inline]
pub fn take_stream(&mut self, port: u16) -> Option<impl AsyncRead + AsyncWrite + Unpin> {
self.ports.remove(&port)
}

/// Future that resolves with any error message or when the error sender is dropped.
/// Take a future that resolves with any error message or when the error sender is dropped.
/// When the future resolves, the port should be considered no longer usable.
///
/// This returns a `Some` on the first call, then a `None` on every subsequent call
pub fn error(&mut self) -> Option<impl Future<Output = Option<String>>> {
// Ignore Cancellation error.
self.error.take().map(|recv| recv.map(|res| res.ok()))
/// A value is returned at most once per port.
#[inline]
pub fn take_error(&mut self, port: u16) -> Option<impl Future<Output = Option<String>>> {
self.errors.remove(&port).map(|recv| recv.map(|res| res.ok()))
}

/// Abort the background task, causing port forwards to fail.
#[inline]
pub fn abort(&self) {
self.task.abort();
}

/// Waits for port forwarding task to complete.
pub async fn join(self) -> Result<(), Error> {
self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
}
}

Expand Down

0 comments on commit 116a970

Please sign in to comment.