-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Preserve parameters #27
Conversation
I am also trying to make LoRA dynamically loadable. This seems to be a good solution. I would test it later. |
Hey, I just find that it is unnecessary to pass Have you tried that? |
I have made switching LoRA dynamically possible: |
Hi,
Yes, this is the first thing that I tried and is almost true, except for the aforementioned
They get concatted by You can verify that with: constants = {}
for node in traced_module.graph.nodes():
if node.kind() != "prim::Constant":
continue
output = node.output()
name = output.debugName()
name = name.removeprefix("self.module.unet.") # might have a different prefix for you
if isinstance(output.type(), torch._C.TensorType):
tensor = output.toIValue()
if isinstance(tensor, torch.Tensor) and tensor.is_cuda:
assert name not in constants
constants[name] = tensor
parameters = {}
for name, param in unet.named_parameters():
assert isinstance(param, torch.Tensor)
parameters[name] = param
for key in sorted(parameters.keys() - constants.keys()):
print(key.ljust(80), parameters[key].shape, parameters[key].dtype)
# ^ this will output the missing constants
for key in sorted(constants.keys() - parameters.keys()):
print(key.ljust(80), constants[key].shape, constants[key].dtype)
# ^ this will output the additional constants
# you can see that some of them are contacted via their shapes. If you output the `data_ptr()` you will see that they do not share the data. Btw, love this project and its simplicity and looking forward to it beating TensorRT soon 💪 😄 |
You are right, and I think it is more important to make LoRA, ControlNet or any other parts of the model to be dynamically loadable than pursuing small marginal benefits of performance optimization. So I would set It is my belief that TorchScript or |
This enables the preserve_parameter setting to preserve the parameters and not inline them as constants.
Unfortunately this is not exposed on the
torch.jit.freeze
so I had to copy paste some logic from pytorch.This allows the user to update the weights on the original model (
model.unet.load_state_dict()
) without having to retrace and compile the module.The only performance impact this has is that before the
_jit_pass_concat_frozen_linear
concatted a fewtime_emb_proj
linear layers into one. So for the cuda graph inference there is no noticable impact in my testing. For the non-cuda-graph inference it might be slower because it has to traverse the model to get the weights.This should also make the lora fusing with PEFT work correctly. But I have not tested that yet.
Also the freeze=False on the controlnet can probably be removed with this.