Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Tests #91

Merged
merged 105 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
4aee1d6
Implement new `SquaredNormLinearCMP` for tests
merajhashemi May 1, 2024
efbbe85
Refactor old convergence tests to use new cmp
merajhashemi May 1, 2024
becceaa
Fix augmented lagrangian convergence test
merajhashemi May 1, 2024
c49f65d
Fix line break style
merajhashemi May 1, 2024
4485eb1
Remove hardcoded constants
merajhashemi May 15, 2024
a5f8b53
Rename variables
merajhashemi May 15, 2024
ff8c2d3
Add observed constraint feature size as parameter
merajhashemi May 15, 2024
3cfbab1
Randomize convergence test
merajhashemi May 15, 2024
8a5add0
Fix type hint
merajhashemi May 16, 2024
7f9cdcf
Add cvxpy to requirements
merajhashemi May 16, 2024
c740a30
Fix randomized convergence test
merajhashemi May 16, 2024
570d38c
Fix stochasticity in noise
merajhashemi May 18, 2024
89a9e3e
Add optimality check for cvxpy
merajhashemi May 18, 2024
8808a0f
Add check to skip overconstrained settings
merajhashemi May 18, 2024
eb501d8
Add manual tests
merajhashemi May 20, 2024
e76b085
Fix tests augmented lagrangian
merajhashemi May 21, 2024
437244f
Fix manual indexed multiplier tests
merajhashemi May 22, 2024
d0fa38d
Fix manual surrogate tests
merajhashemi May 26, 2024
e359825
Add manual test for equality constraints
merajhashemi Jun 3, 2024
6ecc2a2
Remove unused test util functions
merajhashemi Jun 3, 2024
da26c90
Move common fixtures to conftest.py
merajhashemi Jun 4, 2024
0af638f
Fix checkpoint test
merajhashemi Jun 4, 2024
6cc912e
Add test for cmp state_dict
merajhashemi Jun 4, 2024
8dcedb8
Remove `Toy2dCMP`
merajhashemi Jun 4, 2024
2600dc9
Remove unused utils for tests
merajhashemi Jun 4, 2024
d2a0549
Add tests for multiple primal optimizers
merajhashemi Jun 4, 2024
29b3cdd
Move fixtures from utils to conftest
merajhashemi Jun 5, 2024
13d679a
Fix tests for multiple primal optimizers
merajhashemi Jun 5, 2024
01c2528
Skip convergence test for equality constraint
merajhashemi Jun 5, 2024
ef299f3
Add convergence test for no constraints
merajhashemi Jun 5, 2024
c7796c6
Fix convergence test for surrogate
merajhashemi Jun 7, 2024
a028e68
Fix manual test for surrogate
merajhashemi Jun 10, 2024
1c78b70
Adjust hyperparameters for convergence test
merajhashemi Jun 10, 2024
4db784f
Fix tests for multiple primal optimizers
merajhashemi Jun 12, 2024
650b906
Updates to tests
juan43ramirez Jun 17, 2024
11c5de4
pipeline test updates
juan43ramirez Jun 17, 2024
3422f86
Revert "pipeline test updates"
merajhashemi Jun 17, 2024
60189da
(Fix tests on GPU): Ensure CPU and GPU create consistent tensors
merajhashemi Jun 17, 2024
97ebd1d
Fix style
merajhashemi Jun 18, 2024
3b7c7b1
Add expected failure test for formulation
merajhashemi Jun 18, 2024
756d6fb
Revert "Fix style"
merajhashemi Jun 18, 2024
4ba95ad
Remove extra line
merajhashemi Jun 20, 2024
e2e69f3
(Fix test setup): Generate random full-rank matrices with uniform sin…
merajhashemi Jun 20, 2024
9aec4b8
Simplify primal_lr selection for new tests
merajhashemi Jun 20, 2024
2ccb572
Fix primal_lr for Adam
merajhashemi Jun 23, 2024
b5a8a2e
Add docstring
merajhashemi Jun 23, 2024
23fe34f
Fix comment
merajhashemi Jun 23, 2024
c808461
Parameterize surrogate constraint noise magnitude
merajhashemi Jun 23, 2024
e1c891c
Add type hints
merajhashemi Jun 23, 2024
f2f57c7
Remove unnecessary function
merajhashemi Jun 23, 2024
765573e
Add comment
merajhashemi Jun 23, 2024
4e97452
Remove unnecessary imports
merajhashemi Jun 23, 2024
89e8c4c
Add docstring
merajhashemi Jun 24, 2024
1ee2e2e
Rename fixture
merajhashemi Jun 24, 2024
af4444a
Fix checkpoint test with multiple primal optimizers
merajhashemi Jun 24, 2024
62a4c48
Add docstring
merajhashemi Jun 24, 2024
442a2db
Improve code readability
merajhashemi Jun 24, 2024
53b2837
Tighter tolerance for manual tests
merajhashemi Jun 24, 2024
c9acfaa
Add seed fixture to create different tests
merajhashemi Jun 24, 2024
6397a50
Improve code readability
merajhashemi Jun 24, 2024
f2439a6
Rename variables for clarity
merajhashemi Jun 26, 2024
295ccdc
Rename variables for clarity
merajhashemi Jun 27, 2024
99c3015
Split tests in separate files
merajhashemi Jun 27, 2024
f21e504
Refactor setup steps in separate fixtures
merajhashemi Jun 27, 2024
f43734d
Move test to separate module
merajhashemi Jun 27, 2024
ee40cd7
Refactor setup steps in separate fixtures
merajhashemi Jun 27, 2024
e8d4155
Move fixtures to conftest module
merajhashemi Jun 27, 2024
435de8f
Move manual test module
merajhashemi Jun 27, 2024
22f6cd8
Fix code repetition
merajhashemi Jun 27, 2024
f9515dc
Add helper function for extracting constraint features from cmpstate
merajhashemi Jun 27, 2024
98cb04a
Use fixtures
merajhashemi Jun 27, 2024
728bea3
Add tests for `LagrangianStore`
merajhashemi Jun 27, 2024
670747b
Remove importlib for python 3.7 and older
merajhashemi Jun 28, 2024
94c6ca0
Add functional tests for cmp module
merajhashemi Jun 29, 2024
8e87eed
Remove unnecessary line
merajhashemi Jun 29, 2024
a5e31a7
Refactor conftest fixtures
merajhashemi Jun 29, 2024
be99380
Fix init for scalar penalty coefficient
merajhashemi Jun 29, 2024
7e1f9ea
Add functional tests for penalty coefficient updater
merajhashemi Jun 30, 2024
e5871ff
Remove unnecessary line when required positional argument isn't provided
merajhashemi Jul 1, 2024
180ba62
Add tests for multipliers
merajhashemi Jul 1, 2024
37fad6f
Add tests for penalty coefficients
merajhashemi Jul 2, 2024
bc32603
Add tests for constraint
merajhashemi Jul 3, 2024
da51c05
Add tests for constraint_state
merajhashemi Jul 3, 2024
5571335
Refactor move test_formulation
merajhashemi Jul 3, 2024
1efd124
Fix type hints
merajhashemi Jul 4, 2024
27ffc3a
Fix style
merajhashemi Jul 4, 2024
8e08adf
Fix penalty coefficient for scalar
merajhashemi Jul 4, 2024
dd16cb4
Add formulation_utils test
merajhashemi Jul 4, 2024
1b973fd
Ensure tensors are on same device
merajhashemi Jul 4, 2024
91ab48b
Improve multiplier & penalty coefficients design
merajhashemi Jul 4, 2024
6503e7a
Fix device check in test
merajhashemi Jul 4, 2024
5018079
Add more tests for formulation_utils
merajhashemi Jul 4, 2024
3ccbf46
Add more tests for multipliers
merajhashemi Jul 5, 2024
2028262
Add more tests for penalty coefficients
merajhashemi Jul 5, 2024
d87e133
Add more tests for constraint_state
merajhashemi Jul 5, 2024
a08a46e
Add more tests for formulations
merajhashemi Jul 5, 2024
d695156
Add tests for cmp
merajhashemi Jul 6, 2024
c552070
Fix typo
merajhashemi Jul 6, 2024
874ec27
Add tests for optimizers
merajhashemi Jul 5, 2024
ca94269
Undo change to ensure_sequence util function
merajhashemi Jul 8, 2024
1264ad3
Remove unused variable
merajhashemi Jul 8, 2024
488edc3
Rename fixture
merajhashemi Jul 8, 2024
2df71d1
Fix typo
merajhashemi Jul 8, 2024
1074ba2
Replace repetition with loop
merajhashemi Jul 8, 2024
3ee14f9
Remove code repetition for penalty coefficient sanity_check
merajhashemi Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions cooper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
"""Top-level package for Cooper."""

