Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve bagging method for ensemble training (#530) #542

Merged
merged 17 commits into from
Sep 18, 2022

Conversation

taufeeque9
Copy link
Collaborator

Description

Improved the bagging method to train ensemble members by sampling examples over the entire preference dataset instead of the earlier implementation of sampling over fixed batches. Also refactored PreferenceModel, CrossEntropyRewardLoss, EnsembleTrainer, and _make_reward_trainer.

The ensemble member models are trained sequentially using different samples of the dataset. I also tried to train them simultaneously by running the training method in multiple threads. However, this doesn't give much gain and, in fact, increases the overall training time for more than 3 ensemble members. Training simultaneously using multiple processes is too much of a hassle as we cannot directly pass the _train method to processes as they require pickling the self object. So I have currently let the member models train sequentially, and we can try training them in parallel in another PR if required.

One thing to check is the logging behavior of loss and other metrics for the ensemble members. As far as I understand, the sequential training of the member models will overwrite the logs of the model by the last member model that is trained. But the accumulated mean metrics should be calculated correctly across all the member models.

Testing

I have separately tested that the sampler returns different indexes across the member model for bagging_dataset. No new test case added.

@codecov
Copy link

codecov bot commented Aug 29, 2022

Codecov Report

Merging #542 (5ba5a13) into master (c3a63d5) will increase coverage by 0.02%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #542      +/-   ##
==========================================
+ Coverage   97.26%   97.28%   +0.02%     
==========================================
  Files          88       88              
  Lines        8002     8030      +28     
==========================================
+ Hits         7783     7812      +29     
+ Misses        219      218       -1     
Impacted Files Coverage Δ
src/imitation/algorithms/preference_comparisons.py 99.24% <100.00%> (+0.20%) ⬆️
.../imitation/scripts/train_preference_comparisons.py 96.72% <100.00%> (ø)
src/imitation/util/logger.py 100.00% <100.00%> (ø)
tests/algorithms/test_preference_comparisons.py 100.00% <100.00%> (ø)
tests/util/test_logger.py 100.00% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I like the refactor, just left a few suggestions.

As you flagged on Slack, this will result in a different bagging dataset each time train gets called because the dataset continually grows. I don't see a great way around this, so am happy to stick with this implementation for now, and revisit in the future if the empirical results from this method are not strong enough. But wanted to include it in the PR conversation for posterity.

@@ -1155,7 +1157,7 @@ def _training_inner_loop(
fragment_pairs: Sequence[TrajectoryPair],
preferences: np.ndarray,
) -> th.Tensor:
output = self.loss.forward(fragment_pairs, preferences)
output = self.loss.forward(fragment_pairs, preferences, self._preference_model)
loss = output.loss
self.logger.record("loss", loss.item())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are kinda abusing an unintended feature of the hierarchical logger here that may go away after a refactor. We could add loss to the metrics dictionary and do the averaging ourselves in the _train method. At that point it would probably also make sense to just remove the train_inner_loop method entirely.

@taufeeque9
Copy link
Collaborator Author

The ensemble model now logs the loss and the metrics for each member and reports the mean and the standard deviation of the average of the final epoch metrics. Also added a function to change the max_length for truncation of log keys.

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making these changes. Appreciate the test case, and adding support for different max_length.

I'm confused by the way you're adding prefixes for log keys. Left some suggestions for a couple of ways that seem more intuitive to me, although there may be a reason I'm missing why we need to do it another way.

Other comments pretty minor. Please request re-review once addressed.

"mean/reward",
f"mean/reward/member-{member_idx}",
)
self.logger.record(new_key, val)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this code. It seems like we're recording things to a new key (based on the member index), but:

  1. Only the last value logged. (Maybe that's OK, if there's only a single timestep, that is self.logger.dump is never called inside _train?)
  2. We don't remove the old key.

Why not log with the right key in the first place? You could use the add_prefix method from Lev's PR https://github.com/HumanCompatibleAI/imitation/blob/4df85518203e5e83f34864d31ff807813f2a259c/src/imitation/util/logger.py (not yet merged, but you could always cherry-pick the logging commits into a separate PR, get that PR approved & merged, then go back and use it here). Alternatively you could follow the approach Rocamonde took with the regularization PR and pass around a prefix explicitly like in https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/algorithms/preference_comparisons.py#L1223

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have currently used Rocamonde's approach to log. Once #529 gets merged, we can switch the current code to using the add_prefix method in Lev's PR.

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for implementing this!

@taufeeque9 taufeeque9 merged commit 32467bf into master Sep 18, 2022
@taufeeque9 taufeeque9 deleted the bagging-active-selection branch September 18, 2022 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants