mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-28 17:01:10 +02:00
Expanded basic search (#4517)
* initial working version * ranking profile * modification for keyword/instruction retrieval * mypy fixes * EL comments * added env var (True for now) * flipped default to False * mypy & final EL/CW comments + import issue
This commit is contained in:
parent
e3aab8e85e
commit
2683207a24
@ -1,6 +1,8 @@
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.messages import ToolCall
|
from langchain_core.messages import ToolCall
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langgraph.types import StreamWriter
|
from langgraph.types import StreamWriter
|
||||||
@ -10,13 +12,21 @@ from onyx.agents.agent_search.models import GraphConfig
|
|||||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||||
from onyx.chat.tool_handling.tool_response_handler import (
|
from onyx.chat.tool_handling.tool_response_handler import (
|
||||||
get_tool_call_for_non_tool_calling_llm_impl,
|
get_tool_call_for_non_tool_calling_llm_impl,
|
||||||
)
|
)
|
||||||
|
from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||||
|
from onyx.llm.factory import get_default_llms
|
||||||
|
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||||
|
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||||
|
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||||
|
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||||
|
from onyx.tools.models import QueryExpansions
|
||||||
from onyx.tools.models import SearchToolOverrideKwargs
|
from onyx.tools.models import SearchToolOverrideKwargs
|
||||||
from onyx.tools.tool import Tool
|
from onyx.tools.tool import Tool
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||||
@ -30,6 +40,49 @@ from shared_configs.model_server_models import Embedding
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
|
||||||
|
# TODO: Add trimming logic
|
||||||
|
history_segments = []
|
||||||
|
for msg in prompt_builder.message_history:
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
role = "User"
|
||||||
|
elif isinstance(msg, AIMessage):
|
||||||
|
role = "Assistant"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
history_segments.append(f"{role}:\n {msg.content}\n\n")
|
||||||
|
return "\n".join(history_segments)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_query(
|
||||||
|
query: str,
|
||||||
|
expansion_type: QueryExpansionType,
|
||||||
|
prompt_builder: AnswerPromptBuilder,
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
history_str = _create_history_str(prompt_builder)
|
||||||
|
|
||||||
|
if history_str:
|
||||||
|
if expansion_type == QueryExpansionType.KEYWORD:
|
||||||
|
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||||
|
else:
|
||||||
|
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||||
|
expansion_prompt = base_prompt.format(question=query, history=history_str)
|
||||||
|
else:
|
||||||
|
if expansion_type == QueryExpansionType.KEYWORD:
|
||||||
|
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||||
|
else:
|
||||||
|
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||||
|
expansion_prompt = base_prompt.format(question=query)
|
||||||
|
|
||||||
|
msg = HumanMessage(content=expansion_prompt)
|
||||||
|
primary_llm, _ = get_default_llms()
|
||||||
|
response = primary_llm.invoke([msg])
|
||||||
|
rephrased_query: str = cast(str, response.content)
|
||||||
|
|
||||||
|
return rephrased_query
|
||||||
|
|
||||||
|
|
||||||
# TODO: break this out into an implementation function
|
# TODO: break this out into an implementation function
|
||||||
# and a function that handles extracting the necessary fields
|
# and a function that handles extracting the necessary fields
|
||||||
# from the state and config
|
# from the state and config
|
||||||
@ -52,7 +105,16 @@ def choose_tool(
|
|||||||
|
|
||||||
embedding_thread: TimeoutThread[Embedding] | None = None
|
embedding_thread: TimeoutThread[Embedding] | None = None
|
||||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||||
|
expanded_keyword_thread: TimeoutThread[str] | None = None
|
||||||
|
expanded_semantic_thread: TimeoutThread[str] | None = None
|
||||||
override_kwargs: SearchToolOverrideKwargs | None = None
|
override_kwargs: SearchToolOverrideKwargs | None = None
|
||||||
|
|
||||||
|
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||||
|
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||||
|
|
||||||
|
llm = agent_config.tooling.primary_llm
|
||||||
|
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not agent_config.behavior.use_agentic_search
|
not agent_config.behavior.use_agentic_search
|
||||||
and agent_config.tooling.search_tool is not None
|
and agent_config.tooling.search_tool is not None
|
||||||
@ -72,11 +134,20 @@ def choose_tool(
|
|||||||
agent_config.inputs.search_request.query,
|
agent_config.inputs.search_request.query,
|
||||||
)
|
)
|
||||||
|
|
||||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
if USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH:
|
||||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
|
||||||
|
|
||||||
llm = agent_config.tooling.primary_llm
|
expanded_keyword_thread = run_in_background(
|
||||||
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
|
_expand_query,
|
||||||
|
agent_config.inputs.search_request.query,
|
||||||
|
QueryExpansionType.KEYWORD,
|
||||||
|
prompt_builder,
|
||||||
|
)
|
||||||
|
expanded_semantic_thread = run_in_background(
|
||||||
|
_expand_query,
|
||||||
|
agent_config.inputs.search_request.query,
|
||||||
|
QueryExpansionType.SEMANTIC,
|
||||||
|
prompt_builder,
|
||||||
|
)
|
||||||
|
|
||||||
structured_response_format = agent_config.inputs.structured_response_format
|
structured_response_format = agent_config.inputs.structured_response_format
|
||||||
tools = [
|
tools = [
|
||||||
@ -209,6 +280,19 @@ def choose_tool(
|
|||||||
override_kwargs.precomputed_is_keyword = is_keyword
|
override_kwargs.precomputed_is_keyword = is_keyword
|
||||||
override_kwargs.precomputed_keywords = keywords
|
override_kwargs.precomputed_keywords = keywords
|
||||||
|
|
||||||
|
if (
|
||||||
|
selected_tool.name == SearchTool._NAME
|
||||||
|
and expanded_keyword_thread
|
||||||
|
and expanded_semantic_thread
|
||||||
|
):
|
||||||
|
keyword_expansion = wait_on_background(expanded_keyword_thread)
|
||||||
|
semantic_expansion = wait_on_background(expanded_semantic_thread)
|
||||||
|
assert override_kwargs is not None, "must have override kwargs"
|
||||||
|
override_kwargs.expanded_queries = QueryExpansions(
|
||||||
|
keywords_expansions=[keyword_expansion],
|
||||||
|
semantic_expansions=[semantic_expansion],
|
||||||
|
)
|
||||||
|
|
||||||
return ToolChoiceUpdate(
|
return ToolChoiceUpdate(
|
||||||
tool_choice=ToolChoice(
|
tool_choice=ToolChoice(
|
||||||
tool=selected_tool,
|
tool=selected_tool,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -153,3 +154,8 @@ class AnswerGenerationDocuments(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class QueryExpansionType(Enum):
|
||||||
|
KEYWORD = "keyword"
|
||||||
|
SEMANTIC = "semantic"
|
||||||
|
@ -96,3 +96,9 @@ BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
|||||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||||
|
|
||||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||||
|
|
||||||
|
# Whether or not to use the semantic & keyword search expansions for Basic Search
|
||||||
|
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
|
||||||
|
os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower()
|
||||||
|
== "true"
|
||||||
|
)
|
||||||
|
@ -18,11 +18,17 @@ from onyx.indexing.models import IndexingSetting
|
|||||||
from shared_configs.enums import RerankerProvider
|
from shared_configs.enums import RerankerProvider
|
||||||
from shared_configs.model_server_models import Embedding
|
from shared_configs.model_server_models import Embedding
|
||||||
|
|
||||||
|
|
||||||
MAX_METRICS_CONTENT = (
|
MAX_METRICS_CONTENT = (
|
||||||
200 # Just need enough characters to identify where in the doc the chunk is
|
200 # Just need enough characters to identify where in the doc the chunk is
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryExpansions(BaseModel):
|
||||||
|
keywords_expansions: list[str] | None = None
|
||||||
|
semantic_expansions: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class RerankingDetails(BaseModel):
|
class RerankingDetails(BaseModel):
|
||||||
# If model is None (or num_rerank is 0), then reranking is turned off
|
# If model is None (or num_rerank is 0), then reranking is turned off
|
||||||
rerank_model_name: str | None
|
rerank_model_name: str | None
|
||||||
@ -139,6 +145,8 @@ class ChunkContext(BaseModel):
|
|||||||
class SearchRequest(ChunkContext):
|
class SearchRequest(ChunkContext):
|
||||||
query: str
|
query: str
|
||||||
|
|
||||||
|
expanded_queries: QueryExpansions | None = None
|
||||||
|
|
||||||
search_type: SearchType = SearchType.SEMANTIC
|
search_type: SearchType = SearchType.SEMANTIC
|
||||||
|
|
||||||
human_selected_filters: BaseFilters | None = None
|
human_selected_filters: BaseFilters | None = None
|
||||||
@ -187,6 +195,8 @@ class SearchQuery(ChunkContext):
|
|||||||
|
|
||||||
precomputed_query_embedding: Embedding | None = None
|
precomputed_query_embedding: Embedding | None = None
|
||||||
|
|
||||||
|
expanded_queries: QueryExpansions | None = None
|
||||||
|
|
||||||
|
|
||||||
class RetrievalDetails(ChunkContext):
|
class RetrievalDetails(ChunkContext):
|
||||||
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
||||||
|
@ -20,7 +20,7 @@ from onyx.context.search.models import SearchRequest
|
|||||||
from onyx.context.search.preprocessing.access_filters import (
|
from onyx.context.search.preprocessing.access_filters import (
|
||||||
build_access_filters_for_user,
|
build_access_filters_for_user,
|
||||||
)
|
)
|
||||||
from onyx.context.search.retrieval.search_runner import (
|
from onyx.context.search.utils import (
|
||||||
remove_stop_words_and_punctuation,
|
remove_stop_words_and_punctuation,
|
||||||
)
|
)
|
||||||
from onyx.db.models import User
|
from onyx.db.models import User
|
||||||
@ -36,7 +36,6 @@ from onyx.utils.timing import log_function_time
|
|||||||
from shared_configs.configs import MULTI_TENANT
|
from shared_configs.configs import MULTI_TENANT
|
||||||
from shared_configs.contextvars import get_current_tenant_id
|
from shared_configs.contextvars import get_current_tenant_id
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
@ -264,4 +263,5 @@ def retrieval_preprocessing(
|
|||||||
chunks_below=chunks_below,
|
chunks_below=chunks_below,
|
||||||
full_doc=search_request.full_doc,
|
full_doc=search_request.full_doc,
|
||||||
precomputed_query_embedding=search_request.precomputed_query_embedding,
|
precomputed_query_embedding=search_request.precomputed_query_embedding,
|
||||||
|
expanded_queries=search_request.expanded_queries,
|
||||||
)
|
)
|
||||||
|
@ -2,10 +2,10 @@ import string
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import nltk # type:ignore
|
import nltk # type:ignore
|
||||||
from nltk.corpus import stopwords # type:ignore
|
|
||||||
from nltk.tokenize import word_tokenize # type:ignore
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||||
|
from onyx.context.search.enums import SearchType
|
||||||
from onyx.context.search.models import ChunkMetric
|
from onyx.context.search.models import ChunkMetric
|
||||||
from onyx.context.search.models import IndexFilters
|
from onyx.context.search.models import IndexFilters
|
||||||
from onyx.context.search.models import InferenceChunk
|
from onyx.context.search.models import InferenceChunk
|
||||||
@ -15,6 +15,8 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
|
|||||||
from onyx.context.search.models import RetrievalMetricsContainer
|
from onyx.context.search.models import RetrievalMetricsContainer
|
||||||
from onyx.context.search.models import SearchQuery
|
from onyx.context.search.models import SearchQuery
|
||||||
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
|
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
|
||||||
|
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA
|
||||||
|
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA_KEYWORD
|
||||||
from onyx.context.search.utils import inference_section_from_chunks
|
from onyx.context.search.utils import inference_section_from_chunks
|
||||||
from onyx.db.search_settings import get_current_search_settings
|
from onyx.db.search_settings import get_current_search_settings
|
||||||
from onyx.db.search_settings import get_multilingual_expansion
|
from onyx.db.search_settings import get_multilingual_expansion
|
||||||
@ -27,6 +29,9 @@ from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
|||||||
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||||
|
from onyx.utils.threadpool_concurrency import run_in_background
|
||||||
|
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||||
|
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||||
from onyx.utils.timing import log_function_time
|
from onyx.utils.timing import log_function_time
|
||||||
from shared_configs.configs import MODEL_SERVER_HOST
|
from shared_configs.configs import MODEL_SERVER_HOST
|
||||||
from shared_configs.configs import MODEL_SERVER_PORT
|
from shared_configs.configs import MODEL_SERVER_PORT
|
||||||
@ -36,6 +41,23 @@ from shared_configs.model_server_models import Embedding
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _dedupe_chunks(
|
||||||
|
chunks: list[InferenceChunkUncleaned],
|
||||||
|
) -> list[InferenceChunkUncleaned]:
|
||||||
|
used_chunks: dict[tuple[str, int], InferenceChunkUncleaned] = {}
|
||||||
|
for chunk in chunks:
|
||||||
|
key = (chunk.document_id, chunk.chunk_id)
|
||||||
|
if key not in used_chunks:
|
||||||
|
used_chunks[key] = chunk
|
||||||
|
else:
|
||||||
|
stored_chunk_score = used_chunks[key].score or 0
|
||||||
|
this_chunk_score = chunk.score or 0
|
||||||
|
if stored_chunk_score < this_chunk_score:
|
||||||
|
used_chunks[key] = chunk
|
||||||
|
|
||||||
|
return list(used_chunks.values())
|
||||||
|
|
||||||
|
|
||||||
def download_nltk_data() -> None:
|
def download_nltk_data() -> None:
|
||||||
resources = {
|
resources = {
|
||||||
"stopwords": "corpora/stopwords",
|
"stopwords": "corpora/stopwords",
|
||||||
@ -69,22 +91,6 @@ def lemmatize_text(keywords: list[str]) -> list[str]:
|
|||||||
# return keywords
|
# return keywords
|
||||||
|
|
||||||
|
|
||||||
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
|
||||||
try:
|
|
||||||
# Re-tokenize using the NLTK tokenizer for better matching
|
|
||||||
query = " ".join(keywords)
|
|
||||||
stop_words = set(stopwords.words("english"))
|
|
||||||
word_tokens = word_tokenize(query)
|
|
||||||
text_trimmed = [
|
|
||||||
word
|
|
||||||
for word in word_tokens
|
|
||||||
if (word.casefold() not in stop_words and word not in string.punctuation)
|
|
||||||
]
|
|
||||||
return text_trimmed or word_tokens
|
|
||||||
except Exception:
|
|
||||||
return keywords
|
|
||||||
|
|
||||||
|
|
||||||
def combine_retrieval_results(
|
def combine_retrieval_results(
|
||||||
chunk_sets: list[list[InferenceChunk]],
|
chunk_sets: list[list[InferenceChunk]],
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
@ -123,6 +129,20 @@ def get_query_embedding(query: str, db_session: Session) -> Embedding:
|
|||||||
return query_embedding
|
return query_embedding
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
|
||||||
|
search_settings = get_current_search_settings(db_session)
|
||||||
|
|
||||||
|
model = EmbeddingModel.from_db_model(
|
||||||
|
search_settings=search_settings,
|
||||||
|
# The below are globally set, this flow always uses the indexing one
|
||||||
|
server_host=MODEL_SERVER_HOST,
|
||||||
|
server_port=MODEL_SERVER_PORT,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||||
|
return query_embedding
|
||||||
|
|
||||||
|
|
||||||
@log_function_time(print_only=True)
|
@log_function_time(print_only=True)
|
||||||
def doc_index_retrieval(
|
def doc_index_retrieval(
|
||||||
query: SearchQuery,
|
query: SearchQuery,
|
||||||
@ -139,17 +159,113 @@ def doc_index_retrieval(
|
|||||||
query.query, db_session
|
query.query, db_session
|
||||||
)
|
)
|
||||||
|
|
||||||
top_chunks = document_index.hybrid_retrieval(
|
keyword_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
|
||||||
query=query.query,
|
semantic_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
|
||||||
query_embedding=query_embedding,
|
top_base_chunks_thread: TimeoutThread[list[InferenceChunkUncleaned]] | None = None
|
||||||
final_keywords=query.processed_keywords,
|
|
||||||
filters=query.filters,
|
top_semantic_chunks_thread: TimeoutThread[list[InferenceChunkUncleaned]] | None = (
|
||||||
hybrid_alpha=query.hybrid_alpha,
|
None
|
||||||
time_decay_multiplier=query.recency_bias_multiplier,
|
|
||||||
num_to_retrieve=query.num_hits,
|
|
||||||
offset=query.offset,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
keyword_embeddings: list[Embedding] | None = None
|
||||||
|
semantic_embeddings: list[Embedding] | None = None
|
||||||
|
|
||||||
|
top_semantic_chunks: list[InferenceChunkUncleaned] | None = None
|
||||||
|
|
||||||
|
# original retrieveal method
|
||||||
|
top_base_chunks_thread = run_in_background(
|
||||||
|
document_index.hybrid_retrieval,
|
||||||
|
query.query,
|
||||||
|
query_embedding,
|
||||||
|
query.processed_keywords,
|
||||||
|
query.filters,
|
||||||
|
query.hybrid_alpha,
|
||||||
|
query.recency_bias_multiplier,
|
||||||
|
query.num_hits,
|
||||||
|
"semantic",
|
||||||
|
query.offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
query.expanded_queries
|
||||||
|
and query.expanded_queries.keywords_expansions
|
||||||
|
and query.expanded_queries.semantic_expansions
|
||||||
|
):
|
||||||
|
|
||||||
|
keyword_embeddings_thread = run_in_background(
|
||||||
|
get_query_embeddings,
|
||||||
|
query.expanded_queries.keywords_expansions,
|
||||||
|
db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
if query.search_type == SearchType.SEMANTIC:
|
||||||
|
semantic_embeddings_thread = run_in_background(
|
||||||
|
get_query_embeddings,
|
||||||
|
query.expanded_queries.semantic_expansions,
|
||||||
|
db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
keyword_embeddings = wait_on_background(keyword_embeddings_thread)
|
||||||
|
if query.search_type == SearchType.SEMANTIC:
|
||||||
|
assert semantic_embeddings_thread is not None
|
||||||
|
semantic_embeddings = wait_on_background(semantic_embeddings_thread)
|
||||||
|
|
||||||
|
# Use original query embedding for keyword retrieval embedding
|
||||||
|
keyword_embeddings = [query_embedding]
|
||||||
|
|
||||||
|
# Note: we generally prepped earlier for multiple expansions, but for now we only use one.
|
||||||
|
top_keyword_chunks_thread = run_in_background(
|
||||||
|
document_index.hybrid_retrieval,
|
||||||
|
query.expanded_queries.keywords_expansions[0],
|
||||||
|
keyword_embeddings[0],
|
||||||
|
query.processed_keywords,
|
||||||
|
query.filters,
|
||||||
|
HYBRID_ALPHA_KEYWORD,
|
||||||
|
query.recency_bias_multiplier,
|
||||||
|
query.num_hits,
|
||||||
|
QueryExpansionType.KEYWORD,
|
||||||
|
query.offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
if query.search_type == SearchType.SEMANTIC:
|
||||||
|
assert semantic_embeddings is not None
|
||||||
|
|
||||||
|
top_semantic_chunks_thread = run_in_background(
|
||||||
|
document_index.hybrid_retrieval,
|
||||||
|
query.expanded_queries.semantic_expansions[0],
|
||||||
|
semantic_embeddings[0],
|
||||||
|
query.processed_keywords,
|
||||||
|
query.filters,
|
||||||
|
HYBRID_ALPHA,
|
||||||
|
query.recency_bias_multiplier,
|
||||||
|
query.num_hits,
|
||||||
|
QueryExpansionType.SEMANTIC,
|
||||||
|
query.offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
top_base_chunks = wait_on_background(top_base_chunks_thread)
|
||||||
|
|
||||||
|
top_keyword_chunks = wait_on_background(top_keyword_chunks_thread)
|
||||||
|
|
||||||
|
if query.search_type == SearchType.SEMANTIC:
|
||||||
|
assert top_semantic_chunks_thread is not None
|
||||||
|
top_semantic_chunks = wait_on_background(top_semantic_chunks_thread)
|
||||||
|
|
||||||
|
all_top_chunks = top_base_chunks + top_keyword_chunks
|
||||||
|
|
||||||
|
# use all three retrieval methods to retrieve top chunks
|
||||||
|
|
||||||
|
if query.search_type == SearchType.SEMANTIC and top_semantic_chunks is not None:
|
||||||
|
|
||||||
|
all_top_chunks += top_semantic_chunks
|
||||||
|
|
||||||
|
top_chunks = _dedupe_chunks(all_top_chunks)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
top_base_chunks = wait_on_background(top_base_chunks_thread)
|
||||||
|
top_chunks = _dedupe_chunks(top_base_chunks)
|
||||||
|
|
||||||
retrieval_requests: list[VespaChunkRequest] = []
|
retrieval_requests: list[VespaChunkRequest] = []
|
||||||
normal_chunks: list[InferenceChunkUncleaned] = []
|
normal_chunks: list[InferenceChunkUncleaned] = []
|
||||||
referenced_chunk_scores: dict[tuple[str, int], float] = {}
|
referenced_chunk_scores: dict[tuple[str, int], float] = {}
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
|
import string
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from nltk.corpus import stopwords # type:ignore
|
||||||
|
from nltk.tokenize import word_tokenize # type:ignore
|
||||||
|
|
||||||
from onyx.chat.models import SectionRelevancePiece
|
from onyx.chat.models import SectionRelevancePiece
|
||||||
from onyx.context.search.models import InferenceChunk
|
from onyx.context.search.models import InferenceChunk
|
||||||
from onyx.context.search.models import InferenceSection
|
from onyx.context.search.models import InferenceSection
|
||||||
@ -136,3 +140,19 @@ def chunks_or_sections_to_search_docs(
|
|||||||
]
|
]
|
||||||
|
|
||||||
return search_docs
|
return search_docs
|
||||||
|
|
||||||
|
|
||||||
|
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
||||||
|
try:
|
||||||
|
# Re-tokenize using the NLTK tokenizer for better matching
|
||||||
|
query = " ".join(keywords)
|
||||||
|
stop_words = set(stopwords.words("english"))
|
||||||
|
word_tokens = word_tokenize(query)
|
||||||
|
text_trimmed = [
|
||||||
|
word
|
||||||
|
for word in word_tokens
|
||||||
|
if (word.casefold() not in stop_words and word not in string.punctuation)
|
||||||
|
]
|
||||||
|
return text_trimmed or word_tokens
|
||||||
|
except Exception:
|
||||||
|
return keywords
|
||||||
|
@ -4,6 +4,8 @@ from datetime import datetime
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from onyx.access.models import DocumentAccess
|
from onyx.access.models import DocumentAccess
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||||
|
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||||
from onyx.context.search.models import IndexFilters
|
from onyx.context.search.models import IndexFilters
|
||||||
from onyx.context.search.models import InferenceChunkUncleaned
|
from onyx.context.search.models import InferenceChunkUncleaned
|
||||||
from onyx.db.enums import EmbeddingPrecision
|
from onyx.db.enums import EmbeddingPrecision
|
||||||
@ -351,7 +353,9 @@ class HybridCapable(abc.ABC):
|
|||||||
hybrid_alpha: float,
|
hybrid_alpha: float,
|
||||||
time_decay_multiplier: float,
|
time_decay_multiplier: float,
|
||||||
num_to_retrieve: int,
|
num_to_retrieve: int,
|
||||||
|
ranking_profile_type: QueryExpansionType,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
|
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||||
) -> list[InferenceChunkUncleaned]:
|
) -> list[InferenceChunkUncleaned]:
|
||||||
"""
|
"""
|
||||||
Run hybrid search and return a list of inference chunks.
|
Run hybrid search and return a list of inference chunks.
|
||||||
|
@ -176,7 +176,7 @@ schema DANSWER_CHUNK_NAME {
|
|||||||
match-features: recency_bias
|
match-features: recency_bias
|
||||||
}
|
}
|
||||||
|
|
||||||
rank-profile hybrid_searchVARIABLE_DIM inherits default, default_rank {
|
rank-profile hybrid_search_semantic_base_VARIABLE_DIM inherits default, default_rank {
|
||||||
inputs {
|
inputs {
|
||||||
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
||||||
}
|
}
|
||||||
@ -192,7 +192,75 @@ schema DANSWER_CHUNK_NAME {
|
|||||||
|
|
||||||
# First phase must be vector to allow hits that have no keyword matches
|
# First phase must be vector to allow hits that have no keyword matches
|
||||||
first-phase {
|
first-phase {
|
||||||
expression: closeness(field, embeddings)
|
expression: query(title_content_ratio) * closeness(field, title_embedding) + (1 - query(title_content_ratio)) * closeness(field, embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Weighted average between Vector Search and BM-25
|
||||||
|
global-phase {
|
||||||
|
expression {
|
||||||
|
(
|
||||||
|
# Weighted Vector Similarity Score
|
||||||
|
(
|
||||||
|
query(alpha) * (
|
||||||
|
(query(title_content_ratio) * normalize_linear(title_vector_score))
|
||||||
|
+
|
||||||
|
((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings)))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
+
|
||||||
|
|
||||||
|
# Weighted Keyword Similarity Score
|
||||||
|
# Note: for the BM25 Title score, it requires decent stopword removal in the query
|
||||||
|
# This needs to be the case so there aren't irrelevant titles being normalized to a score of 1
|
||||||
|
(
|
||||||
|
(1 - query(alpha)) * (
|
||||||
|
(query(title_content_ratio) * normalize_linear(bm25(title)))
|
||||||
|
+
|
||||||
|
((1 - query(title_content_ratio)) * normalize_linear(bm25(content)))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Boost based on user feedback
|
||||||
|
* document_boost
|
||||||
|
# Decay factor based on time document was last updated
|
||||||
|
* recency_bias
|
||||||
|
# Boost based on aggregated boost calculation
|
||||||
|
* aggregated_chunk_boost
|
||||||
|
}
|
||||||
|
rerank-count: 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
match-features {
|
||||||
|
bm25(title)
|
||||||
|
bm25(content)
|
||||||
|
closeness(field, title_embedding)
|
||||||
|
closeness(field, embeddings)
|
||||||
|
document_boost
|
||||||
|
recency_bias
|
||||||
|
aggregated_chunk_boost
|
||||||
|
closest(embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
rank-profile hybrid_search_keyword_base_VARIABLE_DIM inherits default, default_rank {
|
||||||
|
inputs {
|
||||||
|
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
||||||
|
}
|
||||||
|
|
||||||
|
function title_vector_score() {
|
||||||
|
expression {
|
||||||
|
# If no good matching titles, then it should use the context embeddings rather than having some
|
||||||
|
# irrelevant title have a vector score of 1. This way at least it will be the doc with the highest
|
||||||
|
# matching content score getting the full score
|
||||||
|
max(closeness(field, embeddings), closeness(field, title_embedding))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# First phase must be vector to allow hits that have no keyword matches
|
||||||
|
first-phase {
|
||||||
|
expression: query(title_content_ratio) * bm25(title) + (1 - query(title_content_ratio)) * bm25(content)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Weighted average between Vector Search and BM-25
|
# Weighted average between Vector Search and BM-25
|
||||||
|
@ -19,6 +19,7 @@ import httpx # type: ignore
|
|||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
from retry import retry
|
from retry import retry
|
||||||
|
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||||
@ -800,12 +801,14 @@ class VespaIndex(DocumentIndex):
|
|||||||
hybrid_alpha: float,
|
hybrid_alpha: float,
|
||||||
time_decay_multiplier: float,
|
time_decay_multiplier: float,
|
||||||
num_to_retrieve: int,
|
num_to_retrieve: int,
|
||||||
|
ranking_profile_type: QueryExpansionType,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||||
) -> list[InferenceChunkUncleaned]:
|
) -> list[InferenceChunkUncleaned]:
|
||||||
vespa_where_clauses = build_vespa_filters(filters)
|
vespa_where_clauses = build_vespa_filters(filters)
|
||||||
# Needs to be at least as much as the value set in Vespa schema config
|
# Needs to be at least as much as the value set in Vespa schema config
|
||||||
target_hits = max(10 * num_to_retrieve, 1000)
|
target_hits = max(10 * num_to_retrieve, 1000)
|
||||||
|
|
||||||
yql = (
|
yql = (
|
||||||
YQL_BASE.format(index_name=self.index_name)
|
YQL_BASE.format(index_name=self.index_name)
|
||||||
+ vespa_where_clauses
|
+ vespa_where_clauses
|
||||||
@ -817,6 +820,11 @@ class VespaIndex(DocumentIndex):
|
|||||||
|
|
||||||
final_query = " ".join(final_keywords) if final_keywords else query
|
final_query = " ".join(final_keywords) if final_keywords else query
|
||||||
|
|
||||||
|
if ranking_profile_type == QueryExpansionType.KEYWORD:
|
||||||
|
ranking_profile = f"hybrid_search_keyword_base_{len(query_embedding)}"
|
||||||
|
else:
|
||||||
|
ranking_profile = f"hybrid_search_semantic_base_{len(query_embedding)}"
|
||||||
|
|
||||||
logger.debug(f"Query YQL: {yql}")
|
logger.debug(f"Query YQL: {yql}")
|
||||||
|
|
||||||
params: dict[str, str | int | float] = {
|
params: dict[str, str | int | float] = {
|
||||||
@ -832,7 +840,7 @@ class VespaIndex(DocumentIndex):
|
|||||||
),
|
),
|
||||||
"hits": num_to_retrieve,
|
"hits": num_to_retrieve,
|
||||||
"offset": offset,
|
"offset": offset,
|
||||||
"ranking.profile": f"hybrid_search{len(query_embedding)}",
|
"ranking.profile": ranking_profile,
|
||||||
"timeout": VESPA_TIMEOUT,
|
"timeout": VESPA_TIMEOUT,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,3 +246,75 @@ Please give a short succinct summary of the entire document. Answer only with th
|
|||||||
summary and nothing else. """
|
summary and nothing else. """
|
||||||
|
|
||||||
DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29
|
DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29
|
||||||
|
|
||||||
|
|
||||||
|
QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT = """
|
||||||
|
Please rephrase the following user question/query as a semantic query that would be appropriate for a \
|
||||||
|
search engine.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- do not change the meaning of the question! Specifically, if the query is a an instruction, keep it \
|
||||||
|
as an instruction!
|
||||||
|
|
||||||
|
Here is the user question/query:
|
||||||
|
{question}
|
||||||
|
|
||||||
|
Respond with EXACTLY and ONLY one rephrased question/query.
|
||||||
|
|
||||||
|
Rephrased question/query for search engine:
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT = """
|
||||||
|
Following a previous message history, a user created a follow-up question/query.
|
||||||
|
Please rephrase that question/query as a semantic query \
|
||||||
|
that would be appropriate for a SEARCH ENGINE. Only use the information provided \
|
||||||
|
from the history that is relevant to provide the relevant context for the search query, \
|
||||||
|
meaning that the rephrased search query should be a suitable stand-alone search query.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- do not change the meaning of the question! Specifically, if the query is a an instruction, keep it \
|
||||||
|
as an instruction!
|
||||||
|
|
||||||
|
Here is the relevant previous message history:
|
||||||
|
{history}
|
||||||
|
|
||||||
|
Here is the user question:
|
||||||
|
{question}
|
||||||
|
|
||||||
|
Respond with EXACTLY and ONLY one rephrased query.
|
||||||
|
|
||||||
|
Rephrased query for search engine:
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT = """
|
||||||
|
Please rephrase the following user question as a keyword query that would be appropriate for a \
|
||||||
|
search engine.
|
||||||
|
|
||||||
|
Here is the user question:
|
||||||
|
{question}
|
||||||
|
|
||||||
|
Respond with EXACTLY and ONLY one rephrased query.
|
||||||
|
|
||||||
|
Rephrased query for search engine:
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT = """
|
||||||
|
Following a previous message history, a user created a follow-up question/query.
|
||||||
|
Please rephrase that question/query as a keyword query \
|
||||||
|
that would be appropriate for a SEARCH ENGINE. Only use the information provided \
|
||||||
|
from the history that is relevant to provide the relevant context for the search query, \
|
||||||
|
meaning that the rephrased search query should be a suitable stand-alone search query.
|
||||||
|
|
||||||
|
Here is the relevant previous message history:
|
||||||
|
{history}
|
||||||
|
|
||||||
|
Here is the user question:
|
||||||
|
{question}
|
||||||
|
|
||||||
|
Respond with EXACTLY and ONLY one rephrased query.
|
||||||
|
|
||||||
|
Rephrased query for search engine:
|
||||||
|
""".strip()
|
||||||
|
@ -11,6 +11,7 @@ from onyx.configs.constants import DocumentSource
|
|||||||
from onyx.context.search.enums import SearchType
|
from onyx.context.search.enums import SearchType
|
||||||
from onyx.context.search.models import IndexFilters
|
from onyx.context.search.models import IndexFilters
|
||||||
from onyx.context.search.models import InferenceSection
|
from onyx.context.search.models import InferenceSection
|
||||||
|
from onyx.context.search.models import QueryExpansions
|
||||||
from shared_configs.model_server_models import Embedding
|
from shared_configs.model_server_models import Embedding
|
||||||
|
|
||||||
|
|
||||||
@ -79,6 +80,7 @@ class SearchToolOverrideKwargs(BaseModel):
|
|||||||
)
|
)
|
||||||
document_sources: list[DocumentSource] | None = None
|
document_sources: list[DocumentSource] | None = None
|
||||||
time_cutoff: datetime | None = None
|
time_cutoff: datetime | None = None
|
||||||
|
expanded_queries: QueryExpansions | None = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
@ -295,6 +295,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
|||||||
ordering_only = False
|
ordering_only = False
|
||||||
document_sources = None
|
document_sources = None
|
||||||
time_cutoff = None
|
time_cutoff = None
|
||||||
|
expanded_queries = None
|
||||||
if override_kwargs:
|
if override_kwargs:
|
||||||
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
|
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
|
||||||
alternate_db_session = override_kwargs.alternate_db_session
|
alternate_db_session = override_kwargs.alternate_db_session
|
||||||
@ -307,6 +308,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
|||||||
ordering_only = use_alt_not_None(override_kwargs.ordering_only, False)
|
ordering_only = use_alt_not_None(override_kwargs.ordering_only, False)
|
||||||
document_sources = override_kwargs.document_sources
|
document_sources = override_kwargs.document_sources
|
||||||
time_cutoff = override_kwargs.time_cutoff
|
time_cutoff = override_kwargs.time_cutoff
|
||||||
|
expanded_queries = override_kwargs.expanded_queries
|
||||||
|
|
||||||
# Fast path for ordering-only search
|
# Fast path for ordering-only search
|
||||||
if ordering_only:
|
if ordering_only:
|
||||||
@ -391,6 +393,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
|||||||
precomputed_query_embedding=precomputed_query_embedding,
|
precomputed_query_embedding=precomputed_query_embedding,
|
||||||
precomputed_is_keyword=precomputed_is_keyword,
|
precomputed_is_keyword=precomputed_is_keyword,
|
||||||
precomputed_keywords=precomputed_keywords,
|
precomputed_keywords=precomputed_keywords,
|
||||||
|
# add expanded queries
|
||||||
|
expanded_queries=expanded_queries,
|
||||||
),
|
),
|
||||||
user=self.user,
|
user=self.user,
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
|
@ -5,6 +5,7 @@ RUN THIS AFTER SEED_DUMMY_DOCS.PY
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
from onyx.configs.model_configs import DOC_EMBEDDING_DIM
|
from onyx.configs.model_configs import DOC_EMBEDDING_DIM
|
||||||
from onyx.context.search.models import IndexFilters
|
from onyx.context.search.models import IndexFilters
|
||||||
@ -96,6 +97,7 @@ def test_hybrid_retrieval_times(
|
|||||||
hybrid_alpha=0.5,
|
hybrid_alpha=0.5,
|
||||||
time_decay_multiplier=1.0,
|
time_decay_multiplier=1.0,
|
||||||
num_to_retrieve=50,
|
num_to_retrieve=50,
|
||||||
|
ranking_profile_type=QueryExpansionType.SEMANTIC,
|
||||||
offset=0,
|
offset=0,
|
||||||
title_content_ratio=0.5,
|
title_content_ratio=0.5,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user