mirror of
https://github.com/believethehype/nostrdvm.git
synced 2025-03-17 21:31:52 +01:00
adding MLX backend and example for stable diffusion 2.1
from https://github.com/ml-explore/mlx-examples
This commit is contained in:
parent
64710b4d1d
commit
15d3384dce
0
backends/__init__.py
Normal file
0
backends/__init__.py
Normal file
0
backends/mlx/__init__.py
Normal file
0
backends/mlx/__init__.py
Normal file
0
backends/mlx/stable_diffusion/__init__.py
Normal file
0
backends/mlx/stable_diffusion/__init__.py
Normal file
70
backends/mlx/stable_diffusion/clip.py
Normal file
70
backends/mlx/stable_diffusion/clip.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .config import CLIPTextModelConfig
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
"""The transformer encoder layer from CLIP."""
|
||||
|
||||
def __init__(self, model_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||
self.layer_norm2 = nn.LayerNorm(model_dims)
|
||||
|
||||
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
|
||||
# Add biases to the attention projections to match CLIP
|
||||
self.attention.query_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.key_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.value_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||
|
||||
def __call__(self, x, attn_mask=None):
|
||||
y = self.layer_norm1(x)
|
||||
y = self.attention(y, y, y, attn_mask)
|
||||
x = y + x
|
||||
|
||||
y = self.layer_norm2(x)
|
||||
y = self.linear1(y)
|
||||
y = nn.gelu_approx(y)
|
||||
y = self.linear2(y)
|
||||
x = y + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIPTextModel(nn.Module):
|
||||
"""Implements the text encoder transformer from CLIP."""
|
||||
|
||||
def __init__(self, config: CLIPTextModelConfig):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||
self.layers = [
|
||||
CLIPEncoderLayer(config.model_dims, config.num_heads)
|
||||
for i in range(config.num_layers)
|
||||
]
|
||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||
|
||||
def __call__(self, x):
|
||||
# Extract some shapes
|
||||
B, N = x.shape
|
||||
|
||||
# Compute the embeddings
|
||||
x = self.token_embedding(x)
|
||||
x = x + self.position_embedding.weight[:N]
|
||||
|
||||
# Compute the features from the transformer
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
|
||||
# Apply the final layernorm and return
|
||||
return self.final_layer_norm(x)
|
48
backends/mlx/stable_diffusion/config.py
Normal file
48
backends/mlx/stable_diffusion/config.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoencoderConfig:
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
latent_channels_out: int = 8
|
||||
latent_channels_in: int = 4
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
scaling_factor: float = 0.18215
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPTextModelConfig:
|
||||
num_layers: int = 23
|
||||
model_dims: int = 1024
|
||||
num_heads: int = 16
|
||||
max_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNetConfig:
|
||||
in_channels: int = 4
|
||||
out_channels: int = 4
|
||||
conv_in_kernel: int = 3
|
||||
conv_out_kernel: int = 3
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2)
|
||||
mid_block_layers: int = 2
|
||||
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
|
||||
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||
norm_num_groups: int = 32
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
beta_schedule: str = "scaled_linear"
|
||||
beta_start: float = 0.00085
|
||||
beta_end: float = 0.012
|
||||
num_train_steps: int = 1000
|
292
backends/mlx/stable_diffusion/model_io.py
Normal file
292
backends/mlx/stable_diffusion/model_io.py
Normal file
@ -0,0 +1,292 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open as safetensor_open
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .clip import CLIPTextModel
|
||||
from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig
|
||||
from .tokenizer import Tokenizer
|
||||
from .unet import UNetModel
|
||||
from .vae import Autoencoder
|
||||
|
||||
|
||||
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
||||
_MODELS = {
|
||||
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
||||
"stabilityai/stable-diffusion-2-1-base": {
|
||||
"unet_config": "unet/config.json",
|
||||
"unet": "unet/diffusion_pytorch_model.safetensors",
|
||||
"text_encoder_config": "text_encoder/config.json",
|
||||
"text_encoder": "text_encoder/model.safetensors",
|
||||
"vae_config": "vae/config.json",
|
||||
"vae": "vae/diffusion_pytorch_model.safetensors",
|
||||
"diffusion_config": "scheduler/scheduler_config.json",
|
||||
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||
"tokenizer_merges": "tokenizer/merges.txt",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _from_numpy(x):
|
||||
return mx.array(np.ascontiguousarray(x))
|
||||
|
||||
|
||||
def map_unet_weights(key, value):
|
||||
# Map up/downsampling
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map attention layers
|
||||
if "to_k" in key:
|
||||
key = key.replace("to_k", "key_proj")
|
||||
if "to_out.0" in key:
|
||||
key = key.replace("to_out.0", "out_proj")
|
||||
if "to_q" in key:
|
||||
key = key.replace("to_q", "query_proj")
|
||||
if "to_v" in key:
|
||||
key = key.replace("to_v", "value_proj")
|
||||
|
||||
# Map transformer ffn
|
||||
if "ff.net.2" in key:
|
||||
key = key.replace("ff.net.2", "linear3")
|
||||
if "ff.net.0" in key:
|
||||
k1 = key.replace("ff.net.0.proj", "linear1")
|
||||
k2 = key.replace("ff.net.0.proj", "linear2")
|
||||
v1, v2 = np.split(value, 2)
|
||||
|
||||
return [(k1, _from_numpy(v1)), (k2, _from_numpy(v2))]
|
||||
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
# Transform the weights from 1x1 convs to linear
|
||||
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
|
||||
|
||||
def map_clip_text_encoder_weights(key, value):
|
||||
# Remove prefixes
|
||||
if key.startswith("text_model."):
|
||||
key = key[11:]
|
||||
if key.startswith("embeddings."):
|
||||
key = key[11:]
|
||||
if key.startswith("encoder."):
|
||||
key = key[8:]
|
||||
|
||||
# Map attention layers
|
||||
if "self_attn." in key:
|
||||
key = key.replace("self_attn.", "attention.")
|
||||
if "q_proj." in key:
|
||||
key = key.replace("q_proj.", "query_proj.")
|
||||
if "k_proj." in key:
|
||||
key = key.replace("k_proj.", "key_proj.")
|
||||
if "v_proj." in key:
|
||||
key = key.replace("v_proj.", "value_proj.")
|
||||
|
||||
# Map ffn layers
|
||||
if "mlp.fc1" in key:
|
||||
key = key.replace("mlp.fc1", "linear1")
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
|
||||
|
||||
def map_vae_weights(key, value):
|
||||
# Map up/downsampling
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map attention layers
|
||||
if "to_k" in key:
|
||||
key = key.replace("to_k", "key_proj")
|
||||
if "to_out.0" in key:
|
||||
key = key.replace("to_out.0", "out_proj")
|
||||
if "to_q" in key:
|
||||
key = key.replace("to_q", "query_proj")
|
||||
if "to_v" in key:
|
||||
key = key.replace("to_v", "value_proj")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map the quant/post_quant layers
|
||||
if "quant_conv" in key:
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
|
||||
# Map the conv_shortcut to linear
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
|
||||
|
||||
def _flatten(params):
|
||||
return [(k, v) for p in params for (k, v) in p]
|
||||
|
||||
|
||||
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
||||
dtype = np.float16 if float16 else np.float32
|
||||
with safetensor_open(weight_file, framework="numpy") as f:
|
||||
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
|
||||
model.update(tree_unflatten(weights))
|
||||
|
||||
|
||||
def _check_key(key: str, part: str):
|
||||
if key not in _MODELS:
|
||||
raise ValueError(
|
||||
f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
|
||||
)
|
||||
|
||||
|
||||
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
"""Load the stable diffusion UNet from Hugging Face Hub."""
|
||||
_check_key(key, "load_unet")
|
||||
|
||||
# Download the config and create the model
|
||||
unet_config = _MODELS[key]["unet_config"]
|
||||
with open(hf_hub_download(key, unet_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
n_blocks = len(config["block_out_channels"])
|
||||
model = UNetModel(
|
||||
UNetConfig(
|
||||
in_channels=config["in_channels"],
|
||||
out_channels=config["out_channels"],
|
||||
block_out_channels=config["block_out_channels"],
|
||||
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
||||
num_attention_heads=[config["attention_head_dim"]] * n_blocks
|
||||
if isinstance(config["attention_head_dim"], int)
|
||||
else config["attention_head_dim"],
|
||||
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
unet_weights = _MODELS[key]["unet"]
|
||||
weight_file = hf_hub_download(key, unet_weights)
|
||||
_load_safetensor_weights(map_unet_weights, model, weight_file, float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
"""Load the stable diffusion text encoder from Hugging Face Hub."""
|
||||
_check_key(key, "load_text_encoder")
|
||||
|
||||
# Download the config and create the model
|
||||
text_encoder_config = _MODELS[key]["text_encoder_config"]
|
||||
with open(hf_hub_download(key, text_encoder_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = CLIPTextModel(
|
||||
CLIPTextModelConfig(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
text_encoder_weights = _MODELS[key]["text_encoder"]
|
||||
weight_file = hf_hub_download(key, text_encoder_weights)
|
||||
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
"""Load the stable diffusion autoencoder from Hugging Face Hub."""
|
||||
_check_key(key, "load_autoencoder")
|
||||
|
||||
# Download the config and create the model
|
||||
vae_config = _MODELS[key]["vae_config"]
|
||||
with open(hf_hub_download(key, vae_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = Autoencoder(
|
||||
AutoencoderConfig(
|
||||
in_channels=config["in_channels"],
|
||||
out_channels=config["out_channels"],
|
||||
latent_channels_out=2 * config["latent_channels"],
|
||||
latent_channels_in=config["latent_channels"],
|
||||
block_out_channels=config["block_out_channels"],
|
||||
layers_per_block=config["layers_per_block"],
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
vae_weights = _MODELS[key]["vae"]
|
||||
weight_file = hf_hub_download(key, vae_weights)
|
||||
_load_safetensor_weights(map_vae_weights, model, weight_file, float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_diffusion_config(key: str = _DEFAULT_MODEL):
|
||||
"""Load the stable diffusion config from Hugging Face Hub."""
|
||||
_check_key(key, "load_diffusion_config")
|
||||
|
||||
diffusion_config = _MODELS[key]["diffusion_config"]
|
||||
with open(hf_hub_download(key, diffusion_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
return DiffusionConfig(
|
||||
beta_start=config["beta_start"],
|
||||
beta_end=config["beta_end"],
|
||||
beta_schedule=config["beta_schedule"],
|
||||
num_train_steps=config["num_train_timesteps"],
|
||||
)
|
||||
|
||||
|
||||
def load_tokenizer(key: str = _DEFAULT_MODEL):
|
||||
_check_key(key, "load_tokenizer")
|
||||
|
||||
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
|
||||
with open(merges_file, encoding="utf-8") as f:
|
||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||
|
||||
return Tokenizer(bpe_ranks, vocab)
|
74
backends/mlx/stable_diffusion/sampler.py
Normal file
74
backends/mlx/stable_diffusion/sampler.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from .config import DiffusionConfig
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def _linspace(a, b, num):
|
||||
x = mx.arange(0, num) / (num - 1)
|
||||
return (b - a) * x + a
|
||||
|
||||
|
||||
def _interp(y, x_new):
|
||||
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
|
||||
x_low = x_new.astype(mx.int32)
|
||||
x_high = mx.minimum(x_low + 1, len(y) - 1)
|
||||
|
||||
y_low = y[x_low]
|
||||
y_high = y[x_high]
|
||||
delta_x = x_new - x_low
|
||||
y_new = y_low * (1 - delta_x) + delta_x * y_high
|
||||
|
||||
return y_new
|
||||
|
||||
|
||||
class SimpleEulerSampler:
|
||||
"""A simple Euler integrator that can be used to sample from our diffusion models.
|
||||
|
||||
The method ``step()`` performs one Euler step from x_t to x_t_prev.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
# Compute the noise schedule
|
||||
if config.beta_schedule == "linear":
|
||||
betas = _linspace(
|
||||
config.beta_start, config.beta_end, config.num_train_steps
|
||||
)
|
||||
elif config.beta_schedule == "scaled_linear":
|
||||
betas = _linspace(
|
||||
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
|
||||
).square()
|
||||
else:
|
||||
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
|
||||
|
||||
alphas = 1 - betas
|
||||
alphas_cumprod = mx.cumprod(alphas)
|
||||
|
||||
self._sigmas = mx.concatenate(
|
||||
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
|
||||
)
|
||||
|
||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||
noise = mx.random.normal(shape, key=key)
|
||||
return (
|
||||
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
|
||||
).astype(dtype)
|
||||
|
||||
def sigmas(self, t):
|
||||
return _interp(self._sigmas, t)
|
||||
|
||||
def timesteps(self, num_steps: int, dtype=mx.float32):
|
||||
steps = _linspace(len(self._sigmas) - 1, 0, num_steps + 1).astype(dtype)
|
||||
return list(zip(steps, steps[1:]))
|
||||
|
||||
def step(self, eps_pred, x_t, t, t_prev):
|
||||
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
||||
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
||||
|
||||
dt = sigma_prev - sigma
|
||||
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
|
||||
|
||||
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
||||
|
||||
return x_t_prev
|
100
backends/mlx/stable_diffusion/tokenizer.py
Normal file
100
backends/mlx/stable_diffusion/tokenizer.py
Normal file
@ -0,0 +1,100 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import regex
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||
|
||||
def __init__(self, bpe_ranks, vocab):
|
||||
self.bpe_ranks = bpe_ranks
|
||||
self.vocab = vocab
|
||||
self.pat = regex.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
return "<|startoftext|>"
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self.vocab[self.bos]
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
return "<|endoftext|>"
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self.vocab[self.eos]
|
||||
|
||||
def bpe(self, text):
|
||||
if text in self._cache:
|
||||
return self._cache[text]
|
||||
|
||||
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
if not unique_bigrams:
|
||||
return unigrams
|
||||
|
||||
# In every iteration try to merge the two most likely bigrams. If none
|
||||
# was merged we are done.
|
||||
#
|
||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||
while unique_bigrams:
|
||||
bigram = min(
|
||||
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
||||
)
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
|
||||
new_unigrams = []
|
||||
skip = False
|
||||
for a, b in zip(unigrams, unigrams[1:]):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
if (a, b) == bigram:
|
||||
new_unigrams.append(a + b)
|
||||
skip = True
|
||||
|
||||
else:
|
||||
new_unigrams.append(a)
|
||||
|
||||
if not skip:
|
||||
new_unigrams.append(b)
|
||||
|
||||
unigrams = new_unigrams
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
self._cache[text] = unigrams
|
||||
|
||||
return unigrams
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||
|
||||
# Lower case cleanup and split according to self.pat. Hugging Face does
|
||||
# a much more thorough job here but this should suffice for 95% of
|
||||
# cases.
|
||||
clean_text = regex.sub(r"\s+", " ", text.lower())
|
||||
tokens = regex.findall(self.pat, clean_text)
|
||||
|
||||
# Split the tokens according to the byte-pair merge file
|
||||
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
||||
|
||||
# Map to token ids and return
|
||||
tokens = [self.vocab[t] for t in bpe_tokens]
|
||||
if prepend_bos:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
return tokens
|
425
backends/mlx/stable_diffusion/unet.py
Normal file
425
backends/mlx/stable_diffusion/unet.py
Normal file
@ -0,0 +1,425 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .config import UNetConfig
|
||||
|
||||
|
||||
def upsample_nearest(x, scale: int = 2):
|
||||
B, H, W, C = x.shape
|
||||
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
||||
x = x.reshape(B, H * scale, W * scale, C)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = nn.silu(x)
|
||||
x = self.linear_2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dims: int,
|
||||
num_heads: int,
|
||||
hidden_dims: Optional[int] = None,
|
||||
memory_dims: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(model_dims)
|
||||
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
|
||||
self.attn1.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
memory_dims = memory_dims or model_dims
|
||||
self.norm2 = nn.LayerNorm(model_dims)
|
||||
self.attn2 = nn.MultiHeadAttention(
|
||||
model_dims, num_heads, key_input_dims=memory_dims
|
||||
)
|
||||
self.attn2.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
hidden_dims = hidden_dims or 4 * model_dims
|
||||
self.norm3 = nn.LayerNorm(model_dims)
|
||||
self.linear1 = nn.Linear(model_dims, hidden_dims)
|
||||
self.linear2 = nn.Linear(model_dims, hidden_dims)
|
||||
self.linear3 = nn.Linear(hidden_dims, model_dims)
|
||||
|
||||
def __call__(self, x, memory, attn_mask, memory_mask):
|
||||
# Self attention
|
||||
y = self.norm1(x)
|
||||
y = self.attn1(y, y, y, attn_mask)
|
||||
x = x + y
|
||||
|
||||
# Cross attention
|
||||
y = self.norm2(x)
|
||||
y = self.attn2(y, memory, memory, memory_mask)
|
||||
x = x + y
|
||||
|
||||
# FFN
|
||||
y = self.norm3(x)
|
||||
y_a = self.linear1(y)
|
||||
y_b = self.linear2(y)
|
||||
y = y_a * nn.gelu_approx(y_b) # approximate gelu?
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Transformer2D(nn.Module):
|
||||
"""A transformer model for inputs with 2 spatial dimensions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_dims: int,
|
||||
encoder_dims: int,
|
||||
num_heads: int,
|
||||
num_layers: int = 1,
|
||||
norm_num_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
|
||||
self.proj_in = nn.Linear(in_channels, model_dims)
|
||||
self.transformer_blocks = [
|
||||
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.proj_out = nn.Linear(model_dims, in_channels)
|
||||
|
||||
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
||||
# Save the input to add to the output
|
||||
input_x = x
|
||||
|
||||
# Perform the input norm and projection
|
||||
B, H, W, C = x.shape
|
||||
x = self.norm(x).reshape(B, -1, C)
|
||||
x = self.proj_in(x)
|
||||
|
||||
# Apply the transformer
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
|
||||
# Apply the output projection and reshape
|
||||
x = self.proj_out(x)
|
||||
x = x.reshape(B, H, W, C)
|
||||
|
||||
return x + input_x
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
groups: int = 32,
|
||||
temb_channels: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def __call__(self, x, temb=None):
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(nn.silu(temb))
|
||||
|
||||
y = self.norm1(x)
|
||||
y = nn.silu(y)
|
||||
y = self.conv1(y)
|
||||
if temb is not None:
|
||||
y = y + temb[:, None, None, :]
|
||||
y = self.norm2(y)
|
||||
y = nn.silu(y)
|
||||
y = self.conv2(y)
|
||||
|
||||
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class UNetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
prev_out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
num_attention_heads: int = 8,
|
||||
cross_attention_dim=1280,
|
||||
resnet_groups: int = 32,
|
||||
add_downsample=True,
|
||||
add_upsample=True,
|
||||
add_cross_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Prepare the in channels list for the resnets
|
||||
if prev_out_channels is None:
|
||||
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
|
||||
else:
|
||||
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
|
||||
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
|
||||
in_channels_list = [
|
||||
a + b for a, b in zip(in_channels_list, res_channels_list)
|
||||
]
|
||||
|
||||
# Add resnet blocks that also process the time embedding
|
||||
self.resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ic,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
)
|
||||
for ic in in_channels_list
|
||||
]
|
||||
|
||||
# Add optional cross attention layers
|
||||
if add_cross_attention:
|
||||
self.attentions = [
|
||||
Transformer2D(
|
||||
in_channels=out_channels,
|
||||
model_dims=out_channels,
|
||||
num_heads=num_attention_heads,
|
||||
num_layers=transformer_layers_per_block,
|
||||
encoder_dims=cross_attention_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
if add_upsample:
|
||||
self.upsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
encoder_x=None,
|
||||
temb=None,
|
||||
attn_mask=None,
|
||||
encoder_attn_mask=None,
|
||||
residual_hidden_states=None,
|
||||
):
|
||||
output_states = []
|
||||
|
||||
for i in range(len(self.resnets)):
|
||||
if residual_hidden_states is not None:
|
||||
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
|
||||
|
||||
x = self.resnets[i](x, temb)
|
||||
|
||||
if "attentions" in self:
|
||||
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
|
||||
output_states.append(x)
|
||||
|
||||
if "downsample" in self:
|
||||
x = self.downsample(x)
|
||||
output_states.append(x)
|
||||
|
||||
if "upsample" in self:
|
||||
x = self.upsample(upsample_nearest(x))
|
||||
output_states.append(x)
|
||||
|
||||
return x, output_states
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""The conditional 2D UNet model that actually performs the denoising."""
|
||||
|
||||
def __init__(self, config: UNetConfig):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
config.in_channels,
|
||||
config.block_out_channels[0],
|
||||
config.conv_in_kernel,
|
||||
padding=(config.conv_in_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
self.timesteps = nn.SinusoidalPositionalEncoding(
|
||||
config.block_out_channels[0],
|
||||
max_freq=1,
|
||||
min_freq=math.exp(
|
||||
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
|
||||
),
|
||||
scale=1.0,
|
||||
cos_first=True,
|
||||
full_turns=False,
|
||||
)
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
config.block_out_channels[0],
|
||||
config.block_out_channels[0] * 4,
|
||||
)
|
||||
|
||||
# Make the downsampling blocks
|
||||
block_channels = [config.block_out_channels[0]] + list(
|
||||
config.block_out_channels
|
||||
)
|
||||
self.down_blocks = [
|
||||
UNetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
num_layers=config.layers_per_block[i],
|
||||
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
cross_attention_dim=config.cross_attention_dim[i],
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||
add_upsample=False,
|
||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(
|
||||
zip(block_channels, block_channels[1:])
|
||||
)
|
||||
]
|
||||
|
||||
# Make the middle block
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
out_channels=config.block_out_channels[-1],
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
groups=config.norm_num_groups,
|
||||
),
|
||||
Transformer2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
model_dims=config.block_out_channels[-1],
|
||||
num_heads=config.num_attention_heads[-1],
|
||||
num_layers=config.transformer_layers_per_block[-1],
|
||||
encoder_dims=config.cross_attention_dim[-1],
|
||||
),
|
||||
ResnetBlock2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
out_channels=config.block_out_channels[-1],
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
groups=config.norm_num_groups,
|
||||
),
|
||||
]
|
||||
|
||||
# Make the upsampling blocks
|
||||
block_channels = (
|
||||
[config.block_out_channels[0]]
|
||||
+ list(config.block_out_channels)
|
||||
+ [config.block_out_channels[-1]]
|
||||
)
|
||||
self.up_blocks = [
|
||||
UNetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
prev_out_channels=prev_out_channels,
|
||||
num_layers=config.layers_per_block[i] + 1,
|
||||
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
cross_attention_dim=config.cross_attention_dim[i],
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=(i > 0),
|
||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
)
|
||||
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
||||
list(
|
||||
enumerate(
|
||||
zip(block_channels, block_channels[1:], block_channels[2:])
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
config.norm_num_groups,
|
||||
config.block_out_channels[0],
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.conv_out = nn.Conv2d(
|
||||
config.block_out_channels[0],
|
||||
config.out_channels,
|
||||
config.conv_out_kernel,
|
||||
padding=(config.conv_out_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
|
||||
|
||||
# Compute the time embeddings
|
||||
temb = self.timesteps(timestep).astype(x.dtype)
|
||||
temb = self.time_embedding(temb)
|
||||
|
||||
# Preprocess the input
|
||||
x = self.conv_in(x)
|
||||
|
||||
# Run the downsampling part of the unet
|
||||
residuals = [x]
|
||||
for block in self.down_blocks:
|
||||
x, res = block(
|
||||
x,
|
||||
encoder_x=encoder_x,
|
||||
temb=temb,
|
||||
attn_mask=attn_mask,
|
||||
encoder_attn_mask=encoder_attn_mask,
|
||||
)
|
||||
residuals.extend(res)
|
||||
|
||||
# Run the middle part of the unet
|
||||
x = self.mid_blocks[0](x, temb)
|
||||
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
x = self.mid_blocks[2](x, temb)
|
||||
|
||||
# Run the upsampling part of the unet
|
||||
for block in self.up_blocks:
|
||||
x, _ = block(
|
||||
x,
|
||||
encoder_x=encoder_x,
|
||||
temb=temb,
|
||||
attn_mask=attn_mask,
|
||||
encoder_attn_mask=encoder_attn_mask,
|
||||
residual_hidden_states=residuals,
|
||||
)
|
||||
|
||||
# Postprocess the output
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
268
backends/mlx/stable_diffusion/vae.py
Normal file
268
backends/mlx/stable_diffusion/vae.py
Normal file
@ -0,0 +1,268 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .config import AutoencoderConfig
|
||||
from .unet import ResnetBlock2D, upsample_nearest
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""A single head unmasked attention for use with the VAE."""
|
||||
|
||||
def __init__(self, dims: int, norm_groups: int = 32):
|
||||
super().__init__()
|
||||
|
||||
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
|
||||
self.query_proj = nn.Linear(dims, dims)
|
||||
self.key_proj = nn.Linear(dims, dims)
|
||||
self.value_proj = nn.Linear(dims, dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x):
|
||||
B, H, W, C = x.shape
|
||||
|
||||
y = self.group_norm(x)
|
||||
|
||||
queries = self.query_proj(y).reshape(B, H * W, C)
|
||||
keys = self.key_proj(y).reshape(B, H * W, C)
|
||||
values = self.value_proj(y).reshape(B, H * W, C)
|
||||
|
||||
scale = 1 / math.sqrt(queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 2, 1)
|
||||
attn = mx.softmax(scores, axis=-1)
|
||||
y = (attn @ values).reshape(B, H, W, C)
|
||||
|
||||
y = self.out_proj(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EncoderDecoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
resnet_groups: int = 32,
|
||||
add_downsample=True,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add the resnet blocks
|
||||
self.resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
groups=resnet_groups,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
if add_upsample:
|
||||
self.upsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x)
|
||||
|
||||
if "downsample" in self:
|
||||
x = self.downsample(x)
|
||||
|
||||
if "upsample" in self:
|
||||
x = self.upsample(upsample_nearest(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Implements the encoder side of the Autoencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
channels = [block_out_channels[0]] + list(block_out_channels)
|
||||
self.down_blocks = [
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=i < len(block_out_channels) - 1,
|
||||
add_upsample=False,
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
||||
]
|
||||
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for l in self.down_blocks:
|
||||
x = l(x)
|
||||
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Implements the decoder side of the Autoencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
channels = list(reversed(block_out_channels))
|
||||
channels = [channels[0]] + channels
|
||||
self.up_blocks = [
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=i < len(block_out_channels) - 1,
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[0], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
for l in self.up_blocks:
|
||||
x = l(x)
|
||||
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Autoencoder(nn.Module):
|
||||
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
||||
|
||||
def __init__(self, config: AutoencoderConfig):
|
||||
super().__init__()
|
||||
|
||||
self.latent_channels = config.latent_channels_in
|
||||
self.scaling_factor = config.scaling_factor
|
||||
self.encoder = Encoder(
|
||||
config.in_channels,
|
||||
config.latent_channels_out,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
config.latent_channels_in,
|
||||
config.out_channels,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block + 1,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
)
|
||||
|
||||
self.quant_proj = nn.Linear(
|
||||
config.latent_channels_out, config.latent_channels_out
|
||||
)
|
||||
self.post_quant_proj = nn.Linear(
|
||||
config.latent_channels_in, config.latent_channels_in
|
||||
)
|
||||
|
||||
def decode(self, z):
|
||||
return self.decoder(self.post_quant_proj(z))
|
||||
|
||||
def __call__(self, x, key=None):
|
||||
x = self.encoder(x)
|
||||
x = self.quant_proj(x)
|
||||
|
||||
mean, logvar = x.split(2, axis=-1)
|
||||
std = mx.exp(0.5 * logvar)
|
||||
z = mx.random.normal(mean.shape, key=key) * std + mean
|
||||
|
||||
x_hat = self.decode(z)
|
||||
|
||||
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
11
main.py
11
main.py
@ -5,7 +5,8 @@ import dotenv
|
||||
from nostr_dvm.bot import Bot
|
||||
from nostr_dvm.tasks import videogeneration_replicate_svd, imagegeneration_replicate_sdxl, textgeneration_llmlite, \
|
||||
trending_notes_nostrband, discovery_inactive_follows, translation_google, textextraction_pdf, \
|
||||
translation_libretranslate, textextraction_google, convert_media, imagegeneration_openai_dalle, texttospeech
|
||||
translation_libretranslate, textextraction_google, convert_media, imagegeneration_openai_dalle, texttospeech, \
|
||||
imagegeneration_mlx, advanced_search, textextraction_whisper_mlx
|
||||
from nostr_dvm.utils.admin_utils import AdminConfig
|
||||
from nostr_dvm.utils.backend_utils import keep_alive
|
||||
from nostr_dvm.utils.definitions import EventDefinitions
|
||||
@ -138,6 +139,14 @@ def playground():
|
||||
bot_config.SUPPORTED_DVMS.append(tts)
|
||||
tts.run()
|
||||
|
||||
from sys import platform
|
||||
if platform == "darwin":
|
||||
# Test with MLX for OSX M1/M2/M3 chips
|
||||
mlx = imagegeneration_mlx.build_example("SD with MLX", "mlx_sd", admin_config)
|
||||
bot_config.SUPPORTED_DVMS.append(mlx)
|
||||
mlx.run()
|
||||
|
||||
|
||||
# Run the bot
|
||||
Bot(bot_config)
|
||||
# Keep the main function alive for libraries that require it, like openai
|
||||
|
186
nostr_dvm/tasks/imagegeneration_mlx.py
Normal file
186
nostr_dvm/tasks/imagegeneration_mlx.py
Normal file
@ -0,0 +1,186 @@
|
||||
import json
|
||||
import os
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from nostr_dvm.interfaces.dvmtaskinterface import DVMTaskInterface
|
||||
from nostr_dvm.utils.admin_utils import AdminConfig
|
||||
from nostr_dvm.utils.definitions import EventDefinitions
|
||||
from nostr_dvm.utils.dvmconfig import DVMConfig, build_default_config
|
||||
from nostr_dvm.utils.nip89_utils import NIP89Config, check_and_set_d_tag
|
||||
from nostr_dvm.utils.output_utils import upload_media_to_hoster
|
||||
from nostr_dvm.utils.zap_utils import get_price_per_sat
|
||||
|
||||
"""
|
||||
This File contains a Module to generate an Image on replicate and receive results back.
|
||||
|
||||
Accepted Inputs: Prompt (text)
|
||||
Outputs: An url to an Image
|
||||
Params:
|
||||
"""
|
||||
|
||||
|
||||
class ImageGenerationMLX(DVMTaskInterface):
|
||||
KIND: int = EventDefinitions.KIND_NIP90_GENERATE_IMAGE
|
||||
TASK: str = "text-to-image"
|
||||
FIX_COST: float = 120
|
||||
dependencies = [("nostr-dvm", "nostr-dvm"),
|
||||
("mlx", "mlx"),
|
||||
("safetensors", "safetensors"),
|
||||
("huggingface-hub", "huggingface-hub"),
|
||||
("regex", "regex"),
|
||||
("tqdm", "tqdm"),
|
||||
]
|
||||
|
||||
def __init__(self, name, dvm_config: DVMConfig, nip89config: NIP89Config,
|
||||
admin_config: AdminConfig = None, options=None):
|
||||
dvm_config.SCRIPT = os.path.abspath(__file__)
|
||||
super().__init__(name, dvm_config, nip89config, admin_config, options)
|
||||
|
||||
def is_input_supported(self, tags):
|
||||
for tag in tags:
|
||||
if tag.as_vec()[0] == 'i':
|
||||
input_value = tag.as_vec()[1]
|
||||
input_type = tag.as_vec()[2]
|
||||
if input_type != "text":
|
||||
return False
|
||||
|
||||
elif tag.as_vec()[0] == 'output':
|
||||
output = tag.as_vec()[1]
|
||||
if (output == "" or
|
||||
not (output == "image/png" or "image/jpg"
|
||||
or output == "image/png;format=url" or output == "image/jpg;format=url")):
|
||||
print("Output format not supported, skipping..")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def create_request_from_nostr_event(self, event, client=None, dvm_config=None):
|
||||
request_form = {"jobID": event.id().to_hex() + "_" + self.NAME.replace(" ", "")}
|
||||
prompt = ""
|
||||
width = "1024"
|
||||
height = "1024"
|
||||
|
||||
for tag in event.tags():
|
||||
if tag.as_vec()[0] == 'i':
|
||||
input_type = tag.as_vec()[2]
|
||||
if input_type == "text":
|
||||
prompt = tag.as_vec()[1]
|
||||
|
||||
elif tag.as_vec()[0] == 'param':
|
||||
print("Param: " + tag.as_vec()[1] + ": " + tag.as_vec()[2])
|
||||
if tag.as_vec()[1] == "size":
|
||||
if len(tag.as_vec()) > 3:
|
||||
width = (tag.as_vec()[2])
|
||||
height = (tag.as_vec()[3])
|
||||
elif len(tag.as_vec()) == 3:
|
||||
split = tag.as_vec()[2].split("x")
|
||||
if len(split) > 1:
|
||||
width = split[0]
|
||||
height = split[1]
|
||||
elif tag.as_vec()[1] == "model":
|
||||
model = tag.as_vec()[2]
|
||||
elif tag.as_vec()[1] == "quality":
|
||||
quality = tag.as_vec()[2]
|
||||
|
||||
options = {
|
||||
"prompt": prompt,
|
||||
"size": width + "x" + height,
|
||||
"number": 1
|
||||
}
|
||||
request_form['options'] = json.dumps(options)
|
||||
|
||||
return request_form
|
||||
|
||||
def process(self, request_form):
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from backends.mlx.stable_diffusion import StableDiffusion
|
||||
options = DVMTaskInterface.set_options(request_form)
|
||||
|
||||
sd = StableDiffusion()
|
||||
cfg_weight = 7.5
|
||||
batchsize = 1
|
||||
n_rows = 1
|
||||
steps = 50
|
||||
n_images = options["number"]
|
||||
|
||||
# Generate the latent vectors using diffusion
|
||||
latents = sd.generate_latents(
|
||||
options["prompt"],
|
||||
n_images=n_images,
|
||||
cfg_weight=cfg_weight,
|
||||
num_steps=steps,
|
||||
negative_text="",
|
||||
)
|
||||
for x_t in tqdm(latents, total=steps):
|
||||
mx.simplify(x_t)
|
||||
mx.simplify(x_t)
|
||||
mx.eval(x_t)
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, 1, batchsize)):
|
||||
decoded.append(sd.decode(x_t[i: i + batchsize]))
|
||||
mx.eval(decoded[-1])
|
||||
|
||||
# Arrange them on a grid
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(n_rows * H, B // n_rows * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
image = Image.fromarray(x.__array__())
|
||||
image.save("./outputs/image.jpg")
|
||||
result = upload_media_to_hoster("./outputs/image.jpg")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print("Error in Module")
|
||||
raise Exception(e)
|
||||
|
||||
|
||||
# We build an example here that we can call by either calling this file directly from the main directory,
|
||||
# or by adding it to our playground. You can call the example and adjust it to your needs or redefine it in the
|
||||
# playground or elsewhere
|
||||
def build_example(name, identifier, admin_config):
|
||||
dvm_config = build_default_config(identifier)
|
||||
admin_config.LUD16 = dvm_config.LN_ADDRESS
|
||||
profit_in_sats = 10
|
||||
dvm_config.FIX_COST = int(((4.0 / (get_price_per_sat("USD") * 100)) + profit_in_sats))
|
||||
|
||||
nip89info = {
|
||||
"name": name,
|
||||
"image": "https://image.nostr.build/c33ca6fc4cc038ca4adb46fdfdfda34951656f87ee364ef59095bae1495ce669.jpg",
|
||||
"about": "I use Replicate to run StableDiffusion XL",
|
||||
"encryptionSupported": True,
|
||||
"cashuAccepted": True,
|
||||
"nip90Params": {
|
||||
"size": {
|
||||
"required": False,
|
||||
"values": ["1024:1024", "1024x1792", "1792x1024"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nip89config = NIP89Config()
|
||||
nip89config.DTAG = check_and_set_d_tag(identifier, name, dvm_config.PRIVATE_KEY, nip89info["image"])
|
||||
nip89config.CONTENT = json.dumps(nip89info)
|
||||
|
||||
return ImageGenerationMLX(name=name, dvm_config=dvm_config, nip89config=nip89config,
|
||||
admin_config=admin_config)
|
||||
|
||||
|
||||
def process_venv():
|
||||
args = DVMTaskInterface.process_args()
|
||||
dvm_config = build_default_config(args.identifier)
|
||||
dvm = ImageGenerationMLX(name="", dvm_config=dvm_config, nip89config=NIP89Config(), admin_config=None)
|
||||
result = dvm.process(json.loads(args.request))
|
||||
DVMTaskInterface.write_output(result, args.output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
process_venv()
|
@ -300,10 +300,6 @@ def get_price_per_sat(currency):
|
||||
|
||||
|
||||
def make_ln_address_nostdress(identifier, npub, pin, nostdressdomain):
|
||||
# env_path = Path('.env')
|
||||
# if env_path.is_file():
|
||||
# dotenv.load_dotenv(env_path, verbose=True, override=True)
|
||||
|
||||
print(os.getenv("LNBITS_INVOICE_KEY_" + identifier.upper()))
|
||||
data = {
|
||||
'name': identifier,
|
||||
|
10
setup.py
10
setup.py
@ -1,6 +1,6 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
VERSION = '0.0.9'
|
||||
VERSION = '0.1.0'
|
||||
DESCRIPTION = 'A framework to build and run Nostr NIP90 Data Vending Machines'
|
||||
LONG_DESCRIPTION = ('A framework to build and run Nostr NIP90 Data Vending Machines. '
|
||||
'This is an early stage release. Interfaces might change/brick')
|
||||
@ -13,8 +13,10 @@ setup(
|
||||
author_email="believethehypeonnostr@proton.me",
|
||||
description=DESCRIPTION,
|
||||
long_description=LONG_DESCRIPTION,
|
||||
packages=find_packages(include=['nostr_dvm', 'nostr_dvm.backends', 'nostr_dvm.interfaces', 'nostr_dvm.tasks',
|
||||
'nostr_dvm.utils', 'nostr_dvm.utils.scrapper']),
|
||||
packages=find_packages(include=['nostr_dvm', 'nostr_dvm.interfaces', 'nostr_dvm.tasks',
|
||||
'nostr_dvm.utils', 'nostr_dvm.utils.scrapper',
|
||||
'nostr_dvm.backends', 'nostr_dvm.backends.mlx',
|
||||
'nostr_dvm.backends.mlx.stablediffusion']),
|
||||
install_requires=["nostr-sdk==0.0.5",
|
||||
"bech32==1.2.0",
|
||||
"pycryptodome==3.19.0",
|
||||
@ -32,7 +34,7 @@ setup(
|
||||
"moviepy==2.0.0.dev2",
|
||||
"zipp==3.17.0",
|
||||
"urllib3==2.1.0",
|
||||
"typing_extensions==4.8.0"
|
||||
"typing_extensions>=4.9.0"
|
||||
],
|
||||
keywords=['nostr', 'nip90', 'dvm', 'data vending machine'],
|
||||
url="https://github.com/believethehype/nostrdvm",
|
||||
|
@ -151,8 +151,8 @@ def nostr_client():
|
||||
#nostr_client_test_translation("This is the result of the DVM in spanish", "text", "es", 20, 20)
|
||||
#nostr_client_test_translation("note1p8cx2dz5ss5gnk7c59zjydcncx6a754c0hsyakjvnw8xwlm5hymsnc23rs", "event", "es", 20,20)
|
||||
#nostr_client_test_translation("44a0a8b395ade39d46b9d20038b3f0c8a11168e67c442e3ece95e4a1703e2beb", "event", "zh", 20, 20)
|
||||
#nostr_client_test_image("a beautiful purple ostrich watching the sunset")
|
||||
nostr_client_test_tts("Hello, this is a test. Mic check one, two.")
|
||||
nostr_client_test_image("a beautiful purple ostrich watching the sunset")
|
||||
#nostr_client_test_tts("Hello, this is a test. Mic check one, two.")
|
||||
|
||||
|
||||
#cashutoken = "cashuAeyJ0b2tlbiI6W3sicHJvb2ZzIjpbeyJpZCI6InZxc1VRSVorb0sxOSIsImFtb3VudCI6MSwiQyI6IjAyNWU3ODZhOGFkMmExYTg0N2YxMzNiNGRhM2VhMGIyYWRhZGFkOTRiYzA4M2E2NWJjYjFlOTgwYTE1NGIyMDA2NCIsInNlY3JldCI6InQ1WnphMTZKMGY4UElQZ2FKTEg4V3pPck5rUjhESWhGa291LzVzZFd4S0U9In0seyJpZCI6InZxc1VRSVorb0sxOSIsImFtb3VudCI6NCwiQyI6IjAyOTQxNmZmMTY2MzU5ZWY5ZDc3MDc2MGNjZmY0YzliNTMzMzVmZTA2ZGI5YjBiZDg2Njg5Y2ZiZTIzMjVhYWUwYiIsInNlY3JldCI6IlRPNHB5WE43WlZqaFRQbnBkQ1BldWhncm44UHdUdE5WRUNYWk9MTzZtQXM9In0seyJpZCI6InZxc1VRSVorb0sxOSIsImFtb3VudCI6MTYsIkMiOiIwMmRiZTA3ZjgwYmMzNzE0N2YyMDJkNTZiMGI3ZTIzZTdiNWNkYTBhNmI3Yjg3NDExZWYyOGRiZDg2NjAzNzBlMWIiLCJzZWNyZXQiOiJHYUNIdHhzeG9HM3J2WWNCc0N3V0YxbU1NVXczK0dDN1RKRnVwOHg1cURzPSJ9XSwibWludCI6Imh0dHBzOi8vbG5iaXRzLmJpdGNvaW5maXhlc3RoaXMub3JnL2Nhc2h1L2FwaS92MS9ScDlXZGdKZjlxck51a3M1eVQ2SG5rIn1dfQ=="
|
||||
|
Loading…
x
Reference in New Issue
Block a user