Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve parameters #27

Merged
merged 1 commit into from
Nov 13, 2023
Merged

Conversation

skirsten
Copy link
Contributor

This enables the preserve_parameter setting to preserve the parameters and not inline them as constants.

Unfortunately this is not exposed on the torch.jit.freeze so I had to copy paste some logic from pytorch.

This allows the user to update the weights on the original model (model.unet.load_state_dict()) without having to retrace and compile the module.

The only performance impact this has is that before the _jit_pass_concat_frozen_linear concatted a few time_emb_proj linear layers into one. So for the cuda graph inference there is no noticable impact in my testing. For the non-cuda-graph inference it might be slower because it has to traverse the model to get the weights.

This should also make the lora fusing with PEFT work correctly. But I have not tested that yet.
Also the freeze=False on the controlnet can probably be removed with this.

@chengzeyi
Copy link
Owner

I am also trying to make LoRA dynamically loadable. This seems to be a good solution. I would test it later.

@chengzeyi chengzeyi changed the base branch from main to dev November 13, 2023 05:58
@chengzeyi chengzeyi merged commit fe1433b into chengzeyi:dev Nov 13, 2023
@chengzeyi
Copy link
Owner

This enables the preserve_parameter setting to preserve the parameters and not inline them as constants.

Unfortunately this is not exposed on the torch.jit.freeze so I had to copy paste some logic from pytorch.

This allows the user to update the weights on the original model (model.unet.load_state_dict()) without having to retrace and compile the module.

The only performance impact this has is that before the _jit_pass_concat_frozen_linear concatted a few time_emb_proj linear layers into one. So for the cuda graph inference there is no noticable impact in my testing. For the non-cuda-graph inference it might be slower because it has to traverse the model to get the weights.

This should also make the lora fusing with PEFT work correctly. But I have not tested that yet. Also the freeze=False on the controlnet can probably be removed with this.

Hey, I just find that it is unnecessary to pass preserve_parameter when you want to use load_state_dict, because load_state_dict copy into the parameters of the model in place, which the compiled JIT Graph and CUDA Graph would share.

Have you tried that?

@chengzeyi
Copy link
Owner

This enables the preserve_parameter setting to preserve the parameters and not inline them as constants.

Unfortunately this is not exposed on the torch.jit.freeze so I had to copy paste some logic from pytorch.

This allows the user to update the weights on the original model (model.unet.load_state_dict()) without having to retrace and compile the module.

The only performance impact this has is that before the _jit_pass_concat_frozen_linear concatted a few time_emb_proj linear layers into one. So for the cuda graph inference there is no noticable impact in my testing. For the non-cuda-graph inference it might be slower because it has to traverse the model to get the weights.

This should also make the lora fusing with PEFT work correctly. But I have not tested that yet. Also the freeze=False on the controlnet can probably be removed with this.

I have made switching LoRA dynamically possible:
https://github.com/chengzeyi/stable-fast#dynamically-switch-lora

@skirsten
Copy link
Contributor Author

skirsten commented Nov 13, 2023

Hi,

Hey, I just find that it is unnecessary to pass preserve_parameter when you want to use load_state_dict, because load_state_dict copy into the parameters of the model in place, which the compiled JIT Graph and CUDA Graph would share.

Yes, this is the first thing that I tried and is almost true, except for the aforementioned time_emb_proj layers. Specifically these weights:

down_blocks.0.resnets.0.time_emb_proj.bias
down_blocks.0.resnets.0.time_emb_proj.weight
down_blocks.0.resnets.1.time_emb_proj.bias
down_blocks.0.resnets.1.time_emb_proj.weight
down_blocks.1.resnets.0.time_emb_proj.bias
down_blocks.1.resnets.0.time_emb_proj.weight
mid_block.resnets.1.time_emb_proj.bias
mid_block.resnets.1.time_emb_proj.weight
up_blocks.0.resnets.0.time_emb_proj.bias
up_blocks.0.resnets.0.time_emb_proj.weight
up_blocks.2.resnets.0.time_emb_proj.bias
up_blocks.2.resnets.0.time_emb_proj.weight
up_blocks.2.resnets.1.time_emb_proj.bias
up_blocks.2.resnets.1.time_emb_proj.weight
up_blocks.2.resnets.2.time_emb_proj.bias
up_blocks.2.resnets.2.time_emb_proj.weight

They get concatted by _jit_pass_concat_frozen_linear and then do not share the same underlying memory as the state_dict anymore. So these will not get updated unless preserve_parameters is set.

You can verify that with:

constants = {}

for node in traced_module.graph.nodes():
    if node.kind() != "prim::Constant":
        continue

    output = node.output()
    name = output.debugName()
    name = name.removeprefix("self.module.unet.") # might have a different prefix for you

    if isinstance(output.type(), torch._C.TensorType):
        tensor = output.toIValue()

        if isinstance(tensor, torch.Tensor) and tensor.is_cuda:
            assert name not in constants
            constants[name] = tensor

parameters = {}

for name, param in unet.named_parameters():
    assert isinstance(param, torch.Tensor)
    parameters[name] = param

for key in sorted(parameters.keys() - constants.keys()):
    print(key.ljust(80), parameters[key].shape, parameters[key].dtype)

# ^ this will output the missing constants

for key in sorted(constants.keys() - parameters.keys()):
    print(key.ljust(80), constants[key].shape, constants[key].dtype)

# ^ this will output the additional constants
# you can see that some of them are contacted via their shapes. If you output the `data_ptr()` you will see that they do not share the data.

Btw, love this project and its simplicity and looking forward to it beating TensorRT soon 💪 😄

@chengzeyi
Copy link
Owner

_jit_pass_concat_frozen_linear

You are right, and I think it is more important to make LoRA, ControlNet or any other parts of the model to be dynamically loadable than pursuing small marginal benefits of performance optimization. So I would set preserve_parameters=True for freeze in the next release.

It is my belief that TorchScript or torch.compile could be improved to be powerful enough to beat those highly propagated "over-engineered softwares". Beating TensorRT is possible, since making LoRA dynamically loadable could be a game changer😄.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants