Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regularization API for preference comparisons #481

Merged
merged 120 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
44aa95e
first draft of reward ensembles
levmckinney Jul 2, 2022
3edde81
Fixed doc string
levmckinney Jul 7, 2022
65f6a82
adressed most of reviewers comments
levmckinney Jul 7, 2022
fa9147f
Renamed UncertainRewardNet to RewardNetWithVariance
levmckinney Jul 7, 2022
8d490a4
moved implementation of make_reward_net to reward_nets.py and rewrote…
levmckinney Jul 7, 2022
7b6e285
fixed conservative reward wrapper
levmckinney Jul 7, 2022
2ec296e
added test for reward_moments
levmckinney Jul 8, 2022
eae4477
switched to a nn.ModuleList not sure how serialize identity was passi…
levmckinney Jul 8, 2022
83b971d
added test for conservative reward function
levmckinney Jul 8, 2022
2f7b4c3
pulled loss calculation out of reward trainer
levmckinney Jul 8, 2022
42378f8
created reward ensemble trainer
levmckinney Jul 9, 2022
59b7644
added documentation for cross_entropy_loss_kwarg
levmckinney Jul 9, 2022
f591631
Merge branch 'master' into reward_ensemble
levmckinney Jul 9, 2022
82048ac
fixed tests and implementation of ensemble trainer
levmckinney Jul 9, 2022
53fbd14
modified assert so that it is actually always true
levmckinney Jul 9, 2022
4577771
added loss to preference comparision notebook
levmckinney Jul 9, 2022
01add55
add named config to reward.py and integrated tests
yawen-d Jul 10, 2022
7edb59d
added logging of standard deviation
levmckinney Jul 12, 2022
9a909ff
changed conservative reward wrapper into reward function that adds st…
levmckinney Jul 12, 2022
6dc6463
added validate reward structure function and the ability to pass kwar…
levmckinney Jul 13, 2022
67e3909
fixed test_validate_wrapper_structure
levmckinney Jul 13, 2022
2d1547b
Added option to create and load reward functions that add std. Now by…
levmckinney Jul 14, 2022
0ae83d8
fixed test_validate_wrapper_structure again
levmckinney Jul 14, 2022
a35778b
removed failure_rate.sh
levmckinney Jul 14, 2022
9dbc75e
predict_processed in normalization wrapper now calls base classes pre…
levmckinney Jul 14, 2022
197512a
fixed test coverage
levmckinney Jul 14, 2022
d0e0ad9
added test that normalized reward net passes along its kwargs and tha…
levmckinney Jul 14, 2022
4b42892
adressed reviewers comments.
levmckinney Jul 15, 2022
23b1f4d
now testing that all basic wrappers pass along kwargs when calling pr…
levmckinney Jul 15, 2022
0246fb2
added del kwargs where appropriate to improve redability
levmckinney Jul 15, 2022
89e3fdc
Made reward ensemble ignore additional kwargs.
levmckinney Jul 15, 2022
4e79210
fixed test for preference compairison and added new test documenting …
levmckinney Jul 18, 2022
acac694
Made doc string copy more explicit
levmckinney Jul 18, 2022
077c366
improved documentation and input validation of RewardEnsemble and Rew…
levmckinney Jul 18, 2022
dfefb09
removed extra kwargs from RewardEnsemble __init__
levmckinney Jul 19, 2022
d1512f5
Addressed reviewers comments
levmckinney Jul 19, 2022
169ec0a
rewrote validate wrapper structure to be more readible.
levmckinney Jul 19, 2022
dbd992a
fixed bug in implementation of _validate_wrapper_structure
levmckinney Jul 19, 2022
daaaa8b
added test that check that the kwargs passed to load_reward are passe…
levmckinney Jul 19, 2022
9583ea1
added test that check that the kwargs passed to load_reward are passe…
levmckinney Jul 19, 2022
4175b3d
fixed failing tests
levmckinney Jul 19, 2022
7480113
improved return types
levmckinney Jul 19, 2022
b431311
wrappers default to calling corrisponding base methods and enemble no…
levmckinney Jul 20, 2022
1a93321
addressed more comments by reviewer
levmckinney Jul 20, 2022
1d4b167
addressed even more comments by reviewer
levmckinney Jul 20, 2022
d08b8d5
fixed failing test
levmckinney Jul 20, 2022
27841ba
Merge branch 'master' into reward_ensemble
levmckinney Jul 20, 2022
fcee608
moved make reward back to scripts.common.rewards and reward ensemble …
levmckinney Jul 21, 2022
a559627
Apply suggestions from code review
levmckinney Jul 21, 2022
910b111
extracted make_ensemble and added it to test_preference_compairison.
levmckinney Jul 21, 2022
c78f05b
Updated loss and metrics docstring
levmckinney Jul 21, 2022
dc9934f
addressed more of reviewers comments and fixed loss mean
levmckinney Jul 22, 2022
3f6e5ae
aded load_reard_kwargs to train_rl.py
levmckinney Jul 22, 2022
24e0b1d
aded load_reard_kwargs to train_rl.py
levmckinney Jul 22, 2022
628e285
fixed failling tests and added load_reward_kwargs to retraining test
levmckinney Jul 22, 2022
2c48636
Skeleton of regularization techniques
Rocamonde Jul 25, 2022
9f70750
Merge branch 'master' into dynamic-l2-regularization
Rocamonde Jul 30, 2022
a65f536
Merge branch 'master' into dynamic-l2-regularization
Rocamonde Jul 30, 2022
30b529e
Added loss and weight regularizer, update_param fn protocol
Rocamonde Jul 31, 2022
1557e9e
Fix type error
Rocamonde Jul 31, 2022
68ac8f1
Merge remote-tracking branch 'origin/master' into dynamic-l2-regulari…
AdamGleave Aug 2, 2022
5685c3b
Added logging
Rocamonde Aug 16, 2022
39bb2ed
Renamed reg to regularization, passed logger down
Rocamonde Aug 16, 2022
e43bf7f
Merge remote-tracking branch 'origin/master' into dynamic-l2-regulari…
Rocamonde Aug 16, 2022
a2af49d
Fixes linting issues
Rocamonde Aug 16, 2022
a9e42ec
Silly linting error only visible on CircleCI
Rocamonde Aug 16, 2022
b2a32f7
Update src/imitation/regularization/__init__.py
Rocamonde Aug 22, 2022
dde3b14
Update src/imitation/regularization/__init__.py
Rocamonde Aug 22, 2022
b981fa3
Update src/imitation/regularization/__init__.py
Rocamonde Aug 22, 2022
9c58183
Make regularizer initialization control flow more readable
Rocamonde Aug 22, 2022
3e9be18
Typing improvements
Rocamonde Aug 22, 2022
3a954c9
Merge branch 'dynamic-l2-regularization' of github.com:HumanCompatibl…
Rocamonde Aug 22, 2022
8bfa8fb
Multiple fixes re: types and assertions
Rocamonde Aug 22, 2022
e9038e9
Improved input validation
Rocamonde Aug 22, 2022
8451a9c
Add support for Lp norms, remove singleton L1/L2 implementations
Rocamonde Aug 22, 2022
729faaa
Add support for arbitrary regularizers, improve input validation
Rocamonde Aug 22, 2022
1039562
Typing and linting
Rocamonde Aug 29, 2022
184bf2b
Fix silly errata causing test to fail
Rocamonde Aug 29, 2022
01f3be7
Add missing docstring args
Rocamonde Aug 29, 2022
56bb68b
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
Rocamonde Aug 29, 2022
7d491bb
Restructured folder into submodules and improved typing by adding gen…
Rocamonde Aug 29, 2022
1d53c4a
Updated imports after restructuring folder
Rocamonde Aug 30, 2022
abdab72
Added tests for updaters.py
Rocamonde Aug 30, 2022
3990341
Added tests for regularizers.py, except weight decay
Rocamonde Aug 31, 2022
ed3f574
Fixed tests for lp regularization
Rocamonde Aug 31, 2022
64f13bd
Added tests for weight decay
Rocamonde Aug 31, 2022
f6ea4b1
Linting / formatting
Rocamonde Aug 31, 2022
e103e86
Linting / typing
Rocamonde Aug 31, 2022
bc544b9
Final tests to improve code coverage
Rocamonde Aug 31, 2022
3f97de7
Tweaks for code coverage
Rocamonde Aug 31, 2022
abe6608
Tweaks for code coverage v2
Rocamonde Aug 31, 2022
9707b09
Fix logging issues
Rocamonde Sep 7, 2022
60d2dd7
Formatting
Rocamonde Sep 7, 2022
e48f715
Formatting (docstring)
Rocamonde Sep 7, 2022
a91eff7
Fix file open issue
Rocamonde Sep 8, 2022
71a3edb
Linting
Rocamonde Sep 8, 2022
34a0acb
Remove useless conversion to float
Rocamonde Sep 12, 2022
0636e5d
Replace assert with ValueError
Rocamonde Sep 12, 2022
0a3351e
Check for lambda being negative in the scaler
Rocamonde Sep 12, 2022
2b36e62
Guard against losses being negative in interval param scaler
Rocamonde Sep 12, 2022
e95f4b3
Split tests up to cover more cases
Rocamonde Sep 12, 2022
918e283
Clean up repetitive code for readability
Rocamonde Sep 12, 2022
9c108d9
Remove old TODO message
Rocamonde Sep 12, 2022
24c5729
Merge RewardTrainer seeds into one
Rocamonde Sep 12, 2022
6ed0b9b
Fix interval param tests to new error messages
Rocamonde Sep 12, 2022
64e38a0
Move regularization input validation to factory class
Rocamonde Sep 12, 2022
e3d3f6e
Remake the regularizer factory for better API design
Rocamonde Sep 12, 2022
78efe8f
Fix bugs and tests in new factory design
Rocamonde Sep 12, 2022
19f3518
Added docstring to regularizer factory.
Rocamonde Sep 12, 2022
a293322
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
Rocamonde Sep 13, 2022
e9218c6
Update src/imitation/regularization/regularizers.py
Rocamonde Sep 13, 2022
0e2c997
Update src/imitation/regularization/regularizers.py
Rocamonde Sep 13, 2022
b6cb2cc
Add todo to refactor once #529 is merged.
Rocamonde Sep 13, 2022
b29b7ef
Rename regularize to regularize_and_backward
Rocamonde Sep 13, 2022
0145066
Fix bug in tests and docstrings
Rocamonde Sep 13, 2022
1052398
Rename mode to prefix
Rocamonde Sep 13, 2022
c16a0ef
Added exceptions to docstrings
Rocamonde Sep 13, 2022
7ebec14
Make type ignore only specific to pytype
Rocamonde Sep 13, 2022
d1574c3
Add verbatim double-`` to some docstrings
AdamGleave Sep 14, 2022
8d0b0a9
Change phrasing in docstring
AdamGleave Sep 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 138 additions & 72 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Transitions,
)
from imitation.policies import exploration_wrapper
from imitation.regularization import regularizers
from imitation.rewards import reward_function, reward_nets, reward_wrapper
from imitation.util import logger as imit_logger
from imitation.util import networks, util
Expand Down Expand Up @@ -389,7 +390,7 @@ def forward(
`ensemble_member_index`.

Args:
fragment_pairs: batch of pair of fragments.
fragment_pairs: batch of pairs of fragments.
ensemble_member_index: index of member network in ensemble model.
If the model is an ensemble of networks, this cannot be None.

Expand Down Expand Up @@ -432,10 +433,7 @@ def forward(

return probs, (gt_probs if gt_reward_available else None)

def rewards(
self,
transitions: Transitions,
) -> th.Tensor:
def rewards(self, transitions: Transitions) -> th.Tensor:
"""Computes the reward for all transitions.

Args:
Expand All @@ -460,11 +458,7 @@ def rewards(
assert rews.shape == (len(state),)
return rews

def probability(
self,
rews1: th.Tensor,
rews2: th.Tensor,
) -> th.Tensor:
def probability(self, rews1: th.Tensor, rews2: th.Tensor) -> th.Tensor:
"""Computes the Boltzmann rational probability that the first trajectory is best.

Args:
Expand Down Expand Up @@ -897,11 +891,7 @@ def __init__(self, max_size: Optional[int] = None):
self.max_size = max_size
self.preferences = np.array([])

def push(
self,
fragments: Sequence[TrajectoryWithRewPair],
preferences: np.ndarray,
):
def push(self, fragments: Sequence[TrajectoryWithRewPair], preferences: np.ndarray):
"""Add more samples to the dataset.

Args:
Expand All @@ -917,7 +907,7 @@ def push(
if preferences.shape != (len(fragments),):
raise ValueError(
f"Unexpected preferences shape {preferences.shape}, "
f"expected {(len(fragments), )}",
f"expected {(len(fragments),)}",
)
if preferences.dtype != np.float32:
raise ValueError("preferences should have dtype float32")
Expand Down Expand Up @@ -998,10 +988,7 @@ def _trajectory_pair_includes_reward(fragment_pair: TrajectoryPair):
class CrossEntropyRewardLoss(RewardLoss):
"""Compute the cross entropy reward loss."""

def __init__(
self,
preference_model: PreferenceModel,
):
def __init__(self, preference_model: PreferenceModel):
"""Create cross entropy reward loss.

Args:
Expand Down Expand Up @@ -1032,13 +1019,13 @@ def forward(
"""
probs, gt_probs = self.preference_model(fragment_pairs, ensemble_member_index)
# TODO(ejnnr): Here and below, > 0.5 is problematic
# because getting exactly 0.5 is actually somewhat
# common in some environments (as long as sample=False or temperature=0).
# In a sense that "only" creates class imbalance
# but it's still misleading.
predictions = (probs > 0.5).float()
# because getting exactly 0.5 is actually somewhat
# common in some environments (as long as sample=False or temperature=0).
# In a sense that "only" creates class imbalance
# but it's still misleading.
predictions = probs > 0.5
preferences_th = th.as_tensor(preferences, dtype=th.float32)
ground_truth = (preferences_th > 0.5).float()
ground_truth = preferences_th > 0.5
metrics = {}
metrics["accuracy"] = (predictions == ground_truth).float().mean()
if gt_probs is not None:
Expand Down Expand Up @@ -1094,15 +1081,18 @@ def _train(self, dataset: PreferenceDataset, epoch_multiplier: float) -> None:
class BasicRewardTrainer(RewardTrainer):
"""Train a basic reward model."""

regularizer: Optional[regularizers.Regularizer]

def __init__(
self,
model: reward_nets.RewardNet,
loss: RewardLoss,
batch_size: int = 32,
epochs: int = 1,
lr: float = 1e-3,
weight_decay: float = 0.0,
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
seed: Optional[int] = None,
regularizer_factory: Optional[regularizers.RegularizerFactory] = None,
):
"""Initialize the reward model trainer.

Expand All @@ -1114,19 +1104,24 @@ def __init__(
on the fly by specifying an `epoch_multiplier` in `self.train()`
if longer training is desired in specific cases).
lr: the learning rate
weight_decay: the weight decay factor for the reward model's weights
to use with ``th.optim.AdamW``. This is similar to but not equivalent
to L2 regularization, see https://arxiv.org/abs/1711.05101
seed: the random seed to use for splitting the dataset into training
and validation.
custom_logger: Where to log to; if None (default), creates a new logger.
regularizer_factory: if you would like to apply regularization during
training, specify a regularizer factory here. The factory will be
used to construct a regularizer. See
``imitation.regularization.RegularizerFactory`` for more details.
"""
super().__init__(model, custom_logger)
self.loss = loss
self.batch_size = batch_size
self.epochs = epochs
self.optim = th.optim.AdamW(
self._model.parameters(),
lr=lr,
weight_decay=weight_decay,
self.optim = th.optim.AdamW(self._model.parameters(), lr=lr)
self.seed = seed
self.regularizer = (
regularizer_factory(optimizer=self.optim, logger=self.logger)
if regularizer_factory is not None
else None
)

def _make_data_loader(self, dataset: PreferenceDataset) -> data_th.DataLoader:
Expand All @@ -1138,30 +1133,104 @@ def _make_data_loader(self, dataset: PreferenceDataset) -> data_th.DataLoader:
collate_fn=preference_collate_fn,
)

@property
def requires_regularizer_update(self) -> bool:
"""Whether the regularizer requires updating.

Returns:
If true, this means that a validation dataset will be used.
"""
return self.regularizer is not None and self.regularizer.val_split is not None

def _train(self, dataset: PreferenceDataset, epoch_multiplier: float = 1.0) -> None:
"""Trains for `epoch_multiplier * self.epochs` epochs over `dataset`."""
dataloader = self._make_data_loader(dataset)
if self.regularizer is not None and self.regularizer.val_split is not None:
val_length = int(len(dataset) * self.regularizer.val_split)
train_length = len(dataset) - val_length
if val_length < 1 or train_length < 1:
raise ValueError(
"Not enough data samples to split into training and validation, "
"or the validation split is too large/small. "
"Make sure you've generated enough initial preference data. "
"You can adjust this through initial_comparison_frac in "
"PreferenceComparisons.",
)
train_dataset, val_dataset = data_th.random_split(
dataset,
lengths=[train_length, val_length],
generator=th.Generator().manual_seed(self.seed) if self.seed else None,
)
dataloader = self._make_data_loader(train_dataset)
val_dataloader = self._make_data_loader(val_dataset)
else:
dataloader = self._make_data_loader(dataset)
val_dataloader = None

epochs = round(self.epochs * epoch_multiplier)

for _ in tqdm(range(epochs), desc="Training reward model"):
for fragment_pairs, preferences in dataloader:
self.optim.zero_grad()
loss = self._training_inner_loop(fragment_pairs, preferences)
loss.backward()
self.optim.step()
assert epochs > 0, "Must train for at least one epoch."
epoch_num = 0
with self.logger.accumulate_means("reward"):
for epoch_num in tqdm(range(epochs), desc="Training reward model"):
prefix = f"epoch-{epoch_num}"
train_loss = 0.0
for fragment_pairs, preferences in dataloader:
self.optim.zero_grad()
loss = self._training_inner_loop(
fragment_pairs,
preferences,
prefix=f"{prefix}/train",
)
train_loss += loss.item()
if self.regularizer:
self.regularizer.regularize_and_backward(loss)
else:
loss.backward()
self.optim.step()

if not self.requires_regularizer_update:
continue
assert val_dataloader is not None
assert self.regularizer is not None

val_loss = 0.0
for fragment_pairs, preferences in val_dataloader:
loss = self._training_inner_loop(
fragment_pairs,
preferences,
prefix=f"{prefix}/val",
)
val_loss += loss.item()
self.regularizer.update_params(train_loss, val_loss)

# after training all the epochs,
# record also the final value in a separate key for easy access.
keys = list(self.logger.name_to_value.keys())
for key in keys:
if key.startswith(f"mean/reward/epoch-{epoch_num}"):
val = self.logger.name_to_value[key]
new_key = key.replace(f"mean/reward/epoch-{epoch_num}", "reward/final")
self.logger.record(new_key, val)

def _training_inner_loop(
self,
fragment_pairs: Sequence[TrajectoryPair],
preferences: np.ndarray,
prefix: Optional[str] = None,
) -> th.Tensor:
output = self.loss.forward(fragment_pairs, preferences)
loss = output.loss
self.logger.record("loss", loss.item())
self.logger.record(self._get_logger_key(prefix, "loss"), loss.item())
for name, value in output.metrics.items():
self.logger.record(name, value.item())
self.logger.record(self._get_logger_key(prefix, name), value.item())
return loss

# TODO(juan) refactor & remove once #529 is merged.
def _get_logger_key(self, mode: Optional[str], key: str) -> str:
if mode is None:
return key
return f"{mode}/{key}"


class EnsembleTrainer(BasicRewardTrainer):
"""Train a reward ensemble."""
Expand All @@ -1175,9 +1244,9 @@ def __init__(
batch_size: int = 32,
epochs: int = 1,
lr: float = 1e-3,
weight_decay: float = 0.0,
seed: Optional[int] = None,
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
seed: Optional[int] = None,
regularizer_factory: Optional[regularizers.RegularizerFactory] = None,
):
"""Initialize the reward model trainer.

Expand All @@ -1189,35 +1258,37 @@ def __init__(
on the fly by specifying an `epoch_multiplier` in `self.train()`
if longer training is desired in specific cases).
lr: the learning rate
weight_decay: the weight decay factor for the reward model's weights
to use with ``th.optim.AdamW``. This is similar to but not equivalent
to L2 regularization, see https://arxiv.org/abs/1711.05101
seed: seed for the internal RNG used in bagging
seed: the random seed to use for splitting the dataset into training
and validation, and for bagging.
custom_logger: Where to log to; if None (default), creates a new logger.
regularizer_factory: A factory for creating a regularizer. If None,
no regularization is used.

Raises:
TypeError: if model is not a RewardEnsemble.
"""
if not isinstance(model, reward_nets.RewardEnsemble):
raise TypeError(
f"RewardEnsemble expected by EnsembleTrainer not {type(model)}.",
f"RewardEnsemble expected by EnsembleTrainer, not {type(model)}.",
)

super().__init__(
model,
loss,
batch_size,
epochs,
lr,
weight_decay,
custom_logger,
model=model,
loss=loss,
batch_size=batch_size,
epochs=epochs,
lr=lr,
custom_logger=custom_logger,
seed=seed,
regularizer_factory=regularizer_factory,
)
self.rng = np.random.default_rng(seed=seed)

def _training_inner_loop(
self,
fragment_pairs: Sequence[TrajectoryPair],
preferences: np.ndarray,
prefix: Optional[str] = None,
) -> th.Tensor:
assert len(fragment_pairs) == preferences.shape[0]
losses = []
Expand All @@ -1237,14 +1308,17 @@ def _training_inner_loop(
losses = th.stack(losses)
loss = losses.sum()

self.logger.record("loss", loss.item())
self.logger.record("loss_std", losses.std().item())
self.logger.record(self._get_logger_key(prefix, "loss"), loss.item())
self.logger.record(
self._get_logger_key(prefix, "loss_std"),
losses.std().item(),
)

# Turn metrics from a list of dictionaries into a dictionary of
# tensors.
metrics = {k: th.stack([di[k] for di in metrics]) for k in metrics[0]}
for name, value in metrics.items():
self.logger.record(name, value.mean().item())
self.logger.record(self._get_logger_key(prefix, name), value.mean().item())

return loss

Expand Down Expand Up @@ -1285,11 +1359,7 @@ def _make_reward_trainer(
f" by AddSTDRewardWrapper but found {type(reward_model).__name__}.",
)
else:
return BasicRewardTrainer(
reward_model,
loss=loss,
**reward_trainer_kwargs,
)
return BasicRewardTrainer(reward_model, loss=loss, **reward_trainer_kwargs)


QUERY_SCHEDULES: Dict[str, type_aliases.Schedule] = {
Expand Down Expand Up @@ -1506,13 +1576,9 @@ def train(
if i == 0:
epoch_multiplier = self.initial_epoch_multiplier

with self.logger.accumulate_means("reward"):
self.reward_trainer.train(
self.dataset,
epoch_multiplier=epoch_multiplier,
)
reward_loss = self.logger.name_to_value["mean/reward/loss"]
reward_accuracy = self.logger.name_to_value["mean/reward/accuracy"]
self.reward_trainer.train(self.dataset, epoch_multiplier=epoch_multiplier)
reward_loss = self.logger.name_to_value["reward/final/train/loss"]
reward_accuracy = self.logger.name_to_value["reward/final/train/accuracy"]

###################
# Train the agent #
Expand Down
1 change: 1 addition & 0 deletions src/imitation/regularization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Implements a variety of regularization techniques for NN weights."""
Loading