mirror of
https://github.com/believethehype/nostrdvm.git
synced 2025-10-10 01:02:41 +02:00
adding MLX backend and example for stable diffusion 2.1
from https://github.com/ml-explore/mlx-examples
This commit is contained in:
0
backends/__init__.py
Normal file
0
backends/__init__.py
Normal file
0
backends/mlx/__init__.py
Normal file
0
backends/mlx/__init__.py
Normal file
0
backends/mlx/stable_diffusion/__init__.py
Normal file
0
backends/mlx/stable_diffusion/__init__.py
Normal file
70
backends/mlx/stable_diffusion/clip.py
Normal file
70
backends/mlx/stable_diffusion/clip.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .config import CLIPTextModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoderLayer(nn.Module):
|
||||||
|
"""The transformer encoder layer from CLIP."""
|
||||||
|
|
||||||
|
def __init__(self, model_dims: int, num_heads: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||||
|
self.layer_norm2 = nn.LayerNorm(model_dims)
|
||||||
|
|
||||||
|
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
|
||||||
|
# Add biases to the attention projections to match CLIP
|
||||||
|
self.attention.query_proj.bias = mx.zeros(model_dims)
|
||||||
|
self.attention.key_proj.bias = mx.zeros(model_dims)
|
||||||
|
self.attention.value_proj.bias = mx.zeros(model_dims)
|
||||||
|
self.attention.out_proj.bias = mx.zeros(model_dims)
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||||
|
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||||
|
|
||||||
|
def __call__(self, x, attn_mask=None):
|
||||||
|
y = self.layer_norm1(x)
|
||||||
|
y = self.attention(y, y, y, attn_mask)
|
||||||
|
x = y + x
|
||||||
|
|
||||||
|
y = self.layer_norm2(x)
|
||||||
|
y = self.linear1(y)
|
||||||
|
y = nn.gelu_approx(y)
|
||||||
|
y = self.linear2(y)
|
||||||
|
x = y + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextModel(nn.Module):
|
||||||
|
"""Implements the text encoder transformer from CLIP."""
|
||||||
|
|
||||||
|
def __init__(self, config: CLIPTextModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||||
|
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||||
|
self.layers = [
|
||||||
|
CLIPEncoderLayer(config.model_dims, config.num_heads)
|
||||||
|
for i in range(config.num_layers)
|
||||||
|
]
|
||||||
|
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
# Extract some shapes
|
||||||
|
B, N = x.shape
|
||||||
|
|
||||||
|
# Compute the embeddings
|
||||||
|
x = self.token_embedding(x)
|
||||||
|
x = x + self.position_embedding.weight[:N]
|
||||||
|
|
||||||
|
# Compute the features from the transformer
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
|
||||||
|
for l in self.layers:
|
||||||
|
x = l(x, mask)
|
||||||
|
|
||||||
|
# Apply the final layernorm and return
|
||||||
|
return self.final_layer_norm(x)
|
48
backends/mlx/stable_diffusion/config.py
Normal file
48
backends/mlx/stable_diffusion/config.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoencoderConfig:
|
||||||
|
in_channels: int = 3
|
||||||
|
out_channels: int = 3
|
||||||
|
latent_channels_out: int = 8
|
||||||
|
latent_channels_in: int = 4
|
||||||
|
block_out_channels: Tuple[int] = (128, 256, 512, 512)
|
||||||
|
layers_per_block: int = 2
|
||||||
|
norm_num_groups: int = 32
|
||||||
|
scaling_factor: float = 0.18215
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CLIPTextModelConfig:
|
||||||
|
num_layers: int = 23
|
||||||
|
model_dims: int = 1024
|
||||||
|
num_heads: int = 16
|
||||||
|
max_length: int = 77
|
||||||
|
vocab_size: int = 49408
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UNetConfig:
|
||||||
|
in_channels: int = 4
|
||||||
|
out_channels: int = 4
|
||||||
|
conv_in_kernel: int = 3
|
||||||
|
conv_out_kernel: int = 3
|
||||||
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||||
|
layers_per_block: Tuple[int] = (2, 2, 2, 2)
|
||||||
|
mid_block_layers: int = 2
|
||||||
|
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
|
||||||
|
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||||
|
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||||
|
norm_num_groups: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiffusionConfig:
|
||||||
|
beta_schedule: str = "scaled_linear"
|
||||||
|
beta_start: float = 0.00085
|
||||||
|
beta_end: float = 0.012
|
||||||
|
num_train_steps: int = 1000
|
292
backends/mlx/stable_diffusion/model_io.py
Normal file
292
backends/mlx/stable_diffusion/model_io.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from safetensors import safe_open as safetensor_open
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
|
|
||||||
|
from .clip import CLIPTextModel
|
||||||
|
from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
from .unet import UNetModel
|
||||||
|
from .vae import Autoencoder
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
||||||
|
_MODELS = {
|
||||||
|
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
||||||
|
"stabilityai/stable-diffusion-2-1-base": {
|
||||||
|
"unet_config": "unet/config.json",
|
||||||
|
"unet": "unet/diffusion_pytorch_model.safetensors",
|
||||||
|
"text_encoder_config": "text_encoder/config.json",
|
||||||
|
"text_encoder": "text_encoder/model.safetensors",
|
||||||
|
"vae_config": "vae/config.json",
|
||||||
|
"vae": "vae/diffusion_pytorch_model.safetensors",
|
||||||
|
"diffusion_config": "scheduler/scheduler_config.json",
|
||||||
|
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||||
|
"tokenizer_merges": "tokenizer/merges.txt",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _from_numpy(x):
|
||||||
|
return mx.array(np.ascontiguousarray(x))
|
||||||
|
|
||||||
|
|
||||||
|
def map_unet_weights(key, value):
|
||||||
|
# Map up/downsampling
|
||||||
|
if "downsamplers" in key:
|
||||||
|
key = key.replace("downsamplers.0.conv", "downsample")
|
||||||
|
if "upsamplers" in key:
|
||||||
|
key = key.replace("upsamplers.0.conv", "upsample")
|
||||||
|
|
||||||
|
# Map the mid block
|
||||||
|
if "mid_block.resnets.0" in key:
|
||||||
|
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||||
|
if "mid_block.attentions.0" in key:
|
||||||
|
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||||
|
if "mid_block.resnets.1" in key:
|
||||||
|
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||||
|
|
||||||
|
# Map attention layers
|
||||||
|
if "to_k" in key:
|
||||||
|
key = key.replace("to_k", "key_proj")
|
||||||
|
if "to_out.0" in key:
|
||||||
|
key = key.replace("to_out.0", "out_proj")
|
||||||
|
if "to_q" in key:
|
||||||
|
key = key.replace("to_q", "query_proj")
|
||||||
|
if "to_v" in key:
|
||||||
|
key = key.replace("to_v", "value_proj")
|
||||||
|
|
||||||
|
# Map transformer ffn
|
||||||
|
if "ff.net.2" in key:
|
||||||
|
key = key.replace("ff.net.2", "linear3")
|
||||||
|
if "ff.net.0" in key:
|
||||||
|
k1 = key.replace("ff.net.0.proj", "linear1")
|
||||||
|
k2 = key.replace("ff.net.0.proj", "linear2")
|
||||||
|
v1, v2 = np.split(value, 2)
|
||||||
|
|
||||||
|
return [(k1, _from_numpy(v1)), (k2, _from_numpy(v2))]
|
||||||
|
|
||||||
|
if "conv_shortcut.weight" in key:
|
||||||
|
value = value.squeeze()
|
||||||
|
|
||||||
|
# Transform the weights from 1x1 convs to linear
|
||||||
|
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
|
||||||
|
value = value.squeeze()
|
||||||
|
|
||||||
|
if len(value.shape) == 4:
|
||||||
|
value = value.transpose(0, 2, 3, 1)
|
||||||
|
|
||||||
|
return [(key, _from_numpy(value))]
|
||||||
|
|
||||||
|
|
||||||
|
def map_clip_text_encoder_weights(key, value):
|
||||||
|
# Remove prefixes
|
||||||
|
if key.startswith("text_model."):
|
||||||
|
key = key[11:]
|
||||||
|
if key.startswith("embeddings."):
|
||||||
|
key = key[11:]
|
||||||
|
if key.startswith("encoder."):
|
||||||
|
key = key[8:]
|
||||||
|
|
||||||
|
# Map attention layers
|
||||||
|
if "self_attn." in key:
|
||||||
|
key = key.replace("self_attn.", "attention.")
|
||||||
|
if "q_proj." in key:
|
||||||
|
key = key.replace("q_proj.", "query_proj.")
|
||||||
|
if "k_proj." in key:
|
||||||
|
key = key.replace("k_proj.", "key_proj.")
|
||||||
|
if "v_proj." in key:
|
||||||
|
key = key.replace("v_proj.", "value_proj.")
|
||||||
|
|
||||||
|
# Map ffn layers
|
||||||
|
if "mlp.fc1" in key:
|
||||||
|
key = key.replace("mlp.fc1", "linear1")
|
||||||
|
if "mlp.fc2" in key:
|
||||||
|
key = key.replace("mlp.fc2", "linear2")
|
||||||
|
|
||||||
|
return [(key, _from_numpy(value))]
|
||||||
|
|
||||||
|
|
||||||
|
def map_vae_weights(key, value):
|
||||||
|
# Map up/downsampling
|
||||||
|
if "downsamplers" in key:
|
||||||
|
key = key.replace("downsamplers.0.conv", "downsample")
|
||||||
|
if "upsamplers" in key:
|
||||||
|
key = key.replace("upsamplers.0.conv", "upsample")
|
||||||
|
|
||||||
|
# Map attention layers
|
||||||
|
if "to_k" in key:
|
||||||
|
key = key.replace("to_k", "key_proj")
|
||||||
|
if "to_out.0" in key:
|
||||||
|
key = key.replace("to_out.0", "out_proj")
|
||||||
|
if "to_q" in key:
|
||||||
|
key = key.replace("to_q", "query_proj")
|
||||||
|
if "to_v" in key:
|
||||||
|
key = key.replace("to_v", "value_proj")
|
||||||
|
|
||||||
|
# Map the mid block
|
||||||
|
if "mid_block.resnets.0" in key:
|
||||||
|
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||||
|
if "mid_block.attentions.0" in key:
|
||||||
|
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||||
|
if "mid_block.resnets.1" in key:
|
||||||
|
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||||
|
|
||||||
|
# Map the quant/post_quant layers
|
||||||
|
if "quant_conv" in key:
|
||||||
|
key = key.replace("quant_conv", "quant_proj")
|
||||||
|
value = value.squeeze()
|
||||||
|
|
||||||
|
# Map the conv_shortcut to linear
|
||||||
|
if "conv_shortcut.weight" in key:
|
||||||
|
value = value.squeeze()
|
||||||
|
|
||||||
|
if len(value.shape) == 4:
|
||||||
|
value = value.transpose(0, 2, 3, 1)
|
||||||
|
|
||||||
|
return [(key, _from_numpy(value))]
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten(params):
|
||||||
|
return [(k, v) for p in params for (k, v) in p]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
||||||
|
dtype = np.float16 if float16 else np.float32
|
||||||
|
with safetensor_open(weight_file, framework="numpy") as f:
|
||||||
|
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
|
||||||
|
model.update(tree_unflatten(weights))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_key(key: str, part: str):
|
||||||
|
if key not in _MODELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||||
|
"""Load the stable diffusion UNet from Hugging Face Hub."""
|
||||||
|
_check_key(key, "load_unet")
|
||||||
|
|
||||||
|
# Download the config and create the model
|
||||||
|
unet_config = _MODELS[key]["unet_config"]
|
||||||
|
with open(hf_hub_download(key, unet_config)) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
n_blocks = len(config["block_out_channels"])
|
||||||
|
model = UNetModel(
|
||||||
|
UNetConfig(
|
||||||
|
in_channels=config["in_channels"],
|
||||||
|
out_channels=config["out_channels"],
|
||||||
|
block_out_channels=config["block_out_channels"],
|
||||||
|
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
||||||
|
num_attention_heads=[config["attention_head_dim"]] * n_blocks
|
||||||
|
if isinstance(config["attention_head_dim"], int)
|
||||||
|
else config["attention_head_dim"],
|
||||||
|
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||||
|
norm_num_groups=config["norm_num_groups"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download the weights and map them into the model
|
||||||
|
unet_weights = _MODELS[key]["unet"]
|
||||||
|
weight_file = hf_hub_download(key, unet_weights)
|
||||||
|
_load_safetensor_weights(map_unet_weights, model, weight_file, float16)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||||
|
"""Load the stable diffusion text encoder from Hugging Face Hub."""
|
||||||
|
_check_key(key, "load_text_encoder")
|
||||||
|
|
||||||
|
# Download the config and create the model
|
||||||
|
text_encoder_config = _MODELS[key]["text_encoder_config"]
|
||||||
|
with open(hf_hub_download(key, text_encoder_config)) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
model = CLIPTextModel(
|
||||||
|
CLIPTextModelConfig(
|
||||||
|
num_layers=config["num_hidden_layers"],
|
||||||
|
model_dims=config["hidden_size"],
|
||||||
|
num_heads=config["num_attention_heads"],
|
||||||
|
max_length=config["max_position_embeddings"],
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download the weights and map them into the model
|
||||||
|
text_encoder_weights = _MODELS[key]["text_encoder"]
|
||||||
|
weight_file = hf_hub_download(key, text_encoder_weights)
|
||||||
|
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||||
|
"""Load the stable diffusion autoencoder from Hugging Face Hub."""
|
||||||
|
_check_key(key, "load_autoencoder")
|
||||||
|
|
||||||
|
# Download the config and create the model
|
||||||
|
vae_config = _MODELS[key]["vae_config"]
|
||||||
|
with open(hf_hub_download(key, vae_config)) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
model = Autoencoder(
|
||||||
|
AutoencoderConfig(
|
||||||
|
in_channels=config["in_channels"],
|
||||||
|
out_channels=config["out_channels"],
|
||||||
|
latent_channels_out=2 * config["latent_channels"],
|
||||||
|
latent_channels_in=config["latent_channels"],
|
||||||
|
block_out_channels=config["block_out_channels"],
|
||||||
|
layers_per_block=config["layers_per_block"],
|
||||||
|
norm_num_groups=config["norm_num_groups"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download the weights and map them into the model
|
||||||
|
vae_weights = _MODELS[key]["vae"]
|
||||||
|
weight_file = hf_hub_download(key, vae_weights)
|
||||||
|
_load_safetensor_weights(map_vae_weights, model, weight_file, float16)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_diffusion_config(key: str = _DEFAULT_MODEL):
|
||||||
|
"""Load the stable diffusion config from Hugging Face Hub."""
|
||||||
|
_check_key(key, "load_diffusion_config")
|
||||||
|
|
||||||
|
diffusion_config = _MODELS[key]["diffusion_config"]
|
||||||
|
with open(hf_hub_download(key, diffusion_config)) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
return DiffusionConfig(
|
||||||
|
beta_start=config["beta_start"],
|
||||||
|
beta_end=config["beta_end"],
|
||||||
|
beta_schedule=config["beta_schedule"],
|
||||||
|
num_train_steps=config["num_train_timesteps"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer(key: str = _DEFAULT_MODEL):
|
||||||
|
_check_key(key, "load_tokenizer")
|
||||||
|
|
||||||
|
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
|
||||||
|
with open(vocab_file, encoding="utf-8") as f:
|
||||||
|
vocab = json.load(f)
|
||||||
|
|
||||||
|
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
|
||||||
|
with open(merges_file, encoding="utf-8") as f:
|
||||||
|
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||||
|
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||||
|
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||||
|
|
||||||
|
return Tokenizer(bpe_ranks, vocab)
|
74
backends/mlx/stable_diffusion/sampler.py
Normal file
74
backends/mlx/stable_diffusion/sampler.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
from .config import DiffusionConfig
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def _linspace(a, b, num):
|
||||||
|
x = mx.arange(0, num) / (num - 1)
|
||||||
|
return (b - a) * x + a
|
||||||
|
|
||||||
|
|
||||||
|
def _interp(y, x_new):
|
||||||
|
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
|
||||||
|
x_low = x_new.astype(mx.int32)
|
||||||
|
x_high = mx.minimum(x_low + 1, len(y) - 1)
|
||||||
|
|
||||||
|
y_low = y[x_low]
|
||||||
|
y_high = y[x_high]
|
||||||
|
delta_x = x_new - x_low
|
||||||
|
y_new = y_low * (1 - delta_x) + delta_x * y_high
|
||||||
|
|
||||||
|
return y_new
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleEulerSampler:
|
||||||
|
"""A simple Euler integrator that can be used to sample from our diffusion models.
|
||||||
|
|
||||||
|
The method ``step()`` performs one Euler step from x_t to x_t_prev.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: DiffusionConfig):
|
||||||
|
# Compute the noise schedule
|
||||||
|
if config.beta_schedule == "linear":
|
||||||
|
betas = _linspace(
|
||||||
|
config.beta_start, config.beta_end, config.num_train_steps
|
||||||
|
)
|
||||||
|
elif config.beta_schedule == "scaled_linear":
|
||||||
|
betas = _linspace(
|
||||||
|
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
|
||||||
|
).square()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
|
||||||
|
|
||||||
|
alphas = 1 - betas
|
||||||
|
alphas_cumprod = mx.cumprod(alphas)
|
||||||
|
|
||||||
|
self._sigmas = mx.concatenate(
|
||||||
|
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||||
|
noise = mx.random.normal(shape, key=key)
|
||||||
|
return (
|
||||||
|
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
|
||||||
|
).astype(dtype)
|
||||||
|
|
||||||
|
def sigmas(self, t):
|
||||||
|
return _interp(self._sigmas, t)
|
||||||
|
|
||||||
|
def timesteps(self, num_steps: int, dtype=mx.float32):
|
||||||
|
steps = _linspace(len(self._sigmas) - 1, 0, num_steps + 1).astype(dtype)
|
||||||
|
return list(zip(steps, steps[1:]))
|
||||||
|
|
||||||
|
def step(self, eps_pred, x_t, t, t_prev):
|
||||||
|
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
||||||
|
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
||||||
|
|
||||||
|
dt = sigma_prev - sigma
|
||||||
|
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
|
||||||
|
|
||||||
|
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
||||||
|
|
||||||
|
return x_t_prev
|
100
backends/mlx/stable_diffusion/tokenizer.py
Normal file
100
backends/mlx/stable_diffusion/tokenizer.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import regex
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||||
|
|
||||||
|
def __init__(self, bpe_ranks, vocab):
|
||||||
|
self.bpe_ranks = bpe_ranks
|
||||||
|
self.vocab = vocab
|
||||||
|
self.pat = regex.compile(
|
||||||
|
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||||
|
regex.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos(self):
|
||||||
|
return "<|startoftext|>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_token(self):
|
||||||
|
return self.vocab[self.bos]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos(self):
|
||||||
|
return "<|endoftext|>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token(self):
|
||||||
|
return self.vocab[self.eos]
|
||||||
|
|
||||||
|
def bpe(self, text):
|
||||||
|
if text in self._cache:
|
||||||
|
return self._cache[text]
|
||||||
|
|
||||||
|
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
||||||
|
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||||
|
|
||||||
|
if not unique_bigrams:
|
||||||
|
return unigrams
|
||||||
|
|
||||||
|
# In every iteration try to merge the two most likely bigrams. If none
|
||||||
|
# was merged we are done.
|
||||||
|
#
|
||||||
|
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||||
|
while unique_bigrams:
|
||||||
|
bigram = min(
|
||||||
|
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
||||||
|
)
|
||||||
|
if bigram not in self.bpe_ranks:
|
||||||
|
break
|
||||||
|
|
||||||
|
new_unigrams = []
|
||||||
|
skip = False
|
||||||
|
for a, b in zip(unigrams, unigrams[1:]):
|
||||||
|
if skip:
|
||||||
|
skip = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (a, b) == bigram:
|
||||||
|
new_unigrams.append(a + b)
|
||||||
|
skip = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
new_unigrams.append(a)
|
||||||
|
|
||||||
|
if not skip:
|
||||||
|
new_unigrams.append(b)
|
||||||
|
|
||||||
|
unigrams = new_unigrams
|
||||||
|
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||||
|
|
||||||
|
self._cache[text] = unigrams
|
||||||
|
|
||||||
|
return unigrams
|
||||||
|
|
||||||
|
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||||
|
if isinstance(text, list):
|
||||||
|
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||||
|
|
||||||
|
# Lower case cleanup and split according to self.pat. Hugging Face does
|
||||||
|
# a much more thorough job here but this should suffice for 95% of
|
||||||
|
# cases.
|
||||||
|
clean_text = regex.sub(r"\s+", " ", text.lower())
|
||||||
|
tokens = regex.findall(self.pat, clean_text)
|
||||||
|
|
||||||
|
# Split the tokens according to the byte-pair merge file
|
||||||
|
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
||||||
|
|
||||||
|
# Map to token ids and return
|
||||||
|
tokens = [self.vocab[t] for t in bpe_tokens]
|
||||||
|
if prepend_bos:
|
||||||
|
tokens = [self.bos_token] + tokens
|
||||||
|
if append_eos:
|
||||||
|
tokens.append(self.eos_token)
|
||||||
|
|
||||||
|
return tokens
|
425
backends/mlx/stable_diffusion/unet.py
Normal file
425
backends/mlx/stable_diffusion/unet.py
Normal file
@@ -0,0 +1,425 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .config import UNetConfig
|
||||||
|
|
||||||
|
|
||||||
|
def upsample_nearest(x, scale: int = 2):
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
||||||
|
x = x.reshape(B, H * scale, W * scale, C)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.linear_1(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.linear_2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
hidden_dims: Optional[int] = None,
|
||||||
|
memory_dims: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(model_dims)
|
||||||
|
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
|
||||||
|
self.attn1.out_proj.bias = mx.zeros(model_dims)
|
||||||
|
|
||||||
|
memory_dims = memory_dims or model_dims
|
||||||
|
self.norm2 = nn.LayerNorm(model_dims)
|
||||||
|
self.attn2 = nn.MultiHeadAttention(
|
||||||
|
model_dims, num_heads, key_input_dims=memory_dims
|
||||||
|
)
|
||||||
|
self.attn2.out_proj.bias = mx.zeros(model_dims)
|
||||||
|
|
||||||
|
hidden_dims = hidden_dims or 4 * model_dims
|
||||||
|
self.norm3 = nn.LayerNorm(model_dims)
|
||||||
|
self.linear1 = nn.Linear(model_dims, hidden_dims)
|
||||||
|
self.linear2 = nn.Linear(model_dims, hidden_dims)
|
||||||
|
self.linear3 = nn.Linear(hidden_dims, model_dims)
|
||||||
|
|
||||||
|
def __call__(self, x, memory, attn_mask, memory_mask):
|
||||||
|
# Self attention
|
||||||
|
y = self.norm1(x)
|
||||||
|
y = self.attn1(y, y, y, attn_mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
# Cross attention
|
||||||
|
y = self.norm2(x)
|
||||||
|
y = self.attn2(y, memory, memory, memory_mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
y = self.norm3(x)
|
||||||
|
y_a = self.linear1(y)
|
||||||
|
y_b = self.linear2(y)
|
||||||
|
y = y_a * nn.gelu_approx(y_b) # approximate gelu?
|
||||||
|
y = self.linear3(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer2D(nn.Module):
|
||||||
|
"""A transformer model for inputs with 2 spatial dimensions."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
model_dims: int,
|
||||||
|
encoder_dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_layers: int = 1,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
|
||||||
|
self.proj_in = nn.Linear(in_channels, model_dims)
|
||||||
|
self.transformer_blocks = [
|
||||||
|
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
self.proj_out = nn.Linear(model_dims, in_channels)
|
||||||
|
|
||||||
|
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
||||||
|
# Save the input to add to the output
|
||||||
|
input_x = x
|
||||||
|
|
||||||
|
# Perform the input norm and projection
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = self.norm(x).reshape(B, -1, C)
|
||||||
|
x = self.proj_in(x)
|
||||||
|
|
||||||
|
# Apply the transformer
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
|
||||||
|
|
||||||
|
# Apply the output projection and reshape
|
||||||
|
x = self.proj_out(x)
|
||||||
|
x = x.reshape(B, H, W, C)
|
||||||
|
|
||||||
|
return x + input_x
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
groups: int = 32,
|
||||||
|
temb_channels: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
out_channels = out_channels or in_channels
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
if temb_channels is not None:
|
||||||
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
||||||
|
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if in_channels != out_channels:
|
||||||
|
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
||||||
|
|
||||||
|
def __call__(self, x, temb=None):
|
||||||
|
if temb is not None:
|
||||||
|
temb = self.time_emb_proj(nn.silu(temb))
|
||||||
|
|
||||||
|
y = self.norm1(x)
|
||||||
|
y = nn.silu(y)
|
||||||
|
y = self.conv1(y)
|
||||||
|
if temb is not None:
|
||||||
|
y = y + temb[:, None, None, :]
|
||||||
|
y = self.norm2(y)
|
||||||
|
y = nn.silu(y)
|
||||||
|
y = self.conv2(y)
|
||||||
|
|
||||||
|
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UNetBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
temb_channels: int,
|
||||||
|
prev_out_channels: Optional[int] = None,
|
||||||
|
num_layers: int = 1,
|
||||||
|
transformer_layers_per_block: int = 1,
|
||||||
|
num_attention_heads: int = 8,
|
||||||
|
cross_attention_dim=1280,
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
add_downsample=True,
|
||||||
|
add_upsample=True,
|
||||||
|
add_cross_attention=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Prepare the in channels list for the resnets
|
||||||
|
if prev_out_channels is None:
|
||||||
|
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
|
||||||
|
else:
|
||||||
|
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
|
||||||
|
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
|
||||||
|
in_channels_list = [
|
||||||
|
a + b for a, b in zip(in_channels_list, res_channels_list)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add resnet blocks that also process the time embedding
|
||||||
|
self.resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=ic,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
groups=resnet_groups,
|
||||||
|
)
|
||||||
|
for ic in in_channels_list
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add optional cross attention layers
|
||||||
|
if add_cross_attention:
|
||||||
|
self.attentions = [
|
||||||
|
Transformer2D(
|
||||||
|
in_channels=out_channels,
|
||||||
|
model_dims=out_channels,
|
||||||
|
num_heads=num_attention_heads,
|
||||||
|
num_layers=transformer_layers_per_block,
|
||||||
|
encoder_dims=cross_attention_dim,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add an optional downsampling layer
|
||||||
|
if add_downsample:
|
||||||
|
self.downsample = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# or upsampling layer
|
||||||
|
if add_upsample:
|
||||||
|
self.upsample = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
encoder_x=None,
|
||||||
|
temb=None,
|
||||||
|
attn_mask=None,
|
||||||
|
encoder_attn_mask=None,
|
||||||
|
residual_hidden_states=None,
|
||||||
|
):
|
||||||
|
output_states = []
|
||||||
|
|
||||||
|
for i in range(len(self.resnets)):
|
||||||
|
if residual_hidden_states is not None:
|
||||||
|
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
|
||||||
|
|
||||||
|
x = self.resnets[i](x, temb)
|
||||||
|
|
||||||
|
if "attentions" in self:
|
||||||
|
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||||
|
|
||||||
|
output_states.append(x)
|
||||||
|
|
||||||
|
if "downsample" in self:
|
||||||
|
x = self.downsample(x)
|
||||||
|
output_states.append(x)
|
||||||
|
|
||||||
|
if "upsample" in self:
|
||||||
|
x = self.upsample(upsample_nearest(x))
|
||||||
|
output_states.append(x)
|
||||||
|
|
||||||
|
return x, output_states
|
||||||
|
|
||||||
|
|
||||||
|
class UNetModel(nn.Module):
|
||||||
|
"""The conditional 2D UNet model that actually performs the denoising."""
|
||||||
|
|
||||||
|
def __init__(self, config: UNetConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
config.in_channels,
|
||||||
|
config.block_out_channels[0],
|
||||||
|
config.conv_in_kernel,
|
||||||
|
padding=(config.conv_in_kernel - 1) // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.timesteps = nn.SinusoidalPositionalEncoding(
|
||||||
|
config.block_out_channels[0],
|
||||||
|
max_freq=1,
|
||||||
|
min_freq=math.exp(
|
||||||
|
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
|
||||||
|
),
|
||||||
|
scale=1.0,
|
||||||
|
cos_first=True,
|
||||||
|
full_turns=False,
|
||||||
|
)
|
||||||
|
self.time_embedding = TimestepEmbedding(
|
||||||
|
config.block_out_channels[0],
|
||||||
|
config.block_out_channels[0] * 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make the downsampling blocks
|
||||||
|
block_channels = [config.block_out_channels[0]] + list(
|
||||||
|
config.block_out_channels
|
||||||
|
)
|
||||||
|
self.down_blocks = [
|
||||||
|
UNetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=config.block_out_channels[0] * 4,
|
||||||
|
num_layers=config.layers_per_block[i],
|
||||||
|
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||||
|
num_attention_heads=config.num_attention_heads[i],
|
||||||
|
cross_attention_dim=config.cross_attention_dim[i],
|
||||||
|
resnet_groups=config.norm_num_groups,
|
||||||
|
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||||
|
add_upsample=False,
|
||||||
|
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||||
|
)
|
||||||
|
for i, (in_channels, out_channels) in enumerate(
|
||||||
|
zip(block_channels, block_channels[1:])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Make the middle block
|
||||||
|
self.mid_blocks = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=config.block_out_channels[-1],
|
||||||
|
out_channels=config.block_out_channels[-1],
|
||||||
|
temb_channels=config.block_out_channels[0] * 4,
|
||||||
|
groups=config.norm_num_groups,
|
||||||
|
),
|
||||||
|
Transformer2D(
|
||||||
|
in_channels=config.block_out_channels[-1],
|
||||||
|
model_dims=config.block_out_channels[-1],
|
||||||
|
num_heads=config.num_attention_heads[-1],
|
||||||
|
num_layers=config.transformer_layers_per_block[-1],
|
||||||
|
encoder_dims=config.cross_attention_dim[-1],
|
||||||
|
),
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=config.block_out_channels[-1],
|
||||||
|
out_channels=config.block_out_channels[-1],
|
||||||
|
temb_channels=config.block_out_channels[0] * 4,
|
||||||
|
groups=config.norm_num_groups,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Make the upsampling blocks
|
||||||
|
block_channels = (
|
||||||
|
[config.block_out_channels[0]]
|
||||||
|
+ list(config.block_out_channels)
|
||||||
|
+ [config.block_out_channels[-1]]
|
||||||
|
)
|
||||||
|
self.up_blocks = [
|
||||||
|
UNetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
temb_channels=config.block_out_channels[0] * 4,
|
||||||
|
prev_out_channels=prev_out_channels,
|
||||||
|
num_layers=config.layers_per_block[i] + 1,
|
||||||
|
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||||
|
num_attention_heads=config.num_attention_heads[i],
|
||||||
|
cross_attention_dim=config.cross_attention_dim[i],
|
||||||
|
resnet_groups=config.norm_num_groups,
|
||||||
|
add_downsample=False,
|
||||||
|
add_upsample=(i > 0),
|
||||||
|
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||||
|
)
|
||||||
|
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
||||||
|
list(
|
||||||
|
enumerate(
|
||||||
|
zip(block_channels, block_channels[1:], block_channels[2:])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
config.norm_num_groups,
|
||||||
|
config.block_out_channels[0],
|
||||||
|
pytorch_compatible=True,
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(
|
||||||
|
config.block_out_channels[0],
|
||||||
|
config.out_channels,
|
||||||
|
config.conv_out_kernel,
|
||||||
|
padding=(config.conv_out_kernel - 1) // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
|
||||||
|
|
||||||
|
# Compute the time embeddings
|
||||||
|
temb = self.timesteps(timestep).astype(x.dtype)
|
||||||
|
temb = self.time_embedding(temb)
|
||||||
|
|
||||||
|
# Preprocess the input
|
||||||
|
x = self.conv_in(x)
|
||||||
|
|
||||||
|
# Run the downsampling part of the unet
|
||||||
|
residuals = [x]
|
||||||
|
for block in self.down_blocks:
|
||||||
|
x, res = block(
|
||||||
|
x,
|
||||||
|
encoder_x=encoder_x,
|
||||||
|
temb=temb,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
encoder_attn_mask=encoder_attn_mask,
|
||||||
|
)
|
||||||
|
residuals.extend(res)
|
||||||
|
|
||||||
|
# Run the middle part of the unet
|
||||||
|
x = self.mid_blocks[0](x, temb)
|
||||||
|
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||||
|
x = self.mid_blocks[2](x, temb)
|
||||||
|
|
||||||
|
# Run the upsampling part of the unet
|
||||||
|
for block in self.up_blocks:
|
||||||
|
x, _ = block(
|
||||||
|
x,
|
||||||
|
encoder_x=encoder_x,
|
||||||
|
temb=temb,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
encoder_attn_mask=encoder_attn_mask,
|
||||||
|
residual_hidden_states=residuals,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Postprocess the output
|
||||||
|
x = self.conv_norm_out(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
|
||||||
|
return x
|
268
backends/mlx/stable_diffusion/vae.py
Normal file
268
backends/mlx/stable_diffusion/vae.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .config import AutoencoderConfig
|
||||||
|
from .unet import ResnetBlock2D, upsample_nearest
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""A single head unmasked attention for use with the VAE."""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, norm_groups: int = 32):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
|
||||||
|
self.query_proj = nn.Linear(dims, dims)
|
||||||
|
self.key_proj = nn.Linear(dims, dims)
|
||||||
|
self.value_proj = nn.Linear(dims, dims)
|
||||||
|
self.out_proj = nn.Linear(dims, dims)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
|
||||||
|
y = self.group_norm(x)
|
||||||
|
|
||||||
|
queries = self.query_proj(y).reshape(B, H * W, C)
|
||||||
|
keys = self.key_proj(y).reshape(B, H * W, C)
|
||||||
|
values = self.value_proj(y).reshape(B, H * W, C)
|
||||||
|
|
||||||
|
scale = 1 / math.sqrt(queries.shape[-1])
|
||||||
|
scores = (queries * scale) @ keys.transpose(0, 2, 1)
|
||||||
|
attn = mx.softmax(scores, axis=-1)
|
||||||
|
y = (attn @ values).reshape(B, H, W, C)
|
||||||
|
|
||||||
|
y = self.out_proj(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderDecoderBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_layers: int = 1,
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
add_downsample=True,
|
||||||
|
add_upsample=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Add the resnet blocks
|
||||||
|
self.resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels if i == 0 else out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
groups=resnet_groups,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add an optional downsampling layer
|
||||||
|
if add_downsample:
|
||||||
|
self.downsample = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# or upsampling layer
|
||||||
|
if add_upsample:
|
||||||
|
self.upsample = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for resnet in self.resnets:
|
||||||
|
x = resnet(x)
|
||||||
|
|
||||||
|
if "downsample" in self:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
if "upsample" in self:
|
||||||
|
x = self.upsample(upsample_nearest(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
"""Implements the encoder side of the Autoencoder."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
block_out_channels: List[int] = [64],
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
channels = [block_out_channels[0]] + list(block_out_channels)
|
||||||
|
self.down_blocks = [
|
||||||
|
EncoderDecoderBlock2D(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
resnet_groups=resnet_groups,
|
||||||
|
add_downsample=i < len(block_out_channels) - 1,
|
||||||
|
add_upsample=False,
|
||||||
|
)
|
||||||
|
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
||||||
|
]
|
||||||
|
|
||||||
|
self.mid_blocks = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
out_channels=block_out_channels[-1],
|
||||||
|
groups=resnet_groups,
|
||||||
|
),
|
||||||
|
Attention(block_out_channels[-1], resnet_groups),
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
out_channels=block_out_channels[-1],
|
||||||
|
groups=resnet_groups,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv_in(x)
|
||||||
|
|
||||||
|
for l in self.down_blocks:
|
||||||
|
x = l(x)
|
||||||
|
|
||||||
|
x = self.mid_blocks[0](x)
|
||||||
|
x = self.mid_blocks[1](x)
|
||||||
|
x = self.mid_blocks[2](x)
|
||||||
|
|
||||||
|
x = self.conv_norm_out(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""Implements the decoder side of the Autoencoder."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
block_out_channels: List[int] = [64],
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mid_blocks = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
out_channels=block_out_channels[-1],
|
||||||
|
groups=resnet_groups,
|
||||||
|
),
|
||||||
|
Attention(block_out_channels[-1], resnet_groups),
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=block_out_channels[-1],
|
||||||
|
out_channels=block_out_channels[-1],
|
||||||
|
groups=resnet_groups,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
channels = list(reversed(block_out_channels))
|
||||||
|
channels = [channels[0]] + channels
|
||||||
|
self.up_blocks = [
|
||||||
|
EncoderDecoderBlock2D(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
resnet_groups=resnet_groups,
|
||||||
|
add_downsample=False,
|
||||||
|
add_upsample=i < len(block_out_channels) - 1,
|
||||||
|
)
|
||||||
|
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
||||||
|
]
|
||||||
|
|
||||||
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
resnet_groups, block_out_channels[0], pytorch_compatible=True
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv_in(x)
|
||||||
|
|
||||||
|
x = self.mid_blocks[0](x)
|
||||||
|
x = self.mid_blocks[1](x)
|
||||||
|
x = self.mid_blocks[2](x)
|
||||||
|
|
||||||
|
for l in self.up_blocks:
|
||||||
|
x = l(x)
|
||||||
|
|
||||||
|
x = self.conv_norm_out(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Autoencoder(nn.Module):
|
||||||
|
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
||||||
|
|
||||||
|
def __init__(self, config: AutoencoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.latent_channels = config.latent_channels_in
|
||||||
|
self.scaling_factor = config.scaling_factor
|
||||||
|
self.encoder = Encoder(
|
||||||
|
config.in_channels,
|
||||||
|
config.latent_channels_out,
|
||||||
|
config.block_out_channels,
|
||||||
|
config.layers_per_block,
|
||||||
|
resnet_groups=config.norm_num_groups,
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
config.latent_channels_in,
|
||||||
|
config.out_channels,
|
||||||
|
config.block_out_channels,
|
||||||
|
config.layers_per_block + 1,
|
||||||
|
resnet_groups=config.norm_num_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_proj = nn.Linear(
|
||||||
|
config.latent_channels_out, config.latent_channels_out
|
||||||
|
)
|
||||||
|
self.post_quant_proj = nn.Linear(
|
||||||
|
config.latent_channels_in, config.latent_channels_in
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
return self.decoder(self.post_quant_proj(z))
|
||||||
|
|
||||||
|
def __call__(self, x, key=None):
|
||||||
|
x = self.encoder(x)
|
||||||
|
x = self.quant_proj(x)
|
||||||
|
|
||||||
|
mean, logvar = x.split(2, axis=-1)
|
||||||
|
std = mx.exp(0.5 * logvar)
|
||||||
|
z = mx.random.normal(mean.shape, key=key) * std + mean
|
||||||
|
|
||||||
|
x_hat = self.decode(z)
|
||||||
|
|
||||||
|
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
11
main.py
11
main.py
@@ -5,7 +5,8 @@ import dotenv
|
|||||||
from nostr_dvm.bot import Bot
|
from nostr_dvm.bot import Bot
|
||||||
from nostr_dvm.tasks import videogeneration_replicate_svd, imagegeneration_replicate_sdxl, textgeneration_llmlite, \
|
from nostr_dvm.tasks import videogeneration_replicate_svd, imagegeneration_replicate_sdxl, textgeneration_llmlite, \
|
||||||
trending_notes_nostrband, discovery_inactive_follows, translation_google, textextraction_pdf, \
|
trending_notes_nostrband, discovery_inactive_follows, translation_google, textextraction_pdf, \
|
||||||
translation_libretranslate, textextraction_google, convert_media, imagegeneration_openai_dalle, texttospeech
|
translation_libretranslate, textextraction_google, convert_media, imagegeneration_openai_dalle, texttospeech, \
|
||||||
|
imagegeneration_mlx, advanced_search, textextraction_whisper_mlx
|
||||||
from nostr_dvm.utils.admin_utils import AdminConfig
|
from nostr_dvm.utils.admin_utils import AdminConfig
|
||||||
from nostr_dvm.utils.backend_utils import keep_alive
|
from nostr_dvm.utils.backend_utils import keep_alive
|
||||||
from nostr_dvm.utils.definitions import EventDefinitions
|
from nostr_dvm.utils.definitions import EventDefinitions
|
||||||
@@ -138,6 +139,14 @@ def playground():
|
|||||||
bot_config.SUPPORTED_DVMS.append(tts)
|
bot_config.SUPPORTED_DVMS.append(tts)
|
||||||
tts.run()
|
tts.run()
|
||||||
|
|
||||||
|
from sys import platform
|
||||||
|
if platform == "darwin":
|
||||||
|
# Test with MLX for OSX M1/M2/M3 chips
|
||||||
|
mlx = imagegeneration_mlx.build_example("SD with MLX", "mlx_sd", admin_config)
|
||||||
|
bot_config.SUPPORTED_DVMS.append(mlx)
|
||||||
|
mlx.run()
|
||||||
|
|
||||||
|
|
||||||
# Run the bot
|
# Run the bot
|
||||||
Bot(bot_config)
|
Bot(bot_config)
|
||||||
# Keep the main function alive for libraries that require it, like openai
|
# Keep the main function alive for libraries that require it, like openai
|
||||||
|
186
nostr_dvm/tasks/imagegeneration_mlx.py
Normal file
186
nostr_dvm/tasks/imagegeneration_mlx.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from nostr_dvm.interfaces.dvmtaskinterface import DVMTaskInterface
|
||||||
|
from nostr_dvm.utils.admin_utils import AdminConfig
|
||||||
|
from nostr_dvm.utils.definitions import EventDefinitions
|
||||||
|
from nostr_dvm.utils.dvmconfig import DVMConfig, build_default_config
|
||||||
|
from nostr_dvm.utils.nip89_utils import NIP89Config, check_and_set_d_tag
|
||||||
|
from nostr_dvm.utils.output_utils import upload_media_to_hoster
|
||||||
|
from nostr_dvm.utils.zap_utils import get_price_per_sat
|
||||||
|
|
||||||
|
"""
|
||||||
|
This File contains a Module to generate an Image on replicate and receive results back.
|
||||||
|
|
||||||
|
Accepted Inputs: Prompt (text)
|
||||||
|
Outputs: An url to an Image
|
||||||
|
Params:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenerationMLX(DVMTaskInterface):
|
||||||
|
KIND: int = EventDefinitions.KIND_NIP90_GENERATE_IMAGE
|
||||||
|
TASK: str = "text-to-image"
|
||||||
|
FIX_COST: float = 120
|
||||||
|
dependencies = [("nostr-dvm", "nostr-dvm"),
|
||||||
|
("mlx", "mlx"),
|
||||||
|
("safetensors", "safetensors"),
|
||||||
|
("huggingface-hub", "huggingface-hub"),
|
||||||
|
("regex", "regex"),
|
||||||
|
("tqdm", "tqdm"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, name, dvm_config: DVMConfig, nip89config: NIP89Config,
|
||||||
|
admin_config: AdminConfig = None, options=None):
|
||||||
|
dvm_config.SCRIPT = os.path.abspath(__file__)
|
||||||
|
super().__init__(name, dvm_config, nip89config, admin_config, options)
|
||||||
|
|
||||||
|
def is_input_supported(self, tags):
|
||||||
|
for tag in tags:
|
||||||
|
if tag.as_vec()[0] == 'i':
|
||||||
|
input_value = tag.as_vec()[1]
|
||||||
|
input_type = tag.as_vec()[2]
|
||||||
|
if input_type != "text":
|
||||||
|
return False
|
||||||
|
|
||||||
|
elif tag.as_vec()[0] == 'output':
|
||||||
|
output = tag.as_vec()[1]
|
||||||
|
if (output == "" or
|
||||||
|
not (output == "image/png" or "image/jpg"
|
||||||
|
or output == "image/png;format=url" or output == "image/jpg;format=url")):
|
||||||
|
print("Output format not supported, skipping..")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def create_request_from_nostr_event(self, event, client=None, dvm_config=None):
|
||||||
|
request_form = {"jobID": event.id().to_hex() + "_" + self.NAME.replace(" ", "")}
|
||||||
|
prompt = ""
|
||||||
|
width = "1024"
|
||||||
|
height = "1024"
|
||||||
|
|
||||||
|
for tag in event.tags():
|
||||||
|
if tag.as_vec()[0] == 'i':
|
||||||
|
input_type = tag.as_vec()[2]
|
||||||
|
if input_type == "text":
|
||||||
|
prompt = tag.as_vec()[1]
|
||||||
|
|
||||||
|
elif tag.as_vec()[0] == 'param':
|
||||||
|
print("Param: " + tag.as_vec()[1] + ": " + tag.as_vec()[2])
|
||||||
|
if tag.as_vec()[1] == "size":
|
||||||
|
if len(tag.as_vec()) > 3:
|
||||||
|
width = (tag.as_vec()[2])
|
||||||
|
height = (tag.as_vec()[3])
|
||||||
|
elif len(tag.as_vec()) == 3:
|
||||||
|
split = tag.as_vec()[2].split("x")
|
||||||
|
if len(split) > 1:
|
||||||
|
width = split[0]
|
||||||
|
height = split[1]
|
||||||
|
elif tag.as_vec()[1] == "model":
|
||||||
|
model = tag.as_vec()[2]
|
||||||
|
elif tag.as_vec()[1] == "quality":
|
||||||
|
quality = tag.as_vec()[2]
|
||||||
|
|
||||||
|
options = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"size": width + "x" + height,
|
||||||
|
"number": 1
|
||||||
|
}
|
||||||
|
request_form['options'] = json.dumps(options)
|
||||||
|
|
||||||
|
return request_form
|
||||||
|
|
||||||
|
def process(self, request_form):
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
from backends.mlx.stable_diffusion import StableDiffusion
|
||||||
|
options = DVMTaskInterface.set_options(request_form)
|
||||||
|
|
||||||
|
sd = StableDiffusion()
|
||||||
|
cfg_weight = 7.5
|
||||||
|
batchsize = 1
|
||||||
|
n_rows = 1
|
||||||
|
steps = 50
|
||||||
|
n_images = options["number"]
|
||||||
|
|
||||||
|
# Generate the latent vectors using diffusion
|
||||||
|
latents = sd.generate_latents(
|
||||||
|
options["prompt"],
|
||||||
|
n_images=n_images,
|
||||||
|
cfg_weight=cfg_weight,
|
||||||
|
num_steps=steps,
|
||||||
|
negative_text="",
|
||||||
|
)
|
||||||
|
for x_t in tqdm(latents, total=steps):
|
||||||
|
mx.simplify(x_t)
|
||||||
|
mx.simplify(x_t)
|
||||||
|
mx.eval(x_t)
|
||||||
|
|
||||||
|
# Decode them into images
|
||||||
|
decoded = []
|
||||||
|
for i in tqdm(range(0, 1, batchsize)):
|
||||||
|
decoded.append(sd.decode(x_t[i: i + batchsize]))
|
||||||
|
mx.eval(decoded[-1])
|
||||||
|
|
||||||
|
# Arrange them on a grid
|
||||||
|
x = mx.concatenate(decoded, axis=0)
|
||||||
|
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||||
|
x = x.reshape(n_rows * H, B // n_rows * W, C)
|
||||||
|
x = (x * 255).astype(mx.uint8)
|
||||||
|
|
||||||
|
# Save them to disc
|
||||||
|
image = Image.fromarray(x.__array__())
|
||||||
|
image.save("./outputs/image.jpg")
|
||||||
|
result = upload_media_to_hoster("./outputs/image.jpg")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Error in Module")
|
||||||
|
raise Exception(e)
|
||||||
|
|
||||||
|
|
||||||
|
# We build an example here that we can call by either calling this file directly from the main directory,
|
||||||
|
# or by adding it to our playground. You can call the example and adjust it to your needs or redefine it in the
|
||||||
|
# playground or elsewhere
|
||||||
|
def build_example(name, identifier, admin_config):
|
||||||
|
dvm_config = build_default_config(identifier)
|
||||||
|
admin_config.LUD16 = dvm_config.LN_ADDRESS
|
||||||
|
profit_in_sats = 10
|
||||||
|
dvm_config.FIX_COST = int(((4.0 / (get_price_per_sat("USD") * 100)) + profit_in_sats))
|
||||||
|
|
||||||
|
nip89info = {
|
||||||
|
"name": name,
|
||||||
|
"image": "https://image.nostr.build/c33ca6fc4cc038ca4adb46fdfdfda34951656f87ee364ef59095bae1495ce669.jpg",
|
||||||
|
"about": "I use Replicate to run StableDiffusion XL",
|
||||||
|
"encryptionSupported": True,
|
||||||
|
"cashuAccepted": True,
|
||||||
|
"nip90Params": {
|
||||||
|
"size": {
|
||||||
|
"required": False,
|
||||||
|
"values": ["1024:1024", "1024x1792", "1792x1024"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nip89config = NIP89Config()
|
||||||
|
nip89config.DTAG = check_and_set_d_tag(identifier, name, dvm_config.PRIVATE_KEY, nip89info["image"])
|
||||||
|
nip89config.CONTENT = json.dumps(nip89info)
|
||||||
|
|
||||||
|
return ImageGenerationMLX(name=name, dvm_config=dvm_config, nip89config=nip89config,
|
||||||
|
admin_config=admin_config)
|
||||||
|
|
||||||
|
|
||||||
|
def process_venv():
|
||||||
|
args = DVMTaskInterface.process_args()
|
||||||
|
dvm_config = build_default_config(args.identifier)
|
||||||
|
dvm = ImageGenerationMLX(name="", dvm_config=dvm_config, nip89config=NIP89Config(), admin_config=None)
|
||||||
|
result = dvm.process(json.loads(args.request))
|
||||||
|
DVMTaskInterface.write_output(result, args.output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
process_venv()
|
@@ -300,10 +300,6 @@ def get_price_per_sat(currency):
|
|||||||
|
|
||||||
|
|
||||||
def make_ln_address_nostdress(identifier, npub, pin, nostdressdomain):
|
def make_ln_address_nostdress(identifier, npub, pin, nostdressdomain):
|
||||||
# env_path = Path('.env')
|
|
||||||
# if env_path.is_file():
|
|
||||||
# dotenv.load_dotenv(env_path, verbose=True, override=True)
|
|
||||||
|
|
||||||
print(os.getenv("LNBITS_INVOICE_KEY_" + identifier.upper()))
|
print(os.getenv("LNBITS_INVOICE_KEY_" + identifier.upper()))
|
||||||
data = {
|
data = {
|
||||||
'name': identifier,
|
'name': identifier,
|
||||||
|
10
setup.py
10
setup.py
@@ -1,6 +1,6 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
VERSION = '0.0.9'
|
VERSION = '0.1.0'
|
||||||
DESCRIPTION = 'A framework to build and run Nostr NIP90 Data Vending Machines'
|
DESCRIPTION = 'A framework to build and run Nostr NIP90 Data Vending Machines'
|
||||||
LONG_DESCRIPTION = ('A framework to build and run Nostr NIP90 Data Vending Machines. '
|
LONG_DESCRIPTION = ('A framework to build and run Nostr NIP90 Data Vending Machines. '
|
||||||
'This is an early stage release. Interfaces might change/brick')
|
'This is an early stage release. Interfaces might change/brick')
|
||||||
@@ -13,8 +13,10 @@ setup(
|
|||||||
author_email="believethehypeonnostr@proton.me",
|
author_email="believethehypeonnostr@proton.me",
|
||||||
description=DESCRIPTION,
|
description=DESCRIPTION,
|
||||||
long_description=LONG_DESCRIPTION,
|
long_description=LONG_DESCRIPTION,
|
||||||
packages=find_packages(include=['nostr_dvm', 'nostr_dvm.backends', 'nostr_dvm.interfaces', 'nostr_dvm.tasks',
|
packages=find_packages(include=['nostr_dvm', 'nostr_dvm.interfaces', 'nostr_dvm.tasks',
|
||||||
'nostr_dvm.utils', 'nostr_dvm.utils.scrapper']),
|
'nostr_dvm.utils', 'nostr_dvm.utils.scrapper',
|
||||||
|
'nostr_dvm.backends', 'nostr_dvm.backends.mlx',
|
||||||
|
'nostr_dvm.backends.mlx.stablediffusion']),
|
||||||
install_requires=["nostr-sdk==0.0.5",
|
install_requires=["nostr-sdk==0.0.5",
|
||||||
"bech32==1.2.0",
|
"bech32==1.2.0",
|
||||||
"pycryptodome==3.19.0",
|
"pycryptodome==3.19.0",
|
||||||
@@ -32,7 +34,7 @@ setup(
|
|||||||
"moviepy==2.0.0.dev2",
|
"moviepy==2.0.0.dev2",
|
||||||
"zipp==3.17.0",
|
"zipp==3.17.0",
|
||||||
"urllib3==2.1.0",
|
"urllib3==2.1.0",
|
||||||
"typing_extensions==4.8.0"
|
"typing_extensions>=4.9.0"
|
||||||
],
|
],
|
||||||
keywords=['nostr', 'nip90', 'dvm', 'data vending machine'],
|
keywords=['nostr', 'nip90', 'dvm', 'data vending machine'],
|
||||||
url="https://github.com/believethehype/nostrdvm",
|
url="https://github.com/believethehype/nostrdvm",
|
||||||
|
@@ -151,8 +151,8 @@ def nostr_client():
|
|||||||
#nostr_client_test_translation("This is the result of the DVM in spanish", "text", "es", 20, 20)
|
#nostr_client_test_translation("This is the result of the DVM in spanish", "text", "es", 20, 20)
|
||||||
#nostr_client_test_translation("note1p8cx2dz5ss5gnk7c59zjydcncx6a754c0hsyakjvnw8xwlm5hymsnc23rs", "event", "es", 20,20)
|
#nostr_client_test_translation("note1p8cx2dz5ss5gnk7c59zjydcncx6a754c0hsyakjvnw8xwlm5hymsnc23rs", "event", "es", 20,20)
|
||||||
#nostr_client_test_translation("44a0a8b395ade39d46b9d20038b3f0c8a11168e67c442e3ece95e4a1703e2beb", "event", "zh", 20, 20)
|
#nostr_client_test_translation("44a0a8b395ade39d46b9d20038b3f0c8a11168e67c442e3ece95e4a1703e2beb", "event", "zh", 20, 20)
|
||||||
#nostr_client_test_image("a beautiful purple ostrich watching the sunset")
|
nostr_client_test_image("a beautiful purple ostrich watching the sunset")
|
||||||
nostr_client_test_tts("Hello, this is a test. Mic check one, two.")
|
#nostr_client_test_tts("Hello, this is a test. Mic check one, two.")
|
||||||
|
|
||||||
|
|
||||||
#cashutoken = "cashuAeyJ0b2tlbiI6W3sicHJvb2ZzIjpbeyJpZCI6InZxc1VRSVorb0sxOSIsImFtb3VudCI6MSwiQyI6IjAyNWU3ODZhOGFkMmExYTg0N2YxMzNiNGRhM2VhMGIyYWRhZGFkOTRiYzA4M2E2NWJjYjFlOTgwYTE1NGIyMDA2NCIsInNlY3JldCI6InQ1WnphMTZKMGY4UElQZ2FKTEg4V3pPck5rUjhESWhGa291LzVzZFd4S0U9In0seyJpZCI6InZxc1VRSVorb0sxOSIsImFtb3VudCI6NCwiQyI6IjAyOTQxNmZmMTY2MzU5ZWY5ZDc3MDc2MGNjZmY0YzliNTMzMzVmZTA2ZGI5YjBiZDg2Njg5Y2ZiZTIzMjVhYWUwYiIsInNlY3JldCI6IlRPNHB5WE43WlZqaFRQbnBkQ1BldWhncm44UHdUdE5WRUNYWk9MTzZtQXM9In0seyJpZCI6InZxc1VRSVorb0sxOSIsImFtb3VudCI6MTYsIkMiOiIwMmRiZTA3ZjgwYmMzNzE0N2YyMDJkNTZiMGI3ZTIzZTdiNWNkYTBhNmI3Yjg3NDExZWYyOGRiZDg2NjAzNzBlMWIiLCJzZWNyZXQiOiJHYUNIdHhzeG9HM3J2WWNCc0N3V0YxbU1NVXczK0dDN1RKRnVwOHg1cURzPSJ9XSwibWludCI6Imh0dHBzOi8vbG5iaXRzLmJpdGNvaW5maXhlc3RoaXMub3JnL2Nhc2h1L2FwaS92MS9ScDlXZGdKZjlxck51a3M1eVQ2SG5rIn1dfQ=="
|
#cashutoken = "cashuAeyJ0b2tlbiI6W3sicHJvb2ZzIjpbeyJpZCI6InZxc1VRSVorb0sxOSIsImFtb3VudCI6MSwiQyI6IjAyNWU3ODZhOGFkMmExYTg0N2YxMzNiNGRhM2VhMGIyYWRhZGFkOTRiYzA4M2E2NWJjYjFlOTgwYTE1NGIyMDA2NCIsInNlY3JldCI6InQ1WnphMTZKMGY4UElQZ2FKTEg4V3pPck5rUjhESWhGa291LzVzZFd4S0U9In0seyJpZCI6InZxc1VRSVorb0sxOSIsImFtb3VudCI6NCwiQyI6IjAyOTQxNmZmMTY2MzU5ZWY5ZDc3MDc2MGNjZmY0YzliNTMzMzVmZTA2ZGI5YjBiZDg2Njg5Y2ZiZTIzMjVhYWUwYiIsInNlY3JldCI6IlRPNHB5WE43WlZqaFRQbnBkQ1BldWhncm44UHdUdE5WRUNYWk9MTzZtQXM9In0seyJpZCI6InZxc1VRSVorb0sxOSIsImFtb3VudCI6MTYsIkMiOiIwMmRiZTA3ZjgwYmMzNzE0N2YyMDJkNTZiMGI3ZTIzZTdiNWNkYTBhNmI3Yjg3NDExZWYyOGRiZDg2NjAzNzBlMWIiLCJzZWNyZXQiOiJHYUNIdHhzeG9HM3J2WWNCc0N3V0YxbU1NVXczK0dDN1RKRnVwOHg1cURzPSJ9XSwibWludCI6Imh0dHBzOi8vbG5iaXRzLmJpdGNvaW5maXhlc3RoaXMub3JnL2Nhc2h1L2FwaS92MS9ScDlXZGdKZjlxck51a3M1eVQ2SG5rIn1dfQ=="
|
||||||
|
Reference in New Issue
Block a user