Skip to content

Commit

Permalink
Get rid of constrain and solve steps
Browse files Browse the repository at this point in the history
  • Loading branch information
Veykril committed Dec 28, 2024
1 parent 0e50c3c commit d66a337
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 95 deletions.
8 changes: 0 additions & 8 deletions src/tools/rust-analyzer/crates/hir-ty/src/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,6 @@ impl Generics {
self.params.len()
}

pub(crate) fn len_self_lifetimes(&self) -> usize {
self.params.len_lifetimes()
}

pub(crate) fn has_trait_self(&self) -> bool {
self.params.trait_self_param().is_some()
}

/// (parent total, self param, type params, const params, impl trait list, lifetimes)
pub(crate) fn provenance_split(&self) -> (usize, bool, usize, usize, usize, usize) {
let mut self_param = false;
Expand Down
131 changes: 44 additions & 87 deletions src/tools/rust-analyzer/crates/hir-ty/src/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,9 @@ pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option<Ar
if count == 0 {
return None;
}
let mut ctxt = Context {
def,
has_trait_self: generics.parent_generics().map_or(false, |it| it.has_trait_self()),
len_self: generics.len_self(),
len_self_lifetimes: generics.len_self_lifetimes(),
generics,
constraints: Vec::new(),
db,
};
let variances = Context { generics, variances: vec![Variance::Bivariant; count], db }.solve();

ctxt.build_constraints_for_item();
let res = ctxt.solve();
res.is_empty().not().then(|| Arc::from_iter(res))
variances.is_empty().not().then(|| Arc::from_iter(variances))
}

pub(crate) fn variances_of_cycle(
Expand Down Expand Up @@ -172,25 +162,14 @@ struct InferredIndex(usize);

struct Context<'db> {
db: &'db dyn HirDatabase,
def: GenericDefId,
has_trait_self: bool,
len_self: usize,
len_self_lifetimes: usize,
generics: Generics,
constraints: Vec<Constraint>,
}

/// Declares that the variable `decl_id` appears in a location with
/// variance `variance`.
#[derive(Clone)]
struct Constraint {
inferred: InferredIndex,
variance: Variance,
variances: Vec<Variance>,
}

impl Context<'_> {
fn build_constraints_for_item(&mut self) {
match self.def {
fn solve(mut self) -> Vec<Variance> {
tracing::debug!("solve(generics={:?})", self.generics);
match self.generics.def() {
GenericDefId::AdtId(adt) => {
let db = self.db;
let mut add_constraints_from_variant = |variant| {
Expand Down Expand Up @@ -225,6 +204,26 @@ impl Context<'_> {
}
_ => {}
}
let mut variances = self.variances;

// Const parameters are always invariant.
// Make all const parameters invariant.
for (idx, param) in self.generics.iter_id().enumerate() {
if let GenericParamId::ConstParamId(_) = param {
variances[idx] = Variance::Invariant;
}
}

// Functions are permitted to have unused generic parameters: make those invariant.
if let GenericDefId::FunctionId(_) = self.generics.def() {
for variance in &mut variances {
if *variance == Variance::Bivariant {
*variance = Variance::Invariant;
}
}
}

variances
}

fn contravariant(&mut self, variance: Variance) -> Variance {
Expand Down Expand Up @@ -353,14 +352,8 @@ impl Context<'_> {
// Chalk has no params, so use placeholders for now?
TyKind::Placeholder(index) => {
let idx = crate::from_placeholder_idx(self.db, *index);
let index = idx.local_id.into_raw().into_u32() as usize + self.len_self_lifetimes;
let inferred = if idx.parent == self.def {
InferredIndex(self.has_trait_self as usize + index)
} else {
InferredIndex(self.len_self + index)
};
tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance);
self.constraints.push(Constraint { inferred, variance });
let inferred = InferredIndex(self.generics.type_or_const_param_idx(idx).unwrap());
self.constrain(inferred, variance);
}
TyKind::Function(f) => {
self.add_constraints_from_sig(f, variance);
Expand Down Expand Up @@ -396,7 +389,7 @@ impl Context<'_> {
if args.is_empty() {
return;
}
if def_id == self.def {
if def_id == self.generics.def() {
// HACK: Workaround for the trivial cycle salsa case (see
// recursive_one_bivariant_more_non_bivariant_params test)
let variance_i = variance.xform(Variance::Bivariant);
Expand Down Expand Up @@ -463,18 +456,17 @@ impl Context<'_> {
/// Adds constraints appropriate for a region appearing in a
/// context with ambient variance `variance`
fn add_constraints_from_region(&mut self, region: &Lifetime, variance: Variance) {
tracing::debug!(
"add_constraints_from_region(region={:?}, variance={:?})",
region,
variance
);
match region.data(Interner) {
// FIXME: chalk has no params?
LifetimeData::Placeholder(index) => {
let idx = crate::lt_from_placeholder_idx(self.db, *index);
let index = idx.local_id.into_raw().into_u32() as usize;
let inferred = if idx.parent == self.def {
InferredIndex(index)
} else {
InferredIndex(self.has_trait_self as usize + self.len_self + index)
};
tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance);
self.constraints.push(Constraint { inferred, variance: variance.clone() });
let inferred = InferredIndex(self.generics.lifetime_idx(idx).unwrap());
self.constrain(inferred, variance);
}
LifetimeData::Static => {}

Expand Down Expand Up @@ -513,50 +505,15 @@ impl Context<'_> {
}
}
}
}

impl Context<'_> {
fn solve(self) -> Vec<Variance> {
let mut solutions = vec![Variance::Bivariant; self.generics.len()];
// Propagate constraints until a fixed point is reached. Note
// that the maximum number of iterations is 2C where C is the
// number of constraints (each variable can change values at most
// twice). Since number of constraints is linear in size of the
// input, so is the inference process.
let mut changed = true;
while changed {
changed = false;

for constraint in &self.constraints {
let &Constraint { inferred, variance } = constraint;
let InferredIndex(inferred) = inferred;
let old_value = solutions[inferred];
let new_value = variance.glb(old_value);
if old_value != new_value {
solutions[inferred] = new_value;
changed = true;
}
}
}

// Const parameters are always invariant.
// Make all const parameters invariant.
for (idx, param) in self.generics.iter_id().enumerate() {
if let GenericParamId::ConstParamId(_) = param {
solutions[idx] = Variance::Invariant;
}
}

// Functions are permitted to have unused generic parameters: make those invariant.
if let GenericDefId::FunctionId(_) = self.def {
for variance in &mut solutions {
if *variance == Variance::Bivariant {
*variance = Variance::Invariant;
}
}
}

solutions
fn constrain(&mut self, inferred: InferredIndex, variance: Variance) {
tracing::debug!(
"constrain(index={:?}, variance={:?}, to={:?})",
inferred,
self.variances[inferred.0],
variance
);
self.variances[inferred.0] = self.variances[inferred.0].glb(variance);
}
}

Expand Down

0 comments on commit d66a337

Please sign in to comment.