Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hex conv tf2 #16

Merged
merged 3 commits into from
Feb 24, 2025
Merged

Hex conv tf2 #16

merged 3 commits into from
Feb 24, 2025

Conversation

mhuen
Copy link
Collaborator

@mhuen mhuen commented Feb 24, 2025

The hexagonal convolution kernels could not properly propagate gradients through to their kernel weights because these weights were modfied (via tf.stack/tf.concat) prior to the forward pass of the layer. As a result, GradientTape was not able to find any gradients for the kernel weights. This PR introduces separate classes for the hexagonal kernel creation that seperates out weight initialization to the class's constructor and any transformations to the forward pass in __call__.

Additionally, names for created weights are now fully propagated through.

@mhuen mhuen merged commit aaaacb3 into master Feb 24, 2025
5 checks passed
@mhuen mhuen deleted the HexConvTF2 branch February 24, 2025 16:31
mhuen added a commit that referenced this pull request Feb 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant