Skip to content

Commit

Permalink
Add tests for penalty coefficients
Browse files Browse the repository at this point in the history
  • Loading branch information
merajhashemi committed Jul 2, 2024
1 parent 180ba62 commit 37fad6f
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions tests/multipliers/test_penalty_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,29 @@ def test_penalty_coefficient_init_and_forward(penalty_coefficient_class, num_con
penalty_coefficient = penalty_coefficient_class(init_tensor)
indices = torch.arange(num_constraints, dtype=torch.long)

assert torch.allclose(evaluate_penalty_coefficient(penalty_coefficient, indices), init_tensor)
assert torch.equal(evaluate_penalty_coefficient(penalty_coefficient, indices), init_tensor)


def test_penalty_coefficient_failure_with_grad(penalty_coefficient_class, num_constraints):
generator = testing_utils.frozen_rand_generator()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="PenaltyCoefficient should not require gradients."):
penalty_coefficient_class(torch.randn(num_constraints, requires_grad=True, generator=generator))


def test_penalty_coefficient_failure_with_wrong_shape(penalty_coefficient_class):
generator = testing_utils.frozen_rand_generator()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="init must either be a scalar or a 1D tensor of shape `(num_constraints,)`."):
penalty_coefficient_class(torch.randn(1, generator=generator).unsqueeze(0))


def test_indexed_penalty_coefficient_forward_invalid_indices(num_constraints):
multiplier = multipliers.IndexedPenaltyCoefficient(torch.randn(num_constraints))
indices = torch.arange(num_constraints, dtype=torch.float32)

with pytest.raises(ValueError, match="Indices must be of type torch.long."):
multiplier(indices)


def test_save_and_load_state_dict(penalty_coefficient_class, num_constraints):
generator = testing_utils.frozen_rand_generator()
init_tensor = torch.randn(num_constraints, generator=generator)
Expand All @@ -57,4 +65,4 @@ def test_save_and_load_state_dict(penalty_coefficient_class, num_constraints):
new_penalty_coefficient.load_state_dict(state_dict)
new_penalty_coefficient_value = evaluate_penalty_coefficient(new_penalty_coefficient, indices)

assert torch.allclose(new_penalty_coefficient_value, penalty_coefficient_value)
assert torch.equal(new_penalty_coefficient_value, penalty_coefficient_value)

0 comments on commit 37fad6f

Please sign in to comment.