Skip to content
This repository was archived by the owner on Oct 1, 2020. It is now read-only.

Qnnpack accuracy very poor on unet model #73

Open
Amitdedhia6 opened this issue Aug 19, 2020 · 3 comments
Open

Qnnpack accuracy very poor on unet model #73

Amitdedhia6 opened this issue Aug 19, 2020 · 3 comments

Comments

@Amitdedhia6
Copy link

I am using Unet model for semantic segmentation. I pass a batch of images to the model. The model is expected to output 0 or 1 for each pixel of the image (depending upon whether pixel is part of person object or not). 0 is for background, and 1 is for foreground.

I am trying to quantize the Unet model with Pytorch quantization apis for ARM architecture. I chose Qnnpack as quantization configuration. However the model accuracy is very poor for both Post training static quantization as well as QAT. The output is always a complete black image i.e. contains only background, no foreground for the person object. The model outputs bX2X224X224 i.e. batch_size X 2 channels (one for forground and one for background) X height X width.

Following is the output values for the center pixels of images - with original model and with quantized model.
image

As seen, the original model has varying output values. Hence when we apply Softmax on dim=1 (i,e, Channel dimension), we get some pixels as 0 and some as 1. This is as per expectation. However, the quantized model always outputs high positive for background and high negative for foreground channel. After applying softmax, all the pixels are background pixels, and the output is black images.

I need some help to find why this is happening. Is it a bug in qnnpack quantization routine?

The model source code is available at here. I used pretrained version of Unet with MobileNetV2 as backbone (check benchmark section in the readme from the source code link) - see here.

I first tried with Fbgemm configuration, and it worked fine in terms of accuracy - no major loss. However, when tried with qnnpack, I face above issues. Following is my code for QAT.

use_sigmoid = False

def poly_lr_scheduler(optimizer, init_lr, curr_iter, max_iter, power=0.9):
    for g in optimizer.param_groups:
        g['lr'] = init_lr * (1 - curr_iter/max_iter)**power


def dice_loss(logits, targets, smooth=1.0):
    """
    logits: (torch.float32)  shape (N, C, H, W)
    targets: (torch.float32) shape (N, H, W), value {0,1,...,C-1}
    """

    if not use_sigmoid:
        outputs = F.softmax(logits, dim=1)
        targets = torch.unsqueeze(targets, dim=1)
        # targets = torch.zeros_like(logits).scatter_(dim=1, index=targets.type(torch.int64), src=torch.tensor(1.0))
        targets = torch.zeros_like(logits).scatter_(dim=1, index=targets.type(torch.int64),
                                                    src=torch.ones(targets.type(torch.int64).shape))

        inter = outputs * targets
        dice = 1 - ((2 * inter.sum(dim=(2, 3)) + smooth) / (outputs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) + smooth))
        return dice.mean()
    else:
        outputs = logits[:,1,:,:]
        outputs = torch.sigmoid(outputs)
        inter = outputs * targets
        dice = 1 - ((2*inter.sum(dim=(1,2)) + smooth) / (outputs.sum(dim=(1,2))+targets.sum(dim=(1,2)) + smooth))
        return dice.mean()

def train_model_for_qat(model, data_loader, num_epochs, batch_size):
    device = torch.device('cpu')
    model_params = [p for p in model.parameters() if p.requires_grad]
    SGD_params = {
        "lr": 1e-2,
        "momentum": 0.9,
        "weight_decay": 1e-8
    }

    optimizer = torch.optim.SGD(model_params, lr=SGD_params["lr"], momentum=SGD_params["momentum"],
                                nesterov=True, weight_decay=SGD_params["weight_decay"])
    init_lr = optimizer.param_groups[0]['lr']
    scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=1)

    max_iter = num_epochs * np.ceil(len(data_loader.dataset) / batch_size) + 5
    curr_batch_num = 0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch} in progress")
        train_data_length = 0
        total_train_loss = 0

        model.train()
        for batch in data_loader:
            # get the next batch
            data, target = batch
            data, target = data.to(device), target.to(device)
            train_data_length += len(batch[0])

            optimizer.zero_grad()
            outputs = model(data)
            loss = dice_loss(outputs, target)
            total_train_loss += loss.item()
            loss.backward()
            optimizer.step()
            curr_batch_num += 1
            poly_lr_scheduler(optimizer, init_lr, curr_batch_num, max_iter, power=0.9)

        total_train_loss = total_train_loss / train_data_length
        scheduler.step()

        if epoch > 10:
            # Freeze quantizer parameters
            model.apply(Q.disable_observer)
        if epoch > 20:
            # Freeze batch norm mean and variance estimates
            model.apply(nn_intrinsic_qat.freeze_bn_stats)

        quantized_model = Q.convert(model.eval(), inplace=False)
        accuracy = eval_model_for_quantization(quantized_model, device)
        print(f"...Accuacy at the end of epoch {epoch} : {accuracy}")
        if (accuracy > 99) and (epoch >= 10):
            print("...GUESS we are done with training now...")
            break

    return total_train_loss, model


