Skip to content

Commit

Permalink
Merge pull request torvalds#104 from Rust-for-Linux/condvar
Browse files Browse the repository at this point in the history
Add a condition variable implementation to the `sync` module.
  • Loading branch information
wedsonaf authored Mar 16, 2021
2 parents a116223 + 0098adc commit 58ab534
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 3 deletions.
34 changes: 31 additions & 3 deletions drivers/char/rust_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use alloc::boxed::Box;
use core::pin::Pin;
use kernel::prelude::*;
use kernel::{
chrdev, cstr,
chrdev, condvar_init, cstr,
file_operations::FileOperations,
miscdev, mutex_init, spinlock_init,
sync::{Mutex, SpinLock},
sync::{CondVar, Mutex, SpinLock},
};

module! {
Expand Down Expand Up @@ -86,6 +86,20 @@ impl KernelModule for RustExample {
mutex_init!(data.as_ref(), "RustExample::init::data1");
*data.lock() = 10;
println!("Value: {}", *data.lock());

// SAFETY: `init` is called below.
let cv = Pin::from(Box::try_new(unsafe { CondVar::new() })?);
condvar_init!(cv.as_ref(), "RustExample::init::cv1");
{
let guard = data.lock();
#[allow(clippy::while_immutable_condition)]
while *guard != 10 {
cv.wait(&guard);
}
}
cv.notify_one();
cv.notify_all();
cv.free_waiters();
}

// Test spinlocks.
Expand All @@ -95,13 +109,27 @@ impl KernelModule for RustExample {
spinlock_init!(data.as_ref(), "RustExample::init::data2");
*data.lock() = 10;
println!("Value: {}", *data.lock());

// SAFETY: `init` is called below.
let cv = Pin::from(Box::try_new(unsafe { CondVar::new() })?);
condvar_init!(cv.as_ref(), "RustExample::init::cv2");
{
let guard = data.lock();
#[allow(clippy::while_immutable_condition)]
while *guard != 10 {
cv.wait(&guard);
}
}
cv.notify_one();
cv.notify_all();
cv.free_waiters();
}

// Including this large variable on the stack will trigger
// stack probing on the supported archs.
// This will verify that stack probing does not lead to
// any errors if we need to link `__rust_probestack`.
let x: [u64; 1028] = core::hint::black_box([5; 1028]);
let x: [u64; 514] = core::hint::black_box([5; 514]);
println!("Large array has length: {}", x.len());

let mut chrdev_reg =
Expand Down
13 changes: 13 additions & 0 deletions rust/helpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <linux/bug.h>
#include <linux/build_bug.h>
#include <linux/uaccess.h>
#include <linux/sched/signal.h>

void rust_helper_BUG(void)
{
Expand Down Expand Up @@ -47,6 +48,18 @@ void rust_helper_spin_unlock(spinlock_t *lock)
}
EXPORT_SYMBOL(rust_helper_spin_unlock);

void rust_helper_init_wait(struct wait_queue_entry *wq_entry)
{
init_wait(wq_entry);
}
EXPORT_SYMBOL(rust_helper_init_wait);

int rust_helper_signal_pending(void)
{
return signal_pending(current);
}
EXPORT_SYMBOL(rust_helper_signal_pending);

// See https://github.com/rust-lang/rust-bindgen/issues/1671
static_assert(__builtin_types_compatible_p(size_t, uintptr_t),
"size_t must match uintptr_t, what architecture is this??");
1 change: 1 addition & 0 deletions rust/kernel/bindings_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <linux/uaccess.h>
#include <linux/version.h>
#include <linux/miscdevice.h>
#include <linux/poll.h>

// `bindgen` gets confused at certain things
const gfp_t BINDINGS_GFP_KERNEL = GFP_KERNEL;
Expand Down
137 changes: 137 additions & 0 deletions rust/kernel/sync/condvar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// SPDX-License-Identifier: GPL-2.0

//! A condition variable.
//!
//! This module allows Rust code to use the kernel's [`struct wait_queue_head`] as a condition
//! variable.
use super::{Guard, Lock, NeedsLockClass};
use crate::{bindings, c_types, CStr};
use core::{cell::UnsafeCell, marker::PhantomPinned, mem::MaybeUninit, pin::Pin};

extern "C" {
fn rust_helper_init_wait(wq: *mut bindings::wait_queue_entry);
fn rust_helper_signal_pending() -> c_types::c_int;
}

/// Safely initialises a [`CondVar`] with the given name, generating a new lock class.
#[macro_export]
macro_rules! condvar_init {
($condvar:expr, $name:literal) => {
$crate::init_with_lockdep!($condvar, $name)
};
}

// TODO: `bindgen` is not generating this constant. Figure out why.
const POLLFREE: u32 = 0x4000;

/// Exposes the kernel's [`struct wait_queue_head`] as a condition variable. It allows the caller to
/// atomically release the given lock and go to sleep. It reacquires the lock when it wakes up. And
/// it wakes up when notified by another thread (via [`CondVar::notify_one`] or
/// [`CondVar::notify_all`]) or because the thread received a signal.
///
/// [`struct wait_queue_head`]: ../../../include/linux/wait.h
pub struct CondVar {
pub(crate) wait_list: UnsafeCell<bindings::wait_queue_head>,

/// A condvar needs to be pinned because it contains a [`struct list_head`] that is
/// self-referential, so it cannot be safely moved once it is initialised.
_pin: PhantomPinned,
}

// SAFETY: `CondVar` only uses a `struct wait_queue_head`, which is safe to use on any thread.
unsafe impl Send for CondVar {}

// SAFETY: `CondVar` only uses a `struct wait_queue_head`, which is safe to use on multiple threads
// concurrently.
unsafe impl Sync for CondVar {}

impl CondVar {
/// Constructs a new conditional variable.
///
/// # Safety
///
/// The caller must call `CondVar::init` before using the conditional variable.
pub unsafe fn new() -> Self {
Self {
wait_list: UnsafeCell::new(bindings::wait_queue_head::default()),
_pin: PhantomPinned,
}
}

/// Atomically releases the given lock (whose ownership is proven by the guard) and puts the
/// thread to sleep. It wakes up when notified by [`CondVar::notify_one`] or
/// [`CondVar::notify_all`], or when the thread receives a signal.
///
/// Returns whether there is a signal pending.
pub fn wait<L: Lock>(&self, g: &Guard<L>) -> bool {
let l = g.lock;
let mut wait = MaybeUninit::<bindings::wait_queue_entry>::uninit();

// SAFETY: `wait` points to valid memory.
unsafe { rust_helper_init_wait(wait.as_mut_ptr()) };

// SAFETY: Both `wait` and `wait_list` point to valid memory.
unsafe {
bindings::prepare_to_wait_exclusive(
self.wait_list.get(),
wait.as_mut_ptr(),
bindings::TASK_INTERRUPTIBLE as _,
);
}

// SAFETY: The guard is evidence that the caller owns the lock.
unsafe { l.unlock() };

// SAFETY: No arguments, switches to another thread.
unsafe { bindings::schedule() };

l.lock_noguard();

// SAFETY: Both `wait` and `wait_list` point to valid memory.
unsafe { bindings::finish_wait(self.wait_list.get(), wait.as_mut_ptr()) };

// SAFETY: No arguments, just checks `current` for pending signals.
unsafe { rust_helper_signal_pending() != 0 }
}

/// Calls the kernel function to notify the appropriate number of threads with the given flags.
fn notify(&self, count: i32, flags: u32) {
// SAFETY: `wait_list` points to valid memory.
unsafe {
bindings::__wake_up(
self.wait_list.get(),
bindings::TASK_NORMAL,
count,
flags as _,
)
};
}

/// Wakes a single waiter up, if any. This is not 'sticky' in the sense that if no thread is
/// waiting, the notification is lost completely (as opposed to automatically waking up the
/// next waiter).
pub fn notify_one(&self) {
self.notify(1, 0);
}

/// Wakes all waiters up, if any. This is not 'sticky' in the sense that if no thread is
/// waiting, the notification is lost completely (as opposed to automatically waking up the
/// next waiter).
pub fn notify_all(&self) {
self.notify(0, 0);
}

/// Wakes all waiters up. If they were added by `epoll`, they are also removed from the list of
/// waiters. This is useful when cleaning up a condition variable that may be waited on by
/// threads that use `epoll`.
pub fn free_waiters(&self) {
self.notify(1, bindings::POLLHUP | POLLFREE);
}
}

impl NeedsLockClass for CondVar {
unsafe fn init(self: Pin<&Self>, name: CStr<'static>, key: *mut bindings::lock_class_key) {
bindings::__init_waitqueue_head(self.wait_list.get(), name.as_ptr() as _, key);
}
}
2 changes: 2 additions & 0 deletions rust/kernel/sync/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
use crate::{bindings, CStr};
use core::pin::Pin;

mod condvar;
mod guard;
mod mutex;
mod spinlock;

pub use condvar::CondVar;
pub use guard::{Guard, Lock};
pub use mutex::Mutex;
pub use spinlock::SpinLock;
Expand Down

0 comments on commit 58ab534

Please sign in to comment.