From 7b2e8bcd6629d5c55112fad8e23f70dd5adee6b1 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Thu, 13 Mar 2025 12:27:31 -0700 Subject: [PATCH] port contextual rag to new repo --- ...e1ac4f39a9f_enable_contextual_retrieval.py | 50 ++++++ backend/onyx/configs/app_configs.py | 13 ++ backend/onyx/connectors/models.py | 3 + backend/onyx/context/search/models.py | 5 + .../search/postprocessing/postprocessing.py | 12 ++ backend/onyx/db/models.py | 9 + backend/onyx/db/search_settings.py | 5 + .../vespa/app_config/schemas/danswer_chunk.sd | 6 + .../document_index/vespa/chunk_retrieval.py | 7 +- backend/onyx/document_index/vespa/index.py | 2 +- .../document_index/vespa/indexing_utils.py | 6 +- .../onyx/document_index/vespa_constants.py | 4 + backend/onyx/indexing/chunker.py | 158 +++++++++++++++++- backend/onyx/indexing/embedder.py | 2 +- backend/onyx/indexing/indexing_pipeline.py | 14 ++ backend/onyx/indexing/models.py | 9 + backend/onyx/llm/chat_llm.py | 3 +- backend/onyx/llm/factory.py | 50 +++++- backend/onyx/llm/utils.py | 76 +++++++++ .../onyx/natural_language_processing/utils.py | 25 ++- backend/onyx/prompts/chat_prompts.py | 26 +++ backend/onyx/seeding/load_docs.py | 2 + backend/onyx/server/manage/llm/api.py | 37 ++++ backend/onyx/server/manage/llm/models.py | 6 + .../query_time_check/seed_dummy_docs.py | 2 + backend/shared_configs/configs.py | 1 + .../salesforce/test_postprocessing.py | 2 + backend/tests/unit/onyx/chat/conftest.py | 4 + .../test_quotes_processing.py | 4 + .../unit/onyx/chat/test_prune_and_merge.py | 2 + .../tests/unit/onyx/indexing/test_chunker.py | 33 +++- .../tests/unit/onyx/indexing/test_embedder.py | 12 +- backend/tests/unit/onyx/llm/test_chat_llm.py | 4 +- .../app/admin/configuration/llm/constants.ts | 3 + .../app/admin/configuration/search/page.tsx | 9 + web/src/app/admin/embeddings/interfaces.ts | 9 + .../pages/AdvancedEmbeddingFormPage.tsx | 130 +++++++++++++- .../embeddings/pages/EmbeddingFormPage.tsx | 10 +- web/src/components/admin/connectors/Field.tsx | 5 +- 39 files changed, 723 insertions(+), 37 deletions(-) create mode 100644 backend/alembic/versions/8e1ac4f39a9f_enable_contextual_retrieval.py diff --git a/backend/alembic/versions/8e1ac4f39a9f_enable_contextual_retrieval.py b/backend/alembic/versions/8e1ac4f39a9f_enable_contextual_retrieval.py new file mode 100644 index 000000000..319ce7e27 --- /dev/null +++ b/backend/alembic/versions/8e1ac4f39a9f_enable_contextual_retrieval.py @@ -0,0 +1,50 @@ +"""enable contextual retrieval + +Revision ID: 8e1ac4f39a9f +Revises: 3934b1bc7b62 +Create Date: 2024-12-20 13:29:09.918661 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "8e1ac4f39a9f" +down_revision = "3934b1bc7b62" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "search_settings", + sa.Column( + "enable_contextual_rag", + sa.Boolean(), + nullable=False, + server_default="false", + ), + ) + op.add_column( + "search_settings", + sa.Column( + "contextual_rag_llm_name", + sa.String(), + nullable=True, + ), + ) + op.add_column( + "search_settings", + sa.Column( + "contextual_rag_llm_provider", + sa.String(), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("search_settings", "enable_contextual_rag") + op.drop_column("search_settings", "contextual_rag_llm_name") + op.drop_column("search_settings", "contextual_rag_llm_provider") diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 74be29c6d..e57c62447 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -470,6 +470,8 @@ NUM_SECONDARY_INDEXING_WORKERS = int( ENABLE_MULTIPASS_INDEXING = ( os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true" ) +# Enable contextual retrieval +ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true" # Finer grained chunking for more detail retention # Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE # tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end @@ -511,6 +513,17 @@ MAX_FILE_SIZE_BYTES = int( os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024 ) # 2GB in bytes +# Use document summary for contextual rag +USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true" +# Use chunk summary for contextual rag +USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true" +# Average summary embeddings for contextual rag (not yet implemented) +AVERAGE_SUMMARY_EMBEDDINGS = ( + os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true" +) + +MAX_TOKENS_FOR_FULL_INCLUSION = 4096 + ##### # Miscellaneous ##### diff --git a/backend/onyx/connectors/models.py b/backend/onyx/connectors/models.py index 00335cded..c1e88f9a8 100644 --- a/backend/onyx/connectors/models.py +++ b/backend/onyx/connectors/models.py @@ -164,6 +164,9 @@ class DocumentBase(BaseModel): attributes.append(k + INDEX_SEPARATOR + v) return attributes + def get_content(self) -> str: + return " ".join([section.text for section in self.sections]) + class Document(DocumentBase): """Used for Onyx ingestion api, the ID is required""" diff --git a/backend/onyx/context/search/models.py b/backend/onyx/context/search/models.py index 980ed9644..efeeb4286 100644 --- a/backend/onyx/context/search/models.py +++ b/backend/onyx/context/search/models.py @@ -80,6 +80,9 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting): reduced_dimension=search_settings.reduced_dimension, # Whether switching to this model requires re-indexing background_reindex_enabled=search_settings.background_reindex_enabled, + enable_contextual_rag=search_settings.enable_contextual_rag, + contextual_rag_llm_name=search_settings.contextual_rag_llm_name, + contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider, # Reranking Details rerank_model_name=search_settings.rerank_model_name, rerank_provider_type=search_settings.rerank_provider_type, @@ -218,6 +221,8 @@ class InferenceChunk(BaseChunk): # to specify that a set of words should be highlighted. For example: # ["the answer is 42", "he couldn't find an answer"] match_highlights: list[str] + doc_summary: str + chunk_context: str # when the doc was last updated updated_at: datetime | None diff --git a/backend/onyx/context/search/postprocessing/postprocessing.py b/backend/onyx/context/search/postprocessing/postprocessing.py index 41243eec7..cfa07ef02 100644 --- a/backend/onyx/context/search/postprocessing/postprocessing.py +++ b/backend/onyx/context/search/postprocessing/postprocessing.py @@ -196,9 +196,21 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk RETURN_SEPARATOR ) + def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str: + # remove document summary + if chunk.content.startswith(chunk.doc_summary): + chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip() + # remove chunk context + if chunk.content.endswith(chunk.chunk_context): + chunk.content = chunk.content[ + : len(chunk.content) - len(chunk.chunk_context) + ].rstrip() + return chunk.content + for chunk in chunks: chunk.content = _remove_title(chunk) chunk.content = _remove_metadata_suffix(chunk) + chunk.content = _remove_contextual_rag(chunk) return [chunk.to_inference_chunk() for chunk in chunks] diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index cbe3b1be3..951e2b760 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -791,6 +791,15 @@ class SearchSettings(Base): # Mini and Large Chunks (large chunk also checks for model max context) multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True) + # Contextual RAG + enable_contextual_rag: Mapped[bool] = mapped_column(Boolean, default=False) + + # Contextual RAG LLM + contextual_rag_llm_name: Mapped[str | None] = mapped_column(String, nullable=True) + contextual_rag_llm_provider: Mapped[str | None] = mapped_column( + String, nullable=True + ) + multilingual_expansion: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), default=[] ) diff --git a/backend/onyx/db/search_settings.py b/backend/onyx/db/search_settings.py index bddaf2115..1d21d0d0e 100644 --- a/backend/onyx/db/search_settings.py +++ b/backend/onyx/db/search_settings.py @@ -62,6 +62,9 @@ def create_search_settings( multipass_indexing=search_settings.multipass_indexing, embedding_precision=search_settings.embedding_precision, reduced_dimension=search_settings.reduced_dimension, + enable_contextual_rag=search_settings.enable_contextual_rag, + contextual_rag_llm_name=search_settings.contextual_rag_llm_name, + contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider, multilingual_expansion=search_settings.multilingual_expansion, disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming, rerank_model_name=search_settings.rerank_model_name, @@ -319,6 +322,7 @@ def get_old_default_embedding_model() -> IndexingSetting: passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""), index_name="danswer_chunk", multipass_indexing=False, + enable_contextual_rag=False, api_url=None, ) @@ -333,5 +337,6 @@ def get_new_default_embedding_model() -> IndexingSetting: passage_prefix=ASYM_PASSAGE_PREFIX, index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}", multipass_indexing=False, + enable_contextual_rag=False, api_url=None, ) diff --git a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd index 70980852a..4b7c7c1e0 100644 --- a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd @@ -98,6 +98,12 @@ schema DANSWER_CHUNK_NAME { field metadata type string { indexing: summary | attribute } + field chunk_context type string { + indexing: summary | attribute + } + field doc_summary type string { + indexing: summary | attribute + } field metadata_suffix type string { indexing: summary | attribute } diff --git a/backend/onyx/document_index/vespa/chunk_retrieval.py b/backend/onyx/document_index/vespa/chunk_retrieval.py index 8c7ee6963..a19f2d087 100644 --- a/backend/onyx/document_index/vespa/chunk_retrieval.py +++ b/backend/onyx/document_index/vespa/chunk_retrieval.py @@ -24,9 +24,11 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import ( from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST from onyx.document_index.vespa_constants import BLURB from onyx.document_index.vespa_constants import BOOST +from onyx.document_index.vespa_constants import CHUNK_CONTEXT from onyx.document_index.vespa_constants import CHUNK_ID from onyx.document_index.vespa_constants import CONTENT from onyx.document_index.vespa_constants import CONTENT_SUMMARY +from onyx.document_index.vespa_constants import DOC_SUMMARY from onyx.document_index.vespa_constants import DOC_UPDATED_AT from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT @@ -126,7 +128,8 @@ def _vespa_hit_to_inference_chunk( return InferenceChunkUncleaned( chunk_id=fields[CHUNK_ID], blurb=fields.get(BLURB, ""), # Unused - content=fields[CONTENT], # Includes extra title prefix and metadata suffix + content=fields[CONTENT], # Includes extra title prefix and metadata suffix; + # also sometimes context for contextual rag source_links=source_links_dict or {0: ""}, section_continuation=fields[SECTION_CONTINUATION], document_id=fields[DOCUMENT_ID], @@ -143,6 +146,8 @@ def _vespa_hit_to_inference_chunk( large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []), metadata=metadata, metadata_suffix=fields.get(METADATA_SUFFIX), + doc_summary=fields.get(DOC_SUMMARY, ""), + chunk_context=fields.get(CHUNK_CONTEXT, ""), match_highlights=match_highlights, updated_at=updated_at, ) diff --git a/backend/onyx/document_index/vespa/index.py b/backend/onyx/document_index/vespa/index.py index 17aadb36c..b60eaa322 100644 --- a/backend/onyx/document_index/vespa/index.py +++ b/backend/onyx/document_index/vespa/index.py @@ -187,7 +187,7 @@ class VespaIndex(DocumentIndex): ) -> None: if MULTI_TENANT: logger.info( - "Skipping Vespa index seup for multitenant (would wipe all indices)" + "Skipping Vespa index setup for multitenant (would wipe all indices)" ) return None diff --git a/backend/onyx/document_index/vespa/indexing_utils.py b/backend/onyx/document_index/vespa/indexing_utils.py index ab08dc4d1..9145ce63c 100644 --- a/backend/onyx/document_index/vespa/indexing_utils.py +++ b/backend/onyx/document_index/vespa/indexing_utils.py @@ -25,9 +25,11 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR from onyx.document_index.vespa_constants import BLURB from onyx.document_index.vespa_constants import BOOST +from onyx.document_index.vespa_constants import CHUNK_CONTEXT from onyx.document_index.vespa_constants import CHUNK_ID from onyx.document_index.vespa_constants import CONTENT from onyx.document_index.vespa_constants import CONTENT_SUMMARY +from onyx.document_index.vespa_constants import DOC_SUMMARY from onyx.document_index.vespa_constants import DOC_UPDATED_AT from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT @@ -174,7 +176,7 @@ def _index_vespa_chunk( # For the BM25 index, the keyword suffix is used, the vector is already generated with the more # natural language representation of the metadata section CONTENT: remove_invalid_unicode_chars( - f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}" + f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}" ), # This duplication of `content` is needed for keyword highlighting # Note that it's not exactly the same as the actual content @@ -189,6 +191,8 @@ def _index_vespa_chunk( # Save as a list for efficient extraction as an Attribute METADATA_LIST: metadata_list, METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword), + CHUNK_CONTEXT: chunk.chunk_context, + DOC_SUMMARY: chunk.doc_summary, EMBEDDINGS: embeddings_name_vector_map, TITLE_EMBEDDING: chunk.title_embedding, DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at), diff --git a/backend/onyx/document_index/vespa_constants.py b/backend/onyx/document_index/vespa_constants.py index 8a32fb721..66a7fd99d 100644 --- a/backend/onyx/document_index/vespa_constants.py +++ b/backend/onyx/document_index/vespa_constants.py @@ -71,6 +71,8 @@ LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids" METADATA = "metadata" METADATA_LIST = "metadata_list" METADATA_SUFFIX = "metadata_suffix" +DOC_SUMMARY = "doc_summary" +CHUNK_CONTEXT = "chunk_context" BOOST = "boost" AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor" DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch @@ -106,6 +108,8 @@ YQL_BASE = ( f"{LARGE_CHUNK_REFERENCE_IDS}, " f"{METADATA}, " f"{METADATA_SUFFIX}, " + f"{DOC_SUMMARY}, " + f"{CHUNK_CONTEXT}, " f"{CONTENT_SUMMARY} " f"from {{index_name}} where " ) diff --git a/backend/onyx/indexing/chunker.py b/backend/onyx/indexing/chunker.py index 0dea6fa12..faeb8a9cb 100644 --- a/backend/onyx/indexing/chunker.py +++ b/backend/onyx/indexing/chunker.py @@ -1,7 +1,11 @@ +from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS from onyx.configs.app_configs import BLURB_SIZE from onyx.configs.app_configs import LARGE_CHUNK_RATIO +from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION from onyx.configs.app_configs import MINI_CHUNK_SIZE from onyx.configs.app_configs import SKIP_METADATA_IN_CHUNK +from onyx.configs.app_configs import USE_CHUNK_SUMMARY +from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY from onyx.configs.constants import DocumentSource from onyx.configs.constants import RETURN_SEPARATOR from onyx.configs.constants import SECTION_SEPARATOR @@ -13,10 +17,19 @@ from onyx.connectors.models import IndexingDocument from onyx.connectors.models import Section from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.indexing.models import DocAwareChunk +from onyx.llm.interfaces import LLM +from onyx.llm.utils import get_max_input_tokens +from onyx.llm.utils import MAX_CONTEXT_TOKENS +from onyx.llm.utils import message_to_string from onyx.natural_language_processing.utils import BaseTokenizer +from onyx.natural_language_processing.utils import tokenizer_trim_middle +from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT1 +from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT2 +from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_PROMPT from onyx.utils.logger import setup_logger from onyx.utils.text_processing import clean_text from onyx.utils.text_processing import shared_precompare_cleanup +from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT # Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps @@ -82,6 +95,8 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks], mini_chunk_texts=None, large_chunk_id=large_chunk_id, + chunk_context="", + doc_summary="", ) offset = 0 @@ -118,8 +133,10 @@ class Chunker: def __init__( self, tokenizer: BaseTokenizer, + llm: LLM | None = None, enable_multipass: bool = False, enable_large_chunks: bool = False, + enable_contextual_rag: bool = False, blurb_size: int = BLURB_SIZE, include_metadata: bool = not SKIP_METADATA_IN_CHUNK, chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE, @@ -129,13 +146,33 @@ class Chunker: ) -> None: from llama_index.text_splitter import SentenceSplitter + if llm is None and enable_contextual_rag: + raise ValueError("LLM must be provided for contextual RAG") + self.include_metadata = include_metadata self.chunk_token_limit = chunk_token_limit self.enable_multipass = enable_multipass self.enable_large_chunks = enable_large_chunks + self.enable_contextual_rag = enable_contextual_rag + if enable_contextual_rag: + assert ( + USE_CHUNK_SUMMARY or USE_DOCUMENT_SUMMARY + ), "Contextual RAG requires at least one of chunk summary and document summary enabled" self.tokenizer = tokenizer + self.llm = llm self.callback = callback + self.max_context = 0 + self.prompt_tokens = 0 + if self.llm is not None: + self.max_context = get_max_input_tokens( + model_name=self.llm.config.model_name, + model_provider=self.llm.config.model_provider, + ) + self.prompt_tokens = len( + self.tokenizer.encode(CONTEXTUAL_RAG_PROMPT1 + CONTEXTUAL_RAG_PROMPT2) + ) + self.blurb_splitter = SentenceSplitter( tokenizer=tokenizer.tokenize, chunk_size=blurb_size, @@ -204,6 +241,8 @@ class Chunker: metadata_suffix_semantic: str = "", metadata_suffix_keyword: str = "", image_file_name: str | None = None, + doc_summary: str = "", + chunk_context: str = "", ) -> None: """ Helper to create a new DocAwareChunk, append it to chunks_list. @@ -221,6 +260,8 @@ class Chunker: metadata_suffix_keyword=metadata_suffix_keyword, mini_chunk_texts=self._get_mini_chunk_texts(text), large_chunk_id=None, + doc_summary=doc_summary, + chunk_context=chunk_context, ) chunks_list.append(new_chunk) @@ -309,7 +350,7 @@ class Chunker: continue # CASE 2: Normal text section - section_token_count = len(self.tokenizer.tokenize(section_text)) + section_token_count = len(self.tokenizer.encode(section_text)) # If the section is large on its own, split it separately if section_token_count > content_token_limit: @@ -332,8 +373,7 @@ class Chunker: # If even the split_text is bigger than strict limit, further split if ( STRICT_CHUNK_TOKEN_LIMIT - and len(self.tokenizer.tokenize(split_text)) - > content_token_limit + and len(self.tokenizer.encode(split_text)) > content_token_limit ): smaller_chunks = self._split_oversized_chunk( split_text, content_token_limit @@ -363,10 +403,10 @@ class Chunker: continue # If we can still fit this section into the current chunk, do so - current_token_count = len(self.tokenizer.tokenize(chunk_text)) + current_token_count = len(self.tokenizer.encode(chunk_text)) current_offset = len(shared_precompare_cleanup(chunk_text)) next_section_tokens = ( - len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count + len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count ) if next_section_tokens + current_token_count <= content_token_limit: @@ -414,7 +454,7 @@ class Chunker: # Title prep title = self._extract_blurb(document.get_title_for_document_index() or "") title_prefix = title + RETURN_SEPARATOR if title else "" - title_tokens = len(self.tokenizer.tokenize(title_prefix)) + title_tokens = len(self.tokenizer.encode(title_prefix)) # Metadata prep metadata_suffix_semantic = "" @@ -427,15 +467,53 @@ class Chunker: ) = _get_metadata_suffix_for_document_index( document.metadata, include_separator=True ) - metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic)) + metadata_tokens = len(self.tokenizer.encode(metadata_suffix_semantic)) # If metadata is too large, skip it in the semantic content if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE: metadata_suffix_semantic = "" metadata_tokens = 0 + single_chunk_fits = True + doc_token_count = 0 + if self.enable_contextual_rag: + doc_content = document.get_content() + tokenized_doc = self.tokenizer.tokenize(doc_content) + doc_token_count = len(tokenized_doc) + + # check if doc + title + metadata fits in a single chunk. If so, no need for contextual RAG + single_chunk_fits = ( + doc_token_count + title_tokens + metadata_tokens + <= self.chunk_token_limit + ) + + # expand the size of the context used for contextual rag based on whether chunk context and doc summary are used + context_size = 0 + if ( + self.enable_contextual_rag + and not single_chunk_fits + and not AVERAGE_SUMMARY_EMBEDDINGS + ): + if USE_CHUNK_SUMMARY: + context_size += MAX_CONTEXT_TOKENS + if USE_DOCUMENT_SUMMARY: + context_size += MAX_CONTEXT_TOKENS + # Adjust content token limit to accommodate title + metadata - content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens + content_token_limit = ( + self.chunk_token_limit - title_tokens - metadata_tokens - context_size + ) + + # first check: if there is not enough actual chunk content when including contextual rag, + # then don't do contextual rag + if content_token_limit <= CHUNK_MIN_CONTENT: + context_size = 0 # Don't do contextual RAG + # revert to previous content token limit + content_token_limit = ( + self.chunk_token_limit - title_tokens - metadata_tokens + ) + + # If there is not enough context remaining then just index the chunk with no prefix/suffix if content_token_limit <= CHUNK_MIN_CONTENT: # Not enough space left, so revert to full chunk without the prefix content_token_limit = self.chunk_token_limit @@ -459,6 +537,70 @@ class Chunker: large_chunks = generate_large_chunks(normal_chunks) normal_chunks.extend(large_chunks) + # exit early if contextual rag disabled or not used for this document + if not self.enable_contextual_rag or context_size == 0: + return normal_chunks + + if self.llm is None: + raise ValueError("LLM must be available for contextual RAG") + + doc_summary = "" + doc_content = "" + doc_tokens = [] + if USE_DOCUMENT_SUMMARY: + trunc_doc_tokens = self.max_context - len( + self.tokenizer.encode(DOCUMENT_SUMMARY_PROMPT) + ) + doc_tokens = self.tokenizer.encode(doc_content) + doc_content = tokenizer_trim_middle( + doc_tokens, trunc_doc_tokens, self.tokenizer + ) + summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content) + doc_summary = message_to_string( + self.llm.invoke(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS) + ) + for chunk in normal_chunks: + chunk.doc_summary = doc_summary + + if USE_CHUNK_SUMMARY: + # Truncate the document content to fit in the model context + trunc_doc_tokens = ( + self.max_context - self.prompt_tokens - self.chunk_token_limit + ) + + # use values computed in above doc summary section if available + doc_tokens = doc_tokens or self.tokenizer.encode(doc_content) + doc_content = doc_content or tokenizer_trim_middle( + doc_tokens, trunc_doc_tokens, self.tokenizer + ) + + # only compute doc summary if needed + doc_info = ( + doc_content + if len(doc_tokens) <= MAX_TOKENS_FOR_FULL_INCLUSION + else ( + doc_summary or DOCUMENT_SUMMARY_PROMPT.format(document=doc_content) + ) + ) + + context_prompt1 = CONTEXTUAL_RAG_PROMPT1.format(document=doc_info) + + def assign_context(chunk: DocAwareChunk) -> None: + context_prompt2 = CONTEXTUAL_RAG_PROMPT2.format(chunk=chunk.content) + chunk.chunk_context = ( + "" + if self.llm is None + else message_to_string( + self.llm.invoke( + context_prompt1 + context_prompt2, + max_tokens=MAX_CONTEXT_TOKENS, + ) + ) + ) + + run_functions_tuples_in_parallel( + [(assign_context, (chunk,)) for chunk in normal_chunks] + ) return normal_chunks def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]: diff --git a/backend/onyx/indexing/embedder.py b/backend/onyx/indexing/embedder.py index 67bf56fc8..78ea96340 100644 --- a/backend/onyx/indexing/embedder.py +++ b/backend/onyx/indexing/embedder.py @@ -121,7 +121,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): if chunk.large_chunk_reference_ids: large_chunks_present = True chunk_text = ( - f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}" + f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}" ) or chunk.source_document.get_title_for_document_index() if not chunk_text: diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index 3967b7f7c..1b5f09477 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import Session from onyx.access.access import get_access_for_documents from onyx.access.models import DocumentAccess +from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG from onyx.configs.app_configs import MAX_DOCUMENT_CHARS from onyx.configs.constants import DEFAULT_BOOST from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled @@ -58,6 +59,7 @@ from onyx.indexing.models import IndexChunk from onyx.indexing.models import UpdatableChunkData from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff from onyx.llm.factory import get_default_llm_with_vision +from onyx.llm.factory import get_llm_for_contextual_rag from onyx.natural_language_processing.search_nlp_models import ( InformationContentClassificationModel, ) @@ -794,12 +796,24 @@ def build_indexing_pipeline( search_settings = get_current_search_settings(db_session) multipass_config = get_multipass_config(search_settings) + enable_contextual_rag = ( + search_settings.enable_contextual_rag + if search_settings + else ENABLE_CONTEXTUAL_RAG + ) + llm = get_llm_for_contextual_rag( + search_settings.contextual_rag_llm_name, + search_settings.contextual_rag_llm_provider, + ) + chunker = chunker or Chunker( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=multipass_config.multipass_indexing, enable_large_chunks=multipass_config.enable_large_chunks, + enable_contextual_rag=enable_contextual_rag, # after every doc, update status in case there are a bunch of really long docs callback=callback, + llm=llm, ) return partial( diff --git a/backend/onyx/indexing/models.py b/backend/onyx/indexing/models.py index 686ac2942..cee6499eb 100644 --- a/backend/onyx/indexing/models.py +++ b/backend/onyx/indexing/models.py @@ -49,6 +49,11 @@ class DocAwareChunk(BaseChunk): metadata_suffix_semantic: str metadata_suffix_keyword: str + # This is the summary for the document generated for contextual RAG + doc_summary: str + # This is the context for this chunk generated for contextual RAG + chunk_context: str + mini_chunk_texts: list[str] | None large_chunk_id: int | None @@ -154,6 +159,9 @@ class IndexingSetting(EmbeddingModelDetail): reduced_dimension: int | None = None background_reindex_enabled: bool = True + enable_contextual_rag: bool + contextual_rag_llm_name: str | None = None + contextual_rag_llm_provider: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} @@ -178,6 +186,7 @@ class IndexingSetting(EmbeddingModelDetail): embedding_precision=search_settings.embedding_precision, reduced_dimension=search_settings.reduced_dimension, background_reindex_enabled=search_settings.background_reindex_enabled, + enable_contextual_rag=search_settings.enable_contextual_rag, ) diff --git a/backend/onyx/llm/chat_llm.py b/backend/onyx/llm/chat_llm.py index 2e9496856..1f33dc10e 100644 --- a/backend/onyx/llm/chat_llm.py +++ b/backend/onyx/llm/chat_llm.py @@ -425,12 +425,12 @@ class DefaultMultiLLM(LLM): messages=processed_prompt, tools=tools, tool_choice=tool_choice if tools else None, + max_tokens=max_tokens, # streaming choice stream=stream, # model params temperature=0, timeout=timeout_override or self._timeout, - max_tokens=max_tokens, # For now, we don't support parallel tool calls # NOTE: we can't pass this in if tools are not specified # or else OpenAI throws an error @@ -531,6 +531,7 @@ class DefaultMultiLLM(LLM): tool_choice, structured_response_format, timeout_override, + max_tokens, ) return diff --git a/backend/onyx/llm/factory.py b/backend/onyx/llm/factory.py index 3d0bb6b3b..2d6f9bebf 100644 --- a/backend/onyx/llm/factory.py +++ b/backend/onyx/llm/factory.py @@ -17,6 +17,7 @@ from onyx.llm.interfaces import LLM from onyx.llm.override_models import LLMOverride from onyx.llm.utils import model_supports_image_input from onyx.server.manage.llm.models import FullLLMProvider +from onyx.server.manage.llm.models import LLMProvider from onyx.utils.headers import build_llm_extra_headers from onyx.utils.logger import setup_logger from onyx.utils.long_term_log import LongTermLogger @@ -154,6 +155,44 @@ def get_default_llm_with_vision( return None +def llm_from_provider( + model_name: str, + llm_provider: LLMProvider, + timeout: int | None = None, + temperature: float | None = None, + additional_headers: dict[str, str] | None = None, + long_term_logger: LongTermLogger | None = None, +) -> LLM: + return get_llm( + provider=llm_provider.provider, + model=model_name, + deployment_name=llm_provider.deployment_name, + api_key=llm_provider.api_key, + api_base=llm_provider.api_base, + api_version=llm_provider.api_version, + custom_config=llm_provider.custom_config, + timeout=timeout, + temperature=temperature, + additional_headers=additional_headers, + long_term_logger=long_term_logger, + ) + + +def get_llm_for_contextual_rag( + model_name: str | None, model_provider: str | None +) -> LLM | None: + if not model_name or not model_provider: + return None + with get_session_context_manager() as db_session: + llm_provider = fetch_provider(db_session, model_provider) + if not llm_provider: + raise ValueError("No LLM provider with name {} found".format(model_provider)) + return llm_from_provider( + model_name=model_name, + llm_provider=llm_provider, + ) + + def get_default_llms( timeout: int | None = None, temperature: float | None = None, @@ -179,14 +218,9 @@ def get_default_llms( raise ValueError("No fast default model name found") def _create_llm(model: str) -> LLM: - return get_llm( - provider=llm_provider.provider, - model=model, - deployment_name=llm_provider.deployment_name, - api_key=llm_provider.api_key, - api_base=llm_provider.api_base, - api_version=llm_provider.api_version, - custom_config=llm_provider.custom_config, + return llm_from_provider( + model_name=model, + llm_provider=llm_provider, timeout=timeout, temperature=temperature, additional_headers=additional_headers, diff --git a/backend/onyx/llm/utils.py b/backend/onyx/llm/utils.py index 04fc2260c..9f257b040 100644 --- a/backend/onyx/llm/utils.py +++ b/backend/onyx/llm/utils.py @@ -29,13 +29,19 @@ from litellm.exceptions import Timeout # type: ignore from litellm.exceptions import UnprocessableEntityError # type: ignore from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS +from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION +from onyx.configs.app_configs import USE_CHUNK_SUMMARY +from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY from onyx.configs.constants import MessageType +from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from onyx.configs.model_configs import GEN_AI_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS from onyx.file_store.models import ChatFileType from onyx.file_store.models import InMemoryChatFile from onyx.llm.interfaces import LLM +from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_TOKEN_ESTIMATE +from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_TOKEN_ESTIMATE from onyx.prompts.constants import CODE_BLOCK_PAT from onyx.utils.b64 import get_image_type from onyx.utils.b64 import get_image_type_from_bytes @@ -44,6 +50,10 @@ from shared_configs.configs import LOG_LEVEL logger = setup_logger() +MAX_CONTEXT_TOKENS = 100 +ONE_MILLION = 1_000_000 +CHUNKS_PER_DOC_ESTIMATE = 5 + def litellm_exception_to_error_msg( e: Exception, @@ -416,6 +426,72 @@ def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | N return None +def get_llm_contextual_cost( + llm: LLM, +) -> float: + """ + Approximate the cost of using the given LLM for indexing with Contextual RAG. + + We use a precomputed estimate for the number of tokens in the contextualizing prompts, + and we assume that every chunk is maximized in terms of content and context. + We also assume that every document is maximized in terms of content, as currently if + a document is longer than a certain length, its summary is used instead of the full content. + + We expect that the first assumption will overestimate more than the second one + underestimates, so this should be a fairly conservative price estimate. Also, + this does not account for the cost of documents that fit within a single chunk + which do not get contextualized. + """ + + # calculate input costs + num_tokens = ONE_MILLION + num_input_chunks = num_tokens // DOC_EMBEDDING_CONTEXT_SIZE + + # We assume that the documents are MAX_TOKENS_FOR_FULL_INCLUSION tokens long + # on average. + num_docs = num_tokens // MAX_TOKENS_FOR_FULL_INCLUSION + + num_input_tokens = 0 + num_output_tokens = 0 + + if not USE_CHUNK_SUMMARY and not USE_DOCUMENT_SUMMARY: + return 0 + + if USE_CHUNK_SUMMARY: + # Each per-chunk prompt includes: + # - The prompt tokens + # - the document tokens + # - the chunk tokens + + # for each chunk, we prompt the LLM with the contextual RAG prompt + # and the full document content (or the doc summary, so this is an overestimate) + num_input_tokens += num_input_chunks * ( + CONTEXTUAL_RAG_TOKEN_ESTIMATE + MAX_TOKENS_FOR_FULL_INCLUSION + ) + + # in aggregate, each chunk content is used as a prompt input once + # so the full input size is covered + num_input_tokens += num_tokens + + # A single MAX_CONTEXT_TOKENS worth of output is generated per chunk + num_output_tokens += num_input_chunks * MAX_CONTEXT_TOKENS + + # going over each doc once means all the tokens, plus the prompt tokens for + # the summary prompt. This CAN happen even when USE_DOCUMENT_SUMMARY is false, + # since doc summaries are used for longer documents when USE_CHUNK_SUMMARY is true. + # So, we include this unconditionally to overestimate. + num_input_tokens += num_tokens + num_docs * DOCUMENT_SUMMARY_TOKEN_ESTIMATE + num_output_tokens += num_docs * MAX_CONTEXT_TOKENS + + usd_per_prompt, usd_per_completion = litellm.cost_per_token( + model=llm.config.model_name, + prompt_tokens=num_input_tokens, + completion_tokens=num_output_tokens, + ) + # Costs are in USD dollars per million tokens + return usd_per_prompt + usd_per_completion + + def get_llm_max_tokens( model_map: dict, model_name: str, diff --git a/backend/onyx/natural_language_processing/utils.py b/backend/onyx/natural_language_processing/utils.py index 3c4d13920..0860cf429 100644 --- a/backend/onyx/natural_language_processing/utils.py +++ b/backend/onyx/natural_language_processing/utils.py @@ -11,6 +11,8 @@ from onyx.context.search.models import InferenceChunk from onyx.utils.logger import setup_logger from shared_configs.enums import EmbeddingProvider +TRIM_SEP_PAT = "\n... {n} tokens removed...\n" + logger = setup_logger() transformer_logging.set_verbosity_error() os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -159,9 +161,26 @@ def tokenizer_trim_content( content: str, desired_length: int, tokenizer: BaseTokenizer ) -> str: tokens = tokenizer.encode(content) - if len(tokens) > desired_length: - content = tokenizer.decode(tokens[:desired_length]) - return content + if len(tokens) <= desired_length: + return content + + return tokenizer.decode(tokens[:desired_length]) + + +def tokenizer_trim_middle( + tokens: list[int], desired_length: int, tokenizer: BaseTokenizer +) -> str: + if len(tokens) <= desired_length: + return tokenizer.decode(tokens) + sep_str = TRIM_SEP_PAT.format(n=len(tokens) - desired_length) + sep_tokens = tokenizer.encode(sep_str) + slice_size = (desired_length - len(sep_tokens)) // 2 + assert slice_size > 0, "Slice size is not positive, desired length is too short" + return ( + tokenizer.decode(tokens[:slice_size]) + + sep_str + + tokenizer.decode(tokens[-slice_size:]) + ) def tokenizer_trim_chunks( diff --git a/backend/onyx/prompts/chat_prompts.py b/backend/onyx/prompts/chat_prompts.py index 04cc33488..65c4f9e85 100644 --- a/backend/onyx/prompts/chat_prompts.py +++ b/backend/onyx/prompts/chat_prompts.py @@ -220,3 +220,29 @@ Chat History: Based on the above, what is a short name to convey the topic of the conversation? """.strip() + +# NOTE: the prompt separation is partially done for efficiency; previously I tried +# to do it all in one prompt with sequential format() calls but this will cause a backend +# error when the document contains any {} as python will expect the {} to be filled by +# format() arguments +CONTEXTUAL_RAG_PROMPT1 = """ +{document} + +Here is the chunk we want to situate within the whole document""" + +CONTEXTUAL_RAG_PROMPT2 = """ +{chunk} + +Please give a short succinct context to situate this chunk within the overall document +for the purposes of improving search retrieval of the chunk. Answer only with the succinct +context and nothing else. """ + +CONTEXTUAL_RAG_TOKEN_ESTIMATE = 64 # 19 + 45 + +DOCUMENT_SUMMARY_PROMPT = """ +{document} + +Please give a short succinct summary of the entire document. Answer only with the succinct +summary and nothing else. """ + +DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29 diff --git a/backend/onyx/seeding/load_docs.py b/backend/onyx/seeding/load_docs.py index a3aa99ea6..4c2db16be 100644 --- a/backend/onyx/seeding/load_docs.py +++ b/backend/onyx/seeding/load_docs.py @@ -87,6 +87,8 @@ def _create_indexable_chunks( metadata_suffix_keyword="", mini_chunk_texts=None, large_chunk_reference_ids=[], + doc_summary="", + chunk_context="", embeddings=ChunkEmbedding( full_embedding=preprocessed_doc["content_embedding"], mini_chunk_embeddings=[], diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index ceafca2e3..f00e1c59c 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -21,10 +21,12 @@ from onyx.llm.factory import get_default_llms from onyx.llm.factory import get_llm from onyx.llm.llm_provider_options import fetch_available_well_known_llms from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor +from onyx.llm.utils import get_llm_contextual_cost from onyx.llm.utils import litellm_exception_to_error_msg from onyx.llm.utils import model_supports_image_input from onyx.llm.utils import test_llm from onyx.server.manage.llm.models import FullLLMProvider +from onyx.server.manage.llm.models import LLMCost from onyx.server.manage.llm.models import LLMProviderDescriptor from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import TestLLMRequest @@ -259,3 +261,38 @@ def list_llm_provider_basics( db_session, user ) ] + + +@admin_router.get("/provider-contextual-cost") +def get_provider_contextual_cost( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[LLMCost]: + """ + Get the cost of Re-indexing all documents for contextual retrieval. + + See https://docs.litellm.ai/docs/completion/token_usage#5-cost_per_token + This includes: + - The cost of invoking the LLM on each chunk-document pair to get + - the doc_summary + - the chunk_context + - The per-token cost of the LLM used to generate the doc_summary and chunk_context + """ + providers = fetch_existing_llm_providers(db_session) + costs = [] + for provider in providers: + for model_name in provider.display_model_names or provider.model_names or []: + llm = get_llm( + provider=provider.provider, + model=model_name, + deployment_name=provider.deployment_name, + api_key=provider.api_key, + api_base=provider.api_base, + api_version=provider.api_version, + custom_config=provider.custom_config, + ) + cost = get_llm_contextual_cost(llm) + costs.append( + LLMCost(provider=provider.name, model_name=model_name, cost=cost) + ) + return costs diff --git a/backend/onyx/server/manage/llm/models.py b/backend/onyx/server/manage/llm/models.py index 3172f5adf..a25f3ae43 100644 --- a/backend/onyx/server/manage/llm/models.py +++ b/backend/onyx/server/manage/llm/models.py @@ -115,3 +115,9 @@ class VisionProviderResponse(FullLLMProvider): """Response model for vision providers endpoint, including vision-specific fields.""" vision_models: list[str] + + +class LLMCost(BaseModel): + provider: str + model_name: str + cost: float diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py index da1edd8db..3e9af8df9 100644 --- a/backend/scripts/query_time_check/seed_dummy_docs.py +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -63,6 +63,8 @@ def generate_dummy_chunk( title_prefix=f"Title prefix for doc {doc_id}", metadata_suffix_semantic="", metadata_suffix_keyword="", + doc_summary="", + chunk_context="", mini_chunk_texts=None, embeddings=ChunkEmbedding( full_embedding=generate_random_embedding(embedding_dim), diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 13e7ba03e..30ff23fd6 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -99,6 +99,7 @@ PRESERVED_SEARCH_FIELDS = [ "api_url", "index_name", "multipass_indexing", + "enable_contextual_rag", "model_dim", "normalize", "passage_prefix", diff --git a/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py b/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py index 162db23c5..a8a913336 100644 --- a/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py +++ b/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py @@ -32,6 +32,8 @@ def create_test_chunk( match_highlights=[], updated_at=datetime.now(), image_file_name=None, + doc_summary="", + chunk_context="", ) diff --git a/backend/tests/unit/onyx/chat/conftest.py b/backend/tests/unit/onyx/chat/conftest.py index 8837b1eb2..cbe7f9c76 100644 --- a/backend/tests/unit/onyx/chat/conftest.py +++ b/backend/tests/unit/onyx/chat/conftest.py @@ -81,6 +81,8 @@ def mock_inference_sections() -> list[InferenceSection]: source_links={0: "https://example.com/doc1"}, match_highlights=[], image_file_name=None, + doc_summary="", + chunk_context="", ), chunks=MagicMock(), ), @@ -104,6 +106,8 @@ def mock_inference_sections() -> list[InferenceSection]: source_links={0: "https://example.com/doc2"}, match_highlights=[], image_file_name=None, + doc_summary="", + chunk_context="", ), chunks=MagicMock(), ), diff --git a/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py b/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py index 1bc7cc12e..7b406ae2e 100644 --- a/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py +++ b/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py @@ -151,6 +151,8 @@ def test_fuzzy_match_quotes_to_docs() -> None: match_highlights=[], updated_at=None, image_file_name=None, + doc_summary="", + chunk_context="", ) test_chunk_1 = InferenceChunk( document_id="test doc 1", @@ -170,6 +172,8 @@ def test_fuzzy_match_quotes_to_docs() -> None: match_highlights=[], updated_at=None, image_file_name=None, + doc_summary="", + chunk_context="", ) test_quotes = [ diff --git a/backend/tests/unit/onyx/chat/test_prune_and_merge.py b/backend/tests/unit/onyx/chat/test_prune_and_merge.py index a4037b3d0..ee6299024 100644 --- a/backend/tests/unit/onyx/chat/test_prune_and_merge.py +++ b/backend/tests/unit/onyx/chat/test_prune_and_merge.py @@ -38,6 +38,8 @@ def create_inference_chunk( match_highlights=[], updated_at=None, image_file_name=None, + doc_summary="", + chunk_context="", ) diff --git a/backend/tests/unit/onyx/indexing/test_chunker.py b/backend/tests/unit/onyx/indexing/test_chunker.py index 57ba3fe12..b7fa84b3a 100644 --- a/backend/tests/unit/onyx/indexing/test_chunker.py +++ b/backend/tests/unit/onyx/indexing/test_chunker.py @@ -1,3 +1,6 @@ +from typing import Any +from unittest.mock import Mock + import pytest from onyx.configs.constants import DocumentSource @@ -19,7 +22,10 @@ def embedder() -> DefaultIndexingEmbedder: ) -def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None: +@pytest.mark.parametrize("enable_contextual_rag", [True, False]) +def test_chunk_document( + embedder: DefaultIndexingEmbedder, enable_contextual_rag: bool +) -> None: short_section_1 = "This is a short section." long_section = ( "This is a long section that should be split into multiple chunks. " * 100 @@ -45,9 +51,23 @@ def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None: ) indexing_documents = process_image_sections([document]) + mock_llm_invoke_count = 0 + + def mock_llm_invoke(self: Any, *args: Any, **kwargs: Any) -> Mock: + nonlocal mock_llm_invoke_count + mock_llm_invoke_count += 1 + m = Mock() + m.content = f"Test{mock_llm_invoke_count}" + return m + + mock_llm = Mock() + mock_llm.invoke = mock_llm_invoke + chunker = Chunker( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=False, + enable_contextual_rag=enable_contextual_rag, + llm=mock_llm, ) chunks = chunker.chunk(indexing_documents) @@ -58,6 +78,16 @@ def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None: assert "tag1" in chunks[0].metadata_suffix_keyword assert "tag2" in chunks[0].metadata_suffix_semantic + doc_summary = "Test1" if enable_contextual_rag else "" + chunk_context = "" + count = 2 + for chunk in chunks: + if enable_contextual_rag: + chunk_context = f"Test{count}" + count += 1 + assert chunk.doc_summary == doc_summary + assert chunk.chunk_context == chunk_context + def test_chunker_heartbeat( embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat @@ -78,6 +108,7 @@ def test_chunker_heartbeat( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=False, callback=mock_heartbeat, + enable_contextual_rag=False, ) chunks = chunker.chunk(indexing_documents) diff --git a/backend/tests/unit/onyx/indexing/test_embedder.py b/backend/tests/unit/onyx/indexing/test_embedder.py index d8589ecff..2fe824a74 100644 --- a/backend/tests/unit/onyx/indexing/test_embedder.py +++ b/backend/tests/unit/onyx/indexing/test_embedder.py @@ -21,7 +21,13 @@ def mock_embedding_model() -> Generator[Mock, None, None]: yield mock -def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None: +@pytest.mark.parametrize( + "chunk_context, doc_summary", + [("Test chunk context", "Test document summary"), ("", "")], +) +def test_default_indexing_embedder_embed_chunks( + mock_embedding_model: Mock, chunk_context: str, doc_summary: str +) -> None: # Setup embedder = DefaultIndexingEmbedder( model_name="test-model", @@ -63,6 +69,8 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N large_chunk_reference_ids=[], large_chunk_id=None, image_file_name=None, + chunk_context=chunk_context, + doc_summary=doc_summary, ) ] @@ -81,7 +89,7 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N # Verify the embedding model was called correctly mock_embedding_model.return_value.encode.assert_any_call( - texts=["Title: Test chunk"], + texts=[f"Title: {doc_summary}Test chunk{chunk_context}"], text_type=EmbedTextType.PASSAGE, large_chunks_present=False, ) diff --git a/backend/tests/unit/onyx/llm/test_chat_llm.py b/backend/tests/unit/onyx/llm/test_chat_llm.py index b69b3b7de..4a34db8d4 100644 --- a/backend/tests/unit/onyx/llm/test_chat_llm.py +++ b/backend/tests/unit/onyx/llm/test_chat_llm.py @@ -140,12 +140,12 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None: ], tools=tools, tool_choice=None, + max_tokens=None, stream=False, temperature=0.0, # Default value from GEN_AI_TEMPERATURE timeout=30, parallel_tool_calls=False, mock_response=MOCK_LLM_RESPONSE, - max_tokens=None, ) @@ -286,10 +286,10 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No ], tools=tools, tool_choice=None, + max_tokens=None, stream=True, temperature=0.0, # Default value from GEN_AI_TEMPERATURE timeout=30, parallel_tool_calls=False, mock_response=MOCK_LLM_RESPONSE, - max_tokens=None, ) diff --git a/web/src/app/admin/configuration/llm/constants.ts b/web/src/app/admin/configuration/llm/constants.ts index d7e3449b3..0f1324c4a 100644 --- a/web/src/app/admin/configuration/llm/constants.ts +++ b/web/src/app/admin/configuration/llm/constants.ts @@ -1,5 +1,8 @@ export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider"; +export const LLM_CONTEXTUAL_COST_ADMIN_URL = + "/api/admin/llm/provider-contextual-cost"; + export const EMBEDDING_PROVIDERS_ADMIN_URL = "/api/admin/embedding/embedding-provider"; diff --git a/web/src/app/admin/configuration/search/page.tsx b/web/src/app/admin/configuration/search/page.tsx index 126f1a6a4..af948062a 100644 --- a/web/src/app/admin/configuration/search/page.tsx +++ b/web/src/app/admin/configuration/search/page.tsx @@ -143,6 +143,15 @@ function Main() { +
+ Contextual RAG + + {searchSettings.enable_contextual_rag + ? "Enabled" + : "Disabled"} + +
+
Disable Reranking for Streaming diff --git a/web/src/app/admin/embeddings/interfaces.ts b/web/src/app/admin/embeddings/interfaces.ts index f9e3c4161..394884955 100644 --- a/web/src/app/admin/embeddings/interfaces.ts +++ b/web/src/app/admin/embeddings/interfaces.ts @@ -26,9 +26,18 @@ export enum EmbeddingPrecision { BFLOAT16 = "bfloat16", } +export interface LLMContextualCost { + provider: string; + model_name: string; + cost: number; +} + export interface AdvancedSearchConfiguration { index_name: string | null; multipass_indexing: boolean; + enable_contextual_rag: boolean; + contextual_rag_llm_name: string | null; + contextual_rag_llm_provider: string | null; multilingual_expansion: string[]; disable_rerank_for_streaming: boolean; api_url: string | null; diff --git a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx index b5003f835..52f29811f 100644 --- a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx @@ -3,7 +3,11 @@ import { Formik, Form, FormikProps, FieldArray, Field } from "formik"; import * as Yup from "yup"; import { TrashIcon } from "@/components/icons/icons"; import { FaPlus } from "react-icons/fa"; -import { AdvancedSearchConfiguration, EmbeddingPrecision } from "../interfaces"; +import { + AdvancedSearchConfiguration, + EmbeddingPrecision, + LLMContextualCost, +} from "../interfaces"; import { BooleanFormField, Label, @@ -12,6 +16,13 @@ import { } from "@/components/admin/connectors/Field"; import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput"; import { StringOrNumberOption } from "@/components/Dropdown"; +import useSWR from "swr"; +import { LLM_CONTEXTUAL_COST_ADMIN_URL } from "../../configuration/llm/constants"; +import { getDisplayNameForModel } from "@/lib/hooks"; +import { errorHandlingFetcher } from "@/lib/fetcher"; + +// Number of tokens to show cost calculation for +const COST_CALCULATION_TOKENS = 1_000_000; interface AdvancedEmbeddingFormPageProps { updateAdvancedEmbeddingDetails: ( @@ -45,14 +56,55 @@ const AdvancedEmbeddingFormPage = forwardRef< }, ref ) => { + // Fetch contextual costs + const { data: contextualCosts, error: costError } = useSWR< + LLMContextualCost[] + >(LLM_CONTEXTUAL_COST_ADMIN_URL, errorHandlingFetcher); + + const llmOptions: StringOrNumberOption[] = React.useMemo( + () => + (contextualCosts || []).map((cost) => { + return { + name: getDisplayNameForModel(cost.model_name), + value: cost.model_name, + }; + }), + [contextualCosts] + ); + + // Helper function to format cost as USD + const formatCost = (cost: number) => { + return new Intl.NumberFormat("en-US", { + style: "currency", + currency: "USD", + }).format(cost); + }; + + // Get cost info for selected model + const getSelectedModelCost = (modelName: string | null) => { + if (!contextualCosts || !modelName) return null; + return contextualCosts.find((cost) => cost.model_name === modelName); + }; + + // Get the current value for the selector based on the parent state + const getCurrentLLMValue = React.useMemo(() => { + if (!advancedEmbeddingDetails.contextual_rag_llm_name) return null; + return advancedEmbeddingDetails.contextual_rag_llm_name; + }, [advancedEmbeddingDetails.contextual_rag_llm_name]); + return (
{ // Call updateAdvancedEmbeddingDetails for each changed field Object.entries(values).forEach(([key, value]) => { - updateAdvancedEmbeddingDetails( - key as keyof AdvancedSearchConfiguration, - value - ); + if (key === "contextual_rag_llm") { + const selectedModel = (contextualCosts || []).find( + (cost) => cost.model_name === value + ); + if (selectedModel) { + updateAdvancedEmbeddingDetails( + "contextual_rag_llm_provider", + selectedModel.provider + ); + updateAdvancedEmbeddingDetails( + "contextual_rag_llm_name", + selectedModel.model_name + ); + } + } else { + updateAdvancedEmbeddingDetails( + key as keyof AdvancedSearchConfiguration, + value + ); + } }); // Run validation and report errors @@ -190,6 +258,56 @@ const AdvancedEmbeddingFormPage = forwardRef< label="Disable Rerank for Streaming" name="disable_rerank_for_streaming" /> + +
+ + {values.enable_contextual_rag && + values.contextual_rag_llm && + !costError && ( +
+ {contextualCosts ? ( + <> + Estimated cost for processing{" "} + {COST_CALCULATION_TOKENS.toLocaleString()} tokens:{" "} + + {getSelectedModelCost(values.contextual_rag_llm) + ? formatCost( + getSelectedModelCost( + values.contextual_rag_llm + )!.cost + ) + : "Cost information not available"} + + + ) : ( + "Loading cost information..." + )} +
+ )} +
({ index_name: "", multipass_indexing: true, + enable_contextual_rag: false, + contextual_rag_llm_name: null, + contextual_rag_llm_provider: null, multilingual_expansion: [], disable_rerank_for_streaming: false, api_url: null, @@ -152,6 +155,9 @@ export default function EmbeddingForm() { setAdvancedEmbeddingDetails({ index_name: searchSettings.index_name, multipass_indexing: searchSettings.multipass_indexing, + enable_contextual_rag: searchSettings.enable_contextual_rag, + contextual_rag_llm_name: searchSettings.contextual_rag_llm_name, + contextual_rag_llm_provider: searchSettings.contextual_rag_llm_provider, multilingual_expansion: searchSettings.multilingual_expansion, disable_rerank_for_streaming: searchSettings.disable_rerank_for_streaming, @@ -197,7 +203,9 @@ export default function EmbeddingForm() { searchSettings?.embedding_precision != advancedEmbeddingDetails.embedding_precision || searchSettings?.reduced_dimension != - advancedEmbeddingDetails.reduced_dimension; + advancedEmbeddingDetails.reduced_dimension || + searchSettings?.enable_contextual_rag != + advancedEmbeddingDetails.enable_contextual_rag; const updateSearch = useCallback(async () => { if (!selectedProvider) { diff --git a/web/src/components/admin/connectors/Field.tsx b/web/src/components/admin/connectors/Field.tsx index eeec14446..d87d2f5d3 100644 --- a/web/src/components/admin/connectors/Field.tsx +++ b/web/src/components/admin/connectors/Field.tsx @@ -676,6 +676,7 @@ interface SelectorFormFieldProps { includeReset?: boolean; fontSize?: "sm" | "md" | "lg"; small?: boolean; + disabled?: boolean; } export function SelectorFormField({ @@ -691,6 +692,7 @@ export function SelectorFormField({ includeReset = false, fontSize = "md", small = false, + disabled = false, }: SelectorFormFieldProps) { const [field] = useField(name); const { setFieldValue } = useFormikContext(); @@ -742,8 +744,9 @@ export function SelectorFormField({ : setFieldValue(name, selected)) } defaultValue={defaultValue} + disabled={disabled} > - + {currentlySelected?.name || defaultValue || ""}