Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor layout implementation #491

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from .nf4tensor import NF4Tensor, to_nf4
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uint4 import UInt4Tensor
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized
from .affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized,
LayoutType,
PlainLayoutType,
TensorCoreTiledLayoutType,
)

__all__ = [
"NF4Tensor",
"to_nf4",
"UInt4Tensor"
"AffineQuantizedTensor",
"to_affine_quantized",
"LayoutType",
"PlainLayoutType",
"TensorCoreTiledLayoutType",
]
137 changes: 88 additions & 49 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,35 @@
_ATEN_OP_OR_TORCH_FN_TABLE,
_register_layout_cls,
_get_layout_tensor_constructor,
LayoutType,
)
from typing import ClassVar
from dataclasses import dataclass

aten = torch.ops.aten

@dataclass(frozen=True)
class PlainLayoutType(LayoutType):
Copy link
Member

@msaroufim msaroufim Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment or error that this shouldnt be instantiated directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be instantiated I think, are you talking about LayoutType?

Copy link
Member

@msaroufim msaroufim Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see I guess I'm a bit thrown off because a data classes primary goal is to store data wheras this class stores nothing and its really just a name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have instead done an enum like this

from enum import Enum
class Operations(Enum):
    ADD = (1,)
    SUBTRACT = (2,)
    MULTIPLY = (3,)
    DIVIDE = (4, 'precision') 

enums are also a class so you can override __init__ and define a func that only applies on DIVIDE for example

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I don't follow the DIVIDE part, can you elaborate a bit? is this talking about how to support TensorCoreTiledLayoutType that has a inner_k_tiles argument?

pass

@dataclass(frozen=True)
class TensorCoreTiledLayoutType(LayoutType):
inner_k_tiles: int = 8
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim see here, we have extra configurable arguments, it's not just a name so I'm not sure how enum would work here


def pre_process(self, input: torch.Tensor) -> torch.Tensor:
orig_out_features, orig_in_features = input.shape
in_features = find_multiple(orig_in_features, 1024)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are the in and out numbers coming from? I constants like this were a function of the dtype as well

Copy link
Contributor Author

@jerryzh168 jerryzh168 Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is comes from tinygemm kernel I think, this layout only applies to uint4 dtype

out_features = find_multiple(orig_out_features, 8)
input = torch.nn.functional.pad(
input,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)
return input

def extra_repr(self):
return f"inner_k_tiles={self.inner_k_tiles}"


def _aqt_is_int8(aqt):
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
return (
Expand Down Expand Up @@ -52,10 +77,10 @@ class AQTLayout(torch.Tensor):
"""
Base class for the layout tensor for `AffineQuantizedTensor`
"""
# this should be set for each layout class during registration
extended_layout: Optional[str] = None
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pass

def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_layout_type(self) -> LayoutType:
pass

@classmethod
Expand All @@ -64,9 +89,15 @@ def from_plain(
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
pass

def __repr__(self):
int_data, scale, zero_point = self.get_plain()
layout_type = self.get_layout_type()
return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
Expand Down Expand Up @@ -194,30 +225,17 @@ def from_float(
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
extended_layout: str = "plain",
# TODO: this is only for "tensor_core_tiled", need to figure out
# the proper API for this arg
inner_k_tiles: Optional[int] = None,
layout_type: LayoutType = PlainLayoutType(),
):
original_shape = input_float.shape
if extended_layout == "tensor_core_tiled":
orig_out_features, orig_in_features = input_float.shape
in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)
input_float = torch.nn.functional.pad(
input_float,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)
input_float = layout_type.pre_process(input_float)

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
int_data = layout_type.post_process(int_data)

layout_cls_ctr = get_layout_tensor_constructor(extended_layout)
# TODO: this is temporary, need to come up with the proper UX
if extended_layout == "tensor_core_tiled":
layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles)
else:
layout_tensor = layout_cls_ctr(int_data, scale, zero_point)
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
return cls(
layout_tensor,
block_size,
Expand All @@ -229,8 +247,8 @@ def from_float(
)

@property
def extended_layout(self) -> str:
return self.layout_tensor.extended_layout
def layout_type(self) -> str:
return self.layout_tensor.layout_type

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Expand Down Expand Up @@ -308,13 +326,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
def implements(aten_ops_or_torch_fn):
return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn)

def register_layout_cls(extended_layout: str):
return _register_layout_cls(AffineQuantizedTensor, extended_layout)
def register_layout_cls(layout_type_class: type(LayoutType)):
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)

