Skip to content

Commit

Permalink
test: ensure float32 works as a dtype input for sharded
Browse files Browse the repository at this point in the history
  • Loading branch information
william-silversmith committed Jan 29, 2025
1 parent e5025d1 commit 614a1fb
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions test/test_shards.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ def prod(x):
},
]

@pytest.mark.parametrize("dtype", [np.uint8, np.float32])
@pytest.mark.parametrize("scale", SCALES)
def test_sharded_image_bits(scale):
def test_sharded_image_bits(scale, dtype):
dataset_size = Vec(*scale["size"])
chunk_size = Vec(*scale["chunk_sizes"][0])

spec = create_sharded_image_info(
dataset_size=dataset_size,
chunk_size=chunk_size,
encoding=scale["encoding"],
dtype=np.uint8
dtype=dtype
)

shape = image_shard_shape_from_spec(
Expand Down

0 comments on commit 614a1fb

Please sign in to comment.