diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index 5244d9b94..58fc76aa8 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -3,6 +3,7 @@ import time from datetime import timedelta from typing import Any +import httpx import redis from celery import bootsteps # type: ignore from celery import Celery @@ -30,6 +31,7 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME from danswer.db.engine import SqlEngine +from danswer.httpx.httpx_pool import HttpxPool from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import ColoredFormatter from danswer.utils.logger import PlainFormatter @@ -113,12 +115,16 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None: @worker_init.connect def on_worker_init(sender: Any, **kwargs: Any) -> None: + EXTRA_CONCURRENCY = 8 # a few extra connections for side operations + # decide some initial startup settings based on the celery worker's hostname # (set at the command line) hostname = sender.hostname if hostname.startswith("light"): SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME) - SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) + SqlEngine.init_engine( + pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY + ) elif hostname.startswith("heavy"): SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME) SqlEngine.init_engine(pool_size=8, max_overflow=0) @@ -126,6 +132,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME) SqlEngine.init_engine(pool_size=8, max_overflow=0) + HttpxPool.init_client( + limits=httpx.Limits( + max_keepalive_connections=sender.concurrency + EXTRA_CONCURRENCY + ) + ) + r = get_redis_client() WAIT_INTERVAL = 5 @@ -212,6 +224,86 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: sender.primary_worker_lock = lock + WAIT_INTERVAL = 5 + WAIT_LIMIT = 60 + + time_start = time.monotonic() + logger.info("Redis: Readiness check starting.") + while True: + try: + if r.ping(): + break + except Exception: + pass + + time_elapsed = time.monotonic() - time_start + logger.info( + f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" + ) + if time_elapsed > WAIT_LIMIT: + msg = ( + f"Redis: Readiness check did not succeed within the timeout " + f"({WAIT_LIMIT} seconds). Exiting..." + ) + logger.error(msg) + raise WorkerShutdown(msg) + + time.sleep(WAIT_INTERVAL) + + logger.info("Redis: Readiness check succeeded. Continuing...") + + if not celery_is_worker_primary(sender): + logger.info("Running as a secondary celery worker.") + logger.info("Waiting for primary worker to be ready...") + time_start = time.monotonic() + while True: + if r.exists(DanswerRedisLocks.PRIMARY_WORKER): + break + + time.monotonic() + time_elapsed = time.monotonic() - time_start + logger.info( + f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" + ) + if time_elapsed > WAIT_LIMIT: + msg = ( + f"Primary worker was not ready within the timeout. " + f"({WAIT_LIMIT} seconds). Exiting..." + ) + logger.error(msg) + raise WorkerShutdown(msg) + + time.sleep(WAIT_INTERVAL) + + logger.info("Wait for primary worker completed successfully. Continuing...") + return + + logger.info("Running as the primary celery worker.") + + # This is singleton work that should be done on startup exactly once + # by the primary worker + r = get_redis_client() + + # For the moment, we're assuming that we are the only primary worker + # that should be running. + # TODO: maybe check for or clean up another zombie primary worker if we detect it + r.delete(DanswerRedisLocks.PRIMARY_WORKER) + + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) + + logger.info("Primary worker lock: Acquire starting.") + acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) + if acquired: + logger.info("Primary worker lock: Acquire succeeded.") + else: + logger.error("Primary worker lock: Acquire failed!") + raise WorkerShutdown("Primary worker lock could not be acquired!") + + sender.primary_worker_lock = lock + r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index 2f840e430..4dc36f016 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -17,6 +17,7 @@ from danswer.db.document import get_documents_for_connector_credential_pair from danswer.db.engine import get_sqlalchemy_engine from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index +from danswer.httpx.httpx_pool import HttpxPool # use this within celery tasks to get celery task specific logging @@ -95,7 +96,9 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: curr_ind_name, sec_ind_name = get_both_index_names(db_session) document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + primary_index_name=curr_ind_name, + secondary_index_name=sec_ind_name, + httpx_client=HttpxPool.get(), ) if len(doc_ids_to_remove) == 0: diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 0ae214ca4..5b1c9327d 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -41,6 +41,7 @@ from danswer.db.models import UserGroup from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import UpdateRequest +from danswer.httpx.httpx_pool import HttpxPool from danswer.redis.redis_pool import get_redis_client from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import ( @@ -484,7 +485,9 @@ def vespa_metadata_sync_task(self: Task, document_id: str) -> bool: with Session(get_sqlalchemy_engine()) as db_session: curr_ind_name, sec_ind_name = get_both_index_names(db_session) document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + primary_index_name=curr_ind_name, + secondary_index_name=sec_ind_name, + httpx_client=HttpxPool.get(), ) doc = get_document(document_id, db_session) diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index f57b43bdf..b9cc2cfeb 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -154,6 +154,7 @@ def document_by_cc_pair_cleanup_task( # delete it from vespa and the db timing["db_read"] = time.monotonic() document_index.delete(doc_ids=[document_id]) + # document_index.delete_single(doc_id=document_id) timing["indexed"] = time.monotonic() delete_documents_complete__no_commit( db_session=db_session, @@ -202,7 +203,8 @@ def document_by_cc_pair_cleanup_task( mark_document_as_synced(document_id, db_session) else: - pass + timing["db_read"] = time.monotonic() + timing["indexed"] = time.monotonic() # update_docs_last_modified__no_commit( # db_session=db_session, diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 460e15bd1..e0f612ff1 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -239,7 +239,7 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int( # Attachments with more chars than this will not be indexed. This is to prevent extremely # large files from freezing indexing. 200,000 is ~100 google doc pages. CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int( - os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000) + os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000_000) ) JIRA_CONNECTOR_LABELS_TO_SKIP = [ diff --git a/backend/danswer/document_index/factory.py b/backend/danswer/document_index/factory.py index aedaec147..1d65d5e99 100644 --- a/backend/danswer/document_index/factory.py +++ b/backend/danswer/document_index/factory.py @@ -1,3 +1,4 @@ +import httpx from sqlalchemy.orm import Session from danswer.db.search_settings import get_current_search_settings @@ -8,13 +9,16 @@ from danswer.document_index.vespa.index import VespaIndex def get_default_document_index( primary_index_name: str, secondary_index_name: str | None, + httpx_client: httpx.Client | None = None, ) -> DocumentIndex: """Primary index is the index that is used for querying/updating etc. Secondary index is for when both the currently used index and the upcoming index both need to be updated, updates are applied to both indices""" # Currently only supporting Vespa return VespaIndex( - index_name=primary_index_name, secondary_index_name=secondary_index_name + index_name=primary_index_name, + secondary_index_name=secondary_index_name, + httpx_client=httpx_client, ) diff --git a/backend/danswer/document_index/interfaces.py b/backend/danswer/document_index/interfaces.py index eaa34b377..c3a521215 100644 --- a/backend/danswer/document_index/interfaces.py +++ b/backend/danswer/document_index/interfaces.py @@ -166,6 +166,16 @@ class Deletable(abc.ABC): """ raise NotImplementedError + @abc.abstractmethod + def delete_single(self, doc_id: str) -> None: + """ + Given a single document ids, hard delete it from the document index + + Parameters: + - doc_id: document id as specified by the connector + """ + raise NotImplementedError + class Updatable(abc.ABC): """ diff --git a/backend/danswer/document_index/vespa/chunk_retrieval.py b/backend/danswer/document_index/vespa/chunk_retrieval.py index e4b2ad83c..cec04234d 100644 --- a/backend/danswer/document_index/vespa/chunk_retrieval.py +++ b/backend/danswer/document_index/vespa/chunk_retrieval.py @@ -7,6 +7,7 @@ from datetime import timezone from typing import Any from typing import cast +import httpx import requests from retry import retry @@ -149,6 +150,7 @@ def _get_chunks_via_visit_api( chunk_request: VespaChunkRequest, index_name: str, filters: IndexFilters, + http_client: httpx.Client, field_names: list[str] | None = None, get_large_chunks: bool = False, ) -> list[dict]: @@ -181,21 +183,22 @@ def _get_chunks_via_visit_api( selection += f" and {index_name}.large_chunk_reference_ids == null" # Setting up the selection criteria in the query parameters - params = { - # NOTE: Document Selector Language doesn't allow `contains`, so we can't check - # for the ACL in the selection. Instead, we have to check as a postfilter - "selection": selection, - "continuation": None, - "wantedDocumentCount": 1_000, - "fieldSet": field_set, - } + params = httpx.QueryParams( + { + # NOTE: Document Selector Language doesn't allow `contains`, so we can't check + # for the ACL in the selection. Instead, we have to check as a postfilter + "selection": selection, + "wantedDocumentCount": 1_000, + "fieldSet": field_set, + } + ) document_chunks: list[dict] = [] while True: - response = requests.get(url, params=params) + response = http_client.get(url, params=params) try: response.raise_for_status() - except requests.HTTPError as e: + except httpx.HTTPStatusError as e: request_info = f"Headers: {response.request.headers}\nPayload: {params}" response_info = f"Status Code: {response.status_code}\nResponse Content: {response.text}" error_base = f"Error occurred getting chunk by Document ID {chunk_request.document_id}" @@ -205,7 +208,9 @@ def _get_chunks_via_visit_api( f"{response_info}\n" f"Exception: {e}" ) - raise requests.HTTPError(error_base) from e + raise httpx.HTTPStatusError( + error_base, request=e.request, response=e.response + ) from e # Check if the response contains any documents response_data = response.json() @@ -221,17 +226,21 @@ def _get_chunks_via_visit_api( document_chunks.append(document) # Check for continuation token to handle pagination - if "continuation" in response_data and response_data["continuation"]: - params["continuation"] = response_data["continuation"] - else: + if "continuation" not in response_data: break # Exit loop if no continuation token + if not response_data["continuation"]: + break # Exit loop if continuation token is empty + + params = params.set("continuation", response_data["continuation"]) + return document_chunks def get_all_vespa_ids_for_document_id( document_id: str, index_name: str, + http_client: httpx.Client, filters: IndexFilters | None = None, get_large_chunks: bool = False, ) -> list[str]: @@ -239,6 +248,7 @@ def get_all_vespa_ids_for_document_id( chunk_request=VespaChunkRequest(document_id=document_id), index_name=index_name, filters=filters or IndexFilters(access_control_list=None), + http_client=http_client, field_names=[DOCUMENT_ID], get_large_chunks=get_large_chunks, ) diff --git a/backend/danswer/document_index/vespa/deletion.py b/backend/danswer/document_index/vespa/deletion.py index 742c3ad00..61f66fb4c 100644 --- a/backend/danswer/document_index/vespa/deletion.py +++ b/backend/danswer/document_index/vespa/deletion.py @@ -26,6 +26,7 @@ def _delete_vespa_doc_chunks( doc_chunk_ids = get_all_vespa_ids_for_document_id( document_id=document_id, index_name=index_name, + http_client=http_client, get_large_chunks=True, ) @@ -47,7 +48,11 @@ def _delete_vespa_doc_chunks( t_delete = t["end"] - t["chunks_fetched"] t_all = t["end"] - t["start"] logger.info( - f"chunk_fetch={t_chunk_fetch:.2f} delete={t_delete:.2f} all={t_all:.2f}" + f"_delete_vespa_doc_chunks: " + f"len={len(doc_chunk_ids)} " + f"chunk_fetch={t_chunk_fetch:.2f} " + f"delete={t_delete:.2f} " + f"all={t_all:.2f}" ) diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 125e6bb85..ee255bd75 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -13,6 +13,7 @@ from typing import cast import httpx import requests +from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.chat_configs import DOC_TIME_DECAY from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.chat_configs import TITLE_CONTENT_RATIO @@ -110,9 +111,15 @@ def add_ngrams_to_schema(schema_content: str) -> str: class VespaIndex(DocumentIndex): - def __init__(self, index_name: str, secondary_index_name: str | None) -> None: + def __init__( + self, + index_name: str, + secondary_index_name: str | None, + httpx_client: httpx.Client | None = None, + ) -> None: self.index_name = index_name self.secondary_index_name = secondary_index_name + self.httpx_client = httpx_client or httpx.Client(http2=True) def ensure_indices_exist( self, @@ -204,8 +211,12 @@ class VespaIndex(DocumentIndex): # indexing / updates / deletes since we have to make a large volume of requests. with ( concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor, - httpx.Client(http2=True) as http_client, + httpx.Client(http2=True) as http_temp_client, ): + httpx_client = self.httpx_client + if not httpx_client: + httpx_client = http_temp_client + # Check for existing documents, existing documents need to have all of their chunks deleted # prior to indexing as the document size (num chunks) may have shrunk first_chunks = [chunk for chunk in cleaned_chunks if chunk.chunk_id == 0] @@ -214,7 +225,7 @@ class VespaIndex(DocumentIndex): get_existing_documents_from_chunks( chunks=chunk_batch, index_name=self.index_name, - http_client=http_client, + http_client=httpx_client, executor=executor, ) ) @@ -223,7 +234,7 @@ class VespaIndex(DocumentIndex): delete_vespa_docs( document_ids=doc_id_batch, index_name=self.index_name, - http_client=http_client, + http_client=httpx_client, executor=executor, ) @@ -231,7 +242,7 @@ class VespaIndex(DocumentIndex): batch_index_vespa_chunks( chunks=chunk_batch, index_name=self.index_name, - http_client=http_client, + http_client=httpx_client, executor=executor, ) @@ -248,6 +259,7 @@ class VespaIndex(DocumentIndex): @staticmethod def _apply_updates_batched( updates: list[_VespaUpdateRequest], + http_client: httpx.Client, batch_size: int = BATCH_SIZE, ) -> None: """Runs a batch of updates in parallel via the ThreadPoolExecutor.""" @@ -266,10 +278,7 @@ class VespaIndex(DocumentIndex): # NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for # indexing / updates / deletes since we have to make a large volume of requests. - with ( - concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor, - httpx.Client(http2=True) as http_client, - ): + with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: for update_batch in batch_generator(updates, batch_size): future_to_document_id = { executor.submit( @@ -309,12 +318,20 @@ class VespaIndex(DocumentIndex): index_names.append(self.secondary_index_name) chunk_id_start_time = time.monotonic() - with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + with ( + concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor, + httpx.Client(http2=True) as http_temp_client, + ): + httpx_client = self.httpx_client + if not httpx_client: + httpx_client = http_temp_client + future_to_doc_chunk_ids = { executor.submit( get_all_vespa_ids_for_document_id, document_id=document_id, index_name=index_name, + http_client=httpx_client, filters=None, get_large_chunks=True, ): (document_id, index_name) @@ -370,8 +387,15 @@ class VespaIndex(DocumentIndex): update_request=update_dict, ) ) + with httpx.Client(http2=True) as http_temp_client: + httpx_client = self.httpx_client + if not httpx_client: + httpx_client = http_temp_client + + self._apply_updates_batched( + processed_updates_requests, http_client=httpx_client + ) - self._apply_updates_batched(processed_updates_requests) logger.debug( "Finished updating Vespa documents in %.2f seconds", time.monotonic() - update_start, @@ -402,24 +426,26 @@ class VespaIndex(DocumentIndex): if self.secondary_index_name: index_names.append(self.secondary_index_name) - # chunk_id_start_time = time.monotonic() timing["chunk_fetch_start"] = time.monotonic() - all_doc_chunk_ids: list[str] = [] - for index_name in index_names: - for document_id in update_request.document_ids: - # this calls vespa and can raise http exceptions - doc_chunk_ids = get_all_vespa_ids_for_document_id( - document_id=document_id, - index_name=index_name, - filters=None, - get_large_chunks=True, - ) - all_doc_chunk_ids.extend(doc_chunk_ids) + with httpx.Client(http2=True) as http_temp_client: + httpx_client = self.httpx_client + if not httpx_client: + httpx_client = http_temp_client + + all_doc_chunk_ids: list[str] = [] + for index_name in index_names: + for document_id in update_request.document_ids: + # this calls vespa and can raise http exceptions + doc_chunk_ids = get_all_vespa_ids_for_document_id( + document_id=document_id, + index_name=index_name, + http_client=httpx_client, + filters=None, + get_large_chunks=True, + ) + all_doc_chunk_ids.extend(doc_chunk_ids) + timing["chunk_fetch_end"] = time.monotonic() - timing_chunk_fetch = timing["chunk_fetch_end"] - timing["chunk_fetch_start"] - logger.debug( - f"Took {timing_chunk_fetch:.2f} seconds to fetch all Vespa chunk IDs" - ) # Build the _VespaUpdateRequest objects update_dict: dict[str, dict] = {"fields": {}} @@ -453,9 +479,13 @@ class VespaIndex(DocumentIndex): ) ) - with httpx.Client(http2=True) as http_client: + with httpx.Client(http2=True) as http_temp_client: + httpx_client = self.httpx_client + if not httpx_client: + httpx_client = http_temp_client + for update in processed_update_requests: - http_client.put( + httpx_client.put( update.url, headers={"Content-Type": "application/json"}, json=update.update_request, @@ -488,19 +518,81 @@ class VespaIndex(DocumentIndex): # NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for # indexing / updates / deletes since we have to make a large volume of requests. - with httpx.Client(http2=True) as http_client: - index_names = [self.index_name] - if self.secondary_index_name: - index_names.append(self.secondary_index_name) + index_names = [self.index_name] + if self.secondary_index_name: + index_names.append(self.secondary_index_name) + + with httpx.Client(http2=True) as http_temp_client: + httpx_client = self.httpx_client + if not httpx_client: + httpx_client = http_temp_client for index_name in index_names: delete_vespa_docs( - document_ids=doc_ids, index_name=index_name, http_client=http_client + document_ids=doc_ids, + index_name=index_name, + http_client=httpx_client, ) t_all = time.monotonic() - time_start logger.info(f"VespaIndex.delete: all={t_all:.2f}") + def delete_single(self, doc_id: str) -> None: + # Vespa deletion is poorly documented ... luckily we found this + # https://docs.vespa.ai/en/operations/batch-delete.html#example + + time_start = time.monotonic() + + doc_id = replace_invalid_doc_id_characters(doc_id) + + # NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for + # indexing / updates / deletes since we have to make a large volume of requests. + index_names = [self.index_name] + if self.secondary_index_name: + index_names.append(self.secondary_index_name) + + # if self.httpx_client: + # for index_name in index_names: + # _delete_vespa_doc_chunks(document_id=doc_id, index_name=index_name, http_client=self.httpx_client) + # else: + # with httpx.Client(http2=True) as httpx_client: + # for index_name in index_names: + # _delete_vespa_doc_chunks(document_id=doc_id, index_name=index_name, http_client=httpx_client) + + for index_name in index_names: + params = httpx.QueryParams( + { + "selection": f"{index_name}.document_id=='{doc_id}'", + "cluster": DOCUMENT_INDEX_NAME, + } + ) + + while True: + try: + resp = self.httpx_client.delete( + f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}", + params=params, + ) + resp.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error(f"Failed to delete chunk, details: {e.response.text}") + raise + + resp_data = resp.json() + if "documentCount" in resp_data: + count = resp_data["documentCount"] + logger.info(f"VespaIndex.delete_single: chunks_deleted={count}") + + # Check for continuation token to handle pagination + if "continuation" not in resp_data: + break # Exit loop if no continuation token + + if not resp_data["continuation"]: + break # Exit loop if continuation token is empty + + t_all = time.monotonic() - time_start + logger.info(f"VespaIndex.delete_single: all={t_all:.2f}") + def id_based_retrieval( self, chunk_requests: list[VespaChunkRequest], diff --git a/backend/danswer/httpx/httpx_pool.py b/backend/danswer/httpx/httpx_pool.py new file mode 100644 index 000000000..e5a13a34b --- /dev/null +++ b/backend/danswer/httpx/httpx_pool.py @@ -0,0 +1,42 @@ +import threading +from typing import Any + +import httpx + + +class HttpxPool: + """Class to manage a global httpx Client instance""" + + _client: httpx.Client | None = None + _lock: threading.Lock = threading.Lock() + + # Default parameters for creation + DEFAULT_KWARGS = { + "http2": True, + "limits": httpx.Limits(), + } + + def __init__(self) -> None: + pass + + @classmethod + def _init_client(cls, **kwargs: Any) -> httpx.Client: + """Private helper method to create and return an httpx.Client.""" + merged_kwargs = {**cls.DEFAULT_KWARGS, **kwargs} + return httpx.Client(**merged_kwargs) + + @classmethod + def init_client(cls, **kwargs: Any) -> None: + """Allow the caller to init the client with extra params.""" + with cls._lock: + if not cls._client: + cls._client = cls._init_client(**kwargs) + + @classmethod + def get(cls) -> httpx.Client: + """Gets the httpx.Client. Will init to default settings if not init'd.""" + if not cls._client: + with cls._lock: + if not cls._client: + cls._client = cls._init_client() + return cls._client diff --git a/backend/ee/danswer/background/celery/tasks/vespa/tasks.py b/backend/ee/danswer/background/celery/tasks/vespa/tasks.py index d194b2ef9..805083efc 100644 --- a/backend/ee/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/ee/danswer/background/celery/tasks/vespa/tasks.py @@ -34,7 +34,7 @@ def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) - count = cast(int, r.scard(rug.taskset_key)) task_logger.info( - f"User group sync: usergroup_id={usergroup_id} remaining={count} initial={initial_count}" + f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}" ) if count > 0: return diff --git a/web/src/app/admin/assistants/PersonaTable.tsx b/web/src/app/admin/assistants/PersonaTable.tsx index e1e871372..dac260a18 100644 --- a/web/src/app/admin/assistants/PersonaTable.tsx +++ b/web/src/app/admin/assistants/PersonaTable.tsx @@ -20,12 +20,16 @@ import { UserRole, User } from "@/lib/types"; import { useUser } from "@/components/user/UserProvider"; function PersonaTypeDisplay({ persona }: { persona: Persona }) { - if (persona.is_default_persona) { + if (persona.builtin_persona) { return Built-In; } + if (persona.is_default_persona) { + return Default; + } + if (persona.is_public) { - return Global; + return Public; } if (persona.groups.length > 0 || persona.users.length > 0) { diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 8f9b36b62..3513b2dc3 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -759,7 +759,15 @@ export function ChatPage({ setAboveHorizon(scrollDist.current > 500); }; - scrollableDivRef?.current?.addEventListener("scroll", updateScrollTracking); + useEffect(() => { + const scrollableDiv = scrollableDivRef.current; + if (scrollableDiv) { + scrollableDiv.addEventListener("scroll", updateScrollTracking); + return () => { + scrollableDiv.removeEventListener("scroll", updateScrollTracking); + }; + } + }, []); const handleInputResize = () => { setTimeout(() => { @@ -1137,7 +1145,9 @@ export function ChatPage({ await delay(50); while (!stack.isComplete || !stack.isEmpty()) { - await delay(0.5); + if (stack.isEmpty()) { + await delay(0.5); + } if (!stack.isEmpty() && !controller.signal.aborted) { const packet = stack.nextPacket(); diff --git a/web/src/app/chat/message/CodeBlock.tsx b/web/src/app/chat/message/CodeBlock.tsx index 55a6ea7be..66cc82a6e 100644 --- a/web/src/app/chat/message/CodeBlock.tsx +++ b/web/src/app/chat/message/CodeBlock.tsx @@ -1,20 +1,22 @@ import React, { useState, ReactNode, useCallback, useMemo, memo } from "react"; import { FiCheck, FiCopy } from "react-icons/fi"; -const CODE_BLOCK_PADDING_TYPE = { padding: "1rem" }; +const CODE_BLOCK_PADDING = { padding: "1rem" }; interface CodeBlockProps { - className?: string | undefined; + className?: string; children?: ReactNode; - content: string; - [key: string]: any; + codeText: string; } +const MemoizedCodeLine = memo(({ content }: { content: ReactNode }) => ( + <>{content} +)); + export const CodeBlock = memo(function CodeBlock({ className = "", children, - content, - ...props + codeText, }: CodeBlockProps) { const [copied, setCopied] = useState(false); @@ -26,132 +28,99 @@ export const CodeBlock = memo(function CodeBlock({ .join(" "); }, [className]); - const codeText = useMemo(() => { - let codeText: string | null = null; - if ( - props.node?.position?.start?.offset && - props.node?.position?.end?.offset - ) { - codeText = content.slice( - props.node.position.start.offset, - props.node.position.end.offset - ); - codeText = codeText.trim(); + const handleCopy = useCallback(() => { + if (!codeText) return; + navigator.clipboard.writeText(codeText).then(() => { + setCopied(true); + setTimeout(() => setCopied(false), 2000); + }); + }, [codeText]); - // Find the last occurrence of closing backticks - const lastBackticksIndex = codeText.lastIndexOf("```"); - if (lastBackticksIndex !== -1) { - codeText = codeText.slice(0, lastBackticksIndex + 3); - } + const CopyButton = memo(() => ( +
+ {copied ? ( +
+ + Copied! +
+ ) : ( +
+ + Copy code +
+ )} +
+ )); + CopyButton.displayName = "CopyButton"; - // Remove the language declaration and trailing backticks - const codeLines = codeText.split("\n"); - if ( - codeLines.length > 1 && - (codeLines[0].startsWith("```") || - codeLines[0].trim().startsWith("```")) - ) { - codeLines.shift(); // Remove the first line with the language declaration - if ( - codeLines[codeLines.length - 1] === "```" || - codeLines[codeLines.length - 1]?.trim() === "```" - ) { - codeLines.pop(); // Remove the last line with the trailing backticks - } - - const minIndent = codeLines - .filter((line) => line.trim().length > 0) - .reduce((min, line) => { - const match = line.match(/^\s*/); - return Math.min(min, match ? match[0].length : 0); - }, Infinity); - - const formattedCodeLines = codeLines.map((line) => - line.slice(minIndent) + const CodeContent = memo(() => { + if (!language) { + if (typeof children === "string") { + return ( + + {children} + ); - codeText = formattedCodeLines.join("\n"); } - } - - // handle unknown languages. They won't have a `node.position.start.offset` - if (!codeText) { - const findTextNode = (node: any): string | null => { - if (node.type === "text") { - return node.value; - } - let finalResult = ""; - if (node.children) { - for (const child of node.children) { - const result = findTextNode(child); - if (result) { - finalResult += result; - } - } - } - return finalResult; - }; - - codeText = findTextNode(props.node); - } - - return codeText; - }, [content, props.node]); - - const handleCopy = useCallback( - (event: React.MouseEvent) => { - event.preventDefault(); - if (!codeText) { - return; - } - - navigator.clipboard.writeText(codeText).then(() => { - setCopied(true); - setTimeout(() => setCopied(false), 2000); - }); - }, - [codeText] - ); - - if (!language) { - if (typeof children === "string") { - return {children}; + return ( +
+          
+            {Array.isArray(children)
+              ? children.map((child, index) => (
+                  
+                ))
+              : children}
+          
+        
+ ); } return ( -
-        
-          {children}
+      
+        
+          {Array.isArray(children)
+            ? children.map((child, index) => (
+                
+              ))
+            : children}
         
       
); - } + }); + CodeContent.displayName = "CodeContent"; return (
-
- {language} - {codeText && ( -
- {copied ? ( -
- - Copied! -
- ) : ( -
- - Copy code -
- )} -
- )} -
-
-        {children}
-      
+ {language && ( +
+ {language} + {codeText && } +
+ )} +
); }); + +CodeBlock.displayName = "CodeBlock"; +MemoizedCodeLine.displayName = "MemoizedCodeLine"; diff --git a/web/src/app/chat/message/MemoizedTextComponents.tsx b/web/src/app/chat/message/MemoizedTextComponents.tsx index 4ab8bc810..9ab0e28e3 100644 --- a/web/src/app/chat/message/MemoizedTextComponents.tsx +++ b/web/src/app/chat/message/MemoizedTextComponents.tsx @@ -25,9 +25,9 @@ export const MemoizedLink = memo((props: any) => { } }); -export const MemoizedParagraph = memo(({ node, ...props }: any) => ( -

-)); +export const MemoizedParagraph = memo(({ ...props }: any) => { + return

; +}); MemoizedLink.displayName = "MemoizedLink"; MemoizedParagraph.displayName = "MemoizedParagraph"; diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index edb18138c..e10f5cea0 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -54,6 +54,7 @@ import RegenerateOption from "../RegenerateOption"; import { LlmOverride } from "@/lib/hooks"; import { ContinueGenerating } from "./ContinueMessage"; import { MemoizedLink, MemoizedParagraph } from "./MemoizedTextComponents"; +import { extractCodeText } from "./codeUtils"; const TOOLS_WITH_CUSTOM_HANDLING = [ SEARCH_TOOL_NAME, @@ -253,6 +254,40 @@ export const AIMessage = ({ new Set((docs || []).map((doc) => doc.source_type)) ).slice(0, 3); + const markdownComponents = useMemo( + () => ({ + a: MemoizedLink, + p: MemoizedParagraph, + code: ({ node, inline, className, children, ...props }: any) => { + const codeText = extractCodeText( + node, + finalContent as string, + children + ); + + return ( + + {children} + + ); + }, + }), + [messageId, content] + ); + + const renderedMarkdown = useMemo(() => { + return ( + + {finalContent as string} + + ); + }, [finalContent]); + const includeMessageSwitcher = currentMessageInd !== undefined && onMessageSelection && @@ -352,27 +387,7 @@ export const AIMessage = ({ {typeof content === "string" ? (

- ( - - ), - }} - remarkPlugins={[remarkGfm]} - rehypePlugins={[ - [rehypePrism, { ignoreMissing: true }], - ]} - > - {finalContent as string} - + {renderedMarkdown}
) : ( content diff --git a/web/src/app/chat/message/codeUtils.ts b/web/src/app/chat/message/codeUtils.ts new file mode 100644 index 000000000..2aaae71bc --- /dev/null +++ b/web/src/app/chat/message/codeUtils.ts @@ -0,0 +1,47 @@ +export function extractCodeText( + node: any, + content: string, + children: React.ReactNode +): string { + let codeText: string | null = null; + if ( + node?.position?.start?.offset != null && + node?.position?.end?.offset != null + ) { + codeText = content.slice( + node.position.start.offset, + node.position.end.offset + ); + codeText = codeText.trim(); + + // Find the last occurrence of closing backticks + const lastBackticksIndex = codeText.lastIndexOf("```"); + if (lastBackticksIndex !== -1) { + codeText = codeText.slice(0, lastBackticksIndex + 3); + } + + // Remove the language declaration and trailing backticks + const codeLines = codeText.split("\n"); + if (codeLines.length > 1 && codeLines[0].trim().startsWith("```")) { + codeLines.shift(); // Remove the first line with the language declaration + if (codeLines[codeLines.length - 1]?.trim() === "```") { + codeLines.pop(); // Remove the last line with the trailing backticks + } + + const minIndent = codeLines + .filter((line) => line.trim().length > 0) + .reduce((min, line) => { + const match = line.match(/^\s*/); + return Math.min(min, match ? match[0].length : 0); + }, Infinity); + + const formattedCodeLines = codeLines.map((line) => line.slice(minIndent)); + codeText = formattedCodeLines.join("\n"); + } + } else { + // Fallback if position offsets are not available + codeText = children?.toString() || null; + } + + return codeText || ""; +} diff --git a/web/src/components/chat_search/MinimalMarkdown.tsx b/web/src/components/chat_search/MinimalMarkdown.tsx index 3516749d1..4731e2de9 100644 --- a/web/src/components/chat_search/MinimalMarkdown.tsx +++ b/web/src/components/chat_search/MinimalMarkdown.tsx @@ -1,4 +1,5 @@ import { CodeBlock } from "@/app/chat/message/CodeBlock"; +import { extractCodeText } from "@/app/chat/message/codeUtils"; import { MemoizedLink, MemoizedParagraph, @@ -10,13 +11,11 @@ import remarkGfm from "remark-gfm"; interface MinimalMarkdownProps { content: string; className?: string; - useCodeBlock?: boolean; } export const MinimalMarkdown: React.FC = ({ content, className = "", - useCodeBlock = false, }) => { return ( = ({ components={{ a: MemoizedLink, p: MemoizedParagraph, - code: useCodeBlock - ? (props) => ( - - ) - : (props) => , + code: ({ node, inline, className, children, ...props }: any) => { + const codeText = extractCodeText(node, content, children); + + return ( + + {children} + + ); + }, }} remarkPlugins={[remarkGfm]} > diff --git a/web/src/components/search/results/AnswerSection.tsx b/web/src/components/search/results/AnswerSection.tsx index 225623b00..324e41e0a 100644 --- a/web/src/components/search/results/AnswerSection.tsx +++ b/web/src/components/search/results/AnswerSection.tsx @@ -1,7 +1,5 @@ import { Quote } from "@/lib/search/interfaces"; import { ResponseSection, StatusOptions } from "./ResponseSection"; -import ReactMarkdown from "react-markdown"; -import remarkGfm from "remark-gfm"; import { MinimalMarkdown } from "@/components/chat_search/MinimalMarkdown"; const TEMP_STRING = "__$%^TEMP$%^__"; @@ -40,12 +38,7 @@ export const AnswerSection = (props: AnswerSectionProps) => { status = "success"; header = <>; - body = ( - - ); + body = ; // error while building answer (NOTE: if error occurs during quote generation // the above if statement will hit and the error will not be displayed) @@ -61,9 +54,7 @@ export const AnswerSection = (props: AnswerSectionProps) => { } else if (props.answer) { status = "success"; header = <>; - body = ( - - ); + body = ; } return (