diff --git a/backend/alembic/versions/8e1ac4f39a9f_enable_contextual_retrieval.py b/backend/alembic/versions/8e1ac4f39a9f_enable_contextual_retrieval.py new file mode 100644 index 000000000..b3a23efaf --- /dev/null +++ b/backend/alembic/versions/8e1ac4f39a9f_enable_contextual_retrieval.py @@ -0,0 +1,50 @@ +"""enable contextual retrieval + +Revision ID: 8e1ac4f39a9f +Revises: 3781a5eb12cb +Create Date: 2024-12-20 13:29:09.918661 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "8e1ac4f39a9f" +down_revision = "3781a5eb12cb" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "search_settings", + sa.Column( + "enable_contextual_rag", + sa.Boolean(), + nullable=False, + server_default="false", + ), + ) + op.add_column( + "search_settings", + sa.Column( + "contextual_rag_llm_name", + sa.String(), + nullable=True, + ), + ) + op.add_column( + "search_settings", + sa.Column( + "contextual_rag_llm_provider", + sa.String(), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("search_settings", "enable_contextual_rag") + op.drop_column("search_settings", "contextual_rag_llm_name") + op.drop_column("search_settings", "contextual_rag_llm_provider") diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 6ada04d34..2a73fe60f 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -495,6 +495,11 @@ NUM_SECONDARY_INDEXING_WORKERS = int( ENABLE_MULTIPASS_INDEXING = ( os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true" ) +# Enable contextual retrieval +ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true" + +DEFAULT_CONTEXTUAL_RAG_LLM_NAME = "gpt-4o-mini" +DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER = "DevEnvPresetOpenAI" # Finer grained chunking for more detail retention # Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE # tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end @@ -536,6 +541,17 @@ MAX_FILE_SIZE_BYTES = int( os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024 ) # 2GB in bytes +# Use document summary for contextual rag +USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true" +# Use chunk summary for contextual rag +USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true" +# Average summary embeddings for contextual rag (not yet implemented) +AVERAGE_SUMMARY_EMBEDDINGS = ( + os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true" +) + +MAX_TOKENS_FOR_FULL_INCLUSION = 4096 + ##### # Miscellaneous ##### diff --git a/backend/onyx/connectors/google_drive/doc_conversion.py b/backend/onyx/connectors/google_drive/doc_conversion.py index c4d015d0a..ce3800e59 100644 --- a/backend/onyx/connectors/google_drive/doc_conversion.py +++ b/backend/onyx/connectors/google_drive/doc_conversion.py @@ -30,6 +30,7 @@ from onyx.file_processing.file_validation import is_valid_image_type from onyx.file_processing.image_summarization import summarize_image_with_error_handling from onyx.file_processing.image_utils import store_image_and_create_section from onyx.llm.interfaces import LLM +from onyx.utils.lazy import lazy_eval from onyx.utils.logger import setup_logger logger = setup_logger() @@ -76,6 +77,26 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool: return is_valid_image_type(mime_type) +def download_request(service: GoogleDriveService, file_id: str) -> bytes: + """ + Download the file from Google Drive. + """ + # For other file types, download the file + # Use the correct API call for downloading files + request = service.files().get_media(fileId=file_id) + response_bytes = io.BytesIO() + downloader = MediaIoBaseDownload(response_bytes, request) + done = False + while not done: + _, done = downloader.next_chunk() + + response = response_bytes.getvalue() + if not response: + logger.warning(f"Failed to download {file_id}") + return bytes() + return response + + def _download_and_extract_sections_basic( file: dict[str, str], service: GoogleDriveService, @@ -114,41 +135,31 @@ def _download_and_extract_sections_basic( # For other file types, download the file # Use the correct API call for downloading files - request = service.files().get_media(fileId=file_id) - response_bytes = io.BytesIO() - downloader = MediaIoBaseDownload(response_bytes, request) - done = False - while not done: - _, done = downloader.next_chunk() - - response = response_bytes.getvalue() - if not response: - logger.warning(f"Failed to download {file_name}") - return [] + response_call = lazy_eval(lambda: download_request(service, file_id)) # Process based on mime type if mime_type == "text/plain": - text = response.decode("utf-8") + text = response_call().decode("utf-8") return [TextSection(link=link, text=text)] elif ( mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" ): - text, _ = docx_to_text_and_images(io.BytesIO(response)) + text, _ = docx_to_text_and_images(io.BytesIO(response_call())) return [TextSection(link=link, text=text)] elif ( mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" ): - text = xlsx_to_text(io.BytesIO(response)) + text = xlsx_to_text(io.BytesIO(response_call())) return [TextSection(link=link, text=text)] elif ( mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation" ): - text = pptx_to_text(io.BytesIO(response)) + text = pptx_to_text(io.BytesIO(response_call())) return [TextSection(link=link, text=text)] elif is_gdrive_image_mime_type(mime_type): @@ -158,7 +169,7 @@ def _download_and_extract_sections_basic( with get_session_with_current_tenant() as db_session: section, embedded_id = store_image_and_create_section( db_session=db_session, - image_data=response, + image_data=response_call(), file_name=file_id, display_name=file_name, media_type=mime_type, @@ -171,7 +182,7 @@ def _download_and_extract_sections_basic( return sections elif mime_type == "application/pdf": - text, _pdf_meta, images = read_pdf_file(io.BytesIO(response)) + text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call())) pdf_sections: list[TextSection | ImageSection] = [ TextSection(link=link, text=text) ] @@ -194,8 +205,15 @@ def _download_and_extract_sections_basic( else: # For unsupported file types, try to extract text + if mime_type in [ + "application/vnd.google-apps.video", + "application/vnd.google-apps.audio", + "application/zip", + ]: + return [] + # For unsupported file types, try to extract text try: - text = extract_file_text(io.BytesIO(response), file_name) + text = extract_file_text(io.BytesIO(response_call()), file_name) return [TextSection(link=link, text=text)] except Exception as e: logger.warning(f"Failed to extract text from {file_name}: {e}") diff --git a/backend/onyx/connectors/models.py b/backend/onyx/connectors/models.py index 6ba15b6b1..ac3fa42bd 100644 --- a/backend/onyx/connectors/models.py +++ b/backend/onyx/connectors/models.py @@ -163,6 +163,9 @@ class DocumentBase(BaseModel): attributes.append(k + INDEX_SEPARATOR + v) return attributes + def get_text_content(self) -> str: + return " ".join([section.text for section in self.sections if section.text]) + class Document(DocumentBase): """Used for Onyx ingestion api, the ID is required""" diff --git a/backend/onyx/context/search/models.py b/backend/onyx/context/search/models.py index 980ed9644..3ce3dacae 100644 --- a/backend/onyx/context/search/models.py +++ b/backend/onyx/context/search/models.py @@ -60,7 +60,7 @@ class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting): inference_settings = InferenceSettings.from_db_model(search_settings) indexing_setting = IndexingSetting.from_db_model(search_settings) - return cls(**inference_settings.dict(), **indexing_setting.dict()) + return cls(**inference_settings.model_dump(), **indexing_setting.model_dump()) class SavedSearchSettings(InferenceSettings, IndexingSetting): @@ -80,6 +80,9 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting): reduced_dimension=search_settings.reduced_dimension, # Whether switching to this model requires re-indexing background_reindex_enabled=search_settings.background_reindex_enabled, + enable_contextual_rag=search_settings.enable_contextual_rag, + contextual_rag_llm_name=search_settings.contextual_rag_llm_name, + contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider, # Reranking Details rerank_model_name=search_settings.rerank_model_name, rerank_provider_type=search_settings.rerank_provider_type, @@ -218,6 +221,8 @@ class InferenceChunk(BaseChunk): # to specify that a set of words should be highlighted. For example: # ["the answer is 42", "he couldn't find an answer"] match_highlights: list[str] + doc_summary: str + chunk_context: str # when the doc was last updated updated_at: datetime | None diff --git a/backend/onyx/context/search/postprocessing/postprocessing.py b/backend/onyx/context/search/postprocessing/postprocessing.py index 41243eec7..cfa07ef02 100644 --- a/backend/onyx/context/search/postprocessing/postprocessing.py +++ b/backend/onyx/context/search/postprocessing/postprocessing.py @@ -196,9 +196,21 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk RETURN_SEPARATOR ) + def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str: + # remove document summary + if chunk.content.startswith(chunk.doc_summary): + chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip() + # remove chunk context + if chunk.content.endswith(chunk.chunk_context): + chunk.content = chunk.content[ + : len(chunk.content) - len(chunk.chunk_context) + ].rstrip() + return chunk.content + for chunk in chunks: chunk.content = _remove_title(chunk) chunk.content = _remove_metadata_suffix(chunk) + chunk.content = _remove_contextual_rag(chunk) return [chunk.to_inference_chunk() for chunk in chunks] diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index cbe3b1be3..951e2b760 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -791,6 +791,15 @@ class SearchSettings(Base): # Mini and Large Chunks (large chunk also checks for model max context) multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True) + # Contextual RAG + enable_contextual_rag: Mapped[bool] = mapped_column(Boolean, default=False) + + # Contextual RAG LLM + contextual_rag_llm_name: Mapped[str | None] = mapped_column(String, nullable=True) + contextual_rag_llm_provider: Mapped[str | None] = mapped_column( + String, nullable=True + ) + multilingual_expansion: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), default=[] ) diff --git a/backend/onyx/db/search_settings.py b/backend/onyx/db/search_settings.py index bddaf2115..1d21d0d0e 100644 --- a/backend/onyx/db/search_settings.py +++ b/backend/onyx/db/search_settings.py @@ -62,6 +62,9 @@ def create_search_settings( multipass_indexing=search_settings.multipass_indexing, embedding_precision=search_settings.embedding_precision, reduced_dimension=search_settings.reduced_dimension, + enable_contextual_rag=search_settings.enable_contextual_rag, + contextual_rag_llm_name=search_settings.contextual_rag_llm_name, + contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider, multilingual_expansion=search_settings.multilingual_expansion, disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming, rerank_model_name=search_settings.rerank_model_name, @@ -319,6 +322,7 @@ def get_old_default_embedding_model() -> IndexingSetting: passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""), index_name="danswer_chunk", multipass_indexing=False, + enable_contextual_rag=False, api_url=None, ) @@ -333,5 +337,6 @@ def get_new_default_embedding_model() -> IndexingSetting: passage_prefix=ASYM_PASSAGE_PREFIX, index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}", multipass_indexing=False, + enable_contextual_rag=False, api_url=None, ) diff --git a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd index 70980852a..4b7c7c1e0 100644 --- a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd @@ -98,6 +98,12 @@ schema DANSWER_CHUNK_NAME { field metadata type string { indexing: summary | attribute } + field chunk_context type string { + indexing: summary | attribute + } + field doc_summary type string { + indexing: summary | attribute + } field metadata_suffix type string { indexing: summary | attribute } diff --git a/backend/onyx/document_index/vespa/chunk_retrieval.py b/backend/onyx/document_index/vespa/chunk_retrieval.py index 8c7ee6963..a19f2d087 100644 --- a/backend/onyx/document_index/vespa/chunk_retrieval.py +++ b/backend/onyx/document_index/vespa/chunk_retrieval.py @@ -24,9 +24,11 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import ( from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST from onyx.document_index.vespa_constants import BLURB from onyx.document_index.vespa_constants import BOOST +from onyx.document_index.vespa_constants import CHUNK_CONTEXT from onyx.document_index.vespa_constants import CHUNK_ID from onyx.document_index.vespa_constants import CONTENT from onyx.document_index.vespa_constants import CONTENT_SUMMARY +from onyx.document_index.vespa_constants import DOC_SUMMARY from onyx.document_index.vespa_constants import DOC_UPDATED_AT from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT @@ -126,7 +128,8 @@ def _vespa_hit_to_inference_chunk( return InferenceChunkUncleaned( chunk_id=fields[CHUNK_ID], blurb=fields.get(BLURB, ""), # Unused - content=fields[CONTENT], # Includes extra title prefix and metadata suffix + content=fields[CONTENT], # Includes extra title prefix and metadata suffix; + # also sometimes context for contextual rag source_links=source_links_dict or {0: ""}, section_continuation=fields[SECTION_CONTINUATION], document_id=fields[DOCUMENT_ID], @@ -143,6 +146,8 @@ def _vespa_hit_to_inference_chunk( large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []), metadata=metadata, metadata_suffix=fields.get(METADATA_SUFFIX), + doc_summary=fields.get(DOC_SUMMARY, ""), + chunk_context=fields.get(CHUNK_CONTEXT, ""), match_highlights=match_highlights, updated_at=updated_at, ) diff --git a/backend/onyx/document_index/vespa/index.py b/backend/onyx/document_index/vespa/index.py index 17aadb36c..b60eaa322 100644 --- a/backend/onyx/document_index/vespa/index.py +++ b/backend/onyx/document_index/vespa/index.py @@ -187,7 +187,7 @@ class VespaIndex(DocumentIndex): ) -> None: if MULTI_TENANT: logger.info( - "Skipping Vespa index seup for multitenant (would wipe all indices)" + "Skipping Vespa index setup for multitenant (would wipe all indices)" ) return None diff --git a/backend/onyx/document_index/vespa/indexing_utils.py b/backend/onyx/document_index/vespa/indexing_utils.py index ab08dc4d1..9145ce63c 100644 --- a/backend/onyx/document_index/vespa/indexing_utils.py +++ b/backend/onyx/document_index/vespa/indexing_utils.py @@ -25,9 +25,11 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR from onyx.document_index.vespa_constants import BLURB from onyx.document_index.vespa_constants import BOOST +from onyx.document_index.vespa_constants import CHUNK_CONTEXT from onyx.document_index.vespa_constants import CHUNK_ID from onyx.document_index.vespa_constants import CONTENT from onyx.document_index.vespa_constants import CONTENT_SUMMARY +from onyx.document_index.vespa_constants import DOC_SUMMARY from onyx.document_index.vespa_constants import DOC_UPDATED_AT from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT @@ -174,7 +176,7 @@ def _index_vespa_chunk( # For the BM25 index, the keyword suffix is used, the vector is already generated with the more # natural language representation of the metadata section CONTENT: remove_invalid_unicode_chars( - f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}" + f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}" ), # This duplication of `content` is needed for keyword highlighting # Note that it's not exactly the same as the actual content @@ -189,6 +191,8 @@ def _index_vespa_chunk( # Save as a list for efficient extraction as an Attribute METADATA_LIST: metadata_list, METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword), + CHUNK_CONTEXT: chunk.chunk_context, + DOC_SUMMARY: chunk.doc_summary, EMBEDDINGS: embeddings_name_vector_map, TITLE_EMBEDDING: chunk.title_embedding, DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at), diff --git a/backend/onyx/document_index/vespa_constants.py b/backend/onyx/document_index/vespa_constants.py index 8a32fb721..66a7fd99d 100644 --- a/backend/onyx/document_index/vespa_constants.py +++ b/backend/onyx/document_index/vespa_constants.py @@ -71,6 +71,8 @@ LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids" METADATA = "metadata" METADATA_LIST = "metadata_list" METADATA_SUFFIX = "metadata_suffix" +DOC_SUMMARY = "doc_summary" +CHUNK_CONTEXT = "chunk_context" BOOST = "boost" AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor" DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch @@ -106,6 +108,8 @@ YQL_BASE = ( f"{LARGE_CHUNK_REFERENCE_IDS}, " f"{METADATA}, " f"{METADATA_SUFFIX}, " + f"{DOC_SUMMARY}, " + f"{CHUNK_CONTEXT}, " f"{CONTENT_SUMMARY} " f"from {{index_name}} where " ) diff --git a/backend/onyx/indexing/chunker.py b/backend/onyx/indexing/chunker.py index 0dea6fa12..e84f40f0c 100644 --- a/backend/onyx/indexing/chunker.py +++ b/backend/onyx/indexing/chunker.py @@ -1,7 +1,10 @@ +from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS from onyx.configs.app_configs import BLURB_SIZE from onyx.configs.app_configs import LARGE_CHUNK_RATIO from onyx.configs.app_configs import MINI_CHUNK_SIZE from onyx.configs.app_configs import SKIP_METADATA_IN_CHUNK +from onyx.configs.app_configs import USE_CHUNK_SUMMARY +from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY from onyx.configs.constants import DocumentSource from onyx.configs.constants import RETURN_SEPARATOR from onyx.configs.constants import SECTION_SEPARATOR @@ -13,6 +16,7 @@ from onyx.connectors.models import IndexingDocument from onyx.connectors.models import Section from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.indexing.models import DocAwareChunk +from onyx.llm.utils import MAX_CONTEXT_TOKENS from onyx.natural_language_processing.utils import BaseTokenizer from onyx.utils.logger import setup_logger from onyx.utils.text_processing import clean_text @@ -82,6 +86,9 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks], mini_chunk_texts=None, large_chunk_id=large_chunk_id, + chunk_context="", + doc_summary="", + contextual_rag_reserved_tokens=0, ) offset = 0 @@ -120,6 +127,7 @@ class Chunker: tokenizer: BaseTokenizer, enable_multipass: bool = False, enable_large_chunks: bool = False, + enable_contextual_rag: bool = False, blurb_size: int = BLURB_SIZE, include_metadata: bool = not SKIP_METADATA_IN_CHUNK, chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE, @@ -133,9 +141,20 @@ class Chunker: self.chunk_token_limit = chunk_token_limit self.enable_multipass = enable_multipass self.enable_large_chunks = enable_large_chunks + self.enable_contextual_rag = enable_contextual_rag + if enable_contextual_rag: + assert ( + USE_CHUNK_SUMMARY or USE_DOCUMENT_SUMMARY + ), "Contextual RAG requires at least one of chunk summary and document summary enabled" + self.default_contextual_rag_reserved_tokens = MAX_CONTEXT_TOKENS * ( + int(USE_CHUNK_SUMMARY) + int(USE_DOCUMENT_SUMMARY) + ) self.tokenizer = tokenizer self.callback = callback + self.max_context = 0 + self.prompt_tokens = 0 + self.blurb_splitter = SentenceSplitter( tokenizer=tokenizer.tokenize, chunk_size=blurb_size, @@ -221,6 +240,9 @@ class Chunker: metadata_suffix_keyword=metadata_suffix_keyword, mini_chunk_texts=self._get_mini_chunk_texts(text), large_chunk_id=None, + doc_summary="", + chunk_context="", + contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document ) chunks_list.append(new_chunk) @@ -309,7 +331,7 @@ class Chunker: continue # CASE 2: Normal text section - section_token_count = len(self.tokenizer.tokenize(section_text)) + section_token_count = len(self.tokenizer.encode(section_text)) # If the section is large on its own, split it separately if section_token_count > content_token_limit: @@ -332,8 +354,7 @@ class Chunker: # If even the split_text is bigger than strict limit, further split if ( STRICT_CHUNK_TOKEN_LIMIT - and len(self.tokenizer.tokenize(split_text)) - > content_token_limit + and len(self.tokenizer.encode(split_text)) > content_token_limit ): smaller_chunks = self._split_oversized_chunk( split_text, content_token_limit @@ -363,10 +384,10 @@ class Chunker: continue # If we can still fit this section into the current chunk, do so - current_token_count = len(self.tokenizer.tokenize(chunk_text)) + current_token_count = len(self.tokenizer.encode(chunk_text)) current_offset = len(shared_precompare_cleanup(chunk_text)) next_section_tokens = ( - len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count + len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count ) if next_section_tokens + current_token_count <= content_token_limit: @@ -414,7 +435,7 @@ class Chunker: # Title prep title = self._extract_blurb(document.get_title_for_document_index() or "") title_prefix = title + RETURN_SEPARATOR if title else "" - title_tokens = len(self.tokenizer.tokenize(title_prefix)) + title_tokens = len(self.tokenizer.encode(title_prefix)) # Metadata prep metadata_suffix_semantic = "" @@ -427,15 +448,50 @@ class Chunker: ) = _get_metadata_suffix_for_document_index( document.metadata, include_separator=True ) - metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic)) + metadata_tokens = len(self.tokenizer.encode(metadata_suffix_semantic)) # If metadata is too large, skip it in the semantic content if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE: metadata_suffix_semantic = "" metadata_tokens = 0 + single_chunk_fits = True + doc_token_count = 0 + if self.enable_contextual_rag: + doc_content = document.get_text_content() + tokenized_doc = self.tokenizer.tokenize(doc_content) + doc_token_count = len(tokenized_doc) + + # check if doc + title + metadata fits in a single chunk. If so, no need for contextual RAG + single_chunk_fits = ( + doc_token_count + title_tokens + metadata_tokens + <= self.chunk_token_limit + ) + + # expand the size of the context used for contextual rag based on whether chunk context and doc summary are used + context_size = 0 + if ( + self.enable_contextual_rag + and not single_chunk_fits + and not AVERAGE_SUMMARY_EMBEDDINGS + ): + context_size += self.default_contextual_rag_reserved_tokens + # Adjust content token limit to accommodate title + metadata - content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens + content_token_limit = ( + self.chunk_token_limit - title_tokens - metadata_tokens - context_size + ) + + # first check: if there is not enough actual chunk content when including contextual rag, + # then don't do contextual rag + if content_token_limit <= CHUNK_MIN_CONTENT: + context_size = 0 # Don't do contextual RAG + # revert to previous content token limit + content_token_limit = ( + self.chunk_token_limit - title_tokens - metadata_tokens + ) + + # If there is not enough context remaining then just index the chunk with no prefix/suffix if content_token_limit <= CHUNK_MIN_CONTENT: # Not enough space left, so revert to full chunk without the prefix content_token_limit = self.chunk_token_limit @@ -459,6 +515,9 @@ class Chunker: large_chunks = generate_large_chunks(normal_chunks) normal_chunks.extend(large_chunks) + for chunk in normal_chunks: + chunk.contextual_rag_reserved_tokens = context_size + return normal_chunks def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]: diff --git a/backend/onyx/indexing/embedder.py b/backend/onyx/indexing/embedder.py index 67bf56fc8..78ea96340 100644 --- a/backend/onyx/indexing/embedder.py +++ b/backend/onyx/indexing/embedder.py @@ -121,7 +121,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): if chunk.large_chunk_reference_ids: large_chunks_present = True chunk_text = ( - f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}" + f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}" ) or chunk.source_document.get_title_for_document_index() if not chunk_text: diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index 3967b7f7c..99401c709 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -1,3 +1,4 @@ +from collections import defaultdict from collections.abc import Callable from functools import partial from typing import Protocol @@ -8,7 +9,13 @@ from sqlalchemy.orm import Session from onyx.access.access import get_access_for_documents from onyx.access.models import DocumentAccess +from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME +from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER +from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG from onyx.configs.app_configs import MAX_DOCUMENT_CHARS +from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION +from onyx.configs.app_configs import USE_CHUNK_SUMMARY +from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY from onyx.configs.constants import DEFAULT_BOOST from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION @@ -36,9 +43,10 @@ from onyx.db.document import upsert_documents from onyx.db.document_set import fetch_document_sets_for_documents from onyx.db.engine import get_session_with_current_tenant from onyx.db.models import Document as DBDocument +from onyx.db.models import IndexModelStatus from onyx.db.pg_file_store import get_pgfilestore_by_file_name from onyx.db.pg_file_store import read_lobj -from onyx.db.search_settings import get_current_search_settings +from onyx.db.search_settings import get_active_search_settings from onyx.db.tag import create_or_add_document_tag from onyx.db.tag import create_or_add_document_tag_list from onyx.document_index.document_index_utils import ( @@ -57,11 +65,24 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.indexing.models import IndexChunk from onyx.indexing.models import UpdatableChunkData from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff +from onyx.llm.chat_llm import LLMRateLimitError from onyx.llm.factory import get_default_llm_with_vision +from onyx.llm.factory import get_llm_for_contextual_rag +from onyx.llm.interfaces import LLM +from onyx.llm.utils import get_max_input_tokens +from onyx.llm.utils import MAX_CONTEXT_TOKENS +from onyx.llm.utils import message_to_string from onyx.natural_language_processing.search_nlp_models import ( InformationContentClassificationModel, ) +from onyx.natural_language_processing.utils import BaseTokenizer +from onyx.natural_language_processing.utils import get_tokenizer +from onyx.natural_language_processing.utils import tokenizer_trim_middle +from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT1 +from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT2 +from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_PROMPT from onyx.utils.logger import setup_logger +from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from onyx.utils.timing import log_function_time from shared_configs.configs import ( INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH, @@ -249,6 +270,8 @@ def index_doc_batch_with_handler( db_session: Session, tenant_id: str, ignore_time_skip: bool = False, + enable_contextual_rag: bool = False, + llm: LLM | None = None, ) -> IndexingPipelineResult: try: index_pipeline_result = index_doc_batch( @@ -261,6 +284,8 @@ def index_doc_batch_with_handler( db_session=db_session, ignore_time_skip=ignore_time_skip, tenant_id=tenant_id, + enable_contextual_rag=enable_contextual_rag, + llm=llm, ) except Exception as e: # don't log the batch directly, it's too much text @@ -523,6 +548,145 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]: return indexed_documents +def add_document_summaries( + chunks_by_doc: list[DocAwareChunk], + llm: LLM, + tokenizer: BaseTokenizer, + trunc_doc_tokens: int, +) -> list[int] | None: + """ + Adds a document summary to a list of chunks from the same document. + Returns the number of tokens in the document. + """ + + doc_tokens = [] + # this is value is the same for each chunk in the document; 0 indicates + # There is not enough space for contextual RAG (the chunk content + # and possibly metadata took up too much space) + if chunks_by_doc[0].contextual_rag_reserved_tokens == 0: + return None + + doc_tokens = tokenizer.encode(chunks_by_doc[0].source_document.get_text_content()) + doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_tokens, tokenizer) + summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content) + doc_summary = message_to_string( + llm.invoke(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS) + ) + + for chunk in chunks_by_doc: + chunk.doc_summary = doc_summary + + return doc_tokens + + +def add_chunk_summaries( + chunks_by_doc: list[DocAwareChunk], + llm: LLM, + tokenizer: BaseTokenizer, + trunc_doc_chunk_tokens: int, + doc_tokens: list[int] | None, +) -> None: + """ + Adds chunk summaries to the chunks grouped by document id. + Chunk summaries look at the chunk as well as the entire document (or a summary, + if the document is too long) and describe how the chunk relates to the document. + """ + # all chunks within a document have the same contextual_rag_reserved_tokens + if chunks_by_doc[0].contextual_rag_reserved_tokens == 0: + return + + # use values computed in above doc summary section if available + doc_tokens = doc_tokens or tokenizer.encode( + chunks_by_doc[0].source_document.get_text_content() + ) + doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_chunk_tokens, tokenizer) + + # only compute doc summary if needed + doc_info = ( + doc_content + if len(doc_tokens) <= MAX_TOKENS_FOR_FULL_INCLUSION + else chunks_by_doc[0].doc_summary + ) + if not doc_info: + # This happens if the document is too long AND document summaries are turned off + # In this case we compute a doc summary using the LLM + doc_info = message_to_string( + llm.invoke( + DOCUMENT_SUMMARY_PROMPT.format(document=doc_content), + max_tokens=MAX_CONTEXT_TOKENS, + ) + ) + + context_prompt1 = CONTEXTUAL_RAG_PROMPT1.format(document=doc_info) + + def assign_context(chunk: DocAwareChunk) -> None: + context_prompt2 = CONTEXTUAL_RAG_PROMPT2.format(chunk=chunk.content) + try: + chunk.chunk_context = message_to_string( + llm.invoke( + context_prompt1 + context_prompt2, + max_tokens=MAX_CONTEXT_TOKENS, + ) + ) + except LLMRateLimitError as e: + # Erroring during chunker is undesirable, so we log the error and continue + # TODO: for v2, add robust retry logic + logger.exception(f"Rate limit adding chunk summary: {e}", exc_info=e) + chunk.chunk_context = "" + except Exception as e: + logger.exception(f"Error adding chunk summary: {e}", exc_info=e) + chunk.chunk_context = "" + + run_functions_tuples_in_parallel( + [(assign_context, (chunk,)) for chunk in chunks_by_doc] + ) + + +def add_contextual_summaries( + chunks: list[DocAwareChunk], + llm: LLM, + tokenizer: BaseTokenizer, + chunk_token_limit: int, +) -> list[DocAwareChunk]: + """ + Adds Document summary and chunk-within-document context to the chunks + based on which environment variables are set. + """ + max_context = get_max_input_tokens( + model_name=llm.config.model_name, + model_provider=llm.config.model_provider, + output_tokens=MAX_CONTEXT_TOKENS, + ) + doc2chunks = defaultdict(list) + for chunk in chunks: + doc2chunks[chunk.source_document.id].append(chunk) + + # The number of tokens allowed for the document when computing a document summary + trunc_doc_summary_tokens = max_context - len( + tokenizer.encode(DOCUMENT_SUMMARY_PROMPT) + ) + + prompt_tokens = len( + tokenizer.encode(CONTEXTUAL_RAG_PROMPT1 + CONTEXTUAL_RAG_PROMPT2) + ) + # The number of tokens allowed for the document when computing a + # "chunk in context of document" summary + trunc_doc_chunk_tokens = max_context - prompt_tokens - chunk_token_limit + for chunks_by_doc in doc2chunks.values(): + doc_tokens = None + if USE_DOCUMENT_SUMMARY: + doc_tokens = add_document_summaries( + chunks_by_doc, llm, tokenizer, trunc_doc_summary_tokens + ) + + if USE_CHUNK_SUMMARY: + add_chunk_summaries( + chunks_by_doc, llm, tokenizer, trunc_doc_chunk_tokens, doc_tokens + ) + + return chunks + + @log_function_time(debug_only=True) def index_doc_batch( *, @@ -534,6 +698,8 @@ def index_doc_batch( index_attempt_metadata: IndexAttemptMetadata, db_session: Session, tenant_id: str, + enable_contextual_rag: bool = False, + llm: LLM | None = None, ignore_time_skip: bool = False, filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents, ) -> IndexingPipelineResult: @@ -596,6 +762,20 @@ def index_doc_batch( # a common source of failure for the indexing pipeline chunks: list[DocAwareChunk] = chunker.chunk(ctx.indexable_docs) + # contextual RAG + if enable_contextual_rag: + assert llm is not None, "must provide an LLM for contextual RAG" + llm_tokenizer = get_tokenizer( + model_name=llm.config.model_name, + provider_type=llm.config.model_provider, + ) + + # Because the chunker's tokens are different from the LLM's tokens, + # We add a fudge factor to ensure we truncate prompts to the LLM's token limit + chunks = add_contextual_summaries( + chunks, llm, llm_tokenizer, chunker.chunk_token_limit * 2 + ) + logger.debug("Starting embedding") chunks_with_embeddings, embedding_failures = ( embed_chunks_with_failure_handling( @@ -791,13 +971,33 @@ def build_indexing_pipeline( callback: IndexingHeartbeatInterface | None = None, ) -> IndexingPipelineProtocol: """Builds a pipeline which takes in a list (batch) of docs and indexes them.""" - search_settings = get_current_search_settings(db_session) + all_search_settings = get_active_search_settings(db_session) + if ( + all_search_settings.secondary + and all_search_settings.secondary.status == IndexModelStatus.FUTURE + ): + search_settings = all_search_settings.secondary + else: + search_settings = all_search_settings.primary + multipass_config = get_multipass_config(search_settings) + enable_contextual_rag = ( + search_settings.enable_contextual_rag or ENABLE_CONTEXTUAL_RAG + ) + llm = None + if enable_contextual_rag: + llm = get_llm_for_contextual_rag( + search_settings.contextual_rag_llm_name or DEFAULT_CONTEXTUAL_RAG_LLM_NAME, + search_settings.contextual_rag_llm_provider + or DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER, + ) + chunker = chunker or Chunker( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=multipass_config.multipass_indexing, enable_large_chunks=multipass_config.enable_large_chunks, + enable_contextual_rag=enable_contextual_rag, # after every doc, update status in case there are a bunch of really long docs callback=callback, ) @@ -811,4 +1011,6 @@ def build_indexing_pipeline( ignore_time_skip=ignore_time_skip, db_session=db_session, tenant_id=tenant_id, + enable_contextual_rag=enable_contextual_rag, + llm=llm, ) diff --git a/backend/onyx/indexing/models.py b/backend/onyx/indexing/models.py index 686ac2942..d6283ee21 100644 --- a/backend/onyx/indexing/models.py +++ b/backend/onyx/indexing/models.py @@ -49,6 +49,15 @@ class DocAwareChunk(BaseChunk): metadata_suffix_semantic: str metadata_suffix_keyword: str + # This is the number of tokens reserved for contextual RAG + # in the chunk. doc_summary and chunk_context conbined should + # contain at most this many tokens. + contextual_rag_reserved_tokens: int + # This is the summary for the document generated for contextual RAG + doc_summary: str + # This is the context for this chunk generated for contextual RAG + chunk_context: str + mini_chunk_texts: list[str] | None large_chunk_id: int | None @@ -154,6 +163,9 @@ class IndexingSetting(EmbeddingModelDetail): reduced_dimension: int | None = None background_reindex_enabled: bool = True + enable_contextual_rag: bool + contextual_rag_llm_name: str | None = None + contextual_rag_llm_provider: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} @@ -178,6 +190,7 @@ class IndexingSetting(EmbeddingModelDetail): embedding_precision=search_settings.embedding_precision, reduced_dimension=search_settings.reduced_dimension, background_reindex_enabled=search_settings.background_reindex_enabled, + enable_contextual_rag=search_settings.enable_contextual_rag, ) diff --git a/backend/onyx/llm/chat_llm.py b/backend/onyx/llm/chat_llm.py index 2e9496856..1f33dc10e 100644 --- a/backend/onyx/llm/chat_llm.py +++ b/backend/onyx/llm/chat_llm.py @@ -425,12 +425,12 @@ class DefaultMultiLLM(LLM): messages=processed_prompt, tools=tools, tool_choice=tool_choice if tools else None, + max_tokens=max_tokens, # streaming choice stream=stream, # model params temperature=0, timeout=timeout_override or self._timeout, - max_tokens=max_tokens, # For now, we don't support parallel tool calls # NOTE: we can't pass this in if tools are not specified # or else OpenAI throws an error @@ -531,6 +531,7 @@ class DefaultMultiLLM(LLM): tool_choice, structured_response_format, timeout_override, + max_tokens, ) return diff --git a/backend/onyx/llm/factory.py b/backend/onyx/llm/factory.py index c77518f51..ae0f309eb 100644 --- a/backend/onyx/llm/factory.py +++ b/backend/onyx/llm/factory.py @@ -16,6 +16,7 @@ from onyx.llm.exceptions import GenAIDisabledException from onyx.llm.interfaces import LLM from onyx.llm.override_models import LLMOverride from onyx.llm.utils import model_supports_image_input +from onyx.server.manage.llm.models import LLMProvider from onyx.server.manage.llm.models import LLMProviderView from onyx.utils.headers import build_llm_extra_headers from onyx.utils.logger import setup_logger @@ -154,6 +155,40 @@ def get_default_llm_with_vision( return None +def llm_from_provider( + model_name: str, + llm_provider: LLMProvider, + timeout: int | None = None, + temperature: float | None = None, + additional_headers: dict[str, str] | None = None, + long_term_logger: LongTermLogger | None = None, +) -> LLM: + return get_llm( + provider=llm_provider.provider, + model=model_name, + deployment_name=llm_provider.deployment_name, + api_key=llm_provider.api_key, + api_base=llm_provider.api_base, + api_version=llm_provider.api_version, + custom_config=llm_provider.custom_config, + timeout=timeout, + temperature=temperature, + additional_headers=additional_headers, + long_term_logger=long_term_logger, + ) + + +def get_llm_for_contextual_rag(model_name: str, model_provider: str) -> LLM: + with get_session_context_manager() as db_session: + llm_provider = fetch_llm_provider_view(db_session, model_provider) + if not llm_provider: + raise ValueError("No LLM provider with name {} found".format(model_provider)) + return llm_from_provider( + model_name=model_name, + llm_provider=llm_provider, + ) + + def get_default_llms( timeout: int | None = None, temperature: float | None = None, @@ -179,14 +214,9 @@ def get_default_llms( raise ValueError("No fast default model name found") def _create_llm(model: str) -> LLM: - return get_llm( - provider=llm_provider.provider, - model=model, - deployment_name=llm_provider.deployment_name, - api_key=llm_provider.api_key, - api_base=llm_provider.api_base, - api_version=llm_provider.api_version, - custom_config=llm_provider.custom_config, + return llm_from_provider( + model_name=model, + llm_provider=llm_provider, timeout=timeout, temperature=temperature, additional_headers=additional_headers, diff --git a/backend/onyx/llm/utils.py b/backend/onyx/llm/utils.py index 04fc2260c..9f257b040 100644 --- a/backend/onyx/llm/utils.py +++ b/backend/onyx/llm/utils.py @@ -29,13 +29,19 @@ from litellm.exceptions import Timeout # type: ignore from litellm.exceptions import UnprocessableEntityError # type: ignore from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS +from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION +from onyx.configs.app_configs import USE_CHUNK_SUMMARY +from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY from onyx.configs.constants import MessageType +from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from onyx.configs.model_configs import GEN_AI_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS from onyx.file_store.models import ChatFileType from onyx.file_store.models import InMemoryChatFile from onyx.llm.interfaces import LLM +from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_TOKEN_ESTIMATE +from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_TOKEN_ESTIMATE from onyx.prompts.constants import CODE_BLOCK_PAT from onyx.utils.b64 import get_image_type from onyx.utils.b64 import get_image_type_from_bytes @@ -44,6 +50,10 @@ from shared_configs.configs import LOG_LEVEL logger = setup_logger() +MAX_CONTEXT_TOKENS = 100 +ONE_MILLION = 1_000_000 +CHUNKS_PER_DOC_ESTIMATE = 5 + def litellm_exception_to_error_msg( e: Exception, @@ -416,6 +426,72 @@ def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | N return None +def get_llm_contextual_cost( + llm: LLM, +) -> float: + """ + Approximate the cost of using the given LLM for indexing with Contextual RAG. + + We use a precomputed estimate for the number of tokens in the contextualizing prompts, + and we assume that every chunk is maximized in terms of content and context. + We also assume that every document is maximized in terms of content, as currently if + a document is longer than a certain length, its summary is used instead of the full content. + + We expect that the first assumption will overestimate more than the second one + underestimates, so this should be a fairly conservative price estimate. Also, + this does not account for the cost of documents that fit within a single chunk + which do not get contextualized. + """ + + # calculate input costs + num_tokens = ONE_MILLION + num_input_chunks = num_tokens // DOC_EMBEDDING_CONTEXT_SIZE + + # We assume that the documents are MAX_TOKENS_FOR_FULL_INCLUSION tokens long + # on average. + num_docs = num_tokens // MAX_TOKENS_FOR_FULL_INCLUSION + + num_input_tokens = 0 + num_output_tokens = 0 + + if not USE_CHUNK_SUMMARY and not USE_DOCUMENT_SUMMARY: + return 0 + + if USE_CHUNK_SUMMARY: + # Each per-chunk prompt includes: + # - The prompt tokens + # - the document tokens + # - the chunk tokens + + # for each chunk, we prompt the LLM with the contextual RAG prompt + # and the full document content (or the doc summary, so this is an overestimate) + num_input_tokens += num_input_chunks * ( + CONTEXTUAL_RAG_TOKEN_ESTIMATE + MAX_TOKENS_FOR_FULL_INCLUSION + ) + + # in aggregate, each chunk content is used as a prompt input once + # so the full input size is covered + num_input_tokens += num_tokens + + # A single MAX_CONTEXT_TOKENS worth of output is generated per chunk + num_output_tokens += num_input_chunks * MAX_CONTEXT_TOKENS + + # going over each doc once means all the tokens, plus the prompt tokens for + # the summary prompt. This CAN happen even when USE_DOCUMENT_SUMMARY is false, + # since doc summaries are used for longer documents when USE_CHUNK_SUMMARY is true. + # So, we include this unconditionally to overestimate. + num_input_tokens += num_tokens + num_docs * DOCUMENT_SUMMARY_TOKEN_ESTIMATE + num_output_tokens += num_docs * MAX_CONTEXT_TOKENS + + usd_per_prompt, usd_per_completion = litellm.cost_per_token( + model=llm.config.model_name, + prompt_tokens=num_input_tokens, + completion_tokens=num_output_tokens, + ) + # Costs are in USD dollars per million tokens + return usd_per_prompt + usd_per_completion + + def get_llm_max_tokens( model_map: dict, model_name: str, diff --git a/backend/onyx/natural_language_processing/utils.py b/backend/onyx/natural_language_processing/utils.py index 3c4d13920..0860cf429 100644 --- a/backend/onyx/natural_language_processing/utils.py +++ b/backend/onyx/natural_language_processing/utils.py @@ -11,6 +11,8 @@ from onyx.context.search.models import InferenceChunk from onyx.utils.logger import setup_logger from shared_configs.enums import EmbeddingProvider +TRIM_SEP_PAT = "\n... {n} tokens removed...\n" + logger = setup_logger() transformer_logging.set_verbosity_error() os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -159,9 +161,26 @@ def tokenizer_trim_content( content: str, desired_length: int, tokenizer: BaseTokenizer ) -> str: tokens = tokenizer.encode(content) - if len(tokens) > desired_length: - content = tokenizer.decode(tokens[:desired_length]) - return content + if len(tokens) <= desired_length: + return content + + return tokenizer.decode(tokens[:desired_length]) + + +def tokenizer_trim_middle( + tokens: list[int], desired_length: int, tokenizer: BaseTokenizer +) -> str: + if len(tokens) <= desired_length: + return tokenizer.decode(tokens) + sep_str = TRIM_SEP_PAT.format(n=len(tokens) - desired_length) + sep_tokens = tokenizer.encode(sep_str) + slice_size = (desired_length - len(sep_tokens)) // 2 + assert slice_size > 0, "Slice size is not positive, desired length is too short" + return ( + tokenizer.decode(tokens[:slice_size]) + + sep_str + + tokenizer.decode(tokens[-slice_size:]) + ) def tokenizer_trim_chunks( diff --git a/backend/onyx/prompts/chat_prompts.py b/backend/onyx/prompts/chat_prompts.py index 04cc33488..65c4f9e85 100644 --- a/backend/onyx/prompts/chat_prompts.py +++ b/backend/onyx/prompts/chat_prompts.py @@ -220,3 +220,29 @@ Chat History: Based on the above, what is a short name to convey the topic of the conversation? """.strip() + +# NOTE: the prompt separation is partially done for efficiency; previously I tried +# to do it all in one prompt with sequential format() calls but this will cause a backend +# error when the document contains any {} as python will expect the {} to be filled by +# format() arguments +CONTEXTUAL_RAG_PROMPT1 = """ +{document} + +Here is the chunk we want to situate within the whole document""" + +CONTEXTUAL_RAG_PROMPT2 = """ +{chunk} + +Please give a short succinct context to situate this chunk within the overall document +for the purposes of improving search retrieval of the chunk. Answer only with the succinct +context and nothing else. """ + +CONTEXTUAL_RAG_TOKEN_ESTIMATE = 64 # 19 + 45 + +DOCUMENT_SUMMARY_PROMPT = """ +{document} + +Please give a short succinct summary of the entire document. Answer only with the succinct +summary and nothing else. """ + +DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29 diff --git a/backend/onyx/seeding/load_docs.py b/backend/onyx/seeding/load_docs.py index a3aa99ea6..ab99f2b93 100644 --- a/backend/onyx/seeding/load_docs.py +++ b/backend/onyx/seeding/load_docs.py @@ -87,6 +87,9 @@ def _create_indexable_chunks( metadata_suffix_keyword="", mini_chunk_texts=None, large_chunk_reference_ids=[], + doc_summary="", + chunk_context="", + contextual_rag_reserved_tokens=0, embeddings=ChunkEmbedding( full_embedding=preprocessed_doc["content_embedding"], mini_chunk_embeddings=[], diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index 0a5ceb036..0e6a6ea03 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -21,9 +21,11 @@ from onyx.llm.factory import get_default_llms from onyx.llm.factory import get_llm from onyx.llm.llm_provider_options import fetch_available_well_known_llms from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor +from onyx.llm.utils import get_llm_contextual_cost from onyx.llm.utils import litellm_exception_to_error_msg from onyx.llm.utils import model_supports_image_input from onyx.llm.utils import test_llm +from onyx.server.manage.llm.models import LLMCost from onyx.server.manage.llm.models import LLMProviderDescriptor from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import LLMProviderView @@ -286,3 +288,38 @@ def list_llm_provider_basics( db_session, user ) ] + + +@admin_router.get("/provider-contextual-cost") +def get_provider_contextual_cost( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[LLMCost]: + """ + Get the cost of Re-indexing all documents for contextual retrieval. + + See https://docs.litellm.ai/docs/completion/token_usage#5-cost_per_token + This includes: + - The cost of invoking the LLM on each chunk-document pair to get + - the doc_summary + - the chunk_context + - The per-token cost of the LLM used to generate the doc_summary and chunk_context + """ + providers = fetch_existing_llm_providers(db_session) + costs = [] + for provider in providers: + for model_name in provider.display_model_names or provider.model_names or []: + llm = get_llm( + provider=provider.provider, + model=model_name, + deployment_name=provider.deployment_name, + api_key=provider.api_key, + api_base=provider.api_base, + api_version=provider.api_version, + custom_config=provider.custom_config, + ) + cost = get_llm_contextual_cost(llm) + costs.append( + LLMCost(provider=provider.name, model_name=model_name, cost=cost) + ) + return costs diff --git a/backend/onyx/server/manage/llm/models.py b/backend/onyx/server/manage/llm/models.py index 9d5544d96..7b9c7bc58 100644 --- a/backend/onyx/server/manage/llm/models.py +++ b/backend/onyx/server/manage/llm/models.py @@ -119,3 +119,9 @@ class VisionProviderResponse(LLMProviderView): """Response model for vision providers endpoint, including vision-specific fields.""" vision_models: list[str] + + +class LLMCost(BaseModel): + provider: str + model_name: str + cost: float diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py index 893dcaae0..29c668f8d 100644 --- a/backend/scripts/query_time_check/seed_dummy_docs.py +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -63,7 +63,10 @@ def generate_dummy_chunk( title_prefix=f"Title prefix for doc {doc_id}", metadata_suffix_semantic="", metadata_suffix_keyword="", + doc_summary="", + chunk_context="", mini_chunk_texts=None, + contextual_rag_reserved_tokens=0, embeddings=ChunkEmbedding( full_embedding=generate_random_embedding(embedding_dim), mini_chunk_embeddings=[], diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 13e7ba03e..30ff23fd6 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -99,6 +99,7 @@ PRESERVED_SEARCH_FIELDS = [ "api_url", "index_name", "multipass_indexing", + "enable_contextual_rag", "model_dim", "normalize", "passage_prefix", diff --git a/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py b/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py index 162db23c5..a8a913336 100644 --- a/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py +++ b/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py @@ -32,6 +32,8 @@ def create_test_chunk( match_highlights=[], updated_at=datetime.now(), image_file_name=None, + doc_summary="", + chunk_context="", ) diff --git a/backend/tests/unit/onyx/chat/conftest.py b/backend/tests/unit/onyx/chat/conftest.py index 323d174ee..263a1bc0f 100644 --- a/backend/tests/unit/onyx/chat/conftest.py +++ b/backend/tests/unit/onyx/chat/conftest.py @@ -78,6 +78,8 @@ def mock_inference_sections() -> list[InferenceSection]: source_links={0: "https://example.com/doc1"}, match_highlights=[], image_file_name=None, + doc_summary="", + chunk_context="", ), chunks=MagicMock(), ), @@ -101,6 +103,8 @@ def mock_inference_sections() -> list[InferenceSection]: source_links={0: "https://example.com/doc2"}, match_highlights=[], image_file_name=None, + doc_summary="", + chunk_context="", ), chunks=MagicMock(), ), diff --git a/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py b/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py index 1bc7cc12e..7b406ae2e 100644 --- a/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py +++ b/backend/tests/unit/onyx/chat/stream_processing/test_quotes_processing.py @@ -151,6 +151,8 @@ def test_fuzzy_match_quotes_to_docs() -> None: match_highlights=[], updated_at=None, image_file_name=None, + doc_summary="", + chunk_context="", ) test_chunk_1 = InferenceChunk( document_id="test doc 1", @@ -170,6 +172,8 @@ def test_fuzzy_match_quotes_to_docs() -> None: match_highlights=[], updated_at=None, image_file_name=None, + doc_summary="", + chunk_context="", ) test_quotes = [ diff --git a/backend/tests/unit/onyx/chat/test_prune_and_merge.py b/backend/tests/unit/onyx/chat/test_prune_and_merge.py index a4037b3d0..ee6299024 100644 --- a/backend/tests/unit/onyx/chat/test_prune_and_merge.py +++ b/backend/tests/unit/onyx/chat/test_prune_and_merge.py @@ -38,6 +38,8 @@ def create_inference_chunk( match_highlights=[], updated_at=None, image_file_name=None, + doc_summary="", + chunk_context="", ) diff --git a/backend/tests/unit/onyx/indexing/conftest.py b/backend/tests/unit/onyx/indexing/conftest.py index 6832add97..402c7b8dd 100644 --- a/backend/tests/unit/onyx/indexing/conftest.py +++ b/backend/tests/unit/onyx/indexing/conftest.py @@ -1,5 +1,6 @@ import pytest +from onyx.indexing.embedder import DefaultIndexingEmbedder from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface @@ -17,3 +18,13 @@ class MockHeartbeat(IndexingHeartbeatInterface): @pytest.fixture def mock_heartbeat() -> MockHeartbeat: return MockHeartbeat() + + +@pytest.fixture +def embedder() -> DefaultIndexingEmbedder: + return DefaultIndexingEmbedder( + model_name="intfloat/e5-base-v2", + normalize=True, + query_prefix=None, + passage_prefix=None, + ) diff --git a/backend/tests/unit/onyx/indexing/test_chunker.py b/backend/tests/unit/onyx/indexing/test_chunker.py index 57ba3fe12..7786123a4 100644 --- a/backend/tests/unit/onyx/indexing/test_chunker.py +++ b/backend/tests/unit/onyx/indexing/test_chunker.py @@ -1,25 +1,24 @@ +from typing import Any +from unittest.mock import Mock + import pytest +from onyx.configs.app_configs import USE_CHUNK_SUMMARY +from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY from onyx.configs.constants import DocumentSource from onyx.connectors.models import Document from onyx.connectors.models import TextSection from onyx.indexing.chunker import Chunker from onyx.indexing.embedder import DefaultIndexingEmbedder from onyx.indexing.indexing_pipeline import process_image_sections +from onyx.llm.utils import MAX_CONTEXT_TOKENS from tests.unit.onyx.indexing.conftest import MockHeartbeat -@pytest.fixture -def embedder() -> DefaultIndexingEmbedder: - return DefaultIndexingEmbedder( - model_name="intfloat/e5-base-v2", - normalize=True, - query_prefix=None, - passage_prefix=None, - ) - - -def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None: +@pytest.mark.parametrize("enable_contextual_rag", [True, False]) +def test_chunk_document( + embedder: DefaultIndexingEmbedder, enable_contextual_rag: bool +) -> None: short_section_1 = "This is a short section." long_section = ( "This is a long section that should be split into multiple chunks. " * 100 @@ -45,9 +44,22 @@ def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None: ) indexing_documents = process_image_sections([document]) + mock_llm_invoke_count = 0 + + def mock_llm_invoke(self: Any, *args: Any, **kwargs: Any) -> Mock: + nonlocal mock_llm_invoke_count + mock_llm_invoke_count += 1 + m = Mock() + m.content = f"Test{mock_llm_invoke_count}" + return m + + mock_llm = Mock() + mock_llm.invoke = mock_llm_invoke + chunker = Chunker( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=False, + enable_contextual_rag=enable_contextual_rag, ) chunks = chunker.chunk(indexing_documents) @@ -58,6 +70,14 @@ def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None: assert "tag1" in chunks[0].metadata_suffix_keyword assert "tag2" in chunks[0].metadata_suffix_semantic + rag_tokens = MAX_CONTEXT_TOKENS * ( + int(USE_DOCUMENT_SUMMARY) + int(USE_CHUNK_SUMMARY) + ) + for chunk in chunks: + assert chunk.contextual_rag_reserved_tokens == ( + rag_tokens if enable_contextual_rag else 0 + ) + def test_chunker_heartbeat( embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat @@ -78,6 +98,7 @@ def test_chunker_heartbeat( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=False, callback=mock_heartbeat, + enable_contextual_rag=False, ) chunks = chunker.chunk(indexing_documents) diff --git a/backend/tests/unit/onyx/indexing/test_embedder.py b/backend/tests/unit/onyx/indexing/test_embedder.py index d8589ecff..d49d28344 100644 --- a/backend/tests/unit/onyx/indexing/test_embedder.py +++ b/backend/tests/unit/onyx/indexing/test_embedder.py @@ -21,7 +21,13 @@ def mock_embedding_model() -> Generator[Mock, None, None]: yield mock -def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None: +@pytest.mark.parametrize( + "chunk_context, doc_summary", + [("Test chunk context", "Test document summary"), ("", "")], +) +def test_default_indexing_embedder_embed_chunks( + mock_embedding_model: Mock, chunk_context: str, doc_summary: str +) -> None: # Setup embedder = DefaultIndexingEmbedder( model_name="test-model", @@ -63,6 +69,9 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N large_chunk_reference_ids=[], large_chunk_id=None, image_file_name=None, + chunk_context=chunk_context, + doc_summary=doc_summary, + contextual_rag_reserved_tokens=200, ) ] @@ -81,7 +90,7 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N # Verify the embedding model was called correctly mock_embedding_model.return_value.encode.assert_any_call( - texts=["Title: Test chunk"], + texts=[f"Title: {doc_summary}Test chunk{chunk_context}"], text_type=EmbedTextType.PASSAGE, large_chunks_present=False, ) diff --git a/backend/tests/unit/onyx/indexing/test_indexing_pipeline.py b/backend/tests/unit/onyx/indexing/test_indexing_pipeline.py index a46d455a2..10dbef898 100644 --- a/backend/tests/unit/onyx/indexing/test_indexing_pipeline.py +++ b/backend/tests/unit/onyx/indexing/test_indexing_pipeline.py @@ -1,6 +1,8 @@ +from typing import Any from typing import cast from typing import List from unittest.mock import Mock +from unittest.mock import patch import pytest @@ -9,8 +11,12 @@ from onyx.connectors.models import Document from onyx.connectors.models import DocumentSource from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection +from onyx.indexing.chunker import Chunker +from onyx.indexing.embedder import DefaultIndexingEmbedder from onyx.indexing.indexing_pipeline import _get_aggregated_chunk_boost_factor +from onyx.indexing.indexing_pipeline import add_contextual_summaries from onyx.indexing.indexing_pipeline import filter_documents +from onyx.indexing.indexing_pipeline import process_image_sections from onyx.indexing.models import ChunkEmbedding from onyx.indexing.models import IndexChunk from onyx.natural_language_processing.search_nlp_models import ( @@ -166,6 +172,9 @@ def create_test_chunk( embeddings=ChunkEmbedding(full_embedding=[], mini_chunk_embeddings=[]), title_embedding=None, image_file_name=None, + chunk_context="", + doc_summary="", + contextual_rag_reserved_tokens=200, ) @@ -249,3 +258,76 @@ def test_get_aggregated_boost_factor_individual_failure() -> None: ) assert "Failed to predict content classification for chunk" in str(exc_info.value) + + +@patch("onyx.llm.utils.GEN_AI_MAX_TOKENS", 4096) +@pytest.mark.parametrize("enable_contextual_rag", [True, False]) +def test_contextual_rag( + embedder: DefaultIndexingEmbedder, enable_contextual_rag: bool +) -> None: + short_section_1 = "This is a short section." + long_section = ( + "This is a long section that should be split into multiple chunks. " * 100 + ) + short_section_2 = "This is another short section." + short_section_3 = "This is another short section again." + short_section_4 = "Final short section." + semantic_identifier = "Test Document" + + document = Document( + id="test_doc", + source=DocumentSource.WEB, + semantic_identifier=semantic_identifier, + metadata={"tags": ["tag1", "tag2"]}, + doc_updated_at=None, + sections=[ + TextSection(text=short_section_1, link="link1"), + TextSection(text=short_section_2, link="link2"), + TextSection(text=long_section, link="link3"), + TextSection(text=short_section_3, link="link4"), + TextSection(text=short_section_4, link="link5"), + ], + ) + indexing_documents = process_image_sections([document]) + + mock_llm_invoke_count = 0 + + def mock_llm_invoke(self: Any, *args: Any, **kwargs: Any) -> Mock: + nonlocal mock_llm_invoke_count + mock_llm_invoke_count += 1 + m = Mock() + m.content = f"Test{mock_llm_invoke_count}" + return m + + llm_tokenizer = embedder.embedding_model.tokenizer + + mock_llm = Mock() + mock_llm.invoke = mock_llm_invoke + + chunker = Chunker( + tokenizer=embedder.embedding_model.tokenizer, + enable_multipass=False, + enable_contextual_rag=enable_contextual_rag, + ) + chunks = chunker.chunk(indexing_documents) + + chunks = add_contextual_summaries( + chunks, mock_llm, llm_tokenizer, chunker.chunk_token_limit * 2 + ) + + assert len(chunks) == 5 + assert short_section_1 in chunks[0].content + assert short_section_3 in chunks[-1].content + assert short_section_4 in chunks[-1].content + assert "tag1" in chunks[0].metadata_suffix_keyword + assert "tag2" in chunks[0].metadata_suffix_semantic + + doc_summary = "Test1" if enable_contextual_rag else "" + chunk_context = "" + count = 2 + for chunk in chunks: + if enable_contextual_rag: + chunk_context = f"Test{count}" + count += 1 + assert chunk.doc_summary == doc_summary + assert chunk.chunk_context == chunk_context diff --git a/backend/tests/unit/onyx/llm/test_chat_llm.py b/backend/tests/unit/onyx/llm/test_chat_llm.py index b69b3b7de..4a34db8d4 100644 --- a/backend/tests/unit/onyx/llm/test_chat_llm.py +++ b/backend/tests/unit/onyx/llm/test_chat_llm.py @@ -140,12 +140,12 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None: ], tools=tools, tool_choice=None, + max_tokens=None, stream=False, temperature=0.0, # Default value from GEN_AI_TEMPERATURE timeout=30, parallel_tool_calls=False, mock_response=MOCK_LLM_RESPONSE, - max_tokens=None, ) @@ -286,10 +286,10 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No ], tools=tools, tool_choice=None, + max_tokens=None, stream=True, temperature=0.0, # Default value from GEN_AI_TEMPERATURE timeout=30, parallel_tool_calls=False, mock_response=MOCK_LLM_RESPONSE, - max_tokens=None, ) diff --git a/web/src/app/admin/configuration/llm/constants.ts b/web/src/app/admin/configuration/llm/constants.ts index d7e3449b3..0f1324c4a 100644 --- a/web/src/app/admin/configuration/llm/constants.ts +++ b/web/src/app/admin/configuration/llm/constants.ts @@ -1,5 +1,8 @@ export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider"; +export const LLM_CONTEXTUAL_COST_ADMIN_URL = + "/api/admin/llm/provider-contextual-cost"; + export const EMBEDDING_PROVIDERS_ADMIN_URL = "/api/admin/embedding/embedding-provider"; diff --git a/web/src/app/admin/configuration/search/page.tsx b/web/src/app/admin/configuration/search/page.tsx index 126f1a6a4..af948062a 100644 --- a/web/src/app/admin/configuration/search/page.tsx +++ b/web/src/app/admin/configuration/search/page.tsx @@ -143,6 +143,15 @@ function Main() { +
+ Contextual RAG + + {searchSettings.enable_contextual_rag + ? "Enabled" + : "Disabled"} + +
+
Disable Reranking for Streaming diff --git a/web/src/app/admin/embeddings/interfaces.ts b/web/src/app/admin/embeddings/interfaces.ts index f9e3c4161..394884955 100644 --- a/web/src/app/admin/embeddings/interfaces.ts +++ b/web/src/app/admin/embeddings/interfaces.ts @@ -26,9 +26,18 @@ export enum EmbeddingPrecision { BFLOAT16 = "bfloat16", } +export interface LLMContextualCost { + provider: string; + model_name: string; + cost: number; +} + export interface AdvancedSearchConfiguration { index_name: string | null; multipass_indexing: boolean; + enable_contextual_rag: boolean; + contextual_rag_llm_name: string | null; + contextual_rag_llm_provider: string | null; multilingual_expansion: string[]; disable_rerank_for_streaming: boolean; api_url: string | null; diff --git a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx index b5003f835..e335e8acc 100644 --- a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx @@ -3,7 +3,11 @@ import { Formik, Form, FormikProps, FieldArray, Field } from "formik"; import * as Yup from "yup"; import { TrashIcon } from "@/components/icons/icons"; import { FaPlus } from "react-icons/fa"; -import { AdvancedSearchConfiguration, EmbeddingPrecision } from "../interfaces"; +import { + AdvancedSearchConfiguration, + EmbeddingPrecision, + LLMContextualCost, +} from "../interfaces"; import { BooleanFormField, Label, @@ -12,6 +16,13 @@ import { } from "@/components/admin/connectors/Field"; import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput"; import { StringOrNumberOption } from "@/components/Dropdown"; +import useSWR from "swr"; +import { LLM_CONTEXTUAL_COST_ADMIN_URL } from "../../configuration/llm/constants"; +import { getDisplayNameForModel } from "@/lib/hooks"; +import { errorHandlingFetcher } from "@/lib/fetcher"; + +// Number of tokens to show cost calculation for +const COST_CALCULATION_TOKENS = 1_000_000; interface AdvancedEmbeddingFormPageProps { updateAdvancedEmbeddingDetails: ( @@ -45,14 +56,66 @@ const AdvancedEmbeddingFormPage = forwardRef< }, ref ) => { + // Fetch contextual costs + const { data: contextualCosts, error: costError } = useSWR< + LLMContextualCost[] + >(LLM_CONTEXTUAL_COST_ADMIN_URL, errorHandlingFetcher); + + const llmOptions: StringOrNumberOption[] = React.useMemo( + () => + (contextualCosts || []).map((cost) => { + return { + name: getDisplayNameForModel(cost.model_name), + value: cost.model_name, + }; + }), + [contextualCosts] + ); + + // Helper function to format cost as USD + const formatCost = (cost: number) => { + return new Intl.NumberFormat("en-US", { + style: "currency", + currency: "USD", + }).format(cost); + }; + + // Get cost info for selected model + const getSelectedModelCost = (modelName: string | null) => { + if (!contextualCosts || !modelName) return null; + return contextualCosts.find((cost) => cost.model_name === modelName); + }; + + // Get the current value for the selector based on the parent state + const getCurrentLLMValue = React.useMemo(() => { + if (!advancedEmbeddingDetails.contextual_rag_llm_name) return null; + return advancedEmbeddingDetails.contextual_rag_llm_name; + }, [advancedEmbeddingDetails.contextual_rag_llm_name]); + return (
{ // Call updateAdvancedEmbeddingDetails for each changed field Object.entries(values).forEach(([key, value]) => { - updateAdvancedEmbeddingDetails( - key as keyof AdvancedSearchConfiguration, - value - ); + if (key === "contextual_rag_llm") { + const selectedModel = (contextualCosts || []).find( + (cost) => cost.model_name === value + ); + if (selectedModel) { + updateAdvancedEmbeddingDetails( + "contextual_rag_llm_provider", + selectedModel.provider + ); + updateAdvancedEmbeddingDetails( + "contextual_rag_llm_name", + selectedModel.model_name + ); + } + } else { + updateAdvancedEmbeddingDetails( + key as keyof AdvancedSearchConfiguration, + value + ); + } }); // Run validation and report errors @@ -96,6 +175,23 @@ const AdvancedEmbeddingFormPage = forwardRef< .shape({ multilingual_expansion: Yup.array().of(Yup.string()), multipass_indexing: Yup.boolean(), + enable_contextual_rag: Yup.boolean(), + contextual_rag_llm: Yup.string() + .nullable() + .test( + "required-if-contextual-rag", + "LLM must be selected when Contextual RAG is enabled", + function (value) { + const enableContextualRag = + this.parent.enable_contextual_rag; + console.log( + "enableContextualRag2", + enableContextualRag + ); + console.log("value2", value); + return !enableContextualRag || value !== null; + } + ), disable_rerank_for_streaming: Yup.boolean(), num_rerank: Yup.number() .required("Number of results to rerank is required") @@ -190,6 +286,56 @@ const AdvancedEmbeddingFormPage = forwardRef< label="Disable Rerank for Streaming" name="disable_rerank_for_streaming" /> + +
+ + {values.enable_contextual_rag && + values.contextual_rag_llm && + !costError && ( +
+ {contextualCosts ? ( + <> + Estimated cost for processing{" "} + {COST_CALCULATION_TOKENS.toLocaleString()} tokens:{" "} + + {getSelectedModelCost(values.contextual_rag_llm) + ? formatCost( + getSelectedModelCost( + values.contextual_rag_llm + )!.cost + ) + : "Cost information not available"} + + + ) : ( + "Loading cost information..." + )} +
+ )} +
({ index_name: "", multipass_indexing: true, + enable_contextual_rag: false, + contextual_rag_llm_name: null, + contextual_rag_llm_provider: null, multilingual_expansion: [], disable_rerank_for_streaming: false, api_url: null, @@ -152,6 +155,9 @@ export default function EmbeddingForm() { setAdvancedEmbeddingDetails({ index_name: searchSettings.index_name, multipass_indexing: searchSettings.multipass_indexing, + enable_contextual_rag: searchSettings.enable_contextual_rag, + contextual_rag_llm_name: searchSettings.contextual_rag_llm_name, + contextual_rag_llm_provider: searchSettings.contextual_rag_llm_provider, multilingual_expansion: searchSettings.multilingual_expansion, disable_rerank_for_streaming: searchSettings.disable_rerank_for_streaming, @@ -197,7 +203,9 @@ export default function EmbeddingForm() { searchSettings?.embedding_precision != advancedEmbeddingDetails.embedding_precision || searchSettings?.reduced_dimension != - advancedEmbeddingDetails.reduced_dimension; + advancedEmbeddingDetails.reduced_dimension || + searchSettings?.enable_contextual_rag != + advancedEmbeddingDetails.enable_contextual_rag; const updateSearch = useCallback(async () => { if (!selectedProvider) { @@ -384,6 +392,14 @@ export default function EmbeddingForm() { advancedEmbeddingDetails.reduced_dimension && (
  • Reduced dimension modification
  • )} + {(searchSettings?.enable_contextual_rag != + advancedEmbeddingDetails.enable_contextual_rag || + searchSettings?.contextual_rag_llm_name != + advancedEmbeddingDetails.contextual_rag_llm_name || + searchSettings?.contextual_rag_llm_provider != + advancedEmbeddingDetails.contextual_rag_llm_provider) && ( +
  • Contextual RAG modification
  • + )}
    @@ -471,6 +487,11 @@ export default function EmbeddingForm() { }; const handleReIndex = async () => { + console.log("handleReIndex"); + console.log(selectedProvider); + console.log(advancedEmbeddingDetails); + console.log(rerankingDetails); + console.log(reindexType); if (!selectedProvider) { return; } diff --git a/web/src/components/admin/connectors/Field.tsx b/web/src/components/admin/connectors/Field.tsx index eeec14446..d87d2f5d3 100644 --- a/web/src/components/admin/connectors/Field.tsx +++ b/web/src/components/admin/connectors/Field.tsx @@ -676,6 +676,7 @@ interface SelectorFormFieldProps { includeReset?: boolean; fontSize?: "sm" | "md" | "lg"; small?: boolean; + disabled?: boolean; } export function SelectorFormField({ @@ -691,6 +692,7 @@ export function SelectorFormField({ includeReset = false, fontSize = "md", small = false, + disabled = false, }: SelectorFormFieldProps) { const [field] = useField(name); const { setFieldValue } = useFormikContext(); @@ -742,8 +744,9 @@ export function SelectorFormField({ : setFieldValue(name, selected)) } defaultValue={defaultValue} + disabled={disabled} > - + {currentlySelected?.name || defaultValue || ""}