From 98c5826c2236d1bbb529505b301f2b116c01c28d Mon Sep 17 00:00:00 2001 From: Believethehype Date: Thu, 21 Dec 2023 16:31:58 +0100 Subject: [PATCH] fixes related to new nostr-sdk, fix mlx example --- main.py | 8 ++ .../mlx/modules/stable_diffusion/__init__.py | 97 +++++++++++++++++++ nostr_dvm/tasks/discovery_inactive_follows.py | 13 ++- nostr_dvm/utils/zap_utils.py | 6 +- 4 files changed, 118 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 7f410cd..f10f666 100644 --- a/main.py +++ b/main.py @@ -140,6 +140,14 @@ def playground(): bot_config.SUPPORTED_DVMS.append(tts) tts.run() + search = advanced_search.build_example("Advanced Search", "discovery_content_search", admin_config) + bot_config.SUPPORTED_DVMS.append(search) + search.run() + + + inactive = discovery_inactive_follows.build_example("Inactive People you follow", "discovery_inactive_follows", admin_config) + bot_config.SUPPORTED_DVMS.append(inactive) + inactive.run() if platform == "darwin": # Test with MLX for OSX M1/M2/M3 chips diff --git a/nostr_dvm/backends/mlx/modules/stable_diffusion/__init__.py b/nostr_dvm/backends/mlx/modules/stable_diffusion/__init__.py index e69de29..079e10f 100644 --- a/nostr_dvm/backends/mlx/modules/stable_diffusion/__init__.py +++ b/nostr_dvm/backends/mlx/modules/stable_diffusion/__init__.py @@ -0,0 +1,97 @@ +# Copyright © 2023 Apple Inc. + +import time +from typing import Tuple + +import mlx.core as mx + +from .model_io import ( + _DEFAULT_MODEL, + load_autoencoder, + load_diffusion_config, + load_text_encoder, + load_tokenizer, + load_unet, +) +from .sampler import SimpleEulerSampler + + +def _repeat(x, n, axis): + # Make the expanded shape + s = x.shape + s.insert(axis + 1, n) + + # Expand + x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s) + + # Make the flattened shape + s.pop(axis + 1) + s[axis] *= n + + return x.reshape(s) + + +class StableDiffusion: + def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False): + self.dtype = mx.float16 if float16 else mx.float32 + self.diffusion_config = load_diffusion_config(model) + self.unet = load_unet(model, float16) + self.text_encoder = load_text_encoder(model, float16) + self.autoencoder = load_autoencoder(model, float16) + self.sampler = SimpleEulerSampler(self.diffusion_config) + self.tokenizer = load_tokenizer(model) + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 50, + cfg_weight: float = 7.5, + negative_text: str = "", + latent_size: Tuple[int] = (64, 64), + seed=None, + ): + # Set the PRNG state + seed = seed or int(time.time()) + mx.random.seed(seed) + + # Tokenize the text + tokens = [self.tokenizer.tokenize(text)] + if cfg_weight > 1: + tokens += [self.tokenizer.tokenize(negative_text)] + lengths = [len(t) for t in tokens] + N = max(lengths) + tokens = [t + [0] * (N - len(t)) for t in tokens] + tokens = mx.array(tokens) + + # Compute the features + conditioning = self.text_encoder(tokens) + + # Repeat the conditioning for each of the generated images + if n_images > 1: + conditioning = _repeat(conditioning, n_images, axis=0) + + # Create the latent variables + x_T = self.sampler.sample_prior( + (n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype + ) + + # Perform the denoising loop + x_t = x_T + for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype): + x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t + t_unet = mx.broadcast_to(t, [len(x_t_unet)]) + eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning) + + if cfg_weight > 1: + eps_text, eps_neg = eps_pred.split(2) + eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg) + + x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev) + x_t = x_t_prev + yield x_t + + def decode(self, x_t): + x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor) + x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5)) + return x \ No newline at end of file diff --git a/nostr_dvm/tasks/discovery_inactive_follows.py b/nostr_dvm/tasks/discovery_inactive_follows.py index 5dd2e47..964c0ef 100644 --- a/nostr_dvm/tasks/discovery_inactive_follows.py +++ b/nostr_dvm/tasks/discovery_inactive_follows.py @@ -3,7 +3,7 @@ import os from datetime import timedelta from threading import Thread -from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options +from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options, SecretKey from nostr_dvm.interfaces.dvmtaskinterface import DVMTaskInterface from nostr_dvm.utils.admin_utils import AdminConfig @@ -38,7 +38,6 @@ class DiscoverInactiveFollows(DVMTaskInterface): return True def create_request_from_nostr_event(self, event, client=None, dvm_config=None): - self.client = client self.dvm_config = dvm_config request_form = {"jobID": event.id().to_hex()} @@ -67,11 +66,19 @@ class DiscoverInactiveFollows(DVMTaskInterface): from types import SimpleNamespace ns = SimpleNamespace() + opts = (Options().wait_for_send(False).send_timeout(timedelta(seconds=self.dvm_config.RELAY_TIMEOUT))) + sk = SecretKey.from_hex(self.dvm_config.PRIVATE_KEY) + keys = Keys.from_sk_str(sk.to_hex()) + cli = Client.with_opts(keys, opts) + for relay in self.dvm_config.RELAY_LIST: + cli.add_relay(relay) + cli.connect() + options = DVMTaskInterface.set_options(request_form) step = 20 followers_filter = Filter().author(PublicKey.from_hex(options["user"])).kind(3).limit(1) - followers = self.client.get_events_of([followers_filter], timedelta(seconds=self.dvm_config.RELAY_TIMEOUT)) + followers = cli.get_events_of([followers_filter], timedelta(seconds=self.dvm_config.RELAY_TIMEOUT)) if len(followers) > 0: result_list = [] diff --git a/nostr_dvm/utils/zap_utils.py b/nostr_dvm/utils/zap_utils.py index 59c3a2f..fa1b6ae 100644 --- a/nostr_dvm/utils/zap_utils.py +++ b/nostr_dvm/utils/zap_utils.py @@ -8,7 +8,7 @@ import requests from Crypto.Cipher import AES from Crypto.Util.Padding import pad from bech32 import bech32_decode, convertbits, bech32_encode -from nostr_sdk import nostr_sdk, PublicKey, SecretKey, Event, EventBuilder, Tag, Keys +from nostr_sdk import nostr_sdk, PublicKey, SecretKey, Event, EventBuilder, Tag, Keys, generate_shared_key from nostr_dvm.utils.nostr_utils import get_event_by_id, check_and_decrypt_own_tags import lnurl @@ -200,7 +200,7 @@ def check_for_zapplepay(pubkey_hex: str, content: str): def enrypt_private_zap_message(message, privatekey, publickey): # Generate a random IV - shared_secret = nostr_sdk.generate_shared_key(privatekey, publickey) + shared_secret = generate_shared_key(privatekey, publickey) iv = os.urandom(16) # Encrypt the message @@ -215,7 +215,7 @@ def enrypt_private_zap_message(message, privatekey, publickey): def decrypt_private_zap_message(msg: str, privkey: SecretKey, pubkey: PublicKey): - shared_secret = nostr_sdk.generate_shared_key(privkey, pubkey) + shared_secret = generate_shared_key(privkey, pubkey) if len(shared_secret) != 16 and len(shared_secret) != 32: return "invalid shared secret size" parts = msg.split("_")