Skip to content

Commit

Permalink
Merge pull request #737 from davidbarsky/davidbarsky/remove-arc-swap
Browse files Browse the repository at this point in the history
internal: Replace `arc-swap` with manual `AtomicPtr`
  • Loading branch information
Veykril authored Feb 26, 2025
2 parents 26aeeec + da9a21c commit 99be5d9
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 149 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ rust-version = "1.80"
salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" }
salsa-macros = { version = "0.18.0", path = "components/salsa-macros" }

arc-swap = "1"
boxcar = "0.2.9"
crossbeam-queue = "0.3.11"
dashmap = { version = "6", features = ["raw-api"] }
Expand Down
17 changes: 11 additions & 6 deletions src/function.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{any::Any, fmt, mem::ManuallyDrop, sync::Arc};
use std::{any::Any, fmt, ptr::NonNull};

use crate::{
accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues},
Expand Down Expand Up @@ -176,19 +176,25 @@ where
memo: memo::Memo<C::Output<'db>>,
memo_ingredient_index: MemoIngredientIndex,
) -> &'db memo::Memo<C::Output<'db>> {
let memo = Arc::new(memo);
// We convert to a `NonNull` here as soon as possible because we are going to alias
// into the `Box`, which is a `noalias` type.
let memo = unsafe { NonNull::new_unchecked(Box::into_raw(Box::new(memo))) };

// Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this
// value is returned) and anything removed from map is added to deleted entries (ensured elsewhere).
let db_memo = unsafe { self.extend_memo_lifetime(&memo) };
let db_memo = unsafe { self.extend_memo_lifetime(memo.as_ref()) };

// Safety: We delay the drop of `old_value` until a new revision starts which ensures no
// references will exist for the memo contents.
if let Some(old_value) =
unsafe { self.insert_memo_into_table_for(zalsa, id, memo, memo_ingredient_index) }
{
// In case there is a reference to the old memo out there, we have to store it
// in the deleted entries. This will get cleared when a new revision starts.
self.deleted_entries
.push(ManuallyDrop::into_inner(old_value));
//
// SAFETY: Once the revision starts, there will be no oustanding borrows to the
// memo contents, and so it will be safe to free.
unsafe { self.deleted_entries.push(old_value) };
}
db_memo
}
Expand Down Expand Up @@ -254,7 +260,6 @@ where
let ingredient_index = table.ingredient_index(evict);
Self::evict_value_from_memo_for(
table.memos_mut(evict),
&self.deleted_entries,
self.memo_ingredient_indices.get(ingredient_index),
)
});
Expand Down
37 changes: 32 additions & 5 deletions src/function/delete.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
use std::ptr::NonNull;

use crossbeam_queue::SegQueue;

use super::{memo::ArcMemo, Configuration};
use super::memo::Memo;
use super::Configuration;

/// Stores the list of memos that have been deleted so they can be freed
/// once the next revision starts. See the comment on the field
/// `deleted_entries` of [`FunctionIngredient`][] for more details.
pub(super) struct DeletedEntries<C: Configuration> {
seg_queue: SegQueue<ArcMemo<'static, C>>,
seg_queue: SegQueue<SharedBox<Memo<C::Output<'static>>>>,
}

unsafe impl<T: Send> Send for SharedBox<T> {}
unsafe impl<T: Sync> Sync for SharedBox<T> {}

