diff --git a/nostr_dvm/backends/nova_server/modules/stablediffusionxl/stablediffusionxl.py b/nostr_dvm/backends/nova_server/modules/stablediffusionxl/stablediffusionxl.py index 3f446eb..4e3824c 100644 --- a/nostr_dvm/backends/nova_server/modules/stablediffusionxl/stablediffusionxl.py +++ b/nostr_dvm/backends/nova_server/modules/stablediffusionxl/stablediffusionxl.py @@ -1,31 +1,34 @@ """StableDiffusionXL Module """ import gc -import sys import os +import sys sys.path.insert(0, os.path.dirname(__file__)) -from ssl import Options from nova_utils.interfaces.server_module import Processor from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, logging from compel import Compel, ReturnedEmbeddingsType from nova_utils.utils.cache_utils import get_file import numpy as np + PYTORCH_ENABLE_MPS_FALLBACK = 1 import torch from PIL import Image from lora import build_lora_xl + logging.disable_progress_bar() logging.enable_explicit_format() -#logging.set_verbosity_info() +# logging.set_verbosity_info() # Setting defaults -_default_options = {"model": "stabilityai/stable-diffusion-xl-base-1.0", "ratio": "1-1", "width": "", "height":"", "high_noise_frac" : "0.8", "n_steps" : "35", "lora" : "" } +_default_options = {"model": "stabilityai/stable-diffusion-xl-base-1.0", "ratio": "1-1", "width": "", "height": "", + "high_noise_frac": "0.8", "n_steps": "35", "lora": ""} -# TODO: add log infos, + +# TODO: add log infos, class StableDiffusionXL(Processor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -33,7 +36,6 @@ class StableDiffusionXL(Processor): self.device = None self.ds_iter = None self.current_session = None - # IO shortcuts self.input = [x for x in self.model_io if x.io_type == "input"] @@ -47,7 +49,7 @@ class StableDiffusionXL(Processor): self.torch_d_type = torch.float16 self.ds_iter = ds_iter current_session_name = self.ds_iter.session_names[0] - self.current_session = self.ds_iter.sessions[current_session_name]['manager'] + self.current_session = self.ds_iter.sessions[current_session_name]['manager'] input_prompt = self.current_session.input_data['input_prompt'].data input_prompt = ' '.join(input_prompt) negative_prompt = self.current_session.input_data['negative_prompt'].data @@ -60,13 +62,13 @@ class StableDiffusionXL(Processor): try: if self.options['width'] != "" and self.options['height'] != "": new_width = int(self.options['width']) - new_height = int(self.options['height']) - ratiow, ratioh = self.calculate_aspect(new_width, new_height) + new_height = int(self.options['height']) + ratiow, ratioh = self.calculate_aspect(new_width, new_height) print("Ratio:" + str(ratiow) + ":" + str(ratioh)) else: ratiow = str(self.options['ratio']).split('-')[0] - ratioh =str(self.options['ratio']).split('-')[1] + ratioh = str(self.options['ratio']).split('-')[1] model = self.options["model"] lora = self.options["lora"] @@ -77,30 +79,28 @@ class StableDiffusionXL(Processor): width = mwidth ratiown = int(ratiow) - ratiohn= int(ratioh) + ratiohn = int(ratioh) if ratiown > ratiohn: - height = int((ratiohn/ratiown) * float(width)) + height = int((ratiohn / ratiown) * float(width)) elif ratiown < ratiohn: - width = int((ratiown/ratiohn) * float(height)) + width = int((ratiown / ratiohn) * float(height)) elif ratiown == ratiohn: width = height - print("Processing Output width: " + str(width) + " Output height: " + str(height)) - - - if model == "stabilityai/stable-diffusion-xl-base-1.0": - base = StableDiffusionXLPipeline.from_pretrained(model, torch_dtype=self.torch_d_type, variant=self.variant, use_safetensors=True).to(self.device) + base = StableDiffusionXLPipeline.from_pretrained(model, torch_dtype=self.torch_d_type, + variant=self.variant, use_safetensors=True).to( + self.device) print("Loaded model: " + model) else: - - model_uri = [ x for x in self.trainer.meta_uri if x.uri_id == model][0] + + model_uri = [x for x in self.trainer.meta_uri if x.uri_id == model][0] if str(model_uri) == "": - return "Model not found" + return "Model not found" model_path = get_file( fname=str(model_uri.uri_id) + ".safetensors", @@ -108,43 +108,43 @@ class StableDiffusionXL(Processor): file_hash=model_uri.uri_hash, cache_dir=os.getenv("CACHE_DIR"), tmp_dir=os.getenv("TMP_DIR"), - ) - + ) + print(str(model_path)) - - base = StableDiffusionXLPipeline.from_single_file(str(model_path), torch_dtype=self.torch_d_type, variant=self.variant, use_safetensors=True).to(self.device) + base = StableDiffusionXLPipeline.from_single_file(str(model_path), torch_dtype=self.torch_d_type, + variant=self.variant, use_safetensors=True).to( + self.device) print("Loaded model: " + model) if lora != "" and lora != "None": - print("Loading lora...") - lora, input_prompt, existing_lora = build_lora_xl(lora, input_prompt, "") + print("Loading lora...") + lora, input_prompt, existing_lora = build_lora_xl(lora, input_prompt, "") if existing_lora: - lora_uri = [ x for x in self.trainer.meta_uri if x.uri_id == lora][0] + lora_uri = [x for x in self.trainer.meta_uri if x.uri_id == lora][0] if str(lora_uri) == "": - return "Lora not found" + return "Lora not found" lora_path = get_file( fname=str(lora_uri.uri_id) + ".safetensors", origin=lora_uri.uri_url, file_hash=lora_uri.uri_hash, cache_dir=os.getenv("CACHE_DIR"), tmp_dir=os.getenv("TMP_DIR"), - ) - + ) + base.load_lora_weights(str(lora_path)) print("Loaded Lora: " + str(lora_path)) refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-refiner-1.0", - text_encoder_2=base.text_encoder_2, - vae=base.vae, - torch_dtype=self.torch_d_type, - use_safetensors=True, - variant=self.variant, + "stabilityai/stable-diffusion-xl-refiner-1.0", + text_encoder_2=base.text_encoder_2, + vae=base.vae, + torch_dtype=self.torch_d_type, + use_safetensors=True, + variant=self.variant, ) - compel_base = Compel( tokenizer=[base.tokenizer, base.tokenizer_2], text_encoder=[base.text_encoder, base.text_encoder_2], @@ -164,15 +164,11 @@ class StableDiffusionXL(Processor): conditioning_refiner, pooled_refiner = compel_refiner(input_prompt) negative_conditioning_refiner, negative_pooled_refiner = compel_refiner( negative_prompt) - - + n_steps = int(self.options['n_steps']) high_noise_frac = float(self.options['high_noise_frac']) - - #base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True) - - + # base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True) img = base( prompt_embeds=conditioning, @@ -211,10 +207,10 @@ class StableDiffusionXL(Processor): torch.cuda.empty_cache() torch.cuda.ipc_collect() - if new_height != 0 or new_width != 0 and (new_width != mwidth or new_height != mheight) : + if new_height != 0 or new_width != 0 and (new_width != mwidth or new_height != mheight): print("Resizing to width: " + str(new_width) + " height: " + str(new_height)) image = image.resize((new_width, new_height), Image.LANCZOS) - + numpy_array = np.array(image) return numpy_array @@ -223,7 +219,7 @@ class StableDiffusionXL(Processor): print(e) sys.stdout.flush() return "Error" - + def calculate_aspect(self, width: int, height: int): def gcd(a, b): """The GCD (greatest common divisor) is the highest number that evenly divides both width and height.""" @@ -235,8 +231,6 @@ class StableDiffusionXL(Processor): return x, y - - def to_output(self, data: dict): self.current_session.output_data_templates['output_image'].data = data - return self.current_session.output_data_templates \ No newline at end of file + return self.current_session.output_data_templates diff --git a/nostr_dvm/subscription.py b/nostr_dvm/subscription.py index c840147..720cc5d 100644 --- a/nostr_dvm/subscription.py +++ b/nostr_dvm/subscription.py @@ -3,11 +3,10 @@ import json import math import os import signal -import time from datetime import timedelta from nostr_sdk import (Keys, Client, Timestamp, Filter, nip04_decrypt, HandleNotification, EventBuilder, PublicKey, - Options, Tag, Event, nip04_encrypt, NostrSigner, EventId, Nip19Event, nip44_decrypt, Kind) + Options, Tag, Event, nip04_encrypt, NostrSigner, EventId) from nostr_dvm.utils.database_utils import fetch_user_metadata from nostr_dvm.utils.definitions import EventDefinitions, relay_timeout @@ -19,7 +18,7 @@ from nostr_dvm.utils.nwc_tools import nwc_zap from nostr_dvm.utils.subscription_utils import create_subscription_sql_table, add_to_subscription_sql_table, \ get_from_subscription_sql_table, update_subscription_sql_table, get_all_subscriptions_from_sql_table, \ delete_from_subscription_sql_table -from nostr_dvm.utils.zap_utils import create_bolt11_lud16, zaprequest +from nostr_dvm.utils.zap_utils import zaprequest class Subscription: @@ -78,6 +77,7 @@ class Subscription: await self.client.subscribe([zap_filter, dvm_filter, cancel_subscription_filter], None) create_subscription_sql_table(dvm_config.DB) + class NotificationHandler(HandleNotification): client = self.client dvm_config = self.dvm_config @@ -383,8 +383,6 @@ class Subscription: async def handle_subscription_renewal(subscription): zaps = json.loads(subscription.zaps) - - success = await pay_zap_split(subscription.nwc, subscription.amount, zaps, subscription.tier, subscription.unit) if success: @@ -414,21 +412,18 @@ class Subscription: "Renewed Subscription to DVM " + subscription.tier + ". Next renewal: " + str( Timestamp.from_secs(end).to_human_datetime().replace("Z", " ").replace("T", " "))) - #await self.client.send_direct_msg(PublicKey.parse(subscription.subscriber), message, None) + # await self.client.send_direct_msg(PublicKey.parse(subscription.subscriber), message, None) await self.client.send_private_msg(PublicKey.parse(subscription.subscriber), message, None) - async def check_subscriptions(): try: subscriptions = get_all_subscriptions_from_sql_table(dvm_config.DB) for subscription in subscriptions: - if subscription.nwc == "": delete_from_subscription_sql_table(dvm_config.DB, subscription.id) - if subscription.active: if subscription.end < Timestamp.now().as_secs(): # We could directly zap, but let's make another check if our subscription expired diff --git a/nostr_dvm/tasks/content_discovery_latest_one_per_follower.py b/nostr_dvm/tasks/content_discovery_latest_one_per_follower.py index 2b54178..ea40a16 100644 --- a/nostr_dvm/tasks/content_discovery_latest_one_per_follower.py +++ b/nostr_dvm/tasks/content_discovery_latest_one_per_follower.py @@ -4,8 +4,7 @@ import os from datetime import timedelta from threading import Thread -from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options, SecretKey, NostrSigner, Kind, RelayOptions, \ - RelayLimits, Event +from nostr_sdk import Client, Timestamp, PublicKey, Tag, Keys, Options, SecretKey, NostrSigner, Kind, RelayLimits from nostr_dvm.interfaces.dvmtaskinterface import DVMTaskInterface, process_venv from nostr_dvm.utils.admin_utils import AdminConfig @@ -13,7 +12,7 @@ from nostr_dvm.utils.definitions import EventDefinitions, relay_timeout_long, re from nostr_dvm.utils.dvmconfig import DVMConfig, build_default_config from nostr_dvm.utils.nip88_utils import NIP88Config from nostr_dvm.utils.nip89_utils import NIP89Config, check_and_set_d_tag -from nostr_dvm.utils.output_utils import post_process_list_to_users, post_process_list_to_events +from nostr_dvm.utils.output_utils import post_process_list_to_events """ This File contains a Module to find inactive follows for a user on nostr @@ -83,8 +82,8 @@ class Discoverlatestperfollower(DVMTaskInterface): cli = Client.with_opts(signer, opts) for relay in self.dvm_config.RELAY_LIST: await cli.add_relay(relay) - #ropts = RelayOptions().ping(False) - #await cli.add_relay_with_opts("wss://nostr.band", ropts) + # ropts = RelayOptions().ping(False) + # await cli.add_relay_with_opts("wss://nostr.band", ropts) await cli.connect() @@ -171,13 +170,13 @@ class Discoverlatestperfollower(DVMTaskInterface): result = {v for (k, v) in ns.dic.items() if v is not None} - #print(result) - #result = sorted(result, key=lambda x: x.created_at().as_secs(), reverse=True) + # print(result) + # result = sorted(result, key=lambda x: x.created_at().as_secs(), reverse=True) new_list = sorted(result, key=lambda evt: evt.created_at().as_secs(), reverse=True) new_res = new_list[:int(options["max_results"])] - #result = {v.id().to_hex() for (k, v) in finallist_sorted if v is not None} + # result = {v.id().to_hex() for (k, v) in finallist_sorted if v is not None} - #[: int(options["max_results"])] + # [: int(options["max_results"])] print("events found: " + str(len(new_res))) for v in new_res: e_tag = Tag.parse(["e", v.id().to_hex()]) diff --git a/nostr_dvm/tasks/imagegeneration_replicate_sdxl.py b/nostr_dvm/tasks/imagegeneration_replicate_sdxl.py index 96724aa..b827ac4 100644 --- a/nostr_dvm/tasks/imagegeneration_replicate_sdxl.py +++ b/nostr_dvm/tasks/imagegeneration_replicate_sdxl.py @@ -1,6 +1,7 @@ import json import os from io import BytesIO + import requests from PIL import Image from nostr_sdk import Kind diff --git a/nostr_dvm/utils/admin_utils.py b/nostr_dvm/utils/admin_utils.py index bb69273..76eb5a0 100644 --- a/nostr_dvm/utils/admin_utils.py +++ b/nostr_dvm/utils/admin_utils.py @@ -47,7 +47,8 @@ async def admin_make_database_updates(adminconfig: AdminConfig = None, dvmconfig if not isinstance(adminconfig, AdminConfig): return - if ((adminconfig.WHITELISTUSER is True or adminconfig.UNWHITELISTUSER is True or adminconfig.BLACKLISTUSER is True or adminconfig.DELETEUSER is True) + if (( + adminconfig.WHITELISTUSER is True or adminconfig.UNWHITELISTUSER is True or adminconfig.BLACKLISTUSER is True or adminconfig.DELETEUSER is True) and adminconfig.USERNPUBS == []): return @@ -82,8 +83,6 @@ async def admin_make_database_updates(adminconfig: AdminConfig = None, dvmconfig if adminconfig.DELETEUSER: delete_from_sql_table(db, publickey) - - if adminconfig.ClEANDB: clean_db(db) @@ -96,7 +95,8 @@ async def admin_make_database_updates(adminconfig: AdminConfig = None, dvmconfig nut_wallet = await nutzap_wallet.get_nut_wallet(client, keys) lud16 = adminconfig.LUD16 npub = keys.public_key().to_hex() - await nutzap_wallet.melt_cashu(nut_wallet, DVMConfig.NUZAP_MINTS[0], nut_wallet.balance, client, keys, lud16, npub) + await nutzap_wallet.melt_cashu(nut_wallet, DVMConfig.NUZAP_MINTS[0], nut_wallet.balance, client, keys, lud16, + npub) await nutzap_wallet.get_nut_wallet(client, keys) if adminconfig.REBROADCAST_NIP89: