-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor fusing #1386
base: main
Are you sure you want to change the base?
Refactor fusing #1386
Conversation
framework_attr={}, | ||
input_shape=nodes[0].input_shape, | ||
output_shape=nodes[-1].output_shape, | ||
weights={}, | ||
weights={}, # TODO: update with weights of all nodes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the todo here planned for this PR?
is it necessary actually? because you can always retrieve the original weights from the original graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, in the PR that handles the MP.
@@ -82,6 +83,11 @@ def search_bit_width(graph_to_search_cfg: Graph, | |||
|
|||
# Set graph for MP search | |||
graph = copy.deepcopy(graph_to_search_cfg) # Copy graph before searching | |||
# TODO: The handle of mixed precision with the fused graph will be in a separate PR. Currently, the bit-width |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also needs to be integrated with the BOPs virtual graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the integration with the BOPs, is simply to build the virtual graph from the fused graph.
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider adding an e2e tests to verify that none of our substitutions fails the fusing metadata.
tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py
Outdated
Show resolved
Hide resolved
tests_pytest/_fw_tests_common_base/base_graph_with_fusion_metadata_test.py
Outdated
Show resolved
Hide resolved
tests_pytest/_fw_tests_common_base/base_graph_with_fusion_metadata_test.py
Outdated
Show resolved
Hide resolved
This fixture defines allowed operations and fusing patterns for testing. | ||
""" | ||
return schema.TargetPlatformCapabilities( | ||
default_qco=default_quant_cfg_options, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe check also more complicated scenarios like operators with mixed precision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted it to be minimal to check only the fusing, maybe in the e2e?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it will be complicated to test this specific attributes on a model in an e2e manner.
in these tests you verify the graph, which is more correct for verifying such edge-cases IMO.
you can leave this extension to the next PR of the e2e tests, and also extend the integration tests then.
tests_pytest/_fw_tests_common_base/base_graph_with_fusion_metadata_test.py
Outdated
Show resolved
Hide resolved
tests_pytest/_fw_tests_common_base/base_graph_with_fusion_metadata_test.py
Outdated
Show resolved
Hide resolved
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py
Outdated
Show resolved
Hide resolved
tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py
Outdated
Show resolved
Hide resolved
tests_pytest/_fw_tests_common_base/fusing/base_fusing_info_generator_test.py
Show resolved
Hide resolved
This fixture defines allowed operations and fusing patterns for testing. | ||
""" | ||
return schema.TargetPlatformCapabilities( | ||
default_qco=default_quant_cfg_options, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it will be complicated to test this specific attributes on a model in an e2e manner.
in these tests you verify the graph, which is more correct for verifying such edge-cases IMO.
you can leave this extension to the next PR of the e2e tests, and also extend the integration tests then.
Pull Request Description:
This PR introduces handling graph fusion by encapsulating fusion metadata within a new wrapper class for the graph. This class ensures that after every access to the graph, a validation check is performed to verify that the fusion information remains consistent and that no modifications have introduced inconsistencies. Additionally, the fusion-related logic has been refactored into a new class called
GraphFuser
, which takes the graph along with its fusion metadata and creates a new graph where fused operations are represented as single nodes.Checklist before requesting a review: