Skip to content

Commit

Permalink
add ImageNet example
Browse files Browse the repository at this point in the history
  • Loading branch information
junliang-lin committed Sep 21, 2023
1 parent 93cf0d3 commit 3e74914
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
11 changes: 10 additions & 1 deletion bayesian_torch/ao/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from typing import Any, List, Optional, Type, Union
from torch import Tensor
from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn
from torch.nn import BatchNorm2d
# import copy

__all__ = [
Expand Down Expand Up @@ -140,6 +141,14 @@ def enable_prepare(m):
if callable(prepare):
m._modules[name].prepare()
m._modules[name].dnn_to_bnn_flag=True
elif "BatchNorm2dLayer" in m._modules[name].__class__.__name__: # replace BatchNorm2dLayer with BatchNorm2d in downsample
layer_fn = BatchNorm2d # Get QBNN layer
bn_layer = layer_fn(
num_features=m._modules[name].num_features
)
bn_layer.__dict__.update(m._modules[name].__dict__)
setattr(m, name, bn_layer)



def prepare(model):
Expand All @@ -149,7 +158,7 @@ def prepare(model):
3. run torch.quantize.prepare()
"""
qmodel = QuantizableResNet(QuantizableBottleneck, [3, 4, 6, 3])
qmodel.load_state_dict(model.state_dict())
qmodel.load_state_dict(model.module.state_dict())
qmodel.eval()
enable_prepare(qmodel)
qmodel.qconfig = torch.quantization.get_default_qconfig("onednn")
Expand Down
1 change: 1 addition & 0 deletions bayesian_torch/layers/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self,
affine=True,
track_running_stats=True):
super(BatchNorm2dLayer, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
Expand Down
8 changes: 8 additions & 0 deletions bayesian_torch/scripts/quantize_bayesian_imagenet.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

model=resnet50
mode='test'
val_batch_size=1
num_monte_carlo=1

python examples/main_bayesian_imagenet_bnn2qbnn.py --mode=$mode --val_batch_size=$val_batch_size --num_monte_carlo=$num_monte_carlo ../../datasets

0 comments on commit 3e74914

Please sign in to comment.