Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff Upstreaming - rustc_codegen_ssa, rustc_middle #133429

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4234,6 +4234,7 @@ name = "rustc_monomorphize"
version = "0.0.0"
dependencies = [
"rustc_abi",
"rustc_ast",
"rustc_attr_parsing",
"rustc_data_structures",
"rustc_errors",
Expand All @@ -4243,6 +4244,7 @@ dependencies = [
"rustc_middle",
"rustc_session",
"rustc_span",
"rustc_symbol_mangling",
"rustc_target",
"serde",
"serde_json",
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub struct AutoDiffItem {
pub target: String,
pub attrs: AutoDiffAttrs,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
Expand Down Expand Up @@ -231,7 +232,7 @@ impl AutoDiffAttrs {
self.ret_activity == DiffActivity::ActiveOnly
}

pub fn error() -> Self {
pub const fn error() -> Self {
AutoDiffAttrs {
mode: DiffMode::Error,
ret_activity: DiffActivity::None,
Expand Down
14 changes: 7 additions & 7 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ fn generate_enzyme_call<'ll>(
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
let name = llvm::get_value_name(outer_fn);
let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap();
ad_name.push_str(outer_fn_name.to_string().as_str());
let outer_fn_name = std::str::from_utf8(name).unwrap();
ad_name.push_str(outer_fn_name);

// Let us assume the user wrote the following function square:
//
Expand Down Expand Up @@ -255,14 +255,14 @@ fn generate_enzyme_call<'ll>(
// have no debug info to copy, which would then be ok.
trace!("no dbg info");
}

// Now that we copied the metadata, get rid of dummy code.
llvm::LLVMRustEraseInstBefore(entry, last_inst);
llvm::LLVMRustEraseInstFromParent(last_inst);
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);

if cx.val_ty(outer_fn) != cx.type_void() {
builder.ret(call);
} else {
if cx.val_ty(call) == cx.type_void() {
builder.ret_void();
} else {
builder.ret(call);
}

// Let's crash in case that we messed something up above and generated invalid IR.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ struct UsageSets<'tcx> {
/// Prepare sets of definitions that are relevant to deciding whether something
/// is an "unused function" for coverage purposes.
fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> {
let MonoItemPartitions { all_mono_items, codegen_units } =
let MonoItemPartitions { all_mono_items, codegen_units, .. } =
tcx.collect_and_partition_mono_items(());

// Obtain a MIR body for each function participating in codegen, via an
Expand Down
4 changes: 3 additions & 1 deletion compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ use crate::llvm::Bool;
extern "C" {
// Enzyme
pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;
pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value);
pub fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value);
pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
pub fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
pub fn LLVMRustEraseInstFromParent(V: &Value);
pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
pub fn LLVMDumpModule(M: &Module);
pub fn LLVMDumpValue(V: &Value);
pub fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;

pub fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_ssa/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ codegen_ssa_archive_build_failure = failed to build archive at `{$path}`: {$erro

codegen_ssa_atomic_compare_exchange = Atomic compare-exchange intrinsic missing failure memory ordering

codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto

codegen_ssa_binary_output_to_tty = option `-o` or `--emit` is used to write binary output type `{$shorthand}` to stdout, but stdout is a tty

codegen_ssa_cgu_not_recorded =
Expand Down
38 changes: 33 additions & 5 deletions compiler/rustc_codegen_ssa/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
use std::{fs, io, mem, str, thread};

use rustc_ast::attr;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
use rustc_data_structures::jobserver::{self, Acquired};
use rustc_data_structures::memmap::Mmap;
Expand Down Expand Up @@ -40,7 +41,7 @@ use tracing::debug;
use super::link::{self, ensure_removed};
use super::lto::{self, SerializedModule};
use super::symbol_export::symbol_name_for_instance_in_crate;
use crate::errors::ErrorCreatingRemarkDir;
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
use crate::traits::*;
use crate::{
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
Expand Down Expand Up @@ -118,6 +119,7 @@ pub struct ModuleConfig {
pub merge_functions: bool,
pub emit_lifetime_markers: bool,
pub llvm_plugins: Vec<String>,
pub autodiff: Vec<config::AutoDiff>,
}

impl ModuleConfig {
Expand Down Expand Up @@ -266,6 +268,7 @@ impl ModuleConfig {

emit_lifetime_markers: sess.emit_lifetime_markers(),
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
}
}

Expand Down Expand Up @@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {

fn generate_lto_work<B: ExtraBackendMethods>(
cgcx: &CodegenContext<B>,
autodiff: Vec<AutoDiffItem>,
needs_fat_lto: Vec<FatLtoInput<B>>,
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
Expand All @@ -397,11 +401,19 @@ fn generate_lto_work<B: ExtraBackendMethods>(

if !needs_fat_lto.is_empty() {
assert!(needs_thin_lto.is_empty());
let module =
let mut module =
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
if cgcx.lto == Lto::Fat {
let config = cgcx.config(ModuleKind::Regular);
module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
}
// We are adding a single work item, so the cost doesn't matter.
vec![(WorkItem::LTO(module), 0)]
} else {
if !autodiff.is_empty() {
let dcx = cgcx.create_dcx();
dcx.handle().emit_fatal(AutodiffWithoutLto {});
}
assert!(needs_fat_lto.is_empty());
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
.unwrap_or_else(|e| e.raise());
Expand Down Expand Up @@ -1021,6 +1033,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
/// Sent from a backend worker thread.
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },

/// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
AddAutoDiffItems(Vec<AutoDiffItem>),

/// The frontend has finished generating something (backend IR or a
/// post-LTO artifact) for a codegen unit, and it should be passed to the
/// backend. Sent from the main thread.
Expand Down Expand Up @@ -1348,6 +1363,7 @@ fn start_executing_work<B: ExtraBackendMethods>(

// This is where we collect codegen units that have gone all the way
// through codegen and LLVM.
let mut autodiff_items = Vec::new();
let mut compiled_modules = vec![];
let mut compiled_allocator_module = None;
let mut needs_link = Vec::new();
Expand Down Expand Up @@ -1459,9 +1475,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
let needs_thin_lto = mem::take(&mut needs_thin_lto);
let import_only_modules = mem::take(&mut lto_import_only_modules);

for (work, cost) in
generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
{
for (work, cost) in generate_lto_work(
&cgcx,
autodiff_items.clone(),
needs_fat_lto,
needs_thin_lto,
import_only_modules,
) {
let insertion_index = work_items
.binary_search_by_key(&cost, |&(_, cost)| cost)
.unwrap_or_else(|e| e);
Expand Down Expand Up @@ -1596,6 +1616,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
main_thread_state = MainThreadState::Idle;
}

Message::AddAutoDiffItems(mut items) => {
autodiff_items.append(&mut items);
}

Message::CodegenComplete => {
if codegen_state != Aborted {
codegen_state = Completed;
Expand Down Expand Up @@ -2070,6 +2094,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
}

pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
}

pub(crate) fn check_for_errors(&self, sess: &Session) {
self.shared_emitter_main.check(sess, false);
}
Expand Down
10 changes: 8 additions & 2 deletions compiler/rustc_codegen_ssa/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use rustc_middle::middle::debugger_visualizer::{DebuggerVisualizerFile, Debugger
use rustc_middle::middle::exported_symbols::SymbolExportKind;
use rustc_middle::middle::{exported_symbols, lang_items};
use rustc_middle::mir::BinOp;
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem};
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem, MonoItemPartitions};
use rustc_middle::query::Providers;
use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
Expand Down Expand Up @@ -619,7 +619,9 @@ pub fn codegen_crate<B: ExtraBackendMethods>(

// Run the monomorphization collector and partition the collected items into
// codegen units.
let codegen_units = tcx.collect_and_partition_mono_items(()).codegen_units;
let MonoItemPartitions { codegen_units, autodiff_items, .. } =
tcx.collect_and_partition_mono_items(());
let autodiff_fncs = autodiff_items.to_vec();

// Force all codegen_unit queries so they are already either red or green
// when compile_codegen_unit accesses them. We are not able to re-execute
Expand Down Expand Up @@ -690,6 +692,10 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
);
}

if !autodiff_fncs.is_empty() {
ongoing_codegen.submit_autodiff_items(autodiff_fncs);
}

// For better throughput during parallel processing by LLVM, we used to sort
// CGUs largest to smallest. This would lead to better thread utilization
// by, for example, preventing a large CGU from being processed last and
Expand Down
118 changes: 117 additions & 1 deletion compiler/rustc_codegen_ssa/src/codegen_attrs.rs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remark: this would really benefit from the attr handling rework cc #131808.

Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use std::str::FromStr;

use rustc_ast::attr::list_contains_name;
use rustc_ast::{MetaItemInner, attr};
use rustc_ast::expand::autodiff_attrs::{
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
};
use rustc_ast::{MetaItem, MetaItemInner, attr};
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
use rustc_data_structures::fx::FxHashMap;
use rustc_errors::codes::*;
Expand All @@ -13,6 +18,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
};
use rustc_middle::mir::mono::Linkage;
use rustc_middle::query::Providers;
use rustc_middle::span_bug;
use rustc_middle::ty::{self as ty, TyCtxt};
use rustc_session::parse::feature_err;
use rustc_session::{Session, lint};
Expand Down Expand Up @@ -65,6 +71,13 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs {
codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER;
}

// If our rustc version supports autodiff/enzyme, then we call our handler
// to check for any `#[rustc_autodiff(...)]` attributes.
if cfg!(llvm_enzyme) {
let ad = autodiff_attrs(tcx, did.into());
codegen_fn_attrs.autodiff_item = ad;
}

// When `no_builtins` is applied at the crate level, we should add the
// `no-builtins` attribute to each function to ensure it takes effect in LTO.
let crate_attrs = tcx.hir().attrs(rustc_hir::CRATE_HIR_ID);
Expand Down Expand Up @@ -856,6 +869,109 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
}
}

/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)]
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
/// panic, unless we introduced a bug when parsing the autodiff macro.
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);

let attrs =
attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>();

// check for exactly one autodiff attribute on placeholder functions.
// There should only be one, since we generate a new placeholder per ad macro.
// FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but
// looks strange e.g. under cargo-expand.
let attr = match &attrs[..] {
[] => return None,
[attr] => attr,
// These two attributes are the same and unfortunately duplicated due to a previous bug.
[attr, _attr2] => attr,
_ => {
//FIXME(ZuseZ4): Once we fixed our parser, we should also prohibit the two-attribute
//branch above.
span_bug!(attrs[1].span, "cg_ssa: rustc_autodiff should only exist once per source");
}
};

let list = attr.meta_item_list().unwrap_or_default();

// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
if list.is_empty() {
return Some(AutoDiffAttrs::source());
}

let [mode, input_activities @ .., ret_activity] = &list[..] else {
span_bug!(attr.span, "rustc_autodiff attribute must contain mode and activities");
};
let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
p1.segments.first().unwrap().ident
} else {
span_bug!(attr.span, "rustc_autodiff attribute must contain mode");
};

// parse mode
let mode = match mode.as_str() {
"Forward" => DiffMode::Forward,
"Reverse" => DiffMode::Reverse,
"ForwardFirst" => DiffMode::ForwardFirst,
"ReverseFirst" => DiffMode::ReverseFirst,
_ => {
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
}
};

// First read the ret symbol from the attribute
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
p1.segments.first().unwrap().ident
} else {
span_bug!(attr.span, "rustc_autodiff attribute must contain the return activity");
};

// Then parse it into an actual DiffActivity
let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
span_bug!(ret_symbol.span, "invalid return activity");
};

// Now parse all the intermediate (input) activities
let mut arg_activities: Vec<DiffActivity> = vec![];
for arg in input_activities {
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
match p2.segments.first() {
Some(x) => x.ident,
None => {
span_bug!(
arg.span(),
"rustc_autodiff attribute must contain the input activity"
);
}
}
} else {
span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity");
};

match DiffActivity::from_str(arg_symbol.as_str()) {
Ok(arg_activity) => arg_activities.push(arg_activity),
Err(_) => {
span_bug!(arg_symbol.span, "invalid input activity");
}
}
}

for &input in &arg_activities {
if !valid_input_activity(mode, input) {
span_bug!(attr.span, "Invalid input activity {} for {} mode", input, mode);
}
}
if !valid_ret_activity(mode, ret_activity) {
span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode);
}

Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
}

pub(crate) fn provide(providers: &mut Providers) {
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
}
Loading
Loading