From a2433aa70c68df793fd1017aae9e3c758c0f2dc3 Mon Sep 17 00:00:00 2001 From: Believethehype Date: Tue, 28 Nov 2023 10:08:43 +0100 Subject: [PATCH] simplify pubkey initalization --- bot.py | 5 +-- dvm.py | 17 ++------- interfaces/dvmtaskinterface.py | 8 +++- main.py | 3 +- playground.py | 5 --- tasks/imagegeneration_sdxl.py | 14 +++++-- utils/backend_utils.py | 68 ++++++++++++++++++++-------------- utils/nostr_utils.py | 31 +++++++++------- utils/zap_utils.py | 19 ++++------ 9 files changed, 89 insertions(+), 81 deletions(-) diff --git a/bot.py b/bot.py index ce20195..be5161d 100644 --- a/bot.py +++ b/bot.py @@ -201,7 +201,7 @@ class Bot: content = nostr_event.content() if is_encrypted: - if ptag == self.dvm_config.PUBLIC_KEY: + if ptag == self.keys.public_key().to_hex(): tags_str = nip04_decrypt(Keys.from_sk_str(dvm_config.PRIVATE_KEY).secret_key(), nostr_event.pubkey(), nostr_event.content()) params = json.loads(tags_str) @@ -311,7 +311,7 @@ class Bot: self.job_list.remove(entry) content = nostr_event.content() if is_encrypted: - if ptag == self.dvm_config.PUBLIC_KEY: + if ptag == self.keys.public_key().to_hex(): content = nip04_decrypt(self.keys.secret_key(), nostr_event.pubkey(), content) else: return @@ -335,7 +335,6 @@ class Bot: self.client, self.dvm_config) user = get_or_add_user(self.dvm_config.DB, sender, client=self.client, config=self.dvm_config) - print("ZAPED EVENT: " + zapped_event.as_json()) if zapped_event is not None: if not anon: print("[" + self.NAME + "] Note Zap received for Bot balance: " + str( diff --git a/dvm.py b/dvm.py index 90c88e3..c99a338 100644 --- a/dvm.py +++ b/dvm.py @@ -3,7 +3,7 @@ import typing from datetime import timedelta from nostr_sdk import PublicKey, Keys, Client, Tag, Event, EventBuilder, Filter, HandleNotification, Timestamp, \ - init_logger, LogLevel, nip04_decrypt, Options, nip04_encrypt + init_logger, LogLevel, Options, nip04_encrypt import time @@ -39,10 +39,8 @@ class DVM: .skip_disconnected_relays(skip_disconnected_relays)) self.client = Client.with_opts(self.keys, opts) - self.job_list = [] self.jobs_on_hold_list = [] - pk = self.keys.public_key() print("Nostr DVM public key: " + str(pk.to_bech32()) + " Hex: " + str(pk.to_hex()) + " Supported DVM tasks: " + @@ -53,9 +51,6 @@ class DVM: self.client.connect() zap_filter = Filter().pubkey(pk).kinds([EventDefinitions.KIND_ZAP]).since(Timestamp.now()) - # bot_dm_filter = Filter().pubkey(pk).kinds([EventDefinitions.KIND_DM]).authors(self.dvm_config.DM_ALLOWED).since( - # Timestamp.now()) - kinds = [EventDefinitions.KIND_NIP90_GENERIC] for dvm in self.dvm_config.SUPPORTED_DVMS: if dvm.KIND not in kinds: @@ -76,8 +71,6 @@ class DVM: handle_nip90_job_event(nostr_event) elif nostr_event.kind() == EventDefinitions.KIND_ZAP: handle_zap(nostr_event) - # elif nostr_event.kind() == EventDefinitions.KIND_DM: - # handle_dm(nostr_event) def handle_msg(self, relay_url, msg): return @@ -118,7 +111,6 @@ class DVM: task_is_free = True cashu_redeemed = False - cashu_message = "" if cashu != "": cashu_redeemed, cashu_message = redeem_cashu(cashu, amount, self.dvm_config, self.client) if cashu_message != "": @@ -136,7 +128,6 @@ class DVM: do_work(nip90_event) # if task is directed to us via p tag and user has balance, do the job and update balance - elif p_tag_str == Keys.from_sk_str( self.dvm_config.PUBLIC_KEY) and user.balance >= amount: balance = max(user.balance - amount, 0) @@ -147,7 +138,6 @@ class DVM: print( "[" + self.dvm_config.NIP89.name + "] Using user's balance for task: " + task + - ". Starting processing.. New balance is: " + str(balance)) send_job_status_reaction(nip90_event, "processing", True, 0, @@ -364,7 +354,7 @@ class DVM: content=None, dvm_config=None): - task = get_task(original_event, client=client, dvmconfig=dvm_config) + task = get_task(original_event, client=client, dvm_config=dvm_config) alt_description, reaction = build_status_reaction(status, task, amount, content) e_tag = Tag.parse(["e", original_event.id().to_hex()]) @@ -449,7 +439,7 @@ class DVM: if ((EventDefinitions.KIND_NIP90_EXTRACT_TEXT <= job_event.kind() <= EventDefinitions.KIND_NIP90_GENERIC) or job_event.kind() == EventDefinitions.KIND_DM): - task = get_task(job_event, client=self.client, dvmconfig=self.dvm_config) + task = get_task(job_event, client=self.client, dvm_config=self.dvm_config) for dvm in self.dvm_config.SUPPORTED_DVMS: try: @@ -459,7 +449,6 @@ class DVM: result = dvm.process(request_form) check_and_return_event(result, str(job_event.as_json())) - except Exception as e: print(e) send_job_status_reaction(job_event, "error", content=str(e), dvm_config=self.dvm_config) diff --git a/interfaces/dvmtaskinterface.py b/interfaces/dvmtaskinterface.py index af874c9..195c78c 100644 --- a/interfaces/dvmtaskinterface.py +++ b/interfaces/dvmtaskinterface.py @@ -1,6 +1,8 @@ import json from threading import Thread +from nostr_sdk import Keys + from utils.admin_utils import AdminConfig from utils.dvmconfig import DVMConfig from utils.nip89_utils import NIP89Announcement, NIP89Config @@ -26,6 +28,8 @@ class DVMTaskInterface: def init(self, name, dvm_config, admin_config=None, nip89config=None): self.NAME = name self.PRIVATE_KEY = dvm_config.PRIVATE_KEY + if dvm_config.PUBLIC_KEY == "" or dvm_config.PUBLIC_KEY is None: + dvm_config.PUBLIC_KEY = Keys.from_sk_str(dvm_config.PRIVATE_KEY).public_key().to_hex() self.PUBLIC_KEY = dvm_config.PUBLIC_KEY if dvm_config.COST is not None: self.COST = dvm_config.COST @@ -50,10 +54,12 @@ class DVMTaskInterface: nip89.content = nip89config.CONTENT return nip89 - def is_input_supported(self, input_type, input_content) -> bool: + def is_input_supported(self, tags) -> bool: """Check if input is supported for current Task.""" pass + + def create_request_form_from_nostr_event(self, event, client=None, dvm_config=None) -> dict: """Parse input into a request form that will be given to the process method""" pass diff --git a/main.py b/main.py index d7ed993..dcdf143 100644 --- a/main.py +++ b/main.py @@ -18,7 +18,6 @@ def run_nostr_dvm_with_local_config(): # Note this is very basic for now and still under development bot_config = DVMConfig() bot_config.PRIVATE_KEY = os.getenv("BOT_PRIVATE_KEY") - bot_config.PUBLIC_KEY = Keys.from_sk_str(bot_config.PRIVATE_KEY).public_key().to_hex() bot_config.LNBITS_INVOICE_KEY = os.getenv("LNBITS_INVOICE_KEY") bot_config.LNBITS_ADMIN_KEY = os.getenv("LNBITS_ADMIN_KEY") # The bot will forward zaps for us, use responsibly bot_config.LNBITS_URL = os.getenv("LNBITS_HOST") @@ -60,7 +59,7 @@ def run_nostr_dvm_with_local_config(): bot = Bot(bot_config) bot.run() - # Keep the main function alive for libraries like openai + # Keep the main function alive for libraries that require it, like openai try: while True: time.sleep(10) diff --git a/playground.py b/playground.py index 3ddcfd0..856386e 100644 --- a/playground.py +++ b/playground.py @@ -43,7 +43,6 @@ admin_config.REBROADCAST_NIP89 = False def build_pdf_extractor(name): dvm_config = DVMConfig() dvm_config.PRIVATE_KEY = os.getenv("NOSTR_PRIVATE_KEY") - dvm_config.PUBLIC_KEY = Keys.from_sk_str(dvm_config.PRIVATE_KEY).public_key().to_hex() dvm_config.LNBITS_INVOICE_KEY = os.getenv("LNBITS_INVOICE_KEY") dvm_config.LNBITS_URL = os.getenv("LNBITS_HOST") # Add NIP89 @@ -65,7 +64,6 @@ def build_pdf_extractor(name): def build_translator(name): dvm_config = DVMConfig() dvm_config.PRIVATE_KEY = os.getenv("NOSTR_PRIVATE_KEY") - dvm_config.PUBLIC_KEY = Keys.from_sk_str(dvm_config.PRIVATE_KEY).public_key().to_hex() dvm_config.LNBITS_INVOICE_KEY = os.getenv("LNBITS_INVOICE_KEY") dvm_config.LNBITS_URL = os.getenv("LNBITS_HOST") @@ -98,7 +96,6 @@ def build_translator(name): def build_unstable_diffusion(name): dvm_config = DVMConfig() dvm_config.PRIVATE_KEY = os.getenv("NOSTR_PRIVATE_KEY") - dvm_config.PUBLIC_KEY = Keys.from_sk_str(dvm_config.PRIVATE_KEY).public_key().to_hex() dvm_config.LNBITS_INVOICE_KEY = "" #This one will not use Lnbits to create invoices, but rely on zaps dvm_config.LNBITS_URL = "" @@ -132,7 +129,6 @@ def build_unstable_diffusion(name): def build_sketcher(name): dvm_config = DVMConfig() dvm_config.PRIVATE_KEY = os.getenv("NOSTR_PRIVATE_KEY2") - dvm_config.PUBLIC_KEY = Keys.from_sk_str(dvm_config.PRIVATE_KEY).public_key().to_hex() dvm_config.LNBITS_INVOICE_KEY = os.getenv("LNBITS_INVOICE_KEY") dvm_config.LNBITS_URL = os.getenv("LNBITS_HOST") @@ -168,7 +164,6 @@ def build_sketcher(name): def build_dalle(name): dvm_config = DVMConfig() dvm_config.PRIVATE_KEY = os.getenv("NOSTR_PRIVATE_KEY3") - dvm_config.PUBLIC_KEY = Keys.from_sk_str(dvm_config.PRIVATE_KEY).public_key().to_hex() dvm_config.LNBITS_INVOICE_KEY = os.getenv("LNBITS_INVOICE_KEY") dvm_config.LNBITS_URL = os.getenv("LNBITS_HOST") profit_in_sats = 10 diff --git a/tasks/imagegeneration_sdxl.py b/tasks/imagegeneration_sdxl.py index e32950c..6e0ae14 100644 --- a/tasks/imagegeneration_sdxl.py +++ b/tasks/imagegeneration_sdxl.py @@ -2,12 +2,11 @@ import json from multiprocessing.pool import ThreadPool from backends.nova_server import check_nova_server_status, send_request_to_nova_server -from dvm import DVM from interfaces.dvmtaskinterface import DVMTaskInterface from utils.admin_utils import AdminConfig -from utils.definitions import EventDefinitions from utils.dvmconfig import DVMConfig from utils.nip89_utils import NIP89Config +from utils.definitions import EventDefinitions """ This File contains a Module to transform Text input on NOVA-Server and receive results back. @@ -28,7 +27,16 @@ class ImageGenerationSDXL(DVMTaskInterface): admin_config: AdminConfig = None, options=None): super().__init__(name, dvm_config, nip89config, admin_config, options) - def is_input_supported(self, input_type, input_content): + def is_input_supported(self, tags): + for tag in tags: + if tag.as_vec()[0] == 'i': + if len(tag.as_vec()) < 3: + print("Job Event missing/malformed i tag, skipping..") + return False + else: + input_value = tag.as_vec()[1] + input_type = tag.as_vec()[2] + if input_type != "text": return False return True diff --git a/utils/backend_utils.py b/utils/backend_utils.py index cf3473e..aa0642d 100644 --- a/utils/backend_utils.py +++ b/utils/backend_utils.py @@ -7,8 +7,8 @@ from utils.definitions import EventDefinitions from utils.nostr_utils import get_event_by_id -def get_task(event, client, dvmconfig): - if event.kind() == EventDefinitions.KIND_NIP90_GENERIC: # use this for events that have no id yet +def get_task(event, client, dvm_config): + if event.kind() == EventDefinitions.KIND_NIP90_GENERIC: # use this for events that have no id yet, inclufr j tag for tag in event.tags(): if tag.as_vec()[0] == 'j': return tag.as_vec()[1] @@ -32,7 +32,7 @@ def get_task(event, client, dvmconfig): else: return "unknown job" elif tag.as_vec()[2] == "event": - evt = get_event_by_id(tag.as_vec()[1], client=client, config=dvmconfig) + evt = get_event_by_id(tag.as_vec()[1], client=client, config=dvm_config) if evt is not None: if evt.kind() == 1063: for tg in evt.tags(): @@ -44,40 +44,53 @@ def get_task(event, client, dvmconfig): return "unknown job" else: return "unknown type" - - elif event.kind() == EventDefinitions.KIND_NIP90_TRANSLATE_TEXT: - return "translation" - elif event.kind() == EventDefinitions.KIND_NIP90_GENERATE_IMAGE: - return "text-to-image" - + # TODO if a task can consist of multiple inputs add them here + # else if kind is supported, simply return task else: + for dvm in dvm_config.SUPPORTED_DVMS: + if dvm.KIND == event.kind(): + return dvm.TASK return "unknown type" +def is_input_supported__generic(tags, client, dvm_config) -> bool: + for tag in tags: + if tag.as_vec()[0] == 'i': + if len(tag.as_vec()) < 3: + print("Job Event missing/malformed i tag, skipping..") + return False + else: + input_value = tag.as_vec()[1] + input_type = tag.as_vec()[2] + + if input_type == "event": + evt = get_event_by_id(input_value, client=client, config=dvm_config) + if evt is None: + print("Event not found") + + return True + + + def check_task_is_supported(event: Event, client, get_duration=False, config=None): try: dvm_config = config input_value = "" input_type = "" duration = 1 - task = get_task(event, client=client, dvmconfig=dvm_config) + task = get_task(event, client=client, dvm_config=dvm_config) + + if not is_input_supported__generic(event.tags(), client, dvm_config): + return False, "", 0 + for tag in event.tags(): if tag.as_vec()[0] == 'i': - if len(tag.as_vec()) < 3: - print("Job Event missing/malformed i tag, skipping..") - return False, "", 0 - else: - input_value = tag.as_vec()[1] - input_type = tag.as_vec()[2] - if input_type == "event": - evt = get_event_by_id(input_value, client=client, config=dvm_config) - if evt is None: - print("Event not found") - return False, "", 0 - elif input_type == 'url' and check_url_is_readable(input_value) is None: - print("Url not readable / supported") - return False, task, duration # + input_value = tag.as_vec()[1] + input_type = tag.as_vec()[2] + if input_type == 'url' and check_url_is_readable(input_value) is None: + print("Url not readable / supported") + return False, task, duration # elif tag.as_vec()[0] == 'output': # TODO move this to individual modules @@ -90,15 +103,14 @@ def check_task_is_supported(event: Event, client, get_duration=False, config=Non print("Output format not supported, skipping..") return False, "", 0 + if task not in (x.TASK for x in dvm_config.SUPPORTED_DVMS): + return False, task, duration for dvm in dvm_config.SUPPORTED_DVMS: if dvm.TASK == task: - if not dvm.is_input_supported(input_type, event.content()): + if not dvm.is_input_supported(event.tags()): return False, task, duration - if task not in (x.TASK for x in dvm_config.SUPPORTED_DVMS): - return False, task, duration - return True, task, duration diff --git a/utils/nostr_utils.py b/utils/nostr_utils.py index 758a66f..23a45d9 100644 --- a/utils/nostr_utils.py +++ b/utils/nostr_utils.py @@ -1,7 +1,6 @@ import json -import typing from datetime import timedelta -from nostr_sdk import Filter, Client, Alphabet, EventId, Event, PublicKey, Tag, Keys, nip04_decrypt, EventBuilder +from nostr_sdk import Filter, Client, Alphabet, EventId, Event, PublicKey, Tag, Keys, nip04_decrypt def get_event_by_id(event_id: str, client: Client, config=None) -> Event | None: @@ -19,6 +18,7 @@ def get_event_by_id(event_id: str, client: Client, config=None) -> Event | None: id_filter = Filter().id(event_id).limit(1) events = client.get_events_of([id_filter], timedelta(seconds=config.RELAY_TIMEOUT)) if len(events) > 0: + return events[0] else: return None @@ -42,23 +42,26 @@ def get_referenced_event_by_id(event_id, client, dvm_config, kinds) -> Event | N def send_event(event: Event, client: Client, dvm_config) -> EventId: - relays = [] + try: + relays = [] - for tag in event.tags(): - if tag.as_vec()[0] == 'relays': - relays = tag.as_vec()[1].split(',') + for tag in event.tags(): + if tag.as_vec()[0] == 'relays': + relays = tag.as_vec()[1].split(',') - for relay in relays: - if relay not in dvm_config.RELAY_LIST: - client.add_relay(relay) + for relay in relays: + if relay not in dvm_config.RELAY_LIST: + client.add_relay(relay) - event_id = client.send_event(event) + event_id = client.send_event(event) - for relay in relays: - if relay not in dvm_config.RELAY_LIST: - client.remove_relay(relay) + for relay in relays: + if relay not in dvm_config.RELAY_LIST: + client.remove_relay(relay) - return event_id + return event_id + except Exception as e: + print(e) def check_and_decrypt_tags(event, dvm_config): diff --git a/utils/zap_utils.py b/utils/zap_utils.py index cd4f2d6..f8b6067 100644 --- a/utils/zap_utils.py +++ b/utils/zap_utils.py @@ -1,4 +1,4 @@ -# LIGHTNING FUNCTIONS +# LIGHTNING/CASHU/ZAP FUNCTIONS import base64 import json import os @@ -11,7 +11,7 @@ from nostr_sdk import nostr_sdk, PublicKey, SecretKey, Event, EventBuilder, Tag, from utils.database_utils import get_or_add_user from utils.dvmconfig import DVMConfig -from utils.nostr_utils import get_event_by_id, check_and_decrypt_tags, check_and_decrypt_own_tags +from utils.nostr_utils import get_event_by_id, check_and_decrypt_own_tags import lnurl from hashlib import sha256 @@ -240,15 +240,12 @@ def zap(lud16: str, amount: int, content, zapped_event: Event, keys, dvm_config, def parse_cashu(cashu_token): try: - try: - prefix = "cashuA" - assert cashu_token.startswith(prefix), Exception( - f"Token prefix not valid. Expected {prefix}." - ) - token_base64 = cashu_token[len(prefix):] - cashu = json.loads(base64.urlsafe_b64decode(token_base64)) - except Exception as e: - print(e) + prefix = "cashuA" + assert cashu_token.startswith(prefix), Exception( + f"Token prefix not valid. Expected {prefix}." + ) + token_base64 = cashu_token[len(prefix):] + cashu = json.loads(base64.urlsafe_b64decode(token_base64)) token = cashu["token"][0] proofs = token["proofs"]