From 4c061f988dbe09be0cf8fc8c7e03463539eaecac Mon Sep 17 00:00:00 2001 From: Joe Halliwell Date: Sat, 10 Jul 2021 08:13:30 +0100 Subject: [PATCH 1/2] Add __init__.py --- taming/__init__.py | 0 taming/data/__init__.py | 0 taming/models/__init__.py | 0 taming/modules/__init__.py | 0 taming/modules/diffusionmodules/__init__.py | 0 taming/modules/discriminator/__init__.py | 0 taming/modules/misc/__init__.py | 0 taming/modules/transformer/__init__.py | 0 taming/modules/vqvae/__init__.py | 0 9 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 taming/__init__.py create mode 100644 taming/data/__init__.py create mode 100644 taming/models/__init__.py create mode 100644 taming/modules/__init__.py create mode 100644 taming/modules/diffusionmodules/__init__.py create mode 100644 taming/modules/discriminator/__init__.py create mode 100644 taming/modules/misc/__init__.py create mode 100644 taming/modules/transformer/__init__.py create mode 100644 taming/modules/vqvae/__init__.py diff --git a/taming/__init__.py b/taming/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/data/__init__.py b/taming/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/models/__init__.py b/taming/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/__init__.py b/taming/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/diffusionmodules/__init__.py b/taming/modules/diffusionmodules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/discriminator/__init__.py b/taming/modules/discriminator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/misc/__init__.py b/taming/modules/misc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/transformer/__init__.py b/taming/modules/transformer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/taming/modules/vqvae/__init__.py b/taming/modules/vqvae/__init__.py new file mode 100644 index 00000000..e69de29b From 8f92908403cc25a6edeab1eee0eb89950c08e4a6 Mon Sep 17 00:00:00 2001 From: Joe Halliwell Date: Sat, 10 Jul 2021 08:40:55 +0100 Subject: [PATCH 2/2] Move get_obj_from_str() and instantiate_from_config() into package --- main.py | 8 +------- taming/__init__.py | 13 +++++++++++++ taming/models/cond_transformer.py | 2 +- taming/models/vqgan.py | 2 +- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 7b4f94c5..4a191dc6 100644 --- a/main.py +++ b/main.py @@ -11,13 +11,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor from pytorch_lightning.utilities.distributed import rank_zero_only -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - +from taming import get_obj_from_str, instantiate_from_config def get_parser(**parser_kwargs): def str2bool(v): diff --git a/taming/__init__.py b/taming/__init__.py index e69de29b..ac572368 100644 --- a/taming/__init__.py +++ b/taming/__init__.py @@ -0,0 +1,13 @@ +import importlib + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) diff --git a/taming/models/cond_transformer.py b/taming/models/cond_transformer.py index 6e6869b0..f4bc583f 100644 --- a/taming/models/cond_transformer.py +++ b/taming/models/cond_transformer.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import pytorch_lightning as pl -from main import instantiate_from_config +from taming import instantiate_from_config from taming.modules.util import SOSProvider diff --git a/taming/models/vqgan.py b/taming/models/vqgan.py index 121d01fd..9d751d2d 100644 --- a/taming/models/vqgan.py +++ b/taming/models/vqgan.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import pytorch_lightning as pl -from main import instantiate_from_config +from taming import instantiate_from_config from taming.modules.diffusionmodules.model import Encoder, Decoder from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer