Skip to content

Commit

Permalink
[sharktank] Add common model config, export and compile
Browse files Browse the repository at this point in the history
We don't have a standard way to configure, export and compile sharktank
models.

Here is introduced such mechanism and for demonstration the CLIP text
model is refactored to utilize this new approach.

`config.json`:
```
{
    "model_type": "MyModel",
    "mlir_path": "model.mlir",
    "parameters_path": "model.irpa",
    "iree_module_path": "model.vmfb",
    "compile_args": ["--iree-hal-target-device=local"],
    "export_functions": [
        {
            "function": "forward",
            "batch_sizes": [1, 2, 3]
        }
    ]
}
```

usage
```
model = create_model("config.json")
model.export()
model.compile()
```
  • Loading branch information
sogartar committed Feb 27, 2025
1 parent 3273c83 commit f2850be
Show file tree
Hide file tree
Showing 18 changed files with 822 additions and 219 deletions.
14 changes: 11 additions & 3 deletions sharktank/sharktank/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder
from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten
from .types.tensors import ShardedTensor
from .layers import BaseLayer
from .layers import BaseLayer, ThetaLayer
from .types.theta import mark_export_external_theta
from torch.utils._pytree import PyTree, _is_leaf
import functools

Expand Down Expand Up @@ -177,9 +178,10 @@ def flat_fn(*args, **kwargs):
assert False, "TODO: implement the case when not using an FxProgramsBuilder"


