-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathtest_bias_geglu.py
68 lines (51 loc) · 2.78 KB
/
test_bias_geglu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.transformer.inference.op_binding.gated_activation import GatedActivationOp
from deepspeed.utils.types import ActivationFuncType
from .inference_test_utils import allclose, get_dtypes
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
def run_bias_geglu_reference(activations, bias):
# Expected behavior is that of casting to float32 internally
# Explicitly using the default GeLU
activations = activations + bias.reshape(1, 1, -1)
hidden_states, gate = activations.chunk(2, dim=-1)
return hidden_states * torch.nn.functional.gelu(gate.to(torch.float32)).to(activations.dtype)
def run_bias_geglu_ds(activation, bias):
return GatedActivationOp()(activation, bias, ActivationFuncType.GATED_GELU)
@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255])
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_bias_geglu(batch, sequence, channels, dtype):
activation = torch.randn((batch, sequence, channels * 2), dtype=dtype, device=get_accelerator().device_name())
bias = torch.randn((channels * 2), dtype=dtype, device=get_accelerator().device_name())
ds_out = run_bias_geglu_ds(activation, bias)
ref_out = run_bias_geglu_reference(activation, bias)
assert (allclose(ds_out, ref_out))
def run_gated_silu_reference(activations, bias):
# Expected behavior is that of casting to float32 internally
# Explicitly using the default GeLU
activations = activations + bias.reshape(1, 1, -1)
hidden_states, gate = activations.chunk(2, dim=-1)
return hidden_states * torch.nn.functional.silu(gate.to(torch.float32)).to(activations.dtype)
def run_gated_silu_ds(activation, bias):
return GatedActivationOp()(activation, bias, ActivationFuncType.GATED_SILU)
@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255])
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_gated_silu(batch, sequence, channels, dtype):
activation = torch.randn((batch, sequence, channels * 2), dtype=dtype, device=get_accelerator().device_name())
bias = torch.randn((channels * 2), dtype=dtype, device=get_accelerator().device_name())
ds_out = run_gated_silu_ds(activation, bias)
ref_out = run_gated_silu_reference(activation, bias)
assert (allclose(ds_out, ref_out))