Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nn.Embedding Support to Lora #337

Merged
merged 4 commits into from
May 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 165 additions & 12 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,28 @@ class LoraModel(torch.nn.Module):
>>> lora_model = LoraModel(config, model)
```

```py
>>> import transformers
>>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_int8_training

>>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"]
>>> config = LoraConfig(
... r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
... )

>>> model = transformers.GPTJForCausalLM.from_pretrained(
... "kakaobrain/kogpt",
... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b
... pad_token_id=tokenizer.eos_token_id,
... use_cache=False,
... device_map={"": rank},
... torch_dtype=torch.float16,
... load_in_8bit=True,
... )
>>> model = prepare_model_for_int8_training(model)
>>> lora_model = get_peft_model(model, config)
```

**Attributes**:
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
Expand Down Expand Up @@ -171,7 +193,9 @@ def _find_and_replace(self, adapter_name):
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)
bias = target.bias is not None
if hasattr(target, "bias"):
bias = target.bias is not None

if isinstance(target, LoraLayer):
target.update_layer(
adapter_name,
Expand All @@ -194,6 +218,11 @@ def _find_and_replace(self, adapter_name):
new_module = Linear8bitLt(
adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
)
elif isinstance(target, torch.nn.Embedding):
embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None)
in_features, out_features = target.num_embeddings, target.embedding_dim
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
Expand Down Expand Up @@ -230,8 +259,10 @@ def _find_and_replace(self, adapter_name):
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
new_module.weight = old_module.weight
if old_module.bias is not None:
new_module.bias = old_module.bias
if hasattr(old_module, "bias"):
if old_module.bias is not None:
new_module.bias = old_module.bias

if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
Expand Down Expand Up @@ -337,15 +368,27 @@ def add_weighted_adapter(self, adapters, weights, adapter_name):
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer):
target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0
target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0
for adapter, weight in zip(adapters, weights):
if adapter not in target.lora_A:
continue
target.lora_A[adapter_name].weight.data += (
target.lora_A[adapter].weight.data * weight * target.scaling[adapter]
)
target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight
if adapter_name in target.lora_A:
target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0
target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0
for adapter, weight in zip(adapters, weights):
if adapter not in target.lora_A:
continue
target.lora_A[adapter_name].weight.data += (
target.lora_A[adapter].weight.data * weight * target.scaling[adapter]
)
target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight

elif adapter_name in target.lora_embedding_A:
target.lora_embedding_A[adapter_name].data = target.lora_embedding_A[adapter_name].data * 0.0
target.lora_embedding_B[adapter_name].data = target.lora_embedding_B[adapter_name].data * 0.0
for adapter, weight in zip(adapters, weights):
if adapter not in target.lora_embedding_A:
continue
target.lora_embedding_A[adapter_name].data += (
target.lora_embedding_A[adapter].data * weight * target.scaling[adapter]
)
target.lora_embedding_B[adapter_name].data += target.lora_embedding_B[adapter].data * weight


# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
Expand Down Expand Up @@ -389,6 +432,9 @@ def __init__(
self.lora_dropout = nn.ModuleDict({})
self.lora_A = nn.ModuleDict({})
self.lora_B = nn.ModuleDict({})
# For Embedding layer
self.lora_embedding_A = nn.ParameterDict({})
self.lora_embedding_B = nn.ParameterDict({})
# Mark the weight as unmerged
self.merged = False
self.disable_adapters = False
Expand All @@ -413,11 +459,37 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)

def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()

self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
if r > 0:
self.lora_embedding_A.update(
nn.ParameterDict({adapter_name: nn.Parameter(self.weight.new_zeros((r, self.in_features)))})
)
self.lora_embedding_B.update(
nn.ParameterDict({adapter_name: nn.Parameter(self.weight.new_zeros((self.out_features, r)))})
)
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)

def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B[adapter_name].weight)
if adapter_name in self.lora_embedding_A.keys():
# initialize a the same way as the default for nn.linear and b to zero
nn.init.zeros_(self.lora_embedding_A[adapter_name])
nn.init.normal_(self.lora_embedding_B[adapter_name])


class Linear(nn.Linear, LoraLayer):
Expand Down Expand Up @@ -508,6 +580,87 @@ def forward(self, x: torch.Tensor):
return result


class Embedding(nn.Embedding, LoraLayer):
# LoRA implemented in a Embedding layer
def __init__(
self,
adapter_name: str,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)

nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoraLayer.__init__(self, in_features=num_embeddings, out_features=embedding_dim)

self.weight.requires_grad = False

nn.Embedding.reset_parameters(self)
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name

def unmerge(self, mode: bool = True):
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
self.weight.data -= (
transpose(
self.lora_embedding_B[self.active_adapter] @ self.lora_embedding_A[self.active_adapter], True
)
* self.scaling[self.active_adapter]
)
self.merged = False

def merge(self):
if self.merged:
warnings.warn("Already merged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
self.weight.data += (
transpose(
self.lora_embedding_B[self.active_adapter] @ self.lora_embedding_A[self.active_adapter], True
)
* self.scaling[self.active_adapter]
)
self.merged = True

def forward(self, x: torch.Tensor):
if self.disable_adapters:
if self.r[self.active.adapter] > 0 and self.merged:
self.weight.data -= (
transpose(
self.lora_embedding_B[self.active_adapter].weight
@ self.lora_embedding_A[self.active_adapter].weight,
True,
)
* self.scaling[self.active_adapter]
)
self.merged = False
return nn.Embedding.forward(self, x)

elif self.r[self.active_adapter] > 0 and not self.merged:
result = nn.Embedding.forward(self, x)
if self.r[self.active_adapter] > 0:
after_A = F.embedding(
x,
self.lora_embedding_A[self.active_adapter].T,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
result += (after_A @ self.lora_embedding_B[self.active_adapter].T) * self.scaling[self.active_adapter]
return result
else:
return nn.Embedding.forward(self, x)


if is_bnb_available():

class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
Expand Down