Contextual Retrieval (#4029)

* contextual rag implementation

* WIP

* indexing test fix

* workaround for chunking errors, WIP on fixing massive memory cost

* mypy and test fixes

* reformatting

* fixed rebase
This commit is contained in:
evan-danswer 2025-03-30 11:49:09 -07:00 committed by GitHub
parent cb5bbd3812
commit 56f8ab927b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 1007 additions and 68 deletions

View File

@ -0,0 +1,50 @@
"""enable contextual retrieval
Revision ID: 8e1ac4f39a9f
Revises: 3781a5eb12cb
Create Date: 2024-12-20 13:29:09.918661
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8e1ac4f39a9f"
down_revision = "3781a5eb12cb"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"search_settings",
sa.Column(
"enable_contextual_rag",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_name",
sa.String(),
nullable=True,
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_provider",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("search_settings", "enable_contextual_rag")
op.drop_column("search_settings", "contextual_rag_llm_name")
op.drop_column("search_settings", "contextual_rag_llm_provider")

View File

@ -495,6 +495,11 @@ NUM_SECONDARY_INDEXING_WORKERS = int(
ENABLE_MULTIPASS_INDEXING = (
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
)
# Enable contextual retrieval
ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true"
DEFAULT_CONTEXTUAL_RAG_LLM_NAME = "gpt-4o-mini"
DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER = "DevEnvPresetOpenAI"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
@ -536,6 +541,17 @@ MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
# Use document summary for contextual rag
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
# Use chunk summary for contextual rag
USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true"
# Average summary embeddings for contextual rag (not yet implemented)
AVERAGE_SUMMARY_EMBEDDINGS = (
os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true"
)
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
#####
# Miscellaneous
#####

View File

@ -30,6 +30,7 @@ from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.llm.interfaces import LLM
from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -76,6 +77,26 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
return is_valid_image_type(mime_type)
def download_request(service: GoogleDriveService, file_id: str) -> bytes:
"""
Download the file from Google Drive.
"""
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_id}")
return bytes()
return response
def _download_and_extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
@ -114,41 +135,31 @@ def _download_and_extract_sections_basic(
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
return []
response_call = lazy_eval(lambda: download_request(service, file_id))
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
text = response_call().decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
text = xlsx_to_text(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
text = pptx_to_text(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
@ -158,7 +169,7 @@ def _download_and_extract_sections_basic(
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
image_data=response_call(),
file_name=file_id,
display_name=file_name,
media_type=mime_type,
@ -171,7 +182,7 @@ def _download_and_extract_sections_basic(
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
@ -194,8 +205,15 @@ def _download_and_extract_sections_basic(
else:
# For unsupported file types, try to extract text
if mime_type in [
"application/vnd.google-apps.video",
"application/vnd.google-apps.audio",
"application/zip",
]:
return []
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
text = extract_file_text(io.BytesIO(response_call()), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")

View File

@ -163,6 +163,9 @@ class DocumentBase(BaseModel):
attributes.append(k + INDEX_SEPARATOR + v)
return attributes
def get_text_content(self) -> str:
return " ".join([section.text for section in self.sections if section.text])
class Document(DocumentBase):
"""Used for Onyx ingestion api, the ID is required"""

View File

@ -60,7 +60,7 @@ class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
inference_settings = InferenceSettings.from_db_model(search_settings)
indexing_setting = IndexingSetting.from_db_model(search_settings)
return cls(**inference_settings.dict(), **indexing_setting.dict())
return cls(**inference_settings.model_dump(), **indexing_setting.model_dump())
class SavedSearchSettings(InferenceSettings, IndexingSetting):
@ -80,6 +80,9 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
reduced_dimension=search_settings.reduced_dimension,
# Whether switching to this model requires re-indexing
background_reindex_enabled=search_settings.background_reindex_enabled,
enable_contextual_rag=search_settings.enable_contextual_rag,
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
# Reranking Details
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
@ -218,6 +221,8 @@ class InferenceChunk(BaseChunk):
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
match_highlights: list[str]
doc_summary: str
chunk_context: str
# when the doc was last updated
updated_at: datetime | None

View File

@ -196,9 +196,21 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
RETURN_SEPARATOR
)
def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str:
# remove document summary
if chunk.content.startswith(chunk.doc_summary):
chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip()
# remove chunk context
if chunk.content.endswith(chunk.chunk_context):
chunk.content = chunk.content[
: len(chunk.content) - len(chunk.chunk_context)
].rstrip()
return chunk.content
for chunk in chunks:
chunk.content = _remove_title(chunk)
chunk.content = _remove_metadata_suffix(chunk)
chunk.content = _remove_contextual_rag(chunk)
return [chunk.to_inference_chunk() for chunk in chunks]

View File

@ -791,6 +791,15 @@ class SearchSettings(Base):
# Mini and Large Chunks (large chunk also checks for model max context)
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
# Contextual RAG
enable_contextual_rag: Mapped[bool] = mapped_column(Boolean, default=False)
# Contextual RAG LLM
contextual_rag_llm_name: Mapped[str | None] = mapped_column(String, nullable=True)
contextual_rag_llm_provider: Mapped[str | None] = mapped_column(
String, nullable=True
)
multilingual_expansion: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), default=[]
)

View File

@ -62,6 +62,9 @@ def create_search_settings(
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
enable_contextual_rag=search_settings.enable_contextual_rag,
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
multilingual_expansion=search_settings.multilingual_expansion,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
@ -319,6 +322,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
index_name="danswer_chunk",
multipass_indexing=False,
enable_contextual_rag=False,
api_url=None,
)
@ -333,5 +337,6 @@ def get_new_default_embedding_model() -> IndexingSetting:
passage_prefix=ASYM_PASSAGE_PREFIX,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
multipass_indexing=False,
enable_contextual_rag=False,
api_url=None,
)

View File

@ -98,6 +98,12 @@ schema DANSWER_CHUNK_NAME {
field metadata type string {
indexing: summary | attribute
}
field chunk_context type string {
indexing: summary | attribute
}
field doc_summary type string {
indexing: summary | attribute
}
field metadata_suffix type string {
indexing: summary | attribute
}

View File

@ -24,9 +24,11 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import CONTENT
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOC_SUMMARY
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
@ -126,7 +128,8 @@ def _vespa_hit_to_inference_chunk(
return InferenceChunkUncleaned(
chunk_id=fields[CHUNK_ID],
blurb=fields.get(BLURB, ""), # Unused
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
content=fields[CONTENT], # Includes extra title prefix and metadata suffix;
# also sometimes context for contextual rag
source_links=source_links_dict or {0: ""},
section_continuation=fields[SECTION_CONTINUATION],
document_id=fields[DOCUMENT_ID],
@ -143,6 +146,8 @@ def _vespa_hit_to_inference_chunk(
large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []),
metadata=metadata,
metadata_suffix=fields.get(METADATA_SUFFIX),
doc_summary=fields.get(DOC_SUMMARY, ""),
chunk_context=fields.get(CHUNK_CONTEXT, ""),
match_highlights=match_highlights,
updated_at=updated_at,
)

View File

@ -187,7 +187,7 @@ class VespaIndex(DocumentIndex):
) -> None:
if MULTI_TENANT:
logger.info(
"Skipping Vespa index seup for multitenant (would wipe all indices)"
"Skipping Vespa index setup for multitenant (would wipe all indices)"
)
return None

View File

@ -25,9 +25,11 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import CONTENT
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOC_SUMMARY
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
@ -174,7 +176,7 @@ def _index_vespa_chunk(
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
# natural language representation of the metadata section
CONTENT: remove_invalid_unicode_chars(
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}"
),
# This duplication of `content` is needed for keyword highlighting
# Note that it's not exactly the same as the actual content
@ -189,6 +191,8 @@ def _index_vespa_chunk(
# Save as a list for efficient extraction as an Attribute
METADATA_LIST: metadata_list,
METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword),
CHUNK_CONTEXT: chunk.chunk_context,
DOC_SUMMARY: chunk.doc_summary,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),

View File

@ -71,6 +71,8 @@ LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
DOC_SUMMARY = "doc_summary"
CHUNK_CONTEXT = "chunk_context"
BOOST = "boost"
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
@ -106,6 +108,8 @@ YQL_BASE = (
f"{LARGE_CHUNK_REFERENCE_IDS}, "
f"{METADATA}, "
f"{METADATA_SUFFIX}, "
f"{DOC_SUMMARY}, "
f"{CHUNK_CONTEXT}, "
f"{CONTENT_SUMMARY} "
f"from {{index_name}} where "
)

View File

@ -1,7 +1,10 @@
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import MINI_CHUNK_SIZE
from onyx.configs.app_configs import SKIP_METADATA_IN_CHUNK
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.constants import SECTION_SEPARATOR
@ -13,6 +16,7 @@ from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import clean_text
@ -82,6 +86,9 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks],
mini_chunk_texts=None,
large_chunk_id=large_chunk_id,
chunk_context="",
doc_summary="",
contextual_rag_reserved_tokens=0,
)
offset = 0
@ -120,6 +127,7 @@ class Chunker:
tokenizer: BaseTokenizer,
enable_multipass: bool = False,
enable_large_chunks: bool = False,
enable_contextual_rag: bool = False,
blurb_size: int = BLURB_SIZE,
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
@ -133,9 +141,20 @@ class Chunker:
self.chunk_token_limit = chunk_token_limit
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.enable_contextual_rag = enable_contextual_rag
if enable_contextual_rag:
assert (
USE_CHUNK_SUMMARY or USE_DOCUMENT_SUMMARY
), "Contextual RAG requires at least one of chunk summary and document summary enabled"
self.default_contextual_rag_reserved_tokens = MAX_CONTEXT_TOKENS * (
int(USE_CHUNK_SUMMARY) + int(USE_DOCUMENT_SUMMARY)
)
self.tokenizer = tokenizer
self.callback = callback
self.max_context = 0
self.prompt_tokens = 0
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
chunk_size=blurb_size,
@ -221,6 +240,9 @@ class Chunker:
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=self._get_mini_chunk_texts(text),
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
)
chunks_list.append(new_chunk)
@ -309,7 +331,7 @@ class Chunker:
continue
# CASE 2: Normal text section
section_token_count = len(self.tokenizer.tokenize(section_text))
section_token_count = len(self.tokenizer.encode(section_text))
# If the section is large on its own, split it separately
if section_token_count > content_token_limit:
@ -332,8 +354,7 @@ class Chunker:
# If even the split_text is bigger than strict limit, further split
if (
STRICT_CHUNK_TOKEN_LIMIT
and len(self.tokenizer.tokenize(split_text))
> content_token_limit
and len(self.tokenizer.encode(split_text)) > content_token_limit
):
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
@ -363,10 +384,10 @@ class Chunker:
continue
# If we can still fit this section into the current chunk, do so
current_token_count = len(self.tokenizer.tokenize(chunk_text))
current_token_count = len(self.tokenizer.encode(chunk_text))
current_offset = len(shared_precompare_cleanup(chunk_text))
next_section_tokens = (
len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
)
if next_section_tokens + current_token_count <= content_token_limit:
@ -414,7 +435,7 @@ class Chunker:
# Title prep
title = self._extract_blurb(document.get_title_for_document_index() or "")
title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(self.tokenizer.tokenize(title_prefix))
title_tokens = len(self.tokenizer.encode(title_prefix))
# Metadata prep
metadata_suffix_semantic = ""
@ -427,15 +448,50 @@ class Chunker:
) = _get_metadata_suffix_for_document_index(
document.metadata, include_separator=True
)
metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic))
metadata_tokens = len(self.tokenizer.encode(metadata_suffix_semantic))
# If metadata is too large, skip it in the semantic content
if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE:
metadata_suffix_semantic = ""
metadata_tokens = 0
single_chunk_fits = True
doc_token_count = 0
if self.enable_contextual_rag:
doc_content = document.get_text_content()
tokenized_doc = self.tokenizer.tokenize(doc_content)
doc_token_count = len(tokenized_doc)
# check if doc + title + metadata fits in a single chunk. If so, no need for contextual RAG
single_chunk_fits = (
doc_token_count + title_tokens + metadata_tokens
<= self.chunk_token_limit
)
# expand the size of the context used for contextual rag based on whether chunk context and doc summary are used
context_size = 0
if (
self.enable_contextual_rag
and not single_chunk_fits
and not AVERAGE_SUMMARY_EMBEDDINGS
):
context_size += self.default_contextual_rag_reserved_tokens
# Adjust content token limit to accommodate title + metadata
content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens
content_token_limit = (
self.chunk_token_limit - title_tokens - metadata_tokens - context_size
)
# first check: if there is not enough actual chunk content when including contextual rag,
# then don't do contextual rag
if content_token_limit <= CHUNK_MIN_CONTENT:
context_size = 0 # Don't do contextual RAG
# revert to previous content token limit
content_token_limit = (
self.chunk_token_limit - title_tokens - metadata_tokens
)
# If there is not enough context remaining then just index the chunk with no prefix/suffix
if content_token_limit <= CHUNK_MIN_CONTENT:
# Not enough space left, so revert to full chunk without the prefix
content_token_limit = self.chunk_token_limit
@ -459,6 +515,9 @@ class Chunker:
large_chunks = generate_large_chunks(normal_chunks)
normal_chunks.extend(large_chunks)
for chunk in normal_chunks:
chunk.contextual_rag_reserved_tokens = context_size
return normal_chunks
def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]:

View File

@ -121,7 +121,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
if chunk.large_chunk_reference_ids:
large_chunks_present = True
chunk_text = (
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}"
) or chunk.source_document.get_title_for_document_index()
if not chunk_text:

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from collections.abc import Callable
from functools import partial
from typing import Protocol
@ -8,7 +9,13 @@ from sqlalchemy.orm import Session
from onyx.access.access import get_access_for_documents
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
@ -36,9 +43,10 @@ from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import Document as DBDocument
from onyx.db.models import IndexModelStatus
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.db.pg_file_store import read_lobj
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_active_search_settings
from onyx.db.tag import create_or_add_document_tag
from onyx.db.tag import create_or_add_document_tag_list
from onyx.document_index.document_index_utils import (
@ -57,11 +65,24 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_llm_for_contextual_rag
from onyx.llm.interfaces import LLM
from onyx.llm.utils import get_max_input_tokens
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from onyx.llm.utils import message_to_string
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_middle
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT1
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT2
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
@ -249,6 +270,8 @@ def index_doc_batch_with_handler(
db_session: Session,
tenant_id: str,
ignore_time_skip: bool = False,
enable_contextual_rag: bool = False,
llm: LLM | None = None,
) -> IndexingPipelineResult:
try:
index_pipeline_result = index_doc_batch(
@ -261,6 +284,8 @@ def index_doc_batch_with_handler(
db_session=db_session,
ignore_time_skip=ignore_time_skip,
tenant_id=tenant_id,
enable_contextual_rag=enable_contextual_rag,
llm=llm,
)
except Exception as e:
# don't log the batch directly, it's too much text
@ -523,6 +548,145 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
return indexed_documents
def add_document_summaries(
chunks_by_doc: list[DocAwareChunk],
llm: LLM,
tokenizer: BaseTokenizer,
trunc_doc_tokens: int,
) -> list[int] | None:
"""
Adds a document summary to a list of chunks from the same document.
Returns the number of tokens in the document.
"""
doc_tokens = []
# this is value is the same for each chunk in the document; 0 indicates
# There is not enough space for contextual RAG (the chunk content
# and possibly metadata took up too much space)
if chunks_by_doc[0].contextual_rag_reserved_tokens == 0:
return None
doc_tokens = tokenizer.encode(chunks_by_doc[0].source_document.get_text_content())
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_tokens, tokenizer)
summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content)
doc_summary = message_to_string(
llm.invoke(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS)
)
for chunk in chunks_by_doc:
chunk.doc_summary = doc_summary
return doc_tokens
def add_chunk_summaries(
chunks_by_doc: list[DocAwareChunk],
llm: LLM,
tokenizer: BaseTokenizer,
trunc_doc_chunk_tokens: int,
doc_tokens: list[int] | None,
) -> None:
"""
Adds chunk summaries to the chunks grouped by document id.
Chunk summaries look at the chunk as well as the entire document (or a summary,
if the document is too long) and describe how the chunk relates to the document.
"""
# all chunks within a document have the same contextual_rag_reserved_tokens
if chunks_by_doc[0].contextual_rag_reserved_tokens == 0:
return
# use values computed in above doc summary section if available
doc_tokens = doc_tokens or tokenizer.encode(
chunks_by_doc[0].source_document.get_text_content()
)
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_chunk_tokens, tokenizer)
# only compute doc summary if needed
doc_info = (
doc_content
if len(doc_tokens) <= MAX_TOKENS_FOR_FULL_INCLUSION
else chunks_by_doc[0].doc_summary
)
if not doc_info:
# This happens if the document is too long AND document summaries are turned off
# In this case we compute a doc summary using the LLM
doc_info = message_to_string(
llm.invoke(
DOCUMENT_SUMMARY_PROMPT.format(document=doc_content),
max_tokens=MAX_CONTEXT_TOKENS,
)
)
context_prompt1 = CONTEXTUAL_RAG_PROMPT1.format(document=doc_info)
def assign_context(chunk: DocAwareChunk) -> None:
context_prompt2 = CONTEXTUAL_RAG_PROMPT2.format(chunk=chunk.content)
try:
chunk.chunk_context = message_to_string(
llm.invoke(
context_prompt1 + context_prompt2,
max_tokens=MAX_CONTEXT_TOKENS,
)
)
except LLMRateLimitError as e:
# Erroring during chunker is undesirable, so we log the error and continue
# TODO: for v2, add robust retry logic
logger.exception(f"Rate limit adding chunk summary: {e}", exc_info=e)
chunk.chunk_context = ""
except Exception as e:
logger.exception(f"Error adding chunk summary: {e}", exc_info=e)
chunk.chunk_context = ""
run_functions_tuples_in_parallel(
[(assign_context, (chunk,)) for chunk in chunks_by_doc]
)
def add_contextual_summaries(
chunks: list[DocAwareChunk],
llm: LLM,
tokenizer: BaseTokenizer,
chunk_token_limit: int,
) -> list[DocAwareChunk]:
"""
Adds Document summary and chunk-within-document context to the chunks
based on which environment variables are set.
"""
max_context = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
output_tokens=MAX_CONTEXT_TOKENS,
)
doc2chunks = defaultdict(list)
for chunk in chunks:
doc2chunks[chunk.source_document.id].append(chunk)
# The number of tokens allowed for the document when computing a document summary
trunc_doc_summary_tokens = max_context - len(
tokenizer.encode(DOCUMENT_SUMMARY_PROMPT)
)
prompt_tokens = len(
tokenizer.encode(CONTEXTUAL_RAG_PROMPT1 + CONTEXTUAL_RAG_PROMPT2)
)
# The number of tokens allowed for the document when computing a
# "chunk in context of document" summary
trunc_doc_chunk_tokens = max_context - prompt_tokens - chunk_token_limit
for chunks_by_doc in doc2chunks.values():
doc_tokens = None
if USE_DOCUMENT_SUMMARY:
doc_tokens = add_document_summaries(
chunks_by_doc, llm, tokenizer, trunc_doc_summary_tokens
)
if USE_CHUNK_SUMMARY:
add_chunk_summaries(
chunks_by_doc, llm, tokenizer, trunc_doc_chunk_tokens, doc_tokens
)
return chunks
@log_function_time(debug_only=True)
def index_doc_batch(
*,
@ -534,6 +698,8 @@ def index_doc_batch(
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
tenant_id: str,
enable_contextual_rag: bool = False,
llm: LLM | None = None,
ignore_time_skip: bool = False,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> IndexingPipelineResult:
@ -596,6 +762,20 @@ def index_doc_batch(
# a common source of failure for the indexing pipeline
chunks: list[DocAwareChunk] = chunker.chunk(ctx.indexable_docs)
# contextual RAG
if enable_contextual_rag:
assert llm is not None, "must provide an LLM for contextual RAG"
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
# Because the chunker's tokens are different from the LLM's tokens,
# We add a fudge factor to ensure we truncate prompts to the LLM's token limit
chunks = add_contextual_summaries(
chunks, llm, llm_tokenizer, chunker.chunk_token_limit * 2
)
logger.debug("Starting embedding")
chunks_with_embeddings, embedding_failures = (
embed_chunks_with_failure_handling(
@ -791,13 +971,33 @@ def build_indexing_pipeline(
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
all_search_settings = get_active_search_settings(db_session)
if (
all_search_settings.secondary
and all_search_settings.secondary.status == IndexModelStatus.FUTURE
):
search_settings = all_search_settings.secondary
else:
search_settings = all_search_settings.primary
multipass_config = get_multipass_config(search_settings)
enable_contextual_rag = (
search_settings.enable_contextual_rag or ENABLE_CONTEXTUAL_RAG
)
llm = None
if enable_contextual_rag:
llm = get_llm_for_contextual_rag(
search_settings.contextual_rag_llm_name or DEFAULT_CONTEXTUAL_RAG_LLM_NAME,
search_settings.contextual_rag_llm_provider
or DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER,
)
chunker = chunker or Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass_config.multipass_indexing,
enable_large_chunks=multipass_config.enable_large_chunks,
enable_contextual_rag=enable_contextual_rag,
# after every doc, update status in case there are a bunch of really long docs
callback=callback,
)
@ -811,4 +1011,6 @@ def build_indexing_pipeline(
ignore_time_skip=ignore_time_skip,
db_session=db_session,
tenant_id=tenant_id,
enable_contextual_rag=enable_contextual_rag,
llm=llm,
)

View File

@ -49,6 +49,15 @@ class DocAwareChunk(BaseChunk):
metadata_suffix_semantic: str
metadata_suffix_keyword: str
# This is the number of tokens reserved for contextual RAG
# in the chunk. doc_summary and chunk_context conbined should
# contain at most this many tokens.
contextual_rag_reserved_tokens: int
# This is the summary for the document generated for contextual RAG
doc_summary: str
# This is the context for this chunk generated for contextual RAG
chunk_context: str
mini_chunk_texts: list[str] | None
large_chunk_id: int | None
@ -154,6 +163,9 @@ class IndexingSetting(EmbeddingModelDetail):
reduced_dimension: int | None = None
background_reindex_enabled: bool = True
enable_contextual_rag: bool
contextual_rag_llm_name: str | None = None
contextual_rag_llm_provider: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
@ -178,6 +190,7 @@ class IndexingSetting(EmbeddingModelDetail):
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
background_reindex_enabled=search_settings.background_reindex_enabled,
enable_contextual_rag=search_settings.enable_contextual_rag,
)

View File

@ -425,12 +425,12 @@ class DefaultMultiLLM(LLM):
messages=processed_prompt,
tools=tools,
tool_choice=tool_choice if tools else None,
max_tokens=max_tokens,
# streaming choice
stream=stream,
# model params
temperature=0,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@ -531,6 +531,7 @@ class DefaultMultiLLM(LLM):
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
return

View File

@ -16,6 +16,7 @@ from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import LLMProvider
from onyx.server.manage.llm.models import LLMProviderView
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
@ -154,6 +155,40 @@ def get_default_llm_with_vision(
return None
def llm_from_provider(
model_name: str,
llm_provider: LLMProvider,
timeout: int | None = None,
temperature: float | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model_name,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
def get_llm_for_contextual_rag(model_name: str, model_provider: str) -> LLM:
with get_session_context_manager() as db_session:
llm_provider = fetch_llm_provider_view(db_session, model_provider)
if not llm_provider:
raise ValueError("No LLM provider with name {} found".format(model_provider))
return llm_from_provider(
model_name=model_name,
llm_provider=llm_provider,
)
def get_default_llms(
timeout: int | None = None,
temperature: float | None = None,
@ -179,14 +214,9 @@ def get_default_llms(
raise ValueError("No fast default model name found")
def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
return llm_from_provider(
model_name=model,
llm_provider=llm_provider,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,

View File

@ -29,13 +29,19 @@ from litellm.exceptions import Timeout # type: ignore
from litellm.exceptions import UnprocessableEntityError # type: ignore
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import MessageType
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_TOKEN_ESTIMATE
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_TOKEN_ESTIMATE
from onyx.prompts.constants import CODE_BLOCK_PAT
from onyx.utils.b64 import get_image_type
from onyx.utils.b64 import get_image_type_from_bytes
@ -44,6 +50,10 @@ from shared_configs.configs import LOG_LEVEL
logger = setup_logger()
MAX_CONTEXT_TOKENS = 100
ONE_MILLION = 1_000_000
CHUNKS_PER_DOC_ESTIMATE = 5
def litellm_exception_to_error_msg(
e: Exception,
@ -416,6 +426,72 @@ def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | N
return None
def get_llm_contextual_cost(
llm: LLM,
) -> float:
"""
Approximate the cost of using the given LLM for indexing with Contextual RAG.
We use a precomputed estimate for the number of tokens in the contextualizing prompts,
and we assume that every chunk is maximized in terms of content and context.
We also assume that every document is maximized in terms of content, as currently if
a document is longer than a certain length, its summary is used instead of the full content.
We expect that the first assumption will overestimate more than the second one
underestimates, so this should be a fairly conservative price estimate. Also,
this does not account for the cost of documents that fit within a single chunk
which do not get contextualized.
"""
# calculate input costs
num_tokens = ONE_MILLION
num_input_chunks = num_tokens // DOC_EMBEDDING_CONTEXT_SIZE
# We assume that the documents are MAX_TOKENS_FOR_FULL_INCLUSION tokens long
# on average.
num_docs = num_tokens // MAX_TOKENS_FOR_FULL_INCLUSION
num_input_tokens = 0
num_output_tokens = 0
if not USE_CHUNK_SUMMARY and not USE_DOCUMENT_SUMMARY:
return 0
if USE_CHUNK_SUMMARY:
# Each per-chunk prompt includes:
# - The prompt tokens
# - the document tokens
# - the chunk tokens
# for each chunk, we prompt the LLM with the contextual RAG prompt
# and the full document content (or the doc summary, so this is an overestimate)
num_input_tokens += num_input_chunks * (
CONTEXTUAL_RAG_TOKEN_ESTIMATE + MAX_TOKENS_FOR_FULL_INCLUSION
)
# in aggregate, each chunk content is used as a prompt input once
# so the full input size is covered
num_input_tokens += num_tokens
# A single MAX_CONTEXT_TOKENS worth of output is generated per chunk
num_output_tokens += num_input_chunks * MAX_CONTEXT_TOKENS
# going over each doc once means all the tokens, plus the prompt tokens for
# the summary prompt. This CAN happen even when USE_DOCUMENT_SUMMARY is false,
# since doc summaries are used for longer documents when USE_CHUNK_SUMMARY is true.
# So, we include this unconditionally to overestimate.
num_input_tokens += num_tokens + num_docs * DOCUMENT_SUMMARY_TOKEN_ESTIMATE
num_output_tokens += num_docs * MAX_CONTEXT_TOKENS
usd_per_prompt, usd_per_completion = litellm.cost_per_token(
model=llm.config.model_name,
prompt_tokens=num_input_tokens,
completion_tokens=num_output_tokens,
)
# Costs are in USD dollars per million tokens
return usd_per_prompt + usd_per_completion
def get_llm_max_tokens(
model_map: dict,
model_name: str,

View File

@ -11,6 +11,8 @@ from onyx.context.search.models import InferenceChunk
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
TRIM_SEP_PAT = "\n... {n} tokens removed...\n"
logger = setup_logger()
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -159,9 +161,26 @@ def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
if len(tokens) <= desired_length:
return content
return tokenizer.decode(tokens[:desired_length])
def tokenizer_trim_middle(
tokens: list[int], desired_length: int, tokenizer: BaseTokenizer
) -> str:
if len(tokens) <= desired_length:
return tokenizer.decode(tokens)
sep_str = TRIM_SEP_PAT.format(n=len(tokens) - desired_length)
sep_tokens = tokenizer.encode(sep_str)
slice_size = (desired_length - len(sep_tokens)) // 2
assert slice_size > 0, "Slice size is not positive, desired length is too short"
return (
tokenizer.decode(tokens[:slice_size])
+ sep_str
+ tokenizer.decode(tokens[-slice_size:])
)
def tokenizer_trim_chunks(

View File

@ -220,3 +220,29 @@ Chat History:
Based on the above, what is a short name to convey the topic of the conversation?
""".strip()
# NOTE: the prompt separation is partially done for efficiency; previously I tried
# to do it all in one prompt with sequential format() calls but this will cause a backend
# error when the document contains any {} as python will expect the {} to be filled by
# format() arguments
CONTEXTUAL_RAG_PROMPT1 = """<document>
{document}
</document>
Here is the chunk we want to situate within the whole document"""
CONTEXTUAL_RAG_PROMPT2 = """<chunk>
{chunk}
</chunk>
Please give a short succinct context to situate this chunk within the overall document
for the purposes of improving search retrieval of the chunk. Answer only with the succinct
context and nothing else. """
CONTEXTUAL_RAG_TOKEN_ESTIMATE = 64 # 19 + 45
DOCUMENT_SUMMARY_PROMPT = """<document>
{document}
</document>
Please give a short succinct summary of the entire document. Answer only with the succinct
summary and nothing else. """
DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29

View File

@ -87,6 +87,9 @@ def _create_indexable_chunks(
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_reference_ids=[],
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
embeddings=ChunkEmbedding(
full_embedding=preprocessed_doc["content_embedding"],
mini_chunk_embeddings=[],

View File

@ -21,9 +21,11 @@ from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llm
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import get_llm_contextual_cost
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
@ -286,3 +288,38 @@ def list_llm_provider_basics(
db_session, user
)
]
@admin_router.get("/provider-contextual-cost")
def get_provider_contextual_cost(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMCost]:
"""
Get the cost of Re-indexing all documents for contextual retrieval.
See https://docs.litellm.ai/docs/completion/token_usage#5-cost_per_token
This includes:
- The cost of invoking the LLM on each chunk-document pair to get
- the doc_summary
- the chunk_context
- The per-token cost of the LLM used to generate the doc_summary and chunk_context
"""
providers = fetch_existing_llm_providers(db_session)
costs = []
for provider in providers:
for model_name in provider.display_model_names or provider.model_names or []:
llm = get_llm(
provider=provider.provider,
model=model_name,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
)
cost = get_llm_contextual_cost(llm)
costs.append(
LLMCost(provider=provider.name, model_name=model_name, cost=cost)
)
return costs

View File

@ -119,3 +119,9 @@ class VisionProviderResponse(LLMProviderView):
"""Response model for vision providers endpoint, including vision-specific fields."""
vision_models: list[str]
class LLMCost(BaseModel):
provider: str
model_name: str
cost: float

View File

@ -63,7 +63,10 @@ def generate_dummy_chunk(
title_prefix=f"Title prefix for doc {doc_id}",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
contextual_rag_reserved_tokens=0,
embeddings=ChunkEmbedding(
full_embedding=generate_random_embedding(embedding_dim),
mini_chunk_embeddings=[],

View File

@ -99,6 +99,7 @@ PRESERVED_SEARCH_FIELDS = [
"api_url",
"index_name",
"multipass_indexing",
"enable_contextual_rag",
"model_dim",
"normalize",
"passage_prefix",

View File

@ -32,6 +32,8 @@ def create_test_chunk(
match_highlights=[],
updated_at=datetime.now(),
image_file_name=None,
doc_summary="",
chunk_context="",
)

View File

@ -78,6 +78,8 @@ def mock_inference_sections() -> list[InferenceSection]:
source_links={0: "https://example.com/doc1"},
match_highlights=[],
image_file_name=None,
doc_summary="",
chunk_context="",
),
chunks=MagicMock(),
),
@ -101,6 +103,8 @@ def mock_inference_sections() -> list[InferenceSection]:
source_links={0: "https://example.com/doc2"},
match_highlights=[],
image_file_name=None,
doc_summary="",
chunk_context="",
),
chunks=MagicMock(),
),

View File

@ -151,6 +151,8 @@ def test_fuzzy_match_quotes_to_docs() -> None:
match_highlights=[],
updated_at=None,
image_file_name=None,
doc_summary="",
chunk_context="",
)
test_chunk_1 = InferenceChunk(
document_id="test doc 1",
@ -170,6 +172,8 @@ def test_fuzzy_match_quotes_to_docs() -> None:
match_highlights=[],
updated_at=None,
image_file_name=None,
doc_summary="",
chunk_context="",
)
test_quotes = [

View File

@ -38,6 +38,8 @@ def create_inference_chunk(
match_highlights=[],
updated_at=None,
image_file_name=None,
doc_summary="",
chunk_context="",
)

View File

@ -1,5 +1,6 @@
import pytest
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@ -17,3 +18,13 @@ class MockHeartbeat(IndexingHeartbeatInterface):
@pytest.fixture
def mock_heartbeat() -> MockHeartbeat:
return MockHeartbeat()
@pytest.fixture
def embedder() -> DefaultIndexingEmbedder:
return DefaultIndexingEmbedder(
model_name="intfloat/e5-base-v2",
normalize=True,
query_prefix=None,
passage_prefix=None,
)

View File

@ -1,25 +1,24 @@
from typing import Any
from unittest.mock import Mock
import pytest
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import process_image_sections
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from tests.unit.onyx.indexing.conftest import MockHeartbeat
@pytest.fixture
def embedder() -> DefaultIndexingEmbedder:
return DefaultIndexingEmbedder(
model_name="intfloat/e5-base-v2",
normalize=True,
query_prefix=None,
passage_prefix=None,
)
def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
@pytest.mark.parametrize("enable_contextual_rag", [True, False])
def test_chunk_document(
embedder: DefaultIndexingEmbedder, enable_contextual_rag: bool
) -> None:
short_section_1 = "This is a short section."
long_section = (
"This is a long section that should be split into multiple chunks. " * 100
@ -45,9 +44,22 @@ def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
)
indexing_documents = process_image_sections([document])
mock_llm_invoke_count = 0
def mock_llm_invoke(self: Any, *args: Any, **kwargs: Any) -> Mock:
nonlocal mock_llm_invoke_count
mock_llm_invoke_count += 1
m = Mock()
m.content = f"Test{mock_llm_invoke_count}"
return m
mock_llm = Mock()
mock_llm.invoke = mock_llm_invoke
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
enable_contextual_rag=enable_contextual_rag,
)
chunks = chunker.chunk(indexing_documents)
@ -58,6 +70,14 @@ def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
assert "tag1" in chunks[0].metadata_suffix_keyword
assert "tag2" in chunks[0].metadata_suffix_semantic
rag_tokens = MAX_CONTEXT_TOKENS * (
int(USE_DOCUMENT_SUMMARY) + int(USE_CHUNK_SUMMARY)
)
for chunk in chunks:
assert chunk.contextual_rag_reserved_tokens == (
rag_tokens if enable_contextual_rag else 0
)
def test_chunker_heartbeat(
embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat
@ -78,6 +98,7 @@ def test_chunker_heartbeat(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
callback=mock_heartbeat,
enable_contextual_rag=False,
)
chunks = chunker.chunk(indexing_documents)

View File

@ -21,7 +21,13 @@ def mock_embedding_model() -> Generator[Mock, None, None]:
yield mock
def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None:
@pytest.mark.parametrize(
"chunk_context, doc_summary",
[("Test chunk context", "Test document summary"), ("", "")],
)
def test_default_indexing_embedder_embed_chunks(
mock_embedding_model: Mock, chunk_context: str, doc_summary: str
) -> None:
# Setup
embedder = DefaultIndexingEmbedder(
model_name="test-model",
@ -63,6 +69,9 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N
large_chunk_reference_ids=[],
large_chunk_id=None,
image_file_name=None,
chunk_context=chunk_context,
doc_summary=doc_summary,
contextual_rag_reserved_tokens=200,
)
]
@ -81,7 +90,7 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N
# Verify the embedding model was called correctly
mock_embedding_model.return_value.encode.assert_any_call(
texts=["Title: Test chunk"],
texts=[f"Title: {doc_summary}Test chunk{chunk_context}"],
text_type=EmbedTextType.PASSAGE,
large_chunks_present=False,
)