impl<C: Configuration> Default for DeletedEntries<C> {
fn default() -> Self {
Self {
Expand All @@ -18,8 +24,29 @@ impl<C: Configuration> Default for DeletedEntries<C> {
}

impl<C: Configuration> DeletedEntries<C> {
pub(super) fn push(&self, memo: ArcMemo<'_, C>) {
let memo = unsafe { std::mem::transmute::<ArcMemo<'_, C>, ArcMemo<'static, C>>(memo) };
self.seg_queue.push(memo);
/// # Safety
///
/// The memo must be valid and safe to free when the `DeletedEntries` list is dropped.
pub(super) unsafe fn push(&self, memo: NonNull<Memo<C::Output<'_>>>) {
let memo = unsafe {
std::mem::transmute::<NonNull<Memo<C::Output<'_>>>, NonNull<Memo<C::Output<'static>>>>(
memo,
)
};

self.seg_queue.push(SharedBox(memo));
}
}

/// A wrapper around `NonNull` that frees the allocation when it is dropped.
///
/// `crossbeam::SegQueue` does not expose mutable accessors so we have to create
/// a wrapper to run code during `Drop`.
struct SharedBox<T>(NonNull<T>);

impl<T> Drop for SharedBox<T> {
fn drop(&mut self) {
// SAFETY: Guaranteed by the caller of `DeletedEntries::push`.
unsafe { drop(Box::from_raw(self.0.as_ptr())) };
}
}
4 changes: 1 addition & 3 deletions src/function/execute.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::sync::Arc;

use crate::{
zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind,
};
Expand All @@ -23,7 +21,7 @@ where
&'db self,
db: &'db C::DbView,
active_query: ActiveQueryGuard<'_>,
opt_old_memo: Option<Arc<Memo<C::Output<'_>>>>,
opt_old_memo: Option<&Memo<C::Output<'_>>>,
) -> &'db Memo<C::Output<'db>> {
let zalsa = db.zalsa();
let revision_now = zalsa.current_revision();
Expand Down
3 changes: 1 addition & 2 deletions src/function/maybe_changed_after.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ where
MaybeChangedAfter::No(memo.revisions.accumulated_inputs.load())
};
}
drop(memo_guard); // release the arc-swap guard before cold path
if let Some(mcs) =
self.maybe_changed_after_cold(zalsa, db, id, revision, memo_ingredient_index)
{
Expand Down Expand Up @@ -86,7 +85,7 @@ where
);

// Check if the inputs are still valid. We can just compare `changed_at`.
if self.deep_verify_memo(db, zalsa, &old_memo, &active_query) {
if self.deep_verify_memo(db, zalsa, old_memo, &active_query) {
return Some(if old_memo.revisions.changed_at > revision {
MaybeChangedAfter::Yes
} else {
Expand Down
79 changes: 30 additions & 49 deletions src/function/memo.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use std::any::Any;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::mem::ManuallyDrop;
use std::sync::Arc;
use std::ptr::NonNull;

use crate::accumulator::accumulated_map::InputAccumulatedValues;
use crate::function::DeletedEntries;
use crate::revision::AtomicRevision;
use crate::table::memo::MemoTable;
use crate::zalsa::MemoIngredientIndex;
Expand All @@ -17,21 +15,33 @@ use crate::{

use super::{Configuration, IngredientImpl};

#[allow(type_alias_bounds)]
pub(super) type ArcMemo<'lt, C: Configuration> = Arc<Memo<<C as Configuration>::Output<'lt>>>;

impl<C: Configuration> IngredientImpl<C> {
/// Memos have to be stored internally using `'static` as the database lifetime.
/// This (unsafe) function call converts from something tied to self to static.
/// Values transmuted this way have to be transmuted back to being tied to self
/// when they are returned to the user.
unsafe fn to_static<'db>(&'db self, memo: ArcMemo<'db, C>) -> ArcMemo<'static, C> {
unsafe { std::mem::transmute(memo) }
unsafe fn to_static<'db>(
&'db self,
memo: NonNull<Memo<C::Output<'db>>>,
) -> NonNull<Memo<C::Output<'static>>> {
memo.cast()
}

/// Convert from an internal memo (which uses `'static`) to one tied to self
/// so it can be publicly released.
unsafe fn to_self<'db>(&'db self, memo: ArcMemo<'static, C>) -> ArcMemo<'db, C> {
unsafe fn to_self<'db>(
&'db self,
memo: NonNull<Memo<C::Output<'static>>>,
) -> NonNull<Memo<C::Output<'db>>> {
memo.cast()
}

/// Convert from an internal memo (which uses `'static`) to one tied to self
/// so it can be publicly released.
unsafe fn to_self_ref<'db>(
&'db self,
memo: &'db Memo<C::Output<'static>>,
) -> &'db Memo<C::Output<'db>> {
unsafe { std::mem::transmute(memo) }
}

Expand All @@ -45,17 +55,16 @@ impl<C: Configuration> IngredientImpl<C> {
&'db self,
zalsa: &'db Zalsa,
id: Id,
memo: ArcMemo<'db, C>,
memo: NonNull<Memo<C::Output<'db>>>,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<ManuallyDrop<ArcMemo<'db, C>>> {
) -> Option<NonNull<Memo<C::Output<'db>>>> {
let static_memo = unsafe { self.to_static(memo) };
let old_static_memo = unsafe {
zalsa
.memo_table_for(id)
.insert(memo_ingredient_index, static_memo)
}?;
let old_static_memo = ManuallyDrop::into_inner(old_static_memo);
Some(ManuallyDrop::new(unsafe { self.to_self(old_static_memo) }))
Some(unsafe { self.to_self(old_static_memo) })
}

/// Loads the current memo for `key_index`. This does not hold any sort of
Expand All @@ -66,20 +75,20 @@ impl<C: Configuration> IngredientImpl<C> {
zalsa: &'db Zalsa,
id: Id,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<ArcMemo<'db, C>> {
) -> Option<&'db Memo<C::Output<'db>>> {
let static_memo = zalsa.memo_table_for(id).get(memo_ingredient_index)?;
unsafe { Some(self.to_self(static_memo)) }

unsafe { Some(self.to_self_ref(static_memo)) }
}

/// Evicts the existing memo for the given key, replacing it
/// with an equivalent memo that has no value. If the memo is untracked, BaseInput,
/// or has values assigned as output of another query, this has no effect.
pub(super) fn evict_value_from_memo_for(
table: &mut MemoTable,
deleted_entries: &DeletedEntries<C>,
memo_ingredient_index: MemoIngredientIndex,
) {
let map = |memo: ArcMemo<'static, C>| -> ArcMemo<'static, C> {
let map = |memo: &mut Memo<C::Output<'static>>| {
match &memo.revisions.origin {
QueryOrigin::Assigned(_)
| QueryOrigin::DerivedUntracked(_)
Expand All @@ -88,43 +97,15 @@ impl<C: Configuration> IngredientImpl<C> {
// assigned as output of another query
// or those with untracked inputs
// as their values cannot be reconstructed.
memo
}
QueryOrigin::Derived(_) => {
// Note that we cannot use `Arc::get_mut` here as the use of `ArcSwap` makes it
// impossible to get unique access to the interior Arc
// QueryRevisions: !Clone to discourage cloning, we need it here though
let &QueryRevisions {
changed_at,
durability,
ref origin,
ref tracked_struct_ids,
ref accumulated,
ref accumulated_inputs,
} = &memo.revisions;
// Re-assemble the memo but with the value set to `None`
Arc::new(Memo::new(
None,
memo.verified_at.load(),
QueryRevisions {
changed_at,
durability,
origin: origin.clone(),
tracked_struct_ids: tracked_struct_ids.clone(),
accumulated: accumulated.clone(),
accumulated_inputs: accumulated_inputs.clone(),
},
))
// Set the memo value to `None`.
memo.value = None;
}
}
};
// SAFETY: We queue the old value for deletion, delaying its drop until the next revision bump.
let old = unsafe { table.map_memo(memo_ingredient_index, map) };
if let Some(old) = old {
// In case there is a reference to the old memo out there, we have to store it
// in the deleted entries. This will get cleared when a new revision starts.
deleted_entries.push(ManuallyDrop::into_inner(old));
}

table.map_memo(memo_ingredient_index, map)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/function/specify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ where

let memo_ingredient_index = self.memo_ingredient_index(zalsa, key);
if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) {
self.backdate_if_appropriate(&old_memo, &mut revisions, &value);
self.diff_outputs(zalsa, db, database_key_index, &old_memo, &mut revisions);
self.backdate_if_appropriate(old_memo, &mut revisions, &value);
self.diff_outputs(zalsa, db, database_key_index, old_memo, &mut revisions);
}

let memo = Memo {
Expand Down
Loading

0 comments on commit 99be5d9

Please sign in to comment.