Skip to content

[DRAFT] Feat (llm/bias): Proposal multi-GPU bias correction #1212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from torch import nn
import torch.distributed as dist
import torch.nn.functional as F

from brevitas.nn import QuantHardTanh
Expand All @@ -19,6 +20,7 @@
from brevitas.proxy.runtime_quant import ClampQuantProxyFromInjector
from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector
from brevitas.quant_tensor import QuantTensor
from brevitas_examples.llm.llm_quant.dist_utils import TensorBucket

from .base import Transform

Expand Down Expand Up @@ -105,9 +107,10 @@ def __exit__(self, type, value, traceback):

class bias_correction_mode:

def __init__(self, model, enabled=True, skip_if_no_bias=False):
def __init__(self, model, enabled=True, skip_if_no_bias=False, batch_size=1):
self.model = model
self.bias_correction = _BiasCorrection(skip_if_no_bias=skip_if_no_bias)
self.bias_correction = _BiasCorrection(
skip_if_no_bias=skip_if_no_bias, batch_size=batch_size)
self.enabled = enabled
self.hooks = []
self.output_quant_modules = []
Expand Down Expand Up @@ -255,7 +258,7 @@ class _BiasCorrection(DisableEnableQuantization):

LAYERS = (QuantWBIOL,)

def __init__(self, layers=LAYERS, skip_if_no_bias=False):
def __init__(self, layers=LAYERS, skip_if_no_bias=False, batch_size=1):
super(_BiasCorrection, self).__init__()
self.layers = layers
self.iterations = {}
Expand All @@ -264,10 +267,14 @@ def __init__(self, layers=LAYERS, skip_if_no_bias=False):
self.collect_float_mean_hooks = []
self.correct_bias_hooks = []
self.skip_if_no_bias = skip_if_no_bias
self.batch_size = batch_size

def compute_mean(self, inp, transpose_dim):
inp = inp.transpose(0, transpose_dim)
return inp.reshape(inp.shape[0], -1).mean(dim=1).detach()
# TODO: Validate
if len(inp.shape) == 2:
inp = inp.view(inp.shape[0], -1, self.batch_size)
return inp.mean(dim=1).sum(dim=-1).detach()

def channel_dim(self, inp, module):
if len(inp.shape) == 3 and isinstance(module, QuantLinear):
Expand All @@ -289,10 +296,33 @@ def update_correction(self, name, error):
else:
self.correction_map[name] += error

def _maybe_synchronize_correction_maps(self, bucketize: bool = False) -> None:
if dist.is_initialized():
if bucketize:
names, tensors = zip(*self.correction_map.items())
i = 0
for tensor_bucket in TensorBucket.bucketize_tensors(tensors):
# Synchronize the correction maps
dist.all_reduce(tensor_bucket.flattened_tensor, op=dist.ReduceOp.SUM)
# Reassign correction maps
for tensor in tensor_bucket.debucketize_tensors():
self.correction_map[names[i]] = tensor
i += 1
else:
for name in self.correction_map:
# Synchronize the correction maps
dist.all_reduce(self.correction_map[name], op=dist.ReduceOp.SUM)

def _get_correction_map_reduce_size(self, name: str) -> int:
world_size = dist.get_world_size() if dist.is_initialized() else 1
return world_size * self.iterations[name] * self.batch_size

def apply_correction(self, model):
# Maybe synchronize correction maps if multiple processes are being run
self._maybe_synchronize_correction_maps()
for name, module in model.named_modules():
if name in self.correction_map.keys():
correction = self.correction_map[name] / self.iterations[name]
correction = self.correction_map[name] / self._get_correction_map_reduce_size(name)
# When accelerate is enabled, bring tensors onto the device to avoid allocating a meta parameter.
if hasattr(module, 'allocate_params'):
module.allocate_params(module)
Expand Down
7 changes: 6 additions & 1 deletion src/brevitas_examples/common/accelerate_utils/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from accelerate.utils.modeling import named_module_tensors
from psutil import virtual_memory
import torch
import torch.distributed as dist

