Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Initial commit containing new Triton kernels for multi lora serving. #5025

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions benchmarks/kernels/benchmark_sgmv_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
import triton

from tests.lora.test_sgmv_triton import MAX_TEST_POWER, setup_
from vllm.model_executor.layers.lora import sgmv_triton as sgmv


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['S'], # argument names to use as an x-axis for the plot
x_vals=[16 * 2**i for i in range(3, 6)] +
[4096], # different possible values for `x_name`
line_arg=
'R', # argument name which corresponds to a different line in the plot
line_vals=[64, None], # possible values for `line_arg``
line_names=['Rank=64', f'Random Rank up to {2**MAX_TEST_POWER}'
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="ms", # label name for the y-axis
plot_name=
"sgmv", # name for the plot. Used as file name for saving the plot too.
args={},
))
def benchmark_repeats_expand(S, R, repeats_per_lora=1):
weights, w_start, ranks, w_locs, indices, repeats, _, R, dtype = setup_(
S, R, 4096, dtype=torch.bfloat16, repeats_per_lora=repeats_per_lora)

buffer = torch.randn((S, R), device='cuda', dtype=torch.float32)
out = torch.randn((S, 4096), device='cuda', dtype=dtype)
ms = triton.testing.do_bench(lambda: sgmv.sgmv_expand(buffer,
weights,
out,
w_start,
ranks,
w_locs,
indices,
repeats,
out_col_offset=0),
warmup=500,
rep=4000)
return ms


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['S'], # argument names to use as an x-axis for the plot
x_vals=[16 * 2**i for i in range(3, 6)] +
[4096], # different possible values for `x_name`
line_arg=
'R', # argument name which corresponds to a different line in the plot
line_vals=[64, None], # possible values for `line_arg``
line_names=['Rank=64', f'Random Rank up to {2**MAX_TEST_POWER}'
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="ms", # label name for the y-axis
plot_name=
"sgmv", # name for the plot. Used as file name for saving the plot too.
args={},
))
def benchmark_repeats_shrink(S, R, repeats_per_lora=1):
weights, w_start, ranks, w_locs, indices, repeats, _, R, dtype = setup_(
S, R, 4096, dtype=torch.bfloat16, repeats_per_lora=repeats_per_lora)

x = torch.rand((S, 4096), device='cuda', dtype=dtype)
out = torch.zeros((S, R), device='cuda', dtype=torch.float32)
ms = triton.testing.do_bench(lambda: sgmv.sgmv_shrink(
x, weights, out, w_start, ranks, w_locs, indices, repeats, R),
warmup=500,
rep=4000)
return ms


if __name__ == '__main__':
# NOTE: the random rank benchmark is random ranks up to 2^MAX_TEST_POWER,
# not random up to the rank specified,
# so it doesn't change when you change the rank you're testing
print('These benchmarks can vary a decent amount sometimes. ', end='')
print('They should be consistent across increasing seq length ')
print('(slower but with strong superlinear scaling), ', end='')
print('consistently faster with random ranks up to the same rank and')
print('faster with increasing repeats per lora.')
print('Times are in ms.')
print('-' * 40)
print('Expand | repeats [1]')
benchmark_repeats_expand.run(show_plots=False,
print_data=True,
repeats_per_lora=1)
print('-' * 40)
print('Shrink | repeats [1]')
benchmark_repeats_shrink.run(show_plots=False,
print_data=True,
repeats_per_lora=1)

print('-' * 40)
print('Expand | repeats [8]')
benchmark_repeats_expand.run(show_plots=False,
print_data=True,
repeats_per_lora=8)
print('-' * 40)
print('Shrink | repeats [8]')
benchmark_repeats_shrink.run(show_plots=False,
print_data=True,
repeats_per_lora=8)

# set repeats >= 16 for plaid mode
# (tl.dot is applicable which makes it fast)
print('-' * 40)
print('Expand | repeats [16]')
benchmark_repeats_expand.run(show_plots=False,
print_data=True,
repeats_per_lora=16)
print('-' * 40)
print('Shrink | repeats [16]')
benchmark_repeats_shrink.run(show_plots=False,
print_data=True,
repeats_per_lora=16)

print('-' * 40)
print('Expand | repeats [32]')
benchmark_repeats_expand.run(show_plots=False,
print_data=True,
repeats_per_lora=32)
print('-' * 40)
print('Shrink | repeats [32]')
benchmark_repeats_shrink.run(show_plots=False,
print_data=True,
repeats_per_lora=32)
print('-' * 40)
147 changes: 147 additions & 0 deletions tests/lora/test_sgmv_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import math

import pytest
import torch

from vllm.model_executor.layers.lora import sgmv_triton as sgmv

MAX_TEST_POWER = 6
SEED = 42


def assert_close(a, b, dtype, tl_dot=False):
rtol, atol = {
torch.float16: (5e-3, 5e-3) if not tl_dot else (1e-2, 7e-2),
torch.bfloat16: (3e-2, 2e-2) if not tl_dot else (3e-2, 1e-1),
torch.float32: (2e-3, 3e-4) if not tl_dot else (1e-2, 7e-2),
}[dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)


