Skip to content

Commit

Permalink
Merge pull request #91 from cooper-org/merajhashemi/refactor-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
merajhashemi authored Jul 8, 2024
2 parents 9f94213 + 3ee14f9 commit 1b5b926
Show file tree
Hide file tree
Showing 47 changed files with 2,190 additions and 1,755 deletions.
7 changes: 1 addition & 6 deletions cooper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
"""Top-level package for Cooper."""

import sys

if sys.version_info >= (3, 8):
from importlib.metadata import PackageNotFoundError, version
else:
from importlib_metadata import PackageNotFoundError, version
from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version("cooper")
Expand Down
10 changes: 9 additions & 1 deletion cooper/cmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def observed_strict_violations(self):
for constraint_state in self.observed_constraints.values():
yield constraint_state.strict_violation

def observed_constraint_features(self):
for constraint_state in self.observed_constraints.values():
yield constraint_state.constraint_features

def observed_strict_constraint_features(self):
for constraint_state in self.observed_constraints.values():
yield constraint_state.strict_constraint_features


class ConstrainedMinimizationProblem(abc.ABC):
"""Template for constrained minimization problems."""
Expand Down Expand Up @@ -177,7 +185,7 @@ def named_penalty_coefficients(self) -> Iterator[tuple[str, PenaltyCoefficient]]
if constraint.penalty_coefficient is not None:
yield constraint_name, constraint.penalty_coefficient

