Skip to content

Commit

Permalink
correcting the order of group and dilation parameters in Conv transpo…
Browse files Browse the repository at this point in the history
…se layers.

Fix issue #21

Signed-off-by: Ranganath Krishnan <ranganath.krishnan@intel.com>
  • Loading branch information
ranganathkrishnan committed Jan 2, 2024
1 parent 1180b87 commit 97ba16a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 27 deletions.
24 changes: 12 additions & 12 deletions bayesian_torch/layers/flipout_layers/conv_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,8 +769,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)

# sampling perturbation signs
sign_input = x.clone().uniform_(-1, 1).sign()
Expand Down Expand Up @@ -803,8 +803,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)
perturbed_outputs = perturbed_outputs_tmp * sign_output
out = outputs + perturbed_outputs

Expand Down Expand Up @@ -968,8 +968,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)

# sampling perturbation signs
sign_input = x.clone().uniform_(-1, 1).sign()
Expand Down Expand Up @@ -1002,8 +1002,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)
perturbed_outputs = perturbed_outputs_tmp * sign_output
out = outputs + perturbed_outputs

Expand Down Expand Up @@ -1167,8 +1167,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)

# sampling perturbation signs
sign_input = x.clone().uniform_(-1, 1).sign()
Expand Down Expand Up @@ -1200,8 +1200,8 @@ def forward(self, x, return_kl=True):
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups)
groups=self.groups,
dilation=self.dilation)
perturbed_outputs = perturbed_outputs_tmp * sign_output
out = outputs + perturbed_outputs

Expand Down
12 changes: 6 additions & 6 deletions bayesian_torch/layers/flipout_layers/quantized_conv_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(self.quantized_mu_weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

outputs = torch.ops.quantized.conv_transpose1d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand All @@ -923,7 +923,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(delta_kernel, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)
perturbed_outputs = torch.ops.quantized.conv_transpose1d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)
Expand Down Expand Up @@ -1106,7 +1106,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(self.quantized_mu_weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

outputs = torch.ops.quantized.conv_transpose2d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand All @@ -1131,7 +1131,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(delta_kernel, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)
perturbed_outputs = torch.ops.quantized.conv_transpose2d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)
Expand Down Expand Up @@ -1314,7 +1314,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(self.quantized_mu_weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

outputs = torch.ops.quantized.conv_transpose3d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand All @@ -1339,7 +1339,7 @@ def forward(self, x, normal_scale=6/255, default_scale=0.1, default_zero_point=1

self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(delta_kernel, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)
perturbed_outputs = torch.ops.quantized.conv_transpose3d(x, self._packed_params, scale=default_scale, zero_point=default_zero_point)

perturbed_outputs = torch.ops.quantized.mul(perturbed_outputs, sign_output, default_scale, default_zero_point)
Expand Down
6 changes: 3 additions & 3 deletions bayesian_torch/layers/variational_layers/conv_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def forward(self, input, return_kl=True):

out = F.conv_transpose1d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

if self.quant_prepare:
# quint8 quantstub
Expand Down Expand Up @@ -894,7 +894,7 @@ def forward(self, input, return_kl=True):

out = F.conv_transpose2d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

if self.quant_prepare:
# quint8 quantstub
Expand Down Expand Up @@ -1070,7 +1070,7 @@ def forward(self, input, return_kl=True):

out = F.conv_transpose3d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

if self.quant_prepare:
# quint8 quantstub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

out = F.conv_transpose1d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

else:
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
Expand All @@ -1019,7 +1019,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

out = torch.ops.quantized.conv_transpose1d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand Down Expand Up @@ -1227,7 +1227,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

out = F.conv_transpose2d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

else:
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
Expand All @@ -1250,7 +1250,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

out = torch.ops.quantized.conv_transpose2d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand Down Expand Up @@ -1458,7 +1458,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

out = F.conv_transpose3d(input, weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

else:
eps_kernel = torch.quantize_per_tensor(self.eps_kernel.data.normal_(), normal_scale, 0, torch.qint8) # Quantize a tensor from normal distribution. 99.7% values will lie within 3 standard deviations, so the original range is set as 6.
Expand All @@ -1481,7 +1481,7 @@ def forward(self, input, enable_int8_compute=True, normal_scale=6/255, default_s

self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(weight, bias, self.stride,
self.padding, self.output_padding,
self.dilation, self.groups)
self.groups, self.dilation)

out = torch.ops.quantized.conv_transpose3d(input, self._packed_params, scale=default_scale, zero_point=default_zero_point)

Expand Down

0 comments on commit 97ba16a

Please sign in to comment.