View File

@ -1,6 +1,8 @@
from typing import Any
from typing import cast
from typing import List
from unittest.mock import Mock
from unittest.mock import patch
import pytest
@ -9,8 +11,12 @@ from onyx.connectors.models import Document
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import _get_aggregated_chunk_boost_factor
from onyx.indexing.indexing_pipeline import add_contextual_summaries
from onyx.indexing.indexing_pipeline import filter_documents
from onyx.indexing.indexing_pipeline import process_image_sections
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import IndexChunk
from onyx.natural_language_processing.search_nlp_models import (
@ -166,6 +172,9 @@ def create_test_chunk(
embeddings=ChunkEmbedding(full_embedding=[], mini_chunk_embeddings=[]),
title_embedding=None,
image_file_name=None,
chunk_context="",
doc_summary="",
contextual_rag_reserved_tokens=200,
)
@ -249,3 +258,76 @@ def test_get_aggregated_boost_factor_individual_failure() -> None:
)
assert "Failed to predict content classification for chunk" in str(exc_info.value)
@patch("onyx.llm.utils.GEN_AI_MAX_TOKENS", 4096)
@pytest.mark.parametrize("enable_contextual_rag", [True, False])
def test_contextual_rag(
embedder: DefaultIndexingEmbedder, enable_contextual_rag: bool
) -> None:
short_section_1 = "This is a short section."
long_section = (
"This is a long section that should be split into multiple chunks. " * 100
)
short_section_2 = "This is another short section."
short_section_3 = "This is another short section again."
short_section_4 = "Final short section."
semantic_identifier = "Test Document"
document = Document(
id="test_doc",
source=DocumentSource.WEB,
semantic_identifier=semantic_identifier,
metadata={"tags": ["tag1", "tag2"]},
doc_updated_at=None,
sections=[
TextSection(text=short_section_1, link="link1"),
TextSection(text=short_section_2, link="link2"),
TextSection(text=long_section, link="link3"),
TextSection(text=short_section_3, link="link4"),
TextSection(text=short_section_4, link="link5"),
],
)
indexing_documents = process_image_sections([document])
mock_llm_invoke_count = 0
def mock_llm_invoke(self: Any, *args: Any, **kwargs: Any) -> Mock:
nonlocal mock_llm_invoke_count
mock_llm_invoke_count += 1
m = Mock()
m.content = f"Test{mock_llm_invoke_count}"
return m
llm_tokenizer = embedder.embedding_model.tokenizer
mock_llm = Mock()
mock_llm.invoke = mock_llm_invoke
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
enable_contextual_rag=enable_contextual_rag,
)
chunks = chunker.chunk(indexing_documents)
chunks = add_contextual_summaries(
chunks, mock_llm, llm_tokenizer, chunker.chunk_token_limit * 2
)
assert len(chunks) == 5
assert short_section_1 in chunks[0].content
assert short_section_3 in chunks[-1].content
assert short_section_4 in chunks[-1].content
assert "tag1" in chunks[0].metadata_suffix_keyword
assert "tag2" in chunks[0].metadata_suffix_semantic
doc_summary = "Test1" if enable_contextual_rag else ""
chunk_context = ""
count = 2
for chunk in chunks:
if enable_contextual_rag:
chunk_context = f"Test{count}"
count += 1
assert chunk.doc_summary == doc_summary
assert chunk.chunk_context == chunk_context

View File

@ -140,12 +140,12 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
],
tools=tools,
tool_choice=None,
max_tokens=None,
stream=False,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
max_tokens=None,
)
@ -286,10 +286,10 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No
],
tools=tools,
tool_choice=None,
max_tokens=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
max_tokens=None,
)

View File

@ -1,5 +1,8 @@
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
export const LLM_CONTEXTUAL_COST_ADMIN_URL =
"/api/admin/llm/provider-contextual-cost";
export const EMBEDDING_PROVIDERS_ADMIN_URL =
"/api/admin/embedding/embedding-provider";

View File

@ -143,6 +143,15 @@ function Main() {
</Text>
</div>
<div>
<Text className="font-semibold">Contextual RAG</Text>
<Text className="text-text-700">
{searchSettings.enable_contextual_rag
? "Enabled"
: "Disabled"}
</Text>
</div>
<div>
<Text className="font-semibold">
Disable Reranking for Streaming

View File

@ -26,9 +26,18 @@ export enum EmbeddingPrecision {
BFLOAT16 = "bfloat16",
}
export interface LLMContextualCost {
provider: string;
model_name: string;
cost: number;
}
export interface AdvancedSearchConfiguration {
index_name: string | null;
multipass_indexing: boolean;
enable_contextual_rag: boolean;
contextual_rag_llm_name: string | null;
contextual_rag_llm_provider: string | null;
multilingual_expansion: string[];
disable_rerank_for_streaming: boolean;
api_url: string | null;

View File

@ -3,7 +3,11 @@ import { Formik, Form, FormikProps, FieldArray, Field } from "formik";
import * as Yup from "yup";
import { TrashIcon } from "@/components/icons/icons";
import { FaPlus } from "react-icons/fa";
import { AdvancedSearchConfiguration, EmbeddingPrecision } from "../interfaces";
import {
AdvancedSearchConfiguration,
EmbeddingPrecision,
LLMContextualCost,
} from "../interfaces";
import {
BooleanFormField,
Label,
@ -12,6 +16,13 @@ import {
} from "@/components/admin/connectors/Field";
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";
import { StringOrNumberOption } from "@/components/Dropdown";
import useSWR from "swr";
import { LLM_CONTEXTUAL_COST_ADMIN_URL } from "../../configuration/llm/constants";
import { getDisplayNameForModel } from "@/lib/hooks";
import { errorHandlingFetcher } from "@/lib/fetcher";
// Number of tokens to show cost calculation for
const COST_CALCULATION_TOKENS = 1_000_000;
interface AdvancedEmbeddingFormPageProps {
updateAdvancedEmbeddingDetails: (
@ -45,14 +56,66 @@ const AdvancedEmbeddingFormPage = forwardRef<
},
ref
) => {
// Fetch contextual costs
const { data: contextualCosts, error: costError } = useSWR<
LLMContextualCost[]
>(LLM_CONTEXTUAL_COST_ADMIN_URL, errorHandlingFetcher);
const llmOptions: StringOrNumberOption[] = React.useMemo(
() =>
(contextualCosts || []).map((cost) => {
return {
name: getDisplayNameForModel(cost.model_name),
value: cost.model_name,
};
}),
[contextualCosts]
);
// Helper function to format cost as USD
const formatCost = (cost: number) => {
return new Intl.NumberFormat("en-US", {
style: "currency",
currency: "USD",
}).format(cost);
};
// Get cost info for selected model
const getSelectedModelCost = (modelName: string | null) => {
if (!contextualCosts || !modelName) return null;
return contextualCosts.find((cost) => cost.model_name === modelName);
};
// Get the current value for the selector based on the parent state
const getCurrentLLMValue = React.useMemo(() => {
if (!advancedEmbeddingDetails.contextual_rag_llm_name) return null;
return advancedEmbeddingDetails.contextual_rag_llm_name;
}, [advancedEmbeddingDetails.contextual_rag_llm_name]);
return (
<div className="py-4 rounded-lg max-w-4xl px-4 mx-auto">
<Formik
innerRef={ref}
initialValues={advancedEmbeddingDetails}
initialValues={{
...advancedEmbeddingDetails,
contextual_rag_llm: getCurrentLLMValue,
}}
validationSchema={Yup.object().shape({
multilingual_expansion: Yup.array().of(Yup.string()),
multipass_indexing: Yup.boolean(),
enable_contextual_rag: Yup.boolean(),
contextual_rag_llm: Yup.string()
.nullable()
.test(
"required-if-contextual-rag",
"LLM must be selected when Contextual RAG is enabled",
function (value) {
const enableContextualRag = this.parent.enable_contextual_rag;
console.log("enableContextualRag", enableContextualRag);
console.log("value", value);
return !enableContextualRag || value !== null;
}
),
disable_rerank_for_streaming: Yup.boolean(),
num_rerank: Yup.number()
.required("Number of results to rerank is required")
@ -79,10 +142,26 @@ const AdvancedEmbeddingFormPage = forwardRef<
validate={(values) => {
// Call updateAdvancedEmbeddingDetails for each changed field
Object.entries(values).forEach(([key, value]) => {
updateAdvancedEmbeddingDetails(
key as keyof AdvancedSearchConfiguration,
value
);
if (key === "contextual_rag_llm") {
const selectedModel = (contextualCosts || []).find(
(cost) => cost.model_name === value
);
if (selectedModel) {
updateAdvancedEmbeddingDetails(
"contextual_rag_llm_provider",
selectedModel.provider
);
updateAdvancedEmbeddingDetails(
"contextual_rag_llm_name",
selectedModel.model_name
);
}
} else {
updateAdvancedEmbeddingDetails(
key as keyof AdvancedSearchConfiguration,
value
);
}
});
// Run validation and report errors
@ -96,6 +175,23 @@ const AdvancedEmbeddingFormPage = forwardRef<
.shape({
multilingual_expansion: Yup.array().of(Yup.string()),
multipass_indexing: Yup.boolean(),
enable_contextual_rag: Yup.boolean(),
contextual_rag_llm: Yup.string()
.nullable()
.test(
"required-if-contextual-rag",
"LLM must be selected when Contextual RAG is enabled",
function (value) {
const enableContextualRag =
this.parent.enable_contextual_rag;
console.log(
"enableContextualRag2",
enableContextualRag
);
console.log("value2", value);
return !enableContextualRag || value !== null;
}
),
disable_rerank_for_streaming: Yup.boolean(),
num_rerank: Yup.number()
.required("Number of results to rerank is required")
@ -190,6 +286,56 @@ const AdvancedEmbeddingFormPage = forwardRef<
label="Disable Rerank for Streaming"
name="disable_rerank_for_streaming"
/>
<BooleanFormField
subtext="Enable contextual RAG for all chunk sizes."
optional
label="Contextual RAG"
name="enable_contextual_rag"
/>
<div>
<SelectorFormField
name="contextual_rag_llm"
label="Contextual RAG LLM"
subtext={
costError
? "Error loading LLM models. Please try again later."
: !contextualCosts
? "Loading available LLM models..."
: values.enable_contextual_rag
? "Select the LLM model to use for contextual RAG processing."
: "Enable Contextual RAG above to select an LLM model."
}
options={llmOptions}
disabled={
!values.enable_contextual_rag ||
!contextualCosts ||
!!costError
}
/>
{values.enable_contextual_rag &&
values.contextual_rag_llm &&
!costError && (
<div className="mt-2 text-sm text-text-600">
{contextualCosts ? (
<>
Estimated cost for processing{" "}
{COST_CALCULATION_TOKENS.toLocaleString()} tokens:{" "}
<span className="font-medium">
{getSelectedModelCost(values.contextual_rag_llm)
? formatCost(
getSelectedModelCost(
values.contextual_rag_llm
)!.cost
)
: "Cost information not available"}
</span>
</>
) : (
"Loading cost information..."
)}
</div>
)}
</div>
<NumberInput
description="Number of results to rerank"
optional={false}

View File

@ -64,6 +64,9 @@ export default function EmbeddingForm() {
useState<AdvancedSearchConfiguration>({
index_name: "",
multipass_indexing: true,
enable_contextual_rag: false,
contextual_rag_llm_name: null,
contextual_rag_llm_provider: null,
multilingual_expansion: [],
disable_rerank_for_streaming: false,
api_url: null,
@ -152,6 +155,9 @@ export default function EmbeddingForm() {
setAdvancedEmbeddingDetails({
index_name: searchSettings.index_name,
multipass_indexing: searchSettings.multipass_indexing,
enable_contextual_rag: searchSettings.enable_contextual_rag,
contextual_rag_llm_name: searchSettings.contextual_rag_llm_name,
contextual_rag_llm_provider: searchSettings.contextual_rag_llm_provider,
multilingual_expansion: searchSettings.multilingual_expansion,
disable_rerank_for_streaming:
searchSettings.disable_rerank_for_streaming,
@ -197,7 +203,9 @@ export default function EmbeddingForm() {
searchSettings?.embedding_precision !=
advancedEmbeddingDetails.embedding_precision ||
searchSettings?.reduced_dimension !=
advancedEmbeddingDetails.reduced_dimension;
advancedEmbeddingDetails.reduced_dimension ||
searchSettings?.enable_contextual_rag !=
advancedEmbeddingDetails.enable_contextual_rag;
const updateSearch = useCallback(async () => {
if (!selectedProvider) {
@ -384,6 +392,14 @@ export default function EmbeddingForm() {
advancedEmbeddingDetails.reduced_dimension && (
<li>Reduced dimension modification</li>
)}
{(searchSettings?.enable_contextual_rag !=
advancedEmbeddingDetails.enable_contextual_rag ||
searchSettings?.contextual_rag_llm_name !=
advancedEmbeddingDetails.contextual_rag_llm_name ||
searchSettings?.contextual_rag_llm_provider !=
advancedEmbeddingDetails.contextual_rag_llm_provider) && (
<li>Contextual RAG modification</li>
)}
</ul>
</div>
</div>
@ -471,6 +487,11 @@ export default function EmbeddingForm() {
};
const handleReIndex = async () => {
console.log("handleReIndex");
console.log(selectedProvider);
console.log(advancedEmbeddingDetails);
console.log(rerankingDetails);
console.log(reindexType);
if (!selectedProvider) {
return;
}

View File

@ -676,6 +676,7 @@ interface SelectorFormFieldProps {
includeReset?: boolean;
fontSize?: "sm" | "md" | "lg";
small?: boolean;
disabled?: boolean;
}
export function SelectorFormField({
@ -691,6 +692,7 @@ export function SelectorFormField({
includeReset = false,
fontSize = "md",
small = false,
disabled = false,
}: SelectorFormFieldProps) {
const [field] = useField<string>(name);
const { setFieldValue } = useFormikContext();
@ -742,8 +744,9 @@ export function SelectorFormField({
: setFieldValue(name, selected))
}
defaultValue={defaultValue}
disabled={disabled}
>
<SelectTrigger className={sizeClass.input}>
<SelectTrigger className={sizeClass.input} disabled={disabled}>
<SelectValue placeholder="Select...">
{currentlySelected?.name || defaultValue || ""}
</SelectValue>