From 1b6e12f7669ff252b238d3417e9aaed5ad43c872 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 16 Aug 2022 15:55:15 +0200 Subject: [PATCH 1/3] Allow passing non-default modules to pipeline. Override modules are recognized and replaced in the pipeline. However, no check is performed about mismatched classes yet. This is because the override module is already instantiated and we have no library or class name to compare against. --- src/diffusers/pipeline_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index ee593b463210..4fe206c49fa8 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -156,6 +156,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P from diffusers import pipelines # 3. Load each module in the pipeline + + # 3.1 Assign override modules first (they are already instantiated in `init_dict`) + passed_class_obj = {k: v for k, v in kwargs.items() if k in init_dict} + for name, module in passed_class_obj.items(): + # TODO: verify that the module class belongs to one of the supported classes + init_kwargs[name] = module + init_dict.pop(name) + + # 3.2 Load standard modules for name, (library_name, class_name) in init_dict.items(): is_pipeline_module = hasattr(pipelines, library_name) # if the model is in a pipeline module, then we load it from the pipeline From 59b5fecbe1efede8eeb7e1fda6773384097b8c53 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Aug 2022 15:12:05 +0000 Subject: [PATCH 2/3] up --- src/diffusers/pipeline_utils.py | 71 +++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 4fe206c49fa8..94a6c67b1cc4 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import importlib +import inspect import os from typing import Optional, Union @@ -148,6 +149,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} @@ -156,19 +163,38 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P from diffusers import pipelines # 3. Load each module in the pipeline - - # 3.1 Assign override modules first (they are already instantiated in `init_dict`) - passed_class_obj = {k: v for k, v in kwargs.items() if k in init_dict} - for name, module in passed_class_obj.items(): - # TODO: verify that the module class belongs to one of the supported classes - init_kwargs[name] = module - init_dict.pop(name) - - # 3.2 Load standard modules for name, (library_name, class_name) in init_dict.items(): is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + # if the model is in a pipeline module, then we load it from the pipeline - if is_pipeline_module: + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warn( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) importable_classes = ALL_IMPORTABLE_CLASSES @@ -180,23 +206,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} - load_method_name = None - for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] + if loaded_sub_model is None: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] - load_method = getattr(class_obj, load_method_name) + load_method = getattr(class_obj, load_method_name) - # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): - loaded_sub_model = load_method(os.path.join(cached_folder, name)) - else: - # else load from the root directory - loaded_sub_model = load_method(cached_folder) + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name)) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - # 5. Instantiate the pipeline + # 4. Instantiate the pipeline model = pipeline_class(**init_kwargs) return model From eba502b25fae5e7536fd35d94b55cea6969e732f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 16 Aug 2022 15:17:07 +0000 Subject: [PATCH 3/3] add test --- tests/test_modeling_utils.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 894a4294d664..28766b58fb96 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -717,6 +717,28 @@ def test_from_pretrained_hub(self): assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + @slow + def test_from_pretrained_hub_pass_model(self): + model_path = "google/ddpm-cifar10-32" + + # pass unet into DiffusionPipeline + unet = UNet2DModel.from_pretrained(model_path) + ddpm_from_hub_custom_model = DDPMPipeline.from_pretrained(model_path, unet=unet) + ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet) + + ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) + + ddpm_from_hub_custom_model.scheduler.num_timesteps = 10 + ddpm_from_hub.scheduler.num_timesteps = 10 + + generator = torch.manual_seed(0) + + image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"] + generator = generator.manual_seed(0) + new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"] + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + @slow def test_output_format(self): model_path = "google/ddpm-cifar10-32"