fixes related to new nostr-sdk, fix mlx example

This commit is contained in:
Believethehype 2023-12-21 16:31:58 +01:00
parent ae7d869474
commit 98c5826c22
4 changed files with 118 additions and 6 deletions

View File

@ -140,6 +140,14 @@ def playground():
bot_config.SUPPORTED_DVMS.append(tts)
tts.run()
search = advanced_search.build_example("Advanced Search", "discovery_content_search", admin_config)
bot_config.SUPPORTED_DVMS.append(search)
search.run()
inactive = discovery_inactive_follows.build_example("Inactive People you follow", "discovery_inactive_follows", admin_config)
bot_config.SUPPORTED_DVMS.append(inactive)
inactive.run()
if platform == "darwin":
# Test with MLX for OSX M1/M2/M3 chips

View File

@ -0,0 +1,97 @@
# Copyright © 2023 Apple Inc.
import time
from typing import Tuple
import mlx.core as mx
from .model_io import (
_DEFAULT_MODEL,
load_autoencoder,
load_diffusion_config,
load_text_encoder,
load_tokenizer,
load_unet,
)
from .sampler import SimpleEulerSampler
def _repeat(x, n, axis):
# Make the expanded shape
s = x.shape
s.insert(axis + 1, n)
# Expand
x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s)
# Make the flattened shape
s.pop(axis + 1)
s[axis] *= n
return x.reshape(s)
class StableDiffusion:
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
self.dtype = mx.float16 if float16 else mx.float32
self.diffusion_config = load_diffusion_config(model)
self.unet = load_unet(model, float16)
self.text_encoder = load_text_encoder(model, float16)
self.autoencoder = load_autoencoder(model, float16)
self.sampler = SimpleEulerSampler(self.diffusion_config)
self.tokenizer = load_tokenizer(model)
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 50,
cfg_weight: float = 7.5,
negative_text: str = "",
latent_size: Tuple[int] = (64, 64),
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Tokenize the text
tokens = [self.tokenizer.tokenize(text)]
if cfg_weight > 1:
tokens += [self.tokenizer.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
tokens = mx.array(tokens)
# Compute the features
conditioning = self.text_encoder(tokens)
# Repeat the conditioning for each of the generated images
if n_images > 1:
conditioning = _repeat(conditioning, n_images, axis=0)
# Create the latent variables
x_T = self.sampler.sample_prior(
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
)
# Perform the denoising loop
x_t = x_T
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
if cfg_weight > 1:
eps_text, eps_neg = eps_pred.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
x_t = x_t_prev
yield x_t
def decode(self, x_t):
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
return x

View File

@ -3,7 +3,7 @@ import os
from datetime import timedelta
from threading import Thread
from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options
from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options, SecretKey
from nostr_dvm.interfaces.dvmtaskinterface import DVMTaskInterface
from nostr_dvm.utils.admin_utils import AdminConfig
@ -38,7 +38,6 @@ class DiscoverInactiveFollows(DVMTaskInterface):
return True
def create_request_from_nostr_event(self, event, client=None, dvm_config=None):
self.client = client
self.dvm_config = dvm_config
request_form = {"jobID": event.id().to_hex()}
@ -67,11 +66,19 @@ class DiscoverInactiveFollows(DVMTaskInterface):
from types import SimpleNamespace
ns = SimpleNamespace()
opts = (Options().wait_for_send(False).send_timeout(timedelta(seconds=self.dvm_config.RELAY_TIMEOUT)))
sk = SecretKey.from_hex(self.dvm_config.PRIVATE_KEY)
keys = Keys.from_sk_str(sk.to_hex())
cli = Client.with_opts(keys, opts)
for relay in self.dvm_config.RELAY_LIST:
cli.add_relay(relay)
cli.connect()
options = DVMTaskInterface.set_options(request_form)
step = 20
followers_filter = Filter().author(PublicKey.from_hex(options["user"])).kind(3).limit(1)
followers = self.client.get_events_of([followers_filter], timedelta(seconds=self.dvm_config.RELAY_TIMEOUT))
followers = cli.get_events_of([followers_filter], timedelta(seconds=self.dvm_config.RELAY_TIMEOUT))
if len(followers) > 0:
result_list = []

View File

@ -8,7 +8,7 @@ import requests
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from bech32 import bech32_decode, convertbits, bech32_encode
from nostr_sdk import nostr_sdk, PublicKey, SecretKey, Event, EventBuilder, Tag, Keys
from nostr_sdk import nostr_sdk, PublicKey, SecretKey, Event, EventBuilder, Tag, Keys, generate_shared_key
from nostr_dvm.utils.nostr_utils import get_event_by_id, check_and_decrypt_own_tags
import lnurl
@ -200,7 +200,7 @@ def check_for_zapplepay(pubkey_hex: str, content: str):
def enrypt_private_zap_message(message, privatekey, publickey):
# Generate a random IV
shared_secret = nostr_sdk.generate_shared_key(privatekey, publickey)
shared_secret = generate_shared_key(privatekey, publickey)
iv = os.urandom(16)
# Encrypt the message
@ -215,7 +215,7 @@ def enrypt_private_zap_message(message, privatekey, publickey):
def decrypt_private_zap_message(msg: str, privkey: SecretKey, pubkey: PublicKey):
shared_secret = nostr_sdk.generate_shared_key(privkey, pubkey)
shared_secret = generate_shared_key(privkey, pubkey)
if len(shared_secret) != 16 and len(shared_secret) != 32:
return "invalid shared secret size"
parts = msg.split("_")