def setup_(S, R, H, dtype, repeats_per_lora=1):
S = math.ceil(S / repeats_per_lora) * repeats_per_lora
num_unique = S // repeats_per_lora
if R is None:
ranks = torch.randint(0, MAX_TEST_POWER, (S, ), device='cuda')
ranks = 2**ranks # random powers of 2 between [1, MAX_TEST_POWER]
R = 2**(MAX_TEST_POWER - 1)
else:
ranks = torch.full((S, ), R, device='cuda', dtype=torch.int32)
weights = torch.randn((S * R, H), device='cuda', dtype=dtype)
w_locs = torch.randint(0,
weights.shape[0], (ranks.sum().item(), ),
device='cuda')
w_start = torch.cat([
torch.tensor([
0,
], device='cuda', dtype=torch.int32),
ranks.cumsum(dim=-1)[:-1]
])
indices = torch.arange(num_unique, device='cuda')
repeats = torch.full((num_unique, ),
repeats_per_lora,
device='cuda',
dtype=torch.int32)
repeats = torch.cat([
torch.zeros((1, ), device='cuda', dtype=torch.int32),
repeats.cumsum(dim=-1)
])
return (weights, w_start, ranks, w_locs, indices, repeats, num_unique, R,
dtype)


@pytest.mark.parametrize("S", [16 * 2**i for i in range(3, 4)] + [4096])
@pytest.mark.parametrize("R", [2**r for r in range(MAX_TEST_POWER)])
@pytest.mark.parametrize("H", [64, 4096, 7491])
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("repeats_per_lora", [1, 16])
@pytest.mark.parametrize("seed", [SEED])
@torch.inference_mode()
def test_correct(S, R, H, dtype, repeats_per_lora, seed):
torch.manual_seed(seed)
weights, w_start, ranks, w_locs, indices, repeats, num_unique, R, dtype = (
setup_(S, R, H, dtype, repeats_per_lora))

buffer = torch.randn((S, R), device='cuda', dtype=torch.float32)
out_col_offset = 128
out = torch.randn((S, H + out_col_offset), device='cuda', dtype=dtype)
ref_outs = []
for ui in range(num_unique):
idx = indices[ui]
w_rows = w_locs[w_start[idx]:w_start[idx] + ranks[idx]]
w = weights[w_rows].contiguous()
inp = buffer[repeats[ui]:repeats[ui + 1], :ranks[idx]].contiguous()
ref_out = inp.to(dtype=torch.float32) @ w.to(dtype=torch.float32)
ref_outs.append(ref_out)

ref_out = torch.cat(ref_outs, dim=0)
# doing this apparently leads to incorrect results in the first row
# + out[:, out_col_offset:]
ref_out += out[:, out_col_offset:].to(dtype=torch.float32)
# but this does not (likely depends on torch version)

# run the autotuner, add to a tmp output
sgmv.sgmv_expand(buffer,
weights,
torch.rand((S, H + out_col_offset),
device='cuda',
dtype=dtype),
w_start,
ranks,
w_locs,
indices,
repeats,
out_col_offset=out_col_offset)

sgmv.sgmv_expand(buffer,
weights,
out,
w_start,
ranks,
w_locs,
indices,
repeats,
out_col_offset=out_col_offset)

# diff = (ref_out - out[:, out_col_offset:].to(dtype=torch.float32)).abs()
# print(f'max diff {diff.max():0.5f}, mean {diff.mean():0.5f}')
# triton.language.dot, which is used for improved speed when
# rank and repeats are >= 16
# gives larger differences from torch
assert_close(ref_out,
out[:, out_col_offset:].to(dtype=torch.float32),
dtype=dtype,
tl_dot=repeats_per_lora >= 9)

x = torch.rand((S, H), device='cuda', dtype=dtype)
out = torch.zeros((S, R), device='cuda', dtype=torch.float32)
ref_outs = []
for ui in range(num_unique):
idx = indices[ui]
w_rows = w_locs[w_start[idx]:w_start[idx] + ranks[idx]]
w = weights[w_rows].contiguous()
inp = x[repeats[ui]:repeats[ui + 1]].contiguous()
ref_out = inp.to(dtype=torch.float32) @ w.to(dtype=torch.float32).T
ref_out = torch.cat([
ref_out,
torch.zeros((ref_out.shape[0], R - ref_out.shape[-1]),
dtype=ref_out.dtype,
device='cuda')
],
dim=-1)
ref_outs.append(ref_out)

ref_out = torch.cat(ref_outs, dim=0)
ref_out += out

# run the autotuner, add to a tmp output
sgmv.sgmv_shrink(x, weights, torch.rand_like(out), w_start, ranks, w_locs,
indices, repeats, R)

sgmv.sgmv_shrink(x, weights, out, w_start, ranks, w_locs, indices, repeats,
R)

# diff = (ref_out - out).abs()
# print(f'max diff {diff.max():0.5f}, mean {diff.mean():0.5f}')
assert_close(ref_out, out, dtype=dtype, tl_dot=repeats_per_lora >= 9)
Empty file.
Loading
Loading