diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9fba473 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.idea +.vscode +__pycache__/ +*.py[cod] +*$py.class diff --git a/requirements.txt b/requirements.txt index 2260c29..9ca2cf3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ h5py Pillow -ml_dtypes==0.2.0 +ml_dtypes tensorflow-addons tensorflow-macos tensorflow-metal @@ -11,7 +11,7 @@ regex gradio==3.50.2 scikit-image psutil -torch==2.1.0 +torch torchvision opencv-python numexpr diff --git a/stableDiffusionKeras/EncodeDecode.py b/stableDiffusionKeras/EncodeDecode.py new file mode 100644 index 0000000..36679d4 --- /dev/null +++ b/stableDiffusionKeras/EncodeDecode.py @@ -0,0 +1,192 @@ + +import keras +import tensorflow_addons as tfa + +from .layers import apply_seq +from .kerasCVDiffusionModels import GroupNormalization + +class Decoder(keras.Sequential): + def __init__( + self, + img_height, + img_width, + name = None, + download_weights = False): + super().__init__( + [ + keras.layers.Input((img_height // 8, img_width // 8, 4)), + keras.layers.Rescaling(1.0 / 0.18215), + PaddedConv2D(4, 1, name = "PostQuantConvolutionalIn"), + PaddedConv2D(512, 3, padding = "same", name = "ConvolutionalIn"), + ResnetBlock(512), + AttentionBlock(512), + ResnetBlock(512), + ResnetBlock(512), + ResnetBlock(512), + ResnetBlock(512), + keras.layers.UpSampling2D(size = (2,2)), + PaddedConv2D(512, 3, padding = "same"), + ResnetBlock(512), + ResnetBlock(512), + ResnetBlock(512), + keras.layers.UpSampling2D(size = (2,2)), + PaddedConv2D(512, 3, padding = "same"), + ResnetBlock(256), + ResnetBlock(256), + ResnetBlock(256), + keras.layers.UpSampling2D(size = (2,2)), + PaddedConv2D(256, 3, padding = "same"), + ResnetBlock(128), + ResnetBlock(128), + ResnetBlock(128), + GroupNormalization(epsilon = 1e-5), + keras.layers.Activation("swish"), + PaddedConv2D(3, 3, padding = "same", name = "ConvolutionalOut"), + ], + name=name, + ) + + if download_weights: + decoder_weights_fpath = keras.utils.get_file( + origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5", + file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962", + ) + self.load_weights(decoder_weights_fpath) + +class ImageEncoder(keras.Sequential): + """ImageEncoder is the VAE Encoder for StableDiffusion.""" + + def __init__( + self, + img_height = 512, + img_width = 512, + download_weights = False + ): + super().__init__( + [ + keras.layers.Input((img_height, img_width, 3)), + PaddedConv2D(128, 3, padding = "same"), + ResnetBlock(128), + ResnetBlock(128), + PaddedConv2D(128, 3, padding = "same", strides = 2), + ResnetBlock(256), + ResnetBlock(256), + PaddedConv2D(256, 3, padding = "same", strides = 2), + ResnetBlock(512), + ResnetBlock(512), + PaddedConv2D(512, 3, padding = "same", strides = 2), + ResnetBlock(512), + ResnetBlock(512), + ResnetBlock(512), + AttentionBlock(512), + ResnetBlock(512), + GroupNormalization(epsilon = 1e-5), + keras.layers.Activation("swish"), + PaddedConv2D(8, 3, padding = "same"), + PaddedConv2D(8, 1), + # TODO(lukewood): can this be refactored to be a Rescaling layer? + # Perhaps some sort of rescale and gather? + # Either way, we may need a lambda to gather the first 4 dimensions. + keras.layers.Lambda(lambda x: x[..., :4] * 0.18215), + ] + ) + +""" +Blocks +""" + +class ResnetBlock(keras.layers.Layer): + def __init__(self, output_dim, **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.norm1 = GroupNormalization(epsilon=1e-5) + self.conv1 = PaddedConv2D(output_dim, 3, padding = "same") + self.norm2 = GroupNormalization(epsilon=1e-5) + self.conv2 = PaddedConv2D(output_dim, 3, padding = "same") + + def build(self, input_shape): + if input_shape[-1] != self.output_dim: + self.residual_projection = PaddedConv2D(self.output_dim, 1) + else: + self.residual_projection = lambda x: x + + def call(self, inputs): + x = self.conv1(keras.activations.swish(self.norm1(inputs))) + x = self.conv2(keras.activations.swish(self.norm2(x))) + return x + self.residual_projection(inputs) + + def get_config(self): + config = super().get_config() + config.update({ + "output_dim": self.output_dim, + }) + return config + +class AttentionBlock(keras.layers.Layer): + def __init__(self, output_dim, **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.norm = GroupNormalization(epsilon=1e-5) + self.q = PaddedConv2D(output_dim, 1) + self.k = PaddedConv2D(output_dim, 1) + self.v = PaddedConv2D(output_dim, 1) + self.proj_out = PaddedConv2D(output_dim, 1) + + def get_config(self): + config = super().get_config() + config.update({ + "output_dim": self.output_dim, + }) + return config + + def call(self, inputs): + x = self.norm(inputs) + q, k, v = self.q(x), self.k(x), self.v(x) + + # Compute attention + _, h, w, c = q.shape + q = keras.ops.reshape(q, (-1, h * w, c)) # b, hw, c + k = keras.ops.transpose(k, (0, 3, 1, 2)) + k = keras.ops.reshape(k, (-1, c, h * w)) # b, c, hw + y = q @ k + y = y * (c**-0.5) + y = keras.activations.softmax(y) + + # Attend to values + v = keras.ops.transpose(v, (0, 3, 1, 2)) + v = keras.ops.reshape(v, (-1, c, h * w)) + y = keras.ops.transpose(y, (0, 2, 1)) + x = v @ y + x = keras.ops.transpose(x, (0, 2, 1)) + x = keras.ops.reshape(x, (-1, h, w, c)) + return self.proj_out(x) + inputs + +class PaddedConv2D(keras.layers.Layer): + def __init__( + self, + filters, + kernel_size, + padding = "valid", + strides = 1, + name = None, + **kwargs + ): + super().__init__(**kwargs) + self.conv2d = keras.layers.Conv2D(filters, kernel_size, strides = strides, padding = padding, name = name) + self.filters = filters + self.kernel_size = kernel_size + self.padding = padding + self.strides = strides + + def call(self, inputs): + return self.conv2d(inputs) + + def get_config(self): + config = super().get_config() + config.update({ + "filters": self.filters, + "kernel_size": self.kernel_size, + "padding": self.padding, + "strides": self.strides, + }) + return config \ No newline at end of file diff --git a/stableDiffusionKeras/ReadMe.md b/stableDiffusionKeras/ReadMe.md new file mode 100644 index 0000000..ca591b2 --- /dev/null +++ b/stableDiffusionKeras/ReadMe.md @@ -0,0 +1,3 @@ +### Stable Diffusion TensorFlow ### + +Originally implemented by Divum Gupta, this heavily modified version is the nuts and bolts of MetalDiffusion. diff --git a/stableDiffusionKeras/__init__.py b/stableDiffusionKeras/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stableDiffusionKeras/clipEncoder.py b/stableDiffusionKeras/clipEncoder.py new file mode 100644 index 0000000..9dfb4f8 --- /dev/null +++ b/stableDiffusionKeras/clipEncoder.py @@ -0,0 +1,154 @@ + +import keras +import tensorflow_addons as tfa +import numpy as np + +from .layers import quick_gelu + +# Step 1 +# Create and return the CLIP Embeddings +class CLIPTextTransformer(keras.models.Model): + def __init__( + self, + maxLength = 77, + vocabularySize = 49408 + ): + super().__init__() + + # Create embeddings -> Step 2 + self.embeddings = CLIPTextEmbeddings(maxLength = maxLength, vocabularySize = vocabularySize) + + # Create encoder -> Step 3 + self.encoder = CLIPEncoder() + + self.final_layer_norm = keras.layers.LayerNormalization(epsilon = 1e-5, name = "FinalLayerNormalization") + self.causal_attention_mask = keras.initializers.Constant( + np.triu(np.ones((1, 1, 77, 77), dtype = "float32") * -np.inf, k = 1), + name = "CausalAttentionMask" + ) + + def call(self, inputs): + input_ids, position_ids = inputs + x = self.embeddings([input_ids, position_ids]) + x = self.encoder([x, self.causal_attention_mask]) + return self.final_layer_norm(x) + +# Step 2 +# Create and return word and position embeddings + +class CLIPTextEmbeddings(keras.layers.Layer): + def __init__( + self, + maxLength = 77, + vocabularySize = 49408, + embeddingSize = 768 + ): + super().__init__() + self.token_embedding_layer = keras.layers.Embedding( + vocabularySize, embeddingSize, name = "token_embedding" + ) + self.position_embedding_layer = keras.layers.Embedding( + maxLength, embeddingSize, name = "position_embedding" + ) + + def call(self, inputs): + input_ids, position_ids = inputs + word_embeddings = self.token_embedding_layer(input_ids) + position_embeddings = self.position_embedding_layer(position_ids) + return word_embeddings + position_embeddings + +# Step 3 +# Create and return the hidden states (aka hidden size) +class CLIPEncoder(keras.layers.Layer): + def __init__(self): + super().__init__() + self.layers = [CLIPEncoderLayer() for i in range(12)] + + def call(self, inputs): + [hidden_states, causal_attention_mask] = inputs + for l in self.layers: + hidden_states = l([hidden_states, causal_attention_mask]) + return hidden_states + +# Step 4 (also creatd in step 3) +# Create the layers +class CLIPEncoderLayer(keras.layers.Layer): + def __init__( + self, + intermediateSize = 3072, + embeddingSize = 768 + ): + super().__init__() + self.layer_norm1 = keras.layers.LayerNormalization(epsilon = 1e-5, name = "LayerNormalization001") + self.self_attn = CLIPAttention() + self.layer_norm2 = keras.layers.LayerNormalization(epsilon = 1e-5, name = "LayerNormalization002") + self.fc1 = keras.layers.Dense(intermediateSize, name = "FC1") + self.fc2 = keras.layers.Dense(embeddingSize, name = "FC2") + + def call(self, inputs): + hidden_states, causal_attention_mask = inputs + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn([hidden_states, causal_attention_mask]) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = quick_gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + + return residual + hidden_states + +class CLIPAttention(keras.layers.Layer): + def __init__(self): + super().__init__() + self.embed_dim = 768 + self.num_heads = 12 + self.head_dim = self.embed_dim // self.num_heads + self.scale = self.head_dim**-0.5 + self.q_proj = keras.layers.Dense(self.embed_dim, name = "QueryState") + self.k_proj = keras.layers.Dense(self.embed_dim, name = "KeyState") + self.v_proj = keras.layers.Dense(self.embed_dim, name = "ValueState") + self.out_proj = keras.layers.Dense(self.embed_dim, name = "OutProjection") + + def _shape(self, tensor, seq_len: int, bsz: int): + a = keras.ops.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)) + return keras.layers.Permute((2, 1, 3))(a) # bs , n_head , seq_len , head_dim + + def call(self, inputs): + hidden_states, causal_attention_mask = inputs + bsz, tgt_len, embed_dim = hidden_states.shape + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), tgt_len, -1) + value_states = self._shape(self.v_proj(hidden_states), tgt_len, -1) + + proj_shape = (-1, tgt_len, self.head_dim) + query_states = self._shape(query_states, tgt_len, -1) + query_states = keras.ops.reshape(query_states, proj_shape) + key_states = keras.ops.reshape(key_states, proj_shape) + + src_len = tgt_len + value_states = keras.ops.reshape(value_states, proj_shape) + attn_weights = query_states @ keras.layers.Permute((2, 1))(key_states) + + attn_weights = keras.ops.reshape(attn_weights, (-1, self.num_heads, tgt_len, src_len)) + #print("attn_weights dtype:",attn_weights.dtype) + #print('casual dtype:',causal_attention_mask.dtype) + # Convert the causal_attention_mask tensor to the same data type as attn_weights + #causal_attention_mask = keras.ops.cast(causal_attention_mask, dtype=attn_weights.dtype) + attn_weights = attn_weights + causal_attention_mask + attn_weights = keras.ops.reshape(attn_weights, (-1, tgt_len, src_len)) + + attn_weights = keras.ops.softmax(attn_weights) + attn_output = attn_weights @ value_states + + attn_output = keras.ops.reshape( + attn_output, (-1, self.num_heads, tgt_len, self.head_dim) + ) + attn_output = keras.layers.Permute((2, 1, 3))(attn_output) + attn_output = keras.ops.reshape(attn_output, (-1, tgt_len, embed_dim)) + + return self.out_proj(attn_output) \ No newline at end of file diff --git a/stableDiffusionKeras/clipTokenizer/ReadMe.md b/stableDiffusionKeras/clipTokenizer/ReadMe.md new file mode 100644 index 0000000..78118d5 --- /dev/null +++ b/stableDiffusionKeras/clipTokenizer/ReadMe.md @@ -0,0 +1,3 @@ +## CLIP Tokenizer + +This folder contains the files necessary for the CLIP Tokenizer. diff --git a/stableDiffusionKeras/clipTokenizer/__init__.py b/stableDiffusionKeras/clipTokenizer/__init__.py new file mode 100644 index 0000000..134d99c --- /dev/null +++ b/stableDiffusionKeras/clipTokenizer/__init__.py @@ -0,0 +1,292 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +import keras + + +@lru_cache() +def default_bpe(): + p = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + if os.path.exists(p): + return p + else: + return keras.utils.get_file( + "bpe_simple_vocab_16e6.txt.gz", + "https://github.com/openai/CLIP/blob/main/clip/bpe_simple_vocab_16e6.txt.gz?raw=true", + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__( + self, + bpe_path: str = default_bpe(), + specialTokens = None + ): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + + """ + Special Tokens are words we want to add to the vocabularly that don't exist in real life. + We can use these special tokens to activate pre-trained vectors for the text-encoder + + The only special tokens that are always added indicate when the text starts and ends + """ + if not specialTokens: + addedTokens = None + specialTokens = ['',''] + else: + addedTokens = specialTokens + specialTokens = ['', ''] + specialTokens + vocab.extend(specialTokens) + + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in specialTokens} + + # Create special words to recognize + special = "|".join(specialTokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in specialTokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + +class LegacySimpleTokenizer(object): + def __init__( + self, + bpe_path: str = default_bpe(), + specialTokens = None + ): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + + """ + Special Tokens are words we want to add to the vocabularly that don't exist in real life. + We can use these special tokens to activate pre-trained vectors for the text-encoder + + The only special tokens that are always added indicate when the text starts and ends + """ + # Create the words to add to the vocabulary + if specialTokens is None: + addedTokens = None + specialTokens = ["<|startoftext|>", "<|endoftext|>"] + else: + addedTokens = specialTokens + specialTokens = ["<|startoftext|>", "<|endoftext|>"] + specialTokens + vocab.extend(specialTokens) + + # Create the list for the program to recognize the words + if addedTokens is not None: + special = "|".join(addedTokens) + special = special + "|" + else: + special = "" + + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in specialTokens} + + self.pat = re.compile( + special + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + def bpe(self, token): + # print("Tokenzing this: ",token) + if token in self.cache: + # print("Found in cache! Returning cache value") + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + #return [49406] + bpe_tokens + [49407] + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text diff --git a/stableDiffusionKeras/clipTokenizer/bpe_simple_vocab_16e6.txt.gz b/stableDiffusionKeras/clipTokenizer/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000..7b5088a Binary files /dev/null and b/stableDiffusionKeras/clipTokenizer/bpe_simple_vocab_16e6.txt.gz differ diff --git a/stableDiffusionKeras/constants.py b/stableDiffusionKeras/constants.py new file mode 100644 index 0000000..8200df1 --- /dev/null +++ b/stableDiffusionKeras/constants.py @@ -0,0 +1,3704 @@ +PYTORCH_CKPT_MAPPING = {'text_encoder_legacy': [('cond_stage_model.transformer.text_model.embeddings.token_embedding.weight', + None), + ('cond_stage_model.transformer.text_model.embeddings.position_embedding.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias', + None), + ('cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight', + (1, 0)), + ('cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias', + None), + ('cond_stage_model.transformer.text_model.final_layer_norm.weight', None), + ('cond_stage_model.transformer.text_model.final_layer_norm.bias', None)], + # new Stable Diffusion + 'text_encoder': [('cond_stage_model.model.token_embedding.weight', + None), + ('cond_stage_model.model.positional_embedding', + None), + ('cond_stage_model.model.transformer.resblocks.0.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.0.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.0.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.0.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.0.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.1.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.1.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.1.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.1.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.1.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.1.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.2.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.2.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.2.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.2.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.2.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.2.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.3.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.3.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.3.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.3.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.3.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.3.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.4.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.4.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.4.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.4.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.4.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.4.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.5.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.5.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.5.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.5.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.5.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.5.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.6.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.6.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.6.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.6.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.6.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.6.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.7.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.7.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.7.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.7.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.7.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.7.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.8.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.8.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.8.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.8.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.8.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.8.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.9.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.9.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.9.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.9.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.9.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.9.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.9.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.9.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.9.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.9.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.10.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.10.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.10.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.10.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.10.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.10.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.11.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.11.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.11.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.11.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.11.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.11.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.12.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.12.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.12.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.12.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.12.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.12.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.13.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.13.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.13.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.13.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.13.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.13.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.14.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.14.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.14.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.14.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.14.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.14.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.15.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.15.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.15.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.15.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.15.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.15.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.16.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.16.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.16.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.16.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.16.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.16.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.17.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.17.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.17.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.17.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.17.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.17.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.18.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.18.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.18.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.18.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.18.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.18.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.19.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.19.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.19.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.19.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.19.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.19.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.20.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.20.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.20.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.20.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.20.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.20.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.21.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.21.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.21.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.21.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.21.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.21.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.22.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.22.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.22.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.22.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.22.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.23.ln_1.weight', + None), + ('cond_stage_model.model.transformer.resblocks.23.ln_1.bias', + None), + ('cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight', + None), + ('cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias', + None), + ('cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias', + None), + ('cond_stage_model.model.transformer.resblocks.23.ln_2.weight', + None), + ('cond_stage_model.model.transformer.resblocks.23.ln_2.bias', + None), + ('cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias', + None), + ('cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight', + (1, 0)), + ('cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias', + None), + ('cond_stage_model.model.ln_final.weight', None), + ('cond_stage_model.model.ln_final.bias', None)], + 'diffusion_model': [('model.diffusion_model.time_embed.0.weight', (1, 0)), + ('model.diffusion_model.time_embed.0.bias', None), + ('model.diffusion_model.input_blocks.0.0.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.0.0.bias', None), + ('model.diffusion_model.time_embed.2.weight', (1, 0)), + ('model.diffusion_model.time_embed.2.bias', None), + ('model.diffusion_model.input_blocks.1.0.in_layers.0.weight', None), + ('model.diffusion_model.input_blocks.1.0.in_layers.0.bias', None), + ('model.diffusion_model.input_blocks.1.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.1.0.in_layers.2.bias', None), + ('model.diffusion_model.input_blocks.1.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.0.emb_layers.1.bias', None), + ('model.diffusion_model.input_blocks.1.0.out_layers.0.weight', None), + ('model.diffusion_model.input_blocks.1.0.out_layers.0.bias', None), + ('model.diffusion_model.input_blocks.1.0.out_layers.3.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.1.0.out_layers.3.bias', None), + ('model.diffusion_model.input_blocks.1.1.norm.weight', None), + ('model.diffusion_model.input_blocks.1.1.norm.bias', None), + ('model.diffusion_model.input_blocks.1.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.1.1.proj_in.bias', None), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight', None), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias', None), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias', None), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight', None), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias', None), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias', None), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight', None), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias', None), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias', None), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight', (1, 0)), + ('model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias', None), + ('model.diffusion_model.input_blocks.1.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.1.1.proj_out.bias', None), + ('model.diffusion_model.input_blocks.2.0.in_layers.0.weight', None), + ('model.diffusion_model.input_blocks.2.0.in_layers.0.bias', None), + ('model.diffusion_model.input_blocks.2.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.2.0.in_layers.2.bias', None), + ('model.diffusion_model.input_blocks.2.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.0.emb_layers.1.bias', None), + ('model.diffusion_model.input_blocks.2.0.out_layers.0.weight', None), + ('model.diffusion_model.input_blocks.2.0.out_layers.0.bias', None), + ('model.diffusion_model.input_blocks.2.0.out_layers.3.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.2.0.out_layers.3.bias', None), + ('model.diffusion_model.input_blocks.2.1.norm.weight', None), + ('model.diffusion_model.input_blocks.2.1.norm.bias', None), + ('model.diffusion_model.input_blocks.2.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.2.1.proj_in.bias', None), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight', None), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias', None), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias', None), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight', None), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias', None), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias', None), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight', None), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias', None), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias', None), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight', (1, 0)), + ('model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias', None), + ('model.diffusion_model.input_blocks.2.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.2.1.proj_out.bias', None), + ('model.diffusion_model.input_blocks.3.0.op.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.3.0.op.bias', None), + ('model.diffusion_model.input_blocks.4.0.in_layers.0.weight', None), + ('model.diffusion_model.input_blocks.4.0.in_layers.0.bias', None), + ('model.diffusion_model.input_blocks.4.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.4.0.in_layers.2.bias', None), + ('model.diffusion_model.input_blocks.4.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.0.emb_layers.1.bias', None), + ('model.diffusion_model.input_blocks.4.0.out_layers.0.weight', None), + ('model.diffusion_model.input_blocks.4.0.out_layers.0.bias', None), + ('model.diffusion_model.input_blocks.4.0.out_layers.3.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.4.0.out_layers.3.bias', None), + ('model.diffusion_model.input_blocks.4.0.skip_connection.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.4.0.skip_connection.bias', None), + ('model.diffusion_model.input_blocks.4.1.norm.weight', None), + ('model.diffusion_model.input_blocks.4.1.norm.bias', None), + ('model.diffusion_model.input_blocks.4.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.4.1.proj_in.bias', None), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight', None), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias', None), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias', None), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight', None), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias', None), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias', None), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight', None), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias', None), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias', None), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight', (1, 0)), + ('model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias', None), + ('model.diffusion_model.input_blocks.4.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.4.1.proj_out.bias', None), + ('model.diffusion_model.input_blocks.5.0.in_layers.0.weight', None), + ('model.diffusion_model.input_blocks.5.0.in_layers.0.bias', None), + ('model.diffusion_model.input_blocks.5.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.5.0.in_layers.2.bias', None), + ('model.diffusion_model.input_blocks.5.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.input_blocks.5.0.emb_layers.1.bias', None), + ('model.diffusion_model.input_blocks.5.0.out_layers.0.weight', None), + ('model.diffusion_model.input_blocks.5.0.out_layers.0.bias', None), + ('model.diffusion_model.input_blocks.5.0.out_layers.3.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.5.0.out_layers.3.bias', None), + ('model.diffusion_model.input_blocks.5.1.norm.weight', None), + ('model.diffusion_model.input_blocks.5.1.norm.bias', None), + ('model.diffusion_model.input_blocks.5.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.5.1.proj_in.bias', None), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.input_blocks.5.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.5.1.proj_out.bias', None), + ('model.diffusion_model.input_blocks.6.0.op.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.6.0.op.bias', None), + ('model.diffusion_model.input_blocks.7.0.in_layers.0.weight', None), + ('model.diffusion_model.input_blocks.7.0.in_layers.0.bias', None), + ('model.diffusion_model.input_blocks.7.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.7.0.in_layers.2.bias', None), + ('model.diffusion_model.input_blocks.7.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.input_blocks.7.0.emb_layers.1.bias', None), + ('model.diffusion_model.input_blocks.7.0.out_layers.0.weight', None), + ('model.diffusion_model.input_blocks.7.0.out_layers.0.bias', None), + ('model.diffusion_model.input_blocks.7.0.out_layers.3.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.7.0.out_layers.3.bias', None), + ('model.diffusion_model.input_blocks.7.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.7.0.skip_connection.bias', None), + ('model.diffusion_model.input_blocks.7.1.norm.weight', None), + ('model.diffusion_model.input_blocks.7.1.norm.bias', None), + ('model.diffusion_model.input_blocks.7.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.7.1.proj_in.bias', None), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.input_blocks.7.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.7.1.proj_out.bias', None), + ('model.diffusion_model.input_blocks.8.0.in_layers.0.weight', None), + ('model.diffusion_model.input_blocks.8.0.in_layers.0.bias', None), + ('model.diffusion_model.input_blocks.8.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.8.0.in_layers.2.bias', None), + ('model.diffusion_model.input_blocks.8.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.input_blocks.8.0.emb_layers.1.bias', None), + ('model.diffusion_model.input_blocks.8.0.out_layers.0.weight', None), + ('model.diffusion_model.input_blocks.8.0.out_layers.0.bias', None), + ('model.diffusion_model.input_blocks.8.0.out_layers.3.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.8.0.out_layers.3.bias', None), + ('model.diffusion_model.input_blocks.8.1.norm.weight', None), + ('model.diffusion_model.input_blocks.8.1.norm.bias', None), + ('model.diffusion_model.input_blocks.8.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.8.1.proj_in.bias', None), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.input_blocks.8.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.8.1.proj_out.bias', None), + ('model.diffusion_model.input_blocks.9.0.op.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.9.0.op.bias', None), + ('model.diffusion_model.input_blocks.10.0.in_layers.0.weight', None), + ('model.diffusion_model.input_blocks.10.0.in_layers.0.bias', None), + ('model.diffusion_model.input_blocks.10.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.10.0.in_layers.2.bias', None), + ('model.diffusion_model.input_blocks.10.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.input_blocks.10.0.emb_layers.1.bias', None), + ('model.diffusion_model.input_blocks.10.0.out_layers.0.weight', None), + ('model.diffusion_model.input_blocks.10.0.out_layers.0.bias', None), + ('model.diffusion_model.input_blocks.10.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.10.0.out_layers.3.bias', None), + ('model.diffusion_model.input_blocks.11.0.in_layers.0.weight', None), + ('model.diffusion_model.input_blocks.11.0.in_layers.0.bias', None), + ('model.diffusion_model.input_blocks.11.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.11.0.in_layers.2.bias', None), + ('model.diffusion_model.input_blocks.11.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.input_blocks.11.0.emb_layers.1.bias', None), + ('model.diffusion_model.input_blocks.11.0.out_layers.0.weight', None), + ('model.diffusion_model.input_blocks.11.0.out_layers.0.bias', None), + ('model.diffusion_model.input_blocks.11.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.input_blocks.11.0.out_layers.3.bias', None), + ('model.diffusion_model.middle_block.0.in_layers.0.weight', None), + ('model.diffusion_model.middle_block.0.in_layers.0.bias', None), + ('model.diffusion_model.middle_block.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.middle_block.0.in_layers.2.bias', None), + ('model.diffusion_model.middle_block.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.middle_block.0.emb_layers.1.bias', None), + ('model.diffusion_model.middle_block.0.out_layers.0.weight', None), + ('model.diffusion_model.middle_block.0.out_layers.0.bias', None), + ('model.diffusion_model.middle_block.0.out_layers.3.weight', (2, 3, 1, 0)), + ('model.diffusion_model.middle_block.0.out_layers.3.bias', None), + ('model.diffusion_model.middle_block.1.norm.weight', None), + ('model.diffusion_model.middle_block.1.norm.bias', None), + ('model.diffusion_model.middle_block.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.middle_block.1.proj_in.bias', None), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.middle_block.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.middle_block.1.proj_out.bias', None), + ('model.diffusion_model.middle_block.2.in_layers.0.weight', None), + ('model.diffusion_model.middle_block.2.in_layers.0.bias', None), + ('model.diffusion_model.middle_block.2.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.middle_block.2.in_layers.2.bias', None), + ('model.diffusion_model.middle_block.2.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.middle_block.2.emb_layers.1.bias', None), + ('model.diffusion_model.middle_block.2.out_layers.0.weight', None), + ('model.diffusion_model.middle_block.2.out_layers.0.bias', None), + ('model.diffusion_model.middle_block.2.out_layers.3.weight', (2, 3, 1, 0)), + ('model.diffusion_model.middle_block.2.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.0.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.0.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.0.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.0.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.0.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.0.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.0.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.0.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.0.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.0.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.0.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.0.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.1.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.1.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.1.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.1.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.1.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.1.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.1.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.1.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.1.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.1.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.1.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.1.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.2.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.2.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.2.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.2.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.2.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.2.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.2.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.2.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.2.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.2.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.2.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.2.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.2.1.conv.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.2.1.conv.bias', None), + ('model.diffusion_model.output_blocks.3.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.3.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.3.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.3.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.3.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.3.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.3.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.3.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.3.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.3.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.3.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.3.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.3.1.norm.weight', None), + ('model.diffusion_model.output_blocks.3.1.norm.bias', None), + ('model.diffusion_model.output_blocks.3.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.3.1.proj_in.bias', None), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.output_blocks.3.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.3.1.proj_out.bias', None), + ('model.diffusion_model.output_blocks.4.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.4.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.4.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.4.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.4.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.4.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.4.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.4.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.4.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.4.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.4.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.4.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.4.1.norm.weight', None), + ('model.diffusion_model.output_blocks.4.1.norm.bias', None), + ('model.diffusion_model.output_blocks.4.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.4.1.proj_in.bias', None), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.output_blocks.4.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.4.1.proj_out.bias', None), + ('model.diffusion_model.output_blocks.5.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.5.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.5.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.5.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.5.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.5.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.5.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.5.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.5.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.5.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.5.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.5.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.5.1.norm.weight', None), + ('model.diffusion_model.output_blocks.5.1.norm.bias', None), + ('model.diffusion_model.output_blocks.5.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.5.1.proj_in.bias', None), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.output_blocks.5.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.5.1.proj_out.bias', None), + ('model.diffusion_model.output_blocks.5.2.conv.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.5.2.conv.bias', None), + ('model.diffusion_model.output_blocks.6.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.6.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.6.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.6.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.6.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.6.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.6.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.6.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.6.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.6.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.6.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.6.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.6.1.norm.weight', None), + ('model.diffusion_model.output_blocks.6.1.norm.bias', None), + ('model.diffusion_model.output_blocks.6.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.6.1.proj_in.bias', None), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.output_blocks.6.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.6.1.proj_out.bias', None), + ('model.diffusion_model.output_blocks.7.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.7.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.7.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.7.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.7.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.7.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.7.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.7.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.7.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.7.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.7.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.7.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.7.1.norm.weight', None), + ('model.diffusion_model.output_blocks.7.1.norm.bias', None), + ('model.diffusion_model.output_blocks.7.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.7.1.proj_in.bias', None), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.output_blocks.7.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.7.1.proj_out.bias', None), + ('model.diffusion_model.output_blocks.8.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.8.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.8.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.8.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.8.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.8.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.8.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.8.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.8.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.8.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.8.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.8.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.8.1.norm.weight', None), + ('model.diffusion_model.output_blocks.8.1.norm.bias', None), + ('model.diffusion_model.output_blocks.8.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.8.1.proj_in.bias', None), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.output_blocks.8.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.8.1.proj_out.bias', None), + ('model.diffusion_model.output_blocks.8.2.conv.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.8.2.conv.bias', None), + ('model.diffusion_model.output_blocks.9.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.9.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.9.0.in_layers.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.9.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.9.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.9.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.9.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.9.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.9.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.9.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.9.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.9.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.9.1.norm.weight', None), + ('model.diffusion_model.output_blocks.9.1.norm.bias', None), + ('model.diffusion_model.output_blocks.9.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.9.1.proj_in.bias', None), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.output_blocks.9.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.9.1.proj_out.bias', None), + ('model.diffusion_model.output_blocks.10.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.10.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.10.0.in_layers.2.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.10.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.10.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.10.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.10.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.10.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.10.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.10.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.10.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.10.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.10.1.norm.weight', None), + ('model.diffusion_model.output_blocks.10.1.norm.bias', None), + ('model.diffusion_model.output_blocks.10.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.10.1.proj_in.bias', None), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.output_blocks.10.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.10.1.proj_out.bias', None), + ('model.diffusion_model.output_blocks.11.0.in_layers.0.weight', None), + ('model.diffusion_model.output_blocks.11.0.in_layers.0.bias', None), + ('model.diffusion_model.output_blocks.11.0.in_layers.2.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.11.0.in_layers.2.bias', None), + ('model.diffusion_model.output_blocks.11.0.emb_layers.1.weight', (1, 0)), + ('model.diffusion_model.output_blocks.11.0.emb_layers.1.bias', None), + ('model.diffusion_model.output_blocks.11.0.out_layers.0.weight', None), + ('model.diffusion_model.output_blocks.11.0.out_layers.0.bias', None), + ('model.diffusion_model.output_blocks.11.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.11.0.out_layers.3.bias', None), + ('model.diffusion_model.output_blocks.11.0.skip_connection.weight', + (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.11.0.skip_connection.bias', None), + ('model.diffusion_model.output_blocks.11.1.norm.weight', None), + ('model.diffusion_model.output_blocks.11.1.norm.bias', None), + ('model.diffusion_model.output_blocks.11.1.proj_in.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.11.1.proj_in.bias', None), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight', + None), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias', + None), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight', + None), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias', + None), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight', + None), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias', + None), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias', + None), + ('model.diffusion_model.output_blocks.11.1.proj_out.weight', (2, 3, 1, 0)), + ('model.diffusion_model.output_blocks.11.1.proj_out.bias', None), + ('model.diffusion_model.out.0.weight', None), + ('model.diffusion_model.out.0.bias', None), + ('model.diffusion_model.out.2.weight', (2, 3, 1, 0)), + ('model.diffusion_model.out.2.bias', None)], + 'decoder': [('first_stage_model.post_quant_conv.weight', (2, 3, 1, 0)), + ('first_stage_model.post_quant_conv.bias', None), + ('first_stage_model.decoder.conv_in.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.conv_in.bias', None), + ('first_stage_model.decoder.mid.block_1.norm1.weight', None), + ('first_stage_model.decoder.mid.block_1.norm1.bias', None), + ('first_stage_model.decoder.mid.block_1.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.mid.block_1.conv1.bias', None), + ('first_stage_model.decoder.mid.block_1.norm2.weight', None), + ('first_stage_model.decoder.mid.block_1.norm2.bias', None), + ('first_stage_model.decoder.mid.block_1.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.mid.block_1.conv2.bias', None), + ('first_stage_model.decoder.mid.attn_1.norm.weight', None), + ('first_stage_model.decoder.mid.attn_1.norm.bias', None), + ('first_stage_model.decoder.mid.attn_1.q.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.mid.attn_1.q.bias', None), + ('first_stage_model.decoder.mid.attn_1.k.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.mid.attn_1.k.bias', None), + ('first_stage_model.decoder.mid.attn_1.v.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.mid.attn_1.v.bias', None), + ('first_stage_model.decoder.mid.attn_1.proj_out.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.mid.attn_1.proj_out.bias', None), + ('first_stage_model.decoder.mid.block_2.norm1.weight', None), + ('first_stage_model.decoder.mid.block_2.norm1.bias', None), + ('first_stage_model.decoder.mid.block_2.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.mid.block_2.conv1.bias', None), + ('first_stage_model.decoder.mid.block_2.norm2.weight', None), + ('first_stage_model.decoder.mid.block_2.norm2.bias', None), + ('first_stage_model.decoder.mid.block_2.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.mid.block_2.conv2.bias', None), + ('first_stage_model.decoder.up.3.block.0.norm1.weight', None), + ('first_stage_model.decoder.up.3.block.0.norm1.bias', None), + ('first_stage_model.decoder.up.3.block.0.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.3.block.0.conv1.bias', None), + ('first_stage_model.decoder.up.3.block.0.norm2.weight', None), + ('first_stage_model.decoder.up.3.block.0.norm2.bias', None), + ('first_stage_model.decoder.up.3.block.0.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.3.block.0.conv2.bias', None), + ('first_stage_model.decoder.up.3.block.1.norm1.weight', None), + ('first_stage_model.decoder.up.3.block.1.norm1.bias', None), + ('first_stage_model.decoder.up.3.block.1.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.3.block.1.conv1.bias', None), + ('first_stage_model.decoder.up.3.block.1.norm2.weight', None), + ('first_stage_model.decoder.up.3.block.1.norm2.bias', None), + ('first_stage_model.decoder.up.3.block.1.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.3.block.1.conv2.bias', None), + ('first_stage_model.decoder.up.3.block.2.norm1.weight', None), + ('first_stage_model.decoder.up.3.block.2.norm1.bias', None), + ('first_stage_model.decoder.up.3.block.2.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.3.block.2.conv1.bias', None), + ('first_stage_model.decoder.up.3.block.2.norm2.weight', None), + ('first_stage_model.decoder.up.3.block.2.norm2.bias', None), + ('first_stage_model.decoder.up.3.block.2.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.3.block.2.conv2.bias', None), + ('first_stage_model.decoder.up.3.upsample.conv.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.3.upsample.conv.bias', None), + ('first_stage_model.decoder.up.2.block.0.norm1.weight', None), + ('first_stage_model.decoder.up.2.block.0.norm1.bias', None), + ('first_stage_model.decoder.up.2.block.0.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.2.block.0.conv1.bias', None), + ('first_stage_model.decoder.up.2.block.0.norm2.weight', None), + ('first_stage_model.decoder.up.2.block.0.norm2.bias', None), + ('first_stage_model.decoder.up.2.block.0.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.2.block.0.conv2.bias', None), + ('first_stage_model.decoder.up.2.block.1.norm1.weight', None), + ('first_stage_model.decoder.up.2.block.1.norm1.bias', None), + ('first_stage_model.decoder.up.2.block.1.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.2.block.1.conv1.bias', None), + ('first_stage_model.decoder.up.2.block.1.norm2.weight', None), + ('first_stage_model.decoder.up.2.block.1.norm2.bias', None), + ('first_stage_model.decoder.up.2.block.1.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.2.block.1.conv2.bias', None), + ('first_stage_model.decoder.up.2.block.2.norm1.weight', None), + ('first_stage_model.decoder.up.2.block.2.norm1.bias', None), + ('first_stage_model.decoder.up.2.block.2.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.2.block.2.conv1.bias', None), + ('first_stage_model.decoder.up.2.block.2.norm2.weight', None), + ('first_stage_model.decoder.up.2.block.2.norm2.bias', None), + ('first_stage_model.decoder.up.2.block.2.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.2.block.2.conv2.bias', None), + ('first_stage_model.decoder.up.2.upsample.conv.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.2.upsample.conv.bias', None), + ('first_stage_model.decoder.up.1.block.0.norm1.weight', None), + ('first_stage_model.decoder.up.1.block.0.norm1.bias', None), + ('first_stage_model.decoder.up.1.block.0.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.1.block.0.conv1.bias', None), + ('first_stage_model.decoder.up.1.block.0.norm2.weight', None), + ('first_stage_model.decoder.up.1.block.0.norm2.bias', None), + ('first_stage_model.decoder.up.1.block.0.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.1.block.0.conv2.bias', None), + ('first_stage_model.decoder.up.1.block.0.nin_shortcut.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.1.block.0.nin_shortcut.bias', None), + ('first_stage_model.decoder.up.1.block.1.norm1.weight', None), + ('first_stage_model.decoder.up.1.block.1.norm1.bias', None), + ('first_stage_model.decoder.up.1.block.1.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.1.block.1.conv1.bias', None), + ('first_stage_model.decoder.up.1.block.1.norm2.weight', None), + ('first_stage_model.decoder.up.1.block.1.norm2.bias', None), + ('first_stage_model.decoder.up.1.block.1.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.1.block.1.conv2.bias', None), + ('first_stage_model.decoder.up.1.block.2.norm1.weight', None), + ('first_stage_model.decoder.up.1.block.2.norm1.bias', None), + ('first_stage_model.decoder.up.1.block.2.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.1.block.2.conv1.bias', None), + ('first_stage_model.decoder.up.1.block.2.norm2.weight', None), + ('first_stage_model.decoder.up.1.block.2.norm2.bias', None), + ('first_stage_model.decoder.up.1.block.2.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.1.block.2.conv2.bias', None), + ('first_stage_model.decoder.up.1.upsample.conv.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.1.upsample.conv.bias', None), + ('first_stage_model.decoder.up.0.block.0.norm1.weight', None), + ('first_stage_model.decoder.up.0.block.0.norm1.bias', None), + ('first_stage_model.decoder.up.0.block.0.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.0.block.0.conv1.bias', None), + ('first_stage_model.decoder.up.0.block.0.norm2.weight', None), + ('first_stage_model.decoder.up.0.block.0.norm2.bias', None), + ('first_stage_model.decoder.up.0.block.0.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.0.block.0.conv2.bias', None), + ('first_stage_model.decoder.up.0.block.0.nin_shortcut.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.0.block.0.nin_shortcut.bias', None), + ('first_stage_model.decoder.up.0.block.1.norm1.weight', None), + ('first_stage_model.decoder.up.0.block.1.norm1.bias', None), + ('first_stage_model.decoder.up.0.block.1.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.0.block.1.conv1.bias', None), + ('first_stage_model.decoder.up.0.block.1.norm2.weight', None), + ('first_stage_model.decoder.up.0.block.1.norm2.bias', None), + ('first_stage_model.decoder.up.0.block.1.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.0.block.1.conv2.bias', None), + ('first_stage_model.decoder.up.0.block.2.norm1.weight', None), + ('first_stage_model.decoder.up.0.block.2.norm1.bias', None), + ('first_stage_model.decoder.up.0.block.2.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.0.block.2.conv1.bias', None), + ('first_stage_model.decoder.up.0.block.2.norm2.weight', None), + ('first_stage_model.decoder.up.0.block.2.norm2.bias', None), + ('first_stage_model.decoder.up.0.block.2.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.up.0.block.2.conv2.bias', None), + ('first_stage_model.decoder.norm_out.weight', None), + ('first_stage_model.decoder.norm_out.bias', None), + ('first_stage_model.decoder.conv_out.weight', (2, 3, 1, 0)), + ('first_stage_model.decoder.conv_out.bias', None)], + 'encoder': [('first_stage_model.encoder.conv_in.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.conv_in.bias', None), + ('first_stage_model.encoder.down.0.block.0.norm1.weight', None), + ('first_stage_model.encoder.down.0.block.0.norm1.bias', None), + ('first_stage_model.encoder.down.0.block.0.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.0.block.0.conv1.bias', None), + ('first_stage_model.encoder.down.0.block.0.norm2.weight', None), + ('first_stage_model.encoder.down.0.block.0.norm2.bias', None), + ('first_stage_model.encoder.down.0.block.0.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.0.block.0.conv2.bias', None), + ('first_stage_model.encoder.down.0.block.1.norm1.weight', None), + ('first_stage_model.encoder.down.0.block.1.norm1.bias', None), + ('first_stage_model.encoder.down.0.block.1.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.0.block.1.conv1.bias', None), + ('first_stage_model.encoder.down.0.block.1.norm2.weight', None), + ('first_stage_model.encoder.down.0.block.1.norm2.bias', None), + ('first_stage_model.encoder.down.0.block.1.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.0.block.1.conv2.bias', None), + ('first_stage_model.encoder.down.0.downsample.conv.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.0.downsample.conv.bias', None), + ('first_stage_model.encoder.down.1.block.0.norm1.weight', None), + ('first_stage_model.encoder.down.1.block.0.norm1.bias', None), + ('first_stage_model.encoder.down.1.block.0.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.1.block.0.conv1.bias', None), + ('first_stage_model.encoder.down.1.block.0.norm2.weight', None), + ('first_stage_model.encoder.down.1.block.0.norm2.bias', None), + ('first_stage_model.encoder.down.1.block.0.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.1.block.0.conv2.bias', None), + ('first_stage_model.encoder.down.1.block.0.nin_shortcut.weight', + (2, 3, 1, 0)), + ('first_stage_model.encoder.down.1.block.0.nin_shortcut.bias', None), + ('first_stage_model.encoder.down.1.block.1.norm1.weight', None), + ('first_stage_model.encoder.down.1.block.1.norm1.bias', None), + ('first_stage_model.encoder.down.1.block.1.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.1.block.1.conv1.bias', None), + ('first_stage_model.encoder.down.1.block.1.norm2.weight', None), + ('first_stage_model.encoder.down.1.block.1.norm2.bias', None), + ('first_stage_model.encoder.down.1.block.1.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.1.block.1.conv2.bias', None), + ('first_stage_model.encoder.down.1.downsample.conv.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.1.downsample.conv.bias', None), + ('first_stage_model.encoder.down.2.block.0.norm1.weight', None), + ('first_stage_model.encoder.down.2.block.0.norm1.bias', None), + ('first_stage_model.encoder.down.2.block.0.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.2.block.0.conv1.bias', None), + ('first_stage_model.encoder.down.2.block.0.norm2.weight', None), + ('first_stage_model.encoder.down.2.block.0.norm2.bias', None), + ('first_stage_model.encoder.down.2.block.0.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.2.block.0.conv2.bias', None), + ('first_stage_model.encoder.down.2.block.0.nin_shortcut.weight', + (2, 3, 1, 0)), + ('first_stage_model.encoder.down.2.block.0.nin_shortcut.bias', None), + ('first_stage_model.encoder.down.2.block.1.norm1.weight', None), + ('first_stage_model.encoder.down.2.block.1.norm1.bias', None), + ('first_stage_model.encoder.down.2.block.1.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.2.block.1.conv1.bias', None), + ('first_stage_model.encoder.down.2.block.1.norm2.weight', None), + ('first_stage_model.encoder.down.2.block.1.norm2.bias', None), + ('first_stage_model.encoder.down.2.block.1.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.2.block.1.conv2.bias', None), + ('first_stage_model.encoder.down.2.downsample.conv.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.2.downsample.conv.bias', None), + ('first_stage_model.encoder.down.3.block.0.norm1.weight', None), + ('first_stage_model.encoder.down.3.block.0.norm1.bias', None), + ('first_stage_model.encoder.down.3.block.0.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.3.block.0.conv1.bias', None), + ('first_stage_model.encoder.down.3.block.0.norm2.weight', None), + ('first_stage_model.encoder.down.3.block.0.norm2.bias', None), + ('first_stage_model.encoder.down.3.block.0.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.3.block.0.conv2.bias', None), + ('first_stage_model.encoder.down.3.block.1.norm1.weight', None), + ('first_stage_model.encoder.down.3.block.1.norm1.bias', None), + ('first_stage_model.encoder.down.3.block.1.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.3.block.1.conv1.bias', None), + ('first_stage_model.encoder.down.3.block.1.norm2.weight', None), + ('first_stage_model.encoder.down.3.block.1.norm2.bias', None), + ('first_stage_model.encoder.down.3.block.1.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.down.3.block.1.conv2.bias', None), + ('first_stage_model.encoder.mid.block_1.norm1.weight', None), + ('first_stage_model.encoder.mid.block_1.norm1.bias', None), + ('first_stage_model.encoder.mid.block_1.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.mid.block_1.conv1.bias', None), + ('first_stage_model.encoder.mid.block_1.norm2.weight', None), + ('first_stage_model.encoder.mid.block_1.norm2.bias', None), + ('first_stage_model.encoder.mid.block_1.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.mid.block_1.conv2.bias', None), + ('first_stage_model.encoder.mid.attn_1.norm.weight', None), + ('first_stage_model.encoder.mid.attn_1.norm.bias', None), + ('first_stage_model.encoder.mid.attn_1.q.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.mid.attn_1.q.bias', None), + ('first_stage_model.encoder.mid.attn_1.k.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.mid.attn_1.k.bias', None), + ('first_stage_model.encoder.mid.attn_1.v.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.mid.attn_1.v.bias', None), + ('first_stage_model.encoder.mid.attn_1.proj_out.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.mid.attn_1.proj_out.bias', None), + ('first_stage_model.encoder.mid.block_2.norm1.weight', None), + ('first_stage_model.encoder.mid.block_2.norm1.bias', None), + ('first_stage_model.encoder.mid.block_2.conv1.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.mid.block_2.conv1.bias', None), + ('first_stage_model.encoder.mid.block_2.norm2.weight', None), + ('first_stage_model.encoder.mid.block_2.norm2.bias', None), + ('first_stage_model.encoder.mid.block_2.conv2.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.mid.block_2.conv2.bias', None), + ('first_stage_model.encoder.norm_out.weight', None), + ('first_stage_model.encoder.norm_out.bias', None), + ('first_stage_model.encoder.conv_out.weight', (2, 3, 1, 0)), + ('first_stage_model.encoder.conv_out.bias', None), + ('first_stage_model.quant_conv.weight', (2, 3, 1, 0)), + ('first_stage_model.quant_conv.bias', None)], + 'controlNet': [('control_model.time_embed.0.weight', (1, 0)), + ('control_model.time_embed.0.bias', None), + ('control_model.input_blocks.0.0.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.0.0.bias', None), + ('control_model.input_hint_block.0.weight', (2, 3, 1, 0)), + ('control_model.input_hint_block.0.bias', None), + ('control_model.input_hint_block.2.weight', (2, 3, 1, 0)), + ('control_model.input_hint_block.2.bias', None), + ('control_model.input_hint_block.4.weight', (2, 3, 1, 0)), + ('control_model.input_hint_block.4.bias', None), + ('control_model.input_hint_block.6.weight', (2, 3, 1, 0)), + ('control_model.input_hint_block.6.bias', None), + ('control_model.input_hint_block.8.weight', (2, 3, 1, 0)), + ('control_model.input_hint_block.8.bias', None), + ('control_model.input_hint_block.10.weight', (2, 3, 1, 0)), + ('control_model.input_hint_block.10.bias', None), + ('control_model.input_hint_block.12.weight', (2, 3, 1, 0)), + ('control_model.input_hint_block.12.bias', None), + ('control_model.input_hint_block.14.weight', (2, 3, 1, 0)), + ('control_model.input_hint_block.14.bias', None), + ('control_model.time_embed.2.weight', (1, 0)), + ('control_model.time_embed.2.bias', None), + ('control_model.input_blocks.1.0.in_layers.0.weight', None), + ('control_model.input_blocks.1.0.in_layers.0.bias', None), + ('control_model.input_blocks.1.0.in_layers.2.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.1.0.in_layers.2.bias', None), + ('control_model.input_blocks.1.0.emb_layers.1.weight', (1, 0)), + ('control_model.input_blocks.1.0.emb_layers.1.bias', None), + ('control_model.input_blocks.1.0.out_layers.0.weight', None), + ('control_model.input_blocks.1.0.out_layers.0.bias', None), + ('control_model.input_blocks.1.0.out_layers.3.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.1.0.out_layers.3.bias', None), + ('control_model.input_blocks.1.1.norm.weight', None), + ('control_model.input_blocks.1.1.norm.bias', None), + ('control_model.input_blocks.1.1.proj_in.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.1.1.proj_in.bias', None), + ('control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight', None), + ('control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias', None), + ('control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight', (1, 0)), + ('control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight', (1, 0)), + ('control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight', (1, 0)), + ('control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight', (1, 0)), + ('control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias', None), + ('control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight', None), + ('control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias', None), + ('control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight', (1, 0)), + ('control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight', (1, 0)), + ('control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight', (1, 0)), + ('control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight', (1, 0)), + ('control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias', None), + ('control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight', None), + ('control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias', None), + ('control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight', (1, 0)), + ('control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias', None), + ('control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight', (1, 0)), + ('control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias', None), + ('control_model.input_blocks.1.1.proj_out.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.1.1.proj_out.bias', None), + ('control_model.input_blocks.2.0.in_layers.0.weight', None), + ('control_model.input_blocks.2.0.in_layers.0.bias', None), + ('control_model.input_blocks.2.0.in_layers.2.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.2.0.in_layers.2.bias', None), + ('control_model.input_blocks.2.0.emb_layers.1.weight', (1, 0)), + ('control_model.input_blocks.2.0.emb_layers.1.bias', None), + ('control_model.input_blocks.2.0.out_layers.0.weight', None), + ('control_model.input_blocks.2.0.out_layers.0.bias', None), + ('control_model.input_blocks.2.0.out_layers.3.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.2.0.out_layers.3.bias', None), + ('control_model.input_blocks.2.1.norm.weight', None), + ('control_model.input_blocks.2.1.norm.bias', None), + ('control_model.input_blocks.2.1.proj_in.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.2.1.proj_in.bias', None), + ('control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight', None), + ('control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias', None), + ('control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight', (1, 0)), + ('control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight', (1, 0)), + ('control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight', (1, 0)), + ('control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight', (1, 0)), + ('control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias', None), + ('control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight', None), + ('control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias', None), + ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight', (1, 0)), + ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight', (1, 0)), + ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight', (1, 0)), + ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight', (1, 0)), + ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias', None), + ('control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight', None), + ('control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias', None), + ('control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight', (1, 0)), + ('control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias', None), + ('control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight', (1, 0)), + ('control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias', None), + ('control_model.input_blocks.2.1.proj_out.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.2.1.proj_out.bias', None), + ('control_model.input_blocks.3.0.op.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.3.0.op.bias', None), + ('control_model.input_blocks.4.0.in_layers.0.weight', None), + ('control_model.input_blocks.4.0.in_layers.0.bias', None), + ('control_model.input_blocks.4.0.in_layers.2.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.4.0.in_layers.2.bias', None), + ('control_model.input_blocks.4.0.emb_layers.1.weight', (1, 0)), + ('control_model.input_blocks.4.0.emb_layers.1.bias', None), + ('control_model.input_blocks.4.0.out_layers.0.weight', None), + ('control_model.input_blocks.4.0.out_layers.0.bias', None), + ('control_model.input_blocks.4.0.out_layers.3.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.4.0.out_layers.3.bias', None), + ('control_model.input_blocks.4.0.skip_connection.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.4.0.skip_connection.bias', None), + ('control_model.input_blocks.4.1.norm.weight', None), + ('control_model.input_blocks.4.1.norm.bias', None), + ('control_model.input_blocks.4.1.proj_in.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.4.1.proj_in.bias', None), + ('control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight', None), + ('control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias', None), + ('control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight', (1, 0)), + ('control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight', (1, 0)), + ('control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight', (1, 0)), + ('control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight', (1, 0)), + ('control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias', None), + ('control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight', None), + ('control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias', None), + ('control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight', (1, 0)), + ('control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight', (1, 0)), + ('control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight', (1, 0)), + ('control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight', (1, 0)), + ('control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias', None), + ('control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight', None), + ('control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias', None), + ('control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight', (1, 0)), + ('control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias', None), + ('control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight', (1, 0)), + ('control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias', None), + ('control_model.input_blocks.4.1.proj_out.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.4.1.proj_out.bias', None), + ('control_model.input_blocks.5.0.in_layers.0.weight', None), + ('control_model.input_blocks.5.0.in_layers.0.bias', None), + ('control_model.input_blocks.5.0.in_layers.2.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.5.0.in_layers.2.bias', None), + ('control_model.input_blocks.5.0.emb_layers.1.weight', (1, 0)), + ('control_model.input_blocks.5.0.emb_layers.1.bias', None), + ('control_model.input_blocks.5.0.out_layers.0.weight', None), + ('control_model.input_blocks.5.0.out_layers.0.bias', None), + ('control_model.input_blocks.5.0.out_layers.3.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.5.0.out_layers.3.bias', None), + ('control_model.input_blocks.5.1.norm.weight', None), + ('control_model.input_blocks.5.1.norm.bias', None), + ('control_model.input_blocks.5.1.proj_in.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.5.1.proj_in.bias', None), + ('control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight', + None), + ('control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias', + None), + ('control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight', + None), + ('control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias', + None), + ('control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight', + None), + ('control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias', + None), + ('control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias', + None), + ('control_model.input_blocks.5.1.proj_out.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.5.1.proj_out.bias', None), + ('control_model.input_blocks.6.0.op.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.6.0.op.bias', None), + ('control_model.input_blocks.7.0.in_layers.0.weight', None), + ('control_model.input_blocks.7.0.in_layers.0.bias', None), + ('control_model.input_blocks.7.0.in_layers.2.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.7.0.in_layers.2.bias', None), + ('control_model.input_blocks.7.0.emb_layers.1.weight', (1, 0)), + ('control_model.input_blocks.7.0.emb_layers.1.bias', None), + ('control_model.input_blocks.7.0.out_layers.0.weight', None), + ('control_model.input_blocks.7.0.out_layers.0.bias', None), + ('control_model.input_blocks.7.0.out_layers.3.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.7.0.out_layers.3.bias', None), + ('control_model.input_blocks.7.0.skip_connection.weight', + (2, 3, 1, 0)), + ('control_model.input_blocks.7.0.skip_connection.bias', None), + ('control_model.input_blocks.7.1.norm.weight', None), + ('control_model.input_blocks.7.1.norm.bias', None), + ('control_model.input_blocks.7.1.proj_in.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.7.1.proj_in.bias', None), + ('control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight', + None), + ('control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias', + None), + ('control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight', + None), + ('control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias', + None), + ('control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight', + None), + ('control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias', + None), + ('control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias', + None), + ('control_model.input_blocks.7.1.proj_out.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.7.1.proj_out.bias', None), + ('control_model.input_blocks.8.0.in_layers.0.weight', None), + ('control_model.input_blocks.8.0.in_layers.0.bias', None), + ('control_model.input_blocks.8.0.in_layers.2.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.8.0.in_layers.2.bias', None), + ('control_model.input_blocks.8.0.emb_layers.1.weight', (1, 0)), + ('control_model.input_blocks.8.0.emb_layers.1.bias', None), + ('control_model.input_blocks.8.0.out_layers.0.weight', None), + ('control_model.input_blocks.8.0.out_layers.0.bias', None), + ('control_model.input_blocks.8.0.out_layers.3.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.8.0.out_layers.3.bias', None), + ('control_model.input_blocks.8.1.norm.weight', None), + ('control_model.input_blocks.8.1.norm.bias', None), + ('control_model.input_blocks.8.1.proj_in.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.8.1.proj_in.bias', None), + ('control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight', + None), + ('control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias', + None), + ('control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight', + None), + ('control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias', + None), + ('control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight', + None), + ('control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias', + None), + ('control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias', + None), + ('control_model.input_blocks.8.1.proj_out.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.8.1.proj_out.bias', None), + ('control_model.input_blocks.9.0.op.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.9.0.op.bias', None), + ('control_model.input_blocks.10.0.in_layers.0.weight', None), + ('control_model.input_blocks.10.0.in_layers.0.bias', None), + ('control_model.input_blocks.10.0.in_layers.2.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.10.0.in_layers.2.bias', None), + ('control_model.input_blocks.10.0.emb_layers.1.weight', (1, 0)), + ('control_model.input_blocks.10.0.emb_layers.1.bias', None), + ('control_model.input_blocks.10.0.out_layers.0.weight', None), + ('control_model.input_blocks.10.0.out_layers.0.bias', None), + ('control_model.input_blocks.10.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('control_model.input_blocks.10.0.out_layers.3.bias', None), + ('control_model.input_blocks.11.0.in_layers.0.weight', None), + ('control_model.input_blocks.11.0.in_layers.0.bias', None), + ('control_model.input_blocks.11.0.in_layers.2.weight', (2, 3, 1, 0)), + ('control_model.input_blocks.11.0.in_layers.2.bias', None), + ('control_model.input_blocks.11.0.emb_layers.1.weight', (1, 0)), + ('control_model.input_blocks.11.0.emb_layers.1.bias', None), + ('control_model.input_blocks.11.0.out_layers.0.weight', None), + ('control_model.input_blocks.11.0.out_layers.0.bias', None), + ('control_model.input_blocks.11.0.out_layers.3.weight', + (2, 3, 1, 0)), + ('control_model.input_blocks.11.0.out_layers.3.bias', None), + ('control_model.middle_block.0.in_layers.0.weight', None), + ('control_model.middle_block.0.in_layers.0.bias', None), + ('control_model.middle_block.0.in_layers.2.weight', (2, 3, 1, 0)), + ('control_model.middle_block.0.in_layers.2.bias', None), + ('control_model.middle_block.0.emb_layers.1.weight', (1, 0)), + ('control_model.middle_block.0.emb_layers.1.bias', None), + ('control_model.middle_block.0.out_layers.0.weight', None), + ('control_model.middle_block.0.out_layers.0.bias', None), + ('control_model.middle_block.0.out_layers.3.weight', (2, 3, 1, 0)), + ('control_model.middle_block.0.out_layers.3.bias', None), + ('control_model.middle_block.1.norm.weight', None), + ('control_model.middle_block.1.norm.bias', None), + ('control_model.middle_block.1.proj_in.weight', (2, 3, 1, 0)), + ('control_model.middle_block.1.proj_in.bias', None), + ('control_model.middle_block.1.transformer_blocks.0.norm1.weight', + None), + ('control_model.middle_block.1.transformer_blocks.0.norm1.bias', + None), + ('control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight', + (1, 0)), + ('control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight', + (1, 0)), + ('control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight', + (1, 0)), + ('control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight', + (1, 0)), + ('control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias', + None), + ('control_model.middle_block.1.transformer_blocks.0.norm2.weight', + None), + ('control_model.middle_block.1.transformer_blocks.0.norm2.bias', + None), + ('control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight', + (1, 0)), + ('control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight', + (1, 0)), + ('control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight', + (1, 0)), + ('control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight', + (1, 0)), + ('control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias', + None), + ('control_model.middle_block.1.transformer_blocks.0.norm3.weight', + None), + ('control_model.middle_block.1.transformer_blocks.0.norm3.bias', + None), + ('control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight', + (1, 0)), + ('control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias', + None), + ('control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight', + (1, 0)), + ('control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias', + None), + ('control_model.middle_block.1.proj_out.weight', (2, 3, 1, 0)), + ('control_model.middle_block.1.proj_out.bias', None), + ('control_model.middle_block.2.in_layers.0.weight', None), + ('control_model.middle_block.2.in_layers.0.bias', None), + ('control_model.middle_block.2.in_layers.2.weight', (2, 3, 1, 0)), + ('control_model.middle_block.2.in_layers.2.bias', None), + ('control_model.middle_block.2.emb_layers.1.weight', (1, 0)), + ('control_model.middle_block.2.emb_layers.1.bias', None), + ('control_model.middle_block.2.out_layers.0.weight', None), + ('control_model.middle_block.2.out_layers.0.bias', None), + ('control_model.middle_block.2.out_layers.3.weight', (2, 3, 1, 0)), + ('control_model.middle_block.2.out_layers.3.bias', None), + ('control_model.zero_convs.0.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.0.0.bias', None), + ('control_model.zero_convs.1.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.1.0.bias', None), + ('control_model.zero_convs.2.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.2.0.bias', None), + ('control_model.zero_convs.3.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.3.0.bias', None), + ('control_model.zero_convs.4.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.4.0.bias', None), + ('control_model.zero_convs.5.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.5.0.bias', None), + ('control_model.zero_convs.6.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.6.0.bias', None), + ('control_model.zero_convs.7.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.7.0.bias', None), + ('control_model.zero_convs.8.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.8.0.bias', None), + ('control_model.zero_convs.9.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.9.0.bias', None), + ('control_model.zero_convs.10.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.10.0.bias', None), + ('control_model.zero_convs.11.0.weight', (2, 3, 1, 0)), + ('control_model.zero_convs.11.0.bias', None), + ('control_model.middle_block_out.0.weight', (2, 3, 1, 0)), + ('control_model.middle_block_out.0.bias', None),]} + + +_UNCONDITIONAL_TOKENS = [ + 49406, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, +] +_ALPHAS_CUMPROD = [ + 0.99915, + 0.998296, + 0.9974381, + 0.9965762, + 0.99571025, + 0.9948404, + 0.9939665, + 0.9930887, + 0.9922069, + 0.9913211, + 0.9904313, + 0.98953754, + 0.9886398, + 0.9877381, + 0.9868324, + 0.98592263, + 0.98500896, + 0.9840913, + 0.9831696, + 0.982244, + 0.98131436, + 0.9803808, + 0.97944313, + 0.97850156, + 0.977556, + 0.9766064, + 0.97565293, + 0.9746954, + 0.9737339, + 0.9727684, + 0.97179896, + 0.97082555, + 0.96984816, + 0.96886677, + 0.9678814, + 0.96689206, + 0.96589875, + 0.9649015, + 0.96390027, + 0.9628951, + 0.9618859, + 0.96087277, + 0.95985574, + 0.95883465, + 0.9578097, + 0.95678073, + 0.95574784, + 0.954711, + 0.95367026, + 0.9526256, + 0.9515769, + 0.95052433, + 0.94946784, + 0.94840735, + 0.947343, + 0.94627476, + 0.9452025, + 0.9441264, + 0.9430464, + 0.9419625, + 0.9408747, + 0.939783, + 0.9386874, + 0.93758786, + 0.9364845, + 0.93537724, + 0.9342661, + 0.9331511, + 0.9320323, + 0.9309096, + 0.929783, + 0.9286526, + 0.9275183, + 0.9263802, + 0.92523825, + 0.92409253, + 0.92294294, + 0.9217895, + 0.92063236, + 0.9194713, + 0.9183065, + 0.9171379, + 0.91596556, + 0.9147894, + 0.9136095, + 0.91242576, + 0.9112383, + 0.9100471, + 0.9088522, + 0.9076535, + 0.9064511, + 0.90524495, + 0.9040351, + 0.90282154, + 0.9016043, + 0.90038335, + 0.8991587, + 0.8979304, + 0.8966984, + 0.89546275, + 0.89422345, + 0.8929805, + 0.89173394, + 0.89048374, + 0.88922995, + 0.8879725, + 0.8867115, + 0.88544685, + 0.88417864, + 0.88290685, + 0.8816315, + 0.88035256, + 0.8790701, + 0.87778413, + 0.8764946, + 0.8752016, + 0.873905, + 0.87260497, + 0.8713014, + 0.8699944, + 0.86868393, + 0.86737, + 0.8660526, + 0.8647318, + 0.86340755, + 0.8620799, + 0.8607488, + 0.85941434, + 0.8580765, + 0.8567353, + 0.8553907, + 0.8540428, + 0.85269153, + 0.85133696, + 0.84997904, + 0.84861785, + 0.8472533, + 0.8458856, + 0.8445145, + 0.84314024, + 0.84176266, + 0.8403819, + 0.8389979, + 0.8376107, + 0.8362203, + 0.83482677, + 0.83343, + 0.8320301, + 0.8306271, + 0.8292209, + 0.82781166, + 0.82639927, + 0.8249838, + 0.82356524, + 0.8221436, + 0.82071894, + 0.81929123, + 0.81786054, + 0.8164268, + 0.8149901, + 0.8135504, + 0.81210774, + 0.81066215, + 0.8092136, + 0.8077621, + 0.80630773, + 0.80485046, + 0.8033903, + 0.80192727, + 0.8004614, + 0.79899275, + 0.79752123, + 0.7960469, + 0.7945698, + 0.7930899, + 0.79160726, + 0.7901219, + 0.7886338, + 0.787143, + 0.7856495, + 0.7841533, + 0.78265446, + 0.78115296, + 0.7796488, + 0.77814204, + 0.7766327, + 0.7751208, + 0.7736063, + 0.77208924, + 0.7705697, + 0.7690476, + 0.767523, + 0.7659959, + 0.7644664, + 0.76293445, + 0.7614, + 0.7598632, + 0.75832397, + 0.75678235, + 0.75523835, + 0.75369203, + 0.7521434, + 0.75059247, + 0.7490392, + 0.7474837, + 0.7459259, + 0.7443659, + 0.74280363, + 0.7412392, + 0.7396726, + 0.7381038, + 0.73653287, + 0.7349598, + 0.7333846, + 0.73180735, + 0.730228, + 0.7286466, + 0.7270631, + 0.7254777, + 0.72389024, + 0.72230077, + 0.7207094, + 0.71911603, + 0.7175208, + 0.7159236, + 0.71432453, + 0.7127236, + 0.71112084, + 0.7095162, + 0.7079098, + 0.7063016, + 0.70469165, + 0.70307994, + 0.7014665, + 0.69985133, + 0.6982345, + 0.696616, + 0.6949958, + 0.69337404, + 0.69175065, + 0.69012564, + 0.6884991, + 0.68687093, + 0.6852413, + 0.68361014, + 0.6819775, + 0.6803434, + 0.67870784, + 0.6770708, + 0.6754324, + 0.6737926, + 0.67215145, + 0.670509, + 0.66886514, + 0.66722, + 0.6655736, + 0.66392595, + 0.662277, + 0.6606269, + 0.65897554, + 0.657323, + 0.65566933, + 0.6540145, + 0.6523586, + 0.6507016, + 0.6490435, + 0.64738435, + 0.6457241, + 0.64406294, + 0.6424008, + 0.64073765, + 0.63907355, + 0.63740855, + 0.6357426, + 0.6340758, + 0.6324082, + 0.6307397, + 0.6290704, + 0.6274003, + 0.6257294, + 0.62405777, + 0.6223854, + 0.62071234, + 0.6190386, + 0.61736417, + 0.6156891, + 0.61401343, + 0.6123372, + 0.6106603, + 0.6089829, + 0.607305, + 0.6056265, + 0.6039476, + 0.60226816, + 0.6005883, + 0.598908, + 0.59722733, + 0.5955463, + 0.59386486, + 0.5921831, + 0.59050107, + 0.5888187, + 0.5871361, + 0.5854532, + 0.5837701, + 0.5820868, + 0.5804033, + 0.5787197, + 0.5770359, + 0.575352, + 0.57366806, + 0.571984, + 0.5702999, + 0.5686158, + 0.56693166, + 0.56524754, + 0.5635635, + 0.5618795, + 0.56019557, + 0.5585118, + 0.5568281, + 0.55514455, + 0.5534612, + 0.551778, + 0.5500951, + 0.5484124, + 0.54673, + 0.5450478, + 0.54336596, + 0.54168445, + 0.54000324, + 0.53832245, + 0.5366421, + 0.53496206, + 0.5332825, + 0.53160346, + 0.5299248, + 0.52824676, + 0.5265692, + 0.52489215, + 0.5232157, + 0.5215398, + 0.51986456, + 0.51818997, + 0.51651603, + 0.51484275, + 0.5131702, + 0.5114983, + 0.5098272, + 0.50815684, + 0.5064873, + 0.50481856, + 0.50315064, + 0.50148356, + 0.4998174, + 0.4981521, + 0.49648774, + 0.49482432, + 0.49316183, + 0.49150035, + 0.48983985, + 0.4881804, + 0.486522, + 0.48486462, + 0.4832084, + 0.48155323, + 0.4798992, + 0.47824633, + 0.47659463, + 0.4749441, + 0.47329482, + 0.4716468, + 0.47, + 0.46835446, + 0.46671024, + 0.46506736, + 0.4634258, + 0.46178558, + 0.46014675, + 0.45850933, + 0.45687333, + 0.45523876, + 0.45360568, + 0.45197406, + 0.45034397, + 0.44871536, + 0.44708833, + 0.44546285, + 0.44383895, + 0.44221666, + 0.440596, + 0.43897697, + 0.43735963, + 0.43574396, + 0.43412998, + 0.43251774, + 0.43090722, + 0.4292985, + 0.42769152, + 0.42608637, + 0.42448303, + 0.4228815, + 0.42128187, + 0.4196841, + 0.41808826, + 0.4164943, + 0.4149023, + 0.41331223, + 0.41172415, + 0.41013804, + 0.40855396, + 0.4069719, + 0.4053919, + 0.40381396, + 0.4022381, + 0.40066436, + 0.39909273, + 0.39752322, + 0.3959559, + 0.39439073, + 0.39282778, + 0.39126703, + 0.3897085, + 0.3881522, + 0.3865982, + 0.38504648, + 0.38349706, + 0.38194993, + 0.38040516, + 0.37886274, + 0.37732267, + 0.375785, + 0.37424973, + 0.37271687, + 0.37118647, + 0.36965853, + 0.36813304, + 0.36661002, + 0.36508954, + 0.36357155, + 0.3620561, + 0.36054322, + 0.3590329, + 0.35752517, + 0.35602003, + 0.35451752, + 0.35301763, + 0.3515204, + 0.3500258, + 0.3485339, + 0.3470447, + 0.34555823, + 0.34407446, + 0.34259343, + 0.34111515, + 0.33963963, + 0.33816692, + 0.336697, + 0.3352299, + 0.33376563, + 0.3323042, + 0.33084565, + 0.32938993, + 0.32793713, + 0.3264872, + 0.32504022, + 0.32359615, + 0.32215503, + 0.32071686, + 0.31928164, + 0.31784943, + 0.3164202, + 0.314994, + 0.3135708, + 0.31215066, + 0.31073356, + 0.3093195, + 0.30790854, + 0.30650064, + 0.30509588, + 0.30369422, + 0.30229566, + 0.30090025, + 0.299508, + 0.2981189, + 0.29673296, + 0.29535022, + 0.2939707, + 0.29259437, + 0.29122123, + 0.28985137, + 0.28848472, + 0.28712133, + 0.2857612, + 0.28440437, + 0.2830508, + 0.28170055, + 0.2803536, + 0.27900997, + 0.27766964, + 0.27633268, + 0.27499905, + 0.2736688, + 0.27234194, + 0.27101842, + 0.2696983, + 0.26838157, + 0.26706827, + 0.26575837, + 0.26445192, + 0.26314887, + 0.2618493, + 0.26055318, + 0.2592605, + 0.25797132, + 0.2566856, + 0.2554034, + 0.25412467, + 0.25284946, + 0.25157773, + 0.2503096, + 0.24904492, + 0.24778382, + 0.24652626, + 0.24527225, + 0.2440218, + 0.24277493, + 0.24153163, + 0.24029191, + 0.23905578, + 0.23782326, + 0.23659433, + 0.23536903, + 0.23414734, + 0.23292927, + 0.23171483, + 0.23050404, + 0.22929688, + 0.22809339, + 0.22689353, + 0.22569734, + 0.22450483, + 0.22331597, + 0.2221308, + 0.22094932, + 0.21977153, + 0.21859743, + 0.21742703, + 0.21626033, + 0.21509734, + 0.21393807, + 0.21278252, + 0.21163069, + 0.21048258, + 0.20933822, + 0.20819758, + 0.2070607, + 0.20592754, + 0.20479813, + 0.20367248, + 0.20255059, + 0.20143245, + 0.20031808, + 0.19920748, + 0.19810064, + 0.19699757, + 0.19589828, + 0.19480278, + 0.19371104, + 0.1926231, + 0.19153893, + 0.19045855, + 0.18938197, + 0.18830918, + 0.18724018, + 0.18617497, + 0.18511358, + 0.18405597, + 0.18300217, + 0.18195218, + 0.18090598, + 0.1798636, + 0.17882504, + 0.17779027, + 0.1767593, + 0.17573217, + 0.17470883, + 0.1736893, + 0.1726736, + 0.1716617, + 0.17065361, + 0.16964935, + 0.1686489, + 0.16765225, + 0.16665943, + 0.16567042, + 0.16468522, + 0.16370384, + 0.16272627, + 0.16175252, + 0.16078258, + 0.15981644, + 0.15885411, + 0.1578956, + 0.15694089, + 0.15599, + 0.15504292, + 0.15409963, + 0.15316014, + 0.15222447, + 0.15129258, + 0.1503645, + 0.14944021, + 0.14851972, + 0.14760303, + 0.14669013, + 0.14578101, + 0.14487568, + 0.14397413, + 0.14307636, + 0.14218238, + 0.14129217, + 0.14040573, + 0.13952307, + 0.13864417, + 0.13776903, + 0.13689767, + 0.13603005, + 0.13516618, + 0.13430607, + 0.13344972, + 0.1325971, + 0.13174823, + 0.1309031, + 0.13006169, + 0.12922402, + 0.12839006, + 0.12755983, + 0.12673332, + 0.12591052, + 0.12509143, + 0.12427604, + 0.12346435, + 0.12265636, + 0.121852055, + 0.12105144, + 0.1202545, + 0.11946124, + 0.11867165, + 0.11788572, + 0.11710346, + 0.11632485, + 0.115549885, + 0.11477857, + 0.11401089, + 0.11324684, + 0.11248643, + 0.11172963, + 0.11097645, + 0.110226884, + 0.10948092, + 0.10873855, + 0.10799977, + 0.107264586, + 0.106532976, + 0.105804935, + 0.10508047, + 0.10435956, + 0.1036422, + 0.10292839, + 0.10221813, + 0.1015114, + 0.10080819, + 0.100108504, + 0.09941233, + 0.098719664, + 0.0980305, + 0.09734483, + 0.09666264, + 0.09598393, + 0.095308684, + 0.09463691, + 0.093968585, + 0.09330372, + 0.092642285, + 0.09198428, + 0.09132971, + 0.09067855, + 0.090030804, + 0.089386456, + 0.088745505, + 0.088107936, + 0.08747375, + 0.08684293, + 0.08621547, + 0.085591376, + 0.084970616, + 0.08435319, + 0.0837391, + 0.08312833, + 0.08252087, + 0.08191671, + 0.08131585, + 0.08071827, + 0.080123976, + 0.07953294, + 0.078945175, + 0.078360654, + 0.077779375, + 0.07720133, + 0.07662651, + 0.07605491, + 0.07548651, + 0.07492131, + 0.0743593, + 0.07380046, + 0.073244795, + 0.07269229, + 0.07214294, + 0.07159673, + 0.07105365, + 0.070513695, + 0.06997685, + 0.069443114, + 0.06891247, + 0.06838491, + 0.067860425, + 0.06733901, + 0.066820644, + 0.06630533, + 0.06579305, + 0.0652838, + 0.06477757, + 0.06427433, + 0.0637741, + 0.063276865, + 0.06278259, + 0.062291294, + 0.061802953, + 0.06131756, + 0.0608351, + 0.060355574, + 0.05987896, + 0.059405252, + 0.058934443, + 0.05846652, + 0.058001474, + 0.057539295, + 0.05707997, + 0.056623492, + 0.05616985, + 0.05571903, + 0.055271026, + 0.054825824, + 0.05438342, + 0.053943794, + 0.053506944, + 0.05307286, + 0.052641522, + 0.052212927, + 0.051787063, + 0.051363923, + 0.05094349, + 0.050525755, + 0.05011071, + 0.04969834, + 0.049288645, + 0.0488816, + 0.048477206, + 0.048075445, + 0.04767631, + 0.047279786, + 0.04688587, + 0.046494544, + 0.046105802, + 0.04571963, + 0.04533602, + 0.04495496, + 0.04457644, + 0.044200446, + 0.04382697, + 0.043456003, + 0.043087535, + 0.042721547, + 0.042358037, + 0.04199699, + 0.041638397, + 0.041282244, + 0.040928524, + 0.040577225, + 0.040228333, + 0.039881844, + 0.039537743, + 0.039196018, + 0.038856663, + 0.038519662, + 0.038185004, + 0.037852682, + 0.037522685, + 0.037195, + 0.036869615, + 0.036546525, + 0.036225714, + 0.03590717, + 0.035590887, + 0.035276853, + 0.034965057, + 0.034655485, + 0.03434813, + 0.03404298, + 0.033740025, + 0.033439253, + 0.033140652, + 0.032844216, + 0.03254993, + 0.032257784, + 0.03196777, + 0.031679876, + 0.031394087, + 0.031110398, + 0.030828796, + 0.030549273, + 0.030271813, + 0.02999641, + 0.029723052, + 0.029451728, + 0.029182427, + 0.02891514, + 0.028649855, + 0.028386563, + 0.028125253, + 0.02786591, + 0.027608532, + 0.027353102, + 0.027099613, + 0.026848052, + 0.026598409, + 0.026350675, + 0.02610484, + 0.02586089, + 0.02561882, + 0.025378617, + 0.025140269, + 0.024903767, + 0.0246691, + 0.02443626, + 0.024205236, + 0.023976017, + 0.023748592, + 0.023522953, + 0.023299087, + 0.023076987, + 0.022856642, + 0.02263804, + 0.022421172, + 0.022206029, + 0.0219926, + 0.021780876, + 0.021570845, + 0.021362498, + 0.021155827, + 0.020950818, + 0.020747466, + 0.020545758, + 0.020345684, + 0.020147236, + 0.019950403, + 0.019755175, + 0.019561544, + 0.019369498, + 0.019179028, + 0.018990126, + 0.01880278, + 0.018616982, + 0.018432721, + 0.01824999, + 0.018068777, + 0.017889075, + 0.017710872, + 0.01753416, + 0.017358929, + 0.017185168, + 0.017012872, + 0.016842028, + 0.016672628, + 0.016504662, + 0.016338123, + 0.016173, + 0.016009282, + 0.015846964, + 0.015686033, + 0.015526483, + 0.015368304, + 0.015211486, + 0.0150560215, + 0.014901901, + 0.014749114, + 0.014597654, + 0.014447511, + 0.0142986765, + 0.014151142, + 0.014004898, + 0.013859936, + 0.013716248, + 0.0135738235, + 0.013432656, + 0.013292736, + 0.013154055, + 0.013016605, + 0.012880377, + 0.012745362, + 0.012611552, + 0.012478939, + 0.012347515, + 0.01221727, + 0.012088198, + 0.0119602885, + 0.0118335355, + 0.011707929, + 0.011583461, + 0.011460125, + 0.011337912, + 0.011216813, + 0.011096821, + 0.010977928, + 0.0108601255, + 0.010743406, + 0.010627762, + 0.0105131855, + 0.010399668, + 0.010287202, + 0.01017578, + 0.010065395, + 0.009956039, + 0.009847702, + 0.009740381, + 0.0096340645, + 0.009528747, + 0.009424419, + 0.009321076, + 0.009218709, + 0.00911731, + 0.009016872, + 0.008917389, + 0.008818853, + 0.008721256, + 0.008624591, + 0.008528852, + 0.00843403, + 0.00834012, + 0.008247114, + 0.008155004, + 0.008063785, + 0.007973449, + 0.007883989, + 0.007795398, + 0.0077076694, + 0.0076207966, + 0.0075347726, + 0.007449591, + 0.0073652444, + 0.007281727, + 0.0071990318, + 0.007117152, + 0.0070360815, + 0.0069558136, + 0.0068763415, + 0.006797659, + 0.00671976, + 0.0066426382, + 0.0065662866, + 0.006490699, + 0.0064158696, + 0.006341792, + 0.00626846, + 0.0061958674, + 0.0061240084, + 0.0060528764, + 0.0059824656, + 0.0059127696, + 0.0058437833, + 0.0057755, + 0.0057079145, + 0.00564102, + 0.0055748112, + 0.0055092825, + 0.005444428, + 0.005380241, + 0.0053167176, + 0.005253851, + 0.005191636, + 0.005130066, + 0.0050691366, + 0.0050088423, + 0.0049491767, + 0.004890135, + 0.0048317118, + 0.004773902, + 0.004716699, + 0.0046600983, +] \ No newline at end of file diff --git a/stableDiffusionKeras/controlNetDiffusionModels.py b/stableDiffusionKeras/controlNetDiffusionModels.py new file mode 100644 index 0000000..c5bc7a4 --- /dev/null +++ b/stableDiffusionKeras/controlNetDiffusionModels.py @@ -0,0 +1,665 @@ +""" +Copyright 2022 The KerasCV Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +you may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +ControlNet Version + +by AJ Young + +""" + + +import keras + +""" +Models +""" + +class DiffusionModel(keras.Model): + def __init__( + self, + img_height, + img_width, + max_text_length, + name = "LockedDiffusionModel" + ): + context = keras.layers.Input((max_text_length, 768), name = "Context_Input") + t_embed_input = keras.layers.Input((320,), name = "TimeStepEmbed_Input") + latent = keras.layers.Input((img_height // 8, img_width // 8, 4), name = "LatentImage_Input") + + ### ControlNet Input ### + + controlNet1 = keras.layers.Input(shape = (1), name = "ControlNet_Input001") + controlNet2 = keras.layers.Input(shape = (1), name = "ControlNet_Input002") + controlNet3 = keras.layers.Input(shape = (1), name = "ControlNet_Input003") + controlNet4 = keras.layers.Input(shape = (1), name = "ControlNet_Input004") + controlNet5 = keras.layers.Input(shape = (1), name = "ControlNet_Input005") + controlNet6 = keras.layers.Input(shape = (1), name = "ControlNet_Input006") + controlNet7 = keras.layers.Input(shape = (1), name = "ControlNet_Input007") + controlNet8 = keras.layers.Input(shape = (1), name = "ControlNet_Input008") + controlNet9 = keras.layers.Input(shape = (1), name = "ControlNet_Input009") + controlNet10 = keras.layers.Input(shape = (1), name = "ControlNet_Input010") + controlNet11 = keras.layers.Input(shape = (1), name = "ControlNet_Input011") + controlNet12 = keras.layers.Input(shape = (1), name = "ControlNet_Input012") + controlNet13 = keras.layers.Input(shape = (1), name = "ControlNet_Input013") + + controlNetResults = [controlNet1, controlNet2, controlNet3, controlNet4, controlNet5, controlNet6, controlNet7, controlNet8, controlNet9, controlNet10, controlNet11, controlNet12, controlNet13] + + t_emb = keras.layers.Dense(1280, name = "TimeEmbed1")(t_embed_input) + t_emb = keras.layers.Activation("swish", name = "swishActivation")(t_emb) + t_emb = keras.layers.Dense(1280, name = "TimeEmbed2")(t_emb) + + # Downsampling flow, aka input_blocks + + outputs = [] + x = PaddedConv2D(320, kernel_size = 3, padding = 1, name = "inputBlocks")(latent) + outputs.append(x) + + for _ in range(2): + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(8, 40, fully_connected = False)([x, context]) + outputs.append(x) + x = PaddedConv2D(320, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(8, 80, fully_connected = False)([x, context]) + outputs.append(x) + x = PaddedConv2D(640, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected = False)([x, context]) + outputs.append(x) + x = PaddedConv2D(1280, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + outputs.append(x) + + # Middle flow + + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected = False)([x, context]) + x = ResBlock(1280)([x, t_emb]) + #controlNetResults, controlNetResult = tfPOP(controlNetResults) + x = x + controlNetResults.pop() + + # Upsampling flow + + for _ in range(3): + #controlNetResults, controlNetResult = tfPOP(controlNetResults) + x = keras.layers.Concatenate()([x, outputs.pop() + controlNetResults.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = Upsample(1280)(x) + + for _ in range(3): + #controlNetResults, controlNetResult = tfPOP(controlNetResults) + x = keras.layers.Concatenate()([x, outputs.pop() + controlNetResults.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected = False)([x, context]) + x = Upsample(1280)(x) + + for _ in range(3): + #controlNetResults, controlNetResult = tfPOP(controlNetResults) + x = keras.layers.Concatenate()([x, outputs.pop() + controlNetResults.pop()]) + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(8, 80, fully_connected = False)([x, context]) + x = Upsample(640)(x) + + for _ in range(3): + #controlNetResults, controlNetResult = tfPOP(controlNetResults) + x = keras.layers.Concatenate()([x, outputs.pop() + controlNetResults.pop()]) + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(8, 40, fully_connected = False)([x, context]) + + # Exit flow + + x = GroupNormalization(epsilon = 1e-5)(x) + x = keras.layers.Activation("swish")(x) + output = PaddedConv2D(4, kernel_size = 3, padding = 1)(x) + + super().__init__( + inputs = [ + latent, + t_embed_input, + context, + controlNet1, + controlNet2, + controlNet3, + controlNet4, + controlNet5, + controlNet6, + controlNet7, + controlNet8, + controlNet9, + controlNet10, + controlNet11, + controlNet12, + controlNet13 + ], + outputs = output, + name = name + ) + +class ControlNetDiffusionModel(keras.Model): + def __init__( + self, + img_height, + img_width, + max_text_length, + name = "ControlNetModel", + ): + context = keras.layers.Input((max_text_length, 768), name = "Context_Input") + inputHint = keras.layers.Input((img_height, img_width, 3), name = "Hint_Input") + t_embed_input = keras.layers.Input((320,), name = "TimeStepEmbed_Input") + latent = keras.layers.Input((img_height // 8, img_width // 8, 4), name = "LatentImage_Input") + + t_emb = keras.layers.Dense(1280, name = "ControlTimeEmbed1")(t_embed_input) + t_emb = keras.layers.Activation("swish", name = "swishActivation")(t_emb) + t_emb = keras.layers.Dense(1280, name = "ControlTimeEmbed2")(t_emb) + + # Input Hint Blocks + + guidedHint = HintBlocks()(inputHint) + + # Downsampling flow, aka input_blocks + + outputs = [] + x = PaddedConv2D(320, kernel_size = 3, padding = 1, name = "inputBlocks")(latent) + x = x + guidedHint + outputs.append(zeroConv(x, 320,"zeroConv1")) + + for _ in range(2): + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(8, 40, fully_connected = False)([x, context]) + outputs.append(zeroConv(x, 320)) + x = PaddedConv2D(320, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(zeroConv(x, 320, "zeroConv4")) + + for _ in range(2): + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(8, 80, fully_connected = False)([x, context]) + outputs.append(zeroConv(x, 640)) + x = PaddedConv2D(640, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(zeroConv(x, 640, "zeroConv7")) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected = False)([x, context]) + outputs.append(zeroConv(x, 1280)) + x = PaddedConv2D(1280, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(zeroConv(x, 1280, "zeroConv10")) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + outputs.append(zeroConv(x, 1280)) + + # Middle flow + + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected = False)([x, context]) + x = ResBlock(1280)([x, t_emb]) + outputs.append(zeroConv(x, 1280,"zeroConv13")) + + super().__init__([latent, t_embed_input, context, inputHint], outputs, name = name) + # Input: Latent, TimestepEmbed, Context, Input Hint + # Output: Python List of each zeroConv (Zero Convolution Layer) + + +class DiffusionModelV2(keras.Model): + def __init__( + self, img_height, img_width, max_text_length, name = None + ): + context = keras.layers.Input((max_text_length, 1024)) + t_embed_input = keras.layers.Input((320,)) + latent = keras.layers.Input((img_height // 8, img_width // 8, 4)) + + t_emb = keras.layers.Dense(1280)(t_embed_input) + t_emb = keras.layers.Activation("swish")(t_emb) + t_emb = keras.layers.Dense(1280)(t_emb) + + # Downsampling flow + + outputs = [] + x = PaddedConv2D(320, kernel_size = 3, padding = 1)(latent) + outputs.append(x) + + for _ in range(2): + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(5, 64, fully_connected = True)([x, context]) + outputs.append(x) + x = PaddedConv2D(320, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(10, 64, fully_connected = True)([x, context]) + outputs.append(x) + x = PaddedConv2D(640, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected = True)([x, context]) + outputs.append(x) + x = PaddedConv2D(1280, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + outputs.append(x) + + # Middle flow + + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + x = ResBlock(1280)([x, t_emb]) + + # Upsampling flow + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(10, 64, fully_connected=True)([x, context]) + x = Upsample(640)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(5, 64, fully_connected=True)([x, context]) + + # Exit flow + + x = GroupNormalization(epsilon = 1e-5)(x) + x = keras.layers.Activation("swish")(x) + output = PaddedConv2D(4, kernel_size = 3, padding = 1)(x) + + super().__init__([latent, t_embed_input, context], output, name = name) + +""" +Blocks +""" + +class GroupNormalization(keras.layers.Layer): + """ + GroupNormalization layer. + + This layer is only here temporarily and will be removed + as we introduce GroupNormalization in core Keras. + """ + + def __init__( + self, + groups = 32, + axis = -1, + epsilon = 1e-5, + name = "GroupNormalization", + **kwargs, + ): + super().__init__(**kwargs) + self.groups = groups + self.axis = axis + self.epsilon = epsilon + + def get_config(self): + config = super().get_config() + config.update({ + "groups": self.groups, + "axis": self.axis, + "epsilon": self.epsilon + }) + return config + + def build(self, input_shape): + dim = input_shape[self.axis] + self.gamma = self.add_weight( + shape=(dim,), + name = "gamma", + initializer = "ones", + ) + self.beta = self.add_weight( + shape=(dim,), + name = "beta", + initializer = "zeros", + ) + + ## @tf.function + def call(self, inputs): + input_shape = keras.ops.shape(inputs) + reshaped_inputs = self._reshape_into_groups(inputs, input_shape) + normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape) + return keras.ops.reshape(normalized_inputs, input_shape) + + def _reshape_into_groups(self, inputs, input_shape): + group_shape = [input_shape[i] for i in range(inputs.shape.rank)] + group_shape[self.axis] = input_shape[self.axis] // self.groups + group_shape.insert(self.axis, self.groups) + group_shape = keras.ops.stack(group_shape) + return keras.ops.reshape(inputs, group_shape) + + def _apply_normalization(self, reshaped_inputs, input_shape): + group_reduction_axes = list(range(1, reshaped_inputs.shape.rank)) + axis = -2 if self.axis == -1 else self.axis - 1 + group_reduction_axes.pop(axis) + mean, variance = keras.ops.moments( + reshaped_inputs, group_reduction_axes, keepdims=True + ) + gamma, beta = self._get_reshaped_weights(input_shape) + return keras.ops.batch_normalization( + reshaped_inputs, + mean=mean, + variance=variance, + scale=gamma, + offset=beta, + variance_epsilon=self.epsilon, + ) + + def _get_reshaped_weights(self, input_shape): + broadcast_shape = self._create_broadcast_shape(input_shape) + gamma = keras.ops.reshape(self.gamma, broadcast_shape) + beta = keras.ops.reshape(self.beta, broadcast_shape) + return gamma, beta + + def _create_broadcast_shape(self, input_shape): + broadcast_shape = [1] * input_shape.shape.rank + broadcast_shape[self.axis] = input_shape[self.axis] // self.groups + broadcast_shape.insert(self.axis, self.groups) + return broadcast_shape + +class HintBlocks(keras.layers.Layer): + def __init__(self, hint_channels = 16, model_channels = 320, **kwargs): + super().__init__(**kwargs) + self.layers = [ + PaddedConv2D(filters = 16, kernel_size = 3, padding = 1), + keras.layers.Activation('swish'), + PaddedConv2D(filters = 16, kernel_size = 3, padding = 1), + keras.layers.Activation('swish'), + PaddedConv2D(filters = 32, kernel_size = 3, padding = 1, strides = 2), + keras.layers.Activation('swish'), + PaddedConv2D(filters = 32, kernel_size = 3, padding = 1), + keras.layers.Activation('swish'), + PaddedConv2D(filters = 96, kernel_size = 3, padding = 1, strides = 2), + keras.layers.Activation('swish'), + PaddedConv2D(filters = 96, kernel_size = 3, padding = 1), + keras.layers.Activation('swish'), + PaddedConv2D(filters = 256, kernel_size = 3, padding = 1, strides = 2), + keras.layers.Activation('swish'), + PaddedConv2D(filters = model_channels, kernel_size = 3, padding = 1) + ] + + ## @tf.function + def call(self, inputs): + x = inputs + layerNumber = 0 + layerLength = len(self.layers) + for layer in self.layers: + if layerNumber == layerLength: + for weight in layer.weights: + weight.assign(keras.ops.zeros_like(weight)) + x = layer(x) + layerNumber += 1 + return x + +class PaddedConv2D(keras.layers.Layer): + def __init__( + self, + filters, + kernel_size, + padding = 0, + strides = 1, + name = None, + **kwargs + ): + super().__init__(**kwargs) + self.padding2d = keras.layers.ZeroPadding2D(padding, name = name) + self.conv2d = keras.layers.Conv2D(filters, kernel_size, strides = strides, name = name) + self.filters = filters + self.kernel_size = kernel_size + self.padding = padding + self.strides = strides + + ## @tf.function + def call(self, inputs): + x = self.padding2d(inputs) + return self.conv2d(x) + + def get_config(self): + config = super().get_config() + config.update({ + "filters": self.filters, + "kernel_size": self.kernel_size, + "padding": self.padding, + "strides": self.strides, + }) + return config + +class ResBlock(keras.layers.Layer): + def __init__( + self, + output_dim, + **kwargs + ): + super().__init__(**kwargs) + self.output_dim = output_dim + self.entry_flow = [ + GroupNormalization(epsilon = 1e-5), + keras.layers.Activation("swish", name = "ResBlock_swish1"), + PaddedConv2D(output_dim, 3, padding = 1, name = "inLayers2"), + ] + self.embedding_flow = [ + keras.layers.Activation("swish"), + keras.layers.Dense(output_dim, name = "embeddingLayer"), + ] + self.exit_flow = [ + GroupNormalization(epsilon = 1e-5), + keras.layers.Activation("swish", name = "ResBlock_swish2"), + PaddedConv2D(output_dim, 3, padding = 1, name = "outLayers3"), + ] + + def build(self, input_shape): + if input_shape[0][-1] != self.output_dim: + self.residual_projection = PaddedConv2D(self.output_dim, 1) + else: + self.residual_projection = lambda x: x + + ## @tf.function + def call(self, inputs): + inputs, embeddings = inputs + x = inputs + for layer in self.entry_flow: + x = layer(x) + for layer in self.embedding_flow: + embeddings = layer(embeddings) + x = x + embeddings[:, None, None] + for layer in self.exit_flow: + x = layer(x) + return x + self.residual_projection(inputs) + + def get_config(self): + config = super().get_config() + config.update({ + "output_dim": self.output_dim, + }) + return config + + +class SpatialTransformer(keras.layers.Layer): + def __init__( + self, + num_heads, + head_size, + fully_connected = False, + **kwargs + ): + super().__init__(**kwargs) + self.norm = GroupNormalization(epsilon = 1e-5) + self.num_heads = num_heads + self.head_size = head_size + self.fully_connected = fully_connected + channels = num_heads * head_size + if fully_connected: + self.proj1 = keras.layers.Dense(num_heads * head_size, name = "proj_in1_fullyConnected") + else: + self.proj1 = PaddedConv2D(num_heads * head_size, 1, name = "proj_in") + self.transformer_block = BasicTransformerBlock(channels, num_heads, head_size) + if fully_connected: + self.proj2 = keras.layers.Dense(channels) + else: + self.proj2 = PaddedConv2D(channels, 1, name = "proj_in2_fullyConnected") + + def get_config(self): + config = super().get_config() + config.update({ + "num_heads": self.num_heads, + "head_size": self.head_size, + "fully_connected": self.fully_connected + }) + return config + + ## @tf.function + def call(self, inputs): + inputs, context = inputs + _, h, w, c = inputs.shape + x = self.norm(inputs) + x = self.proj1(x) + x = keras.ops.reshape(x, (-1, h * w, c)) + x = self.transformer_block([x, context]) + x = keras.ops.reshape(x, (-1, h, w, c)) + return self.proj2(x) + inputs + + +class BasicTransformerBlock(keras.layers.Layer): + def __init__( + self, + dim, + num_heads, + head_size, + **kwargs + ): + super().__init__(**kwargs) + self.norm1 = keras.layers.LayerNormalization(epsilon = 1e-5, name = "norm1") + self.attn1 = CrossAttention(num_heads, head_size) + + self.norm2 = keras.layers.LayerNormalization(epsilon = 1e-5, name = "norm2") + self.attn2 = CrossAttention(num_heads, head_size) + + self.norm3 = keras.layers.LayerNormalization(epsilon = 1e-5, name = "norm3") + self.geglu = GEGLU(dim * 4) + self.dense = keras.layers.Dense(dim) + + ## @tf.function + def call(self, inputs): + inputs, context = inputs + x = self.attn1([self.norm1(inputs), None]) + inputs + x = self.attn2([self.norm2(x), context]) + x + return self.dense(self.geglu(self.norm3(x))) + x + + +class CrossAttention(keras.layers.Layer): + def __init__(self, num_heads, head_size, **kwargs): + super().__init__(**kwargs) + self.to_q = keras.layers.Dense(num_heads * head_size, use_bias = False, name = "to_q") + self.to_k = keras.layers.Dense(num_heads * head_size, use_bias = False, name = "to_k") + self.to_v = keras.layers.Dense(num_heads * head_size, use_bias = False, name = "to_v") + self.scale = head_size**-0.5 + self.num_heads = num_heads + self.head_size = head_size + self.out_proj = keras.layers.Dense(num_heads * head_size, name = "out_projection") + + ## @tf.function + def call(self, inputs): + inputs, context = inputs + context = inputs if context is None else context + q, k, v = self.to_q(inputs), self.to_k(context), self.to_v(context) + q = keras.ops.reshape(q, (-1, inputs.shape[1], self.num_heads, self.head_size)) + k = keras.ops.reshape(k, (-1, context.shape[1], self.num_heads, self.head_size)) + v = keras.ops.reshape(v, (-1, context.shape[1], self.num_heads, self.head_size)) + + q = keras.ops.transpose(q, (0, 2, 1, 3)) # (bs, num_heads, time, head_size) + k = keras.ops.transpose(k, (0, 2, 3, 1)) # (bs, num_heads, head_size, time) + v = keras.ops.transpose(v, (0, 2, 1, 3)) # (bs, num_heads, time, head_size) + + score = td_dot(q, k) * self.scale + weights = keras.activations.softmax(score) # (bs, num_heads, time, time) + attn = td_dot(weights, v) + attn = keras.ops.transpose(attn, (0, 2, 1, 3)) # (bs, time, num_heads, head_size) + out = keras.ops.reshape(attn, (-1, inputs.shape[1], self.num_heads * self.head_size)) + return self.out_proj(out) + +class Upsample(keras.layers.Layer): + def __init__( + self, + channels, + **kwargs + ): + super().__init__(**kwargs) + self.channels = channels + self.ups = keras.layers.UpSampling2D(2) + self.conv = PaddedConv2D(channels, 3, padding = 1, name = "Upsample") + + ## @tf.function + def call(self, inputs): + return self.conv(self.ups(inputs)) + + def get_config(self): + config = super().get_config() + config.update({ + "channels": self.channels, + }) + return config + +class GEGLU(keras.layers.Layer): + def __init__(self, output_dim, **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.dense = keras.layers.Dense(output_dim * 2) + + ## @tf.function + def call(self, inputs): + x = self.dense(inputs) + x, gate = x[..., : self.output_dim], x[..., self.output_dim :] + tanh_res = keras.activations.tanh( + gate * 0.7978845608 * (1 + 0.044715 * (gate**2)) + ) + return x * 0.5 * gate * (1 + tanh_res) + +def td_dot(a, b): + aa = keras.ops.reshape(a, (-1, a.shape[2], a.shape[3])) + bb = keras.ops.reshape(b, (-1, b.shape[2], b.shape[3])) + cc = keras.backend.batch_dot(aa, bb) + return keras.ops.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2])) + +def zeroConv(tensor, channels, name = None): + layer = keras.layers.Conv2D(filters = channels, kernel_size = 1, padding = 'same', name = name) + for weight in layer.weights: + weight.assign(keras.ops.zeros_like(weight)) + + tensor = layer(tensor) + + return tensor \ No newline at end of file diff --git a/stableDiffusionKeras/kerasCVDiffusionModels.py b/stableDiffusionKeras/kerasCVDiffusionModels.py new file mode 100644 index 0000000..e7850d8 --- /dev/null +++ b/stableDiffusionKeras/kerasCVDiffusionModels.py @@ -0,0 +1,517 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import keras + +""" +Models +""" + +class DiffusionModel(keras.Model): + def __init__( + self, img_height, img_width, max_text_length, name = "DiffusionModel", download_weights = False + ): + context = keras.layers.Input((max_text_length, 768), name = "Context_Input") + t_embed_input = keras.layers.Input((320,), name = "TimeStepEmbed_Input") + latent = keras.layers.Input((img_height // 8, img_width // 8, 4), name = "LatentImage_Input") + + t_emb = keras.layers.Dense(1280, name = "TimeEmbed1")(t_embed_input) + t_emb = keras.layers.Activation("swish", name = "swishActivation")(t_emb) + t_emb = keras.layers.Dense(1280, name = "TimeEmbed2")(t_emb) + + # Downsampling flow, aka input_blocks + + outputs = [] + x = PaddedConv2D(320, kernel_size = 3, padding = 1, name = "inputBlocks")(latent) + outputs.append(x) + + for _ in range(2): + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(8, 40, fully_connected = False)([x, context]) + outputs.append(x) + x = PaddedConv2D(320, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(8, 80, fully_connected = False)([x, context]) + outputs.append(x) + x = PaddedConv2D(640, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected = False)([x, context]) + outputs.append(x) + x = PaddedConv2D(1280, 3, strides = 2, padding = 1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + outputs.append(x) + + # Middle flow + + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected = False)([x, context]) + x = ResBlock(1280)([x, t_emb]) + + # Upsampling flow + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(8, 160, fully_connected = False)([x, context]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(8, 80, fully_connected = False)([x, context]) + x = Upsample(640)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(8, 40, fully_connected = False)([x, context]) + + # Exit flow + + x = GroupNormalization(epsilon = 1e-5)(x) + x = keras.layers.Activation("swish")(x) + output = PaddedConv2D(4, kernel_size = 3, padding = 1)(x) + + super().__init__([latent, t_embed_input, context], output, name = name) + + if download_weights: + diffusion_model_weights_fpath = keras.utils.get_file( + origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5", + file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe", + ) + self.load_weights(diffusion_model_weights_fpath) + + +class DiffusionModelV2(keras.Model): + def __init__( + self, img_height, img_width, max_text_length, name = None, download_weights = False + ): + context = keras.layers.Input((max_text_length, 1024)) + t_embed_input = keras.layers.Input((320,)) + latent = keras.layers.Input((img_height // 8, img_width // 8, 4)) + + t_emb = keras.layers.Dense(1280)(t_embed_input) + t_emb = keras.layers.Activation("swish")(t_emb) + t_emb = keras.layers.Dense(1280)(t_emb) + + # Downsampling flow + + outputs = [] + x = PaddedConv2D(320, kernel_size = 3, padding = 1)(latent) + outputs.append(x) + + for _ in range(2): + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(5, 64, fully_connected=True)([x, context]) + outputs.append(x) + x = PaddedConv2D(320, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(10, 64, fully_connected=True)([x, context]) + outputs.append(x) + x = PaddedConv2D(640, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + outputs.append(x) + x = PaddedConv2D(1280, 3, strides=2, padding=1)(x) # Downsample 2x + outputs.append(x) + + for _ in range(2): + x = ResBlock(1280)([x, t_emb]) + outputs.append(x) + + # Middle flow + + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + x = ResBlock(1280)([x, t_emb]) + + # Upsampling flow + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(1280)([x, t_emb]) + x = SpatialTransformer(20, 64, fully_connected=True)([x, context]) + x = Upsample(1280)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(640)([x, t_emb]) + x = SpatialTransformer(10, 64, fully_connected=True)([x, context]) + x = Upsample(640)(x) + + for _ in range(3): + x = keras.layers.Concatenate()([x, outputs.pop()]) + x = ResBlock(320)([x, t_emb]) + x = SpatialTransformer(5, 64, fully_connected=True)([x, context]) + + # Exit flow + + x = GroupNormalization(epsilon=1e-5)(x) + x = keras.layers.Activation("swish")(x) + output = PaddedConv2D(4, kernel_size=3, padding=1)(x) + + super().__init__([latent, t_embed_input, context], output, name = name) + + if download_weights: + diffusion_model_weights_fpath = keras.utils.get_file( + origin="https://huggingface.co/ianstenbit/keras-sd2.1/resolve/main/diffusion_model_v2_1.h5", + file_hash="c31730e91111f98fe0e2dbde4475d381b5287ebb9672b1821796146a25c5132d", + ) + self.load_weights(diffusion_model_weights_fpath) + +""" +Blocks +""" + +class GroupNormalization(keras.layers.Layer): + """GroupNormalization layer. + This layer is only here temporarily and will be removed + as we introduce GroupNormalization in core Keras. + """ + + def __init__( + self, + groups = 32, + axis = -1, + epsilon = 1e-5, + name = "GroupNormalization", + **kwargs, + ): + super().__init__(**kwargs) + self.groups = groups + self.axis = axis + self.epsilon = epsilon + + def get_config(self): + config = super().get_config() + config.update({ + "groups": self.groups, + "axis": self.axis, + "epsilon": self.epsilon + }) + return config + + def build(self, input_shape): + dim = input_shape[self.axis] + self.gamma = self.add_weight( + shape=(dim,), + name = "gamma", + initializer = "ones", + ) + self.beta = self.add_weight( + shape=(dim,), + name = "beta", + initializer = "zeros", + ) + + #@tf.function + def call(self, inputs): + input_shape = keras.ops.shape(inputs) + reshaped_inputs = self._reshape_into_groups(inputs, input_shape) + normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape) + return keras.ops.reshape(normalized_inputs, input_shape) + + def _reshape_into_groups(self, inputs, input_shape): + group_shape = [input_shape[i] for i in range(inputs.shape.rank)] + group_shape[self.axis] = input_shape[self.axis] // self.groups + group_shape.insert(self.axis, self.groups) + group_shape = keras.ops.stack(group_shape) + return keras.ops.reshape(inputs, group_shape) + + def _apply_normalization(self, reshaped_inputs, input_shape): + group_reduction_axes = list(range(1, reshaped_inputs.shape.rank)) + axis = -2 if self.axis == -1 else self.axis - 1 + group_reduction_axes.pop(axis) + mean, variance = keras.ops.moments( + reshaped_inputs, group_reduction_axes, keepdims=True + ) + gamma, beta = self._get_reshaped_weights(input_shape) + return keras.ops.batch_normalization( + reshaped_inputs, + mean=mean, + variance=variance, + scale=gamma, + offset=beta, + variance_epsilon=self.epsilon, + ) + + def _get_reshaped_weights(self, input_shape): + broadcast_shape = self._create_broadcast_shape(input_shape) + gamma = keras.ops.reshape(self.gamma, broadcast_shape) + beta = keras.ops.reshape(self.beta, broadcast_shape) + return gamma, beta + + def _create_broadcast_shape(self, input_shape): + broadcast_shape = [1] * input_shape.shape.rank + broadcast_shape[self.axis] = input_shape[self.axis] // self.groups + broadcast_shape.insert(self.axis, self.groups) + return broadcast_shape + +class PaddedConv2D(keras.layers.Layer): + def __init__( + self, + filters, + kernel_size, + padding = 0, + strides = 1, + name = None, + **kwargs + ): + super().__init__(**kwargs) + self.padding2d = keras.layers.ZeroPadding2D(padding, name = name) + self.conv2d = keras.layers.Conv2D(filters, kernel_size, strides = strides, name = name) + self.filters = filters + self.kernel_size = kernel_size + self.padding = padding + self.strides = strides + + #@tf.function + def call(self, inputs): + x = self.padding2d(inputs) + return self.conv2d(x) + + def get_config(self): + config = super().get_config() + config.update({ + "filters": self.filters, + "kernel_size": self.kernel_size, + "padding": self.padding, + "strides": self.strides, + }) + return config + +class ResBlock(keras.layers.Layer): + def __init__( + self, + output_dim, + **kwargs + ): + super().__init__(**kwargs) + self.output_dim = output_dim + self.entry_flow = [ + GroupNormalization(epsilon = 1e-5), + keras.layers.Activation("swish", name = "ResBlock_swish1"), + PaddedConv2D(output_dim, 3, padding = 1, name = "inLayers2"), + ] + self.embedding_flow = [ + keras.layers.Activation("swish"), + keras.layers.Dense(output_dim, name = "embeddingLayer"), + ] + self.exit_flow = [ + GroupNormalization(epsilon = 1e-5), + keras.layers.Activation("swish", name = "ResBlock_swish2"), + PaddedConv2D(output_dim, 3, padding = 1, name = "outLayers3"), + ] + + def build(self, input_shape): + if input_shape[0][-1] != self.output_dim: + self.residual_projection = PaddedConv2D(self.output_dim, 1) + else: + self.residual_projection = lambda x: x + + #@tf.function + def call(self, inputs): + inputs, embeddings = inputs + x = inputs + for layer in self.entry_flow: + x = layer(x) + for layer in self.embedding_flow: + embeddings = layer(embeddings) + x = x + embeddings[:, None, None] + for layer in self.exit_flow: + x = layer(x) + return x + self.residual_projection(inputs) + + def get_config(self): + config = super().get_config() + config.update({ + "output_dim": self.output_dim, + }) + return config + + +class SpatialTransformer(keras.layers.Layer): + def __init__( + self, + num_heads, + head_size, + fully_connected = False, + **kwargs + ): + super().__init__(**kwargs) + self.norm = GroupNormalization(epsilon = 1e-5) + self.num_heads = num_heads + self.head_size = head_size + self.fully_connected = fully_connected + channels = num_heads * head_size + if fully_connected: + self.proj1 = keras.layers.Dense(num_heads * head_size, name = "proj_in1_fullyConnected") + else: + self.proj1 = PaddedConv2D(num_heads * head_size, 1, name = "proj_in") + self.transformer_block = BasicTransformerBlock(channels, num_heads, head_size) + if fully_connected: + self.proj2 = keras.layers.Dense(channels) + else: + self.proj2 = PaddedConv2D(channels, 1, name = "proj_in2_fullyConnected") + + def get_config(self): + config = super().get_config() + config.update({ + "num_heads": self.num_heads, + "head_size": self.head_size, + "fully_connected": self.fully_connected + }) + return config + + #@tf.function + def call(self, inputs): + inputs, context = inputs + _, h, w, c = inputs.shape + x = self.norm(inputs) + x = self.proj1(x) + x = keras.ops.reshape(x, (-1, h * w, c)) + x = self.transformer_block([x, context]) + x = keras.ops.reshape(x, (-1, h, w, c)) + return self.proj2(x) + inputs + + +class BasicTransformerBlock(keras.layers.Layer): + def __init__( + self, + dim, + num_heads, + head_size, + **kwargs + ): + super().__init__(**kwargs) + self.norm1 = keras.layers.LayerNormalization(epsilon = 1e-5, name = "norm1") + self.attn1 = CrossAttention(num_heads, head_size) + + self.norm2 = keras.layers.LayerNormalization(epsilon = 1e-5, name = "norm2") + self.attn2 = CrossAttention(num_heads, head_size) + + self.norm3 = keras.layers.LayerNormalization(epsilon = 1e-5, name = "norm3") + self.geglu = GEGLU(dim * 4) + self.dense = keras.layers.Dense(dim) + + #@tf.function + def call(self, inputs): + inputs, context = inputs + x = self.attn1([self.norm1(inputs), None]) + inputs + x = self.attn2([self.norm2(x), context]) + x + return self.dense(self.geglu(self.norm3(x))) + x + + +class CrossAttention(keras.layers.Layer): + def __init__(self, num_heads, head_size, **kwargs): + super().__init__(**kwargs) + self.to_q = keras.layers.Dense(num_heads * head_size, use_bias = False, name = "to_q") + self.to_k = keras.layers.Dense(num_heads * head_size, use_bias = False, name = "to_k") + self.to_v = keras.layers.Dense(num_heads * head_size, use_bias = False, name = "to_v") + self.scale = head_size**-0.5 + self.num_heads = num_heads + self.head_size = head_size + self.out_proj = keras.layers.Dense(num_heads * head_size, name = "out_projection") + + #@tf.function + def call(self, inputs): + inputs, context = inputs + context = inputs if context is None else context + q, k, v = self.to_q(inputs), self.to_k(context), self.to_v(context) + q = keras.ops.reshape(q, (-1, inputs.shape[1], self.num_heads, self.head_size)) + k = keras.ops.reshape(k, (-1, context.shape[1], self.num_heads, self.head_size)) + v = keras.ops.reshape(v, (-1, context.shape[1], self.num_heads, self.head_size)) + + q = keras.ops.transpose(q, (0, 2, 1, 3)) # (bs, num_heads, time, head_size) + k = keras.ops.transpose(k, (0, 2, 3, 1)) # (bs, num_heads, head_size, time) + v = keras.ops.transpose(v, (0, 2, 1, 3)) # (bs, num_heads, time, head_size) + + score = td_dot(q, k) * self.scale + weights = keras.activations.softmax(score) # (bs, num_heads, time, time) + attn = td_dot(weights, v) + attn = keras.ops.transpose(attn, (0, 2, 1, 3)) # (bs, time, num_heads, head_size) + out = keras.ops.reshape(attn, (-1, inputs.shape[1], self.num_heads * self.head_size)) + return self.out_proj(out) + + +class Upsample(keras.layers.Layer): + def __init__( + self, + channels, + **kwargs + ): + super().__init__(**kwargs) + self.channels = channels + self.ups = keras.layers.UpSampling2D(2) + self.conv = PaddedConv2D(channels, 3, padding = 1, name = "Upsample") + + #@tf.function + def call(self, inputs): + return self.conv(self.ups(inputs)) + + def get_config(self): + config = super().get_config() + config.update({ + "channels": self.channels, + }) + return config + + +class GEGLU(keras.layers.Layer): + def __init__(self, output_dim, **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.dense = keras.layers.Dense(output_dim * 2) + + #@tf.function + def call(self, inputs): + x = self.dense(inputs) + x, gate = x[..., : self.output_dim], x[..., self.output_dim :] + tanh_res = keras.activations.tanh( + gate * 0.7978845608 * (1 + 0.044715 * (gate**2)) + ) + return x * 0.5 * gate * (1 + tanh_res) + + +def td_dot(a, b): + aa = keras.ops.reshape(a, (-1, a.shape[2], a.shape[3])) + bb = keras.ops.reshape(b, (-1, b.shape[2], b.shape[3])) + cc = keras.backend.batch_dot(aa, bb) + return keras.ops.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2])) \ No newline at end of file diff --git a/stableDiffusionKeras/layers.py b/stableDiffusionKeras/layers.py new file mode 100644 index 0000000..5e20b63 --- /dev/null +++ b/stableDiffusionKeras/layers.py @@ -0,0 +1,64 @@ + +import keras + + +class PaddedConv2D(keras.layers.Layer): + def __init__( + self, + channels, + kernel_size, + padding = 0, + stride = 1, + name = None + ): + super().__init__() + self.padding2d = keras.layers.ZeroPadding2D((padding, padding), name = name) + self.conv2d = keras.layers.Conv2D( + channels, kernel_size, strides=(stride, stride), name = name + ) + + def call(self, x): + x = self.padding2d(x) + return self.conv2d(x) + + +class GEGLU(keras.layers.Layer): + def __init__(self, dim_out, name = None): + super().__init__() + self.proj = keras.layers.Dense(dim_out * 2, name = name) + self.dim_out = dim_out + + def call(self, x): + xp = self.proj(x) + x, gate = xp[..., : self.dim_out], xp[..., self.dim_out :] + return x * gelu(gate) + + +def gelu(x): + tanh_res = keras.activations.tanh(x * 0.7978845608 * (1 + 0.044715 * (x**2))) + return 0.5 * x * (1 + tanh_res) + + +def quick_gelu(x): + return x * keras.ops.sigmoid(x * 1.702) + + +"""def apply_seq(x, layers): + for l in layers: + x = l(x) + return x""" + +def apply_seq(x, seq_layer): + if isinstance(seq_layer, keras.Sequential): + x = seq_layer(x) + else: + for l in seq_layer: + x = l(x) + return x + + +def td_dot(a, b): + aa = keras.ops.reshape(a, (-1, a.shape[2], a.shape[3])) + bb = keras.ops.reshape(b, (-1, b.shape[2], b.shape[3])) + cc = keras.backend.batch_dot(aa, bb) + return keras.ops.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2])) diff --git a/stableDiffusionKeras/openClipEncoder.py b/stableDiffusionKeras/openClipEncoder.py new file mode 100644 index 0000000..354bcbb --- /dev/null +++ b/stableDiffusionKeras/openClipEncoder.py @@ -0,0 +1,152 @@ + +import keras +import tensorflow_addons as tfa +import numpy as np + +from .layers import quick_gelu, gelu + +# Step 1 +# Create and return the CLIP Embeddings +class OpenCLIPTextTransformer(keras.models.Model): + def __init__( + self, + maxLength = 77, + vocabularySize = 49408 + ): + super().__init__() + + # Create embeddings -> Step 2 + self.embeddings = OpenCLIPTextEmbeddings(maxLength = maxLength, vocabularySize = vocabularySize) + + # Create encoder -> Step 3 + self.encoder = OpenCLIPEncoder() + + self.final_layer_norm = keras.layers.LayerNormalization(epsilon = 1e-5, name = "FinalLayerNormalization") + self.causal_attention_mask = keras.initializers.Constant( + np.triu(np.ones((1, 1, 77, 77), dtype = "float32") * -np.inf, k = 1) + ) + + def call(self, inputs): + input_ids, position_ids = inputs + x = self.embeddings([input_ids, position_ids]) + x = self.encoder([x, self.causal_attention_mask]) + return self.final_layer_norm(x) + +# Step 2 +# Create and return word and position embeddings +class OpenCLIPTextEmbeddings(keras.layers.Layer): + def __init__( + self, + maxLength = 77, + vocabularySize = 49408, + embeddingSize = 1024 + ): + super().__init__() + # Token Embedding Layer - Representing a sequence of tokens (words) + self.token_embedding_layer = keras.layers.Embedding( + vocabularySize, embeddingSize, name = "token_embedding" + ) + # Position Embedding layer - Where is the word in the sentence? What does it mean in the context of the sentence? + self.position_embedding_layer = keras.layers.Embedding( + maxLength, embeddingSize, name = "position_embedding" + ) + + def call(self, inputs): + input_ids, position_ids = inputs + word_embeddings = self.token_embedding_layer(input_ids) + position_embeddings = self.position_embedding_layer(position_ids) + return word_embeddings + position_embeddings + +# Step 3 +# Create and return the hidden states (aka hidden size) +class OpenCLIPEncoder(keras.layers.Layer): + def __init__(self): + super().__init__() + self.layers = [OpenCLIPEncoderLayer() for i in range(24)] + + def call(self, inputs): + [hidden_states, causal_attention_mask] = inputs + for l in self.layers: + hidden_states = l([hidden_states, causal_attention_mask]) + return hidden_states + +# Step 4 (also creatd in step 3) +# Create the layers +class OpenCLIPEncoderLayer(keras.layers.Layer): + def __init__( + self, + intermediateSize = 4096, + embeddingSize = 1024 + ): + super().__init__() + self.layer_norm1 = keras.layers.LayerNormalization(epsilon = 1e-5, name = "LayerNormalization01") # Layer Normalization 1 + self.self_attn = OpenCLIPAttention() # Attention Layers + self.layer_norm2 = keras.layers.LayerNormalization(epsilon = 1e-5, name = "LayerNormalization02") # Layer Normalization 2 + self.fc1 = keras.layers.Dense(intermediateSize, name = "FC1") # MLP layer? + self.fc2 = keras.layers.Dense(embeddingSize, name = "FC2") # ??? + + def call(self, inputs): + hidden_states, causal_attention_mask = inputs + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn([hidden_states, causal_attention_mask]) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + + # MLP Steps + hidden_states = self.fc1(hidden_states) + hidden_states = gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + + return residual + hidden_states + +class OpenCLIPAttention(keras.layers.Layer): + def __init__(self): + super().__init__() + self.embed_dim = 1024 + self.num_heads = 16 + self.head_dim = self.embed_dim // self.num_heads + self.scale = self.head_dim**-0.5 + self.q_proj = keras.layers.Dense(self.embed_dim, name = "QueryState") # Query states, the given word + self.k_proj = keras.layers.Dense(self.embed_dim, name = "KeyState") # Key states, all other words + self.v_proj = keras.layers.Dense(self.embed_dim, name = "ValueState") # Value states, the sentence + self.out_proj = keras.layers.Dense(self.embed_dim, name = "OutProjection") # Out Projection? + + def _shape(self, tensor, seq_len: int, bsz: int): + # Keys + a = keras.ops.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)) + return keras.layers.Permute((2, 1, 3))(a) # bs , n_head , seq_len , head_dim + + def call(self, inputs): + hidden_states, causal_attention_mask = inputs + bsz, tgt_len, embed_dim = hidden_states.shape + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), tgt_len, -1) + value_states = self._shape(self.v_proj(hidden_states), tgt_len, -1) + + proj_shape = (-1, tgt_len, self.head_dim) + query_states = self._shape(query_states, tgt_len, -1) + query_states = keras.ops.reshape(query_states, proj_shape) + key_states = keras.ops.reshape(key_states, proj_shape) + + src_len = tgt_len + value_states = keras.ops.reshape(value_states, proj_shape) + attn_weights = query_states @ keras.layers.Permute((2, 1))(key_states) + + attn_weights = keras.ops.reshape(attn_weights, (-1, self.num_heads, tgt_len, src_len)) + attn_weights = attn_weights + causal_attention_mask + attn_weights = keras.ops.reshape(attn_weights, (-1, tgt_len, src_len)) + + attn_weights = keras.ops.softmax(attn_weights) + attn_output = attn_weights @ value_states + + attn_output = keras.ops.reshape( + attn_output, (-1, self.num_heads, tgt_len, self.head_dim) + ) + attn_output = keras.layers.Permute((2, 1, 3))(attn_output) + attn_output = keras.ops.reshape(attn_output, (-1, tgt_len, embed_dim)) + + return self.out_proj(attn_output) \ No newline at end of file diff --git a/stableDiffusionKeras/samplers/DPMSolverKerasCV.py b/stableDiffusionKeras/samplers/DPMSolverKerasCV.py new file mode 100644 index 0000000..c5239d1 --- /dev/null +++ b/stableDiffusionKeras/samplers/DPMSolverKerasCV.py @@ -0,0 +1,215 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""StableDiffusion Noise scheduler + +Adapted from https://github.com/huggingface/diffusers/blob/v0.3.0/src/diffusers/schedulers/scheduling_ddpm.py#L56 + +From https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_diffusion/noise_scheduler.py +""" +import keras +import numpy as np + + +class NoiseScheduler: + """ + Args: + train_timesteps: number of diffusion steps used to train the model. + beta_start: the starting `beta` value of inference. + beta_end: the final `beta` value. + beta_schedule: + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `quadratic`. + betas: a complete set of betas, in lieu of using one of the existing schedules. + variance_type: + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample: + option to clip predicted sample between -1 and 1 for numerical stability. + """ + + def __init__( + self, + train_timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + beta_schedule="linear", + betas=None, + variance_type="fixed_small", + clip_sample=True, + ): + self.train_timesteps = train_timesteps + + if beta_schedule == "linear": + self.betas = keras.ops.linspace(beta_start, beta_end, train_timesteps) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + keras.ops.linspace(beta_start**0.5, beta_end**0.5, train_timesteps) + ** 2 + ) + else: + raise ValueError(f"Invalid beta schedule: {beta_schedule}.") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = keras.ops.cumprod(self.alphas) + + self.variance_type = variance_type + self.clip_sample = clip_sample + + def _get_variance(self, timestep, predicted_variance=None): + alpha_prod = self.alphas_cumprod[timestep] + alpha_prod_prev = ( + self.alphas_cumprod[timestep - 1] if timestep > 0 else 1.0 + ) + + variance = ( + (1 - alpha_prod_prev) / (1 - alpha_prod) * self.betas[timestep] + ) + + if self.variance_type == "fixed_small": + variance = keras.ops.clip( + variance, clip_value_min=1e-20, clip_value_max=1 + ) + elif self.variance_type == "fixed_small_log": + variance = keras.ops.log( + ( + keras.ops.clip( + variance, clip_value_min=1e-20, clip_value_max=1 + ) + ) + ) + elif self.variance_type == "fixed_large": + variance = self.betas[timestep] + elif self.variance_type == "fixed_large_log": + variance = keras.ops.log(self.betas[timestep]) + elif self.variance_type == "learned": + return predicted_variance + elif self.variance_type == "learned_range": + min_log = variance + max_log = self.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + else: + raise ValueError(f"Invalid variance type: {self.variance_type}") + + return variance + + def step( + self, + model_output, + timestep, + sample, + predict_epsilon=True, + ): + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (usually the predicted noise). + Args: + model_output: a Tensor containing direct output from learned diffusion model + timestep: current discrete timestep in the diffusion chain. + sample: a Tensor containing the current instance of sample being created by diffusion process. + predict_epsilon: whether the model is predicting noise (epsilon) or samples + Returns: + The predicted sample at the previous timestep + """ + + if model_output.shape[1] == sample.shape[ + 1 + ] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: + model_output, predicted_variance = keras.ops.split( + model_output, sample.shape[1], axis=1 + ) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod = self.alphas_cumprod[timestep] + alpha_prod_prev = ( + self.alphas_cumprod[timestep - 1] if timestep > 0 else 1.0 + ) + beta_prod = 1 - alpha_prod + beta_prod_prev = 1 - alpha_prod_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if predict_epsilon: + pred_original_sample = ( + sample - beta_prod ** (0.5) * model_output + ) / alpha_prod ** (0.5) + else: + pred_original_sample = model_output + + # 3. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = keras.ops.clip(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = ( + alpha_prod_prev ** (0.5) * self.betas[timestep] + ) / beta_prod + current_sample_coeff = ( + self.alphas[timestep] ** (0.5) * beta_prod_prev / beta_prod + ) + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = ( + pred_original_sample_coeff * pred_original_sample + + current_sample_coeff * sample + ) + + # 6. Add noise + variance = 0 + if timestep > 0: + noise = keras.random.normal(model_output.shape) + variance = ( + self._get_variance( + timestep, predicted_variance=predicted_variance + ) + ** 0.5 + ) * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample + + def add_noise( + self, + original_samples, + noise, + timesteps, + ): + sqrt_alpha_prod = keras.ops.take(self.alphas_cumprod, timesteps) ** 0.5 + sqrt_one_minus_alpha_prod = ( + 1 - keras.ops.take(self.alphas_cumprod, timesteps) + ) ** 0.5 + + for _ in range(3): + sqrt_alpha_prod = keras.ops.expand_dims(sqrt_alpha_prod, axis=-1) + sqrt_one_minus_alpha_prod = keras.ops.expand_dims( + sqrt_one_minus_alpha_prod, axis=-1 + ) + + noisy_samples = ( + sqrt_alpha_prod * original_samples + + sqrt_one_minus_alpha_prod * noise + ) + return noisy_samples + + def __len__(self): + return self.train_timesteps \ No newline at end of file diff --git a/stableDiffusionKeras/samplers/ReadMe.md b/stableDiffusionKeras/samplers/ReadMe.md new file mode 100644 index 0000000..7cbaba5 --- /dev/null +++ b/stableDiffusionKeras/samplers/ReadMe.md @@ -0,0 +1,5 @@ +### Samplers ### + +This folder contains the samplers used for calculating and generating the images. + +Because this is a TensorFlow implementation of Stable Diffusion, there are only a few options for sampling. diff --git a/stableDiffusionKeras/samplers/__init__.py b/stableDiffusionKeras/samplers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stableDiffusionKeras/samplers/basicSampler.py b/stableDiffusionKeras/samplers/basicSampler.py new file mode 100644 index 0000000..9df5525 --- /dev/null +++ b/stableDiffusionKeras/samplers/basicSampler.py @@ -0,0 +1,378 @@ +### Basic Modules +import math +import random + +### TensorFlow Modules + +import keras + +### Modules for image building +from PIL import Image +import cv2 #OpenCV + +from stableDiffusionKeras.utils import keras_print + + +class BasicSampler(): + def __init__( + self, + model = None, + timesteps = keras.ops.arange(1, 1000, 1000 // 50), + batchSize = 1, + seed = 1990, + inputImage = None, # Expecting a tensor + inputMask = None, # Expecting a tensor + inputImageStrength = 0.5, + temperature = 1, + AlphasCumprod = None, + controlNetInput = None + ): + print("...starting Basic Sampler...") + self.model = model + self.timesteps = timesteps + self.batchSize = batchSize + self.seed = seed + self.inputImage = inputImage + self.inputMask = inputMask + self.inputImageStrength = inputImageStrength + self.inputImageNoise_T = self.timesteps[ int(len(self.timesteps)*self.inputImageStrength) ] + self.temperature = temperature + self.AlphasCumprod = AlphasCumprod # Length = 1000 + + self.latent, self.alphas, self.alphas_prev, self.controlNetInput = self.getStartingParameters( + self.timesteps, + self.batchSize, + seed, + inputImage = self.inputImage, + inputImageNoise_T = self.inputImageNoise_T, + controlNetInput = controlNetInput + ) + + if self.inputImage is not None: + self.timesteps = self.timesteps[: int( len(self.timesteps) * self.inputImageStrength ) ] + + print("...sampler ready...") + + def addNoise( + self, + x, + t, + noise = None, + DType = keras.config.floatx() + ): + batch_size , w , h = x.shape[0] , x.shape[1] , x.shape[2] + if noise is None: + # Post-Encode version: + noise = keras.random.normal((batch_size,w,h,4), dtype = DType) + # Pre-Encode version: + # noise = keras.random.normal((batch_size,w,h,3), dtype = DType) + sqrt_alpha_prod = self.AlphasCumprod[t] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.AlphasCumprod[t]) ** 0.5 + + return sqrt_alpha_prod * x + sqrt_one_minus_alpha_prod * noise + + def getStartingParameters( + self, + timesteps, + batchSize, + seed, + inputImage = None, + inputImageNoise_T = None, + controlNetInput = None + ): + # Use floor division to get minimum height/width of image size + # for the Diffusion and Decoder models + floorDividedImageHeight = self.model.imageHeight // 8 + floorDividedImageWidth = self.model.imageWidth // 8 + + alphas = [self.AlphasCumprod[t] for t in timesteps] # sample steps length + alphas_prev = [1.0] + alphas[:-1] + + if inputImage is None: + # Create a random input image from noise + latent = keras.random.stateless_normal( + (batchSize, floorDividedImageHeight, floorDividedImageWidth, 4), + seed = [seed, seed] + ) + else: + ## Debug Variables + randomNumber = str(random.randint(0, 2 ** 31)) + + # Noise the input image before encoding + #latent = self.addNoise(inputImage, inputImageNoise_T) + + # Encode the given image + print(inputImage.shape) + latent = self.model.encoder(inputImage, training = False) + print(latent.shape) + #self.displayImage(latent,("encoded" + randomNumber)) + # Repeat it within the tensor for the given batch size + latent = keras.ops.repeat(latent , batchSize , axis = 0) + # Noise the image after encode + latent = self.addNoise(latent, inputImageNoise_T) + + + if controlNetInput is None: + # Create a random input image from noise + controlNetLatent = keras.random.normal( + (batchSize, floorDividedImageHeight, floorDividedImageWidth, 3), + seed = seed + ) + else: + controlNetLatent = keras.ops.repeat(controlNetInput, batchSize , axis = 0) + + return latent, alphas, alphas_prev, controlNetLatent + + def get_x_prev_and_pred_x0( + self, + x, + e_t, + index, + a_t, + a_prev, + temperature, + seed + ): + sigma_t = keras.initializers.Constant(0.0) + sqrt_one_minus_at = keras.ops.sqrt(keras.initializers.Constant(1.0) - a_t) + pred_x0 = (x - sqrt_one_minus_at * e_t) / keras.ops.sqrt(a_t) + + # Direction pointing to x_t + dir_xt = keras.ops.sqrt(keras.initializers.Constant(1.0) - a_prev - keras.ops.square(sigma_t)) * e_t + noise = sigma_t * keras.random.normal(x.shape, seed = seed) * temperature + x_prev = keras.ops.sqrt(a_prev) * pred_x0 + dir_xt + return x_prev, pred_x0 + + # Keras Version + def sample( + self, + context, + unconditionalContext, + unconditionalGuidanceScale, + controlNet = [None, 1, None], #[0]Use ControlNet, [1]Strength, [2] Cache Input + vPrediction = False, + device = None + ): + with keras.device(device): + # Progress Bar set-up + progbar = keras.utils.Progbar(len(self.timesteps)) + iteration = 0 + + # ControlNet Cache + if controlNet[2] is not None: + keras_print("...using controlNet cache...") + controlNetCache = controlNet[2] + else: + if controlNet[0] is True: + keras_print("...creating controlNet cache...") + controlNetCache = [] + + if controlNet[2] is not None and len(controlNet[2]) != len(list(enumerate(self.timesteps))[::-1]): + keras_print("...updating controlNet cache...") + controlNetCache = [] + controlNet[2] = None + + keras_print("...sampling:") + + # Iteration loop + for index, timestep in list(enumerate(self.timesteps))[::-1]: + + latentPrevious = self.latent + + # Establish timestep embedding + #t_emb = self.timestepEmbedding(float(timestep)) + t_emb = self.timestepEmbedding(int(timestep)) + t_emb = keras.ops.repeat(t_emb, self.batchSize, axis = 0) #shape is (1, 320) + + inputsConditional = [self.latent, t_emb, context] + inputsUnconditional = [self.latent, t_emb, unconditionalContext] + + if controlNet[0] is True: + + if controlNet[2] is None: + # No cache was given, so we're starting from scratch + + # Get unconditional and conditional tensors(arrays) + controlNetUnconditionalArray = self.model.controlNet( + [self.latent, t_emb, unconditionalContext, keras.ops.concatenate(self.controlNetInput, axis = 3)], + training = False + ) + controlNetConditionalArray = self.model.controlNet( + [self.latent, t_emb, context, keras.ops.concatenate(self.controlNetInput, axis = 3)], + training = False + ) + + # Apply strength + controlNetUnconditionalArray = [result * scale for result, scale in zip(controlNetUnconditionalArray, controlNet[1])] + controlNetConditionalArray = [result * scale for result, scale in zip(controlNetConditionalArray, controlNet[1])] + + # Update Cache + controlNetCacheData = { + "unconditional" : controlNetUnconditionalArray, + "conditional" : controlNetConditionalArray + } + controlNetCache.insert(0, controlNetCacheData) + + # Add the resulting tensors from the contorlNet models to the list of inputs for the diffusion models + inputsUnconditional.append(controlNetUnconditionalArray) + inputsConditional.append(controlNetConditionalArray) + else: + # Use ControlNet Cache + inputsUnconditional.extend(controlNetCache[index]["unconditional"]) + inputsConditional.extend(controlNetCache[index]["conditional"]) + + # Get unconditional (negative prompt) latent image + unconditionalLatent = self.model.diffusion_model( + inputsUnconditional, + training = False + ) + + # Get conditional (positive prompt) latent image + self.latent = self.model.diffusion_model( + inputsConditional, + training = False + ) + + # Combine the two latent images + self.latent = unconditionalLatent + unconditionalGuidanceScale * (self.latent - unconditionalLatent) + + # Alphas + a_t, a_prev = self.alphas[index], self.alphas_prev[index] + + # Predictions + if vPrediction is False: + # Debug Info + if iteration == 0: + print("Latent Previous dtype:",latentPrevious.dtype) + print("Latent dtype:",self.latent.dtype) + + # Make the data types (dtypes) match + if latentPrevious.dtype != self.latent.dtype: + latentPrevious = keras.ops.cast(latentPrevious, dtype = self.latent.dtype) + + pred_x0 = (latentPrevious - math.sqrt(1.0 - a_t) * self.latent) / math.sqrt( + a_t + ) + + self.latent = ( + self.latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0 + ) + else: + # v-Prediction for SD 2.1-V models + self.latent = self.predictEpsFromZandV(latentPrevious, index, self.latent) + + # Keras Progress Bar Update + iteration += 1 + progbar.update(iteration) + + keras_print("...finished! Returning latent image...") + + return self.latent, controlNetCache + + def predictEpsFromZandV( + self, + latent, + timestep, + velocity + ): + + #sqrt_alphas_cumprod = keras.ops.sqrt(keras.ops.cumprod([1 - alpha for alpha in self.alphas], axis = 0, exclusive = True)) + sqrt_alphas_cumprod = keras.ops.sqrt(self.alphas) + #keras_print("\nSquare Root Alphas Cumprod:\n",len(sqrt_alphas_cumprod)) + tensorShape = sqrt_alphas_cumprod.shape[0] + # sqrt_alphas_cumprod = sqrt_alphas_cumprod[timestep] + #sqrt_alphas_cumprod = keras.ops.reshape(sqrt_alphas_cumprod, (tensorShape,) + (1,) * (len(latent.shape) - 1)) + + sqrt_one_minus_alphas_cumprod = keras.ops.sqrt([1 - alpha for alpha in self.alphas]) + #keras_print("\nSquare Root Alphas Cumprod Minus One:\n",len(sqrt_one_minus_alphas_cumprod)) + tensorShape = sqrt_one_minus_alphas_cumprod.shape[0] + # sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timestep] + #sqrt_one_minus_alphas_cumprod = keras.ops.reshape(sqrt_one_minus_alphas_cumprod, (tensorShape,) + (1,) * (len(latent.shape) - 1)) + + return ( sqrt_alphas_cumprod[timestep] * latent - + sqrt_one_minus_alphas_cumprod[timestep] * velocity + ) + + def predictStartFromZandV( + self, + latent, + timestep, + velocity + ): + #sqrt_alphas_cumprod = keras.ops.sqrt(keras.ops.cumprod([1 - alpha for alpha in self.alphas], axis = 0, exclusive = True)) + sqrt_alphas_cumprod = keras.ops.sqrt(self.alphas) + tensorShape = sqrt_alphas_cumprod.shape[0] + # sqrt_alphas_cumprod = sqrt_alphas_cumprod[timestep] + #sqrt_alphas_cumprod = keras.ops.reshape(sqrt_alphas_cumprod, (tensorShape,) + (1,) * (len(latent.shape) - 1)) + + #sqrt_one_minus_alphas_cumprod = keras.ops.sqrt(1 - keras.ops.cumprod(self.alphas, axis = 0, exclusive = True)) + sqrt_one_minus_alphas_cumprod = keras.ops.sqrt([1 - alpha for alpha in self.alphas]) + tensorShape = sqrt_one_minus_alphas_cumprod.shape[0] + # sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timestep] + #sqrt_one_minus_alphas_cumprod = keras.ops.reshape(sqrt_one_minus_alphas_cumprod, (tensorShape,) + (1,) * (len(latent.shape) - 1)) + + """sqrt_alphas_cumprod_t = extractIntoTensor(sqrt_alphas_cumprod, timestep, latent.shape) + sqrt_one_minus_alphas_cumprod_t = extractIntoTensor(sqrt_one_minus_alphas_cumprod, timestep, latent.shape) + + return sqrt_alphas_cumprod_t * latent - sqrt_one_minus_alphas_cumprod_t * velocity""" + + """print(sqrt_alphas_cumprod.shape) + print(timestep) + print(sqrt_alphas_cumprod[timestep])""" + + return ( + sqrt_alphas_cumprod[timestep] * velocity + + sqrt_one_minus_alphas_cumprod[timestep] * latent + ) + + def timestepEmbedding( + self, + timesteps, + dimensions = 320, + max_period = 10000.0 + ): + half = dimensions // 2 + freqs = keras.ops.exp( + -keras.ops.log(max_period) * keras.ops.arange(0, half, dtype = keras.config.floatx()) / half + ) + args = keras.ops.convert_to_tensor([timesteps], dtype = keras.config.floatx()) * freqs + embedding = keras.ops.concatenate([keras.ops.cos(args), keras.ops.sin(args)], 0) + embedding = keras.ops.reshape(embedding, [1, -1]) + return embedding + + def displayImage(self, image, name = "sampler"): + # Assuming input_image_tensor is a TensorFlow tensor representing the image + + try: + input_image_tensor = self.model.decoder(image, training = False) + except Exception as e: + print(e) + input_image_tensor = image + + # Assuming input_image_tensor is a TensorFlow tensor representing the image + # Remove the batch dimension + input_image_tensor = keras.ops.squeeze(input_image_tensor, axis = 0) + + #keras.ops.image.resize(input_image_tensor, [self.model.imageWidth, self.model.imageHeight]) + + # Convert the tensor to a NumPy array + input_image_array = input_image_tensor.numpy() + + # Rescale the array to the range [0, 255] + input_image_array = ((input_image_array + 1) / 2.0) * 255.0 + + # Convert the array to uint8 data type + input_image_array = input_image_array.astype('uint8') + + # Display the image using Matplotlib + imageFromBatch = Image.fromarray(input_image_array) + imageFromBatch.save("debug/"+name+".png") + +""" +Utilities +""" + +def extractIntoTensor(a, t, x_shape): + b, *_ = keras.ops.shape(t) + out = keras.ops.take(a, t, axis = -1) + return keras.ops.reshape(out, (b,) + (1,) * (len(x_shape) - 1)) \ No newline at end of file diff --git a/stableDiffusionKeras/samplers/basicVSampler.py b/stableDiffusionKeras/samplers/basicVSampler.py new file mode 100644 index 0000000..c30db4b --- /dev/null +++ b/stableDiffusionKeras/samplers/basicVSampler.py @@ -0,0 +1,220 @@ +# TensorFlow Modules + +import keras + +from stableDiffusionKeras.utils import keras_print + + +class BasicSampler(): + def __init__( + self, + model = None, + timesteps = keras.ops.numpy.arange(1, 1000, 1000 // 5), + batchSize = 1, + seed = 1990, + inputImage = None, # Expecting a tensor + inputMask = None, # Expecting a tensor + inputImageStrength = 0.5, + temperature = 1, + AlphasCumprod = None + ): + print("...starting Basic Sampler...") + self.model = model + self.timesteps = timesteps + self.batchSize = batchSize + self.seed = seed + self.inputImage = inputImage + self.inputMask = inputMask + self.inputImageStrength = inputImageStrength + self.inputImageNoise_T = self.timesteps[ int(len(self.timesteps) * self.inputImageStrength) ] + self.temperature = temperature + self.AlphasCumprod = AlphasCumprod + + self.latent, self.alphas, self.alphas_prev = self.getStartingParameters( + self.timesteps, + self.batchSize, + seed, + inputImage = self.inputImage, + inputImageNoise_T = self.inputImageNoise_T + ) + + if self.inputImage is not None: + self.timesteps = self.timesteps[: int( len(self.timesteps) * self.inputImageStrength ) ] + + print("...sampler ready...") + + def addNoise( + self, + x, + t, + noise = None, + DType = keras.config.floatx() + ): + batch_size , w , h = x.shape[0] , x.shape[1] , x.shape[2] + if noise is None: + noise = keras.random.normal((batch_size,w,h,4), dtype = DType) + sqrt_alpha_prod = self.AlphasCumprod[t] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.AlphasCumprod[t]) ** 0.5 + + return sqrt_alpha_prod * x + sqrt_one_minus_alpha_prod * noise + + def getStartingParameters( + self, + timesteps, + batchSize, + seed, + inputImage = None, + inputImageNoise_T = None + ): + # Use floor division to get minimum height/width of image size + # for the Diffusion and Decoder models + floorDividedImageHeight = self.model.imageHeight // 8 + floorDividedImageWidth = self.model.imageWidth // 8 + + alphas = [self.AlphasCumprod[t] for t in timesteps] + alphas_prev = [1.0] + alphas[:-1] + + if inputImage is None: + # Create a random input image from noise + latent = keras.random.normal( + (batchSize, floorDividedImageHeight, floorDividedImageWidth, 4), + seed = seed + ) + else: + # Encode the given image + latent = self.model.encoder(inputImage, training = False) + # Repeat it within the tensor for the given batch size + latent = keras.ops.repeat(latent , batchSize , axis = 0) + # Noise the image + latent = self.addNoise(latent, inputImageNoise_T) + + return latent, alphas, alphas_prev + + def get_x_prev_and_pred_x0( + self, + x, + e_t, + index, + a_t, + a_prev, + temperature, + seed + ): + sigma_t = keras.initializers.Constant(0.0) + sqrt_one_minus_at = keras.ops.sqrt(keras.initializers.Constant(1.0).value - a_t) + pred_x0 = (x - sqrt_one_minus_at * e_t) / keras.ops.sqrt(a_t) + + # Direction pointing to x_t + dir_xt = keras.ops.sqrt(keras.initializers.Constant(1.0).value - a_prev - keras.ops.square(sigma_t)) * e_t + noise = sigma_t.value * keras.random.normal(x.shape, seed = seed) * temperature + x_prev = keras.ops.sqrt(a_prev) * pred_x0 + dir_xt + return x_prev, pred_x0 + + # Keras Version + def sample( + self, + context, + unconditionalContext, + unconditionalGuidanceScale + ): + keras_print("...sampling:") + + # Progress Bar set-up + progbar = keras.utils.Progbar(len(self.timesteps)) + iteration = 0 + + # Iteration loop + for index, timestep in list(enumerate(self.timesteps))[::-1]: + + latentPrevious = self.latent + + # Establish timestep embedding + t_emb = self.timestepEmbedding(float(timestep)) + t_emb = keras.ops.repeat(t_emb, self.batchSize, axis = 0) + + # Get unconditional (negative prompt) latent image + unconditionalLatent = self.model.diffusion_model( + [self.latent, t_emb, unconditionalContext], + training = False + ) + # Get conditional (positive prompt) latent image + self.latent = self.model.diffusion_model( + [self.latent, t_emb, context], + training = False + ) + + # Combine the two latent images, the et + self.latent = unconditionalLatent + unconditionalGuidanceScale * (self.latent - unconditionalLatent) + + # Alphas, the sigma + a_t, a_prev = self.alphas[index], self.alphas_prev[index] + + """# Predictions + predictV = (latentPrevious - keras.ops.sqrt(keras.initializers.Constant(1.0) - a_t) * self.latent) / keras.ops.sqrt( + a_t + ) + self.latent = ( + self.latent * keras.ops.sqrt(1.0 - a_prev) + keras.ops.sqrt(a_prev) * predictV + )""" + + # Predictions + predictV = (latentPrevious - keras.ops.sqrt(keras.initializers.Constant(1.0) - a_t) * self.latent) / keras.ops.sqrt( + a_t + ) + self.latent = ( + self.latent * keras.ops.sqrt(1.0 - a_prev) + keras.ops.sqrt(a_prev) * predictV + ) + + # Keras Progress Bar Update + iteration += 1 + progbar.update(iteration) + + keras_print("...finished! Returning latent image...") + + return self.latent + + + def getModelOutput( + self, + latent, + inputTimesteps, + context, + unconditionalContext, + unconditionalGuidanceScale, + batch_size, + ): + + # Establish timestep embedding + t_emb = self.timestepEmbedding(float(inputTimesteps)) + t_emb = keras.ops.repeat(t_emb, batch_size, axis = 0) + + # Get unconditional (negative prompt) latent image + unconditionalLatent = self.model.diffusion_model( + [latent, t_emb, unconditionalContext], + training = False + ) + # Get conditional (positive prompt) latent image + latent = self.model.diffusion_model( + [latent, t_emb, context], + training = False + ) + + # Combine the images and return the result + return unconditionalLatent + unconditionalGuidanceScale * ( + latent - unconditionalLatent + ) + + def timestepEmbedding( + self, + timesteps, + dimensions = 320, + max_period = 10000.0 + ): + half = dimensions // 2 + freqs = keras.ops.exp( + -keras.ops.log(max_period) * keras.ops.arange(0, half, dtype=keras.config.floatx()) / half + ) + args = keras.ops.convert_to_tensor([timesteps], dtype = keras.config.floatx()) * freqs + embedding = keras.ops.concatenate([keras.ops.cos(args), keras.ops.sin(args)], 0) + embedding = keras.ops.reshape(embedding, [1, -1]) + return embedding \ No newline at end of file diff --git a/stableDiffusionKeras/stableDiffusion.py b/stableDiffusionKeras/stableDiffusion.py new file mode 100644 index 0000000..cfbe534 --- /dev/null +++ b/stableDiffusionKeras/stableDiffusion.py @@ -0,0 +1,986 @@ +### System modules +import sys +import os +import warnings +import logging + +### Math modules +import numpy as np +import random + +### Time modules +import datetime + +### Memmory Management +import gc #Garbage Collector + +from jax import Array +### Console GUI +from rich import print, box +from rich.panel import Panel +from rich.text import Text + +from .utils import keras_print + +### Import TensorFlow module +### but with supressed warnings to clear up the terminal outputs +# Filter tensorflow version warnings +# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709 +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'} +# https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning +warnings.simplefilter(action = 'ignore', category = FutureWarning) +warnings.simplefilter(action = 'ignore', category = Warning) + +# TensorFlow module + + +# More suppressed warnings from TensorFlow +#tf.get_logger().setLevel('INFO') +#tf.autograph.set_verbosity(0) +#tf.get_logger().setLevel(logging.ERROR) + +### Keras module +import keras +from keras import backend as K + +### Models from Modules +## VAE, encode and decode +from .EncodeDecode import Decoder, ImageEncoder +# from .autoencoderKl import Decoder, Encoder +## Diffusion +from .kerasCVDiffusionModels import DiffusionModel, DiffusionModelV2 +## Text encoder +from .clipEncoder import CLIPTextTransformer # SD 1.4/1.5 +from .openClipEncoder import OpenCLIPTextTransformer # SD 2.x +## Tokenizer +from .clipTokenizer import SimpleTokenizer, LegacySimpleTokenizer +## ControlNet +from .controlNetDiffusionModels import DiffusionModel as ControlDiffusionModel +from .controlNetDiffusionModels import ControlNetDiffusionModel as ControlNetModel + +### Pytorch (for converting pytorch weights) +import torch as torch + +### Safetensors (for converting safetensor weights) +from safetensors.torch import load_file + +### Tools +from .tools import textEmbeddings as textEmbeddingTools + +### Modules for image building +from PIL import Image +import cv2 #OpenCV + +### Sampler modules +from .samplers import DPMSolverKerasCV as DPMSolver +from .samplers.basicSampler import BasicSampler + +### Global Variables +MAX_TEXT_LEN = 77 +from .constants import _ALPHAS_CUMPROD, PYTORCH_CKPT_MAPPING + +### Main Class + +class StableDiffusion: + ### Base class/object for Stable Diffusion + def __init__( + self, + imageHeight = 512, + imageWidth = 512, + jit_compile = False, + weights = None, + legacy = True, + VAE = "Original", + textEmbeddings = None, + mixedPrecision = False, + optimizer = "nadam", + device = None, + controlNet = [False, None] # [0] = Use ControlNet? [1] = ControlNet Weights [2] = Input [3] = Strength + ): + self.device = device + + with keras.device(self.device): + ### Step 1: Establish image dimensions for UNet ### + ## requires multiples of 2**7, 2 to the power of 7 + self.imageHeight = round(imageHeight / 128) * 128 + self.imageWidth = round(imageWidth / 128) * 128 + + # Global policy + self.dtype = keras.config.floatx() # Default + + # Maaaybe float16 will result in faster images? + if mixedPrecision is True: + self.changePolicy("mixed_float16") + + ### Step 2: Load Text Embeddings ### + textEmbeddingTokens = [] + if textEmbeddings == None: + keras_print("\nIgnoring Text Embeddings") + self.textEmbeddings = None + self.textEmbeddingsTokens = None + else: + keras_print("\nUsing Text Embeddings") + self.textEmbeddings, self.textEmbeddingsTokens = textEmbeddingTools.loadTextEmbedding(textEmbeddings) + + ### Step 3: Which version of Stable Diffusion ### + + self.legacy = legacy + + ### Step 4: Create Tokenizer ### + if self.legacy is True: + if self.textEmbeddings is None: + # If no textEmbeddings were given, we're not adding to the special tokens list in the tokenizer + self.tokenizer = LegacySimpleTokenizer() + else: + self.tokenizer = LegacySimpleTokenizer(specialTokens = self.textEmbeddingsTokens) + else: + if self.textEmbeddings is None: + self.tokenizer = SimpleTokenizer() + else: + self.tokenizer = SimpleTokenizer(specialTokens = self.textEmbeddingsTokens) + + ### Step 5: Create Models ### + """ + We need to create empty models before we can compile them with + the weights of the trained models. + First, let's check for pytorch weights. If given, we will load them later. + If not, then we're loading in a pre-compiled model OR weights made for TensorFlow + """ + + ## Step 5.1: Create weightless models ## + if controlNet[0] == True: + keras_print("\nUsing ControlNet",controlNet[1]) + + text_encoder, diffusion_model, decoder, encoder, control_net = CreateModels( + self.imageHeight, + self.imageWidth, + preCompiled = None, # If not None, then we're passing on Keras weights ".h5" + legacy = legacy, + addedTokens = self.textEmbeddings, + useControlNet = [controlNet[0]], + device = self.device + ) + + ## Step 5.2 Create object/class variables that point to the compiled models + self.text_encoder = text_encoder + self.diffusion_model = diffusion_model + self.decoder = decoder + self.encoder = encoder + self.controlNet = control_net + + ## Step 5.4: Load Weights + # NOTE: must be done after creating models + self.weights = weights + + self.setWeights(weights, VAE) + + ### Step 6: Load Text Embedding Weights ### + if self.textEmbeddings is not None: + if legacy is True: + CLIP = CLIPTextTransformer + else: + CLIP = OpenCLIPTextTransformer + self.text_encoder = textEmbeddingTools.loadTextEmbeddingWeight( + textEncoder = text_encoder, + CLIP = CLIP, + maxTextLength = MAX_TEXT_LEN, + embeddings = self.textEmbeddings, + legacy = legacy + ) + + ### Step 7: Load ControlNet Weights ### + if controlNet[0] == True: + if ".safetensors" in controlNet[1]: + loadWeightsFromSafeTensor( + self, + controlNet[1], # Which weights to load, in this case maybe all four models + legacy, # Which version of Stable Diffusion + ['controlNet'] # Which specific Models to load + ) + elif ".pth" in controlNet[1]: + loadWeightsFromPytorchCKPT( + self, + controlNet[1], # Which weights to load, in this case maybe all four models + legacy, # Which version of Stable Diffusion + ['controlNet'] # Which specific Models to load + ) + + ### Step 8: Compile Models ### + self.jitCompile = jit_compile + self.compileModels(optimizer, self.jitCompile) + + ## Cache + self.prompt = None + self.negativePrompt = None + self.encodedPrompt = None + self.encodedNegativePrompt = None + self.batch_size = None + self.controlNetCache = None + + def compileModels( + self, + optimizer = "nadam", + jitCompile = False + ): + modules = ['text_encoder', 'diffusion_model', 'decoder', 'encoder' ] + + if jitCompile is True: + keras_print("\nCompiling models with XLA (Accelerated Linear Algebra):") + else: + keras_print("\nCompiling models") + + with keras.device(self.device): + for module in modules: + getattr(self, module).compile( + optimizer = keras.optimizers.Adam(), + jit_compile = jitCompile + ) + print(module,"compiled.") + + """ + Generate and image, the key function + """ + def generate( + self, + prompt, + negativePrompt = None, + batch_size = 1, + num_steps = 25, + unconditional_guidance_scale = 7.5, + temperature = 1, + seed = None, + input_image = None, # expecting file path as a string or np.ndarray + input_image_strength = 0.5, + input_mask = None, # expecting a file path as a string + sampler = None, + controlNetStrength = 1, + controlNetImage = None, + controlNetCache = False, + vPrediction = False + ): + with keras.device(self.device): + ## Memory Efficiency + # Clear up tensorflow memory + keras_print("\n...cleaning memory...") + keras.backend.clear_session() + gc.collect() + + keras_print("...getting to work...") + + ### Step 1: Cache Prompts + if self.prompt != prompt: # New prompt? + # Create prompt cache + self.prompt = prompt + self.encodedPrompt = None + + if self.negativePrompt != negativePrompt: # New negative prompt? + # Create negative prompt cache + self.negativePrompt = negativePrompt + self.encodedNegativePrompt = None + + if self.batch_size != batch_size: # New batch size? + # clear prompt caches if batch_size has changed + self.encodedPrompt = None + self.encodedNegativePrompt = None + self.batch_size = batch_size + + ### Step 2: Tokenize prompts + # the tokenized prompts are AKA "starting context" + # we'll also tokenize the negative prompt, the "unconditional context" + + if self.encodedPrompt is None: + # No cached encoded prompt exists + keras_print("\n...tokenizing prompt...") + + if self.textEmbeddings is not None: + keras_print("...checking for text embeddings...") + prompt = textEmbeddingTools.injectTokens( + prompt = prompt, + embeddings = self.textEmbeddings + ) + + phrase, pos_ids = self.encodeText(prompt, batch_size, self.legacy) + + keras_print("...encoding the tokenized prompt...") + context = self.text_encoder( + [phrase, pos_ids], + training = False + ) + + # Cache encoded prompt + self.encodedPrompt = context + else: + # Load cached encoded prompt + keras_print("...using cached encoded prompt...") + context = self.encodedPrompt + + if self.encodedNegativePrompt is None: + keras_print("...tokenizing negative prompt...") + if negativePrompt is None: + # Encoding text requires a string variable + negativePrompt = "" + + if self.textEmbeddings is not None: + keras_print("...checking for text embeddings...") + negativePrompt = textEmbeddingTools.injectTokens( + prompt = negativePrompt, + embeddings = self.textEmbeddings + ) + + unconditional_tokens, pos_ids = self.encodeText(negativePrompt, batch_size, self.legacy) + + keras_print("...encoding the tokenized negative prompt...") + unconditionalContext = self.text_encoder( + [unconditional_tokens, pos_ids], + training = False + ) + + # Cache encoded negative prompt + self.encodedNegativePrompt = unconditionalContext + else: + keras_print("...using cached encoded negative prompt...") + unconditionalContext = self.encodedNegativePrompt + + ### Step 3: Prepare the input image, if it was given + ## If given, we're expecting an np.ndarry + input_image_tensor = None + if input_image is not None: + + if isinstance(input_image, np.ndarray): + print("...received NumPy Array...") + print(input_image.shape) + + input_image = keras.ops.convert_to_tensor(input_image, dtype = keras.config.floatx()) + + # Resize the image to self.imageHeight x self.imageWidth + input_image = keras.ops.image.resize(input_image, [self.imageHeight, self.imageWidth]) + + inputImageArray = keras.initializers.Constant(input_image, dtype = keras.config.floatx()) + inputImageArray = keras.ops.expand_dims(input_image[..., :3], axis = 0) + input_image_tensor = keras.ops.cast((inputImageArray / 255.0) * 2 - 1, self.dtype) + + print(input_image_tensor.shape) + #displayImage(input_image_tensor, name = "1preppedImage") + elif isinstance(input_image, Array): + print("...received jax.Array (JAX Array)...") + input_image_tensor = input_image + #displayImage(input_image_tensor, name = "1preppedImage") + + ### Step 4: Prepare the image mask, if it was given + if type(input_mask) is str: + print("...preparing input mask...") + input_mask = Image.open(input_mask) + input_mask = input_mask.resize((self.imageWidth, self.imageHeight)) + input_mask_array = np.array(input_mask, dtype = np.float32)[None,...,None] + input_mask_array = input_mask_array / 255.0 + + latent_mask = input_mask.resize((self.imageWidth // 8, self.imageHeight // 8)) + latent_mask = np.array(latent_mask, dtype = np.float32)[None,...,None] + latent_mask = 1 - (latent_mask.astype("float") / 255.0) + latent_mask_tensor = keras.ops.cast(keras.ops.repeat(latent_mask, batch_size , axis = 0), self.dtype) + else: + latent_mask_tensor = None + + ### Step 5: Create a random seed if one is not provided + if seed is None: + keras_print("...generating random seed...") + seed = random.randint(1000, sys.maxsize) + seed = int(seed) + else: + seed = int(seed) + + ### Step 6: Create time steps + keras_print("...creating time steps...") + timesteps = keras.ops.arange(1, 1000, 1000 // num_steps) + + ### Step 7: Load Sampler and: + ### Step 8: Start Diffusion + if sampler == "DPMSolver": + keras_print("...using DPM Solver...\n...starting sampler...") + + alphasCumprod = keras.initializers.Constant(_ALPHAS_CUMPROD) + + noiseScheduler = DPMSolver.NoiseScheduler( + beta_schedule = "scaled_linear" + ) + + print("...starting diffusion...\n...this solver not supported yet!\nDividing by zero now:\n") + + x = 5 / 0 + else: + if sampler is None: keras_print("...no sampler given...") + + # ControlNet + # Parameters: [0]Use ControlNet, [1] Input Image, [2]Strength, [3] Cache Input + if self.controlNet is not None: + controlNetImage = [keras.initializers.Constant(controlNetImage[0].copy(), dtype = keras.config.floatx()) / 255.0] + if controlNetCache is False: + self.controlNetCache = None + if type(self.controlNetCache) is dict: + if len(self.controlNetCache["unconditional"]) != timesteps: + keras_print("Incompatible cache!") + self.controlNetCache = None + controlNetParamters = [True, controlNetImage, controlNetStrength, self.controlNetCache] + else: + controlNetParamters = [False, None, 1, None] + + # Create Sampler + sampler = BasicSampler( + model = self, + timesteps = timesteps, + batchSize = batch_size, + seed = seed, + inputImage = input_image_tensor, + inputMask = latent_mask_tensor, + inputImageStrength = input_image_strength, + temperature = temperature, + AlphasCumprod = _ALPHAS_CUMPROD, + controlNetInput = controlNetParamters[1] # Input Image, assuming pre-processed + ) + + if vPrediction is True: + keras_print("...using v-prediction...") + + # Sample, create image essentially + latentImage, self.controlNetCache = sampler.sample( + context, + unconditionalContext, + unconditional_guidance_scale, + controlNet = [controlNetParamters[0], controlNetParamters[2], controlNetParamters[3]], # [0]Use Control Net, [2]Strength, [3]Cache + vPrediction = vPrediction, + device = self.device + ) + + ### Step 9: Decoding stage + keras_print("\n...decoding latent image...") + decoded = self.decoder( + latentImage, + training = False + ) + decoded = ((decoded + 1) / 2) * 255 + + ### Step 10: Merge inpainting result of input mask with original image + if input_mask is not None: + decoded = inputImageArray * (1-input_mask_array) + np.array(decoded) * input_mask_array + + ### Memory cleanup + gc.collect() + + ### Step 11: return final image as an array + return np.clip(decoded, 0, 255).astype("uint8") + + def changePolicy(self, policy): + + if policy == "mixed_float16": + #self.dtype = tf.float16 + if keras.mixed_precision.global_policy().name != 'mixed_float16': + print("\n...using mixed precision...") + keras.mixed_precision.set_global_policy('mixed_float16') + #self.dtype = tf.float16 + + if policy == "float32": + #self.dtype = keras.config.floatx() + if keras.mixed_precision.global_policy().name != 'float32': + print("\n...using regular precision...") + keras.mixed_precision.set_global_policy('float32') + #self.dtype = keras.config.floatx() + + def encodeText( + self, + prompt, + batch_size, + legacy + ): + TextLimit = MAX_TEXT_LEN - 1 + with keras.device(self.device): + if legacy is True: + # First, encode the prompt + inputs = self.tokenizer.encode(prompt) + # Then check the inputs length and truncate if too long + if len(inputs) > TextLimit: + keras_print("Prompt is too long (should be less than 77 words). Truncating down to 77 words...") + inputs = inputs[:TextLimit] + + """## Create numpy array with the inputs + # Phrase - aka the prompt + phrase = [49406] + inputs + [49407] * (TextLimit - len(inputs)) + phrase = np.array(phrase)[None].astype("int32") + phrase = np.repeat(phrase, batch_size, axis = 0) + + # Position ID + pos_ids = np.array(list(range(77)))[None].astype("int32") + pos_ids = np.repeat(pos_ids, batch_size, axis = 0)""" + + # Phrase - aka the prompt + phrase = keras.ops.concatenate([[49406], inputs, [49407] * (TextLimit - len(inputs))], axis=0) + phrase = keras.ops.expand_dims(phrase, axis=0) + phrase = keras.ops.repeat(phrase, batch_size, axis=0) + phrase = keras.ops.cast(phrase, dtype="int32") + + # Position ID + pos_ids = keras.ops.expand_dims(keras.ops.arange(77), axis=0) + pos_ids = keras.ops.repeat(pos_ids, batch_size, axis=0) + pos_ids = keras.ops.cast(pos_ids, dtype="int32") + else: + # First, encode the prompt + TextLimit += 1 + if isinstance(prompt, str): + inputs = [prompt] + # Then tokenize the prompt + startOfToken = self.tokenizer.encoder[""] + endOfToken = self.tokenizer.encoder[""] + allTokens = [[startOfToken] + self.tokenizer.encode(input) + [endOfToken] for input in inputs] + # Create the empty tensor/numpy array to load the tokens into + phrase = np.zeros((len(allTokens), TextLimit), dtype = np.int32) + + for i, tokens in enumerate(allTokens): + if len(tokens) > TextLimit: + tokens = tokens[:TextLimit] # Truncate + tokens[-1] = endOfToken + phrase[i, :len(tokens)] = np.array(tokens) + + phrase = np.repeat(phrase, batch_size, axis = 0) + + pos_ids = np.array(list(range(TextLimit)))[None].astype("int32") + pos_ids = np.repeat(pos_ids, batch_size, axis = 0) + + return phrase, pos_ids + + def setWeights(self, weights, VAE = "Original"): + self.weights = weights + # Load weights for VAE models, if given + if VAE != "Original": + if ".ckpt" in VAE: + loadWeightsFromPytorchCKPT( + self, + VAE, # Which weights to load, in this case weights for VAE + self.legacy, # Which version of Stable Diffusion + ['decoder', 'encoder'], # Models to load + True + ) + elif ".safetensors" in VAE: + loadWeightsFromSafeTensor( + self, + VAE, # Which weights to load, in this case weights for VAE + self.legacy, # Which version of Stable Diffusion + ['decoder', 'encoder'], # Models to load + True + ) + else: + loadWeightsFromKeras( + self, + VAE, + VAEOnly = True + ) + + # Load all weights + if ".ckpt" in self.weights: + if VAE == "Original": # Load all weights from PyTorch .ckpt + modules = ['text_encoder', 'diffusion_model', 'decoder', 'encoder' ] + else: # only load weights for the text encoder and diffusion model if VAE was given + modules = ['text_encoder', 'diffusion_model'] + loadWeightsFromPytorchCKPT( + self, + self.weights, # Which weights to load, in this case maybe all four models + self.legacy, # Which version of Stable Diffusion + modules # Which specific Models to load + ) + elif ".safetensors" in self.weights: + if VAE == "Original": # Load all weights from PyTorch .ckpt + modules = ['text_encoder', 'diffusion_model', 'decoder', 'encoder' ] + else: # only load weights for the text encoder and diffusion model if VAE was given + modules = ['text_encoder', 'diffusion_model'] + loadWeightsFromSafeTensor( + self, + self.weights, # Which weights to load, in this case maybe all four models + self.legacy, # Which version of Stable Diffusion + modules # Which specific Models to load + ) + else: + if VAE == "Original": + loadWeightsFromKeras( + self, + self.weights, + VAEOnly = False + ) + else: + loadWeightsFromKeras( + self, + self.weights, + VAEOnly = VAE + ) + +### Functions ### + +def CreateModels( + imageHeight = 512, + imageWidth = 512, + preCompiled = None, + legacy = True, + addedTokens = 0, + useControlNet = [False], + device = None +): + with keras.device(device): + # Memory Clean up + keras.backend.clear_session() + gc.collect() + + controlNet = None + + if legacy is True: + # Are we using Pre-Stable Diffusion 2.0? + + keras_print("\nCreating models in legacy mode...") + + # Create Text Encoder model + input_word_ids = keras.layers.Input(shape = (MAX_TEXT_LEN,), dtype = "int32") + input_pos_ids = keras.layers.Input(shape = (MAX_TEXT_LEN,), dtype = "int32") + embeds = CLIPTextTransformer()([input_word_ids, input_pos_ids]) + text_encoder = keras.models.Model([input_word_ids, input_pos_ids], embeds) + keras_print("Created text encoder model") + + if useControlNet[0] is False: + # Create Diffusion model + diffusion_model = DiffusionModel( + imageHeight, + imageWidth, + MAX_TEXT_LEN + ) + keras_print("Created diffusion model") + else: + # Create seperate control net model + controlNet = ControlNetModel( + imageHeight, + imageWidth, + MAX_TEXT_LEN + ) + + keras_print("Created ControlNet Model") + + # Create Diffusion model + diffusion_model = ControlDiffusionModel( + imageHeight, + imageWidth, + MAX_TEXT_LEN + ) + keras_print("Created diffusion model") + + # Create Decoder model + decoder = Decoder( + img_height = imageHeight, + img_width = imageWidth, + ) + keras_print("Created decoder model") + + # Create Image Encoder model + encoder = ImageEncoder( + img_height = imageHeight, + img_width = imageWidth + ) + keras_print("Created encoder model") + + else: + # We're using SD 2.0 and newer + + print("\nCreating models in contemporary mode...") + + # Create Text Encoder model + input_word_ids = keras.layers.Input(shape = (MAX_TEXT_LEN,), dtype = "int32") + input_pos_ids = keras.layers.Input(shape = (MAX_TEXT_LEN,), dtype = "int32") + embeds = OpenCLIPTextTransformer()([input_word_ids, input_pos_ids]) + text_encoder = keras.models.Model([input_word_ids, input_pos_ids], embeds) + print("Created text encoder model") + + # Create Diffusion model + diffusion_model = DiffusionModelV2( + imageHeight, + imageWidth, + MAX_TEXT_LEN + ) + print("Created diffusion model") + + # Create Decoder model + decoder = Decoder( + img_height = imageHeight, + img_width = imageWidth, + ) + print("Created decoder model") + + # Create Image Encoder model + encoder = ImageEncoder( + img_height = imageHeight, + img_width = imageWidth + ) + print("Created encoder model") + + # return created models + return text_encoder, diffusion_model, decoder , encoder, controlNet + +def loadWeightsFromKeras( + models, + weightsPath, + VAEOnly = False +): + keras_print("\nLoading Keras weights for:", weightsPath) + textEncoderWeights = weightsPath + "/text_encoder.h5" + diffusionModelWeights = weightsPath + "/diffusion_model.h5" + imageEncoderWeights = weightsPath + "/encoder.h5" + decoderWeights = weightsPath + "/decoder.h5" + + if VAEOnly is False: + models.text_encoder.load_weights(textEncoderWeights) + keras_print("...Text Encoder weights loaded!") + models.diffusion_model.load_weights(diffusionModelWeights) + keras_print("...diffusion model weights loaded") + models.encoder.load_weights(imageEncoderWeights) + keras_print("...Image Encoder weights loaded!") + models.decoder.load_weights(decoderWeights) + keras_print("...Decoder weights loaded!") + keras_print("All weights loaded!") + +def loadWeightsFromPytorchCKPT( + model, + pytorch_ckpt_path, + legacy = True, + moduleName = ['text_encoder', 'diffusion_model', 'decoder', 'encoder' ], + VAEoverride = False + ): + print("\nLoading pytorch checkpoint " + pytorch_ckpt_path) + pytorchWeights = torch.load(pytorch_ckpt_path, map_location = "mps") + if legacy is True: + ## Legacy Mode + print("...loading pytroch weights in legacy mode...") + for module in moduleName: + module_weights = [] + if module == "text_encoder": + module = "text_encoder_legacy" + for i , (key , perm ) in enumerate(PYTORCH_CKPT_MAPPING[module]): + if VAEoverride is True: + key = key.replace("first_stage_model.","") + if 'state_dict' in pytorchWeights: + weight = pytorchWeights['state_dict'][key].detach().numpy() + else: + weight = pytorchWeights[key].detach().numpy() + if perm is not None: + weight = np.transpose(weight , perm) + module_weights.append(weight) + if module == "text_encoder_legacy": + module = "text_encoder" + + getattr(model, module).set_weights(module_weights) + + print("Loaded %d pytorch weights for %s"%(len(module_weights) , module)) + else: + ## Contemporary Mode + print("...loading pytorch weights in contemporary mode...") + for module in moduleName: + module_weights = [] + in_projWeightConversion = [] + in_projBiasConversion = [] + for i , (key , perm ) in enumerate(PYTORCH_CKPT_MAPPING[module]): + if "in_proj" not in key: + if VAEoverride is True: + key = key.replace("first_stage_model.","") + weight = pytorchWeights['state_dict'][key].detach().numpy() + + if module == "diffusion_model": + if "proj_in.weight" in key or "proj_out.weight" in key: + #print(i+1," Overriding premuation from constants:\n",key) + # This is so the constants.py "diffusion_model" dictionary keeps its legacy state + perm = (1,0) + + if perm is not None: + weight = np.transpose(weight , perm ) + module_weights.append(weight) + else: + if module == "text_encoder": + # "in_proj" layer of SD2.x is a matrix multiplcation of the query, key, and value layers of SD1.4/5 + # We will slice this layer into the the three vectors + if "weight" in key: + # Get the in_proj.weight + originalWeight = pytorchWeights['state_dict'][key].float().numpy() + + queryWeight = originalWeight[:1024, ...] + queryWeight = np.transpose(queryWeight, (1,0)) + + keyWeight = originalWeight[1024:2048, ...] + keyWeight = np.transpose(keyWeight, (1,0)) + + valueWeight = originalWeight[2048:, ...] + valueWeight = np.transpose(valueWeight, (1,0)) + + # Clear local variable to carry forward for bias + in_projWeightConversion = [] + + in_projWeightConversion.append(queryWeight) # Query states + in_projWeightConversion.append(keyWeight) # Key states + in_projWeightConversion.append(valueWeight) # Value states + elif "bias" in key: + originalBias = pytorchWeights['state_dict'][key].float().numpy() + + queryBias = originalBias[:1024] + + keyBias = originalBias[1024:2048] + + valueBias = originalBias[2048:] + + # Clear local variable to carry forward for bias + in_projBiasConversion = [] + + in_projBiasConversion.append(queryBias) # Query states + in_projBiasConversion.append(keyBias) # Key states + in_projBiasConversion.append(valueBias) # Value states + + # add the converted weights/biases in the correct order + # Query + module_weights.append(in_projWeightConversion[0]) + module_weights.append(in_projBiasConversion[0]) + # Key + module_weights.append(in_projWeightConversion[1]) + module_weights.append(in_projBiasConversion[1]) + # Value + module_weights.append(in_projWeightConversion[2]) + module_weights.append(in_projBiasConversion[2]) + + print("Loading weights for ", module) + + getattr(model, module).set_weights(module_weights) + print("Loaded %d pytorch weights for %s"%(len(module_weights) , module)) + + ## Memory Clean up + del pytorchWeights + +def loadWeightsFromSafeTensor( + model, + safetensor_path, + legacy = True, + moduleName = ['text_encoder', 'diffusion_model', 'decoder', 'encoder' ], + VAEoverride = False + ): + print("\nLoading safetensor " + safetensor_path) + safeTensorWeights = load_file(safetensor_path) + if legacy is True: + ## Legacy Mode + print("...loading safetensors weights in legacy mode...") + for module in moduleName: + module_weights = [] + if module == "text_encoder": + module = "text_encoder_legacy" + for i , (key , perm ) in enumerate(PYTORCH_CKPT_MAPPING[module]): + if VAEoverride is True: + key = key.replace("first_stage_model.","") + if 'state_dict' in safeTensorWeights: + weight = safeTensorWeights['state_dict'][key].detach().numpy() + else: + if module == "controlNet": + # Repalce "control_model." in case the safetensor doesn't have that key + key = key.replace("control_model.","") + weight = safeTensorWeights[key].detach().numpy() + if perm is not None: + weight = np.transpose(weight , perm) + module_weights.append(weight) + if module == "text_encoder_legacy": + module = "text_encoder" + + getattr(model, module).set_weights(module_weights) + + print("Loaded %d safetensors weights for %s"%(len(module_weights) , module)) + else: + ## Contemporary Mode + print("...loading safetensors weights in contemporary mode...") + for module in moduleName: + module_weights = [] + in_projWeightConversion = [] + in_projBiasConversion = [] + for i , (key , perm ) in enumerate(PYTORCH_CKPT_MAPPING[module]): + if "in_proj" not in key: + if VAEoverride is True: + key = key.replace("first_stage_model.","") + weight = safeTensorWeights['state_dict'][key].detach().numpy() + + if module == "diffusion_model": + if "proj_in.weight" in key or "proj_out.weight" in key: + #print(i+1," Overriding premuation from constants:\n",key) + # This is so the constants.py "diffusion_model" dictionary keeps its legacy state + perm = (1,0) + + if perm is not None: + weight = np.transpose(weight , perm ) + module_weights.append(weight) + else: + if module == "text_encoder": + # "in_proj" layer of SD2.x is a matrix multiplcation of the query, key, and value layers of SD1.4/5 + # We will slice this layer into the the three vectors + if "weight" in key: + # Get the in_proj.weight + originalWeight = safeTensorWeights['state_dict'][key].float().numpy() + + queryWeight = originalWeight[:1024, ...] + queryWeight = np.transpose(queryWeight, (1,0)) + + keyWeight = originalWeight[1024:2048, ...] + keyWeight = np.transpose(keyWeight, (1,0)) + + valueWeight = originalWeight[2048:, ...] + valueWeight = np.transpose(valueWeight, (1,0)) + + # Clear local variable to carry forward for bias + in_projWeightConversion = [] + + in_projWeightConversion.append(queryWeight) # Query states + in_projWeightConversion.append(keyWeight) # Key states + in_projWeightConversion.append(valueWeight) # Value states + elif "bias" in key: + originalBias = safeTensorWeights['state_dict'][key].float().numpy() + + queryBias = originalBias[:1024] + + keyBias = originalBias[1024:2048] + + valueBias = originalBias[2048:] + + # Clear local variable to carry forward for bias + in_projBiasConversion = [] + + in_projBiasConversion.append(queryBias) # Query states + in_projBiasConversion.append(keyBias) # Key states + in_projBiasConversion.append(valueBias) # Value states + + # add the converted weights/biases in the correct order + # Query + module_weights.append(in_projWeightConversion[0]) + module_weights.append(in_projBiasConversion[0]) + # Key + module_weights.append(in_projWeightConversion[1]) + module_weights.append(in_projBiasConversion[1]) + # Value + module_weights.append(in_projWeightConversion[2]) + module_weights.append(in_projBiasConversion[2]) + + print("Loading weights for ", module) + + getattr(model, module).set_weights(module_weights) + print("Loaded %d safetensors weights for %s"%(len(module_weights) , module)) + + ## Memory Clean up + del safeTensorWeights + + +def displayImage(input_image_tensor, name = "image"): + # Assuming input_image_tensor is a TensorFlow tensor representing the image + # Remove the batch dimension + input_image_tensor = keras.ops.squeeze(input_image_tensor, axis = 0) + + # Convert the tensor to a NumPy array + input_image_array = input_image_tensor.numpy() + + # Rescale the array to the range [0, 255] + input_image_array = ((input_image_array + 1) / 2.0) * 255.0 + + # Convert the array to uint8 data type + input_image_array = input_image_array.astype('uint8') + + # Display the image using Matplotlib + imageFromBatch = Image.fromarray(input_image_array) + imageFromBatch.save("debug/"+name+".png") \ No newline at end of file diff --git a/stableDiffusionKeras/tools/ReadMe.md b/stableDiffusionKeras/tools/ReadMe.md new file mode 100644 index 0000000..ea9cba8 --- /dev/null +++ b/stableDiffusionKeras/tools/ReadMe.md @@ -0,0 +1,3 @@ +## Tools ## + +These files are helpful tools the TensorFlow pipeline uses. diff --git a/stableDiffusionKeras/tools/__init__.py b/stableDiffusionKeras/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stableDiffusionKeras/tools/textEmbeddings.py b/stableDiffusionKeras/tools/textEmbeddings.py new file mode 100644 index 0000000..370e87a --- /dev/null +++ b/stableDiffusionKeras/tools/textEmbeddings.py @@ -0,0 +1,246 @@ +import numpy as np + +import keras +import torch as torch + +from stableDiffusionKeras.utils import keras_print + + +class Embedding: + """ + This is an object class that stores the loaded Text Embedding + It only needs two key variables to exist: + Name: unique name of embedding + Vector: vector(s) of the embedding + """ + def __init__(self, vector, name, step = None): + self.vector = vector + self.name = name + + # Adjust the vector shape to (x, 768) + # This is for single vector text embeddings, which may come as a (768,) instead of (1,768) + if self.vector.ndim < 2: + if self.vector.shape[0] == 768: # Stable Diffusion 1.4/1.5 + self.vector = self.vector.reshape((1,768)) + elif self.vector.shape[0] == 1024: # Stable Diffusion 2.x + self.vector = self.vector.reshape((1,1024)) + + # Create the unique tokens + if self.vector.shape[0] > 1: + #If we have a multidimensional vector, then we'll split up the token per dimension + self.token = [] + for dimension in range(self.vector.shape[0]): + self.token.append("<" + self.name + "_" + str(dimension) + ">") + self.name = "<" + self.name + ">" + else: + # Single dimension vector, so the token is the name + self.name = "<" + self.name + ">" + self.token = self.name + + # Extra info + self.step = step + self.shape = self.vector.shape + self.vectors = 0 + self.cached_checksum = None + self.sd_checkpoint = None + self.sd_checkpoint_name = None + self.optimizer_state_dict = None + self.filename = self.name + ".pt" + + def save(self, filename): + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + "sd_checkpoint": self.sd_checkpoint, + "sd_checkpoint_name": self.sd_checkpoint_name, + } + + print("If I could save, I'd save this:\n",embedding_data) + +def injectTokens( + prompt, + embeddings +): + """ + This code searches the given prompt for any of the given embeddings and replaces it with the proper embedding for the tokenizer + + Only necessary for multi-vector embeddings because we've split the token up per vector. + + For example, if we have a multi-vector embedding like this: + Token: + Vectors: (3,768) + + Then, in the creation of the embedding class, we've automatically created the actual token for the tokenizer: + Token: + Vecors: (3,768) + + So, if we find a multi-vector token, then we replace it's user-friendly token name with the actual token name. For example: + + Prompt: A picture of , painted by Caravaggio + + becomes + + Prompt: A picture of , painted by Caravaggio + """ + prompt = prompt.lower() + foundTokens = 0 + + for embedding in embeddings: + if embedding.name in prompt: + # First, let's prepare the replacement tokens + replacementToken = "" + if type(embedding.token) is str: + replacementToken = embedding.name + elif type(embedding.token) is list: + for token in embedding.token: + replacementToken = replacementToken + " " + token + foundTokens += 1 + prompt = prompt.replace(embedding.name, replacementToken) + + keras_print("...found",foundTokens,"text embedding token(s)...") + + return prompt + +def loadTextEmbedding( + textEmbeddings + ): + """ + Using pytorch, we load in the text embedding weights as numpy arrays and store them in the Embeddings object class + + textEmbeddings REQUIRES a list expecting the first index to have the file path. For example: + + ['models/embeddings/','myEmbedding.pt','myOtherEmbedding.bin','etc.pt'] + + The code then seperates the file path as a variable and uses it to find the embeddings + """ + finalTextEmbeddings = [] + tokensToAdd = [] + # save file path into seperate location + embeddingsPath = textEmbeddings[0] + # delete file path from list + del textEmbeddings[0] + + for textEmbedding in textEmbeddings: + print("\nLoading text embedding " + textEmbedding) + # Load the text embedding file + textEmbeddingFile = torch.load(embeddingsPath + textEmbedding, map_location = "cpu") + + # Debug Info + # print("Data for",textEmbedding,"\n",textEmbeddingFile) + # print(textEmbeddingFile.keys()) # Shows the entire file data, which should be a dictionary + + if "pt" in textEmbedding: + # load the necessary values + stringToToken = textEmbeddingFile["string_to_token"] # Token assigned to vector + stringToParam = textEmbeddingFile["string_to_param"] # The vector(s) + textEmbeddingName = textEmbedding.replace(".pt","") + elif "bin" in textEmbedding: + # load the necessary values + for key, value in textEmbeddingFile.items(): + stringToToken = key # Token assigned to vector + stringToParam = value # The vector + textEmbeddingName = textEmbedding.replace(".bin","") + + # Save the token for finding the vector + if type(stringToToken) is dict: + token = list(stringToToken.keys())[0] # Convert dictionary to a list and then pull the first value + else: + token = stringToToken + + # Save the vector by finding it with the token + if type(stringToToken) is dict: + textEmbeddingVector = stringToParam[token] + else: + textEmbeddingVector = stringToParam + + # Debug info + # print("Weight type:\n",type(textEmbeddingVector)) + # print("Vector shape:\n", textEmbeddingVector.shape) + + # Make the token lowercase + token = textEmbeddingName.lower() + print("Unique Token: ","<"+token+">") + + embedding = Embedding(name = token, vector = textEmbeddingVector.detach().numpy()) + try: + embedding.step = textEmbeddingFile["step"] + embedding.sd_checkpoint_name = textEmbeddingFile["sd_checkpoint_name"] + except Exception as e: + embedding.step = 0 + embedding.sd_checkpoint_name = "N/A" + + finalTextEmbeddings.append(embedding) + + if type(embedding.token) is str: + tokensToAdd.append(embedding.token) + elif type(embedding.token) is list: + tokensToAdd.extend(embedding.token) + + # Memory Clean up + del textEmbeddingFile + + # add file path back to list for re-compiling later, if needed + textEmbeddings.insert(0,embeddingsPath) + + return finalTextEmbeddings, tokensToAdd + +def loadTextEmbeddingWeight( + textEncoder, + CLIP, + maxTextLength, + embeddings, + legacy +): + """ + This code is where the magic happens with Text Embeddings. + We're going to add our text embeddings to the Text Encoder Model + """ + keras_print("\nLoading Text Embedding weights...") + + if legacy == True: + columnLength = 768 + else: + columnLength = 1024 + + # First get the current weights of the text encoder + originalWeights = textEncoder.get_weights() + + # Find the "token_embedding" weights + updatedWeights = originalWeights[0] + successfulTokenCount = 0 + + # Add our token vectors to the "token_embedding" weights + for embedding in embeddings: + if np.size(embedding.vector[0]) != columnLength: + # if our vector column length doesn't match our version of stable diffusion, then skip this embedding + print(embedding.name,"not compatible with current version of Stable Diffusion") + continue + + # Add our vectors to the weights for the "token_embeddings" + updatedWeights = np.vstack((updatedWeights, embedding.vector)) + + # Update our token count, taking multidimensional vectors into account + if type(embedding.token) is list: + successfulTokenCount += len(embedding.token) + else: + successfulTokenCount += 1 + + keras_print("...found all compatible embeddings, total:",successfulTokenCount,"...") + + # Create new Text Encoder model, increasing the size of tokens for the CLIP model + keras_print("...creating new text encoder model with embeddings") + input_word_ids = keras.layers.Input(shape = (maxTextLength,), dtype = "int32") + input_pos_ids = keras.layers.Input(shape = (maxTextLength,), dtype = "int32") + embeds = CLIP(vocabularySize = 49408 + successfulTokenCount)([input_word_ids, input_pos_ids]) + textEncoder = keras.models.Model([input_word_ids, input_pos_ids], embeds) + keras_print("...created text encoder model with", successfulTokenCount,"token(s) added") + + # Update the weights for "token_embedding" and then set the weights of the model + keras_print("...setting updated weights for token_embedding...") + originalWeights[0] = updatedWeights + textEncoder.set_weights(originalWeights) + keras_print("...weights loaded!") + + return textEncoder \ No newline at end of file diff --git a/stableDiffusionKeras/tools/tools.py b/stableDiffusionKeras/tools/tools.py new file mode 100644 index 0000000..00bc3dc --- /dev/null +++ b/stableDiffusionKeras/tools/tools.py @@ -0,0 +1,9 @@ + + +def getWeightsAndNames(model): + # For finding the order of weights + names = [weight.name for layer in model.layers for weight in layer.weights] + weights = model.get_weights() + + for name, weight in zip(names, weights): + keras_print(name,"\n",weight.shape) \ No newline at end of file diff --git a/stableDiffusionKeras/utils.py b/stableDiffusionKeras/utils.py new file mode 100644 index 0000000..24ba047 --- /dev/null +++ b/stableDiffusionKeras/utils.py @@ -0,0 +1,22 @@ +from keras.api import backend + + +def keras_print(*args, **kwargs): + back_end = backend.backend() + if back_end == "tensorflow": + import tensorflow as tf + return keras_print(*args, **kwargs) + elif back_end == "jax": + import jax.debug + return jax.debug.print(*args, **kwargs) + else: + return print(*args, **kwargs) + # print_fn = {"jax": jax.debug.print, + # "tensorflow": keras_print}.get(backend, print) + # "torch" https://pytorch.org/docs/stable/generated/torch.set_printoptions.html ? + # "openvino" + # "numpy" + # return print_fn(*args, **kwargs) + + +__all__ = ["keras_print"]