Skip to content

Commit

Permalink
【Hackathon 5th No.45】为torch.cuda.comm.scatter and torch.cuda.comm.gat…
Browse files Browse the repository at this point in the history
…her 添加test (#319)

add test for cuda_comm_scatter and cuda_comm_gather
  • Loading branch information
longranger2 authored Nov 6, 2023
1 parent 0e313a9 commit e8985bc
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 0 deletions.
71 changes: 71 additions & 0 deletions tests/test_cuda_comm_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.cuda.comm.gather")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
t1 = torch.randn((10, 3, 32, 32), device='cpu')
t2 = torch.randn((10, 3, 32, 32), device='cpu')
results = torch.cuda.comm.gather([t1, t2])
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle has no corresponding api tentatively",
)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
t1 = torch.randn((10, 3, 32, 32), device='cpu')
t2 = torch.randn((10, 3, 32, 32), device='cpu')
results = torch.cuda.comm.gather([t1, t2], dims=1)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle has no corresponding api tentatively",
)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
t1 = torch.randn((10, 3, 32, 32), device='cpu')
t2 = torch.randn((10, 3, 32, 32), device='cpu')
results = torch.cuda.comm.gather([t1, t2], destination='cpu')
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle has no corresponding api tentatively",
)
124 changes: 124 additions & 0 deletions tests/test_cuda_comm_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.cuda.comm.scatter")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
nhwc = torch.randn((10, 3, 32, 32), device='cpu')
results = torch.cuda.comm.scatter(tensor=nhwc)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle has no corresponding api tentatively",
)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
nhwc = torch.randn((10, 3, 32, 32), device='gpu')
results = torch.cuda.comm.scatter(tensor=nhwc)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle has no corresponding api tentatively",
)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
nhwc = torch.randn((10, 3, 32, 32), device='gpu')
devices = [torch.device('cuda:0'), torch.device('cuda:1')]
result = torch.cuda.comm.scatter(nhwc, devices=devices)
"""
)

obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle has no corresponding api tentatively",
)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
nhwc = torch.randn((10, 3, 32, 32), device='gpu')
chunk_sizes = [5, 5]
result = torch.cuda.comm.scatter(nhwc, chunk_sizes=chunk_sizes)
"""
)

obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle has no corresponding api tentatively",
)


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
nhwc = torch.randn((10, 3, 32, 32), device='gpu')
result = torch.cuda.comm.scatter(nhwc, dim=1)
"""
)

obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle has no corresponding api tentatively",
)


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
nhwc = torch.randn((10, 3, 32, 32), device='gpu')
t1 = torch.empty(5, 10, device='cuda')
t2 = torch.empty(5, 10, device='cuda')
result = torch.cuda.comm.scatter(nhwc, out=[t1, t2])
"""
)

obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle has no corresponding api tentatively",
)

0 comments on commit e8985bc

Please sign in to comment.