Am I missing anything?

One issue that we did encounter is that the upsampling layers of Unet use nn.ConvTranspose2d which is not supported for quantization. Hence before this layer, we need to dequantize tensors, apply nn.ConvTranspose2d, and then requantize for subsequent layers. Can this be reason for lower accuracy?

#------------------------------------------------------------------------------
#   Decoder block
#------------------------------------------------------------------------------
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, block_unit):
        super(DecoderBlock, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, padding=1, stride=2)
        self.block_unit = block_unit
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, input, shortcut):
        # self.deconv = nn.ConvTranspose2d not supported for FBGEMM and QNNPACK quantization
        input = self.dequant(input)
        x = self.deconv(input)
        x = self.quant(x)
        x = torch.cat([x, shortcut], dim=1)
        x = self.block_unit(x)
        return x

The following is the model after QAT training is completed for 30 epochs . . .

UNet(
  (backbone): MobileNetV2(
    (features): Sequential(
      (0): Sequential(
        (0): QuantizedConvReLU2d(3, 32, kernel_size=(3, 3), stride=(2, 2), scale=0.012562132440507412, zero_point=0, padding=(1, 1))
        (1): Identity()
        (2): Identity()
      )
      (1): InvertedResidual(
        (skip_add): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.046556487679481506, zero_point=0, padding=(1, 1), groups=32)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.043083205819129944, zero_point=96)
          (4): Identity()
        )
      )
      (2): InvertedResidual(
        (skip_add): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(16, 96, kernel_size=(1, 1), stride=(1, 1), scale=0.05470738932490349, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(96, 96, kernel_size=(3, 3), stride=(2, 2), scale=0.05578919127583504, zero_point=0, padding=(1, 1), groups=96)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), scale=0.08143805712461472, zero_point=131)
          (7): Identity()
        )
      )
      (3): InvertedResidual(
        (skip_add): QFunctional(
          scale=0.10905726253986359, zero_point=133
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(24, 144, kernel_size=(1, 1), stride=(1, 1), scale=0.021390624344348907, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(144, 144, kernel_size=(3, 3), stride=(1, 1), scale=0.03496978059411049, zero_point=0, padding=(1, 1), groups=144)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), scale=0.07988038659095764, zero_point=166)
          (7): Identity()
        )
      )
      (4): InvertedResidual(
        (skip_add): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(24, 144, kernel_size=(1, 1), stride=(1, 1), scale=0.016173962503671646, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(144, 144, kernel_size=(3, 3), stride=(2, 2), scale=0.05084317922592163, zero_point=0, padding=(1, 1), groups=144)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), scale=0.08057469874620438, zero_point=133)
          (7): Identity()
        )
      )
      (5): InvertedResidual(
        (skip_add): QFunctional(
          scale=0.07931466400623322, zero_point=141
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=0.015451926738023758, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=0.01901066116988659, zero_point=0, padding=(1, 1), groups=192)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=0.03396213427186012, zero_point=137)
          (7): Identity()
        )
      )
      (6): InvertedResidual(
        (skip_add): QFunctional(
          scale=0.10119215399026871, zero_point=149
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=0.009366143494844437, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=0.03307875618338585, zero_point=0, padding=(1, 1), groups=192)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=0.045690517872571945, zero_point=152)
          (7): Identity()
        )
      )
      (7): InvertedResidual(
        (skip_add): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=0.013529903255403042, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(192, 192, kernel_size=(3, 3), stride=(2, 2), scale=0.030076880007982254, zero_point=0, padding=(1, 1), groups=192)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), scale=0.05553155764937401, zero_point=128)
          (7): Identity()
        )
      )
      (8): InvertedResidual(
        (skip_add): QFunctional(
          scale=0.057563915848731995, zero_point=132
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(64, 384, kernel_size=(1, 1), stride=(1, 1), scale=0.008955957368016243, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(384, 384, kernel_size=(3, 3), stride=(1, 1), scale=0.01566135324537754, zero_point=0, padding=(1, 1), groups=384)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), scale=0.02868938073515892, zero_point=162)
          (7): Identity()
        )
      )
      (9): InvertedResidual(
        (skip_add): QFunctional(
          scale=0.05936211720108986, zero_point=140
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(64, 384, kernel_size=(1, 1), stride=(1, 1), scale=0.011350379325449467, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(384, 384, kernel_size=(3, 3), stride=(1, 1), scale=0.013551343232393265, zero_point=0, padding=(1, 1), groups=384)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), scale=0.02829933725297451, zero_point=124)
          (7): Identity()
        )
      )
      (10): InvertedResidual(
        (skip_add): QFunctional(
          scale=0.056326691061258316, zero_point=121
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(64, 384, kernel_size=(1, 1), stride=(1, 1), scale=0.009888351894915104, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(384, 384, kernel_size=(3, 3), stride=(1, 1), scale=0.00840410403907299, zero_point=0, padding=(1, 1), groups=384)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), scale=0.02762036770582199, zero_point=130)
          (7): Identity()
        )
      )
      (11): InvertedResidual(
        (skip_add): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(64, 384, kernel_size=(1, 1), stride=(1, 1), scale=0.010262548923492432, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(384, 384, kernel_size=(3, 3), stride=(1, 1), scale=0.020638082176446915, zero_point=0, padding=(1, 1), groups=384)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), scale=0.03133825585246086, zero_point=114)
          (7): Identity()
        )
      )
      (12): InvertedResidual(
        (skip_add): QFunctional(
          scale=0.049823448061943054, zero_point=106
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=0.007199177984148264, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=0.017748937010765076, zero_point=0, padding=(1, 1), groups=576)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=0.045204587280750275, zero_point=94)
          (7): Identity()
        )
      )
      (13): InvertedResidual(
        (skip_add): QFunctional(
          scale=0.06418105959892273, zero_point=125
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=0.008789398707449436, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(576, 576, kernel_size=(3, 3), stride=(1, 1), scale=0.019841214641928673, zero_point=0, padding=(1, 1), groups=576)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), scale=0.06256742030382156, zero_point=129)
          (7): Identity()
        )
      )
      (14): InvertedResidual(
        (skip_add): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(96, 576, kernel_size=(1, 1), stride=(1, 1), scale=0.011278725229203701, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(576, 576, kernel_size=(3, 3), stride=(2, 2), scale=0.028320688754320145, zero_point=0, padding=(1, 1), groups=576)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), scale=0.06365438550710678, zero_point=132)
          (7): Identity()
        )
      )
      (15): InvertedResidual(
        (skip_add): QFunctional(
          scale=0.08667448908090591, zero_point=127
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(160, 960, kernel_size=(1, 1), stride=(1, 1), scale=0.011708680540323257, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(960, 960, kernel_size=(3, 3), stride=(1, 1), scale=0.026726122945547104, zero_point=0, padding=(1, 1), groups=960)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), scale=0.04459201171994209, zero_point=116)
          (7): Identity()
        )
      )
      (16): InvertedResidual(
        (skip_add): QFunctional(
          scale=0.1879616528749466, zero_point=126
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(160, 960, kernel_size=(1, 1), stride=(1, 1), scale=0.015011523850262165, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(960, 960, kernel_size=(3, 3), stride=(1, 1), scale=0.025075148791074753, zero_point=0, padding=(1, 1), groups=960)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), scale=0.1145012229681015, zero_point=119)
          (7): Identity()
        )
      )
      (17): InvertedResidual(
        (skip_add): QFunctional(
          scale=1.0, zero_point=0
          (activation_post_process): Identity()
        )
        (conv): Sequential(
          (0): QuantizedConvReLU2d(160, 960, kernel_size=(1, 1), stride=(1, 1), scale=0.006071502808481455, zero_point=0)
          (1): Identity()
          (2): Identity()
          (3): QuantizedConvReLU2d(960, 960, kernel_size=(3, 3), stride=(1, 1), scale=0.01608050987124443, zero_point=0, padding=(1, 1), groups=960)
          (4): Identity()
          (5): Identity()
          (6): QuantizedConv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), scale=0.02348274178802967, zero_point=127)
          (7): Identity()
        )
      )
      (18): Sequential(
        (0): QuantizedConvReLU2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), scale=0.08627913892269135, zero_point=0)
        (1): Identity()
        (2): Identity()
      )
    )
  )
  (decoder1): DecoderBlock(
    (deconv): ConvTranspose2d(1280, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (block_unit): InvertedResidual(
      (skip_add): QFunctional(
        scale=1.0, zero_point=0
        (activation_post_process): Identity()
      )
      (conv): Sequential(
        (0): QuantizedConvReLU2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), scale=0.038529299199581146, zero_point=0)
        (1): Identity()
        (2): Identity()
        (3): QuantizedConvReLU2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), scale=0.06666766107082367, zero_point=0, padding=(1, 1), groups=1152)
        (4): Identity()
        (5): Identity()
        (6): QuantizedConv2d(1152, 96, kernel_size=(1, 1), stride=(1, 1), scale=0.0864308550953865, zero_point=117)
        (7): Identity()
      )
    )
    (quant): Quantize(scale=tensor([0.0979]), zero_point=tensor([128]), dtype=torch.quint8)
    (dequant): DeQuantize()
  )
  (decoder2): DecoderBlock(
    (deconv): ConvTranspose2d(96, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (block_unit): InvertedResidual(
      (skip_add): QFunctional(
        scale=1.0, zero_point=0
        (activation_post_process): Identity()
      )
      (conv): Sequential(
        (0): QuantizedConvReLU2d(64, 384, kernel_size=(1, 1), stride=(1, 1), scale=0.06379921734333038, zero_point=0)
        (1): Identity()
        (2): Identity()
        (3): QuantizedConvReLU2d(384, 384, kernel_size=(3, 3), stride=(1, 1), scale=0.28728926181793213, zero_point=0, padding=(1, 1), groups=384)
        (4): Identity()
        (5): Identity()
        (6): QuantizedConv2d(384, 32, kernel_size=(1, 1), stride=(1, 1), scale=0.2210002988576889, zero_point=126)
        (7): Identity()
      )
    )
    (quant): Quantize(scale=tensor([0.0561]), zero_point=tensor([120]), dtype=torch.quint8)
    (dequant): DeQuantize()
  )
  (decoder3): DecoderBlock(
    (deconv): ConvTranspose2d(32, 24, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (block_unit): InvertedResidual(
      (skip_add): QFunctional(
        scale=1.0, zero_point=0
        (activation_post_process): Identity()
      )
      (conv): Sequential(
        (0): QuantizedConvReLU2d(48, 288, kernel_size=(1, 1), stride=(1, 1), scale=0.11421681195497513, zero_point=0)
        (1): Identity()
        (2): Identity()
        (3): QuantizedConvReLU2d(288, 288, kernel_size=(3, 3), stride=(1, 1), scale=0.20177718997001648, zero_point=0, padding=(1, 1), groups=288)
        (4): Identity()
        (5): Identity()
        (6): QuantizedConv2d(288, 24, kernel_size=(1, 1), stride=(1, 1), scale=0.21056368947029114, zero_point=114)
        (7): Identity()
      )
    )
    (quant): Quantize(scale=tensor([0.1462]), zero_point=tensor([123]), dtype=torch.quint8)
    (dequant): DeQuantize()
  )
  (decoder4): DecoderBlock(
    (deconv): ConvTranspose2d(24, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (block_unit): InvertedResidual(
      (skip_add): QFunctional(
        scale=1.0, zero_point=0
        (activation_post_process): Identity()
      )
      (conv): Sequential(
        (0): QuantizedConvReLU2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=0.1248798817396164, zero_point=0)
        (1): Identity()
        (2): Identity()
        (3): QuantizedConvReLU2d(192, 192, kernel_size=(3, 3), stride=(1, 1), scale=0.1945924311876297, zero_point=0, padding=(1, 1), groups=192)
        (4): Identity()
        (5): Identity()
        (6): QuantizedConv2d(192, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.3521292805671692, zero_point=135)
        (7): Identity()
      )
    )
    (quant): Quantize(scale=tensor([0.1450]), zero_point=tensor([140]), dtype=torch.quint8)
    (dequant): DeQuantize()
  )
  (conv_last): Sequential(
    (0): QuantizedConv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), scale=1.0138226747512817, zero_point=131, padding=(1, 1))
    (1): QuantizedConv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), scale=2.995656728744507, zero_point=128, padding=(1, 1))
  )
  (quant): Quantize(scale=tensor([0.0186]), zero_point=tensor([114]), dtype=torch.quint8)
  (dequant): DeQuantize()
)

Any help is highly appreciated - thanks.

@Amitdedhia6
Copy link
Author

Question - typically how many epochs are good enough for QAT fine tuning? And do we need to supply loads of images for training (like say 5000), or few can suffice (say 100)?

@Amitdedhia6
Copy link
Author

Amitdedhia6 commented Aug 20, 2020

Question - suppose I performed QAT for say n epochs. Then (before calling torch.quantization.Convert) generated model output. Then called torch.quantization.Convert. And again generated output. Will the two outputs be same? How much gap is expected?

@Amitdedhia6
Copy link
Author

I raised this issue on discussion forum as well. Please refer here for more details on the issue.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant