has delete_single working in optimized fashion

This commit is contained in:
Richard Kuo (Danswer) 2024-10-02 16:32:36 -07:00
commit b22848be52
20 changed files with 523 additions and 221 deletions

View File

@ -3,6 +3,7 @@ import time
from datetime import timedelta
from typing import Any
import httpx
import redis
from celery import bootsteps # type: ignore
from celery import Celery
@ -30,6 +31,7 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.httpx.httpx_pool import HttpxPool
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
@ -113,12 +115,16 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # a few extra connections for side operations
# decide some initial startup settings based on the celery worker's hostname
# (set at the command line)
hostname = sender.hostname
if hostname.startswith("light"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
SqlEngine.init_engine(
pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY
)
elif hostname.startswith("heavy"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
@ -126,6 +132,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
HttpxPool.init_client(
limits=httpx.Limits(
max_keepalive_connections=sender.concurrency + EXTRA_CONCURRENCY
)
)
r = get_redis_client()
WAIT_INTERVAL = 5
@ -212,6 +224,86 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
sender.primary_worker_lock = lock
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Redis: Readiness check succeeded. Continuing...")
if not celery_is_worker_primary(sender):
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for primary worker to be ready...")
time_start = time.monotonic()
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time.monotonic()
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
return
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client()
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
sender.primary_worker_lock = lock
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)

View File

@ -17,6 +17,7 @@ from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.httpx.httpx_pool import HttpxPool
# use this within celery tasks to get celery task specific logging
@ -95,7 +96,9 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
primary_index_name=curr_ind_name,
secondary_index_name=sec_ind_name,
httpx_client=HttpxPool.get(),
)
if len(doc_ids_to_remove) == 0:

View File

@ -41,6 +41,7 @@ from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.httpx.httpx_pool import HttpxPool
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
@ -484,7 +485,9 @@ def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
primary_index_name=curr_ind_name,
secondary_index_name=sec_ind_name,
httpx_client=HttpxPool.get(),
)
doc = get_document(document_id, db_session)

View File

