@@ -434,9 +434,10 @@ def adapter_state_dict_load_map(self, adapter_name):
434
434
"""Return a mapping from the key present in disk-loaded state dict
435
435
and how it should be represented in the loaded model's state dict.
436
436
437
- If a key is not present here, it is assumed to be mapped 1:1.
437
+ The default should be a 1:1 mapping but it is important to define a mapping as it also serves as the
438
+ ground-truth for which keys are supposed to be loaded from a saved state dict.
438
439
"""
439
- return {}
440
+ raise NotImplementedError
440
441
441
442
def unload_and_optionally_merge_module (
442
443
self , merge : bool , safe_merge : bool , adapter_names : Optional [list [str ]]
@@ -550,15 +551,24 @@ def set_adapter(self, adapter_names: Union[str, list[str]]):
550
551
self ._active_adapter = adapter_name
551
552
552
553
def adapter_state_dict_load_map (self , adapter_name ):
553
- # The state dict returned by ModulesToSaveWrapper
554
- return {k : f"modules_to_save.{ adapter_name } .{ k } " for k in self .adapter_state_dict (adapter_name )}
554
+ # Maps the module keys as they are in the saved state dict to the in-memory state dict.
555
+ # Must contain all keys that are supposed to be loaded.
556
+ if adapter_name not in self ._adapters :
557
+ # In caes of multiple adapters, each bringing their own modules to save, each
558
+ # ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
559
+ return {}
560
+ return {k : f"modules_to_save.{ adapter_name } .{ k } " for k in self .modules_to_save [adapter_name ].state_dict ()}
555
561
556
- def adapter_state_dict (self , adapter_name ):
562
+ def adapter_state_dict (self , adapter_name , state_dict ):
557
563
if adapter_name not in self ._adapters :
558
564
# In caes of multiple adapters, each bringing their own modules to save, each
559
565
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
560
566
return {}
561
- return self .modules_to_save [adapter_name ].state_dict ()
567
+
568
+ return {
569
+ k : state_dict [f"modules_to_save.{ adapter_name } .{ k } " ]
570
+ for k in self .modules_to_save [adapter_name ].state_dict ()
571
+ }
562
572
563
573
def unload_and_optionally_merge_module (
564
574
self , merge : bool , safe_merge : bool , adapter_names : Optional [list [str ]]
@@ -651,17 +661,20 @@ def update(self, active_adapter, **kwargs):
651
661
652
662
super ().update (active_adapter )
653
663
654
- def adapter_state_dict (self , adapter_name ):
664
+ def adapter_state_dict_load_map (self , adapter_name ):
665
+ if self .token_adapter .tied_adapter :
666
+ return {}
667
+ return {"token_adapter.trainable_tokens_delta" : f"token_adapter.trainable_tokens_delta.{ adapter_name } " }
668
+
669
+ def adapter_state_dict (self , adapter_name , state_dict ):
655
670
if self .token_adapter .tied_adapter :
656
671
# storing of weight-tied layers is not up to us and will be handled by
657
672
# transformers. we're just here to keep those layers in sync during training.
658
673
# therefore we return an empty state dict.
659
674
return {}
660
675
661
676
return {
662
- f"token_adapter.{ k } " : v
663
- for k , v in self .token_adapter .state_dict ().items ()
664
- if k .startswith ("trainable_tokens_" ) and k .endswith (f".{ adapter_name } " )
677
+ f"token_adapter.{ k } " : state_dict [f"token_adapter.{ k } .{ adapter_name } " ] for k in ["trainable_tokens_delta" ]
665
678
}
666
679
667
680
def enable_adapters (self , enabled : bool ):
0 commit comments