This commit is contained in:
Believethehype 2024-10-11 10:11:06 +02:00
parent 1f2f692ced
commit c63292e503
5 changed files with 61 additions and 72 deletions

View File

@ -1,31 +1,34 @@
"""StableDiffusionXL Module
"""
import gc
import sys
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))
from ssl import Options
from nova_utils.interfaces.server_module import Processor
from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, logging
from compel import Compel, ReturnedEmbeddingsType
from nova_utils.utils.cache_utils import get_file
import numpy as np
PYTORCH_ENABLE_MPS_FALLBACK = 1
import torch
from PIL import Image
from lora import build_lora_xl
logging.disable_progress_bar()
logging.enable_explicit_format()
#logging.set_verbosity_info()
# logging.set_verbosity_info()
# Setting defaults
_default_options = {"model": "stabilityai/stable-diffusion-xl-base-1.0", "ratio": "1-1", "width": "", "height":"", "high_noise_frac" : "0.8", "n_steps" : "35", "lora" : "" }
_default_options = {"model": "stabilityai/stable-diffusion-xl-base-1.0", "ratio": "1-1", "width": "", "height": "",
"high_noise_frac": "0.8", "n_steps": "35", "lora": ""}
# TODO: add log infos,
# TODO: add log infos,
class StableDiffusionXL(Processor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -33,7 +36,6 @@ class StableDiffusionXL(Processor):
self.device = None
self.ds_iter = None
self.current_session = None
# IO shortcuts
self.input = [x for x in self.model_io if x.io_type == "input"]
@ -47,7 +49,7 @@ class StableDiffusionXL(Processor):
self.torch_d_type = torch.float16
self.ds_iter = ds_iter
current_session_name = self.ds_iter.session_names[0]
self.current_session = self.ds_iter.sessions[current_session_name]['manager']
self.current_session = self.ds_iter.sessions[current_session_name]['manager']
input_prompt = self.current_session.input_data['input_prompt'].data
input_prompt = ' '.join(input_prompt)
negative_prompt = self.current_session.input_data['negative_prompt'].data
@ -60,13 +62,13 @@ class StableDiffusionXL(Processor):
try:
if self.options['width'] != "" and self.options['height'] != "":
new_width = int(self.options['width'])
new_height = int(self.options['height'])
ratiow, ratioh = self.calculate_aspect(new_width, new_height)
new_height = int(self.options['height'])
ratiow, ratioh = self.calculate_aspect(new_width, new_height)
print("Ratio:" + str(ratiow) + ":" + str(ratioh))
else:
ratiow = str(self.options['ratio']).split('-')[0]
ratioh =str(self.options['ratio']).split('-')[1]
ratioh = str(self.options['ratio']).split('-')[1]
model = self.options["model"]
lora = self.options["lora"]
@ -77,30 +79,28 @@ class StableDiffusionXL(Processor):
width = mwidth
ratiown = int(ratiow)
ratiohn= int(ratioh)
ratiohn = int(ratioh)
if ratiown > ratiohn:
height = int((ratiohn/ratiown) * float(width))
height = int((ratiohn / ratiown) * float(width))
elif ratiown < ratiohn:
width = int((ratiown/ratiohn) * float(height))
width = int((ratiown / ratiohn) * float(height))
elif ratiown == ratiohn:
width = height
print("Processing Output width: " + str(width) + " Output height: " + str(height))
if model == "stabilityai/stable-diffusion-xl-base-1.0":
base = StableDiffusionXLPipeline.from_pretrained(model, torch_dtype=self.torch_d_type, variant=self.variant, use_safetensors=True).to(self.device)
base = StableDiffusionXLPipeline.from_pretrained(model, torch_dtype=self.torch_d_type,
variant=self.variant, use_safetensors=True).to(
self.device)
print("Loaded model: " + model)
else:
model_uri = [ x for x in self.trainer.meta_uri if x.uri_id == model][0]
model_uri = [x for x in self.trainer.meta_uri if x.uri_id == model][0]
if str(model_uri) == "":
return "Model not found"
return "Model not found"
model_path = get_file(
fname=str(model_uri.uri_id) + ".safetensors",
@ -108,43 +108,43 @@ class StableDiffusionXL(Processor):
file_hash=model_uri.uri_hash,
cache_dir=os.getenv("CACHE_DIR"),
tmp_dir=os.getenv("TMP_DIR"),
)
)
print(str(model_path))
base = StableDiffusionXLPipeline.from_single_file(str(model_path), torch_dtype=self.torch_d_type, variant=self.variant, use_safetensors=True).to(self.device)
base = StableDiffusionXLPipeline.from_single_file(str(model_path), torch_dtype=self.torch_d_type,
variant=self.variant, use_safetensors=True).to(
self.device)
print("Loaded model: " + model)
if lora != "" and lora != "None":
print("Loading lora...")
lora, input_prompt, existing_lora = build_lora_xl(lora, input_prompt, "")
print("Loading lora...")
lora, input_prompt, existing_lora = build_lora_xl(lora, input_prompt, "")
if existing_lora:
lora_uri = [ x for x in self.trainer.meta_uri if x.uri_id == lora][0]
lora_uri = [x for x in self.trainer.meta_uri if x.uri_id == lora][0]
if str(lora_uri) == "":
return "Lora not found"
return "Lora not found"
lora_path = get_file(
fname=str(lora_uri.uri_id) + ".safetensors",
origin=lora_uri.uri_url,
file_hash=lora_uri.uri_hash,
cache_dir=os.getenv("CACHE_DIR"),
tmp_dir=os.getenv("TMP_DIR"),
)
)
base.load_lora_weights(str(lora_path))
print("Loaded Lora: " + str(lora_path))
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=self.torch_d_type,
use_safetensors=True,
variant=self.variant,
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=self.torch_d_type,
use_safetensors=True,
variant=self.variant,
)
compel_base = Compel(
tokenizer=[base.tokenizer, base.tokenizer_2],
text_encoder=[base.text_encoder, base.text_encoder_2],
@ -164,15 +164,11 @@ class StableDiffusionXL(Processor):
conditioning_refiner, pooled_refiner = compel_refiner(input_prompt)
negative_conditioning_refiner, negative_pooled_refiner = compel_refiner(
negative_prompt)
n_steps = int(self.options['n_steps'])
high_noise_frac = float(self.options['high_noise_frac'])
#base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
# base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
img = base(
prompt_embeds=conditioning,
@ -211,10 +207,10 @@ class StableDiffusionXL(Processor):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if new_height != 0 or new_width != 0 and (new_width != mwidth or new_height != mheight) :
if new_height != 0 or new_width != 0 and (new_width != mwidth or new_height != mheight):
print("Resizing to width: " + str(new_width) + " height: " + str(new_height))
image = image.resize((new_width, new_height), Image.LANCZOS)
numpy_array = np.array(image)
return numpy_array
@ -223,7 +219,7 @@ class StableDiffusionXL(Processor):
print(e)
sys.stdout.flush()
return "Error"
def calculate_aspect(self, width: int, height: int):
def gcd(a, b):
"""The GCD (greatest common divisor) is the highest number that evenly divides both width and height."""
@ -235,8 +231,6 @@ class StableDiffusionXL(Processor):
return x, y
def to_output(self, data: dict):
self.current_session.output_data_templates['output_image'].data = data
return self.current_session.output_data_templates
return self.current_session.output_data_templates

View File

@ -3,11 +3,10 @@ import json
import math
import os
import signal
import time
from datetime import timedelta
from nostr_sdk import (Keys, Client, Timestamp, Filter, nip04_decrypt, HandleNotification, EventBuilder, PublicKey,
Options, Tag, Event, nip04_encrypt, NostrSigner, EventId, Nip19Event, nip44_decrypt, Kind)
Options, Tag, Event, nip04_encrypt, NostrSigner, EventId)
from nostr_dvm.utils.database_utils import fetch_user_metadata
from nostr_dvm.utils.definitions import EventDefinitions, relay_timeout
@ -19,7 +18,7 @@ from nostr_dvm.utils.nwc_tools import nwc_zap
from nostr_dvm.utils.subscription_utils import create_subscription_sql_table, add_to_subscription_sql_table, \
get_from_subscription_sql_table, update_subscription_sql_table, get_all_subscriptions_from_sql_table, \
delete_from_subscription_sql_table
from nostr_dvm.utils.zap_utils import create_bolt11_lud16, zaprequest
from nostr_dvm.utils.zap_utils import zaprequest
class Subscription:
@ -78,6 +77,7 @@ class Subscription:
await self.client.subscribe([zap_filter, dvm_filter, cancel_subscription_filter], None)
create_subscription_sql_table(dvm_config.DB)
class NotificationHandler(HandleNotification):
client = self.client
dvm_config = self.dvm_config
@ -383,8 +383,6 @@ class Subscription:
async def handle_subscription_renewal(subscription):
zaps = json.loads(subscription.zaps)
success = await pay_zap_split(subscription.nwc, subscription.amount, zaps, subscription.tier,
subscription.unit)
if success:
@ -414,21 +412,18 @@ class Subscription:
"Renewed Subscription to DVM " + subscription.tier + ". Next renewal: " + str(
Timestamp.from_secs(end).to_human_datetime().replace("Z", " ").replace("T",
" ")))
#await self.client.send_direct_msg(PublicKey.parse(subscription.subscriber), message, None)
# await self.client.send_direct_msg(PublicKey.parse(subscription.subscriber), message, None)
await self.client.send_private_msg(PublicKey.parse(subscription.subscriber), message, None)
async def check_subscriptions():
try:
subscriptions = get_all_subscriptions_from_sql_table(dvm_config.DB)
for subscription in subscriptions:
if subscription.nwc == "":
delete_from_subscription_sql_table(dvm_config.DB, subscription.id)
if subscription.active:
if subscription.end < Timestamp.now().as_secs():
# We could directly zap, but let's make another check if our subscription expired

