Skip to content

Commit c42eb22

Browse files
authored
Prepare 0.15.1 patch (#2459)
This release is a patch release to release a fix for #2450 which might result in loss of `modules_to_save` when trained with deepspeed ZerO stage 3.
1 parent b34d8a2 commit c42eb22

File tree

4 files changed

+38
-21
lines changed

4 files changed

+38
-21
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from setuptools import find_packages, setup
1616

1717

18-
VERSION = "0.15.0"
18+
VERSION = "0.15.1"
1919

2020
extras = {}
2121
extras["quality"] = [

src/peft/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "0.15.0"
15+
__version__ = "0.15.1"
1616

1717
from .auto import (
1818
MODEL_TYPE_TO_PEFT_MODEL_MAPPING,

src/peft/utils/other.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,10 @@ def adapter_state_dict_load_map(self, adapter_name):
434434
"""Return a mapping from the key present in disk-loaded state dict
435435
and how it should be represented in the loaded model's state dict.
436436
437-
If a key is not present here, it is assumed to be mapped 1:1.
437+
The default should be a 1:1 mapping but it is important to define a mapping as it also serves as the
438+
ground-truth for which keys are supposed to be loaded from a saved state dict.
438439
"""
439-
return {}
440+
raise NotImplementedError
440441

441442
def unload_and_optionally_merge_module(
442443
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
@@ -550,15 +551,24 @@ def set_adapter(self, adapter_names: Union[str, list[str]]):
550551
self._active_adapter = adapter_name
551552

552553
def adapter_state_dict_load_map(self, adapter_name):
553-
# The state dict returned by ModulesToSaveWrapper
554-
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.adapter_state_dict(adapter_name)}
554+
# Maps the module keys as they are in the saved state dict to the in-memory state dict.
555+
# Must contain all keys that are supposed to be loaded.
556+
if adapter_name not in self._adapters:
557+
# In caes of multiple adapters, each bringing their own modules to save, each
558+
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
559+
return {}
560+
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.modules_to_save[adapter_name].state_dict()}
555561

556-
def adapter_state_dict(self, adapter_name):
562+
def adapter_state_dict(self, adapter_name, state_dict):
557563
if adapter_name not in self._adapters:
558564
# In caes of multiple adapters, each bringing their own modules to save, each
559565
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
560566
return {}
561-
return self.modules_to_save[adapter_name].state_dict()
567+
568+
return {
569+
k: state_dict[f"modules_to_save.{adapter_name}.{k}"]
570+
for k in self.modules_to_save[adapter_name].state_dict()
571+
}
562572

563573
def unload_and_optionally_merge_module(
564574
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
@@ -651,17 +661,20 @@ def update(self, active_adapter, **kwargs):
651661

652662
super().update(active_adapter)
653663

654-
def adapter_state_dict(self, adapter_name):
664+
def adapter_state_dict_load_map(self, adapter_name):
665+
if self.token_adapter.tied_adapter:
666+
return {}
667+
return {"token_adapter.trainable_tokens_delta": f"token_adapter.trainable_tokens_delta.{adapter_name}"}
668+
669+
def adapter_state_dict(self, adapter_name, state_dict):
655670
if self.token_adapter.tied_adapter:
656671
# storing of weight-tied layers is not up to us and will be handled by
657672
# transformers. we're just here to keep those layers in sync during training.
658673
# therefore we return an empty state dict.
659674
return {}
660675

661676
return {
662-
f"token_adapter.{k}": v
663-
for k, v in self.token_adapter.state_dict().items()
664-
if k.startswith("trainable_tokens_") and k.endswith(f".{adapter_name}")
677+
f"token_adapter.{k}": state_dict[f"token_adapter.{k}.{adapter_name}"] for k in ["trainable_tokens_delta"]
665678
}
666679

667680
def enable_adapters(self, enabled: bool):

src/peft/utils/save_and_load.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,16 @@ def renamed_dora_weights(k):
197197
# ADDITIONAL TRAINING MODULES / MODULES_TO_SAVE
198198
for name, module in model.named_modules():
199199
if isinstance(module, AuxiliaryTrainingWrapper):
200-
to_return.update({f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name).items()})
200+
# Compute the module-relative state dict to make it easier for the adapter to fetch the appropriate
201+
# keys that the module thinks need to be saved. We cannot rely on `.state_dict()` internally of the
202+
# module since accelerators like DeepSpeed require special handling which is done for the model
203+
# state dict from above but most likely not in the module itself. See #2450.
204+
module_state_dict = {
205+
k.removeprefix(f"{name}."): v for k, v in state_dict.items() if k.startswith(f"{name}.")
206+
}
207+
to_return.update(
208+
{f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name, module_state_dict).items()}
209+
)
201210

202211
# DEAL WITH EMBEDDINGS
203212
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
@@ -343,14 +352,9 @@ def set_peft_model_state_dict(
343352
# `modules_to_save.{adapter_name}.` prefix. This prefix must be restored when loading the model from the
344353
# saved state dict which is why we fetch a load key map from the wrapper.
345354
key_map = module.adapter_state_dict_load_map(adapter_name)
346-
347-
for k in module.adapter_state_dict(adapter_name):
348-
# each saved state dict is adapter specific, i.e. does not contain the adapter name
349-
# but the loaded state dict does include adapter names since we can have multiple.
350-
k_no_adapter = k.replace(f".{adapter_name}", "")
351-
352-
store_key = f"{name}.{key_map.get(k, k)}"
353-
lookup_key = f"{name}.{k_no_adapter}"
355+
for k in key_map:
356+
lookup_key = f"{name}.{k}"
357+
store_key = f"{name}.{key_map[k]}"
354358

355359
state_dict[store_key] = peft_model_state_dict[lookup_key]
356360

0 commit comments

Comments
 (0)