def export_static_model_mlir(
def export_model_mlir(
model: BaseLayer,
output_path: PathLike,
*,
function_batch_size_pairs: Optional[dict[Optional[str], list[int]]] = None,
batch_sizes: Optional[list[int]] = None,
):
Expand All @@ -199,6 +201,9 @@ def export_static_model_mlir(

assert not (function_batch_size_pairs is not None and batch_sizes is not None)

if isinstance(model, ThetaLayer):
mark_export_external_theta(model.theta)

if batch_sizes is not None:
function_batch_size_pairs = {None: batch_sizes}

Expand All @@ -210,12 +215,15 @@ def export_static_model_mlir(
for function, batch_sizes in function_batch_size_pairs.items():
for batch_size in batch_sizes:
args, kwargs = model.sample_inputs(batch_size, function)
dynamic_shapes = model.dynamic_shapes_for_export(
batch_size=batch_size, function=function
)

@fxb.export_program(
name=f"{function or 'forward'}_bs{batch_size}",
args=args,
kwargs=kwargs,
dynamic_shapes=None,
dynamic_shapes=dynamic_shapes,
strict=False,
)
def _(model, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .base import BaseLayer, ThetaLayer
from .base import *
from .conv import Conv2DLayer
from .kv_cache import PagedKVCache
from .causal_llm import BaseCausalLMModel
Expand Down
236 changes: 226 additions & 10 deletions sharktank/sharktank/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Dict, Optional
from typing import Any, Dict, Optional
from collections import OrderedDict
from collections.abc import Mapping
from abc import ABCMeta
import torch
import torch.nn as nn
from os import PathLike

from ..types import InferenceTensor, Theta, AnyTensor
from ..types import InferenceTensor, Theta, AnyTensor, Dataset
from ..utils import debugging
from .. import ops
from .configs import ModelConfig, ExportFunctionConfig, DynamicBatchSize

__all__ = [
"BaseLayer",
"ThetaLayer",
"create_model",
"get_model_type_id",
"model_registry",
]


Expand All @@ -33,12 +37,80 @@ def _set_recursively_submodules_default_trace_tensor_key_prefix(
)


class BaseLayer(nn.Module):
def get_model_type_id(model_type: type["BaseLayer"]) -> str:
"""Get a string representation of the model type."""
return f"{model_type.__module__}.{model_type.__name__}"


def create_model(config: PathLike | ModelConfig, /) -> "BaseLayer":
"""
Create model from a configuration.
Example
config.json:
```
{
"model_type": "MyModel",
"mlir_path": "model.mlir",
"parameters_path": "model.irpa",
"iree_module_path": "model.vmfb",
"compile_args": ["--iree-hal-target-device=local"],
"export_functions": [
{
"function": "forward",
"batch_sizes": [1, 2, 3]
}
]
]
}
```
usage
```
model = create_model("config.json")
model.export()
model.compile()
```
"""
if not isinstance(config, ModelConfig):
config = ModelConfig.load(config)

return config.model_type.from_config(config)


model_registry: dict[str, type["BaseLayer"]] = {}
"""Registry of all model types.
This is used to dispatch when construction a model form a config."""


class BaseLayerMetaClass(ABCMeta):
def __init__(cls, clsname, bases, methods):
super().__init__(clsname, bases, methods)
model_registry[get_model_type_id(cls)] = cls


class BaseLayer(nn.Module, metaclass=BaseLayerMetaClass):
"""Base class of all of our layers."""

def __init__(self):
def __init__(self, config: ModelConfig | None = None):
super().__init__()
self._trace_tensor_key_prefix = ""
self.config = config

# Can be overridden is derived classes.
self.default_export_function = "forward"
self.default_export_batch_sizes = [1]

@classmethod
def from_config(cls, config: ModelConfig, /) -> "BaseLayer":
"""Create a model from config.
Override in derived classes if any special behavior is desired."""
return cls(config=config)

@classmethod
def config_type(cls) -> type[ModelConfig]:
"""Return the type of the config for this model."""
raise NotImplementedError()

def set_recursively_submodules_default_trace_tensor_key_prefix(self):
"""All submodules get a trace key prefix that reflects their nesting with
Expand Down Expand Up @@ -107,8 +179,8 @@ def assert_not_nan(self, *ts: torch.Tensor):
raise AssertionError(f"Tensor contains nans! {t}")

def sample_inputs(
self, batch_size: int = 1, function: Optional[str] = None
) -> tuple[tuple[AnyTensor], OrderedDict[str, AnyTensor]]:
self, batch_size: int | None = 1, function: Optional[str] = None
) -> tuple[tuple[AnyTensor, ...], OrderedDict[str, AnyTensor]]:
"""Return sample inputs that can be used to run the function from the model.
If function is None then layer is treated as the callable.
E.g.
Expand All @@ -121,15 +193,159 @@ def sample_inputs(
"""
raise NotImplementedError()

def dynamic_shapes_for_export(
self,
batch_size: int | DynamicBatchSize | None = 1,
function: Optional[str] = None,
) -> Dict[str, Any] | tuple[Any, ...] | list[Any] | None:
"""During export the result is directly passed to the underlying export function."""
return None

def export_mlir(self, path: PathLike | None = None, /):
"""Export the model into MLIR format.
Exporting is driven by the model's configuration.
Can be overridden in derived classes."""

if path is None:
path = self.config.mlir_path
if path is None:
raise ValueError("Missing MLIR export path.")

export_functions = [
ExportFunctionConfig(
function=self.default_export_function,
batch_sizes=self.default_export_batch_sizes,
)
]
if self.config.export_functions is not None:
export_functions = self.config.export_functions

function_batch_size_pairs = {
export_function.function
or self.default_export_function: export_function.batch_sizes
or self.default_export_batch_sizes
for export_function in export_functions
}
from ..export import export_model_mlir

export_model_mlir(
model=self,
output_path=path,
function_batch_size_pairs=function_batch_size_pairs,
)

def export(self, mlir_path: PathLike | None = None, /, *args, **kwargs):
"""Export MLIR and any other artifacts required for compilation.
Can be overridden in derived classes."""
self.export_mlir(mlir_path)

def compile(self, output_path: PathLike | None = None, /):
"""Compile the model.
Does not do auto-export, requires the model to be exported first."""
if output_path is None:
output_path = self.config.iree_module_path
if output_path is None:
raise ValueError("Missing compile output path.")

from iree.compiler import compile_file

compile_file(
self.config.mlir_path,
output_file=output_path,
extra_args=self.config.get_compile_args(),
)


class ThetaLayer(BaseLayer):
"Base class for layers that derive parameters from a Theta object."

def __init__(self, theta: Theta):
super().__init__()
def __init__(self, theta: Theta | None = None, config: ModelConfig | None = None):
super().__init__(config=config)
if theta is None:
theta = self.load_theta()
if theta is None:
theta = self.generate_random_theta()
self.theta = theta

def theta_tensor(self, name: str) -> InferenceTensor:
# TODO: We may need to do some bookkeeping here to ensure export
# tracks all of these.
return self.theta.tensor(name)

def shard_theta(self, theta: Theta) -> Theta:
"""Override to implement theta sharding.
This default implementation supports only the trivial case of no sharding."""
if (
self.config.tensor_parallelism is not None
and self.config.tensor_parallelism != 1
):
raise ValueError(
"Theta sharding for model "
f"{get_model_type_id(self.__class__)} is not supported."
)
return theta

def load_theta(self) -> Theta | None:
"""Load a theta if it exists.
This can be either an IRPA/GGUF parameters file or a hugging face model."""
assert self.config is not None

needs_sharding = True

parameters_path = self.config.parameters_path
if parameters_path is not None:
dataset = Dataset.load(parameters_path)
theta = dataset.root_theta
tensor_parallelism = dataset.properties.get("tensor_parallelism", 1)
if (
tensor_parallelism != 1
and tensor_parallelism != self.config.tensor_parallelism
):
raise ValueError(
"Could not shard theta that is already sharded "
"with different tensor_parallelism. "
f"Desired is {self.config.tensor_parallelism}, "
f"actual is {tensor_parallelism}"
)
needs_sharding = tensor_parallelism != self.config.tensor_parallelism
elif self.config.hugging_face_repo_id is not None:
theta = self.load_theta_from_hugging_face()
else:
return None

if needs_sharding:
theta = self.shard_theta(theta)

return theta

def load_theta_from_hugging_face(self) -> Theta:
"""Override to load a theta form Hugging Face."""
raise NotImplementedError()

def generate_random_theta(self) -> Theta:
"""Initialize a theta with random contents.
The generation should respect the model configuration like rng_seed.
Override in derived classes."""
raise NotImplementedError()

def export_parameters(self, path: PathLike | None = None, /):
"Export model parameters (includes the theta) into an IRPA/GGUF file."
if path is None:
path = self.config.parameters_path
if path is None:
raise ValueError("Missing model parameters export path.")

properties = self.config.asdict_for_saving()
dataset = Dataset(properties=properties, root_theta=self.theta)
dataset.save(path)

def export(
self,
mlir_path: PathLike | None = None,
parameters_path: PathLike | None = None,
/,
*args,
**kwargs,
):
super().export(mlir_path)
self.export_parameters(parameters_path)
1 change: 1 addition & 0 deletions sharktank/sharktank/layers/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .llm_configs import *
from .config import *
Loading

0 comments on commit f2850be

Please sign in to comment.