import brevitas.config as config
from brevitas.graph.utils import get_module
Expand Down Expand Up @@ -369,9 +370,13 @@ def find_all_devices(data):
def calc_gpu_device_map(absolute_mem_margin: float = 2.0 * 1e9,
relative_mem_margin: float = 0.3) -> Dict[int, float]:
torch.cuda.empty_cache()
# Ensure GPU exclusion when multiple processes are run
rank, world_size = (dist.get_rank(), dist.get_world_size()) if dist.is_initialized() else (0, 1)
gpu_device_map = {
i: (torch.cuda.mem_get_info(i)[0] - absolute_mem_margin) * (1.0 - relative_mem_margin)
for i in range(torch.cuda.device_count())}
for i in range(
rank * (torch.cuda.device_count() // world_size), (rank + 1) *
(torch.cuda.device_count() // world_size))}
return gpu_device_map


Expand Down
38 changes: 36 additions & 2 deletions src/brevitas_examples/llm/llm_quant/bias_corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,47 @@
"""

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from tqdm import tqdm

from brevitas.graph.calibrate import bias_correction_mode
from brevitas_examples.llm.llm_quant.data_utils import DatasetToDevice


# Function to batchify the dataset
def collate_fn(kwargs_list, return_tensors="pt"):
kwargs = {}
for curr_dict in kwargs_list:
for key, value in curr_dict.items():
if isinstance(value, torch.Tensor):
if key not in kwargs:
kwargs[key] = []
kwargs[key].append(value)
else:
if key not in kwargs:
kwargs[key] = value
for key, value in kwargs.items():
if isinstance(value, list) and len(value) > 0:
kwargs[key] = torch.cat(kwargs[key], dim=0)
return kwargs


def _maybe_partition_dataloader(dataloader: DatasetToDevice) -> DatasetToDevice:
# If multiple processes are running simultaneously, each receives a different partition
if dist.is_initialized():
rank = dist.get_rank()
partition_size = len(dataloader) // dist.get_world_size()
dataloader = DatasetToDevice(
dataloader.data[rank * partition_size:(rank + 1) * partition_size], dataloader.device)
return dataloader


@torch.no_grad()
def apply_bias_correction(model, dataloader):
with bias_correction_mode(model):
def apply_bias_correction(model, dataloader, batch_size=1):
dataloader = _maybe_partition_dataloader(dataloader)
bias_correction_dataloader = DataLoader(
dataloader, collate_fn=collate_fn, batch_size=batch_size)
with bias_correction_mode(model, batch_size=batch_size):
for inps in tqdm(dataloader):
model(**inps)
46 changes: 46 additions & 0 deletions src/brevitas_examples/llm/llm_quant/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from builtins import staticmethod
from functools import reduce
from typing import Any, List

import torch
import torch.distributed as dist


class TensorBucket:

def __init__(self, flattened_tensors: List[torch.Tensor], shapes: List[torch.Size]) -> None:
self.flattened_tensor = torch.cat(flattened_tensors)
self.shapes = shapes

@staticmethod
def bucketize_tensors(tensors: List[torch.Tensor], bucket_size: int = 1e8) -> "TensorBucket":
flattened_tensors_bucket, shapes_bucket = [], []
curr_bucket_size = 0
for tensor in tensors:
# Check if tensor fits in the current bucket
if curr_bucket_size + tensor.numel() * tensor.element_size() > bucket_size:
# If not, create a new bucket
yield TensorBucket(flattened_tensors=flattened_tensors_bucket, shapes=shapes_bucket)
curr_bucket_size = 0
flattened_tensors_bucket, shapes_bucket = [], []
# Tensor fits in the current bucket
flattened_tensors_bucket.append(tensor.view(-1))
shapes_bucket.append(tensor.shape)
# Create remaining bucket
yield TensorBucket(flattened_tensors=flattened_tensors_bucket, shapes=shapes_bucket)

def debucketize_tensors(self) -> torch.Tensor:
offset = 0
for shape in self.shapes:
n_element = reduce(lambda x, y: x * y, shape)
yield self.flattened_tensor[offset:offset + n_element].view(shape)
offset += n_element


def init_process_group(backend: str = "nccl") -> None:
# Verify if the script was launched with torch_elastic
if dist.is_torchelastic_launched():
# If that is the case, initialize the default process group
dist.init_process_group(backend=backend)
3 changes: 3 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from brevitas_examples.common.generative.quantize import generate_quantizers
from brevitas_examples.llm.llm_args import create_llm_args_parser
from brevitas_examples.llm.llm_args import validate
from brevitas_examples.llm.llm_quant import dist_utils
from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction
from brevitas_examples.llm.llm_quant.calibrate import apply_calibration
from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model
Expand Down Expand Up @@ -634,6 +635,8 @@ def parse_args(args, override_defaults={}):


def main():
# Initialize default distributed group if script is launched with torchrun
dist_utils.init_process_group(backend="nccl")
overrides = override_defaults(sys.argv[1:])
args, extra_args = parse_args(sys.argv[1:], override_defaults=overrides)
quantize_llm(args, extra_args)
Expand Down