-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Patch for Cambricon MLUs test #1747
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
d853d39
f605d3c
e0cfce6
4f462b8
2fbdb33
3b2b8d6
2d0ef7c
0965469
4447fe2
3e8e62b
99965ac
83eaba2
ad14bfe
2b1ad33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -24,6 +24,7 @@ | |||
import torch | ||||
import yaml | ||||
from diffusers import StableDiffusionPipeline | ||||
from packaging import version | ||||
|
||||
from peft import ( | ||||
AdaLoraConfig, | ||||
|
@@ -464,13 +465,16 @@ def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs): | |||
if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig): | ||||
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)") | ||||
|
||||
if (self.torch_device in ["cpu"]) and (version.parse(torch.__version__) <= version.parse(2.1)): | ||||
self.skipTest("PyTorch 2.1 not supported for Half of addmm_impl_cpu_ ") | ||||
|
||||
model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.float16) | ||||
config = config_cls( | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I cannot replicate this, whether with or without GPU. The idea of this test is exactly to check that this error does not occur with fp16, so not using this dtype is counter-productive. Is this only occurring with MLU devices? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reproduction code is as follows.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so instead of changing the dtype, how about skipping the test if an old pytorch version is detected? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, maybe we can use fp16 with pt>=2.3, and fp32 with pt<2.3 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We really don't need to test merging with fp32 here, as it's tested extensively in other tests. This test is very specifically for merging with fp16, so if we don't use fp16, we can skip it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ha, Got it! I will fix it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I found that it is the device "cpu" leads error. Line 473 in cb0bf07
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So IIRC, there is an error when using CPU + float16 + old PyTorch. If we change either of those variables, there is no error. On CI, we have a new PyTorch version, so it passes, despite using CPU. If we switch to I assume this fails on your CI because it uses an older PyTorch version. This is why I suggested to just skip the test with older PyTorch versions. If you want, you could add a specific test for merging float16 with MLU, which would be skipped if the device is not available. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BenjaminBossan |
||||
base_model_name_or_path=model_id, | ||||
**config_kwargs, | ||||
) | ||||
model = get_peft_model(model, config) | ||||
model = model.to(device="cpu", dtype=torch.float16) | ||||
model = model.to(device=self.torch_device, dtype=torch.float16) | ||||
|
||||
model.eval() | ||||
|
||||
|
@@ -561,6 +565,8 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): | |||
logits_merged_unloaded = model(**dummy_input)[0] | ||||
|
||||
atol, rtol = 1e-4, 1e-4 | ||||
if self.torch_device in ["mlu"]: | ||||
atol, rtol = 1e-3, 1e-3 # MLU | ||||
if (config.peft_type == "IA3") and (model_id == "Conv2d"): | ||||
# for some reason, the IA³ Conv2d introduces a larger error | ||||
atol, rtol = 0.3, 0.01 | ||||
|
@@ -689,16 +695,19 @@ def _test_safe_merge(self, model_id, config_cls, config_kwargs): | |||
model = get_peft_model(model, config).eval() | ||||
logits_peft = model(**inputs)[0] | ||||
|
||||
atol, rtol = 1e-6, 1e-6 # default | ||||
# Initializing with LN tuning cannot be configured to change the outputs (unlike init_lora_weights=False) | ||||
if not issubclass(config_cls, LNTuningConfig): | ||||
# sanity check that the logits are different | ||||
assert not torch.allclose(logits_base, logits_peft, atol=1e-6, rtol=1e-6) | ||||
assert not torch.allclose(logits_base, logits_peft, atol=atol, rtol=rtol) | ||||
|
||||
model_unloaded = model.merge_and_unload(safe_merge=True) | ||||
logits_unloaded = model_unloaded(**inputs)[0] | ||||
|
||||
if self.torch_device in ["mlu"]: | ||||
atol, rtol = 1e-3, 1e-3 # MLU | ||||
# check that the logits are the same after unloading | ||||
assert torch.allclose(logits_peft, logits_unloaded, atol=1e-6, rtol=1e-6) | ||||
assert torch.allclose(logits_peft, logits_unloaded, atol=atol, rtol=rtol) | ||||
|
||||
def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): | ||||
# Test for mixing different adapters in a single batch by passing the adapter_names argument | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for that. Fixed !