diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 5b2b0e5dfb..cd7a077a09 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1378,6 +1378,13 @@ def create_or_update_model_card(self, output_dir: str): card.text = "\n".join(lines) card.save(filename) + def gradient_checkpointing_enable(self): + if hasattr(self.base_model, "gradient_checkpointing_enable"): + self.base_model.gradient_checkpointing_enable() + self.base_model = self._prepare_model_for_gradient_checkpointing(self.base_model) + else: + raise AttributeError("gradient_checkpointing_enable is not defined") + class PeftModelForSequenceClassification(PeftModel): """