mirror of
https://github.com/believethehype/nostrdvm.git
synced 2025-03-18 05:41:51 +01:00
fixes related to new nostr-sdk, fix mlx example
This commit is contained in:
parent
ae7d869474
commit
98c5826c22
8
main.py
8
main.py
@ -140,6 +140,14 @@ def playground():
|
||||
bot_config.SUPPORTED_DVMS.append(tts)
|
||||
tts.run()
|
||||
|
||||
search = advanced_search.build_example("Advanced Search", "discovery_content_search", admin_config)
|
||||
bot_config.SUPPORTED_DVMS.append(search)
|
||||
search.run()
|
||||
|
||||
|
||||
inactive = discovery_inactive_follows.build_example("Inactive People you follow", "discovery_inactive_follows", admin_config)
|
||||
bot_config.SUPPORTED_DVMS.append(inactive)
|
||||
inactive.run()
|
||||
|
||||
if platform == "darwin":
|
||||
# Test with MLX for OSX M1/M2/M3 chips
|
||||
|
@ -0,0 +1,97 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .model_io import (
|
||||
_DEFAULT_MODEL,
|
||||
load_autoencoder,
|
||||
load_diffusion_config,
|
||||
load_text_encoder,
|
||||
load_tokenizer,
|
||||
load_unet,
|
||||
)
|
||||
from .sampler import SimpleEulerSampler
|
||||
|
||||
|
||||
def _repeat(x, n, axis):
|
||||
# Make the expanded shape
|
||||
s = x.shape
|
||||
s.insert(axis + 1, n)
|
||||
|
||||
# Expand
|
||||
x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s)
|
||||
|
||||
# Make the flattened shape
|
||||
s.pop(axis + 1)
|
||||
s[axis] *= n
|
||||
|
||||
return x.reshape(s)
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
self.dtype = mx.float16 if float16 else mx.float32
|
||||
self.diffusion_config = load_diffusion_config(model)
|
||||
self.unet = load_unet(model, float16)
|
||||
self.text_encoder = load_text_encoder(model, float16)
|
||||
self.autoencoder = load_autoencoder(model, float16)
|
||||
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
||||
self.tokenizer = load_tokenizer(model)
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 50,
|
||||
cfg_weight: float = 7.5,
|
||||
negative_text: str = "",
|
||||
latent_size: Tuple[int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = seed or int(time.time())
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Tokenize the text
|
||||
tokens = [self.tokenizer.tokenize(text)]
|
||||
if cfg_weight > 1:
|
||||
tokens += [self.tokenizer.tokenize(negative_text)]
|
||||
lengths = [len(t) for t in tokens]
|
||||
N = max(lengths)
|
||||
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||
tokens = mx.array(tokens)
|
||||
|
||||
# Compute the features
|
||||
conditioning = self.text_encoder(tokens)
|
||||
|
||||
# Repeat the conditioning for each of the generated images
|
||||
if n_images > 1:
|
||||
conditioning = _repeat(conditioning, n_images, axis=0)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior(
|
||||
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
|
||||
)
|
||||
|
||||
# Perform the denoising loop
|
||||
x_t = x_T
|
||||
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
|
||||
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
|
||||
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
|
||||
|
||||
if cfg_weight > 1:
|
||||
eps_text, eps_neg = eps_pred.split(2)
|
||||
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
|
||||
|
||||
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
|
||||
x_t = x_t_prev
|
||||
yield x_t
|
||||
|
||||
def decode(self, x_t):
|
||||
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
|
||||
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
|
||||
return x
|
@ -3,7 +3,7 @@ import os
|
||||
from datetime import timedelta
|
||||
from threading import Thread
|
||||
|
||||
from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options
|
||||
from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options, SecretKey
|
||||
|
||||
from nostr_dvm.interfaces.dvmtaskinterface import DVMTaskInterface
|
||||
from nostr_dvm.utils.admin_utils import AdminConfig
|
||||
@ -38,7 +38,6 @@ class DiscoverInactiveFollows(DVMTaskInterface):
|
||||
return True
|
||||
|
||||
def create_request_from_nostr_event(self, event, client=None, dvm_config=None):
|
||||
self.client = client
|
||||
self.dvm_config = dvm_config
|
||||
|
||||
request_form = {"jobID": event.id().to_hex()}
|
||||
@ -67,11 +66,19 @@ class DiscoverInactiveFollows(DVMTaskInterface):
|
||||
from types import SimpleNamespace
|
||||
ns = SimpleNamespace()
|
||||
|
||||
opts = (Options().wait_for_send(False).send_timeout(timedelta(seconds=self.dvm_config.RELAY_TIMEOUT)))
|
||||
sk = SecretKey.from_hex(self.dvm_config.PRIVATE_KEY)
|
||||
keys = Keys.from_sk_str(sk.to_hex())
|
||||
cli = Client.with_opts(keys, opts)
|
||||
for relay in self.dvm_config.RELAY_LIST:
|
||||
cli.add_relay(relay)
|
||||
cli.connect()
|
||||
|
||||
options = DVMTaskInterface.set_options(request_form)
|
||||
step = 20
|
||||
|
||||
followers_filter = Filter().author(PublicKey.from_hex(options["user"])).kind(3).limit(1)
|
||||
followers = self.client.get_events_of([followers_filter], timedelta(seconds=self.dvm_config.RELAY_TIMEOUT))
|
||||
followers = cli.get_events_of([followers_filter], timedelta(seconds=self.dvm_config.RELAY_TIMEOUT))
|
||||
|
||||
if len(followers) > 0:
|
||||
result_list = []
|
||||
|
@ -8,7 +8,7 @@ import requests
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad
|
||||
from bech32 import bech32_decode, convertbits, bech32_encode
|
||||
from nostr_sdk import nostr_sdk, PublicKey, SecretKey, Event, EventBuilder, Tag, Keys
|
||||
from nostr_sdk import nostr_sdk, PublicKey, SecretKey, Event, EventBuilder, Tag, Keys, generate_shared_key
|
||||
|
||||
from nostr_dvm.utils.nostr_utils import get_event_by_id, check_and_decrypt_own_tags
|
||||
import lnurl
|
||||
@ -200,7 +200,7 @@ def check_for_zapplepay(pubkey_hex: str, content: str):
|
||||
|
||||
def enrypt_private_zap_message(message, privatekey, publickey):
|
||||
# Generate a random IV
|
||||
shared_secret = nostr_sdk.generate_shared_key(privatekey, publickey)
|
||||
shared_secret = generate_shared_key(privatekey, publickey)
|
||||
iv = os.urandom(16)
|
||||
|
||||
# Encrypt the message
|
||||
@ -215,7 +215,7 @@ def enrypt_private_zap_message(message, privatekey, publickey):
|
||||
|
||||
|
||||
def decrypt_private_zap_message(msg: str, privkey: SecretKey, pubkey: PublicKey):
|
||||
shared_secret = nostr_sdk.generate_shared_key(privkey, pubkey)
|
||||
shared_secret = generate_shared_key(privkey, pubkey)
|
||||
if len(shared_secret) != 16 and len(shared_secret) != 32:
|
||||
return "invalid shared secret size"
|
||||
parts = msg.split("_")
|
||||
|
Loading…
x
Reference in New Issue
Block a user