def dual_parameters(self) -> Iterator[Multiplier]:
def dual_parameters(self) -> Iterator[torch.nn.Parameter]:
"""Return an iterator over the parameters of the multipliers associated with the
registered constraints of the CMP. This method is useful for instantiating the
dual optimizers. If a multiplier is shared by several constraints, we only
Expand Down
21 changes: 2 additions & 19 deletions cooper/constraints/constraint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Literal, Optional, Type

import torch

from cooper.constraints.constraint_state import ConstraintState
from cooper.constraints.constraint_type import ConstraintType
from cooper.formulations import ContributionStore, Formulation, LagrangianFormulation
Expand All @@ -26,25 +24,10 @@ def __init__(
self.formulation = formulation_type(constraint_type=self.constraint_type)

self.multiplier = multiplier
if multiplier.constraint_type != self.constraint_type:
raise ValueError(
f"Attempted to pair {self.constraint_type} constraint, with {multiplier.constraint_type} multiplier."
)
self.multiplier.sanity_check()
self.multiplier.set_constraint_type(constraint_type)

self.penalty_coefficient = penalty_coefficient
self.sanity_check_penalty_coefficient()

def sanity_check_penalty_coefficient(self) -> None:
if self.formulation.expects_penalty_coefficient:
if self.penalty_coefficient is None:
raise ValueError(f"{self.formulation_type} expects a penalty coefficient but none was provided.")
else:
if torch.any(self.penalty_coefficient.value < 0):
raise ValueError("All entries of the penalty coefficient must be non-negative.")
else:
if self.penalty_coefficient is not None:
raise ValueError(f"Received unexpected penalty coefficient for {self.formulation_type}.")
self.formulation.sanity_check_penalty_coefficient(penalty_coefficient)

def compute_contribution_to_lagrangian(
self, constraint_state: ConstraintState, primal_or_dual: Literal["primal", "dual"]
Expand Down
7 changes: 2 additions & 5 deletions cooper/constraints/constraint_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ class ConstraintState:
contributes_to_dual_update: bool = True

def __post_init__(self):
if self.constraint_features is not None and self.violation is None:
raise ValueError("violation must be provided if constraint_features is provided.")

if self.strict_constraint_features is not None and self.strict_violation is None:
raise ValueError("strict_violation must be provided if strict_constraint_features is provided.")

Expand All @@ -73,9 +70,9 @@ def extract_violations(self, do_unsqueeze=True) -> tuple[torch.Tensor, torch.Ten
if do_unsqueeze:
# If the violation is a scalar, we unsqueeze it to ensure that it has at
# least one dimension for using einsum.
if len(violation.shape) == 0:
if violation.dim() == 0:
violation = violation.unsqueeze(0)
if len(strict_violation.shape) == 0:
if strict_violation.dim() == 0:
strict_violation = strict_violation.unsqueeze(0)

return violation, strict_violation
Expand Down
9 changes: 7 additions & 2 deletions cooper/formulations/formulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def __init__(self, constraint_type: ConstraintType):
def __repr__(self):
return f"{type(self).__name__}(constraint_type={self.constraint_type})"

def sanity_check_penalty_coefficient(self, penalty_coefficient: Optional[PenaltyCoefficient]) -> None:
if self.expects_penalty_coefficient and penalty_coefficient is None:
raise ValueError(f"{type(self).__name__} expects a penalty coefficient but none was provided.")
if not self.expects_penalty_coefficient and penalty_coefficient is not None:
raise ValueError(f"Received unexpected penalty coefficient for {type(self).__name__}.")

def _prepare_kwargs_for_lagrangian_contribution(
self,
constraint_state: ConstraintState,
Expand All @@ -40,8 +46,7 @@ def _prepare_kwargs_for_lagrangian_contribution(
primal_or_dual: Literal["primal", "dual"],
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:

if self.expects_penalty_coefficient and penalty_coefficient is None:
raise ValueError(f"{type(self).__name__} expects a penalty coefficient but none was provided.")
self.sanity_check_penalty_coefficient(penalty_coefficient)

violation, strict_violation = constraint_state.extract_violations()
constraint_features, strict_constraint_features = constraint_state.extract_constraint_features()
Expand Down
13 changes: 7 additions & 6 deletions cooper/formulations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,24 @@


def evaluate_constraint_factor(
module: Union[Multiplier, PenaltyCoefficient], constraint_features: torch.Tensor, expand_shape: torch.Tensor
module: Union[Multiplier, PenaltyCoefficient],
constraint_features: Optional[torch.Tensor],
expand_shape: tuple[int, ...],
) -> torch.Tensor:
"""Evaluate the Lagrange multiplier or penalty coefficient associated with a
constraint.
Args:
module: Multiplier or penalty coefficient module.
constraint_state: The current state of the constraint.
constraint_features: The observed features of the constraint.
expand_shape: Shape of the constraint violation tensor.
"""

# TODO(gallego-posada): This way of calling the modules assumes either 0 or 1
# arguments. This should be generalized to allow for multiple arguments.
value = module() if constraint_features is None else module(constraint_features)
value = module(constraint_features) if module.expects_constraint_features else module()

if len(value.shape) == 0:
if value.dim() == 0:
# Unsqueeze value to make it a 1D tensor for consistent use in Formulations' einsum calls
value.unsqueeze_(0)

Expand Down Expand Up @@ -155,5 +158,3 @@ def compute_primal_quadratic_augmented_contribution(
linear_term = compute_primal_weighted_violation(multiplier_value, violation)
quadratic_penalty = compute_quadratic_penalty(penalty_coefficient_value, violation, constraint_type)
return linear_term + quadratic_penalty
else:
raise ValueError(f"{constraint_type} is incompatible with quadratic penalties.")
29 changes: 16 additions & 13 deletions cooper/multipliers/multipliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@


class Multiplier(torch.nn.Module, abc.ABC):
expects_constraint_features: bool
constraint_type: ConstraintType

@abc.abstractmethod
def forward(self, *args, **kwargs):
"""Return the current value of the multiplier."""
Expand All @@ -27,6 +30,10 @@ def sanity_check(self):
# TODO(gallego-posada): Add docstring
pass

def set_constraint_type(self, constraint_type):
self.constraint_type = constraint_type
self.sanity_check()


class ExplicitMultiplier(Multiplier):
"""
Expand All @@ -46,23 +53,15 @@ class ExplicitMultiplier(Multiplier):

def __init__(
self,
constraint_type: ConstraintType,
num_constraints: Optional[int] = None,
init: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
):
super().__init__()

self.constraint_type = constraint_type
self.weight = self.initialize_weight(num_constraints=num_constraints, init=init, device=device, dtype=dtype)

self.sanity_check()

@property
def is_inequality(self):
return self.constraint_type == ConstraintType.INEQUALITY

