Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix nontrivial input port size support #133

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions src/bartiq/compilation/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from dataclasses import replace
from typing import Callable, TypeVar

from .._routine import Constraint, PortDirection, Resource, ResourceType, Routine
from bartiq.errors import BartiqPrecompilationError

from .._routine import Constraint, Port, PortDirection, Resource, ResourceType, Routine
from ..symbolics.backend import SymbolicBackend, TExpr

T = TypeVar("T")
Expand Down Expand Up @@ -104,24 +106,48 @@ def _introduce_port_variables(routine: Routine[T], backend: SymbolicBackend[T])
additional_local_variables: dict[str, TExpr[T]] = {}
new_input_params: list[str] = []
additional_constraints: list[Constraint[T]] = []
for port in routine.ports.values():
if port.direction == PortDirection.output:
new_ports[port.name] = port
else:
new_variable_name = f"#{port.name}"
new_variable = backend.as_expression(new_variable_name)
if (size := backend.serialize(port.size)) != new_variable_name and backend.is_single_parameter(port.size):
if size not in additional_local_variables:
additional_local_variables[size] = new_variable
else:
additional_constraints.append(Constraint(new_variable, additional_local_variables[size]))
elif backend.is_constant_int(port.size):
additional_constraints.append(Constraint(new_variable, port.size))
new_ports[port.name] = replace(port, size=new_variable)
new_input_params.append(new_variable_name)

# We sort ports so that single-parameter ones are preceeding non-trivial ones.
# This is because for ports with non-trivial sizes we need to be able to
# verify if all the free symbols are properly defined.
def _sort_key(port: Port[T]) -> tuple[bool, str]:
return (not backend.is_single_parameter(port.size), port.name)

# We only process non-output ports, as only they can introduce new input params and local
# variables.
non_output_ports = routine.filter_ports((PortDirection.input, PortDirection.through)).values()
for port in sorted(non_output_ports, key=_sort_key):
new_variable_name = f"#{port.name}"
new_variable = backend.as_expression(new_variable_name)
if port.size != new_variable and backend.is_single_parameter(port.size):
if (size := backend.serialize(port.size)) not in additional_local_variables:
additional_local_variables[size] = new_variable
else:
additional_constraints.append(Constraint(new_variable, additional_local_variables[size]))
elif backend.is_constant_int(port.size):
additional_constraints.append(Constraint(new_variable, port.size))
elif not backend.is_single_parameter(port.size):
missing_symbols = [
symbol
for symbol in backend.free_symbols_in(port.size)
if (
symbol not in routine.input_params
and symbol not in routine.local_variables
and symbol not in additional_local_variables
)
]
if missing_symbols:
raise BartiqPrecompilationError(
f"Size of the port {port.name} depends on symbols {sorted(missing_symbols)} which are undefined."
)
new_size = backend.substitute_all(port.size, additional_local_variables)
new_ports[port.name] = replace(port, size=new_size)
additional_constraints.append(Constraint(new_variable, new_size))
new_ports[port.name] = replace(port, size=new_variable)
new_input_params.append(new_variable_name)
return replace(
routine,
ports=new_ports,
ports={**new_ports, **routine.filter_ports((PortDirection.output,))},
input_params=tuple([*routine.input_params, *new_input_params]),
local_variables={**routine.local_variables, **additional_local_variables},
constraints=tuple([*routine.constraints, *additional_constraints]),
Expand All @@ -138,7 +164,6 @@ def introduce_port_variables(routine: Routine[T], backend: SymbolicBackend[T]) -
Returns:
A routine with the extra variables representing port sizes.
"""

return replace(
routine, children={name: _introduce_port_variables(child, backend) for name, child in routine.children.items()}
)
Expand Down
81 changes: 80 additions & 1 deletion tests/compilation/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from bartiq import compile_routine
from bartiq.compilation.preprocessing import introduce_port_variables
from bartiq.errors import BartiqCompilationError
from bartiq.errors import BartiqCompilationError, BartiqPrecompilationError


def load_compile_test_data():
Expand Down Expand Up @@ -141,3 +141,82 @@ def test_compile_errors(routine, expected_error, backend):
compile_routine(
routine, preprocessing_stages=[introduce_port_variables], backend=backend, skip_verification=True
)


@pytest.mark.parametrize(
"routine, expected_lhs, expected_rhs",
[
(
{
"name": "root",
"type": "dummy",
"children": [
{
"name": "a",
"type": "dummy",
"ports": [
{"name": "in_0", "direction": "input", "size": "N"},
{"name": "in_1", "direction": "input", "size": "2 ** N"},
],
}
],
"ports": [
{"name": "in_0", "size": "K", "direction": "input"},
{"name": "in_1", "size": "K", "direction": "input"},
],
"connections": ["in_0 -> a.in_0", "in_1 -> a.in_1"],
},
"K",
"2 ** K",
),
(
{
"name": "root",
"type": "dummy",
"children": [
{
"name": "a",
"type": "dummy",
"ports": [
{"name": "in_0", "direction": "input", "size": "N"},
{"name": "in_1", "direction": "input", "size": "f(g(N)) + N + 1"},
],
}
],
"ports": [
{"name": "in_0", "size": "K", "direction": "input"},
{"name": "in_1", "size": "K", "direction": "input"},
],
"connections": ["in_0 -> a.in_0", "in_1 -> a.in_1"],
},
"K",
"f(g(K)) + K + 1",
),
],
)
def test_compilation_introduces_constraints_stemming_from_relation_between_port_sizes(
routine, expected_lhs, expected_rhs, backend
):

compiled_routine = compile_routine(routine, backend=backend).routine

constraint = compiled_routine.children["a"].constraints[0]
assert constraint.lhs == backend.as_expression(expected_lhs)
assert constraint.rhs == backend.as_expression(expected_rhs)


def test_compilation_fails_if_input_ports_has_size_depending_on_undefined_variable(backend):
routine = {
"name": "root",
"type": "dummy",
"children": [
{"name": "a", "type": "dummy", "ports": [{"name": "in_0", "direction": "input", "size": "N + M"}]}
],
"ports": [{"name": "in_0", "direction": "input", "size": "K"}],
"connections": ["in_0 -> a.in_0"],
}

with pytest.raises(
BartiqPrecompilationError, match=r"Size of the port in_0 depends on symbols \['M', 'N'\] which are undefined."
):
compile_routine(routine, backend=backend)
Loading