Skip to content

Commit

Permalink
fix fwd void ret case
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Apr 3, 2024
1 parent ad11ec9 commit a8048a4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
29 changes: 9 additions & 20 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {

unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize]) {
dbg!("size_positions: {:?}", size_positions);

// first, remove all calls from fnc
let bb = LLVMGetFirstBasicBlock(tgt);
let br = LLVMRustGetTerminator(bb);
Expand Down Expand Up @@ -843,7 +843,6 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,

// Now clean up placeholder code.
LLVMRustEraseInstBefore(bb, last_inst);
//dbg!(&tgt);

let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src));
let void_type = LLVMVoidTypeInContext(llcx);
Expand All @@ -865,6 +864,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
let _fnc_ok =
LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction);
}

unsafe fn get_panic_name(llmod: &llvm::Module) -> CString {
// The names are mangled and their ending changes based on a hash, so just take whichever.
let mut f = LLVMGetFirstFunction(llmod);
Expand Down Expand Up @@ -922,21 +922,7 @@ unsafe fn add_panic_msg_to_global<'a>(llmod: &'a llvm::Module, llcx: &'a llvm::C
LLVMRustSetLinkage(global_var, Linkage::PrivateLinkage);
LLVMSetInitializer(global_var, struct_initializer);

//let msg_global_name = "ad_safety_msg".to_string();
//let cmsg_global_name = CString::new(msg_global_name).unwrap();
//let msg = "autodiff safety check failed!";
//let cmsg = CString::new(msg).unwrap();
//let msg_len = msg.len();
//let i8_array_type = llvm::LLVMRustArrayType(llvm::LLVMInt8TypeInContext(llcx), msg_len as u64);
//let global_type = llvm::LLVMStructTypeInContext(llcx, [i8_array_type].as_mut_ptr(), 1, 0);
//let string_const_val = llvm::LLVMConstStringInContext(llcx, cmsg.as_ptr() as *const c_char, msg_len as u32, 0);
//let initializer = llvm::LLVMConstStructInContext(llcx, [string_const_val].as_mut_ptr(), 1, 0);
//let global = llvm::LLVMAddGlobal(llmod, global_type, cmsg_global_name.as_ptr() as *const c_char);
//llvm::LLVMRustSetLinkage(global, llvm::Linkage::PrivateLinkage);
//llvm::LLVMSetInitializer(global, initializer);
//llvm::LLVMSetUnnamedAddress(global, llvm::UnnamedAddr::Global);

global_var
global_var
}

// As unsafe as it can be.
Expand Down Expand Up @@ -1027,6 +1013,10 @@ pub(crate) unsafe fn enzyme_ad(
DiffMode::ReverseFirst => DiffMode::Reverse,
_ => unreachable!(),
};

let void_type = LLVMVoidTypeInContext(llcx);
let return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src_fnc));
let void_ret = void_type == return_type;
let mut tmp = match mode {
DiffMode::Forward => enzyme_rust_forward_diff(
logic_ref,
Expand All @@ -1036,6 +1026,7 @@ pub(crate) unsafe fn enzyme_ad(
ret_activity,
input_tts,
output_tt,
void_ret,
),
DiffMode::Reverse => enzyme_rust_reverse_diff(
logic_ref,
Expand All @@ -1053,7 +1044,6 @@ pub(crate) unsafe fn enzyme_ad(

let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res));

let void_type = LLVMVoidTypeInContext(llcx);
let rev_mode = item.attrs.mode == DiffMode::Reverse;
create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions);
// TODO: implement drop for wrapper type?
Expand Down Expand Up @@ -1142,8 +1132,7 @@ pub(crate) unsafe fn differentiate(
assert!(res.is_ok());
}
for item in higher_order_items {
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt);
//let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref);
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref);
assert!(res.is_ok());
}

Expand Down
23 changes: 16 additions & 7 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
ret_diffactivity: DiffActivity,
input_tts: Vec<TypeTree>,
output_tt: TypeTree,
void_ret: bool,
) -> (&Value, Vec<usize>) {
let ret_activity = cdiffe_from(ret_diffactivity);
assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF);
Expand All @@ -864,12 +865,18 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
input_activity.push(act);
}

let ret_primary_ret = match ret_activity {
CDIFFE_TYPE::DFT_CONSTANT => true,
CDIFFE_TYPE::DFT_DUP_ARG => true,
CDIFFE_TYPE::DFT_DUP_NONEED => false,
_ => panic!("Implementation error in enzyme_rust_forward_diff."),
// if we have void ret, this must be false;
let ret_primary_ret = if void_ret {
false
} else {
match ret_activity {
CDIFFE_TYPE::DFT_CONSTANT => true,
CDIFFE_TYPE::DFT_DUP_ARG => true,
CDIFFE_TYPE::DFT_DUP_NONEED => false,
_ => panic!("Implementation error in enzyme_rust_forward_diff."),
}
};
trace!("ret_primary_ret: {}", &ret_primary_ret);

let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();
//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];
Expand Down Expand Up @@ -897,6 +904,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
for i in &input_activity {
trace!("input_activity i: {}", &i);
}
trace!("before calling Enzyme");
let res = EnzymeCreateForwardDiff(
logic_ref, // Logic
std::ptr::null(),
Expand All @@ -916,7 +924,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
args_uncacheable.len(), // uncacheable arguments
std::ptr::null_mut(), // write augmented function to this
);
dbg!(res);
trace!("after calling Enzyme");
(res, vec![])
}

Expand Down Expand Up @@ -981,6 +989,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
for i in &input_activity {
trace!("input_activity i: {}", &i);
}
trace!("before calling Enzyme");
let res = EnzymeCreatePrimalAndGradient(
logic_ref, // Logic
std::ptr::null(),
Expand All @@ -1003,7 +1012,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
std::ptr::null_mut(), // write augmented function to this
0,
);
dbg!(res);
trace!("after calling Enzyme");
(res, primal_sizes)
}

Expand Down

0 comments on commit a8048a4

Please sign in to comment.