port contextual rag to new repo

This commit is contained in:
Evan Lohn 2025-03-13 12:27:31 -07:00
parent 463340b8a1
commit 7b2e8bcd66
39 changed files with 723 additions and 37 deletions

View File

@ -0,0 +1,50 @@
"""enable contextual retrieval
Revision ID: 8e1ac4f39a9f
Revises: 3934b1bc7b62
Create Date: 2024-12-20 13:29:09.918661
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8e1ac4f39a9f"
down_revision = "3934b1bc7b62"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"search_settings",
sa.Column(
"enable_contextual_rag",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_name",
sa.String(),
nullable=True,
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_provider",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("search_settings", "enable_contextual_rag")
op.drop_column("search_settings", "contextual_rag_llm_name")
op.drop_column("search_settings", "contextual_rag_llm_provider")

View File

@ -470,6 +470,8 @@ NUM_SECONDARY_INDEXING_WORKERS = int(
ENABLE_MULTIPASS_INDEXING = (
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
)
# Enable contextual retrieval
ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
@ -511,6 +513,17 @@ MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
# Use document summary for contextual rag
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
# Use chunk summary for contextual rag
USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true"
# Average summary embeddings for contextual rag (not yet implemented)
AVERAGE_SUMMARY_EMBEDDINGS = (
os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true"
)
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
#####
# Miscellaneous
#####

View File

@ -164,6 +164,9 @@ class DocumentBase(BaseModel):
attributes.append(k + INDEX_SEPARATOR + v)
return attributes
def get_content(self) -> str:
return " ".join([section.text for section in self.sections])
class Document(DocumentBase):
"""Used for Onyx ingestion api, the ID is required"""

View File

@ -80,6 +80,9 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
reduced_dimension=search_settings.reduced_dimension,
# Whether switching to this model requires re-indexing
background_reindex_enabled=search_settings.background_reindex_enabled,
enable_contextual_rag=search_settings.enable_contextual_rag,
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
# Reranking Details
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
@ -218,6 +221,8 @@ class InferenceChunk(BaseChunk):
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
match_highlights: list[str]
doc_summary: str
chunk_context: str
# when the doc was last updated
updated_at: datetime | None

View File

@ -196,9 +196,21 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
RETURN_SEPARATOR
)
def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str:
# remove document summary
if chunk.content.startswith(chunk.doc_summary):
chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip()
# remove chunk context
if chunk.content.endswith(chunk.chunk_context):
chunk.content = chunk.content[
: len(chunk.content) - len(chunk.chunk_context)
].rstrip()
return chunk.content
for chunk in chunks:
chunk.content = _remove_title(chunk)
chunk.content = _remove_metadata_suffix(chunk)
chunk.content = _remove_contextual_rag(chunk)
return [chunk.to_inference_chunk() for chunk in chunks]

View File

@ -791,6 +791,15 @@ class SearchSettings(Base):
# Mini and Large Chunks (large chunk also checks for model max context)
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
# Contextual RAG
enable_contextual_rag: Mapped[bool] = mapped_column(Boolean, default=False)
# Contextual RAG LLM
contextual_rag_llm_name: Mapped[str | None] = mapped_column(String, nullable=True)
contextual_rag_llm_provider: Mapped[str | None] = mapped_column(
String, nullable=True
)
multilingual_expansion: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), default=[]
)

View File

@ -62,6 +62,9 @@ def create_search_settings(
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
enable_contextual_rag=search_settings.enable_contextual_rag,
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
multilingual_expansion=search_settings.multilingual_expansion,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
@ -319,6 +322,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
index_name="danswer_chunk",
multipass_indexing=False,
enable_contextual_rag=False,
api_url=None,
)
@ -333,5 +337,6 @@ def get_new_default_embedding_model() -> IndexingSetting:
passage_prefix=ASYM_PASSAGE_PREFIX,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
multipass_indexing=False,
enable_contextual_rag=False,
api_url=None,
)

View File

@ -98,6 +98,12 @@ schema DANSWER_CHUNK_NAME {
field metadata type string {
indexing: summary | attribute
}
field chunk_context type string {
indexing: summary | attribute
}
field doc_summary type string {
indexing: summary | attribute
}
field metadata_suffix type string {
indexing: summary | attribute
}

View File

@ -24,9 +24,11 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import CONTENT
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOC_SUMMARY
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
@ -126,7 +128,8 @@ def _vespa_hit_to_inference_chunk(
return InferenceChunkUncleaned(
chunk_id=fields[CHUNK_ID],
blurb=fields.get(BLURB, ""), # Unused
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
content=fields[CONTENT], # Includes extra title prefix and metadata suffix;
# also sometimes context for contextual rag
source_links=source_links_dict or {0: ""},
section_continuation=fields[SECTION_CONTINUATION],
document_id=fields[DOCUMENT_ID],
@ -143,6 +146,8 @@ def _vespa_hit_to_inference_chunk(
large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []),
metadata=metadata,
metadata_suffix=fields.get(METADATA_SUFFIX),
doc_summary=fields.get(DOC_SUMMARY, ""),
chunk_context=fields.get(CHUNK_CONTEXT, ""),
match_highlights=match_highlights,
updated_at=updated_at,
)

View File

@ -187,7 +187,7 @@ class VespaIndex(DocumentIndex):
) -> None:
if MULTI_TENANT:
logger.info(
"Skipping Vespa index seup for multitenant (would wipe all indices)"
"Skipping Vespa index setup for multitenant (would wipe all indices)"
)
return None

View File

@ -25,9 +25,11 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import CONTENT
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOC_SUMMARY
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
@ -174,7 +176,7 @@ def _index_vespa_chunk(
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
# natural language representation of the metadata section
CONTENT: remove_invalid_unicode_chars(
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}"
),
# This duplication of `content` is needed for keyword highlighting
# Note that it's not exactly the same as the actual content
@ -189,6 +191,8 @@ def _index_vespa_chunk(
# Save as a list for efficient extraction as an Attribute
METADATA_LIST: metadata_list,
METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword),
CHUNK_CONTEXT: chunk.chunk_context,
DOC_SUMMARY: chunk.doc_summary,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),

View File

@ -71,6 +71,8 @@ LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
DOC_SUMMARY = "doc_summary"
CHUNK_CONTEXT = "chunk_context"
BOOST = "boost"
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
@ -106,6 +108,8 @@ YQL_BASE = (
f"{LARGE_CHUNK_REFERENCE_IDS}, "
f"{METADATA}, "
f"{METADATA_SUFFIX}, "
f"{DOC_SUMMARY}, "
f"{CHUNK_CONTEXT}, "
f"{CONTENT_SUMMARY} "
f"from {{index_name}} where "
)

View File

@ -1,7 +1,11 @@
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import MINI_CHUNK_SIZE
from onyx.configs.app_configs import SKIP_METADATA_IN_CHUNK
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.constants import SECTION_SEPARATOR
@ -13,10 +17,19 @@ from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.llm.interfaces import LLM
from onyx.llm.utils import get_max_input_tokens
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from onyx.llm.utils import message_to_string
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_middle
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT1
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT2
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import clean_text
from onyx.utils.text_processing import shared_precompare_cleanup
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
@ -82,6 +95,8 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks],
mini_chunk_texts=None,
large_chunk_id=large_chunk_id,
chunk_context="",
doc_summary="",
)
offset = 0
@ -118,8 +133,10 @@ class Chunker:
def __init__(
self,
tokenizer: BaseTokenizer,
llm: LLM | None = None,
enable_multipass: bool = False,
enable_large_chunks: bool = False,
enable_contextual_rag: bool = False,
blurb_size: int = BLURB_SIZE,
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
@ -129,13 +146,33 @@ class Chunker:
) -> None:
from llama_index.text_splitter import SentenceSplitter
if llm is None and enable_contextual_rag:
raise ValueError("LLM must be provided for contextual RAG")
self.include_metadata = include_metadata
self.chunk_token_limit = chunk_token_limit
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.enable_contextual_rag = enable_contextual_rag
if enable_contextual_rag:
assert (
USE_CHUNK_SUMMARY or USE_DOCUMENT_SUMMARY
), "Contextual RAG requires at least one of chunk summary and document summary enabled"
self.tokenizer = tokenizer
self.llm = llm
self.callback = callback
self.max_context = 0
self.prompt_tokens = 0
if self.llm is not None:
self.max_context = get_max_input_tokens(
model_name=self.llm.config.model_name,
model_provider=self.llm.config.model_provider,
)
self.prompt_tokens = len(
self.tokenizer.encode(CONTEXTUAL_RAG_PROMPT1 + CONTEXTUAL_RAG_PROMPT2)
)
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
chunk_size=blurb_size,
@ -204,6 +241,8 @@ class Chunker:
metadata_suffix_semantic: str = "",
metadata_suffix_keyword: str = "",
image_file_name: str | None = None,
doc_summary: str = "",
chunk_context: str = "",
) -> None:
"""
Helper to create a new DocAwareChunk, append it to chunks_list.
@ -221,6 +260,8 @@ class Chunker:
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=self._get_mini_chunk_texts(text),
large_chunk_id=None,
doc_summary=doc_summary,
chunk_context=chunk_context,
)
chunks_list.append(new_chunk)
@ -309,7 +350,7 @@ class Chunker:
continue
# CASE 2: Normal text section
section_token_count = len(self.tokenizer.tokenize(section_text))
section_token_count = len(self.tokenizer.encode(section_text))
# If the section is large on its own, split it separately
if section_token_count > content_token_limit:
@ -332,8 +373,7 @@ class Chunker:
# If even the split_text is bigger than strict limit, further split
if (
STRICT_CHUNK_TOKEN_LIMIT
and len(self.tokenizer.tokenize(split_text))
> content_token_limit
and len(self.tokenizer.encode(split_text)) > content_token_limit
):
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
@ -363,10 +403,10 @@ class Chunker:
continue
# If we can still fit this section into the current chunk, do so
current_token_count = len(self.tokenizer.tokenize(chunk_text))
current_token_count = len(self.tokenizer.encode(chunk_text))
current_offset = len(shared_precompare_cleanup(chunk_text))
next_section_tokens = (
len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
)
if next_section_tokens + current_token_count <= content_token_limit:
@ -414,7 +454,7 @@ class Chunker:
# Title prep
title = self._extract_blurb(document.get_title_for_document_index() or "")
title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(self.tokenizer.tokenize(title_prefix))
title_tokens = len(self.tokenizer.encode(title_prefix))
# Metadata prep
metadata_suffix_semantic = ""
@ -427,15 +467,53 @@ class Chunker:
) = _get_metadata_suffix_for_document_index(
document.metadata, include_separator=True
)
metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic))
metadata_tokens = len(self.tokenizer.encode(metadata_suffix_semantic))
# If metadata is too large, skip it in the semantic content
if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE:
metadata_suffix_semantic = ""
metadata_tokens = 0
single_chunk_fits = True
doc_token_count = 0
if self.enable_contextual_rag:
doc_content = document.get_content()
tokenized_doc = self.tokenizer.tokenize(doc_content)
doc_token_count = len(tokenized_doc)
# check if doc + title + metadata fits in a single chunk. If so, no need for contextual RAG
single_chunk_fits = (
doc_token_count + title_tokens + metadata_tokens
<= self.chunk_token_limit
)
# expand the size of the context used for contextual rag based on whether chunk context and doc summary are used
context_size = 0
if (
self.enable_contextual_rag
and not single_chunk_fits
and not AVERAGE_SUMMARY_EMBEDDINGS
):
if USE_CHUNK_SUMMARY:
context_size += MAX_CONTEXT_TOKENS
if USE_DOCUMENT_SUMMARY:
context_size += MAX_CONTEXT_TOKENS
# Adjust content token limit to accommodate title + metadata
content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens
content_token_limit = (
self.chunk_token_limit - title_tokens - metadata_tokens - context_size
)
# first check: if there is not enough actual chunk content when including contextual rag,
# then don't do contextual rag
if content_token_limit <= CHUNK_MIN_CONTENT:
context_size = 0 # Don't do contextual RAG
# revert to previous content token limit
content_token_limit = (
self.chunk_token_limit - title_tokens - metadata_tokens
)
# If there is not enough context remaining then just index the chunk with no prefix/suffix
if content_token_limit <= CHUNK_MIN_CONTENT:
# Not enough space left, so revert to full chunk without the prefix
content_token_limit = self.chunk_token_limit
@ -459,6 +537,70 @@ class Chunker:
large_chunks = generate_large_chunks(normal_chunks)
normal_chunks.extend(large_chunks)
# exit early if contextual rag disabled or not used for this document
if not self.enable_contextual_rag or context_size == 0:
return normal_chunks
if self.llm is None:
raise ValueError("LLM must be available for contextual RAG")
doc_summary = ""
doc_content = ""
doc_tokens = []
if USE_DOCUMENT_SUMMARY:
trunc_doc_tokens = self.max_context - len(
self.tokenizer.encode(DOCUMENT_SUMMARY_PROMPT)
)
doc_tokens = self.tokenizer.encode(doc_content)
doc_content = tokenizer_trim_middle(
doc_tokens, trunc_doc_tokens, self.tokenizer
)
summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content)
doc_summary = message_to_string(
self.llm.invoke(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS)
)
for chunk in normal_chunks:
chunk.doc_summary = doc_summary
if USE_CHUNK_SUMMARY:
# Truncate the document content to fit in the model context
trunc_doc_tokens = (
self.max_context - self.prompt_tokens - self.chunk_token_limit
)
# use values computed in above doc summary section if available
doc_tokens = doc_tokens or self.tokenizer.encode(doc_content)
doc_content = doc_content or tokenizer_trim_middle(
doc_tokens, trunc_doc_tokens, self.tokenizer
)
# only compute doc summary if needed
doc_info = (
doc_content
if len(doc_tokens) <= MAX_TOKENS_FOR_FULL_INCLUSION
else (
doc_summary or DOCUMENT_SUMMARY_PROMPT.format(document=doc_content)
)
)
context_prompt1 = CONTEXTUAL_RAG_PROMPT1.format(document=doc_info)
def assign_context(chunk: DocAwareChunk) -> None:
context_prompt2 = CONTEXTUAL_RAG_PROMPT2.format(chunk=chunk.content)
chunk.chunk_context = (
""
if self.llm is None
else message_to_string(
self.llm.invoke(
context_prompt1 + context_prompt2,
max_tokens=MAX_CONTEXT_TOKENS,
)
)
)
run_functions_tuples_in_parallel(
[(assign_context, (chunk,)) for chunk in normal_chunks]
)
return normal_chunks
def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]:

