Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

portforward: Improve API and support background task cancelation #854

Merged
merged 2 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/pod_portforward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ async fn main() -> anyhow::Result<()> {
let _ = tokio::time::timeout(std::time::Duration::from_secs(15), running).await?;

let mut pf = pods.portforward("example", &[80]).await?;
let ports = pf.ports();
let mut port = ports[0].stream().unwrap();
let mut port = pf.take_stream(80).unwrap();
port.write_all(b"GET / HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\nAccept: */*\r\n\r\n")
.await?;
let mut rstream = tokio_util::io::ReaderStream::new(port);
Expand Down
4 changes: 2 additions & 2 deletions examples/pod_portforward_bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async fn main() -> anyhow::Result<()> {
// Get `Portforwarder` that handles the WebSocket connection.
// There's no need to spawn a task to drive this, but it can be awaited to be notified on error.
let mut forwarder = pods.portforward("example", &[80]).await?;
let port = forwarder.ports()[0].stream().unwrap();
let port = forwarder.take_stream(80).unwrap();

// let hyper drive the HTTP state in our DuplexStream via a task
let (sender, connection) = hyper::client::conn::handshake(port).await?;
Expand All @@ -59,7 +59,7 @@ async fn main() -> anyhow::Result<()> {
// The following task is only used to show any error from the forwarder.
// This example can be stopped with Ctrl-C if anything happens.
tokio::spawn(async move {
if let Err(e) = forwarder.await {
if let Err(e) = forwarder.join().await {
log::error!("forwarder errored: {}", e);
}
});
Expand Down
3 changes: 1 addition & 2 deletions examples/pod_portforward_hyper_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ async fn main() -> anyhow::Result<()> {
let _ = tokio::time::timeout(std::time::Duration::from_secs(15), running).await?;

let mut pf = pods.portforward("example", &[80]).await?;
let ports = pf.ports();
let port = ports[0].stream().unwrap();
let port = pf.take_stream(80).unwrap();

// let hyper drive the HTTP state in our DuplexStream via a task
let (mut sender, connection) = hyper::client::conn::handshake(port).await?;
Expand Down
144 changes: 54 additions & 90 deletions kube-client/src/api/portforward.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
use std::{
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};
use std::{collections::HashMap, future::Future};

use bytes::{Buf, Bytes};
use futures::{
Expand Down Expand Up @@ -62,6 +57,9 @@ pub enum Error {
/// Failed to receive a WebSocket message from the server.
#[error("failed to receive a WebSocket message: {0}")]
ReceiveWebSocketMessage(#[source] ws::Error),

#[error("failed to complete the background task: {0}")]
Spawn(#[source] tokio::task::JoinError),
}

type ErrorReceiver = oneshot::Receiver<String>;
Expand All @@ -73,109 +71,75 @@ enum Message {
ToPod(u8, Bytes),
}

struct PortforwarderState {
waker: Option<Waker>,
result: Option<Result<(), Error>>,
}

// Provides `AsyncRead + AsyncWrite` for each port and **does not** bind to local ports.
// Error channel for each port is only written by the server when there's an exception and
// the port cannot be used (didn't initialize or can't be used anymore).
/// Manage port forwarding.
/// Manages port-forwarded streams.
///
/// Provides `AsyncRead + AsyncWrite` for each port and **does not** bind to local ports. Error
/// channel for each port is only written by the server when there's an exception and
//. the port cannot be used (didn't initialize or can't be used anymore).
pub struct Portforwarder {
ports: Vec<Port>,
state: Arc<Mutex<PortforwarderState>>,
ports: HashMap<u16, DuplexStream>,
errors: HashMap<u16, ErrorReceiver>,
task: tokio::task::JoinHandle<Result<(), Error>>,
}

impl Portforwarder {
pub(crate) fn new<S>(stream: WebSocketStream<S>, port_nums: &[u16]) -> Self
where
S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
let mut ports = Vec::new();
let mut errors = Vec::new();
let mut duplexes = Vec::new();
for _ in port_nums.iter() {
let mut ports = HashMap::with_capacity(port_nums.len());
let mut error_rxs = HashMap::with_capacity(port_nums.len());
let mut error_txs = Vec::with_capacity(port_nums.len());
let mut task_ios = Vec::with_capacity(port_nums.len());
for port in port_nums.iter() {
let (a, b) = tokio::io::duplex(1024 * 1024);
let (tx, rx) = oneshot::channel();
ports.push(Port::new(a, rx));
errors.push(Some(tx));
duplexes.push(b);
}

let state = Arc::new(Mutex::new(PortforwarderState {
waker: None,
result: None,
}));
let shared_state = state.clone();
let port_nums = port_nums.to_owned();
tokio::spawn(async move {
let result = start_message_loop(stream, port_nums, duplexes, errors).await;

let mut shared = shared_state.lock().unwrap();
shared.result = Some(result);
if let Some(waker) = shared.waker.take() {
waker.wake()
}
});
Portforwarder { ports, state }
}
ports.insert(*port, a);
task_ios.push(b);

/// Get streams for forwarded ports.
pub fn ports(&mut self) -> &mut [Port] {
self.ports.as_mut_slice()
}
}

impl Future for Portforwarder {
type Output = Result<(), Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.state.lock().unwrap();
if let Some(result) = state.result.take() {
return Poll::Ready(result);
}

if let Some(waker) = &state.waker {
if waker.will_wake(cx.waker()) {
return Poll::Pending;
}
let (tx, rx) = oneshot::channel();
error_rxs.insert(*port, rx);
error_txs.push(Some(tx));
}

state.waker = Some(cx.waker().clone());
Poll::Pending
}
}

pub struct Port {
// Data pipe.
stream: Option<DuplexStream>,
// Error channel.
error: Option<ErrorReceiver>,
}

impl Port {
pub(crate) fn new(stream: DuplexStream, error: ErrorReceiver) -> Self {
Port {
stream: Some(stream),
error: Some(error),
let task = tokio::spawn(start_message_loop(
stream,
port_nums.to_vec(),
task_ios,
error_txs,
));

Portforwarder {
ports,
errors: error_rxs,
task,
}
}

/// Data pipe for sending to and receiving from the forwarded port.
/// Take a port stream by the port on the target resource.
///
/// This returns a `Some` on the first call, then a `None` on every subsequent call
pub fn stream(&mut self) -> Option<impl AsyncRead + AsyncWrite + Unpin> {
self.stream.take()
/// A value is returned at most once per port.
#[inline]
pub fn take_stream(&mut self, port: u16) -> Option<impl AsyncRead + AsyncWrite + Unpin> {
self.ports.remove(&port)
}

/// Future that resolves with any error message or when the error sender is dropped.
/// Take a future that resolves with any error message or when the error sender is dropped.
/// When the future resolves, the port should be considered no longer usable.
///
/// This returns a `Some` on the first call, then a `None` on every subsequent call
pub fn error(&mut self) -> Option<impl Future<Output = Option<String>>> {
// Ignore Cancellation error.
self.error.take().map(|recv| recv.map(|res| res.ok()))
/// A value is returned at most once per port.
#[inline]
pub fn take_error(&mut self, port: u16) -> Option<impl Future<Output = Option<String>>> {
self.errors.remove(&port).map(|recv| recv.map(|res| res.ok()))
}

/// Abort the background task, causing port forwards to fail.
#[inline]
pub fn abort(&self) {
self.task.abort();
}

/// Waits for port forwarding task to complete.
pub async fn join(self) -> Result<(), Error> {
self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
}
}

Expand Down