Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[zero++] Synchronize at the end of secondary partitioning and simplif…
…y the logic (#5216) ## 1. Why? We have a very long thread investigating [the issue](#5059). To summarize, this is because a. The 2nd partitioning is asynchronous because it copies device-to-device from full tensor to 2nd tensor b. When using prefetching, the all-gather of 2nd tensor can happen before 2nd partitioning ends. At that moment, the value of 2nd tensor might contain bad values.  Also, we found that the logic of copying is wrong and lengthy, so we simplified it to only two lines. Kudos to @yundai424, Haowen Ning, @samadejacobs for the investigation effort. ## 2. What? After multiple careful tests, we found patching `get_accelerator().synchronize()` to ensure all cuda stream finished before 2nd partitioning can prevent the issue ## 3. Tests I validated the correctness of the simplification of 2nd partition logic. The loss is "exactly" the same before and after simplification under the same random seed. Before ``` [ {"loss": 2.0731}, {"loss": 2.0288}, {"loss": 1.927}, {"loss": 1.8347}, {"loss": 1.8347}, {"loss": 1.7896}, {"loss": 1.602}, {"loss": 1.766}, {"loss": 1.8751}, {"loss": 1.6776} ] ``` After ``` [ {"loss": 2.0731}, {"loss": 2.0288}, {"loss": 1.927}, {"loss": 1.8347}, {"loss": 1.8347}, {"loss": 1.7896}, {"loss": 1.602}, {"loss": 1.766}, {"loss": 1.8751}, {"loss": 1.6776} ] ``` ## 4. TODO We need further investigation on the issue @samadejacobs 1) Revisit ZeRO-3 prefetch design 2) Refactor hpz to reuse primary tensor for secondary partition. --------- Signed-off-by: byhsu <byhsu@linkedin.com> Co-authored-by: byhsu <byhsu@linkedin.com>
- Loading branch information