import sys

if sys.version_info >= (3, 8):
from importlib.metadata import PackageNotFoundError, version
else:
from importlib_metadata import PackageNotFoundError, version
from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version("cooper")
Expand Down
10 changes: 9 additions & 1 deletion cooper/cmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def observed_strict_violations(self):
for constraint_state in self.observed_constraints.values():
yield constraint_state.strict_violation

def observed_constraint_features(self):
for constraint_state in self.observed_constraints.values():
yield constraint_state.constraint_features

def observed_strict_constraint_features(self):
for constraint_state in self.observed_constraints.values():
yield constraint_state.strict_constraint_features


class ConstrainedMinimizationProblem(abc.ABC):
"""Template for constrained minimization problems."""
Expand Down Expand Up @@ -177,7 +185,7 @@ def named_penalty_coefficients(self) -> Iterator[tuple[str, PenaltyCoefficient]]
if constraint.penalty_coefficient is not None:
yield constraint_name, constraint.penalty_coefficient

def dual_parameters(self) -> Iterator[Multiplier]:
def dual_parameters(self) -> Iterator[torch.nn.Parameter]:
"""Return an iterator over the parameters of the multipliers associated with the
registered constraints of the CMP. This method is useful for instantiating the
dual optimizers. If a multiplier is shared by several constraints, we only
Expand Down
21 changes: 2 additions & 19 deletions cooper/constraints/constraint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Literal, Optional, Type

import torch

from cooper.constraints.constraint_state import ConstraintState
from cooper.constraints.constraint_type import ConstraintType
from cooper.formulations import ContributionStore, Formulation, LagrangianFormulation
Expand All @@ -26,25 +24,10 @@ def __init__(
self.formulation = formulation_type(constraint_type=self.constraint_type)

self.multiplier = multiplier
if multiplier.constraint_type != self.constraint_type:
raise ValueError(
f"Attempted to pair {self.constraint_type} constraint, with {multiplier.constraint_type} multiplier."
)
self.multiplier.sanity_check()
self.multiplier.set_constraint_type(constraint_type)

self.penalty_coefficient = penalty_coefficient
self.sanity_check_penalty_coefficient()

def sanity_check_penalty_coefficient(self) -> None:
if self.formulation.expects_penalty_coefficient:
if self.penalty_coefficient is None:
raise ValueError(f"{self.formulation_type} expects a penalty coefficient but none was provided.")
else:
if torch.any(self.penalty_coefficient.value < 0):
raise ValueError("All entries of the penalty coefficient must be non-negative.")
else:
if self.penalty_coefficient is not None:
raise ValueError(f"Received unexpected penalty coefficient for {self.formulation_type}.")
self.formulation.sanity_check_penalty_coefficient(penalty_coefficient)

def compute_contribution_to_lagrangian(
self, constraint_state: ConstraintState, primal_or_dual: Literal["primal", "dual"]
Expand Down
7 changes: 2 additions & 5 deletions cooper/constraints/constraint_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ class ConstraintState:
contributes_to_dual_update: bool = True

def __post_init__(self):
if self.constraint_features is not None and self.violation is None:
raise ValueError("violation must be provided if constraint_features is provided.")

if self.strict_constraint_features is not None and self.strict_violation is None:
raise ValueError("strict_violation must be provided if strict_constraint_features is provided.")

Expand All @@ -73,9 +70,9 @@ def extract_violations(self, do_unsqueeze=True) -> tuple[torch.Tensor, torch.Ten
if do_unsqueeze:
# If the violation is a scalar, we unsqueeze it to ensure that it has at
# least one dimension for using einsum.
if len(violation.shape) == 0:
if violation.dim() == 0:
violation = violation.unsqueeze(0)
if len(strict_violation.shape) == 0:
if strict_violation.dim() == 0:
strict_violation = strict_violation.unsqueeze(0)

return violation, strict_violation
Expand Down
9 changes: 7 additions & 2 deletions cooper/formulations/formulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def __init__(self, constraint_type: ConstraintType):
def __repr__(self):
return f"{type(self).__name__}(constraint_type={self.constraint_type})"

def sanity_check_penalty_coefficient(self, penalty_coefficient: Optional[PenaltyCoefficient]) -> None:
if self.expects_penalty_coefficient and penalty_coefficient is None:
raise ValueError(f"{type(self).__name__} expects a penalty coefficient but none was provided.")
if not self.expects_penalty_coefficient and penalty_coefficient is not None:
raise ValueError(f"Received unexpected penalty coefficient for {type(self).__name__}.")

def _prepare_kwargs_for_lagrangian_contribution(
self,
constraint_state: ConstraintState,
Expand All @@ -40,8 +46,7 @@ def _prepare_kwargs_for_lagrangian_contribution(
primal_or_dual: Literal["primal", "dual"],
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:

if self.expects_penalty_coefficient and penalty_coefficient is None:
raise ValueError(f"{type(self).__name__} expects a penalty coefficient but none was provided.")
self.sanity_check_penalty_coefficient(penalty_coefficient)

violation, strict_violation = constraint_state.extract_violations()
constraint_features, strict_constraint_features = constraint_state.extract_constraint_features()
Expand Down
13 changes: 7 additions & 6 deletions cooper/formulations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,24 @@


def evaluate_constraint_factor(
module: Union[Multiplier, PenaltyCoefficient], constraint_features: torch.Tensor, expand_shape: torch.Tensor
module: Union[Multiplier, PenaltyCoefficient],
constraint_features: Optional[torch.Tensor],
expand_shape: tuple[int, ...],
) -> torch.Tensor:
"""Evaluate the Lagrange multiplier or penalty coefficient associated with a
constraint.

Args:
module: Multiplier or penalty coefficient module.
constraint_state: The current state of the constraint.
constraint_features: The observed features of the constraint.
expand_shape: Shape of the constraint violation tensor.
"""

# TODO(gallego-posada): This way of calling the modules assumes either 0 or 1
# arguments. This should be generalized to allow for multiple arguments.
value = module() if constraint_features is None else module(constraint_features)
value = module(constraint_features) if module.expects_constraint_features else module()

if len(value.shape) == 0:
if value.dim() == 0:
# Unsqueeze value to make it a 1D tensor for consistent use in Formulations' einsum calls
value.unsqueeze_(0)

Expand Down Expand Up @@ -155,5 +158,3 @@ def compute_primal_quadratic_augmented_contribution(
linear_term = compute_primal_weighted_violation(multiplier_value, violation)
quadratic_penalty = compute_quadratic_penalty(penalty_coefficient_value, violation, constraint_type)
return linear_term + quadratic_penalty
else:
raise ValueError(f"{constraint_type} is incompatible with quadratic penalties.")
29 changes: 16 additions & 13 deletions cooper/multipliers/multipliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@


class Multiplier(torch.nn.Module, abc.ABC):
expects_constraint_features: bool
constraint_type: ConstraintType

@abc.abstractmethod
def forward(self, *args, **kwargs):
"""Return the current value of the multiplier."""
Expand All @@ -27,6 +30,10 @@ def sanity_check(self):
# TODO(gallego-posada): Add docstring
pass

def set_constraint_type(self, constraint_type):
self.constraint_type = constraint_type
self.sanity_check()


class ExplicitMultiplier(Multiplier):
"""
Expand All @@ -46,23 +53,15 @@ class ExplicitMultiplier(Multiplier):

