Skip to content

Commit

Permalink
Update checkpoint_merger pipeline to pass the "variant" argument (#6670)
Browse files Browse the repository at this point in the history
* make checkpoint_merger pipeline pass the "variant" argument to from_pretrained()

* make style

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
  • Loading branch information
3 people authored Feb 22, 2024
1 parent 5a54dc9 commit d5f444d
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion examples/community/checkpoint_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
variant - which variant of a pretrained model to load, e.g. "fp16" (None)
"""
# Default kwargs from DiffusionPipeline
cache_dir = kwargs.pop("cache_dir", None)
Expand All @@ -89,6 +91,7 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
variant = kwargs.pop("variant", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
device_map = kwargs.pop("device_map", None)
Expand Down Expand Up @@ -173,7 +176,10 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
# Step 3:-
# Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place
final_pipe = DiffusionPipeline.from_pretrained(
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
cached_folders[0],
torch_dtype=torch_dtype,
device_map=device_map,
variant=variant,
)
final_pipe.to(self.device)

Expand Down

0 comments on commit d5f444d

Please sign in to comment.