View File

@ -4,8 +4,7 @@ import os
from datetime import timedelta
from threading import Thread
from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options, SecretKey, NostrSigner, Kind, RelayOptions, \
RelayLimits, Event
from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options, SecretKey, NostrSigner, Kind, RelayLimits
from nostr_dvm.interfaces.dvmtaskinterface import DVMTaskInterface, process_venv
from nostr_dvm.utils.admin_utils import AdminConfig
@ -13,7 +12,7 @@ from nostr_dvm.utils.definitions import EventDefinitions, relay_timeout_long, re
from nostr_dvm.utils.dvmconfig import DVMConfig, build_default_config
from nostr_dvm.utils.nip88_utils import NIP88Config
from nostr_dvm.utils.nip89_utils import NIP89Config, check_and_set_d_tag
from nostr_dvm.utils.output_utils import post_process_list_to_users, post_process_list_to_events
from nostr_dvm.utils.output_utils import post_process_list_to_events
"""
This File contains a Module to find inactive follows for a user on nostr
@ -83,8 +82,8 @@ class Discoverlatestperfollower(DVMTaskInterface):
cli = Client.with_opts(signer, opts)
for relay in self.dvm_config.RELAY_LIST:
await cli.add_relay(relay)
#ropts = RelayOptions().ping(False)
#await cli.add_relay_with_opts("wss://nostr.band", ropts)
# ropts = RelayOptions().ping(False)
# await cli.add_relay_with_opts("wss://nostr.band", ropts)
await cli.connect()
@ -171,13 +170,13 @@ class Discoverlatestperfollower(DVMTaskInterface):
result = {v for (k, v) in ns.dic.items() if v is not None}
#print(result)
#result = sorted(result, key=lambda x: x.created_at().as_secs(), reverse=True)
# print(result)
# result = sorted(result, key=lambda x: x.created_at().as_secs(), reverse=True)
new_list = sorted(result, key=lambda evt: evt.created_at().as_secs(), reverse=True)
new_res = new_list[:int(options["max_results"])]
#result = {v.id().to_hex() for (k, v) in finallist_sorted if v is not None}
# result = {v.id().to_hex() for (k, v) in finallist_sorted if v is not None}
#[: int(options["max_results"])]
# [: int(options["max_results"])]
print("events found: " + str(len(new_res)))
for v in new_res:
e_tag = Tag.parse(["e", v.id().to_hex()])

View File

@ -1,6 +1,7 @@
import json
import os
from io import BytesIO
import requests
from PIL import Image
from nostr_sdk import Kind

View File

@ -47,7 +47,8 @@ async def admin_make_database_updates(adminconfig: AdminConfig = None, dvmconfig
if not isinstance(adminconfig, AdminConfig):
return
if ((adminconfig.WHITELISTUSER is True or adminconfig.UNWHITELISTUSER is True or adminconfig.BLACKLISTUSER is True or adminconfig.DELETEUSER is True)
if ((
adminconfig.WHITELISTUSER is True or adminconfig.UNWHITELISTUSER is True or adminconfig.BLACKLISTUSER is True or adminconfig.DELETEUSER is True)
and adminconfig.USERNPUBS == []):
return
@ -82,8 +83,6 @@ async def admin_make_database_updates(adminconfig: AdminConfig = None, dvmconfig
if adminconfig.DELETEUSER:
delete_from_sql_table(db, publickey)
if adminconfig.ClEANDB:
clean_db(db)
@ -96,7 +95,8 @@ async def admin_make_database_updates(adminconfig: AdminConfig = None, dvmconfig
nut_wallet = await nutzap_wallet.get_nut_wallet(client, keys)
lud16 = adminconfig.LUD16
npub = keys.public_key().to_hex()
await nutzap_wallet.melt_cashu(nut_wallet, DVMConfig.NUZAP_MINTS[0], nut_wallet.balance, client, keys, lud16, npub)
await nutzap_wallet.melt_cashu(nut_wallet, DVMConfig.NUZAP_MINTS[0], nut_wallet.balance, client, keys, lud16,
npub)
await nutzap_wallet.get_nut_wallet(client, keys)
if adminconfig.REBROADCAST_NIP89: