Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: rework extensions interface #119

Merged
merged 9 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ delegate = "0.12.0"
petgraph = "0.6.5"
lazy_static = "1.4.0"
downcast-rs= "1.2.1"
strum = "0.26.3"

[dev-dependencies]
insta = "1.39.0"
Expand Down
264 changes: 127 additions & 137 deletions src/custom.rs
Original file line number Diff line number Diff line change
@@ -1,173 +1,163 @@
use std::{
any::TypeId,
collections::{HashMap, HashSet},
rc::Rc,
};

//! Provides an interface for extending `hugr-llvm` to emit [CustomType]s,
//! [CustomConst]s, and [ExtensionOp]s.
//!
//! [CustomType]: hugr::types::CustomType
//! [CustomConst]: hugr::ops::constant::CustomConst
//! [ExtensionOp]: hugr::ops::ExtensionOp
use std::rc::Rc;

use self::extension_op::{ExtensionOpFn, ExtensionOpMap};
use hugr::{
extension::ExtensionId,
ops::{constant::CustomConst, ExtensionOp},
types::CustomType,
extension::{simple_op::MakeOpDef, ExtensionId},
ops::{constant::CustomConst, ExtensionOp, OpName},
HugrView,
};

use anyhow::{anyhow, Result};
use inkwell::{types::BasicTypeEnum, values::BasicValueEnum};
use strum::IntoEnumIterator;
use types::CustomTypeKey;

use self::load_constant::{LoadConstantFn, LoadConstantsMap};
use self::types::LLVMCustomTypeFn;
use anyhow::Result;

use crate::{
emit::{func::EmitFuncContext, EmitOpArgs},
types::TypingSession,
types::TypeConverter,
};

pub mod extension_op;
pub mod load_constant;
pub mod types;

// TODO move these extension implementations to crate::extension
// https://github.com/CQCL/hugr-llvm/issues/121
pub mod conversions;
pub mod float;
pub mod int;
pub mod logic;
pub mod prelude;
pub mod rotation;

/// The extension point for lowering HUGR Extensions to LLVM.
pub trait CodegenExtension<H> {
/// The [ExtensionId] for which this extension will lower `ExtensionOp`s and
/// [CustomType]s.
///
/// Note that a [CodegenExtsMap] will only delegate to a single
/// `CodegenExtension` per [ExtensionId].
fn extension(&self) -> ExtensionId;

/// The [TypeId]s for which [dyn CustomConst](CustomConst)s should be passed
/// to [Self::load_constant].
///
/// Defaults to an empty set.
fn supported_consts(&self) -> HashSet<TypeId> {
Default::default()
}

/// Return the type of the given [CustomType], which will have an extension
/// that matches `Self`.
fn llvm_type<'c>(
&self,
context: &TypingSession<'c, H>,
hugr_type: &CustomType,
) -> Result<BasicTypeEnum<'c>>;

/// Return an emitter that will be asked to emit `ExtensionOp`s that have an
/// extension that matches `Self.`
fn emit_extension_op<'c>(
&self,
context: &mut EmitFuncContext<'c, H>,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()>;

/// Emit instructions to materialise `konst`. `konst` will have a [TypeId]
/// that matches `self.supported_consts`.
///
/// If the result is `Ok(None)`, [CodegenExtsMap] may try other
/// `CodegenExtension`s.
fn load_constant<'c>(
&self,
#[allow(unused)] context: &mut EmitFuncContext<'c, H>,
#[allow(unused)] konst: &dyn CustomConst,
) -> Result<Option<BasicValueEnum<'c>>> {
Ok(None)
}
/// A helper to register codegen extensions.
///
/// Types that implement this trait can be registered with a [CodegenExtsBuilder]
/// via [CodegenExtsBuilder::add_extension].
///
/// See [prelude::PreludeCodegenExtension] for an example.
pub trait CodegenExtension {
/// Implementers should add each of their handlers to `builder` and return the
/// resulting [CodegenExtsBuilder].
fn add_extension<'a, H: HugrView + 'a>(
self,
builder: CodegenExtsBuilder<'a, H>,
) -> CodegenExtsBuilder<'a, H>
where
Self: 'a;
}

/// A collection of [CodegenExtension]s.
/// A container for a collection of codegen callbacks as they are being
/// assembled.
///
/// Provides methods to delegate operations to appropriate contained
/// [CodegenExtension]s.
pub struct CodegenExtsMap<'a, H> {
supported_consts: HashMap<TypeId, HashSet<ExtensionId>>,
extensions: HashMap<ExtensionId, Box<dyn 'a + CodegenExtension<H>>>,
/// The callbacks are registered against several keys:
/// - [CustomType]s, with [CodegenExtsBuilder::custom_type]
/// - [CustomConst]s, with [CodegenExtsBuilder::custom_const]
/// - [ExtensionOp]s, with [CodegenExtsBuilder::extension_op]
///
/// Each callback may hold references older than `'a`.
///
/// Registering any callback silently replaces any other callback registered for
/// that same key.
///
/// [CustomType]: hugr::types::CustomType
#[derive(Default)]
pub struct CodegenExtsBuilder<'a, H> {
load_constant_handlers: LoadConstantsMap<'a, H>,
extension_op_handlers: ExtensionOpMap<'a, H>,
type_converter: TypeConverter<'a>,
}

