diff --git a/benchmarks/kernels/benchmark_sgmv_triton.py b/benchmarks/kernels/benchmark_sgmv_triton.py new file mode 100644 index 0000000000000..9cacebdd29d5c --- /dev/null +++ b/benchmarks/kernels/benchmark_sgmv_triton.py @@ -0,0 +1,128 @@ +import torch +import triton + +from tests.lora.test_sgmv_triton import MAX_TEST_POWER, setup_ +from vllm.model_executor.layers.lora import sgmv_triton as sgmv + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['S'], # argument names to use as an x-axis for the plot + x_vals=[16 * 2**i for i in range(3, 6)] + + [4096], # different possible values for `x_name` + line_arg= + 'R', # argument name which corresponds to a different line in the plot + line_vals=[64, None], # possible values for `line_arg`` + line_names=['Rank=64', f'Random Rank up to {2**MAX_TEST_POWER}' + ], # label name for the lines + styles=[('blue', '-'), ('green', '-')], # line styles + ylabel="ms", # label name for the y-axis + plot_name= + "sgmv", # name for the plot. Used as file name for saving the plot too. + args={}, + )) +def benchmark_repeats_expand(S, R, repeats_per_lora=1): + weights, w_start, ranks, w_locs, indices, repeats, _, R, dtype = setup_( + S, R, 4096, dtype=torch.bfloat16, repeats_per_lora=repeats_per_lora) + + buffer = torch.randn((S, R), device='cuda', dtype=torch.float32) + out = torch.randn((S, 4096), device='cuda', dtype=dtype) + ms = triton.testing.do_bench(lambda: sgmv.sgmv_expand(buffer, + weights, + out, + w_start, + ranks, + w_locs, + indices, + repeats, + out_col_offset=0), + warmup=500, + rep=4000) + return ms + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['S'], # argument names to use as an x-axis for the plot + x_vals=[16 * 2**i for i in range(3, 6)] + + [4096], # different possible values for `x_name` + line_arg= + 'R', # argument name which corresponds to a different line in the plot + line_vals=[64, None], # possible values for `line_arg`` + line_names=['Rank=64', f'Random Rank up to {2**MAX_TEST_POWER}' + ], # label name for the lines + styles=[('blue', '-'), ('green', '-')], # line styles + ylabel="ms", # label name for the y-axis + plot_name= + "sgmv", # name for the plot. Used as file name for saving the plot too. + args={}, + )) +def benchmark_repeats_shrink(S, R, repeats_per_lora=1): + weights, w_start, ranks, w_locs, indices, repeats, _, R, dtype = setup_( + S, R, 4096, dtype=torch.bfloat16, repeats_per_lora=repeats_per_lora) + + x = torch.rand((S, 4096), device='cuda', dtype=dtype) + out = torch.zeros((S, R), device='cuda', dtype=torch.float32) + ms = triton.testing.do_bench(lambda: sgmv.sgmv_shrink( + x, weights, out, w_start, ranks, w_locs, indices, repeats, R), + warmup=500, + rep=4000) + return ms + + +if __name__ == '__main__': + # NOTE: the random rank benchmark is random ranks up to 2^MAX_TEST_POWER, + # not random up to the rank specified, + # so it doesn't change when you change the rank you're testing + print('These benchmarks can vary a decent amount sometimes. ', end='') + print('They should be consistent across increasing seq length ') + print('(slower but with strong superlinear scaling), ', end='') + print('consistently faster with random ranks up to the same rank and') + print('faster with increasing repeats per lora.') + print('Times are in ms.') + print('-' * 40) + print('Expand | repeats [1]') + benchmark_repeats_expand.run(show_plots=False, + print_data=True, + repeats_per_lora=1) + print('-' * 40) + print('Shrink | repeats [1]') + benchmark_repeats_shrink.run(show_plots=False, + print_data=True, + repeats_per_lora=1) + + print('-' * 40) + print('Expand | repeats [8]') + benchmark_repeats_expand.run(show_plots=False, + print_data=True, + repeats_per_lora=8) + print('-' * 40) + print('Shrink | repeats [8]') + benchmark_repeats_shrink.run(show_plots=False, + print_data=True, + repeats_per_lora=8) + + # set repeats >= 16 for plaid mode + # (tl.dot is applicable which makes it fast) + print('-' * 40) + print('Expand | repeats [16]') + benchmark_repeats_expand.run(show_plots=False, + print_data=True, + repeats_per_lora=16) + print('-' * 40) + print('Shrink | repeats [16]') + benchmark_repeats_shrink.run(show_plots=False, + print_data=True, + repeats_per_lora=16) + + print('-' * 40) + print('Expand | repeats [32]') + benchmark_repeats_expand.run(show_plots=False, + print_data=True, + repeats_per_lora=32) + print('-' * 40) + print('Shrink | repeats [32]') + benchmark_repeats_shrink.run(show_plots=False, + print_data=True, + repeats_per_lora=32) + print('-' * 40) diff --git a/tests/lora/test_sgmv_triton.py b/tests/lora/test_sgmv_triton.py new file mode 100644 index 0000000000000..a98131dd0c925 --- /dev/null +++ b/tests/lora/test_sgmv_triton.py @@ -0,0 +1,147 @@ +import math + +import pytest +import torch + +from vllm.model_executor.layers.lora import sgmv_triton as sgmv + +MAX_TEST_POWER = 6 +SEED = 42 + + +def assert_close(a, b, dtype, tl_dot=False): + rtol, atol = { + torch.float16: (5e-3, 5e-3) if not tl_dot else (1e-2, 7e-2), + torch.bfloat16: (3e-2, 2e-2) if not tl_dot else (3e-2, 1e-1), + torch.float32: (2e-3, 3e-4) if not tl_dot else (1e-2, 7e-2), + }[dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +def setup_(S, R, H, dtype, repeats_per_lora=1): + S = math.ceil(S / repeats_per_lora) * repeats_per_lora + num_unique = S // repeats_per_lora + if R is None: + ranks = torch.randint(0, MAX_TEST_POWER, (S, ), device='cuda') + ranks = 2**ranks # random powers of 2 between [1, MAX_TEST_POWER] + R = 2**(MAX_TEST_POWER - 1) + else: + ranks = torch.full((S, ), R, device='cuda', dtype=torch.int32) + weights = torch.randn((S * R, H), device='cuda', dtype=dtype) + w_locs = torch.randint(0, + weights.shape[0], (ranks.sum().item(), ), + device='cuda') + w_start = torch.cat([ + torch.tensor([ + 0, + ], device='cuda', dtype=torch.int32), + ranks.cumsum(dim=-1)[:-1] + ]) + indices = torch.arange(num_unique, device='cuda') + repeats = torch.full((num_unique, ), + repeats_per_lora, + device='cuda', + dtype=torch.int32) + repeats = torch.cat([ + torch.zeros((1, ), device='cuda', dtype=torch.int32), + repeats.cumsum(dim=-1) + ]) + return (weights, w_start, ranks, w_locs, indices, repeats, num_unique, R, + dtype) + + +@pytest.mark.parametrize("S", [16 * 2**i for i in range(3, 4)] + [4096]) +@pytest.mark.parametrize("R", [2**r for r in range(MAX_TEST_POWER)]) +@pytest.mark.parametrize("H", [64, 4096, 7491]) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("repeats_per_lora", [1, 16]) +@pytest.mark.parametrize("seed", [SEED]) +@torch.inference_mode() +def test_correct(S, R, H, dtype, repeats_per_lora, seed): + torch.manual_seed(seed) + weights, w_start, ranks, w_locs, indices, repeats, num_unique, R, dtype = ( + setup_(S, R, H, dtype, repeats_per_lora)) + + buffer = torch.randn((S, R), device='cuda', dtype=torch.float32) + out_col_offset = 128 + out = torch.randn((S, H + out_col_offset), device='cuda', dtype=dtype) + ref_outs = [] + for ui in range(num_unique): + idx = indices[ui] + w_rows = w_locs[w_start[idx]:w_start[idx] + ranks[idx]] + w = weights[w_rows].contiguous() + inp = buffer[repeats[ui]:repeats[ui + 1], :ranks[idx]].contiguous() + ref_out = inp.to(dtype=torch.float32) @ w.to(dtype=torch.float32) + ref_outs.append(ref_out) + + ref_out = torch.cat(ref_outs, dim=0) + # doing this apparently leads to incorrect results in the first row + # + out[:, out_col_offset:] + ref_out += out[:, out_col_offset:].to(dtype=torch.float32) + # but this does not (likely depends on torch version) + + # run the autotuner, add to a tmp output + sgmv.sgmv_expand(buffer, + weights, + torch.rand((S, H + out_col_offset), + device='cuda', + dtype=dtype), + w_start, + ranks, + w_locs, + indices, + repeats, + out_col_offset=out_col_offset) + + sgmv.sgmv_expand(buffer, + weights, + out, + w_start, + ranks, + w_locs, + indices, + repeats, + out_col_offset=out_col_offset) + + # diff = (ref_out - out[:, out_col_offset:].to(dtype=torch.float32)).abs() + # print(f'max diff {diff.max():0.5f}, mean {diff.mean():0.5f}') + # triton.language.dot, which is used for improved speed when + # rank and repeats are >= 16 + # gives larger differences from torch + assert_close(ref_out, + out[:, out_col_offset:].to(dtype=torch.float32), + dtype=dtype, + tl_dot=repeats_per_lora >= 9) + + x = torch.rand((S, H), device='cuda', dtype=dtype) + out = torch.zeros((S, R), device='cuda', dtype=torch.float32) + ref_outs = [] + for ui in range(num_unique): + idx = indices[ui] + w_rows = w_locs[w_start[idx]:w_start[idx] + ranks[idx]] + w = weights[w_rows].contiguous() + inp = x[repeats[ui]:repeats[ui + 1]].contiguous() + ref_out = inp.to(dtype=torch.float32) @ w.to(dtype=torch.float32).T + ref_out = torch.cat([ + ref_out, + torch.zeros((ref_out.shape[0], R - ref_out.shape[-1]), + dtype=ref_out.dtype, + device='cuda') + ], + dim=-1) + ref_outs.append(ref_out) + + ref_out = torch.cat(ref_outs, dim=0) + ref_out += out + + # run the autotuner, add to a tmp output + sgmv.sgmv_shrink(x, weights, torch.rand_like(out), w_start, ranks, w_locs, + indices, repeats, R) + + sgmv.sgmv_shrink(x, weights, out, w_start, ranks, w_locs, indices, repeats, + R) + + # diff = (ref_out - out).abs() + # print(f'max diff {diff.max():0.5f}, mean {diff.mean():0.5f}') + assert_close(ref_out, out, dtype=dtype, tl_dot=repeats_per_lora >= 9) diff --git a/vllm/model_executor/layers/lora/__init__.py b/vllm/model_executor/layers/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/lora/sgmv_triton.py b/vllm/model_executor/layers/lora/sgmv_triton.py new file mode 100644 index 0000000000000..ec24af5f82052 --- /dev/null +++ b/vllm/model_executor/layers/lora/sgmv_triton.py @@ -0,0 +1,340 @@ +import torch +import triton +import triton.language as tl + +# generally faster than 16, but can be lowered to 16 to reduce the +# shared memory required by the kernel. +MAX_REPEATS_PER_BLOCK = 32 + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H_OUT': 32}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_OUT': 32}, num_warps=2), + triton.Config({'BLOCK_SIZE_H_OUT': 32}, num_warps=4), + triton.Config({'BLOCK_SIZE_H_OUT': 64}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_OUT': 64}, num_warps=2), + triton.Config({'BLOCK_SIZE_H_OUT': 64}, num_warps=4), + triton.Config({'BLOCK_SIZE_H_OUT': 128}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_OUT': 128}, num_warps=2), + triton.Config({'BLOCK_SIZE_H_OUT': 128}, num_warps=4), + ], + key=['R', 'H'], +) +@triton.jit +def sgmv_shrink_multi_lora_rank( + # Same arguments as below, some renamed + x_ptr, + w_ptr, + o_ptr, + w_start, + ranks, + w_locs, + indices, + repeats, + S, + R: tl.constexpr, + H, + stride_xs, + stride_xh, + stride_wp, + stride_wh, + stride_os, + stride_or, + # Meta-parameters + BLOCK_SIZE_INPUT_PER_LORA: tl.constexpr, + BLOCK_SIZE_H_OUT: tl.constexpr): + """ + The shrink side of the lora, very similar implementation to expand, but + uses the split-k strategy as in punica. + """ + # grid will be [num_unique, h out // block size h out] + lora_id, h_id = tl.program_id(axis=0), tl.program_id(axis=1) + idx = tl.load(indices + lora_id) + if idx < 0: + return + rank = tl.load(ranks + idx) + w_start_ = tl.load(w_start + idx) + repeats_0, repeats_1 = (tl.load(repeats + lora_id), + tl.load(repeats + lora_id + 1)) + + n_inputs = repeats_1 - repeats_0 + input_range = tl.arange(0, BLOCK_SIZE_INPUT_PER_LORA) + offs_xs = repeats_0 + input_range + rank_range = tl.arange(0, R) + offs_wp = tl.load(w_locs + w_start_ + rank_range) + offs_h = h_id * BLOCK_SIZE_H_OUT + tl.arange(0, BLOCK_SIZE_H_OUT) + offs_os = offs_xs + + w_ptrs = w_ptr + offs_wp[:, None] * stride_wp + offs_h[None, :] * stride_wh + w = tl.load(w_ptrs, + mask=(offs_h[None, :] < H) & (rank_range[:, None] < rank), + other=0.0).to(dtype=tl.float32) # [R, H_OUT] + + # tl.dot works only on sizes >= 16 + if BLOCK_SIZE_INPUT_PER_LORA >= 16 and R >= 16: + x_ptrs = (x_ptr + offs_xs[:, None] * stride_xs + + offs_h[None, :] * stride_xh) + # [next pow 2 inputs for this lora, R] + x = tl.load(x_ptrs, + mask=(input_range[:, None] < n_inputs) & + (offs_h[None, :] < H), + other=0.0).to(dtype=tl.float32) + + o_ptrs = (o_ptr + offs_os[:, None] * stride_os + + rank_range[None, :] * stride_or) + tl.atomic_add(o_ptrs, + tl.dot(x, tl.trans(w)), + mask=(input_range[:, None] < n_inputs) & + (rank_range[None, :] < rank)) + else: + for i in range(n_inputs): + x_ptrs = x_ptr + (repeats_0 + i) * stride_xs + offs_h * stride_xh + o_ptrs = (o_ptr + (repeats_0 + i) * stride_os + + rank_range * stride_or) + x = tl.load(x_ptrs, mask=offs_h < H, + other=0.0).to(dtype=tl.float32) + tl.atomic_add(o_ptrs, + tl.sum(x[None, :] * w, axis=1), + mask=rank_range < rank) + + +@torch.inference_mode() +def sgmv_shrink(x, weights, out, w_start, ranks, w_locs, indices, repeats, + max_rank): + # Check constraints. + assert weights.shape[-1] == x.shape[-1], ( + "weight hidden dim is greater than the output tensor hidden dim: " + + f"weight shape {weights.shape}, out shape {out.shape}") + assert x.shape[0] == out.shape[0], ( + "x shape at 0 differs from out shape at 0: x shape " + + f"{x.shape}, out shape {out.shape}") + assert max_rank >= ranks.max(), ( + "ranks tensor includes a rank that is higher than the given max_rank") + assert x.is_contiguous(), "x must be contiguous" + assert weights.is_contiguous(), "Weights must be contiguous" + assert out.is_contiguous(), "Out must be contiguous" + S, H = x.shape + R = max_rank + assert triton.next_power_of_2(R) == R + + BLOCK_SIZE_INPUT_PER_LORA = triton.next_power_of_2( + (repeats[1:] - repeats[:-1]).max().item()) + # for load balancing and shared memory limitations + assert BLOCK_SIZE_INPUT_PER_LORA <= MAX_REPEATS_PER_BLOCK, ( + "Exceeded the maximum number of repeats for a single lora. " + + "Repeats should be split into groups of size at most " + + f"{MAX_REPEATS_PER_BLOCK}") + grid = lambda META: (len(repeats) - 1, + triton.cdiv(H, META['BLOCK_SIZE_H_OUT'])) + sgmv_shrink_multi_lora_rank[grid]( + x, + weights, + out, + w_start, + ranks, + w_locs, + indices, + repeats, + S, + R, + H, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + out.stride(0), + out.stride(1), + BLOCK_SIZE_INPUT_PER_LORA=BLOCK_SIZE_INPUT_PER_LORA) + return out + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H_OUT': 32}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_OUT': 32}, num_warps=2), + triton.Config({'BLOCK_SIZE_H_OUT': 32}, num_warps=4), + triton.Config({'BLOCK_SIZE_H_OUT': 64}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_OUT': 64}, num_warps=2), + triton.Config({'BLOCK_SIZE_H_OUT': 64}, num_warps=4), + triton.Config({'BLOCK_SIZE_H_OUT': 128}, num_warps=1), + triton.Config({'BLOCK_SIZE_H_OUT': 128}, num_warps=2), + triton.Config({'BLOCK_SIZE_H_OUT': 128}, num_warps=4), + ], + key=['R', 'H'], +) +@triton.jit +def sgmv_expand_multi_lora_rank( + # NOTE: Inputs MUST be grouped by lora + # Pointers to buffer, weight page and output respectively + b_ptr, + w_ptr, + o_ptr, + # indices a tensor of [num unique loras in seq] + # repeats [num unique loras in seq + 1] + # indices contains, for each group of inputs, the unique lora idx + # repeats, r such that sum(r)=seq_length, repeats=cumsum(r). + # Cumulative sum of how many inputs are using the same lora, + # starting at 0 + # w_locs holds the row indices for a [page_size, hidden] + # tensor in which the weights are stored + # rows of weights for a lora are not necessarily contiguous + # lora is w_ptr[ + # w_locs[ + # w_start[indices[lora_id]] : + # w_start[indices[lora_id]] + ranks[indices[lora_id]]] + # ] + w_start, + ranks, + w_locs, + indices, + repeats, + # optional output column offset + out_col_offset, + scale, + # Dimensions, sequence length/batch, max rank, hidden out + S, + R: tl.constexpr, + H, + # row, col stride for each + stride_bs, + stride_br, + stride_wp, + stride_wh, + stride_os, + stride_oh, + # Meta-parameters + BLOCK_SIZE_INPUT_PER_LORA: tl.constexpr, + BLOCK_SIZE_H_OUT: tl.constexpr): + """ + The punica expand kernel in Triton. Can take advantage of tl.dot() for + increased speed when the rank and number of inputs are larger than 16. + i.e. prefill or grouped + """ + # grid will be [num_unique, h out // block size h out] + lora_id, h_id = tl.program_id(axis=0), tl.program_id(axis=1) + idx = tl.load(indices + lora_id) + if idx < 0: + return + rank = tl.load(ranks + idx) + w_start_ = tl.load(w_start + idx) + repeats_0, repeats_1 = tl.load(repeats + lora_id), tl.load(repeats + + lora_id + 1) + + n_inputs = repeats_1 - repeats_0 + input_range = tl.arange(0, BLOCK_SIZE_INPUT_PER_LORA) + offs_bs = repeats_0 + input_range + rank_range = tl.arange(0, R) + offs_wp = tl.load(w_locs + w_start_ + rank_range) + offs_wh = h_id * BLOCK_SIZE_H_OUT + tl.arange(0, BLOCK_SIZE_H_OUT) + offs_r = rank_range + + w_ptrs = (w_ptr + offs_wp[:, None] * stride_wp + + offs_wh[None, :] * stride_wh) + + offs_os = offs_bs + offs_oh = offs_wh + + w = tl.load(w_ptrs, + mask=(offs_wh[None, :] < H) & (rank_range[:, None] < rank), + other=0.0).to(dtype=tl.float32) # [R, H_OUT] + + # tl.dot works only on sizes >= 16 + if BLOCK_SIZE_INPUT_PER_LORA >= 16 and R >= 16: + b_ptrs = (b_ptr + offs_bs[:, None] * stride_bs + + offs_r[None, :] * stride_br) + buffer = tl.load(b_ptrs, + mask=(input_range[:, None] < n_inputs) & + (rank_range[None, :] < rank), + other=0.0) # [next pow 2 inputs for this lora, R] + buffer *= scale + + o_ptrs = (o_ptr + offs_os[:, None] * stride_os + + (offs_oh[None, :] + out_col_offset) * stride_oh) + accumulator = tl.load(o_ptrs, + mask=(input_range[:, None] < n_inputs) & + (offs_oh[None, :] < H), + other=0.0).to(dtype=tl.float32) + accumulator += tl.dot(buffer, w) + + tl.store(o_ptrs, + accumulator, + mask=(input_range[:, None] < n_inputs) & + (offs_oh[None, :] < H)) + else: + for i in range(n_inputs): + b_ptrs = b_ptr + (repeats_0 + i) * stride_bs + offs_r * stride_br + o_ptrs = (o_ptr + (repeats_0 + i) * stride_os + + (offs_oh + out_col_offset) * stride_oh) + out = tl.load(o_ptrs, mask=offs_oh < H, + other=0.0).to(dtype=tl.float32) + buffer = tl.load(b_ptrs, mask=rank_range < rank, + other=0.0).to(dtype=tl.float32) + buffer *= scale + + out += tl.sum(buffer[:, None] * w, axis=0) + tl.store(o_ptrs, out, mask=offs_oh < H) + + +@torch.inference_mode() +def sgmv_expand(buffer, + weights, + out, + w_start, + ranks, + w_locs, + indices, + repeats, + out_col_offset=0, + scale=1.0): + # Check constraints. + assert ranks.max() <= buffer.shape[1], ( + "Ranks argument includes a higher rank than the buffer's " + + f"second dim: max rank {ranks.max()}, buffer shape {buffer.shape}") + assert weights.shape[-1] <= out.shape[-1], ( + "Weight hidden dim is greater than the output tensor hidden " + + f"dim: weight shape {weights.shape}, out shape {out.shape}") + assert buffer.shape[0] == out.shape[0], ( + "Buffer shape at 0 differs from out shape at 0: " + + f"buffer shape {buffer.shape}, out shape {out.shape}") + assert out_col_offset + weights.shape[-1] <= out.shape[-1], ( + f"Output column offset {out_col_offset} with output dim " + + f"{weights.shape[-1]} is too high for given output tensor {out.shape}") + assert buffer.is_contiguous(), "Buffer must be contiguous" + assert weights.is_contiguous(), "Weights must be contiguous" + assert out.is_contiguous(), "Out must be contiguous" + S, R = buffer.shape + H = weights.shape[-1] + assert triton.next_power_of_2(R) == R + + BLOCK_SIZE_INPUT_PER_LORA = triton.next_power_of_2( + (repeats[1:] - repeats[:-1]).max().item()) + # for load balancing and shared memory limitations + assert BLOCK_SIZE_INPUT_PER_LORA <= MAX_REPEATS_PER_BLOCK, ( + "Exceeded the maximum number of repeats for a single lora. " + + "Repeats should be split into groups of size at most " + + f"{MAX_REPEATS_PER_BLOCK}") + grid = lambda META: (len(repeats) - 1, + triton.cdiv(H, META['BLOCK_SIZE_H_OUT'])) + sgmv_expand_multi_lora_rank[grid]( + buffer, + weights, + out, + w_start, + ranks, + w_locs, + indices, + repeats, + out_col_offset, + scale, + S, + R, + H, + buffer.stride(0), + buffer.stride(1), + weights.stride(0), + weights.stride(1), + out.stride(0), + out.stride(1), + BLOCK_SIZE_INPUT_PER_LORA=BLOCK_SIZE_INPUT_PER_LORA) + return out