mirror of
https://github.com/believethehype/nostrdvm.git
synced 2025-06-11 03:00:48 +02:00
fixes related to new nostr-sdk, fix mlx example
This commit is contained in:
parent
ae7d869474
commit
98c5826c22
8
main.py
8
main.py
@ -140,6 +140,14 @@ def playground():
|
|||||||
bot_config.SUPPORTED_DVMS.append(tts)
|
bot_config.SUPPORTED_DVMS.append(tts)
|
||||||
tts.run()
|
tts.run()
|
||||||
|
|
||||||
|
search = advanced_search.build_example("Advanced Search", "discovery_content_search", admin_config)
|
||||||
|
bot_config.SUPPORTED_DVMS.append(search)
|
||||||
|
search.run()
|
||||||
|
|
||||||
|
|
||||||
|
inactive = discovery_inactive_follows.build_example("Inactive People you follow", "discovery_inactive_follows", admin_config)
|
||||||
|
bot_config.SUPPORTED_DVMS.append(inactive)
|
||||||
|
inactive.run()
|
||||||
|
|
||||||
if platform == "darwin":
|
if platform == "darwin":
|
||||||
# Test with MLX for OSX M1/M2/M3 chips
|
# Test with MLX for OSX M1/M2/M3 chips
|
||||||
|
@ -0,0 +1,97 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .model_io import (
|
||||||
|
_DEFAULT_MODEL,
|
||||||
|
load_autoencoder,
|
||||||
|
load_diffusion_config,
|
||||||
|
load_text_encoder,
|
||||||
|
load_tokenizer,
|
||||||
|
load_unet,
|
||||||
|
)
|
||||||
|
from .sampler import SimpleEulerSampler
|
||||||
|
|
||||||
|
|
||||||
|
def _repeat(x, n, axis):
|
||||||
|
# Make the expanded shape
|
||||||
|
s = x.shape
|
||||||
|
s.insert(axis + 1, n)
|
||||||
|
|
||||||
|
# Expand
|
||||||
|
x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s)
|
||||||
|
|
||||||
|
# Make the flattened shape
|
||||||
|
s.pop(axis + 1)
|
||||||
|
s[axis] *= n
|
||||||
|
|
||||||
|
return x.reshape(s)
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion:
|
||||||
|
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
|
||||||
|
self.dtype = mx.float16 if float16 else mx.float32
|
||||||
|
self.diffusion_config = load_diffusion_config(model)
|
||||||
|
self.unet = load_unet(model, float16)
|
||||||
|
self.text_encoder = load_text_encoder(model, float16)
|
||||||
|
self.autoencoder = load_autoencoder(model, float16)
|
||||||
|
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
||||||
|
self.tokenizer = load_tokenizer(model)
|
||||||
|
|
||||||
|
def generate_latents(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
n_images: int = 1,
|
||||||
|
num_steps: int = 50,
|
||||||
|
cfg_weight: float = 7.5,
|
||||||
|
negative_text: str = "",
|
||||||
|
latent_size: Tuple[int] = (64, 64),
|
||||||
|
seed=None,
|
||||||
|
):
|
||||||
|
# Set the PRNG state
|
||||||
|
seed = seed or int(time.time())
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
# Tokenize the text
|
||||||
|
tokens = [self.tokenizer.tokenize(text)]
|
||||||
|
if cfg_weight > 1:
|
||||||
|
tokens += [self.tokenizer.tokenize(negative_text)]
|
||||||
|
lengths = [len(t) for t in tokens]
|
||||||
|
N = max(lengths)
|
||||||
|
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||||
|
tokens = mx.array(tokens)
|
||||||
|
|
||||||
|
# Compute the features
|
||||||
|
conditioning = self.text_encoder(tokens)
|
||||||
|
|
||||||
|
# Repeat the conditioning for each of the generated images
|
||||||
|
if n_images > 1:
|
||||||
|
conditioning = _repeat(conditioning, n_images, axis=0)
|
||||||
|
|
||||||
|
# Create the latent variables
|
||||||
|
x_T = self.sampler.sample_prior(
|
||||||
|
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform the denoising loop
|
||||||
|
x_t = x_T
|
||||||
|
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
|
||||||
|
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
|
||||||
|
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||||
|
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
|
||||||
|
|
||||||
|
if cfg_weight > 1:
|
||||||
|
eps_text, eps_neg = eps_pred.split(2)
|
||||||
|
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
|
||||||
|
|
||||||
|
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
|
||||||
|
x_t = x_t_prev
|
||||||
|
yield x_t
|
||||||
|
|
||||||
|
def decode(self, x_t):
|
||||||
|
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
|
||||||
|
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
|
||||||
|
return x
|
@ -3,7 +3,7 @@ import os
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options
|
from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options, SecretKey
|
||||||
|
|
||||||
from nostr_dvm.interfaces.dvmtaskinterface import DVMTaskInterface
|
from nostr_dvm.interfaces.dvmtaskinterface import DVMTaskInterface
|
||||||
from nostr_dvm.utils.admin_utils import AdminConfig
|
from nostr_dvm.utils.admin_utils import AdminConfig
|
||||||
@ -38,7 +38,6 @@ class DiscoverInactiveFollows(DVMTaskInterface):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def create_request_from_nostr_event(self, event, client=None, dvm_config=None):
|
def create_request_from_nostr_event(self, event, client=None, dvm_config=None):
|
||||||
self.client = client
|
|
||||||
self.dvm_config = dvm_config
|
self.dvm_config = dvm_config
|
||||||
|
|
||||||
request_form = {"jobID": event.id().to_hex()}
|
request_form = {"jobID": event.id().to_hex()}
|
||||||
@ -67,11 +66,19 @@ class DiscoverInactiveFollows(DVMTaskInterface):
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
ns = SimpleNamespace()
|
ns = SimpleNamespace()
|
||||||
|
|
||||||
|
opts = (Options().wait_for_send(False).send_timeout(timedelta(seconds=self.dvm_config.RELAY_TIMEOUT)))
|
||||||
|
sk = SecretKey.from_hex(self.dvm_config.PRIVATE_KEY)
|
||||||
|
keys = Keys.from_sk_str(sk.to_hex())
|
||||||
|
cli = Client.with_opts(keys, opts)
|
||||||
|
for relay in self.dvm_config.RELAY_LIST:
|
||||||
|
cli.add_relay(relay)
|
||||||
|
cli.connect()
|
||||||
|
|
||||||
options = DVMTaskInterface.set_options(request_form)
|
options = DVMTaskInterface.set_options(request_form)
|
||||||
step = 20
|
step = 20
|
||||||
|
|
||||||
followers_filter = Filter().author(PublicKey.from_hex(options["user"])).kind(3).limit(1)
|
followers_filter = Filter().author(PublicKey.from_hex(options["user"])).kind(3).limit(1)
|
||||||
followers = self.client.get_events_of([followers_filter], timedelta(seconds=self.dvm_config.RELAY_TIMEOUT))
|
followers = cli.get_events_of([followers_filter], timedelta(seconds=self.dvm_config.RELAY_TIMEOUT))
|
||||||
|
|
||||||
if len(followers) > 0:
|
if len(followers) > 0:
|
||||||
result_list = []
|
result_list = []
|
||||||
|
@ -8,7 +8,7 @@ import requests
|
|||||||
from Crypto.Cipher import AES
|
from Crypto.Cipher import AES
|
||||||
from Crypto.Util.Padding import pad
|
from Crypto.Util.Padding import pad
|
||||||
from bech32 import bech32_decode, convertbits, bech32_encode
|
from bech32 import bech32_decode, convertbits, bech32_encode
|
||||||
from nostr_sdk import nostr_sdk, PublicKey, SecretKey, Event, EventBuilder, Tag, Keys
|
from nostr_sdk import nostr_sdk, PublicKey, SecretKey, Event, EventBuilder, Tag, Keys, generate_shared_key
|
||||||
|
|
||||||
from nostr_dvm.utils.nostr_utils import get_event_by_id, check_and_decrypt_own_tags
|
from nostr_dvm.utils.nostr_utils import get_event_by_id, check_and_decrypt_own_tags
|
||||||
import lnurl
|
import lnurl
|
||||||
@ -200,7 +200,7 @@ def check_for_zapplepay(pubkey_hex: str, content: str):
|
|||||||
|
|
||||||
def enrypt_private_zap_message(message, privatekey, publickey):
|
def enrypt_private_zap_message(message, privatekey, publickey):
|
||||||
# Generate a random IV
|
# Generate a random IV
|
||||||
shared_secret = nostr_sdk.generate_shared_key(privatekey, publickey)
|
shared_secret = generate_shared_key(privatekey, publickey)
|
||||||
iv = os.urandom(16)
|
iv = os.urandom(16)
|
||||||
|
|
||||||
# Encrypt the message
|
# Encrypt the message
|
||||||
@ -215,7 +215,7 @@ def enrypt_private_zap_message(message, privatekey, publickey):
|
|||||||
|
|
||||||
|
|
||||||
def decrypt_private_zap_message(msg: str, privkey: SecretKey, pubkey: PublicKey):
|
def decrypt_private_zap_message(msg: str, privkey: SecretKey, pubkey: PublicKey):
|
||||||
shared_secret = nostr_sdk.generate_shared_key(privkey, pubkey)
|
shared_secret = generate_shared_key(privkey, pubkey)
|
||||||
if len(shared_secret) != 16 and len(shared_secret) != 32:
|
if len(shared_secret) != 16 and len(shared_secret) != 32:
|
||||||
return "invalid shared secret size"
|
return "invalid shared secret size"
|
||||||
parts = msg.split("_")
|
parts = msg.split("_")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user