From 5eb99ccbabb1504eafb0a636f2e7a33f83bba0a2 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 21 Feb 2025 07:30:37 +0100 Subject: [PATCH] Remove some `ZalsaDatabase::zalsa` calls Instead pass around `&Zalsa` to callers more to reduce dynamic dispatch, in most of these cases the functions are only called once so the compiler should have enough knowledge to make the extra argument passing virtually free --- src/database.rs | 8 +++---- src/function/diff_outputs.rs | 16 +++++++++----- src/function/execute.rs | 2 +- src/function/fetch.rs | 34 +++++++++++++++++++++-------- src/function/maybe_changed_after.rs | 22 +++++++++++++------ src/function/memo.rs | 3 ++- src/function/specify.rs | 2 +- src/key.rs | 19 +++++++++++----- src/table/sync.rs | 2 +- src/tracked_struct.rs | 2 +- 10 files changed, 74 insertions(+), 36 deletions(-) diff --git a/src/database.rs b/src/database.rs index 4e1eb5dae..87aace8fe 100644 --- a/src/database.rs +++ b/src/database.rs @@ -49,9 +49,8 @@ pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { /// Queries which report untracked reads will be re-executed in the next /// revision. fn report_untracked_read(&self) { - let db = self.as_dyn_database(); - let zalsa_local = db.zalsa_local(); - zalsa_local.report_untracked_read(db.zalsa().current_revision()) + let (zalsa, zalsa_local) = self.zalsas(); + zalsa_local.report_untracked_read(zalsa.current_revision()) } /// Return the "debug name" (i.e., the struct name, etc) for an "ingredient", @@ -81,8 +80,7 @@ pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { /// used instead. fn unwind_if_revision_cancelled(&self) { let db = self.as_dyn_database(); - let zalsa_local = db.zalsa_local(); - zalsa_local.unwind_if_revision_cancelled(db); + self.zalsa_local().unwind_if_revision_cancelled(db); } /// Execute `op` with the database in thread-local storage for debug print-outs. diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index a727d7720..8ea8fe750 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -1,7 +1,7 @@ use super::{memo::Memo, Configuration, IngredientImpl}; use crate::{ - hash::FxHashSet, key::OutputDependencyIndex, zalsa_local::QueryRevisions, AsDynDatabase as _, - DatabaseKeyIndex, Event, EventKind, + hash::FxHashSet, key::OutputDependencyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, + AsDynDatabase as _, DatabaseKeyIndex, Event, EventKind, }; impl IngredientImpl @@ -15,6 +15,7 @@ where /// that no longer exist in this revision from [`QueryRevisions::tracked_struct_ids`]. pub(super) fn diff_outputs( &self, + zalsa: &Zalsa, db: &C::DbView, key: DatabaseKeyIndex, old_memo: &Memo>, @@ -38,11 +39,16 @@ where } for old_output in old_outputs { - Self::report_stale_output(db, key, old_output); + Self::report_stale_output(zalsa, db, key, old_output); } } - fn report_stale_output(db: &C::DbView, key: DatabaseKeyIndex, output: OutputDependencyIndex) { + fn report_stale_output( + zalsa: &Zalsa, + db: &C::DbView, + key: DatabaseKeyIndex, + output: OutputDependencyIndex, + ) { let db = db.as_dyn_database(); db.salsa_event(&|| { @@ -52,6 +58,6 @@ where }) }); - output.remove_stale_output(db, key); + output.remove_stale_output(zalsa, db, key); } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 3adbe4b08..185bea33b 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -79,7 +79,7 @@ where // old value. if let Some(old_memo) = &opt_old_memo { self.backdate_if_appropriate(old_memo, &mut revisions, &value); - self.diff_outputs(db, database_key_index, old_memo, &mut revisions); + self.diff_outputs(zalsa, db, database_key_index, old_memo, &mut revisions); } tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}"); diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 7828f33b9..211147e3c 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,7 +1,9 @@ use super::{memo::Memo, Configuration, IngredientImpl}; use crate::{ - accumulator::accumulated_map::InputAccumulatedValues, runtime::StampedValue, - zalsa::ZalsaDatabase, AsDynDatabase as _, Id, + accumulator::accumulated_map::InputAccumulatedValues, + runtime::StampedValue, + zalsa::{Zalsa, ZalsaDatabase}, + AsDynDatabase as _, Id, }; impl IngredientImpl @@ -40,16 +42,24 @@ where db: &'db C::DbView, id: Id, ) -> &'db Memo> { + let zalsa = db.zalsa(); loop { - if let Some(memo) = self.fetch_hot(db, id).or_else(|| self.fetch_cold(db, id)) { + if let Some(memo) = self + .fetch_hot(zalsa, db, id) + .or_else(|| self.fetch_cold(zalsa, db, id)) + { return memo; } } } #[inline] - fn fetch_hot<'db>(&'db self, db: &'db C::DbView, id: Id) -> Option<&'db Memo>> { - let zalsa = db.zalsa(); + fn fetch_hot<'db>( + &'db self, + zalsa: &'db Zalsa, + db: &'db C::DbView, + id: Id, + ) -> Option<&'db Memo>> { let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { if memo.value.is_some() @@ -63,13 +73,19 @@ where None } - fn fetch_cold<'db>(&'db self, db: &'db C::DbView, id: Id) -> Option<&'db Memo>> { - let (zalsa, zalsa_local) = db.zalsas(); + fn fetch_cold<'db>( + &'db self, + zalsa: &'db Zalsa, + db: &'db C::DbView, + id: Id, + ) -> Option<&'db Memo>> { + let zalsa_local = db.zalsa_local(); let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. let _claim_guard = zalsa.sync_table_for(id).claim( db.as_dyn_database(), + zalsa, zalsa_local, database_key_index, self.memo_ingredient_index, @@ -79,10 +95,10 @@ where let active_query = zalsa_local.push_query(database_key_index); // Now that we've claimed the item, check again to see if there's a "hot" value. - let zalsa = db.zalsa(); let opt_old_memo = self.get_memo_from_table_for(zalsa, id); if let Some(old_memo) = &opt_old_memo { - if old_memo.value.is_some() && self.deep_verify_memo(db, old_memo, &active_query) { + if old_memo.value.is_some() && self.deep_verify_memo(db, zalsa, old_memo, &active_query) + { // Unsafety invariant: memo is present in memo_map and we have verified that it is // still valid for the current revision. return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index c37b1de68..214d3eb33 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,6 +2,7 @@ use crate::{ accumulator::accumulated_map::InputAccumulatedValues, ingredient::MaybeChangedAfter, key::DatabaseKeyIndex, + plumbing::ZalsaLocal, zalsa::{Zalsa, ZalsaDatabase}, zalsa_local::{ActiveQueryGuard, QueryEdge, QueryOrigin}, AsDynDatabase as _, Id, Revision, @@ -38,7 +39,9 @@ where }; } drop(memo_guard); // release the arc-swap guard before cold path - if let Some(mcs) = self.maybe_changed_after_cold(db, id, revision) { + if let Some(mcs) = + self.maybe_changed_after_cold(zalsa, zalsa_local, db, id, revision) + { return mcs; } else { // We failed to claim, have to retry. @@ -52,15 +55,17 @@ where fn maybe_changed_after_cold<'db>( &'db self, + zalsa: &Zalsa, + zalsa_local: &ZalsaLocal, db: &'db C::DbView, key_index: Id, revision: Revision, ) -> Option { - let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(key_index); let _claim_guard = zalsa.sync_table_for(key_index).claim( db.as_dyn_database(), + zalsa, zalsa_local, database_key_index, self.memo_ingredient_index, @@ -79,7 +84,7 @@ where ); // Check if the inputs are still valid. We can just compare `changed_at`. - if self.deep_verify_memo(db, &old_memo, &active_query) { + if self.deep_verify_memo(db, zalsa, &old_memo, &active_query) { return Some(if old_memo.revisions.changed_at > revision { MaybeChangedAfter::Yes } else { @@ -141,7 +146,7 @@ where database_key_index, memo.revisions.accumulated_inputs.load(), ); - memo.mark_outputs_as_verified(db, database_key_index); + memo.mark_outputs_as_verified(zalsa, db, database_key_index); return true; } @@ -159,10 +164,10 @@ where pub(super) fn deep_verify_memo( &self, db: &C::DbView, + zalsa: &Zalsa, old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, ) -> bool { - let zalsa = db.zalsa(); let database_key_index = active_query.database_key_index; tracing::debug!( @@ -237,8 +242,11 @@ where // by this function cannot be read until this function is marked green, // so even if we mark them as valid here, the function will re-execute // and overwrite the contents. - dependency_index - .mark_validated_output(db.as_dyn_database(), database_key_index); + dependency_index.mark_validated_output( + zalsa, + db.as_dyn_database(), + database_key_index, + ); } } } diff --git a/src/function/memo.rs b/src/function/memo.rs index 8f688702d..ae2a20997 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -186,11 +186,12 @@ impl Memo { pub(super) fn mark_outputs_as_verified( &self, + zalsa: &Zalsa, db: &dyn crate::Database, database_key_index: DatabaseKeyIndex, ) { for output in self.revisions.origin.outputs() { - output.mark_validated_output(db, database_key_index); + output.mark_validated_output(zalsa, db, database_key_index); } } diff --git a/src/function/specify.rs b/src/function/specify.rs index 5a9187c70..85e721f56 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -75,7 +75,7 @@ where if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) { self.backdate_if_appropriate(&old_memo, &mut revisions, &value); - self.diff_outputs(db, database_key_index, &old_memo, &mut revisions); + self.diff_outputs(zalsa, db, database_key_index, &old_memo, &mut revisions); } let memo = Memo { diff --git a/src/key.rs b/src/key.rs index 8fd159520..2476a5e88 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,8 +1,11 @@ use core::fmt; use crate::{ - accumulator::accumulated_map::InputAccumulatedValues, cycle::CycleRecoveryStrategy, - ingredient::MaybeChangedAfter, zalsa::IngredientIndex, Database, Id, + accumulator::accumulated_map::InputAccumulatedValues, + cycle::CycleRecoveryStrategy, + ingredient::MaybeChangedAfter, + zalsa::{IngredientIndex, Zalsa}, + Database, Id, }; /// An integer that uniquely identifies a particular query instance within the @@ -33,18 +36,24 @@ impl OutputDependencyIndex { } } - pub(crate) fn remove_stale_output(&self, db: &dyn Database, executor: DatabaseKeyIndex) { - db.zalsa() + pub(crate) fn remove_stale_output( + &self, + zalsa: &Zalsa, + db: &dyn Database, + executor: DatabaseKeyIndex, + ) { + zalsa .lookup_ingredient(self.ingredient_index) .remove_stale_output(db, executor, self.key_index) } pub(crate) fn mark_validated_output( &self, + zalsa: &Zalsa, db: &dyn Database, database_key_index: DatabaseKeyIndex, ) { - db.zalsa() + zalsa .lookup_ingredient(self.ingredient_index) .mark_validated_output(db, database_key_index, self.key_index) } diff --git a/src/table/sync.rs b/src/table/sync.rs index 14fb6a69e..3f9179067 100644 --- a/src/table/sync.rs +++ b/src/table/sync.rs @@ -34,12 +34,12 @@ impl SyncTable { pub(crate) fn claim<'me>( &'me self, db: &'me dyn Database, + zalsa: &'me Zalsa, zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, ) -> Option> { let mut syncs = self.syncs.write(); - let zalsa = db.zalsa(); let thread_id = std::thread::current().id(); util::ensure_vec_len(&mut syncs, memo_ingredient_index.as_usize() + 1); diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index a82249196..77f72d436 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -608,7 +608,7 @@ where db.salsa_event(&|| Event::new(EventKind::DidDiscard { key: executor })); for stale_output in memo.origin().outputs() { - stale_output.remove_stale_output(db, executor); + stale_output.remove_stale_output(zalsa, db, executor); } }