View File

@ -121,7 +121,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
if chunk.large_chunk_reference_ids:
large_chunks_present = True
chunk_text = (
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}"
) or chunk.source_document.get_title_for_document_index()
if not chunk_text:

View File

@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
from onyx.access.access import get_access_for_documents
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
@ -58,6 +59,7 @@ from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_llm_for_contextual_rag
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
@ -794,12 +796,24 @@ def build_indexing_pipeline(
search_settings = get_current_search_settings(db_session)
multipass_config = get_multipass_config(search_settings)
enable_contextual_rag = (
search_settings.enable_contextual_rag
if search_settings
else ENABLE_CONTEXTUAL_RAG
)
llm = get_llm_for_contextual_rag(
search_settings.contextual_rag_llm_name,
search_settings.contextual_rag_llm_provider,
)
chunker = chunker or Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass_config.multipass_indexing,
enable_large_chunks=multipass_config.enable_large_chunks,
enable_contextual_rag=enable_contextual_rag,
# after every doc, update status in case there are a bunch of really long docs
callback=callback,
llm=llm,
)
return partial(

View File

@ -49,6 +49,11 @@ class DocAwareChunk(BaseChunk):
metadata_suffix_semantic: str
metadata_suffix_keyword: str
# This is the summary for the document generated for contextual RAG
doc_summary: str
# This is the context for this chunk generated for contextual RAG
chunk_context: str
mini_chunk_texts: list[str] | None
large_chunk_id: int | None
@ -154,6 +159,9 @@ class IndexingSetting(EmbeddingModelDetail):
reduced_dimension: int | None = None
background_reindex_enabled: bool = True
enable_contextual_rag: bool
contextual_rag_llm_name: str | None = None
contextual_rag_llm_provider: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
@ -178,6 +186,7 @@ class IndexingSetting(EmbeddingModelDetail):
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
background_reindex_enabled=search_settings.background_reindex_enabled,
enable_contextual_rag=search_settings.enable_contextual_rag,
)

View File

@ -425,12 +425,12 @@ class DefaultMultiLLM(LLM):
messages=processed_prompt,
tools=tools,
tool_choice=tool_choice if tools else None,
max_tokens=max_tokens,
# streaming choice
stream=stream,
# model params
temperature=0,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@ -531,6 +531,7 @@ class DefaultMultiLLM(LLM):
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
return

View File

@ -17,6 +17,7 @@ from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProvider
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
@ -154,6 +155,44 @@ def get_default_llm_with_vision(
return None
def llm_from_provider(
model_name: str,
llm_provider: LLMProvider,
timeout: int | None = None,
temperature: float | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model_name,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
def get_llm_for_contextual_rag(
model_name: str | None, model_provider: str | None
) -> LLM | None:
if not model_name or not model_provider:
return None
with get_session_context_manager() as db_session:
llm_provider = fetch_provider(db_session, model_provider)
if not llm_provider:
raise ValueError("No LLM provider with name {} found".format(model_provider))
return llm_from_provider(
model_name=model_name,
llm_provider=llm_provider,
)
def get_default_llms(
timeout: int | None = None,
temperature: float | None = None,
@ -179,14 +218,9 @@ def get_default_llms(
raise ValueError("No fast default model name found")
def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
return llm_from_provider(
model_name=model,
llm_provider=llm_provider,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,

View File

@ -29,13 +29,19 @@ from litellm.exceptions import Timeout # type: ignore
from litellm.exceptions import UnprocessableEntityError # type: ignore
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import MessageType
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_TOKEN_ESTIMATE
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_TOKEN_ESTIMATE
from onyx.prompts.constants import CODE_BLOCK_PAT
from onyx.utils.b64 import get_image_type
from onyx.utils.b64 import get_image_type_from_bytes
@ -44,6 +50,10 @@ from shared_configs.configs import LOG_LEVEL
logger = setup_logger()
MAX_CONTEXT_TOKENS = 100
ONE_MILLION = 1_000_000
CHUNKS_PER_DOC_ESTIMATE = 5
def litellm_exception_to_error_msg(
e: Exception,
@ -416,6 +426,72 @@ def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | N
return None
def get_llm_contextual_cost(
llm: LLM,
) -> float:
"""
Approximate the cost of using the given LLM for indexing with Contextual RAG.
We use a precomputed estimate for the number of tokens in the contextualizing prompts,
and we assume that every chunk is maximized in terms of content and context.
We also assume that every document is maximized in terms of content, as currently if
a document is longer than a certain length, its summary is used instead of the full content.
We expect that the first assumption will overestimate more than the second one
underestimates, so this should be a fairly conservative price estimate. Also,
this does not account for the cost of documents that fit within a single chunk
which do not get contextualized.
"""
# calculate input costs
num_tokens = ONE_MILLION
num_input_chunks = num_tokens // DOC_EMBEDDING_CONTEXT_SIZE
# We assume that the documents are MAX_TOKENS_FOR_FULL_INCLUSION tokens long
# on average.
num_docs = num_tokens // MAX_TOKENS_FOR_FULL_INCLUSION
num_input_tokens = 0
num_output_tokens = 0
if not USE_CHUNK_SUMMARY and not USE_DOCUMENT_SUMMARY:
return 0
if USE_CHUNK_SUMMARY:
# Each per-chunk prompt includes:
# - The prompt tokens
# - the document tokens
# - the chunk tokens
# for each chunk, we prompt the LLM with the contextual RAG prompt
# and the full document content (or the doc summary, so this is an overestimate)
num_input_tokens += num_input_chunks * (
CONTEXTUAL_RAG_TOKEN_ESTIMATE + MAX_TOKENS_FOR_FULL_INCLUSION
)
# in aggregate, each chunk content is used as a prompt input once
# so the full input size is covered
num_input_tokens += num_tokens
# A single MAX_CONTEXT_TOKENS worth of output is generated per chunk
num_output_tokens += num_input_chunks * MAX_CONTEXT_TOKENS
# going over each doc once means all the tokens, plus the prompt tokens for
# the summary prompt. This CAN happen even when USE_DOCUMENT_SUMMARY is false,
# since doc summaries are used for longer documents when USE_CHUNK_SUMMARY is true.
# So, we include this unconditionally to overestimate.
num_input_tokens += num_tokens + num_docs * DOCUMENT_SUMMARY_TOKEN_ESTIMATE
num_output_tokens += num_docs * MAX_CONTEXT_TOKENS
usd_per_prompt, usd_per_completion = litellm.cost_per_token(
model=llm.config.model_name,
prompt_tokens=num_input_tokens,
completion_tokens=num_output_tokens,
)
# Costs are in USD dollars per million tokens
return usd_per_prompt + usd_per_completion
def get_llm_max_tokens(
model_map: dict,
model_name: str,

View File

@ -11,6 +11,8 @@ from onyx.context.search.models import InferenceChunk
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
TRIM_SEP_PAT = "\n... {n} tokens removed...\n"
logger = setup_logger()
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -159,9 +161,26 @@ def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
if len(tokens) <= desired_length:
return content
return tokenizer.decode(tokens[:desired_length])
def tokenizer_trim_middle(
tokens: list[int], desired_length: int, tokenizer: BaseTokenizer
) -> str:
if len(tokens) <= desired_length:
return tokenizer.decode(tokens)
sep_str = TRIM_SEP_PAT.format(n=len(tokens) - desired_length)
sep_tokens = tokenizer.encode(sep_str)
slice_size = (desired_length - len(sep_tokens)) // 2
assert slice_size > 0, "Slice size is not positive, desired length is too short"
return (
tokenizer.decode(tokens[:slice_size])
+ sep_str
+ tokenizer.decode(tokens[-slice_size:])
)
def tokenizer_trim_chunks(

View File

@ -220,3 +220,29 @@ Chat History:
Based on the above, what is a short name to convey the topic of the conversation?
""".strip()
# NOTE: the prompt separation is partially done for efficiency; previously I tried
# to do it all in one prompt with sequential format() calls but this will cause a backend
# error when the document contains any {} as python will expect the {} to be filled by
# format() arguments
CONTEXTUAL_RAG_PROMPT1 = """<document>
{document}
</document>
Here is the chunk we want to situate within the whole document"""
CONTEXTUAL_RAG_PROMPT2 = """<chunk>
{chunk}
</chunk>
Please give a short succinct context to situate this chunk within the overall document
for the purposes of improving search retrieval of the chunk. Answer only with the succinct
context and nothing else. """
CONTEXTUAL_RAG_TOKEN_ESTIMATE = 64 # 19 + 45
DOCUMENT_SUMMARY_PROMPT = """<document>
{document}
</document>
Please give a short succinct summary of the entire document. Answer only with the succinct
summary and nothing else. """
DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29

View File

@ -87,6 +87,8 @@ def _create_indexable_chunks(
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_reference_ids=[],
doc_summary="",
chunk_context="",
embeddings=ChunkEmbedding(
full_embedding=preprocessed_doc["content_embedding"],
mini_chunk_embeddings=[],

View File

@ -21,10 +21,12 @@ from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llm
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import get_llm_contextual_cost
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest
@ -259,3 +261,38 @@ def list_llm_provider_basics(
db_session, user
)
]
@admin_router.get("/provider-contextual-cost")
def get_provider_contextual_cost(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMCost]:
"""
Get the cost of Re-indexing all documents for contextual retrieval.
See https://docs.litellm.ai/docs/completion/token_usage#5-cost_per_token
This includes:
- The cost of invoking the LLM on each chunk-document pair to get
- the doc_summary
- the chunk_context
- The per-token cost of the LLM used to generate the doc_summary and chunk_context
"""
providers = fetch_existing_llm_providers(db_session)
costs = []
for provider in providers:
for model_name in provider.display_model_names or provider.model_names or []:
llm = get_llm(
provider=provider.provider,
model=model_name,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
)
cost = get_llm_contextual_cost(llm)
costs.append(
LLMCost(provider=provider.name, model_name=model_name, cost=cost)
)
return costs

View File

@ -115,3 +115,9 @@ class VisionProviderResponse(FullLLMProvider):
"""Response model for vision providers endpoint, including vision-specific fields."""
vision_models: list[str]
class LLMCost(BaseModel):
provider: str
model_name: str
cost: float

View File

@ -63,6 +63,8 @@ def generate_dummy_chunk(
title_prefix=f"Title prefix for doc {doc_id}",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
embeddings=ChunkEmbedding(
full_embedding=generate_random_embedding(embedding_dim),

View File

@ -99,6 +99,7 @@ PRESERVED_SEARCH_FIELDS = [
"api_url",
"index_name",
"multipass_indexing",
"enable_contextual_rag",
"model_dim",
"normalize",
"passage_prefix",

View File

@ -32,6 +32,8 @@ def create_test_chunk(
match_highlights=[],
updated_at=datetime.now(),
image_file_name=None,
doc_summary="",
chunk_context="",
)

View File

@ -81,6 +81,8 @@ def mock_inference_sections() -> list[InferenceSection]:
source_links={0: "https://example.com/doc1"},
match_highlights=[],
image_file_name=None,
doc_summary="",
chunk_context="",
),
chunks=MagicMock(),
),
@ -104,6 +106,8 @@ def mock_inference_sections() -> list[InferenceSection]:
source_links={0: "https://example.com/doc2"},
match_highlights=[],
image_file_name=None,
doc_summary="",
chunk_context="",
),
chunks=MagicMock(),
),

View File

@ -151,6 +151,8 @@ def test_fuzzy_match_quotes_to_docs() -> None:
match_highlights=[],
updated_at=None,
image_file_name=None,
doc_summary="",
chunk_context="",
)
test_chunk_1 = InferenceChunk(
document_id="test doc 1",
@ -170,6 +172,8 @@ def test_fuzzy_match_quotes_to_docs() -> None:
match_highlights=[],
updated_at=None,
image_file_name=None,
doc_summary="",
chunk_context="",
)
test_quotes = [

View File

@ -38,6 +38,8 @@ def create_inference_chunk(
match_highlights=[],
updated_at=None,
image_file_name=None,
doc_summary="",
chunk_context="",
)

View File

@ -1,3 +1,6 @@
from typing import Any
from unittest.mock import Mock
import pytest
from onyx.configs.constants import DocumentSource
@ -19,7 +22,10 @@ def embedder() -> DefaultIndexingEmbedder:
)
def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
@pytest.mark.parametrize("enable_contextual_rag", [True, False])
def test_chunk_document(
embedder: DefaultIndexingEmbedder, enable_contextual_rag: bool
) -> None:
short_section_1 = "This is a short section."
long_section = (
"This is a long section that should be split into multiple chunks. " * 100
@ -45,9 +51,23 @@ def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
)
indexing_documents = process_image_sections([document])
mock_llm_invoke_count = 0
def mock_llm_invoke(self: Any, *args: Any, **kwargs: Any) -> Mock:
nonlocal mock_llm_invoke_count
mock_llm_invoke_count += 1
m = Mock()
m.content = f"Test{mock_llm_invoke_count}"
return m
mock_llm = Mock()
mock_llm.invoke = mock_llm_invoke
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
enable_contextual_rag=enable_contextual_rag,
llm=mock_llm,
)
chunks = chunker.chunk(indexing_documents)
@ -58,6 +78,16 @@ def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
assert "tag1" in chunks[0].metadata_suffix_keyword
assert "tag2" in chunks[0].metadata_suffix_semantic
doc_summary = "Test1" if enable_contextual_rag else ""
chunk_context = ""
count = 2
for chunk in chunks:
if enable_contextual_rag:
chunk_context = f"Test{count}"
count += 1
assert chunk.doc_summary == doc_summary
assert chunk.chunk_context == chunk_context
def test_chunker_heartbeat(
embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat
@ -78,6 +108,7 @@ def test_chunker_heartbeat(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
callback=mock_heartbeat,
enable_contextual_rag=False,
)
chunks = chunker.chunk(indexing_documents)

View File

@ -21,7 +21,13 @@ def mock_embedding_model() -> Generator[Mock, None, None]:
yield mock
def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None:
@pytest.mark.parametrize(
"chunk_context, doc_summary",
[("Test chunk context", "Test document summary"), ("", "")],
)
def test_default_indexing_embedder_embed_chunks(
mock_embedding_model: Mock, chunk_context: str, doc_summary: str
) -> None:
# Setup
embedder = DefaultIndexingEmbedder(
model_name="test-model",
@ -63,6 +69,8 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N
large_chunk_reference_ids=[],
large_chunk_id=None,
image_file_name=None,
chunk_context=chunk_context,
doc_summary=doc_summary,
)
]
@ -81,7 +89,7 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N
# Verify the embedding model was called correctly
mock_embedding_model.return_value.encode.assert_any_call(
texts=["Title: Test chunk"],
texts=[f"Title: {doc_summary}Test chunk{chunk_context}"],
text_type=EmbedTextType.PASSAGE,
large_chunks_present=False,
)

View File

@ -140,12 +140,12 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
],
tools=tools,
tool_choice=None,
max_tokens=None,
stream=False,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
max_tokens=None,
)
@ -286,10 +286,10 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No
],
tools=tools,
tool_choice=None,
max_tokens=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
max_tokens=None,
)

View File

@ -1,5 +1,8 @@
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
export const LLM_CONTEXTUAL_COST_ADMIN_URL =
"/api/admin/llm/provider-contextual-cost";
export const EMBEDDING_PROVIDERS_ADMIN_URL =
"/api/admin/embedding/embedding-provider";

View File

@ -143,6 +143,15 @@ function Main() {
</Text>
</div>
<div>
<Text className="font-semibold">Contextual RAG</Text>
<Text className="text-text-700">
{searchSettings.enable_contextual_rag
? "Enabled"
: "Disabled"}
</Text>
</div>
<div>
<Text className="font-semibold">
Disable Reranking for Streaming

View File

@ -26,9 +26,18 @@ export enum EmbeddingPrecision {
BFLOAT16 = "bfloat16",
}
export interface LLMContextualCost {
provider: string;
model_name: string;
cost: number;
}
export interface AdvancedSearchConfiguration {
index_name: string | null;
multipass_indexing: boolean;
enable_contextual_rag: boolean;
contextual_rag_llm_name: string | null;
contextual_rag_llm_provider: string | null;
multilingual_expansion: string[];
disable_rerank_for_streaming: boolean;
api_url: string | null;

View File

@ -3,7 +3,11 @@ import { Formik, Form, FormikProps, FieldArray, Field } from "formik";
import * as Yup from "yup";
import { TrashIcon } from "@/components/icons/icons";
import { FaPlus } from "react-icons/fa";
import { AdvancedSearchConfiguration, EmbeddingPrecision } from "../interfaces";
import {
AdvancedSearchConfiguration,
EmbeddingPrecision,
LLMContextualCost,
} from "../interfaces";
import {
BooleanFormField,
Label,
@ -12,6 +16,13 @@ import {
} from "@/components/admin/connectors/Field";
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";
import { StringOrNumberOption } from "@/components/Dropdown";
import useSWR from "swr";
import { LLM_CONTEXTUAL_COST_ADMIN_URL } from "../../configuration/llm/constants";
import { getDisplayNameForModel } from "@/lib/hooks";
import { errorHandlingFetcher } from "@/lib/fetcher";
// Number of tokens to show cost calculation for
const COST_CALCULATION_TOKENS = 1_000_000;
interface AdvancedEmbeddingFormPageProps {
updateAdvancedEmbeddingDetails: (
@ -45,14 +56,55 @@ const AdvancedEmbeddingFormPage = forwardRef<
},
ref
) => {
// Fetch contextual costs
const { data: contextualCosts, error: costError } = useSWR<
LLMContextualCost[]
>(LLM_CONTEXTUAL_COST_ADMIN_URL, errorHandlingFetcher);
const llmOptions: StringOrNumberOption[] = React.useMemo(
() =>
(contextualCosts || []).map((cost) => {
return {
name: getDisplayNameForModel(cost.model_name),
value: cost.model_name,
};
}),
[contextualCosts]
);
// Helper function to format cost as USD
const formatCost = (cost: number) => {
return new Intl.NumberFormat("en-US", {
style: "currency",
currency: "USD",
}).format(cost);
};
// Get cost info for selected model
const getSelectedModelCost = (modelName: string | null) => {
if (!contextualCosts || !modelName) return null;
return contextualCosts.find((cost) => cost.model_name === modelName);
};
// Get the current value for the selector based on the parent state
const getCurrentLLMValue = React.useMemo(() => {
if (!advancedEmbeddingDetails.contextual_rag_llm_name) return null;
return advancedEmbeddingDetails.contextual_rag_llm_name;
}, [advancedEmbeddingDetails.contextual_rag_llm_name]);
return (
<div className="py-4 rounded-lg max-w-4xl px-4 mx-auto">
<Formik
innerRef={ref}
initialValues={advancedEmbeddingDetails}
initialValues={{
...advancedEmbeddingDetails,
contextual_rag_llm: getCurrentLLMValue,
}}
validationSchema={Yup.object().shape({
multilingual_expansion: Yup.array().of(Yup.string()),
multipass_indexing: Yup.boolean(),
enable_contextual_rag: Yup.boolean(),
contextual_rag_llm: Yup.string().nullable(),
disable_rerank_for_streaming: Yup.boolean(),
num_rerank: Yup.number()
.required("Number of results to rerank is required")
@ -79,10 +131,26 @@ const AdvancedEmbeddingFormPage = forwardRef<
validate={(values) => {
// Call updateAdvancedEmbeddingDetails for each changed field
Object.entries(values).forEach(([key, value]) => {
updateAdvancedEmbeddingDetails(
key as keyof AdvancedSearchConfiguration,
value
);
if (key === "contextual_rag_llm") {
const selectedModel = (contextualCosts || []).find(
(cost) => cost.model_name === value
);
if (selectedModel) {
updateAdvancedEmbeddingDetails(
"contextual_rag_llm_provider",
selectedModel.provider
);
updateAdvancedEmbeddingDetails(
"contextual_rag_llm_name",
selectedModel.model_name
);
}
} else {
updateAdvancedEmbeddingDetails(
key as keyof AdvancedSearchConfiguration,
value
);
}
});
// Run validation and report errors
@ -190,6 +258,56 @@ const AdvancedEmbeddingFormPage = forwardRef<
label="Disable Rerank for Streaming"
name="disable_rerank_for_streaming"
/>
<BooleanFormField
subtext="Enable contextual RAG for all chunk sizes."
optional
label="Contextual RAG"
name="enable_contextual_rag"
/>
<div>
<SelectorFormField
name="contextual_rag_llm"
label="Contextual RAG LLM"
subtext={
costError
? "Error loading LLM models. Please try again later."
: !contextualCosts
? "Loading available LLM models..."
: values.enable_contextual_rag
? "Select the LLM model to use for contextual RAG processing."
: "Enable Contextual RAG above to select an LLM model."
}
options={llmOptions}
disabled={
!values.enable_contextual_rag ||
!contextualCosts ||
!!costError
}
/>
{values.enable_contextual_rag &&
values.contextual_rag_llm &&
!costError && (
<div className="mt-2 text-sm text-text-600">
{contextualCosts ? (
<>
Estimated cost for processing{" "}
{COST_CALCULATION_TOKENS.toLocaleString()} tokens:{" "}
<span className="font-medium">
{getSelectedModelCost(values.contextual_rag_llm)
? formatCost(
getSelectedModelCost(
values.contextual_rag_llm
)!.cost
)
: "Cost information not available"}
</span>
</>
) : (
"Loading cost information..."
)}
</div>
)}
</div>
<NumberInput
description="Number of results to rerank"
optional={false}

View File

@ -64,6 +64,9 @@ export default function EmbeddingForm() {
useState<AdvancedSearchConfiguration>({
index_name: "",
multipass_indexing: true,
enable_contextual_rag: false,
contextual_rag_llm_name: null,
contextual_rag_llm_provider: null,
multilingual_expansion: [],
disable_rerank_for_streaming: false,
api_url: null,
@ -152,6 +155,9 @@ export default function EmbeddingForm() {
setAdvancedEmbeddingDetails({
index_name: searchSettings.index_name,
multipass_indexing: searchSettings.multipass_indexing,
enable_contextual_rag: searchSettings.enable_contextual_rag,
contextual_rag_llm_name: searchSettings.contextual_rag_llm_name,
contextual_rag_llm_provider: searchSettings.contextual_rag_llm_provider,
multilingual_expansion: searchSettings.multilingual_expansion,
disable_rerank_for_streaming:
searchSettings.disable_rerank_for_streaming,
@ -197,7 +203,9 @@ export default function EmbeddingForm() {
searchSettings?.embedding_precision !=
advancedEmbeddingDetails.embedding_precision ||
searchSettings?.reduced_dimension !=
advancedEmbeddingDetails.reduced_dimension;
advancedEmbeddingDetails.reduced_dimension ||
searchSettings?.enable_contextual_rag !=
advancedEmbeddingDetails.enable_contextual_rag;
const updateSearch = useCallback(async () => {
if (!selectedProvider) {

View File

@ -676,6 +676,7 @@ interface SelectorFormFieldProps {
includeReset?: boolean;
fontSize?: "sm" | "md" | "lg";
small?: boolean;
disabled?: boolean;
}
export function SelectorFormField({
@ -691,6 +692,7 @@ export function SelectorFormField({
includeReset = false,
fontSize = "md",
small = false,
disabled = false,
}: SelectorFormFieldProps) {
const [field] = useField<string>(name);
const { setFieldValue } = useFormikContext();
@ -742,8 +744,9 @@ export function SelectorFormField({
: setFieldValue(name, selected))
}
defaultValue={defaultValue}
disabled={disabled}
>
<SelectTrigger className={sizeClass.input}>
<SelectTrigger className={sizeClass.input} disabled={disabled}>
<SelectValue placeholder="Select...">
{currentlySelected?.name || defaultValue || ""}
</SelectValue>