-
Notifications
You must be signed in to change notification settings - Fork 260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve bagging method for ensemble training (#530) #542
Conversation
* Refactor PreferenceModel and CrossEntropyLoss classes *
Codecov Report
@@ Coverage Diff @@
## master #542 +/- ##
==========================================
+ Coverage 97.26% 97.28% +0.02%
==========================================
Files 88 88
Lines 8002 8030 +28
==========================================
+ Hits 7783 7812 +29
+ Misses 219 218 -1
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I like the refactor, just left a few suggestions.
As you flagged on Slack, this will result in a different bagging dataset each time train gets called because the dataset continually grows. I don't see a great way around this, so am happy to stick with this implementation for now, and revisit in the future if the empirical results from this method are not strong enough. But wanted to include it in the PR conversation for posterity.
@@ -1155,7 +1157,7 @@ def _training_inner_loop( | |||
fragment_pairs: Sequence[TrajectoryPair], | |||
preferences: np.ndarray, | |||
) -> th.Tensor: | |||
output = self.loss.forward(fragment_pairs, preferences) | |||
output = self.loss.forward(fragment_pairs, preferences, self._preference_model) | |||
loss = output.loss | |||
self.logger.record("loss", loss.item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are kinda abusing an unintended feature of the hierarchical logger here that may go away after a refactor. We could add loss
to the metrics
dictionary and do the averaging ourselves in the _train
method. At that point it would probably also make sense to just remove the train_inner_loop
method entirely.
The ensemble model now logs the loss and the metrics for each member and reports the mean and the standard deviation of the average of the final epoch metrics. Also added a function to change the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making these changes. Appreciate the test case, and adding support for different max_length
.
I'm confused by the way you're adding prefixes for log keys. Left some suggestions for a couple of ways that seem more intuitive to me, although there may be a reason I'm missing why we need to do it another way.
Other comments pretty minor. Please request re-review once addressed.
"mean/reward", | ||
f"mean/reward/member-{member_idx}", | ||
) | ||
self.logger.record(new_key, val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused by this code. It seems like we're recording things to a new key (based on the member index), but:
- Only the last value logged. (Maybe that's OK, if there's only a single timestep, that is
self.logger.dump
is never called inside_train
?) - We don't remove the old key.
Why not log with the right key in the first place? You could use the add_prefix
method from Lev's PR https://github.com/HumanCompatibleAI/imitation/blob/4df85518203e5e83f34864d31ff807813f2a259c/src/imitation/util/logger.py (not yet merged, but you could always cherry-pick the logging commits into a separate PR, get that PR approved & merged, then go back and use it here). Alternatively you could follow the approach Rocamonde took with the regularization PR and pass around a prefix explicitly like in https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/algorithms/preference_comparisons.py#L1223
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have currently used Rocamonde's approach to log. Once #529 gets merged, we can switch the current code to using the add_prefix
method in Lev's PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for implementing this!
Description
Improved the bagging method to train ensemble members by sampling examples over the entire preference dataset instead of the earlier implementation of sampling over fixed batches. Also refactored
PreferenceModel
,CrossEntropyRewardLoss
,EnsembleTrainer
, and_make_reward_trainer
.The ensemble member models are trained sequentially using different samples of the dataset. I also tried to train them simultaneously by running the training method in multiple threads. However, this doesn't give much gain and, in fact, increases the overall training time for more than 3 ensemble members. Training simultaneously using multiple processes is too much of a hassle as we cannot directly pass the
_train
method to processes as they require pickling theself
object. So I have currently let the member models train sequentially, and we can try training them in parallel in another PR if required.One thing to check is the logging behavior of loss and other metrics for the ensemble members. As far as I understand, the sequential training of the member models will overwrite the logs of the model by the last member model that is trained. But the accumulated mean metrics should be calculated correctly across all the member models.
Testing
I have separately tested that the sampler returns different indexes across the member model for
bagging_dataset
. No new test case added.