@ -154,6 +154,7 @@ def document_by_cc_pair_cleanup_task(
# delete it from vespa and the db
timing["db_read"] = time.monotonic()
document_index.delete(doc_ids=[document_id])
# document_index.delete_single(doc_id=document_id)
timing["indexed"] = time.monotonic()
delete_documents_complete__no_commit(
db_session=db_session,
@ -202,7 +203,8 @@ def document_by_cc_pair_cleanup_task(
mark_document_as_synced(document_id, db_session)
else:
pass
timing["db_read"] = time.monotonic()
timing["indexed"] = time.monotonic()
# update_docs_last_modified__no_commit(
# db_session=db_session,

View File

@ -239,7 +239,7 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
# Attachments with more chars than this will not be indexed. This is to prevent extremely
# large files from freezing indexing. 200,000 is ~100 google doc pages.
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000_000)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [

View File

@ -1,3 +1,4 @@
import httpx
from sqlalchemy.orm import Session
from danswer.db.search_settings import get_current_search_settings
@ -8,13 +9,16 @@ from danswer.document_index.vespa.index import VespaIndex
def get_default_document_index(
primary_index_name: str,
secondary_index_name: str | None,
httpx_client: httpx.Client | None = None,
) -> DocumentIndex:
"""Primary index is the index that is used for querying/updating etc.
Secondary index is for when both the currently used index and the upcoming
index both need to be updated, updates are applied to both indices"""
# Currently only supporting Vespa
return VespaIndex(
index_name=primary_index_name, secondary_index_name=secondary_index_name
index_name=primary_index_name,
secondary_index_name=secondary_index_name,
httpx_client=httpx_client,
)

View File

@ -166,6 +166,16 @@ class Deletable(abc.ABC):
"""
raise NotImplementedError
@abc.abstractmethod
def delete_single(self, doc_id: str) -> None:
"""
Given a single document ids, hard delete it from the document index
Parameters:
- doc_id: document id as specified by the connector
"""
raise NotImplementedError
class Updatable(abc.ABC):
"""

View File

@ -7,6 +7,7 @@ from datetime import timezone
from typing import Any
from typing import cast
import httpx
import requests
from retry import retry
@ -149,6 +150,7 @@ def _get_chunks_via_visit_api(
chunk_request: VespaChunkRequest,
index_name: str,
filters: IndexFilters,
http_client: httpx.Client,
field_names: list[str] | None = None,
get_large_chunks: bool = False,
) -> list[dict]:
@ -181,21 +183,22 @@ def _get_chunks_via_visit_api(
selection += f" and {index_name}.large_chunk_reference_ids == null"
# Setting up the selection criteria in the query parameters
params = {
# NOTE: Document Selector Language doesn't allow `contains`, so we can't check
# for the ACL in the selection. Instead, we have to check as a postfilter
"selection": selection,
"continuation": None,
"wantedDocumentCount": 1_000,
"fieldSet": field_set,
}
params = httpx.QueryParams(
{
# NOTE: Document Selector Language doesn't allow `contains`, so we can't check
# for the ACL in the selection. Instead, we have to check as a postfilter
"selection": selection,
"wantedDocumentCount": 1_000,
"fieldSet": field_set,
}
)
document_chunks: list[dict] = []
while True:
response = requests.get(url, params=params)
response = http_client.get(url, params=params)
try:
response.raise_for_status()
except requests.HTTPError as e:
except httpx.HTTPStatusError as e:
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
response_info = f"Status Code: {response.status_code}\nResponse Content: {response.text}"
error_base = f"Error occurred getting chunk by Document ID {chunk_request.document_id}"
@ -205,7 +208,9 @@ def _get_chunks_via_visit_api(
f"{response_info}\n"
f"Exception: {e}"
)
raise requests.HTTPError(error_base) from e
raise httpx.HTTPStatusError(
error_base, request=e.request, response=e.response
) from e
# Check if the response contains any documents
response_data = response.json()
@ -221,17 +226,21 @@ def _get_chunks_via_visit_api(
document_chunks.append(document)
# Check for continuation token to handle pagination
if "continuation" in response_data and response_data["continuation"]:
params["continuation"] = response_data["continuation"]
else:
if "continuation" not in response_data:
break # Exit loop if no continuation token
if not response_data["continuation"]:
break # Exit loop if continuation token is empty
params = params.set("continuation", response_data["continuation"])
return document_chunks
def get_all_vespa_ids_for_document_id(
document_id: str,
index_name: str,
http_client: httpx.Client,
filters: IndexFilters | None = None,
get_large_chunks: bool = False,
) -> list[str]:
@ -239,6 +248,7 @@ def get_all_vespa_ids_for_document_id(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=filters or IndexFilters(access_control_list=None),
http_client=http_client,
field_names=[DOCUMENT_ID],
get_large_chunks=get_large_chunks,
)

View File

@ -26,6 +26,7 @@ def _delete_vespa_doc_chunks(
doc_chunk_ids = get_all_vespa_ids_for_document_id(
document_id=document_id,
index_name=index_name,
http_client=http_client,
get_large_chunks=True,
)
@ -47,7 +48,11 @@ def _delete_vespa_doc_chunks(
t_delete = t["end"] - t["chunks_fetched"]
t_all = t["end"] - t["start"]
logger.info(
f"chunk_fetch={t_chunk_fetch:.2f} delete={t_delete:.2f} all={t_all:.2f}"
f"_delete_vespa_doc_chunks: "
f"len={len(doc_chunk_ids)} "
f"chunk_fetch={t_chunk_fetch:.2f} "
f"delete={t_delete:.2f} "
f"all={t_all:.2f}"
)

View File

@ -13,6 +13,7 @@ from typing import cast
import httpx
import requests
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.chat_configs import DOC_TIME_DECAY
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
@ -110,9 +111,15 @@ def add_ngrams_to_schema(schema_content: str) -> str:
class VespaIndex(DocumentIndex):
def __init__(self, index_name: str, secondary_index_name: str | None) -> None:
def __init__(
self,
index_name: str,
secondary_index_name: str | None,
httpx_client: httpx.Client | None = None,
) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
self.httpx_client = httpx_client or httpx.Client(http2=True)
def ensure_indices_exist(
self,
@ -204,8 +211,12 @@ class VespaIndex(DocumentIndex):
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx.Client(http2=True) as http_client,
httpx.Client(http2=True) as http_temp_client,
):
httpx_client = self.httpx_client
if not httpx_client:
httpx_client = http_temp_client
# Check for existing documents, existing documents need to have all of their chunks deleted
# prior to indexing as the document size (num chunks) may have shrunk
first_chunks = [chunk for chunk in cleaned_chunks if chunk.chunk_id == 0]
@ -214,7 +225,7 @@ class VespaIndex(DocumentIndex):
get_existing_documents_from_chunks(
chunks=chunk_batch,
index_name=self.index_name,
http_client=http_client,
http_client=httpx_client,
executor=executor,
)
)
@ -223,7 +234,7 @@ class VespaIndex(DocumentIndex):
delete_vespa_docs(
document_ids=doc_id_batch,
index_name=self.index_name,
http_client=http_client,
http_client=httpx_client,
executor=executor,
)
@ -231,7 +242,7 @@ class VespaIndex(DocumentIndex):
batch_index_vespa_chunks(
chunks=chunk_batch,
index_name=self.index_name,
http_client=http_client,
http_client=httpx_client,
executor=executor,
)
@ -248,6 +259,7 @@ class VespaIndex(DocumentIndex):
@staticmethod
def _apply_updates_batched(
updates: list[_VespaUpdateRequest],
http_client: httpx.Client,
batch_size: int = BATCH_SIZE,
) -> None:
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
@ -266,10 +278,7 @@ class VespaIndex(DocumentIndex):
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx.Client(http2=True) as http_client,
):
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
for update_batch in batch_generator(updates, batch_size):
future_to_document_id = {
executor.submit(
@ -309,12 +318,20 @@ class VespaIndex(DocumentIndex):
index_names.append(self.secondary_index_name)
chunk_id_start_time = time.monotonic()
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx.Client(http2=True) as http_temp_client,
):
httpx_client = self.httpx_client
if not httpx_client:
httpx_client = http_temp_client
future_to_doc_chunk_ids = {
executor.submit(
get_all_vespa_ids_for_document_id,
document_id=document_id,
index_name=index_name,
http_client=httpx_client,
filters=None,
get_large_chunks=True,
): (document_id, index_name)
@ -370,8 +387,15 @@ class VespaIndex(DocumentIndex):
update_request=update_dict,
)
)
with httpx.Client(http2=True) as http_temp_client:
httpx_client = self.httpx_client
if not httpx_client:
httpx_client = http_temp_client
self._apply_updates_batched(
processed_updates_requests, http_client=httpx_client
)
self._apply_updates_batched(processed_updates_requests)
logger.debug(
"Finished updating Vespa documents in %.2f seconds",
time.monotonic() - update_start,
@ -402,24 +426,26 @@ class VespaIndex(DocumentIndex):
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
# chunk_id_start_time = time.monotonic()
timing["chunk_fetch_start"] = time.monotonic()
all_doc_chunk_ids: list[str] = []
for index_name in index_names:
for document_id in update_request.document_ids:
# this calls vespa and can raise http exceptions
doc_chunk_ids = get_all_vespa_ids_for_document_id(
document_id=document_id,
index_name=index_name,
filters=None,
get_large_chunks=True,
)
all_doc_chunk_ids.extend(doc_chunk_ids)
with httpx.Client(http2=True) as http_temp_client:
httpx_client = self.httpx_client
if not httpx_client:
httpx_client = http_temp_client
all_doc_chunk_ids: list[str] = []
for index_name in index_names:
for document_id in update_request.document_ids:
# this calls vespa and can raise http exceptions
doc_chunk_ids = get_all_vespa_ids_for_document_id(
document_id=document_id,
index_name=index_name,
http_client=httpx_client,
filters=None,
get_large_chunks=True,
)
all_doc_chunk_ids.extend(doc_chunk_ids)
timing["chunk_fetch_end"] = time.monotonic()
timing_chunk_fetch = timing["chunk_fetch_end"] - timing["chunk_fetch_start"]
logger.debug(
f"Took {timing_chunk_fetch:.2f} seconds to fetch all Vespa chunk IDs"
)
# Build the _VespaUpdateRequest objects
update_dict: dict[str, dict] = {"fields": {}}
@ -453,9 +479,13 @@ class VespaIndex(DocumentIndex):
)
)
with httpx.Client(http2=True) as http_client:
with httpx.Client(http2=True) as http_temp_client:
httpx_client = self.httpx_client
if not httpx_client:
httpx_client = http_temp_client
for update in processed_update_requests:
http_client.put(
httpx_client.put(
update.url,
headers={"Content-Type": "application/json"},
json=update.update_request,
@ -488,19 +518,81 @@ class VespaIndex(DocumentIndex):
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
# indexing / updates / deletes since we have to make a large volume of requests.
with httpx.Client(http2=True) as http_client:
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with httpx.Client(http2=True) as http_temp_client:
httpx_client = self.httpx_client
if not httpx_client:
httpx_client = http_temp_client
for index_name in index_names:
delete_vespa_docs(
document_ids=doc_ids, index_name=index_name, http_client=http_client
document_ids=doc_ids,
index_name=index_name,
http_client=httpx_client,
)
t_all = time.monotonic() - time_start
logger.info(f"VespaIndex.delete: all={t_all:.2f}")
def delete_single(self, doc_id: str) -> None:
# Vespa deletion is poorly documented ... luckily we found this
# https://docs.vespa.ai/en/operations/batch-delete.html#example
time_start = time.monotonic()
doc_id = replace_invalid_doc_id_characters(doc_id)
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
# indexing / updates / deletes since we have to make a large volume of requests.
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
# if self.httpx_client:
# for index_name in index_names:
# _delete_vespa_doc_chunks(document_id=doc_id, index_name=index_name, http_client=self.httpx_client)
# else:
# with httpx.Client(http2=True) as httpx_client:
# for index_name in index_names:
# _delete_vespa_doc_chunks(document_id=doc_id, index_name=index_name, http_client=httpx_client)
for index_name in index_names:
params = httpx.QueryParams(
{
"selection": f"{index_name}.document_id=='{doc_id}'",
"cluster": DOCUMENT_INDEX_NAME,
}
)
while True:
try:
resp = self.httpx_client.delete(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
params=params,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"Failed to delete chunk, details: {e.response.text}")
raise
resp_data = resp.json()
if "documentCount" in resp_data:
count = resp_data["documentCount"]
logger.info(f"VespaIndex.delete_single: chunks_deleted={count}")
# Check for continuation token to handle pagination
if "continuation" not in resp_data:
break # Exit loop if no continuation token
if not resp_data["continuation"]:
break # Exit loop if continuation token is empty
t_all = time.monotonic() - time_start
logger.info(f"VespaIndex.delete_single: all={t_all:.2f}")
def id_based_retrieval(
self,
chunk_requests: list[VespaChunkRequest],

View File

@ -0,0 +1,42 @@
import threading
from typing import Any
import httpx
class HttpxPool:
"""Class to manage a global httpx Client instance"""
_client: httpx.Client | None = None
_lock: threading.Lock = threading.Lock()
# Default parameters for creation
DEFAULT_KWARGS = {
"http2": True,
"limits": httpx.Limits(),
}
def __init__(self) -> None:
pass
@classmethod
def _init_client(cls, **kwargs: Any) -> httpx.Client:
"""Private helper method to create and return an httpx.Client."""
merged_kwargs = {**cls.DEFAULT_KWARGS, **kwargs}
return httpx.Client(**merged_kwargs)
@classmethod
def init_client(cls, **kwargs: Any) -> None:
"""Allow the caller to init the client with extra params."""
with cls._lock:
if not cls._client:
cls._client = cls._init_client(**kwargs)
@classmethod
def get(cls) -> httpx.Client:
"""Gets the httpx.Client. Will init to default settings if not init'd."""
if not cls._client:
with cls._lock:
if not cls._client:
cls._client = cls._init_client()
return cls._client

View File

@ -34,7 +34,7 @@ def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -
count = cast(int, r.scard(rug.taskset_key))
task_logger.info(
f"User group sync: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
return

View File

@ -20,12 +20,16 @@ import { UserRole, User } from "@/lib/types";
import { useUser } from "@/components/user/UserProvider";
function PersonaTypeDisplay({ persona }: { persona: Persona }) {
if (persona.is_default_persona) {
if (persona.builtin_persona) {
return <Text>Built-In</Text>;
}
if (persona.is_default_persona) {
return <Text>Default</Text>;
}
if (persona.is_public) {
return <Text>Global</Text>;
return <Text>Public</Text>;
}
if (persona.groups.length > 0 || persona.users.length > 0) {

View File

@ -759,7 +759,15 @@ export function ChatPage({
setAboveHorizon(scrollDist.current > 500);
};
scrollableDivRef?.current?.addEventListener("scroll", updateScrollTracking);
useEffect(() => {
const scrollableDiv = scrollableDivRef.current;
if (scrollableDiv) {
scrollableDiv.addEventListener("scroll", updateScrollTracking);
return () => {
scrollableDiv.removeEventListener("scroll", updateScrollTracking);
};
}
}, []);
const handleInputResize = () => {
setTimeout(() => {
@ -1137,7 +1145,9 @@ export function ChatPage({
await delay(50);
while (!stack.isComplete || !stack.isEmpty()) {
await delay(0.5);
if (stack.isEmpty()) {
await delay(0.5);
}
if (!stack.isEmpty() && !controller.signal.aborted) {
const packet = stack.nextPacket();

View File

@ -1,20 +1,22 @@
import React, { useState, ReactNode, useCallback, useMemo, memo } from "react";
import { FiCheck, FiCopy } from "react-icons/fi";
const CODE_BLOCK_PADDING_TYPE = { padding: "1rem" };
const CODE_BLOCK_PADDING = { padding: "1rem" };
interface CodeBlockProps {
className?: string | undefined;
className?: string;
children?: ReactNode;
content: string;
[key: string]: any;
codeText: string;
}
const MemoizedCodeLine = memo(({ content }: { content: ReactNode }) => (
<>{content}</>
));
export const CodeBlock = memo(function CodeBlock({
className = "",
children,
content,
...props
codeText,
}: CodeBlockProps) {
const [copied, setCopied] = useState(false);
@ -26,132 +28,99 @@ export const CodeBlock = memo(function CodeBlock({
.join(" ");
}, [className]);
const codeText = useMemo(() => {
let codeText: string | null = null;
if (
props.node?.position?.start?.offset &&
props.node?.position?.end?.offset
) {
codeText = content.slice(
props.node.position.start.offset,
props.node.position.end.offset
);
codeText = codeText.trim();
const handleCopy = useCallback(() => {
if (!codeText) return;
navigator.clipboard.writeText(codeText).then(() => {
setCopied(true);
setTimeout(() => setCopied(false), 2000);
});
}, [codeText]);
// Find the last occurrence of closing backticks
const lastBackticksIndex = codeText.lastIndexOf("```");
if (lastBackticksIndex !== -1) {
codeText = codeText.slice(0, lastBackticksIndex + 3);
}
const CopyButton = memo(() => (
<div
className="ml-auto cursor-pointer select-none"
onMouseDown={handleCopy}
>
{copied ? (
<div className="flex items-center space-x-2">
<FiCheck size={16} />
<span>Copied!</span>
</div>
) : (
<div className="flex items-center space-x-2">
<FiCopy size={16} />
<span>Copy code</span>
</div>
)}
</div>
));
CopyButton.displayName = "CopyButton";
// Remove the language declaration and trailing backticks
const codeLines = codeText.split("\n");
if (
codeLines.length > 1 &&
(codeLines[0].startsWith("```") ||
codeLines[0].trim().startsWith("```"))
) {
codeLines.shift(); // Remove the first line with the language declaration
if (
codeLines[codeLines.length - 1] === "```" ||
codeLines[codeLines.length - 1]?.trim() === "```"
) {
codeLines.pop(); // Remove the last line with the trailing backticks
}
const minIndent = codeLines
.filter((line) => line.trim().length > 0)
.reduce((min, line) => {
const match = line.match(/^\s*/);
return Math.min(min, match ? match[0].length : 0);
}, Infinity);
const formattedCodeLines = codeLines.map((line) =>
line.slice(minIndent)
const CodeContent = memo(() => {
if (!language) {
if (typeof children === "string") {
return (
<code
className={`
font-mono
text-gray-800
bg-gray-50
border
border-gray-300
rounded
px-1
py-[3px]
text-xs
whitespace-pre-wrap
break-words
overflow-hidden
mb-1
${className}
`}
>
{children}
</code>
);
codeText = formattedCodeLines.join("\n");
}
}
// handle unknown languages. They won't have a `node.position.start.offset`
if (!codeText) {
const findTextNode = (node: any): string | null => {
if (node.type === "text") {
return node.value;
}
let finalResult = "";
if (node.children) {
for (const child of node.children) {
const result = findTextNode(child);
if (result) {
finalResult += result;
}
}
}
return finalResult;
};
codeText = findTextNode(props.node);
}
return codeText;
}, [content, props.node]);
const handleCopy = useCallback(
(event: React.MouseEvent) => {
event.preventDefault();
if (!codeText) {
return;
}
navigator.clipboard.writeText(codeText).then(() => {
setCopied(true);
setTimeout(() => setCopied(false), 2000);
});
},
[codeText]
);
if (!language) {
if (typeof children === "string") {
return <code className={className}>{children}</code>;
return (
<pre style={CODE_BLOCK_PADDING}>
<code className={`text-sm ${className}`}>
{Array.isArray(children)
? children.map((child, index) => (
<MemoizedCodeLine key={index} content={child} />
))
: children}
</code>
</pre>
);
}
return (
<pre style={CODE_BLOCK_PADDING_TYPE}>
<code {...props} className={`text-sm ${className}`}>
{children}
<pre className="overflow-x-scroll" style={CODE_BLOCK_PADDING}>
<code className="text-xs overflow-x-auto">
{Array.isArray(children)
? children.map((child, index) => (
<MemoizedCodeLine key={index} content={child} />
))
: children}
</code>
</pre>
);
}
});
CodeContent.displayName = "CodeContent";
return (
<div className="overflow-x-hidden">
<div className="flex mx-3 py-2 text-xs">
{language}
{codeText && (
<div
className="ml-auto cursor-pointer select-none"
onMouseDown={handleCopy}
>
{copied ? (
<div className="flex items-center space-x-2">
<FiCheck size={16} />
<span>Copied!</span>
</div>
) : (
<div className="flex items-center space-x-2">
<FiCopy size={16} />
<span>Copy code</span>
</div>
)}
</div>
)}
</div>
<pre {...props} className="overflow-x-scroll" style={{ padding: "1rem" }}>
<code className={`text-xs overflow-x-auto `}>{children}</code>
</pre>
{language && (
<div className="flex mx-3 py-2 text-xs">
{language}
{codeText && <CopyButton />}
</div>
)}
<CodeContent />
</div>
);
});
CodeBlock.displayName = "CodeBlock";
MemoizedCodeLine.displayName = "MemoizedCodeLine";

View File

@ -25,9 +25,9 @@ export const MemoizedLink = memo((props: any) => {
}
});
export const MemoizedParagraph = memo(({ node, ...props }: any) => (
<p {...props} className="text-default" />
));
export const MemoizedParagraph = memo(({ ...props }: any) => {
return <p {...props} className="text-default" />;
});
MemoizedLink.displayName = "MemoizedLink";
MemoizedParagraph.displayName = "MemoizedParagraph";

View File

@ -54,6 +54,7 @@ import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks";
import { ContinueGenerating } from "./ContinueMessage";
import { MemoizedLink, MemoizedParagraph } from "./MemoizedTextComponents";
import { extractCodeText } from "./codeUtils";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@ -253,6 +254,40 @@ export const AIMessage = ({
new Set((docs || []).map((doc) => doc.source_type))
).slice(0, 3);
const markdownComponents = useMemo(
() => ({
a: MemoizedLink,
p: MemoizedParagraph,
code: ({ node, inline, className, children, ...props }: any) => {
const codeText = extractCodeText(
node,
finalContent as string,
children
);
return (
<CodeBlock className={className} codeText={codeText}>
{children}
</CodeBlock>
);
},
}),
[messageId, content]
);
const renderedMarkdown = useMemo(() => {
return (
<ReactMarkdown
className="prose max-w-full text-base"
components={markdownComponents}
remarkPlugins={[remarkGfm]}
rehypePlugins={[[rehypePrism, { ignoreMissing: true }]]}
>
{finalContent as string}
</ReactMarkdown>
);
}, [finalContent]);
const includeMessageSwitcher =
currentMessageInd !== undefined &&
onMessageSelection &&
@ -352,27 +387,7 @@ export const AIMessage = ({
{typeof content === "string" ? (
<div className="overflow-x-visible max-w-content-max">
<ReactMarkdown
key={messageId}
className="prose max-w-full text-base"
components={{
a: MemoizedLink,
p: MemoizedParagraph,
code: (props) => (
<CodeBlock
className="w-full"
{...props}
content={content as string}
/>
),
}}
remarkPlugins={[remarkGfm]}
rehypePlugins={[
[rehypePrism, { ignoreMissing: true }],
]}
>
{finalContent as string}
</ReactMarkdown>
{renderedMarkdown}
</div>
) : (
content

View File

@ -0,0 +1,47 @@
export function extractCodeText(
node: any,
content: string,
children: React.ReactNode
): string {
let codeText: string | null = null;
if (
node?.position?.start?.offset != null &&
node?.position?.end?.offset != null
) {
codeText = content.slice(
node.position.start.offset,
node.position.end.offset
);
codeText = codeText.trim();
// Find the last occurrence of closing backticks
const lastBackticksIndex = codeText.lastIndexOf("```");
if (lastBackticksIndex !== -1) {
codeText = codeText.slice(0, lastBackticksIndex + 3);
}
// Remove the language declaration and trailing backticks
const codeLines = codeText.split("\n");
if (codeLines.length > 1 && codeLines[0].trim().startsWith("```")) {
codeLines.shift(); // Remove the first line with the language declaration
if (codeLines[codeLines.length - 1]?.trim() === "```") {
codeLines.pop(); // Remove the last line with the trailing backticks
}
const minIndent = codeLines
.filter((line) => line.trim().length > 0)
.reduce((min, line) => {
const match = line.match(/^\s*/);
return Math.min(min, match ? match[0].length : 0);
}, Infinity);
const formattedCodeLines = codeLines.map((line) => line.slice(minIndent));
codeText = formattedCodeLines.join("\n");
}
} else {
// Fallback if position offsets are not available
codeText = children?.toString() || null;
}
return codeText || "";
}

View File

@ -1,4 +1,5 @@
import { CodeBlock } from "@/app/chat/message/CodeBlock";
import { extractCodeText } from "@/app/chat/message/codeUtils";
import {
MemoizedLink,
MemoizedParagraph,
@ -10,13 +11,11 @@ import remarkGfm from "remark-gfm";
interface MinimalMarkdownProps {
content: string;
className?: string;
useCodeBlock?: boolean;
}
export const MinimalMarkdown: React.FC<MinimalMarkdownProps> = ({
content,
className = "",
useCodeBlock = false,
}) => {
return (
<ReactMarkdown
@ -24,11 +23,15 @@ export const MinimalMarkdown: React.FC<MinimalMarkdownProps> = ({
components={{
a: MemoizedLink,
p: MemoizedParagraph,
code: useCodeBlock
? (props) => (
<CodeBlock className="w-full" {...props} content={content} />
)
: (props) => <code {...props} />,
code: ({ node, inline, className, children, ...props }: any) => {
const codeText = extractCodeText(node, content, children);
return (
<CodeBlock className={className} codeText={codeText}>
{children}
</CodeBlock>
);
},
}}
remarkPlugins={[remarkGfm]}
>

View File

@ -1,7 +1,5 @@
import { Quote } from "@/lib/search/interfaces";
import { ResponseSection, StatusOptions } from "./ResponseSection";
import ReactMarkdown from "react-markdown";
import remarkGfm from "remark-gfm";
import { MinimalMarkdown } from "@/components/chat_search/MinimalMarkdown";
const TEMP_STRING = "__$%^TEMP$%^__";
@ -40,12 +38,7 @@ export const AnswerSection = (props: AnswerSectionProps) => {
status = "success";
header = <></>;
body = (
<MinimalMarkdown
useCodeBlock
content={replaceNewlines(props.answer || "")}
/>
);
body = <MinimalMarkdown content={replaceNewlines(props.answer || "")} />;
// error while building answer (NOTE: if error occurs during quote generation
// the above if statement will hit and the error will not be displayed)
@ -61,9 +54,7 @@ export const AnswerSection = (props: AnswerSectionProps) => {
} else if (props.answer) {
status = "success";
header = <></>;
body = (
<MinimalMarkdown useCodeBlock content={replaceNewlines(props.answer)} />
);
body = <MinimalMarkdown content={replaceNewlines(props.answer)} />;
}
return (