Skip to content

Commit

Permalink
Remove sparse test collection warning (pytorch#489)
Browse files Browse the repository at this point in the history
* push

* remove sparse test warning

* push
  • Loading branch information
msaroufim authored Jul 9, 2024
1 parent 1479c7d commit 48f48a8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/sparsity/test_fast_sparse_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from torchao.utils import TORCH_VERSION_AFTER_2_4, is_fbcode

class TestModel(nn.Module):
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(128, 256, bias=False)
Expand All @@ -36,7 +36,7 @@ def test_runtime_weight_sparsification(self):
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
input = torch.rand((128, 128)).half().cuda()
grad = torch.rand((128, 128)).half().cuda()
model = TestModel().half().cuda()
model = ToyModel().half().cuda()
model_c = copy.deepcopy(model)

for name, mod in model.named_modules():
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_runtime_weight_sparsification_compile(self):
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
input = torch.rand((128, 128)).half().cuda()
grad = torch.rand((128, 128)).half().cuda()
model = TestModel().half().cuda()
model = ToyModel().half().cuda()
model_c = copy.deepcopy(model)

for name, mod in model.named_modules():
Expand Down

0 comments on commit 48f48a8

Please sign in to comment.