def __init__(
self,
constraint_type: ConstraintType,
num_constraints: Optional[int] = None,
init: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
):
super().__init__()

self.constraint_type = constraint_type
self.weight = self.initialize_weight(num_constraints=num_constraints, init=init, device=device, dtype=dtype)

self.sanity_check()

@property
def is_inequality(self):
return self.constraint_type == ConstraintType.INEQUALITY

def initialize_weight(
self,
num_constraints: Optional[int],
Expand Down Expand Up @@ -90,7 +89,7 @@ def device(self):
return self.weight.device

def sanity_check(self):
if self.is_inequality and torch.any(self.weight.data < 0):
if self.constraint_type == ConstraintType.INEQUALITY and torch.any(self.weight.data < 0):
raise ValueError("For inequality constraint, all entries in multiplier must be non-negative.")

@torch.no_grad()
Expand All @@ -100,12 +99,12 @@ def post_step_(self):
the dual optimizer, and ensures that (if required) the multipliers are
non-negative.
"""
if self.is_inequality:
if self.constraint_type == ConstraintType.INEQUALITY:
# Ensures non-negativity for multipliers associated with inequality constraints.
self.weight.data = torch.relu(self.weight.data)

def __repr__(self):
return f"{type(self).__name__}(constraint_type={self.constraint_type}, num_constraints={self.weight.shape[0]})"
return f"{type(self).__name__}(num_constraints={self.weight.shape[0]})"


class DenseMultiplier(ExplicitMultiplier):
Expand All @@ -120,6 +119,8 @@ class DenseMultiplier(ExplicitMultiplier):
:py:class:`~cooper.multipliers.IndexedMultiplier`.
"""

expects_constraint_features = False

def forward(self):
"""Return the current value of the multiplier."""
return torch.clone(self.weight)
Expand All @@ -139,8 +140,10 @@ class IndexedMultiplier(ExplicitMultiplier):
and memory-efficient sparse gradients (on GPU).
"""

def __init__(self, *args, **kwargs):
super(IndexedMultiplier, self).__init__(*args, **kwargs)
expects_constraint_features = True

def __init__(self, num_constraints=None, init=None, device=None, dtype=torch.float32):
super().__init__(num_constraints, init, device, dtype)
if self.weight.dim() == 1:
# To use the forward call in F.embedding, we must reshape the weight to be a
# 2-dim tensor
Expand Down
28 changes: 20 additions & 8 deletions cooper/multipliers/penalty_coefficients.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import warnings
from typing import Optional

import torch
Expand All @@ -12,12 +11,13 @@ class PenaltyCoefficient(abc.ABC):
init: Value of the penalty coefficient.
"""

expects_constraint_features: bool
_value: Optional[torch.Tensor] = None

def __init__(self, init: torch.Tensor):
if init.requires_grad:
raise ValueError("PenaltyCoefficient should not require gradients.")
if init.dim() > 1:
raise ValueError("init must either be a scalar or a 1D tensor of shape `(num_constraints,)`.")
self._value = init.clone()
self.value = init

@property
def value(self):
Expand All @@ -28,16 +28,17 @@ def value(self):
def value(self, value: torch.Tensor):
"""Update the value of the penalty."""
if value.requires_grad:
raise ValueError("New value of PenaltyCoefficient should not require gradients.")
if value.shape != self._value.shape:
warnings.warn(
raise ValueError("PenaltyCoefficient should not require gradients.")
if self._value is not None and value.shape != self._value.shape:
raise ValueError(
f"New shape {value.shape} of PenaltyCoefficient does not match existing shape {self._value.shape}."
)
self._value = value.clone()
self.sanity_check()

def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
"""Move the penalty to a new device and/or change its dtype."""
self.value = self.value.to(device=device, dtype=dtype)
self._value = self._value.to(device=device, dtype=dtype)
return self

def state_dict(self):
Expand All @@ -46,6 +47,10 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self._value = state_dict["value"]

def sanity_check(self):
if torch.any(self._value < 0):
raise ValueError("All entries of the penalty coefficient must be non-negative.")

def __repr__(self):
if self.value.numel() <= 10:
return f"{type(self).__name__}({self.value})"
Expand All @@ -61,6 +66,8 @@ def __call__(self, *args, **kwargs):
class DensePenaltyCoefficient(PenaltyCoefficient):
"""Constant (non-trainable) coefficient class used for Augmented Lagrangian formulation."""

expects_constraint_features = False

@torch.no_grad()
def __call__(self):
"""Return the current value of the penalty coefficient."""
Expand All @@ -73,6 +80,8 @@ class IndexedPenaltyCoefficient(PenaltyCoefficient):
value of the penalty for a subset of constraints.
"""

expects_constraint_features = True

@torch.no_grad()
def __call__(self, indices: torch.Tensor):
"""Return the current value of the penalty coefficient at the provided indices.
Expand All @@ -86,6 +95,9 @@ def __call__(self, indices: torch.Tensor):
# torch.nn.functional.embedding and *not* as masks.
raise ValueError("Indices must be of type torch.long.")

if self.value.dim() == 0:
return self.value.clone()

coefficient_values = torch.nn.functional.embedding(indices, self.value.unsqueeze(1), sparse=False)

# Flatten coefficient values to 1D since Embedding works with 2D tensors.
Expand Down
4 changes: 2 additions & 2 deletions cooper/optim/constrained_optimizers/constrained_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def base_sanity_checks(self):
"""

if self.primal_optimizers is None:
raise RuntimeError("No primal optimizer(s) was provided for building a ConstrainedOptimizer.")
raise TypeError("No primal optimizer(s) was provided for building a ConstrainedOptimizer.")
if self.dual_optimizers is None:
raise RuntimeError("No dual optimizer(s) was provided for building a ConstrainedOptimizer.")
raise TypeError("No dual optimizer(s) was provided for building a ConstrainedOptimizer.")
for dual_optimizer in self.dual_optimizers:
for param_group in dual_optimizer.param_groups:
if not param_group["maximize"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def primal_extrapolation_step(self):
"""
Perform an extrapolation step on the parameters associated with the primal variables.
"""
if not all(hasattr(primal_optimizer, "extrapolation") for primal_optimizer in self.primal_optimizers):
raise ValueError("All primal optimizers must implement an `extrapolation` method.")

for primal_optimizer in self.primal_optimizers:
primal_optimizer.extrapolation()

Expand All @@ -53,9 +50,6 @@ def dual_extrapolation_step(self):
After being updated by the dual optimizer steps, the multipliers are
post-processed (e.g. to ensure non-negativity for inequality constraints).
"""
if not all(hasattr(dual_optimizer, "extrapolation") for dual_optimizer in self.dual_optimizers):
raise ValueError("All dual optimizers must implement an `extrapolation` method.")

# Update multipliers based on current constraint violations (gradients)
# For unobserved constraints the gradient is None, so this is a no-op.
for dual_optimizer in self.dual_optimizers:
Expand Down
9 changes: 6 additions & 3 deletions cooper/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,22 @@ def __init__(
):
self.cmp = cmp
self.primal_optimizers = ensure_sequence(primal_optimizers)
self.dual_optimizers = ensure_sequence(dual_optimizers) if dual_optimizers is not None else None
self.dual_optimizers = ensure_sequence(dual_optimizers)

def zero_grad(self):
"""
Sets the gradients of all optimized :py:class:`~torch.nn.parameter.Parameter`\\s
to zero. This includes both the primal and dual variables.
"""
for primal_optimizer in self.primal_optimizers:
primal_optimizer.zero_grad()
# Prior to PyTorch 2.0, set_to_none=False was the default behavior.
# The default behavior was changed to set_to_none=True in PyTorch 2.0.
# We set set_to_none=True explicitly to ensure compatibility with both versions.
primal_optimizer.zero_grad(set_to_none=True)

if self.dual_optimizers is not None:
for dual_optimizer in self.dual_optimizers:
dual_optimizer.zero_grad()
dual_optimizer.zero_grad(set_to_none=True)

@torch.no_grad()
def primal_step(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def update_penalty_coefficient_(self, constraint: Constraint, constraint_state:
# less than the tolerance, the penalty value is left unchanged.
new_value = torch.where(condition, observed_penalty_values * self.growth_factor, observed_penalty_values)

if isinstance(penalty_coefficient, DensePenaltyCoefficient):
penalty_coefficient.value = new_value.detach()
elif isinstance(penalty_coefficient, IndexedPenaltyCoefficient):
if isinstance(penalty_coefficient, IndexedPenaltyCoefficient) and new_value.dim() > 0:
penalty_coefficient.value[strict_constraint_features] = new_value.detach()
else:
penalty_coefficient.value = new_value.detach()
Loading