diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 24bc29895..78d62e805 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -38,6 +38,7 @@ Transitions, ) from imitation.policies import exploration_wrapper +from imitation.regularization import regularizers from imitation.rewards import reward_function, reward_nets, reward_wrapper from imitation.util import logger as imit_logger from imitation.util import networks, util @@ -389,7 +390,7 @@ def forward( `ensemble_member_index`. Args: - fragment_pairs: batch of pair of fragments. + fragment_pairs: batch of pairs of fragments. ensemble_member_index: index of member network in ensemble model. If the model is an ensemble of networks, this cannot be None. @@ -432,10 +433,7 @@ def forward( return probs, (gt_probs if gt_reward_available else None) - def rewards( - self, - transitions: Transitions, - ) -> th.Tensor: + def rewards(self, transitions: Transitions) -> th.Tensor: """Computes the reward for all transitions. Args: @@ -460,11 +458,7 @@ def rewards( assert rews.shape == (len(state),) return rews - def probability( - self, - rews1: th.Tensor, - rews2: th.Tensor, - ) -> th.Tensor: + def probability(self, rews1: th.Tensor, rews2: th.Tensor) -> th.Tensor: """Computes the Boltzmann rational probability that the first trajectory is best. Args: @@ -897,11 +891,7 @@ def __init__(self, max_size: Optional[int] = None): self.max_size = max_size self.preferences = np.array([]) - def push( - self, - fragments: Sequence[TrajectoryWithRewPair], - preferences: np.ndarray, - ): + def push(self, fragments: Sequence[TrajectoryWithRewPair], preferences: np.ndarray): """Add more samples to the dataset. Args: @@ -917,7 +907,7 @@ def push( if preferences.shape != (len(fragments),): raise ValueError( f"Unexpected preferences shape {preferences.shape}, " - f"expected {(len(fragments), )}", + f"expected {(len(fragments),)}", ) if preferences.dtype != np.float32: raise ValueError("preferences should have dtype float32") @@ -998,10 +988,7 @@ def _trajectory_pair_includes_reward(fragment_pair: TrajectoryPair): class CrossEntropyRewardLoss(RewardLoss): """Compute the cross entropy reward loss.""" - def __init__( - self, - preference_model: PreferenceModel, - ): + def __init__(self, preference_model: PreferenceModel): """Create cross entropy reward loss. Args: @@ -1032,13 +1019,13 @@ def forward( """ probs, gt_probs = self.preference_model(fragment_pairs, ensemble_member_index) # TODO(ejnnr): Here and below, > 0.5 is problematic - # because getting exactly 0.5 is actually somewhat - # common in some environments (as long as sample=False or temperature=0). - # In a sense that "only" creates class imbalance - # but it's still misleading. - predictions = (probs > 0.5).float() + # because getting exactly 0.5 is actually somewhat + # common in some environments (as long as sample=False or temperature=0). + # In a sense that "only" creates class imbalance + # but it's still misleading. + predictions = probs > 0.5 preferences_th = th.as_tensor(preferences, dtype=th.float32) - ground_truth = (preferences_th > 0.5).float() + ground_truth = preferences_th > 0.5 metrics = {} metrics["accuracy"] = (predictions == ground_truth).float().mean() if gt_probs is not None: @@ -1094,6 +1081,8 @@ def _train(self, dataset: PreferenceDataset, epoch_multiplier: float) -> None: class BasicRewardTrainer(RewardTrainer): """Train a basic reward model.""" + regularizer: Optional[regularizers.Regularizer] + def __init__( self, model: reward_nets.RewardNet, @@ -1101,8 +1090,9 @@ def __init__( batch_size: int = 32, epochs: int = 1, lr: float = 1e-3, - weight_decay: float = 0.0, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + seed: Optional[int] = None, + regularizer_factory: Optional[regularizers.RegularizerFactory] = None, ): """Initialize the reward model trainer. @@ -1114,19 +1104,24 @@ def __init__( on the fly by specifying an `epoch_multiplier` in `self.train()` if longer training is desired in specific cases). lr: the learning rate - weight_decay: the weight decay factor for the reward model's weights - to use with ``th.optim.AdamW``. This is similar to but not equivalent - to L2 regularization, see https://arxiv.org/abs/1711.05101 + seed: the random seed to use for splitting the dataset into training + and validation. custom_logger: Where to log to; if None (default), creates a new logger. + regularizer_factory: if you would like to apply regularization during + training, specify a regularizer factory here. The factory will be + used to construct a regularizer. See + ``imitation.regularization.RegularizerFactory`` for more details. """ super().__init__(model, custom_logger) self.loss = loss self.batch_size = batch_size self.epochs = epochs - self.optim = th.optim.AdamW( - self._model.parameters(), - lr=lr, - weight_decay=weight_decay, + self.optim = th.optim.AdamW(self._model.parameters(), lr=lr) + self.seed = seed + self.regularizer = ( + regularizer_factory(optimizer=self.optim, logger=self.logger) + if regularizer_factory is not None + else None ) def _make_data_loader(self, dataset: PreferenceDataset) -> data_th.DataLoader: @@ -1138,30 +1133,104 @@ def _make_data_loader(self, dataset: PreferenceDataset) -> data_th.DataLoader: collate_fn=preference_collate_fn, ) + @property + def requires_regularizer_update(self) -> bool: + """Whether the regularizer requires updating. + + Returns: + If true, this means that a validation dataset will be used. + """ + return self.regularizer is not None and self.regularizer.val_split is not None + def _train(self, dataset: PreferenceDataset, epoch_multiplier: float = 1.0) -> None: """Trains for `epoch_multiplier * self.epochs` epochs over `dataset`.""" - dataloader = self._make_data_loader(dataset) + if self.regularizer is not None and self.regularizer.val_split is not None: + val_length = int(len(dataset) * self.regularizer.val_split) + train_length = len(dataset) - val_length + if val_length < 1 or train_length < 1: + raise ValueError( + "Not enough data samples to split into training and validation, " + "or the validation split is too large/small. " + "Make sure you've generated enough initial preference data. " + "You can adjust this through initial_comparison_frac in " + "PreferenceComparisons.", + ) + train_dataset, val_dataset = data_th.random_split( + dataset, + lengths=[train_length, val_length], + generator=th.Generator().manual_seed(self.seed) if self.seed else None, + ) + dataloader = self._make_data_loader(train_dataset) + val_dataloader = self._make_data_loader(val_dataset) + else: + dataloader = self._make_data_loader(dataset) + val_dataloader = None + epochs = round(self.epochs * epoch_multiplier) - for _ in tqdm(range(epochs), desc="Training reward model"): - for fragment_pairs, preferences in dataloader: - self.optim.zero_grad() - loss = self._training_inner_loop(fragment_pairs, preferences) - loss.backward() - self.optim.step() + assert epochs > 0, "Must train for at least one epoch." + epoch_num = 0 + with self.logger.accumulate_means("reward"): + for epoch_num in tqdm(range(epochs), desc="Training reward model"): + prefix = f"epoch-{epoch_num}" + train_loss = 0.0 + for fragment_pairs, preferences in dataloader: + self.optim.zero_grad() + loss = self._training_inner_loop( + fragment_pairs, + preferences, + prefix=f"{prefix}/train", + ) + train_loss += loss.item() + if self.regularizer: + self.regularizer.regularize_and_backward(loss) + else: + loss.backward() + self.optim.step() + + if not self.requires_regularizer_update: + continue + assert val_dataloader is not None + assert self.regularizer is not None + + val_loss = 0.0 + for fragment_pairs, preferences in val_dataloader: + loss = self._training_inner_loop( + fragment_pairs, + preferences, + prefix=f"{prefix}/val", + ) + val_loss += loss.item() + self.regularizer.update_params(train_loss, val_loss) + + # after training all the epochs, + # record also the final value in a separate key for easy access. + keys = list(self.logger.name_to_value.keys()) + for key in keys: + if key.startswith(f"mean/reward/epoch-{epoch_num}"): + val = self.logger.name_to_value[key] + new_key = key.replace(f"mean/reward/epoch-{epoch_num}", "reward/final") + self.logger.record(new_key, val) def _training_inner_loop( self, fragment_pairs: Sequence[TrajectoryPair], preferences: np.ndarray, + prefix: Optional[str] = None, ) -> th.Tensor: output = self.loss.forward(fragment_pairs, preferences) loss = output.loss - self.logger.record("loss", loss.item()) + self.logger.record(self._get_logger_key(prefix, "loss"), loss.item()) for name, value in output.metrics.items(): - self.logger.record(name, value.item()) + self.logger.record(self._get_logger_key(prefix, name), value.item()) return loss + # TODO(juan) refactor & remove once #529 is merged. + def _get_logger_key(self, mode: Optional[str], key: str) -> str: + if mode is None: + return key + return f"{mode}/{key}" + class EnsembleTrainer(BasicRewardTrainer): """Train a reward ensemble.""" @@ -1175,9 +1244,9 @@ def __init__( batch_size: int = 32, epochs: int = 1, lr: float = 1e-3, - weight_decay: float = 0.0, - seed: Optional[int] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + seed: Optional[int] = None, + regularizer_factory: Optional[regularizers.RegularizerFactory] = None, ): """Initialize the reward model trainer. @@ -1189,28 +1258,29 @@ def __init__( on the fly by specifying an `epoch_multiplier` in `self.train()` if longer training is desired in specific cases). lr: the learning rate - weight_decay: the weight decay factor for the reward model's weights - to use with ``th.optim.AdamW``. This is similar to but not equivalent - to L2 regularization, see https://arxiv.org/abs/1711.05101 - seed: seed for the internal RNG used in bagging + seed: the random seed to use for splitting the dataset into training + and validation, and for bagging. custom_logger: Where to log to; if None (default), creates a new logger. + regularizer_factory: A factory for creating a regularizer. If None, + no regularization is used. Raises: TypeError: if model is not a RewardEnsemble. """ if not isinstance(model, reward_nets.RewardEnsemble): raise TypeError( - f"RewardEnsemble expected by EnsembleTrainer not {type(model)}.", + f"RewardEnsemble expected by EnsembleTrainer, not {type(model)}.", ) super().__init__( - model, - loss, - batch_size, - epochs, - lr, - weight_decay, - custom_logger, + model=model, + loss=loss, + batch_size=batch_size, + epochs=epochs, + lr=lr, + custom_logger=custom_logger, + seed=seed, + regularizer_factory=regularizer_factory, ) self.rng = np.random.default_rng(seed=seed) @@ -1218,6 +1288,7 @@ def _training_inner_loop( self, fragment_pairs: Sequence[TrajectoryPair], preferences: np.ndarray, + prefix: Optional[str] = None, ) -> th.Tensor: assert len(fragment_pairs) == preferences.shape[0] losses = [] @@ -1237,14 +1308,17 @@ def _training_inner_loop( losses = th.stack(losses) loss = losses.sum() - self.logger.record("loss", loss.item()) - self.logger.record("loss_std", losses.std().item()) + self.logger.record(self._get_logger_key(prefix, "loss"), loss.item()) + self.logger.record( + self._get_logger_key(prefix, "loss_std"), + losses.std().item(), + ) # Turn metrics from a list of dictionaries into a dictionary of # tensors. metrics = {k: th.stack([di[k] for di in metrics]) for k in metrics[0]} for name, value in metrics.items(): - self.logger.record(name, value.mean().item()) + self.logger.record(self._get_logger_key(prefix, name), value.mean().item()) return loss @@ -1285,11 +1359,7 @@ def _make_reward_trainer( f" by AddSTDRewardWrapper but found {type(reward_model).__name__}.", ) else: - return BasicRewardTrainer( - reward_model, - loss=loss, - **reward_trainer_kwargs, - ) + return BasicRewardTrainer(reward_model, loss=loss, **reward_trainer_kwargs) QUERY_SCHEDULES: Dict[str, type_aliases.Schedule] = { @@ -1506,13 +1576,9 @@ def train( if i == 0: epoch_multiplier = self.initial_epoch_multiplier - with self.logger.accumulate_means("reward"): - self.reward_trainer.train( - self.dataset, - epoch_multiplier=epoch_multiplier, - ) - reward_loss = self.logger.name_to_value["mean/reward/loss"] - reward_accuracy = self.logger.name_to_value["mean/reward/accuracy"] + self.reward_trainer.train(self.dataset, epoch_multiplier=epoch_multiplier) + reward_loss = self.logger.name_to_value["reward/final/train/loss"] + reward_accuracy = self.logger.name_to_value["reward/final/train/accuracy"] ################### # Train the agent # diff --git a/src/imitation/regularization/__init__.py b/src/imitation/regularization/__init__.py new file mode 100644 index 000000000..5e40ad61b --- /dev/null +++ b/src/imitation/regularization/__init__.py @@ -0,0 +1 @@ +"""Implements a variety of regularization techniques for NN weights.""" diff --git a/src/imitation/regularization/regularizers.py b/src/imitation/regularization/regularizers.py new file mode 100644 index 000000000..b38141274 --- /dev/null +++ b/src/imitation/regularization/regularizers.py @@ -0,0 +1,306 @@ +"""Implements the regularizer base class and some standard regularizers.""" + +import abc +from typing import Generic, Optional, Protocol, Type, TypeVar, Union + +import numpy as np +import torch as th +from torch import optim + +from imitation.regularization import updaters +from imitation.util import logger as imit_logger + +# this is not actually a scalar, dimension check is still required for tensor. +Scalar = Union[th.Tensor, float] + +R = TypeVar("R") +Self = TypeVar("Self", bound="Regularizer") +T_Regularizer_co = TypeVar( # pytype: disable=not-supported-yet + "T_Regularizer_co", + covariant=True, +) + + +class RegularizerFactory(Protocol[T_Regularizer_co]): + """Protocol for functions that create regularizers. + + The regularizer factory is meant to be used as a way to create a regularizer + in two steps. First, the end-user creates a regularizer factory by calling + the `.create()` method of a regularizer class. This allows specifying + all the relevant configuration to the regularization algorithm. Then, the + network algorithm finishes setting up the optimizer and logger, and calls + the regularizer factory to create the regularizer. + + This two-step process separates the configuration of the regularization + algorithm from additional "operational" parameters. This is useful because it + solves two problems: + + #. The end-user does not have access to the optimizer and logger when + configuring the regularization algorithm. + #. Validation of the configuration is done outside the network constructor. + + It also allows re-using the same regularizer factory for multiple networks. + """ + + def __call__( + self, + *, + optimizer: optim.Optimizer, + logger: imit_logger.HierarchicalLogger, + ) -> T_Regularizer_co: + """Constructs a regularizer from the factory. + + Args: + optimizer: The optimizer used by the network. + logger: The logger used by the network. + """ + + +class Regularizer(abc.ABC, Generic[R]): + """Abstract class for creating regularizers with a common interface.""" + + optimizer: optim.Optimizer + lambda_: float + lambda_updater: Optional[updaters.LambdaUpdater] + logger: imit_logger.HierarchicalLogger + val_split: Optional[float] + + def __init__( + self, + optimizer: optim.Optimizer, + initial_lambda: float, + lambda_updater: Optional[updaters.LambdaUpdater], + logger: imit_logger.HierarchicalLogger, + val_split: Optional[float] = None, + ) -> None: + """Initialize the regularizer. + + Args: + optimizer: The optimizer to which the regularizer is attached. + initial_lambda: The initial value of the regularization parameter. + lambda_updater: A callable object that takes in the current lambda and + the train and val loss, and returns the new lambda. + logger: The logger to which the regularizer will log its parameters. + val_split: The fraction of the training data to use as validation data + for the lambda updater. Can be none if no lambda updater is provided. + + Raises: + ValueError: if no lambda updater (``lambda_updater``) is provided and the + initial regularization strength (``initial_lambda``) is zero. + ValueError: if a validation split (``val_split``) is provided but it's not a + float in the (0, 1) interval. + ValueError: if a lambda updater is provided but no validation split + is provided. + ValueError: if a validation split is set, but no lambda updater is + provided. + """ + if lambda_updater is None and np.allclose(initial_lambda, 0.0): + raise ValueError( + "If you do not pass a regularizer parameter updater your " + "regularization strength must be non-zero, as this would " + "result in no regularization.", + ) + + if val_split is not None and ( + not isinstance(val_split, float) + or np.allclose(val_split, 0.0) + or val_split <= 0 + or val_split >= 1 + ): + raise ValueError( + f"val_split = {val_split} must be a float strictly between 0 and 1.", + ) + + if lambda_updater is not None and val_split is None: + raise ValueError( + "If you pass a regularizer parameter updater, you must also " + "specify a validation split. Otherwise the updater won't have any " + "validation data to use for updating.", + ) + elif lambda_updater is None and val_split is not None: + raise ValueError( + "If you pass a validation split, you must also " + "pass a regularizer parameter updater. Otherwise you are wasting" + " data into the validation split that will not be used.", + ) + + self.optimizer = optimizer + self.lambda_ = initial_lambda + self.lambda_updater = lambda_updater + self.logger = logger + self.val_split = val_split + + self.logger.record("regularization_lambda", self.lambda_) + + @classmethod + def create( + cls: Type[Self], + initial_lambda: float, + lambda_updater: Optional[updaters.LambdaUpdater] = None, + val_split: float = 0.0, + **kwargs, + ) -> RegularizerFactory[Self]: + """Create a regularizer.""" + + def factory( + *, + optimizer: optim.Optimizer, + logger: imit_logger.HierarchicalLogger, + ) -> Self: + return cls( + initial_lambda=initial_lambda, + optimizer=optimizer, + lambda_updater=lambda_updater, + logger=logger, + val_split=val_split, + **kwargs, + ) + + return factory + + @abc.abstractmethod + def regularize_and_backward(self, loss: th.Tensor) -> R: + """Abstract method for performing the regularization step. + + The return type is a generic and the specific implementation + must describe the meaning of the return type. + + This step will also call `loss.backward()` for the user. + This is because the regularizer may require the + loss to be called before or after the regularization step. + Leaving this to the user would force them to make their + implementation dependent on the regularizer algorithm used, + which is prone to errors. + + Args: + loss: The loss to regularize. + """ + + def update_params(self, train_loss: Scalar, val_loss: Scalar) -> None: + """Update the regularization parameter. + + This method calls the lambda_updater to update the regularization parameter, + and assigns the new value to `self.lambda_`. Then logs the new value using + the provided logger. + + Args: + train_loss: The loss on the training set. + val_loss: The loss on the validation set. + """ + if self.lambda_updater is not None: + self.lambda_ = self.lambda_updater(self.lambda_, train_loss, val_loss) + self.logger.record("regularization_lambda", self.lambda_) + + +class LossRegularizer(Regularizer[Scalar]): + """Abstract base class for regularizers that add a loss term to the loss function. + + Requires the user to implement the _loss_penalty method. + """ + + @abc.abstractmethod + def _loss_penalty(self, loss: Scalar) -> Scalar: + """Implement this method to add a loss term to the loss function. + + This method should return the term to be added to the loss function, + not the regularized loss itself. + + Args: + loss: The loss function to which the regularization term is added. + """ + + def regularize_and_backward(self, loss: th.Tensor) -> Scalar: + """Add the regularization term to the loss and compute gradients. + + Args: + loss: The loss to regularize. + + Returns: + The regularized loss. + """ + regularized_loss = th.add(loss, self._loss_penalty(loss)) + regularized_loss.backward() + self.logger.record("regularized_loss", regularized_loss.item()) + return regularized_loss + + +class WeightRegularizer(Regularizer): + """Abstract base class for regularizers that regularize the weights of a network. + + Requires the user to implement the _weight_penalty method. + """ + + @abc.abstractmethod + def _weight_penalty(self, weight: th.Tensor, group: dict) -> Scalar: + """Implement this method to regularize the weights of the network. + + This method should return the regularization term to be added to the weight, + not the regularized weight itself. + + Args: + weight: The weight (network parameter) to regularize. + group: The group of parameters to which the weight belongs. + """ + + def regularize_and_backward(self, loss: th.Tensor) -> None: + """Regularize the weights of the network, and call ``loss.backward()``.""" + loss.backward() + for group in self.optimizer.param_groups: + for param in group["params"]: + param.data = th.add(param.data, self._weight_penalty(param, group)) + + +class LpRegularizer(LossRegularizer): + """Applies Lp regularization to a loss function.""" + + p: int + + def __init__( + self, + optimizer: optim.Optimizer, + initial_lambda: float, + lambda_updater: Optional[updaters.LambdaUpdater], + logger: imit_logger.HierarchicalLogger, + p: int, + val_split: Optional[float] = None, + ) -> None: + """Initialize the regularizer.""" + super().__init__(optimizer, initial_lambda, lambda_updater, logger, val_split) + if not isinstance(p, int) or p < 1: + raise ValueError("p must be a positive integer") + self.p = p + + def _loss_penalty(self, loss: Scalar) -> Scalar: + """Returns the loss penalty. + + Calculates the p-th power of the Lp norm of the weights in the optimizer, + and returns a scaled version of it as the penalty. + + Args: + loss: The loss to regularize. + + Returns: + The scaled pth power of the Lp norm of the network weights. + """ + del loss + penalty = 0 + for group in self.optimizer.param_groups: + for param in group["params"]: + penalty += th.linalg.vector_norm(param, ord=self.p).pow(self.p) + return self.lambda_ * penalty + + +class WeightDecayRegularizer(WeightRegularizer): + """Applies weight decay to a loss function.""" + + def _weight_penalty(self, weight, group) -> Scalar: + """Returns the weight penalty. + + Args: + weight: The weight to regularize. + group: The group of parameters to which the weight belongs. + + Returns: + The weight penalty (to add to the current value of the weight) + """ + return -self.lambda_ * group["lr"] * weight.data diff --git a/src/imitation/regularization/updaters.py b/src/imitation/regularization/updaters.py new file mode 100644 index 000000000..8dc508fed --- /dev/null +++ b/src/imitation/regularization/updaters.py @@ -0,0 +1,133 @@ +"""Implements parameter scaling algorithms to update the parameters of a regularizer.""" + +from typing import Protocol, Tuple, Union + +import numpy as np +import torch as th + +LossType = Union[th.Tensor, float] + + +class LambdaUpdater(Protocol): + """Protocol type for functions that update the regularizer parameter. + + A callable object that takes in the current lambda and the train and val loss, and + returns the new lambda. This has been implemented as a protocol and not an ABC + because a user might wish to provide their own implementation without having to + inherit from the base class, e.g. by defining a function instead of a class. + + Note: if you implement `LambdaUpdater`, your implementation MUST be purely + functional, i.e. side-effect free. The class structure should only be used + to store constant hyperparameters. (Alternatively, closures can be used for that). + """ + + def __call__(self, lambda_, train_loss: LossType, val_loss: LossType) -> float: + ... + + +class IntervalParamScaler(LambdaUpdater): + """Scales the lambda of the regularizer by some constant factor. + + Lambda is scaled up if the ratio of the validation loss to the training loss + is above the tolerable interval, and scaled down if the ratio is below the + tolerable interval. Nothing happens if the ratio is within the tolerable + interval. + """ + + def __init__(self, scaling_factor: float, tolerable_interval: Tuple[float, float]): + """Initialize the interval parameter scaler. + + Args: + scaling_factor: The factor by which to scale the lambda, a value in (0, 1). + tolerable_interval: The interval within which the ratio of the validation + loss to the training loss is considered acceptable. A tuple whose first + element is at least 0 and the second element is greater than the first. + + Raises: + ValueError: If the tolerable interval is not a tuple of length 2. + ValueError: if the scaling factor is not in (0, 1). + ValueError: if the tolerable interval is negative or not a proper interval. + """ + eps = np.finfo(float).eps + if not (eps < scaling_factor < 1 - eps): + raise ValueError( + "scaling_factor must be in (0, 1) within machine precision.", + ) + if len(tolerable_interval) != 2: + raise ValueError("tolerable_interval must be a tuple of length 2") + if not (0 <= tolerable_interval[0] < tolerable_interval[1]): + raise ValueError( + "tolerable_interval must be a tuple whose first element " + "is at least 0 and the second element is greater than " + "the first", + ) + + self.scaling_factor = scaling_factor + self.tolerable_interval = tolerable_interval + + def __call__( + self, + lambda_: float, + train_loss: LossType, + val_loss: LossType, + ) -> float: + """Scales the lambda of the regularizer by some constant factor. + + Lambda is scaled up if the ratio of the validation loss to the training loss + is above the tolerable interval, and scaled down if the ratio is below the + tolerable interval. Nothing happens if the ratio is within the tolerable + interval. + + Args: + lambda_: The current value of the lambda. + train_loss: The loss on the training set. + val_loss: The loss on the validation set. + + Returns: + The new value of the lambda. + + Raises: + ValueError: If the loss on the validation set is not a scalar. + ValueError: if lambda_ is zero (will result in no scaling). + ValueError: if lambda_ is not a float. + """ + # check that the tensors val_loss and train_loss are both scalars + if not ( + isinstance(val_loss, float) + or (isinstance(val_loss, th.Tensor) and val_loss.dim() == 0) + ): + raise ValueError("val_loss must be a scalar") + + if not ( + isinstance(train_loss, float) + or (isinstance(train_loss, th.Tensor) and train_loss.dim() == 0) + ): + raise ValueError("train_loss must be a scalar") + + if np.finfo(float).eps > abs(lambda_): + raise ValueError( + "lambda_ must not be zero. Make sure that you're not " + "scaling the value of lambda down too quickly or passing an " + "initial value of zero to the lambda parameter.", + ) + elif lambda_ < 0: + raise ValueError("lambda_ must be non-negative") + if not isinstance(lambda_, float): + raise ValueError("lambda_ must be a float") + if train_loss < 0 or val_loss < 0: + raise ValueError("losses must be non-negative for this updater") + + eps = np.finfo(float).eps + if train_loss < eps and val_loss < eps: + # 0/0 is undefined, so return the current lambda + return lambda_ + elif train_loss < eps <= val_loss: + # the ratio would be infinite + return lambda_ * (1 + self.scaling_factor) + + val_to_train_ratio = val_loss / train_loss + if val_to_train_ratio > self.tolerable_interval[1]: + lambda_ *= 1 + self.scaling_factor + elif val_to_train_ratio < self.tolerable_interval[0]: + lambda_ *= 1 - self.scaling_factor + return lambda_ diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 24bdf263a..9374b29bf 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -15,6 +15,7 @@ from imitation.algorithms import preference_comparisons from imitation.data import types from imitation.data.types import TrajectoryWithRew +from imitation.regularization import regularizers, updaters from imitation.rewards import reward_nets from imitation.util import networks, util @@ -214,11 +215,11 @@ def test_reward_ensemble_trainer_raises_type_error(venv): with pytest.raises( TypeError, - match=r"RewardEnsemble expected by EnsembleTrainer not .*", + match=r"RewardEnsemble expected by EnsembleTrainer, not .*", ): preference_comparisons.EnsembleTrainer( - reward_net, - loss, + model=reward_net, # type: ignore + loss=loss, ) @@ -490,8 +491,94 @@ def test_active_fragmenter_discount_rate_no_crash( main_trainer.train(100, 10) +@pytest.fixture(scope="module") +def interval_param_scaler() -> updaters.IntervalParamScaler: + return updaters.IntervalParamScaler( + scaling_factor=0.1, + tolerable_interval=(1.1, 1.5), + ) + + +def test_reward_trainer_regularization_no_crash( + agent_trainer, + venv, + random_fragmenter, + custom_logger, + preference_model, + interval_param_scaler, +): + reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) + loss = preference_comparisons.CrossEntropyRewardLoss(preference_model) + initial_lambda = 0.1 + regularizer_factory = regularizers.LpRegularizer.create( + initial_lambda=initial_lambda, + val_split=0.2, + lambda_updater=interval_param_scaler, + p=2, + ) + reward_trainer = preference_comparisons.BasicRewardTrainer( + reward_net, + loss, + regularizer_factory=regularizer_factory, + custom_logger=custom_logger, + ) + + main_trainer = preference_comparisons.PreferenceComparisons( + agent_trainer, + reward_net, + num_iterations=2, + transition_oversampling=2, + fragment_length=2, + fragmenter=random_fragmenter, + reward_trainer=reward_trainer, + custom_logger=custom_logger, + ) + main_trainer.train(50, 50) + + +def test_reward_trainer_regularization_raises( + agent_trainer, + venv, + random_fragmenter, + custom_logger, + preference_model, + interval_param_scaler, +): + reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) + loss = preference_comparisons.CrossEntropyRewardLoss(preference_model) + initial_lambda = 0.1 + regularizer_factory = regularizers.LpRegularizer.create( + initial_lambda=initial_lambda, + val_split=0.2, + lambda_updater=interval_param_scaler, + p=2, + ) + reward_trainer = preference_comparisons.BasicRewardTrainer( + reward_net, + loss, + regularizer_factory=regularizer_factory, + custom_logger=custom_logger, + ) + + main_trainer = preference_comparisons.PreferenceComparisons( + agent_trainer, + reward_net, + num_iterations=2, + transition_oversampling=2, + fragment_length=2, + fragmenter=random_fragmenter, + reward_trainer=reward_trainer, + custom_logger=custom_logger, + ) + with pytest.raises( + ValueError, + match="Not enough data samples to split " "into training and validation.*", + ): + main_trainer.train(100, 10) + + @pytest.fixture -def ensemble_preference_model(venv) -> preference_comparisons.PreferenceComparisons: +def ensemble_preference_model(venv) -> preference_comparisons.PreferenceModel: reward_net = reward_nets.RewardEnsemble( venv.observation_space, venv.action_space, @@ -509,7 +596,7 @@ def ensemble_preference_model(venv) -> preference_comparisons.PreferenceComparis @pytest.fixture -def preference_model(venv) -> preference_comparisons.PreferenceComparisons: +def preference_model(venv) -> preference_comparisons.PreferenceModel: reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) return preference_comparisons.PreferenceModel( model=reward_net, diff --git a/tests/test_regularization.py b/tests/test_regularization.py new file mode 100644 index 000000000..7c09a1ccf --- /dev/null +++ b/tests/test_regularization.py @@ -0,0 +1,528 @@ +"""Tests for `imitation.regularization.*`.""" +import itertools +import tempfile + +import numpy as np +import pytest +import torch as th + +from imitation.regularization import regularizers, updaters +from imitation.util import logger as imit_logger + + +@pytest.fixture( + scope="module", + params=[ + (0.5, (0.9, 1)), # unlikely to fall inside the interval + (0.5, (0.01, 10)), # likely to fall inside the interval + ], +) +def interval_param_scaler(request): + return updaters.IntervalParamScaler(*request.param) + + +@pytest.mark.parametrize( + "lambda_", + [ + 10.0, + 0.001, + ], +) +@pytest.mark.parametrize( + "train_loss", + [ + th.tensor(100.0), + th.tensor(10.0), + th.tensor(0.1), + th.tensor(0.0), + 100.0, + 10.0, + 0.1, + 0.0, + ], +) +def test_interval_param_scaler(lambda_, train_loss, interval_param_scaler): + scaler = interval_param_scaler + tolerable_interval = scaler.tolerable_interval + scaling_factor = scaler.scaling_factor + eps = np.finfo(float).eps + if train_loss > eps: + # The loss is a non-zero scalar, so we can construct a validation loss for + # three different cases: + + # case that the ratio between the validation loss and the training loss is + # above the tolerable interval + val_loss = train_loss * tolerable_interval[1] * 2 + assert scaler(lambda_, train_loss, val_loss) == lambda_ * (1 + scaling_factor) + + # case that the ratio between the validation loss and the training loss is + # below the tolerable interval + val_loss = train_loss * tolerable_interval[0] / 2 + assert scaler(lambda_, train_loss, val_loss) == lambda_ * (1 - scaling_factor) + + # case that the ratio between the validation loss and the training loss is + # within the tolerable interval + val_loss = train_loss * (tolerable_interval[0] + tolerable_interval[1]) / 2 + assert scaler(lambda_, train_loss, val_loss) == lambda_ + else: + # we have a zero loss. We try two cases. When the validation loss is zero, + # the ratio is undefined, so we should return the current lambda. When the + # validation loss is nonzero, the ratio is infinite, so we should see the lambda + # increase by the scaling factor. + # We try it for both a tensor and a float value. + val_loss = th.tensor(0.0) + assert scaler(lambda_, train_loss, val_loss) == lambda_ + val_loss = 0.0 + assert scaler(lambda_, train_loss, val_loss) == lambda_ + val_loss = th.tensor(1.0) + assert scaler(lambda_, train_loss, val_loss) == lambda_ * (1 + scaling_factor) + val_loss = 1.0 + assert scaler(lambda_, train_loss, val_loss) == lambda_ * (1 + scaling_factor) + + +def test_interval_param_scaler_raises(interval_param_scaler): + scaler = interval_param_scaler + with pytest.raises(ValueError, match="val_loss must be a scalar"): + scaler(1.0, 1.0, th.Tensor([3.0, 4.0])) + with pytest.raises(ValueError, match="train_loss must be a scalar"): + scaler(1.0, th.Tensor([1.0, 2.0]), 1.0) + with pytest.raises(ValueError, match="train_loss must be a scalar"): + scaler(1.0, "random value", th.tensor(1.0)) # type: ignore + with pytest.raises(ValueError, match="val_loss must be a scalar"): + scaler(1.0, 1.0, "random value") # type: ignore + with pytest.raises(ValueError, match="lambda_ must be a float"): + scaler(th.tensor(1.0), 1.0, 1.0) # type: ignore + with pytest.raises(ValueError, match="lambda_ must not be zero.*"): + scaler(0.0, 1.0, 1.0) + with pytest.raises(ValueError, match="lambda_ must be non-negative.*"): + scaler(-1.0, 1.0, 1.0) + with pytest.raises(ValueError, match="losses must be non-negative.*"): + scaler(1.0, -1.0, 1.0) + with pytest.raises(ValueError, match="losses must be non-negative.*"): + scaler(1.0, 1.0, -1.0) + + +def test_interval_param_scaler_init_raises(): + # this validates the value of scaling_factor. + interval_err_msg = r"scaling_factor must be in \(0, 1\) within machine precision." + + with pytest.raises(ValueError, match=interval_err_msg): + # cannot be negative as this is counter-intuitive to + # the direction of scaling (just use the reciprocal). + updaters.IntervalParamScaler(-1, (0.1, 0.9)) + + with pytest.raises(ValueError, match=interval_err_msg): + # cannot be larger than one as this would make lambda + # negative when scaling down. + updaters.IntervalParamScaler(1.1, (0.1, 0.9)) + + with pytest.raises(ValueError, match=interval_err_msg): + # cannot be exactly zero, as this never changes the value + # of lambda when scaling up. + updaters.IntervalParamScaler(0.0, (0.1, 0.9)) + + with pytest.raises(ValueError, match=interval_err_msg): + # cannot be exactly one, as when lambda is scaled down + # this brings it to zero. + updaters.IntervalParamScaler(1.0, (0.1, 0.9)) + + # an interval obviously needs two elements only. + with pytest.raises( + ValueError, + match="tolerable_interval must be a tuple of length 2", + ): + updaters.IntervalParamScaler(0.5, (0.1, 0.9, 0.5)) # type: ignore + with pytest.raises( + ValueError, + match="tolerable_interval must be a tuple of length 2", + ): + updaters.IntervalParamScaler(0.5, (0.1,)) # type: ignore + + # the first element of the interval must be at least 0. + with pytest.raises( + ValueError, + match="tolerable_interval must be a tuple whose first element " + "is at least 0.*", + ): + updaters.IntervalParamScaler(0.5, (-0.1, 0.9)) + + # the second element of the interval must be greater than the first. + with pytest.raises( + ValueError, + match="tolerable_interval must be a tuple.*the second " + "element is greater than the first", + ): + updaters.IntervalParamScaler(0.5, (0.1, 0.05)) + + +@pytest.fixture(scope="module") +def hierarchical_logger(): + tmpdir = tempfile.mkdtemp() + return imit_logger.configure(tmpdir, ["tensorboard", "stdout", "csv"]) + + +@pytest.fixture(scope="module", params=[0.1, 1.0, 10.0]) +def simple_optimizer(request): + return th.optim.Adam([th.tensor(request.param, requires_grad=True)], lr=0.1) + + +@pytest.fixture(scope="module", params=[0.1, 1.0, 10.0]) +def initial_lambda(request): + return request.param + + +class SimpleRegularizer(regularizers.Regularizer[None]): + """A simple regularizer that does nothing.""" + + def regularize_and_backward(self, loss: th.Tensor) -> None: + pass # pragma: no cover + + +def test_regularizer_init_no_crash( + initial_lambda, + hierarchical_logger, + simple_optimizer, + interval_param_scaler, +): + SimpleRegularizer( + initial_lambda=initial_lambda, + optimizer=simple_optimizer, + logger=hierarchical_logger, + lambda_updater=interval_param_scaler, + val_split=0.2, + ) + + SimpleRegularizer( + initial_lambda=initial_lambda, + optimizer=simple_optimizer, + logger=hierarchical_logger, + lambda_updater=None, + val_split=None, + ) + + SimpleRegularizer.create( + initial_lambda=initial_lambda, + lambda_updater=interval_param_scaler, + val_split=0.2, + )( + optimizer=simple_optimizer, + logger=hierarchical_logger, + ) + + +@pytest.mark.parametrize( + "val_split", + [ + 0.0, + 1.0, + -10, + 10, + "random value", + 10**-100, + ], +) +def test_regularizer_init_raises_on_val_split( + initial_lambda, + hierarchical_logger, + simple_optimizer, + interval_param_scaler, + val_split, +): + val_split_err_msg = "val_split.*must be a float.*between.*" + with pytest.raises(ValueError, match=val_split_err_msg): + return SimpleRegularizer( + initial_lambda=initial_lambda, + optimizer=simple_optimizer, + logger=hierarchical_logger, + lambda_updater=interval_param_scaler, + val_split=val_split, + ) + + +def test_regularizer_init_raises( + initial_lambda, + hierarchical_logger, + simple_optimizer, + interval_param_scaler, +): + with pytest.raises( + ValueError, + match=".*do not pass.*parameter updater.*regularization strength.*non-zero", + ): + SimpleRegularizer( + initial_lambda=0.0, + optimizer=simple_optimizer, + logger=hierarchical_logger, + lambda_updater=None, + val_split=0.2, + ) + with pytest.raises( + ValueError, + match=".*pass.*parameter updater.*must.*specify.*validation split.*", + ): + SimpleRegularizer( + initial_lambda=initial_lambda, + optimizer=simple_optimizer, + logger=hierarchical_logger, + lambda_updater=interval_param_scaler, + val_split=None, + ) + with pytest.raises( + ValueError, + match=".*pass.*validation split.*must.*pass.*parameter updater.*", + ): + SimpleRegularizer( + initial_lambda=initial_lambda, + optimizer=simple_optimizer, + logger=hierarchical_logger, + lambda_updater=None, + val_split=0.2, + ) + + +@pytest.mark.parametrize( + "train_loss", + [ + th.tensor(10.0), + th.tensor(1.0), + th.tensor(0.1), + th.tensor(0.01), + ], +) +def test_regularizer_update_params( + initial_lambda, + hierarchical_logger, + simple_optimizer, + interval_param_scaler, + train_loss, +): + regularizer = SimpleRegularizer( + initial_lambda=initial_lambda, + logger=hierarchical_logger, + lambda_updater=interval_param_scaler, + optimizer=simple_optimizer, + val_split=0.1, + ) + val_to_train_loss_ratio = interval_param_scaler.tolerable_interval[1] * 2 + val_loss = train_loss * val_to_train_loss_ratio + assert regularizer.lambda_ == initial_lambda + assert ( + hierarchical_logger.default_logger.name_to_value["regularization_lambda"] + == initial_lambda + ) + regularizer.update_params(train_loss, val_loss) + expected_lambda_value = interval_param_scaler(initial_lambda, train_loss, val_loss) + assert regularizer.lambda_ == expected_lambda_value + assert expected_lambda_value != initial_lambda + assert ( + hierarchical_logger.default_logger.name_to_value["regularization_lambda"] + == expected_lambda_value + ) + + +class SimpleLossRegularizer(regularizers.LossRegularizer): + """A simple loss regularizer. + + It multiplies the total loss by lambda_+1. + """ + + def _loss_penalty(self, loss: th.Tensor) -> th.Tensor: + return loss * self.lambda_ # this multiplies the total loss by lambda_+1. + + +@pytest.mark.parametrize( + "train_loss_base", + [ + th.tensor(10.0), + th.tensor(1.0), + th.tensor(0.1), + th.tensor(0.01), + ], +) +def test_loss_regularizer( + hierarchical_logger, + simple_optimizer, + initial_lambda, + train_loss_base, +): + regularizer = SimpleLossRegularizer( + initial_lambda=initial_lambda, + logger=hierarchical_logger, + lambda_updater=None, + optimizer=simple_optimizer, + ) + loss_param = simple_optimizer.param_groups[0]["params"][0] + train_loss = train_loss_base * loss_param + regularizer.optimizer.zero_grad() + regularized_loss = regularizer.regularize_and_backward(train_loss) + assert th.allclose(regularized_loss.data, train_loss * (initial_lambda + 1)) + assert ( + hierarchical_logger.default_logger.name_to_value["regularized_loss"] + == regularized_loss + ) + assert th.allclose(loss_param.grad, train_loss_base * (initial_lambda + 1)) + + +class SimpleWeightRegularizer(regularizers.WeightRegularizer): + """A simple weight regularizer. + + It multiplies the total weight by lambda_+1. + """ + + def _weight_penalty(self, weight, group): + # this multiplies the total weight by lambda_+1. + # However, the grad is only calculated with respect to the + # previous value of the weight. + # This difference is only noticeable if the grad of the loss + # has a functional dependence on the weight (i.e. not linear). + return weight * self.lambda_ + + +@pytest.mark.parametrize( + "train_loss_base", + [ + th.tensor(10.0), + th.tensor(1.0), + th.tensor(0.1), + th.tensor(0.01), + ], +) +def test_weight_regularizer( + hierarchical_logger, + simple_optimizer, + initial_lambda, + train_loss_base, +): + regularizer = SimpleWeightRegularizer( + initial_lambda=initial_lambda, + logger=hierarchical_logger, + lambda_updater=None, + optimizer=simple_optimizer, + ) + weight = simple_optimizer.param_groups[0]["params"][0] + initial_weight_value = weight.data.clone() + regularizer.optimizer.zero_grad() + train_loss = train_loss_base * th.pow(weight, 2) / 2 + regularizer.regularize_and_backward(train_loss) + assert th.allclose(weight.data, initial_weight_value * (initial_lambda + 1)) + assert th.allclose(weight.grad, train_loss_base * initial_weight_value) + + +@pytest.mark.parametrize("p", [0.5, 1.5, -1, 0, "random value"]) +def test_lp_regularizer_p_value_raises(hierarchical_logger, simple_optimizer, p): + with pytest.raises(ValueError, match="p must be a positive integer"): + regularizers.LpRegularizer( + initial_lambda=1.0, + logger=hierarchical_logger, + lambda_updater=None, + optimizer=simple_optimizer, + p=p, + ) + + +MULTI_PARAM_OPTIMIZER_INIT_VALS = [-1.0, 0.0, 1.0] +MULTI_PARAM_OPTIMIZER_ARGS = itertools.product( + MULTI_PARAM_OPTIMIZER_INIT_VALS, + MULTI_PARAM_OPTIMIZER_INIT_VALS, +) + + +@pytest.fixture(scope="module", params=MULTI_PARAM_OPTIMIZER_ARGS) +def multi_param_optimizer(request): + return th.optim.Adam( + [th.tensor(p, requires_grad=True) for p in request.param], + lr=0.1, + ) + + +MULTI_PARAM_AND_LR_OPTIMIZER_ARGS = itertools.product( + MULTI_PARAM_OPTIMIZER_INIT_VALS, + MULTI_PARAM_OPTIMIZER_INIT_VALS, + [0.001, 0.01, 0.1], +) + + +@pytest.fixture(scope="module", params=MULTI_PARAM_AND_LR_OPTIMIZER_ARGS) +def multi_param_and_lr_optimizer(request): + return th.optim.Adam( + [th.tensor(p, requires_grad=True) for p in request.param[:-1]], + lr=request.param[-1], + ) + + +@pytest.mark.parametrize( + "train_loss", + [ + th.tensor(10.0), + th.tensor(1.0), + th.tensor(0.1), + ], +) +@pytest.mark.parametrize("p", [1, 2, 3]) +def test_lp_regularizer( + hierarchical_logger, + multi_param_optimizer, + initial_lambda, + train_loss, + p, +): + regularizer = regularizers.LpRegularizer( + initial_lambda=initial_lambda, + logger=hierarchical_logger, + lambda_updater=None, + optimizer=multi_param_optimizer, + p=p, + ) + params = multi_param_optimizer.param_groups[0]["params"] + regularizer.optimizer.zero_grad() + regularized_loss = regularizer.regularize_and_backward(train_loss) + loss_penalty = sum( + [th.linalg.vector_norm(param.data, ord=p).pow(p) for param in params], + ) + assert th.allclose( + regularized_loss.data, + train_loss + initial_lambda * loss_penalty, + ) + assert ( + regularized_loss + == hierarchical_logger.default_logger.name_to_value["regularized_loss"] + ) + for param in params: + assert th.allclose( + param.grad, + p * initial_lambda * th.abs(param).pow(p - 1) * th.sign(param), + ) + + +@pytest.mark.parametrize( + "train_loss_base", + [ + th.tensor(1.0), + th.tensor(0.1), + th.tensor(0.01), + ], +) +def test_weight_decay_regularizer( + multi_param_and_lr_optimizer, + hierarchical_logger, + initial_lambda, + train_loss_base, +): + regularizer = regularizers.WeightDecayRegularizer( + initial_lambda=initial_lambda, + logger=hierarchical_logger, + lambda_updater=None, + optimizer=multi_param_and_lr_optimizer, + ) + weights = regularizer.optimizer.param_groups[0]["params"] + lr = regularizer.optimizer.param_groups[0]["lr"] + initial_weight_values = [weight.data.clone() for weight in weights] + regularizer.optimizer.zero_grad() + train_loss = train_loss_base * sum(th.pow(weight, 2) / 2 for weight in weights) + regularizer.regularize_and_backward(train_loss) + for weight, initial_weight_value in zip(weights, initial_weight_values): + assert th.allclose( + weight.data, + initial_weight_value * (1 - lr * initial_lambda), + ) + assert th.allclose(weight.grad, train_loss_base * initial_weight_value)