From bc63bd57e4b833f41973f7a77666db8a863df1f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ja=C5=82owiecki?= Date: Thu, 24 Oct 2024 11:42:02 +0200 Subject: [PATCH 1/6] Add testcase validating constraints introduced by relationship between port sizes --- tests/compilation/test_compile.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/compilation/test_compile.py b/tests/compilation/test_compile.py index afb086a..7851d36 100644 --- a/tests/compilation/test_compile.py +++ b/tests/compilation/test_compile.py @@ -141,3 +141,31 @@ def test_compile_errors(routine, expected_error, backend): compile_routine( routine, preprocessing_stages=[introduce_port_variables], backend=backend, skip_verification=True ) + + +def test_compilation_introduces_constraints_stemming_from_relation_between_port_sizes(backend): + routine = RoutineV1( + name="root", + type="dummy", + children=[ + { + "name": "a", + "type": "dummy", + "ports": [ + {"name": "in_0", "direction": "input", "size": "N"}, + {"name": "in_1", "direction": "input", "size": "2 ** N"}, + ], + } + ], + ports=[ + {"name": "in_0", "size": "K", "direction": "input"}, + {"name": "in_1", "size": "K", "direction": "input"}, + ], + connections=["in_0 -> a.in_0", "in_1 -> a.in_1"], + ) + + compiled_routine = compile_routine(routine, backend=backend).routine + + constraint = compiled_routine.children["a"].constraints[0] + assert constraint.lhs == backend.as_expression("K") + assert constraint.rhs == backend.as_expression("2 ** K") From 6c66223772a6e0fdf1eb3e0600ce3cdb81548808 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ja=C5=82owiecki?= Date: Thu, 24 Oct 2024 12:28:42 +0200 Subject: [PATCH 2/6] Implement initial logic for handling non-trivial port sizes dependent on other ports --- src/bartiq/compilation/preprocessing.py | 27 ++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/bartiq/compilation/preprocessing.py b/src/bartiq/compilation/preprocessing.py index 4880f26..0355999 100644 --- a/src/bartiq/compilation/preprocessing.py +++ b/src/bartiq/compilation/preprocessing.py @@ -2,7 +2,9 @@ from dataclasses import replace from typing import Callable, TypeVar -from .._routine import Constraint, PortDirection, Resource, ResourceType, Routine +from bartiq.errors import BartiqPrecompilationError + +from .._routine import Constraint, Port, PortDirection, Resource, ResourceType, Routine from ..symbolics.backend import SymbolicBackend, TExpr T = TypeVar("T") @@ -104,7 +106,14 @@ def _introduce_port_variables(routine: Routine[T], backend: SymbolicBackend[T]) additional_local_variables: dict[str, TExpr[T]] = {} new_input_params: list[str] = [] additional_constraints: list[Constraint[T]] = [] - for port in routine.ports.values(): + + # We sort ports so that single-parameter ones are preceeding non-trivial ones. + # This is because for ports with non-trivial sizes we need to be able to + # verify if all the free symbols are properly defined. + def _sort_key(port: Port[T]) -> tuple[bool, str]: + return (not backend.is_single_parameter(port.size), port.name) + + for port in sorted(routine.ports.values(), key=_sort_key): if port.direction == PortDirection.output: new_ports[port.name] = port else: @@ -117,6 +126,19 @@ def _introduce_port_variables(routine: Routine[T], backend: SymbolicBackend[T]) additional_constraints.append(Constraint(new_variable, additional_local_variables[size])) elif backend.is_constant_int(port.size): additional_constraints.append(Constraint(new_variable, port.size)) + elif not backend.is_single_parameter(port.size): + for symbol in backend.free_symbols_in(port.size): + if ( + symbol not in routine.input_params + and symbol not in routine.local_variables + and symbol not in additional_local_variables + ): + raise BartiqPrecompilationError( + f"Size of the port {port.name} depends on symbol {symbol} which is undefined." + ) + new_size = backend.substitute_all(port.size, additional_local_variables) + new_ports[port.name] = replace(port, size=new_size) + additional_constraints.append(Constraint(new_variable, new_size)) new_ports[port.name] = replace(port, size=new_variable) new_input_params.append(new_variable_name) return replace( @@ -138,7 +160,6 @@ def introduce_port_variables(routine: Routine[T], backend: SymbolicBackend[T]) - Returns: A routine with the extra variables representing port sizes. """ - return replace( routine, children={name: _introduce_port_variables(child, backend) for name, child in routine.children.items()} ) From d500deb59df30fbb83f8f12caddb79c8152d8b25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ja=C5=82owiecki?= Date: Thu, 24 Oct 2024 12:51:35 +0200 Subject: [PATCH 3/6] Slightly simplify iteration --- src/bartiq/compilation/preprocessing.py | 58 ++++++++++++------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/bartiq/compilation/preprocessing.py b/src/bartiq/compilation/preprocessing.py index 0355999..1e5f16f 100644 --- a/src/bartiq/compilation/preprocessing.py +++ b/src/bartiq/compilation/preprocessing.py @@ -113,37 +113,37 @@ def _introduce_port_variables(routine: Routine[T], backend: SymbolicBackend[T]) def _sort_key(port: Port[T]) -> tuple[bool, str]: return (not backend.is_single_parameter(port.size), port.name) - for port in sorted(routine.ports.values(), key=_sort_key): - if port.direction == PortDirection.output: - new_ports[port.name] = port - else: - new_variable_name = f"#{port.name}" - new_variable = backend.as_expression(new_variable_name) - if (size := backend.serialize(port.size)) != new_variable_name and backend.is_single_parameter(port.size): - if size not in additional_local_variables: - additional_local_variables[size] = new_variable - else: - additional_constraints.append(Constraint(new_variable, additional_local_variables[size])) - elif backend.is_constant_int(port.size): - additional_constraints.append(Constraint(new_variable, port.size)) - elif not backend.is_single_parameter(port.size): - for symbol in backend.free_symbols_in(port.size): - if ( - symbol not in routine.input_params - and symbol not in routine.local_variables - and symbol not in additional_local_variables - ): - raise BartiqPrecompilationError( - f"Size of the port {port.name} depends on symbol {symbol} which is undefined." - ) - new_size = backend.substitute_all(port.size, additional_local_variables) - new_ports[port.name] = replace(port, size=new_size) - additional_constraints.append(Constraint(new_variable, new_size)) - new_ports[port.name] = replace(port, size=new_variable) - new_input_params.append(new_variable_name) + # We only process non-output ports, as only they can introduce new input params and local + # variables. + non_output_ports = routine.filter_ports((PortDirection.input, PortDirection.through)).values() + for port in sorted(non_output_ports, key=_sort_key): + new_variable_name = f"#{port.name}" + new_variable = backend.as_expression(new_variable_name) + if port.size != new_variable and backend.is_single_parameter(port.size): + if (size := backend.serialize(port.size)) not in additional_local_variables: + additional_local_variables[size] = new_variable + else: + additional_constraints.append(Constraint(new_variable, additional_local_variables[size])) + elif backend.is_constant_int(port.size): + additional_constraints.append(Constraint(new_variable, port.size)) + elif not backend.is_single_parameter(port.size): + for symbol in backend.free_symbols_in(port.size): + if ( + symbol not in routine.input_params + and symbol not in routine.local_variables + and symbol not in additional_local_variables + ): + raise BartiqPrecompilationError( + f"Size of the port {port.name} depends on symbol {symbol} which is undefined." + ) + new_size = backend.substitute_all(port.size, additional_local_variables) + new_ports[port.name] = replace(port, size=new_size) + additional_constraints.append(Constraint(new_variable, new_size)) + new_ports[port.name] = replace(port, size=new_variable) + new_input_params.append(new_variable_name) return replace( routine, - ports=new_ports, + ports={**new_ports, **routine.filter_ports((PortDirection.output,))}, input_params=tuple([*routine.input_params, *new_input_params]), local_variables={**routine.local_variables, **additional_local_variables}, constraints=tuple([*routine.constraints, *additional_constraints]), From 0da454c6a91fc81e2e79abdea4c2e40c18f7dce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ja=C5=82owiecki?= Date: Thu, 24 Oct 2024 13:01:24 +0200 Subject: [PATCH 4/6] Improve handling of missing symbols --- src/bartiq/compilation/preprocessing.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/bartiq/compilation/preprocessing.py b/src/bartiq/compilation/preprocessing.py index 1e5f16f..de434ba 100644 --- a/src/bartiq/compilation/preprocessing.py +++ b/src/bartiq/compilation/preprocessing.py @@ -127,15 +127,19 @@ def _sort_key(port: Port[T]) -> tuple[bool, str]: elif backend.is_constant_int(port.size): additional_constraints.append(Constraint(new_variable, port.size)) elif not backend.is_single_parameter(port.size): - for symbol in backend.free_symbols_in(port.size): + missing_symbols = [ + symbol + for symbol in backend.free_symbols_in(port.size) if ( symbol not in routine.input_params and symbol not in routine.local_variables and symbol not in additional_local_variables - ): - raise BartiqPrecompilationError( - f"Size of the port {port.name} depends on symbol {symbol} which is undefined." - ) + ) + ] + if missing_symbols: + raise BartiqPrecompilationError( + f"Size of the port {port.name} depends on symbols {missing_symbols} which are undefined." + ) new_size = backend.substitute_all(port.size, additional_local_variables) new_ports[port.name] = replace(port, size=new_size) additional_constraints.append(Constraint(new_variable, new_size)) From f411ef76b543f6310542f0ae5c75365c3ecb85bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ja=C5=82owiecki?= Date: Thu, 24 Oct 2024 15:25:05 +0200 Subject: [PATCH 5/6] Add missing testcase --- src/bartiq/compilation/preprocessing.py | 2 +- tests/compilation/test_compile.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/bartiq/compilation/preprocessing.py b/src/bartiq/compilation/preprocessing.py index de434ba..022fa8f 100644 --- a/src/bartiq/compilation/preprocessing.py +++ b/src/bartiq/compilation/preprocessing.py @@ -138,7 +138,7 @@ def _sort_key(port: Port[T]) -> tuple[bool, str]: ] if missing_symbols: raise BartiqPrecompilationError( - f"Size of the port {port.name} depends on symbols {missing_symbols} which are undefined." + f"Size of the port {port.name} depends on symbols {sorted(missing_symbols)} which are undefined." ) new_size = backend.substitute_all(port.size, additional_local_variables) new_ports[port.name] = replace(port, size=new_size) diff --git a/tests/compilation/test_compile.py b/tests/compilation/test_compile.py index 7851d36..3c6ae4c 100644 --- a/tests/compilation/test_compile.py +++ b/tests/compilation/test_compile.py @@ -22,7 +22,7 @@ from bartiq import compile_routine from bartiq.compilation.preprocessing import introduce_port_variables -from bartiq.errors import BartiqCompilationError +from bartiq.errors import BartiqCompilationError, BartiqPrecompilationError def load_compile_test_data(): @@ -169,3 +169,20 @@ def test_compilation_introduces_constraints_stemming_from_relation_between_port_ constraint = compiled_routine.children["a"].constraints[0] assert constraint.lhs == backend.as_expression("K") assert constraint.rhs == backend.as_expression("2 ** K") + + +def test_compilation_fails_if_input_ports_has_size_depending_on_undefined_variable(backend): + routine = { + "name": "root", + "type": "dummy", + "children": [ + {"name": "a", "type": "dummy", "ports": [{"name": "in_0", "direction": "input", "size": "N + M"}]} + ], + "ports": [{"name": "in_0", "direction": "input", "size": "K"}], + "connections": ["in_0 -> a.in_0"], + } + + with pytest.raises( + BartiqPrecompilationError, match=r"Size of the port in_0 depends on symbols \['M', 'N'\] which are undefined." + ): + compile_routine(routine, backend=backend) From f99f6082c03e17ebee3b048e6fd2d771b81105f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ja=C5=82owiecki?= Date: Thu, 24 Oct 2024 16:12:21 +0200 Subject: [PATCH 6/6] Add testcase requested by the reviewer --- tests/compilation/test_compile.py | 70 +++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/tests/compilation/test_compile.py b/tests/compilation/test_compile.py index 3c6ae4c..4863f98 100644 --- a/tests/compilation/test_compile.py +++ b/tests/compilation/test_compile.py @@ -143,32 +143,66 @@ def test_compile_errors(routine, expected_error, backend): ) -def test_compilation_introduces_constraints_stemming_from_relation_between_port_sizes(backend): - routine = RoutineV1( - name="root", - type="dummy", - children=[ +@pytest.mark.parametrize( + "routine, expected_lhs, expected_rhs", + [ + ( { - "name": "a", + "name": "root", "type": "dummy", + "children": [ + { + "name": "a", + "type": "dummy", + "ports": [ + {"name": "in_0", "direction": "input", "size": "N"}, + {"name": "in_1", "direction": "input", "size": "2 ** N"}, + ], + } + ], "ports": [ - {"name": "in_0", "direction": "input", "size": "N"}, - {"name": "in_1", "direction": "input", "size": "2 ** N"}, + {"name": "in_0", "size": "K", "direction": "input"}, + {"name": "in_1", "size": "K", "direction": "input"}, ], - } - ], - ports=[ - {"name": "in_0", "size": "K", "direction": "input"}, - {"name": "in_1", "size": "K", "direction": "input"}, - ], - connections=["in_0 -> a.in_0", "in_1 -> a.in_1"], - ) + "connections": ["in_0 -> a.in_0", "in_1 -> a.in_1"], + }, + "K", + "2 ** K", + ), + ( + { + "name": "root", + "type": "dummy", + "children": [ + { + "name": "a", + "type": "dummy", + "ports": [ + {"name": "in_0", "direction": "input", "size": "N"}, + {"name": "in_1", "direction": "input", "size": "f(g(N)) + N + 1"}, + ], + } + ], + "ports": [ + {"name": "in_0", "size": "K", "direction": "input"}, + {"name": "in_1", "size": "K", "direction": "input"}, + ], + "connections": ["in_0 -> a.in_0", "in_1 -> a.in_1"], + }, + "K", + "f(g(K)) + K + 1", + ), + ], +) +def test_compilation_introduces_constraints_stemming_from_relation_between_port_sizes( + routine, expected_lhs, expected_rhs, backend +): compiled_routine = compile_routine(routine, backend=backend).routine constraint = compiled_routine.children["a"].constraints[0] - assert constraint.lhs == backend.as_expression("K") - assert constraint.rhs == backend.as_expression("2 ** K") + assert constraint.lhs == backend.as_expression(expected_lhs) + assert constraint.rhs == backend.as_expression(expected_rhs) def test_compilation_fails_if_input_ports_has_size_depending_on_undefined_variable(backend):