def initialize_weight(
self,
num_constraints: Optional[int],
Expand Down Expand Up @@ -90,7 +89,7 @@ def device(self):
return self.weight.device

def sanity_check(self):
if self.is_inequality and torch.any(self.weight.data < 0):
if self.constraint_type == ConstraintType.INEQUALITY and torch.any(self.weight.data < 0):
raise ValueError("For inequality constraint, all entries in multiplier must be non-negative.")

@torch.no_grad()
Expand All @@ -100,12 +99,12 @@ def post_step_(self):
the dual optimizer, and ensures that (if required) the multipliers are
non-negative.
"""
if self.is_inequality:
if self.constraint_type == ConstraintType.INEQUALITY:
# Ensures non-negativity for multipliers associated with inequality constraints.
self.weight.data = torch.relu(self.weight.data)

def __repr__(self):
return f"{type(self).__name__}(constraint_type={self.constraint_type}, num_constraints={self.weight.shape[0]})"
return f"{type(self).__name__}(num_constraints={self.weight.shape[0]})"


class DenseMultiplier(ExplicitMultiplier):
Expand All @@ -120,6 +119,8 @@ class DenseMultiplier(ExplicitMultiplier):
:py:class:`~cooper.multipliers.IndexedMultiplier`.
"""

expects_constraint_features = False

def forward(self):
"""Return the current value of the multiplier."""
return torch.clone(self.weight)
Expand All @@ -139,8 +140,10 @@ class IndexedMultiplier(ExplicitMultiplier):
and memory-efficient sparse gradients (on GPU).
"""

def __init__(self, *args, **kwargs):
super(IndexedMultiplier, self).__init__(*args, **kwargs)
expects_constraint_features = True

def __init__(self, num_constraints=None, init=None, device=None, dtype=torch.float32):
super().__init__(num_constraints, init, device, dtype)
if self.weight.dim() == 1:
# To use the forward call in F.embedding, we must reshape the weight to be a
# 2-dim tensor
Expand Down
28 changes: 20 additions & 8 deletions cooper/multipliers/penalty_coefficients.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import warnings
from typing import Optional

import torch
Expand All @@ -12,12 +11,13 @@ class PenaltyCoefficient(abc.ABC):
init: Value of the penalty coefficient.
"""

expects_constraint_features: bool
_value: Optional[torch.Tensor] = None

def __init__(self, init: torch.Tensor):
if init.requires_grad:
raise ValueError("PenaltyCoefficient should not require gradients.")
if init.dim() > 1:
raise ValueError("init must either be a scalar or a 1D tensor of shape `(num_constraints,)`.")
self._value = init.clone()
self.value = init

@property
def value(self):
Expand All @@ -28,16 +28,17 @@ def value(self):
def value(self, value: torch.Tensor):
"""Update the value of the penalty."""
if value.requires_grad:
raise ValueError("New value of PenaltyCoefficient should not require gradients.")
if value.shape != self._value.shape:
warnings.warn(
raise ValueError("PenaltyCoefficient should not require gradients.")
if self._value is not None and value.shape != self._value.shape:
raise ValueError(
f"New shape {value.shape} of PenaltyCoefficient does not match existing shape {self._value.shape}."
)
self._value = value.clone()
self.sanity_check()

def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
"""Move the penalty to a new device and/or change its dtype."""
self.value = self.value.to(device=device, dtype=dtype)
self._value = self._value.to(device=device, dtype=dtype)
return self

def state_dict(self):
Expand All @@ -46,6 +47,10 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self._value = state_dict["value"]

def sanity_check(self):
if torch.any(self._value < 0):
raise ValueError("All entries of the penalty coefficient must be non-negative.")

def __repr__(self):
if self.value.numel() <= 10:
return f"{type(self).__name__}({self.value})"
Expand All @@ -61,6 +66,8 @@ def __call__(self, *args, **kwargs):
class DensePenaltyCoefficient(PenaltyCoefficient):
"""Constant (non-trainable) coefficient class used for Augmented Lagrangian formulation."""

expects_constraint_features = False

@torch.no_grad()
def __call__(self):
"""Return the current value of the penalty coefficient."""
Expand All @@ -73,6 +80,8 @@ class IndexedPenaltyCoefficient(PenaltyCoefficient):
value of the penalty for a subset of constraints.
"""

expects_constraint_features = True

@torch.no_grad()
def __call__(self, indices: torch.Tensor):
"""Return the current value of the penalty coefficient at the provided indices.
Expand All @@ -86,6 +95,9 @@ def __call__(self, indices: torch.Tensor):
# torch.nn.functional.embedding and *not* as masks.
raise ValueError("Indices must be of type torch.long.")

if self.value.dim() == 0:
return self.value.clone()

coefficient_values = torch.nn.functional.embedding(indices, self.value.unsqueeze(1), sparse=False)

# Flatten coefficient values to 1D since Embedding works with 2D tensors.
Expand Down
4 changes: 2 additions & 2 deletions cooper/optim/constrained_optimizers/constrained_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def base_sanity_checks(self):
"""

if self.primal_optimizers is None:
raise RuntimeError("No primal optimizer(s) was provided for building a ConstrainedOptimizer.")
raise TypeError("No primal optimizer(s) was provided for building a ConstrainedOptimizer.")
if self.dual_optimizers is None:
raise RuntimeError("No dual optimizer(s) was provided for building a ConstrainedOptimizer.")
raise TypeError("No dual optimizer(s) was provided for building a ConstrainedOptimizer.")
for dual_optimizer in self.dual_optimizers:
for param_group in dual_optimizer.param_groups:
if not param_group["maximize"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def primal_extrapolation_step(self):
"""
Perform an extrapolation step on the parameters associated with the primal variables.
"""
if not all(hasattr(primal_optimizer, "extrapolation") for primal_optimizer in self.primal_optimizers):
raise ValueError("All primal optimizers must implement an `extrapolation` method.")

for primal_optimizer in self.primal_optimizers:
primal_optimizer.extrapolation()

Expand All @@ -53,9 +50,6 @@ def dual_extrapolation_step(self):
After being updated by the dual optimizer steps, the multipliers are
post-processed (e.g. to ensure non-negativity for inequality constraints).
"""
if not all(hasattr(dual_optimizer, "extrapolation") for dual_optimizer in self.dual_optimizers):
raise ValueError("All dual optimizers must implement an `extrapolation` method.")

# Update multipliers based on current constraint violations (gradients)
# For unobserved constraints the gradient is None, so this is a no-op.
for dual_optimizer in self.dual_optimizers:
Expand Down
9 changes: 6 additions & 3 deletions cooper/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,22 @@ def __init__(
):
self.cmp = cmp
self.primal_optimizers = ensure_sequence(primal_optimizers)
self.dual_optimizers = ensure_sequence(dual_optimizers) if dual_optimizers is not None else None
self.dual_optimizers = ensure_sequence(dual_optimizers)

def zero_grad(self):
"""
Sets the gradients of all optimized :py:class:`~torch.nn.parameter.Parameter`\\s
to zero. This includes both the primal and dual variables.
"""
for primal_optimizer in self.primal_optimizers:
primal_optimizer.zero_grad()
# Prior to PyTorch 2.0, set_to_none=False was the default behavior.
# The default behavior was changed to set_to_none=True in PyTorch 2.0.
# We set set_to_none=True explicitly to ensure compatibility with both versions.
primal_optimizer.zero_grad(set_to_none=True)

if self.dual_optimizers is not None:
for dual_optimizer in self.dual_optimizers:
dual_optimizer.zero_grad()
dual_optimizer.zero_grad(set_to_none=True)

@torch.no_grad()
def primal_step(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def update_penalty_coefficient_(self, constraint: Constraint, constraint_state:
# less than the tolerance, the penalty value is left unchanged.
new_value = torch.where(condition, observed_penalty_values * self.growth_factor, observed_penalty_values)

if isinstance(penalty_coefficient, DensePenaltyCoefficient):
penalty_coefficient.value = new_value.detach()
elif isinstance(penalty_coefficient, IndexedPenaltyCoefficient):
if isinstance(penalty_coefficient, IndexedPenaltyCoefficient) and new_value.dim() > 0:
penalty_coefficient.value[strict_constraint_features] = new_value.detach()
else:
penalty_coefficient.value = new_value.detach()
Loading

0 comments on commit 1b5b926

Please sign in to comment.