def get_layout_tensor_constructor(extended_layout: str):
return _get_layout_tensor_constructor(AffineQuantizedTensor, extended_layout)
def get_layout_tensor_constructor(layout_type_class: type(LayoutType)):
return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class)

@register_layout_cls("plain")
@register_layout_cls(PlainLayoutType)
class PlainAQTLayout(AQTLayout):
"""
Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point
Expand All @@ -330,6 +348,7 @@ def __new__(
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
kwargs = {}
kwargs["device"] = int_data.device
Expand All @@ -346,34 +365,39 @@ def __init__(
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
self.int_data = int_data
self.scale = scale
self.zero_point = zero_point
self.layout_type = layout_type

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point"], []
return ["int_data", "scale", "zero_point"], [self.layout_type]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
return cls(int_data, scale, zero_point)
layout_type, = tensor_attributes
return cls(int_data, scale, zero_point, layout_type)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
self.layout_type,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
self.layout_type,
)

@classmethod
Expand All @@ -398,19 +422,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self):
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.int_data, self.scale, self.zero_point

def get_layout_type(self) -> LayoutType:
return self.layout_type

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
return cls(int_data, scale, zero_point)
assert isinstance(layout_type, PlainLayoutType)
return cls(int_data, scale, zero_point, layout_type)

@register_layout_cls("tensor_core_tiled")
@register_layout_cls(TensorCoreTiledLayoutType)
class TensorCoreTiledAQTLayout(AQTLayout):
"""
Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only,
Expand All @@ -427,6 +456,7 @@ def __new__(
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
transposed: bool,
layout_type: LayoutType,
):
kwargs = {}
kwargs["device"] = packed_weight.device
Expand All @@ -443,31 +473,40 @@ def __init__(
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
transposed: bool,
layout_type: LayoutType,
):
self.packed_weight = packed_weight
self.scale_and_zero = scale_and_zero
self.transposed = False
self.layout_type = layout_type

def __tensor_flatten__(self):
return ["packed_weight", "scale_and_zero"], [self.transposed]
return ["packed_weight", "scale_and_zero"], [self.transposed, self.layout_type]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
transposed, = tensor_attributes
return cls(packed_weight, scale_and_zero, transposed)
transposed, layout_type, = tensor_attributes
return cls(packed_weight, scale_and_zero, transposed, layout_type)

@classmethod
def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType
):
assert isinstance(layout_type, TensorCoreTiledLayoutType)
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
return cls(packed_weight, scale_and_zero, False)
return cls(packed_weight, scale_and_zero, False, layout_type)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
Expand All @@ -477,18 +516,15 @@ def to(self, *args, **kwargs):
return self.__class__(
self.packed_weight.to(device),
self.scale_and_zero.to(device),
self.transposed
self.transposed,
self.layout_type,
)

def _apply_fn_to_data(self, fn):
self.packed_weight = fn(self.packed_weight)
self.scale_and_zero = fn(self.scale_and_zero)
return self

def __repr__(self):
int_data, scale, zero_point = self.get_plain()
return f"TensorCoreTiledAQTLayout(int_data={int_data}, scale={scale}, zero_point={zero_point})"

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
Expand All @@ -511,7 +547,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self):
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
quantize_affine,
Expand Down Expand Up @@ -542,6 +578,9 @@ def get_plain(self):
int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain)
return int_data, scale, zero

def get_layout_type(self) -> LayoutType:
return self.layout_type

def _quantized_linear_op(input_tensor, weight_qtensor, bias):
"""
Quantized version of F.linear operator
Expand All @@ -565,8 +604,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
is_cuda and
input_is_int8 and
input_tensor.dtype == weight_qtensor.dtype and
input_tensor.extended_layout == "plain" and
weight_qtensor.extended_layout == "plain"
isinstance(input_tensor.layout_type, PlainLayoutType) and
isinstance(weight_qtensor.layout_type, PlainLayoutType)
):
#
# 1. do the matrix form of dot(X_i, W_j)
Expand Down Expand Up @@ -608,7 +647,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
weight_qtensor.extended_layout == "tensor_core_tiled"
isinstance(weight_qtensor.layout_type, TensorCoreTiledLayoutType)
):
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
assert input_tensor.shape[-1] == weight_qtensor.shape[1], (
Expand Down Expand Up @@ -651,7 +690,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
weight_qtensor.extended_layout == "plain"
isinstance(weight_qtensor.layout_type, PlainLayoutType)
):
# TODO: enable cpu and mps efficient path
# per channel int8 weight only quantizated mm
Expand Down
Loading
Loading