Skip to content

Commit

Permalink
fix bug in post-training quantization evaluation due to Jit trace.
Browse files Browse the repository at this point in the history
Signed-off-by: Ranganath Krishnan <ranganath.krishnan@intel.com>
  • Loading branch information
ranganathkrishnan committed Sep 27, 2023
1 parent 93cf0d3 commit f5c7126
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
type=int,
default=10,
)
parser.add_argument("--mode", type=str, required=True, help="train | test | ptq | test_ptq")
parser.add_argument("--mode", type=str, required=True, help="train | test | ptq")

parser.add_argument(
"--num_monte_carlo",
Expand Down Expand Up @@ -333,27 +333,30 @@ def main():
'''
model.load_state_dict(checkpoint['state_dict'])


# post-training quantization
model_int8 = quantize(model, calib_loader, args)
model_int8.eval()
model_int8.cpu()

for i, (data, target) in enumerate(calib_loader):
data = data.cpu()
print('Evaluating quantized INT8 model....')
evaluate(args, model_int8, val_loader)

with torch.no_grad():
traced_model = torch.jit.trace(model_int8, data)
traced_model = torch.jit.freeze(traced_model)
#for i, (data, target) in enumerate(calib_loader):
# data = data.cpu()

save_path = os.path.join(
args.save_dir,
'quantized_bayesian_{}_cifar.pth'.format(args.arch))
traced_model.save(save_path)
print('INT8 model checkpoint saved at ', save_path)
print('Evaluating quantized INT8 model....')
evaluate(args, traced_model, val_loader)
#with torch.no_grad():
# traced_model = torch.jit.trace(model_int8, data)
# traced_model = torch.jit.freeze(traced_model)

#save_path = os.path.join(
# args.save_dir,
# 'quantized_bayesian_{}_cifar.pth'.format(args.arch))
#traced_model.save(save_path)
#print('INT8 model checkpoint saved at ', save_path)
#print('Evaluating quantized INT8 model....')
#evaluate(args, traced_model, val_loader)

'''
elif args.mode =='test_ptq':
print('load model...')
if len(args.model_checkpoint) > 0:
Expand All @@ -366,7 +369,7 @@ def main():
model_int8 = torch.jit.freeze(model_int8)
print('Evaluating the INT8 model....')
evaluate(args, model_int8, val_loader)

'''

def train(args, train_loader, model, criterion, optimizer, epoch, tb_writer=None):
batch_time = AverageMeter()
Expand Down

0 comments on commit f5c7126

Please sign in to comment.