impl<'c, H> CodegenExtsMap<'c, H> {
/// Create a new, empty, `CodegenExtsMap`.
pub fn new() -> Self {
Self {
supported_consts: Default::default(),
extensions: Default::default(),
}
impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> {
/// Forwards to [CodegenExtension::add_extension].
///
/// ```
/// use hugr_llvm::custom::{prelude::{PreludeCodegenExtension, DefaultPreludeCodegen}, CodegenExtsBuilder};
/// let ext = PreludeCodegenExtension::from(DefaultPreludeCodegen);
/// CodegenExtsBuilder::<hugr::Hugr>::default().add_extension(ext);
/// ```
pub fn add_extension(self, ext: impl CodegenExtension + 'a) -> Self {
ext.add_extension(self)
}

/// Consumes a `CodegenExtsMap` and returns a new one, with `ext`
/// incorporated.
pub fn add_cge(mut self, ext: impl 'c + CodegenExtension<H>) -> Self {
let extension = ext.extension();
for k in ext.supported_consts() {
self.supported_consts
.entry(k)
.or_default()
.insert(extension.clone());
}
self.extensions.insert(extension, Box::new(ext));
/// Register a callback to map a [CustomType] to a [BasicTypeEnum].
///
/// [CustomType]: hugr::types::CustomType
/// [BasicTypeEnum]: inkwell::types::BasicTypeEnum
pub fn custom_type(
mut self,
custom_type: CustomTypeKey,
handler: impl LLVMCustomTypeFn<'a>,
) -> Self {
self.type_converter.custom_type(custom_type, handler);
self
}

/// Returns the matching inner [CodegenExtension] if it exists.
pub fn get(&self, extension: &ExtensionId) -> Result<&dyn CodegenExtension<H>> {
let b = self
.extensions
.get(extension)
.ok_or(anyhow!("CodegenExtsMap: Unknown extension: {}", extension))?;
Ok(b.as_ref())
/// Register a callback to emit a [ExtensionOp], keyed by fully
/// qualified [OpName].
pub fn extension_op(
mut self,
extension: ExtensionId,
op: OpName,
handler: impl ExtensionOpFn<'a, H>,
) -> Self {
self.extension_op_handlers
.extension_op(extension, op, handler);
self
}

/// Return the type of the given [CustomType] by delegating to the
/// appropriate inner [CodegenExtension].
pub fn llvm_type(
&self,
ts: &TypingSession<'c, H>,
hugr_type: &CustomType,
) -> Result<BasicTypeEnum<'c>> {
self.get(hugr_type.extension())?.llvm_type(ts, hugr_type)
/// Register callbacks to emit [ExtensionOp]s that match the
/// definitions generated by `Op`s impl of [strum::IntoEnumIterator]>
pub fn simple_extension_op<Op: MakeOpDef + IntoEnumIterator>(
mut self,
handler: impl 'a
+ for<'c> Fn(
&mut EmitFuncContext<'c, H>,
EmitOpArgs<'c, '_, ExtensionOp, H>,
Op,
) -> Result<()>,
) -> Self {
self.extension_op_handlers
.simple_extension_op::<Op>(handler);
self
}

/// Emit instructions for `args` by delegating to the appropriate inner
/// [CodegenExtension].
pub fn emit<'hugr>(
self: Rc<Self>,
context: &mut EmitFuncContext<'c, H>,
args: EmitOpArgs<'c, 'hugr, ExtensionOp, H>,
) -> Result<()>
where
H: HugrView,
{
self.get(args.node().def().extension())?
.emit_extension_op(context, args)
/// Register a callback to materialise a constant implemented by `CC`.
pub fn custom_const<CC: CustomConst>(
mut self,
handler: impl LoadConstantFn<'a, H, CC>,
) -> Self {
self.load_constant_handlers.custom_const(handler);
self
}

/// Emit instructions to materialise `konst` by delegating to the
/// appropriate inner [CodegenExtension]s.
pub fn load_constant(
&self,
context: &mut EmitFuncContext<'c, H>,
konst: &dyn CustomConst,
) -> Result<BasicValueEnum<'c>> {
let type_id = konst.type_id();
self.supported_consts
.get(&type_id)
.into_iter()
.flatten()
.filter_map(|ext| {
let cge = self.extensions.get(ext).unwrap();
match cge.load_constant(context, konst) {
Err(e) => Some(Err(e)),
Ok(None) => None,
Ok(Some(v)) => Some(Ok(v)),
}
})
.next()
.unwrap_or(Err(anyhow!(
"No extension could load constant name: {} type_id: {type_id:?}",
konst.name()
)))
/// Consume `self` to return collections of callbacks for each of the
/// supported keys.`
pub fn finish(self) -> CodegenExtsMap<'a, H> {
CodegenExtsMap {
load_constant_handlers: Rc::new(self.load_constant_handlers),
extension_op_handlers: Rc::new(self.extension_op_handlers),
type_converter: Rc::new(self.type_converter),
}
}
}

impl<'c, H: HugrView> Default for CodegenExtsMap<'c, H> {
fn default() -> Self {
Self::new()
}
/// The result of [CodegenExtsBuilder::finish]. Users are expected to
/// deconstruct this type, and consume the fields independently.
/// We expect to add further collections at a later date, and so this type is
/// marked `non_exhaustive`
#[derive(Default)]
#[non_exhaustive]
pub struct CodegenExtsMap<'a, H> {
pub load_constant_handlers: Rc<LoadConstantsMap<'a, H>>,
pub extension_op_handlers: Rc<ExtensionOpMap<'a, H>>,
pub type_converter: Rc<TypeConverter<'a>>,
}
Loading