-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zak acr #128
Zak acr #128
Conversation
Should now be ready with all the good changes |
# sparsifying filter-bank (SFB) module | ||
class SFB(LIONmodel.LIONmodel): | ||
def __init__(self, model_parameters): | ||
|
||
super().__init__(model_parameters) | ||
# FoE kernels | ||
self.penalty = nn.Parameter((-12.0) * torch.ones(1)) | ||
self.n_kernels = model_parameters.n_kernels | ||
self.conv = nn.ModuleList( | ||
[ | ||
nn.Conv2d( | ||
1, | ||
model_parameters.n_filters, | ||
kernel_size=7, | ||
stride=1, | ||
padding=3, | ||
bias=False, | ||
) | ||
for i in range(self.n_kernels) | ||
] | ||
) | ||
if model_parameters.L2net: | ||
self.L2net = L2net() | ||
|
||
@staticmethod | ||
def default_parameters(): | ||
param = LIONParameter() | ||
param.n_kernels = 10 | ||
param.n_filters = 32 | ||
param.L2net = True | ||
return param | ||
|
||
def forward(self, x): | ||
# compute the output of the FoE part | ||
total_out = 0.0 | ||
for kernel_idx in range(self.n_kernels): | ||
x_out = torch.abs(self.conv[kernel_idx](x)) | ||
x_out_flat = x_out.view(x.size(0), -1) | ||
total_out += torch.sum(x_out_flat, dim=1) | ||
|
||
total_out = total_out.view(x.size(0), -1) | ||
out = (torch.nn.functional.softplus(self.penalty)) * total_out | ||
if self.model_parameters.L2net: | ||
out = out + self.L2net(x) | ||
return out | ||
|
||
|
||
class ACR(LIONmodel.LIONmodel): | ||
def __init__(self, model_parameters: LIONParameter = None): | ||
|
||
super().__init__(model_parameters) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote this from Subhos code. I can't see it anywere (maybe Im blind) here. @Zakobian did you use this SBF in your code? why does your version of ACR not have it? or is it somewhere else that I am missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SFB was in the ACR just for consistency check - not really part of the method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Zakobian so when you train it etc, you don't use it?
What do you mean "consistency check"?
merge conflicts resolved