Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor tt for easier maintainance #160

Merged
merged 1 commit into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 2 additions & 304 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#![allow(rustc::usage_of_ty_tykind)]
#![allow(unused_imports)]

use rustc_ast::expand::typetree::{Type, Kind, TypeTree, FncTree};
use rustc_target::abi::FieldsShape;

pub use self::fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable};
Expand Down Expand Up @@ -75,6 +74,7 @@ pub use rustc_type_ir::ConstKind::{
};
pub use rustc_type_ir::*;

pub use self::typetree::*;
pub use self::binding::BindingMode;
pub use self::binding::BindingMode::*;
pub use self::closure::{
Expand Down Expand Up @@ -127,6 +127,7 @@ pub mod util;
pub mod visit;
pub mod vtable;
pub mod walk;
pub mod typetree;

mod adt;
mod assoc;
Expand Down Expand Up @@ -2721,306 +2722,3 @@ mod size_asserts {
// tidy-alphabetical-end
}

pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
let mut visited = vec![];
let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited, None);
let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child: ty };
return TypeTree(vec![tt]);
}

use rustc_ast::expand::autodiff_attrs::DiffActivity;

// This function combines three tasks. To avoid traversing each type 3x, we combine them.
// 1. Create a TypeTree from a Ty. This is the main task.
// 2. IFF da is not empty, we also want to adjust DiffActivity to account for future MIR->LLVM
// lowering. E.g. fat ptr are going to introduce an extra int.
// 3. IFF da is not empty, we are creating TT for a function directly differentiated (has an
// autodiff macro on top). Here we want to make sure that shadows are mutable internally.
// We know the outermost ref/ptr indirection is mutability - we generate it like that.
// We now have to make sure that inner ptr/ref are mutable too, or issue a warning.
// Not an error, becaues it only causes issues if they are actually read, which we don't check
// yet. We should add such analysis to relibably either issue an error or accept without warning.
// If there only were some reasearch to do that...
pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>, span: Option<Span>) -> FncTree {
if !fn_ty.is_fn() {
return FncTree { args: vec![], ret: TypeTree::new() };
}
let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);

// If rustc compiles the unmodified primal, we know that this copy of the function
// also has correct lifetimes. We know that Enzyme won't free the shadow too early
// (or actually at all), so let's strip lifetimes when computing the layout.
// Recommended by compiler-errors:
// https://discord.com/channels/273534239310479360/957720175619215380/1223454360676208751
let x = tcx.instantiate_bound_regions_with_erased(fnc_binder);

let mut new_activities = vec![];
let mut new_positions = vec![];
let mut visited = vec![];
let mut args = vec![];
for (i, ty) in x.inputs().iter().enumerate() {
// We care about safety checks, if an argument get's duplicated and we write into the
// shadow. That's equivalent to Duplicated or DuplicatedOnly.
let safety = if !da.is_empty() {
assert!(da.len() == x.inputs().len(), "{:?} != {:?}", da.len(), x.inputs().len());
// If we have Activities, we also have spans
assert!(span.is_some());
match da[i] {
DiffActivity::DuplicatedOnly | DiffActivity::Duplicated => true,
_ => false,
}
} else {
false
};

visited.clear();
if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() {
if ty.is_fn_ptr() {
unimplemented!("what to do whith fn ptr?");
}
let inner_ty = ty.builtin_deref(true).unwrap().ty;
if inner_ty.is_slice() {
// We know that the lenght will be passed as extra arg.
let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span);
let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child };
args.push(TypeTree(vec![tt]));
let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() };
args.push(TypeTree(vec![i64_tt]));
if !da.is_empty() {
// We are looking at a slice. The length of that slice will become an
// extra integer on llvm level. Integers are always const.
// However, if the slice get's duplicated, we want to know to later check the
// size. So we mark the new size argument as FakeActivitySize.
let activity = match da[i] {
DiffActivity::DualOnly | DiffActivity::Dual |
DiffActivity::DuplicatedOnly | DiffActivity::Duplicated
=> DiffActivity::FakeActivitySize,
DiffActivity::Const => DiffActivity::Const,
_ => panic!("unexpected activity for ptr/ref"),
};
new_activities.push(activity);
new_positions.push(i + 1);
}
trace!("ABI MATCHING!");
continue;
}
}
let arg_tt = typetree_from_ty(*ty, tcx, 0, safety, &mut visited, span);
args.push(arg_tt);
}

// now add the extra activities coming from slices
// Reverse order to not invalidate the indices
for _ in 0..new_activities.len() {
let pos = new_positions.pop().unwrap();
let activity = new_activities.pop().unwrap();
da.insert(pos, activity);
}

visited.clear();
let ret = typetree_from_ty(x.output(), tcx, 0, false, &mut visited, span);

FncTree { args, ret }
}

fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec<Ty<'a>>, span: Option<Span>) -> TypeTree {
if depth > 20 {
trace!("depth > 20 for ty: {}", &ty);
}
if visited.contains(&ty) {
// recursive type
trace!("recursive type: {}", &ty);
return TypeTree::new();
}
visited.push(ty);

if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() {
if ty.is_fn_ptr() {
unimplemented!("what to do whith fn ptr?");
}

let inner_ty_and_mut = ty.builtin_deref(true).unwrap();
let is_mut = inner_ty_and_mut.mutbl == hir::Mutability::Mut;
let inner_ty = inner_ty_and_mut.ty;

// Now account for inner mutability.
if !is_mut && depth > 0 && safety {
let ptr_ty: String = if ty.is_ref() {
"ref"
} else if ty.is_unsafe_ptr() {
"ptr"
} else {
assert!(ty.is_box());
"box"
}.to_string();

// If we have mutability, we also have a span
assert!(span.is_some());
let span = span.unwrap();

tcx.sess
.dcx()
.emit_warning(AutodiffUnsafeInnerConstRef{span, ty: ptr_ty});
}

//visited.push(inner_ty);
let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span);
let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child };
visited.pop();
return TypeTree(vec![tt]);
}


if ty.is_closure() || ty.is_coroutine() || ty.is_fresh() || ty.is_fn() {
visited.pop();
return TypeTree::new();
}

if ty.is_scalar() {
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
} else if ty.is_floating_point() {
match ty {
x if x == tcx.types.f32 => (Kind::Float, 4),
x if x == tcx.types.f64 => (Kind::Double, 8),
_ => panic!("floatTy scalar that is neither f32 nor f64"),
}
} else {
panic!("scalar that is neither integral nor floating point");
};
visited.pop();
return TypeTree(vec![Type { offset: -1, child: TypeTree::new(), kind, size }]);
}

let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty };

let layout = tcx.layout_of(param_env_and);
assert!(layout.is_ok());

let layout = layout.unwrap().layout;
let fields = layout.fields();
let max_size = layout.size();



if ty.is_adt() && !ty.is_simd() {
let adt_def = ty.ty_adt_def().unwrap();

if adt_def.is_struct() {
let (offsets, _memory_index) = match fields {
// Manuel TODO:
FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m),
FieldsShape::Array { .. } => {return TypeTree::new();}, //e.g. core::arch::x86_64::__m128i, TODO: later
FieldsShape::Union(_) => {return TypeTree::new();},
FieldsShape::Primitive => {return TypeTree::new();},
};

let substs = match ty.kind() {
Adt(_, subst_ref) => subst_ref,
_ => panic!(""),
};

let fields = adt_def.all_fields();
let fields = fields
.into_iter()
.zip(offsets.into_iter())
.filter_map(|(field, offset)| {
let field_ty: Ty<'_> = field.ty(tcx, substs);
let field_ty: Ty<'_> =
tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty);

if field_ty.is_phantom_data() {
return None;
}

//visited.push(field_ty);
let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0;

for c in &mut child {
if c.offset == -1 {
c.offset = offset.bytes() as isize
} else {
c.offset += offset.bytes() as isize;
}
}

Some(child)
})
.flatten()
.collect::<Vec<Type>>();

visited.pop();
let ret_tt = TypeTree(fields);
return ret_tt;
} else if adt_def.is_enum() {
// Enzyme can't represent enums, so let it figure it out itself, without seeeding
// typetree
//unimplemented!("adt that is an enum");
} else {
//let ty_name = tcx.def_path_debug_str(adt_def.did());
//tcx.sess.emit_fatal(UnsupportedUnion { ty_name });
}
}

if ty.is_simd() {
trace!("simd");
let (_size, inner_ty) = ty.simd_size_and_type(tcx);
//visited.push(inner_ty);
let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span);
//let tt = TypeTree(
// std::iter::repeat(subtt)
// .take(*count as usize)
// .enumerate()
// .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize)))
// .flatten()
// .collect(),
//);
// TODO
visited.pop();
return TypeTree::new();
}

if ty.is_array() {
let (stride, count) = match fields {
FieldsShape::Array { stride: s, count: c } => (s, c),
_ => panic!(""),
};
let byte_stride = stride.bytes_usize();
let byte_max_size = max_size.bytes_usize();

assert!(byte_stride * *count as usize == byte_max_size);
if (*count as usize) == 0 {
return TypeTree::new();
}
let sub_ty = ty.builtin_index().unwrap();
//visited.push(sub_ty);
let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span);

// calculate size of subtree
let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty };
let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize;
let tt = TypeTree(
std::iter::repeat(subtt)
.take(*count as usize)
.enumerate()
.map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize)))
.flatten()
.collect(),
);

visited.pop();
return tt;
}

if ty.is_slice() {
let sub_ty = ty.builtin_index().unwrap();
//visited.push(sub_ty);
let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span);

visited.pop();
return subtt;
}

visited.pop();
TypeTree::new()
}
Loading
Loading