mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-01 02:00:48 +02:00
Remove OnyxContext (#4376)
* Remove OnyxContext * Fix UT * Fix tests v2
This commit is contained in:
@ -2,7 +2,6 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
|
|||||||
from onyx.chat.models import AllCitations
|
from onyx.chat.models import AllCitations
|
||||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||||
from onyx.chat.models import OnyxAnswerPiece
|
from onyx.chat.models import OnyxAnswerPiece
|
||||||
from onyx.chat.models import OnyxContexts
|
|
||||||
from onyx.chat.models import QADocsResponse
|
from onyx.chat.models import QADocsResponse
|
||||||
from onyx.chat.models import StreamingError
|
from onyx.chat.models import StreamingError
|
||||||
from onyx.chat.process_message import ChatPacketStream
|
from onyx.chat.process_message import ChatPacketStream
|
||||||
@ -32,8 +31,6 @@ def gather_stream_for_answer_api(
|
|||||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||||
elif isinstance(packet, AllCitations):
|
elif isinstance(packet, AllCitations):
|
||||||
response.citations = packet.citations
|
response.citations = packet.citations
|
||||||
elif isinstance(packet, OnyxContexts):
|
|
||||||
response.contexts = packet
|
|
||||||
|
|
||||||
if answer:
|
if answer:
|
||||||
response.answer = answer
|
response.answer = answer
|
||||||
|
@ -8,7 +8,6 @@ from pydantic import model_validator
|
|||||||
|
|
||||||
from ee.onyx.server.manage.models import StandardAnswer
|
from ee.onyx.server.manage.models import StandardAnswer
|
||||||
from onyx.chat.models import CitationInfo
|
from onyx.chat.models import CitationInfo
|
||||||
from onyx.chat.models import OnyxContexts
|
|
||||||
from onyx.chat.models import PersonaOverrideConfig
|
from onyx.chat.models import PersonaOverrideConfig
|
||||||
from onyx.chat.models import QADocsResponse
|
from onyx.chat.models import QADocsResponse
|
||||||
from onyx.chat.models import SubQuestionIdentifier
|
from onyx.chat.models import SubQuestionIdentifier
|
||||||
@ -220,4 +219,3 @@ class OneShotQAResponse(BaseModel):
|
|||||||
llm_selected_doc_indices: list[int] | None = None
|
llm_selected_doc_indices: list[int] | None = None
|
||||||
error_msg: str | None = None
|
error_msg: str | None = None
|
||||||
chat_message_id: int | None = None
|
chat_message_id: int | None = None
|
||||||
contexts: OnyxContexts | None = None
|
|
||||||
|
@ -7,7 +7,6 @@ from langgraph.types import StreamWriter
|
|||||||
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.chat.models import OnyxContext
|
|
||||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||||
from onyx.chat.stream_processing.answer_response_handler import (
|
from onyx.chat.stream_processing.answer_response_handler import (
|
||||||
@ -24,7 +23,7 @@ def process_llm_stream(
|
|||||||
should_stream_answer: bool,
|
should_stream_answer: bool,
|
||||||
writer: StreamWriter,
|
writer: StreamWriter,
|
||||||
final_search_results: list[LlmDoc] | None = None,
|
final_search_results: list[LlmDoc] | None = None,
|
||||||
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
|
displayed_search_results: list[LlmDoc] | None = None,
|
||||||
) -> AIMessageChunk:
|
) -> AIMessageChunk:
|
||||||
tool_call_chunk = AIMessageChunk(content="")
|
tool_call_chunk = AIMessageChunk(content="")
|
||||||
|
|
||||||
|
@ -156,7 +156,6 @@ def generate_initial_answer(
|
|||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=question,
|
query=question,
|
||||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
|
||||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||||
search_query_info=query_info,
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: relevance_list,
|
get_section_relevance=lambda: relevance_list,
|
||||||
|
@ -183,7 +183,6 @@ def generate_validate_refined_answer(
|
|||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=question,
|
query=question,
|
||||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
|
||||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||||
search_query_info=query_info,
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: relevance_list,
|
get_section_relevance=lambda: relevance_list,
|
||||||
|
@ -57,7 +57,6 @@ def format_results(
|
|||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=state.question,
|
query=state.question,
|
||||||
get_retrieved_sections=lambda: reranked_documents,
|
get_retrieved_sections=lambda: reranked_documents,
|
||||||
get_reranked_sections=lambda: state.retrieved_documents,
|
|
||||||
get_final_context_sections=lambda: reranked_documents,
|
get_final_context_sections=lambda: reranked_documents,
|
||||||
search_query_info=query_info,
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: relevance_list,
|
get_section_relevance=lambda: relevance_list,
|
||||||
|
@ -13,9 +13,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
|||||||
SEARCH_RESPONSE_SUMMARY_ID,
|
SEARCH_RESPONSE_SUMMARY_ID,
|
||||||
)
|
)
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||||
from onyx.tools.tool_implementations.search.search_utils import (
|
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
|
||||||
context_from_inference_section,
|
|
||||||
)
|
|
||||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
)
|
)
|
||||||
@ -59,9 +57,7 @@ def basic_use_tool_response(
|
|||||||
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
||||||
for section in search_response_summary.top_sections:
|
for section in search_response_summary.top_sections:
|
||||||
if section.center_chunk.document_id not in initial_search_results:
|
if section.center_chunk.document_id not in initial_search_results:
|
||||||
initial_search_results.append(
|
initial_search_results.append(section_to_llm_doc(section))
|
||||||
context_from_inference_section(section)
|
|
||||||
)
|
|
||||||
|
|
||||||
new_tool_call_chunk = AIMessageChunk(content="")
|
new_tool_call_chunk = AIMessageChunk(content="")
|
||||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||||
|
@ -194,17 +194,6 @@ class StreamingError(BaseModel):
|
|||||||
stack_trace: str | None = None
|
stack_trace: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class OnyxContext(BaseModel):
|
|
||||||
content: str
|
|
||||||
document_id: str
|
|
||||||
semantic_identifier: str
|
|
||||||
blurb: str
|
|
||||||
|
|
||||||
|
|
||||||
class OnyxContexts(BaseModel):
|
|
||||||
contexts: list[OnyxContext]
|
|
||||||
|
|
||||||
|
|
||||||
class OnyxAnswer(BaseModel):
|
class OnyxAnswer(BaseModel):
|
||||||
answer: str | None
|
answer: str | None
|
||||||
|
|
||||||
@ -270,7 +259,6 @@ class PersonaOverrideConfig(BaseModel):
|
|||||||
AnswerQuestionPossibleReturn = (
|
AnswerQuestionPossibleReturn = (
|
||||||
OnyxAnswerPiece
|
OnyxAnswerPiece
|
||||||
| CitationInfo
|
| CitationInfo
|
||||||
| OnyxContexts
|
|
||||||
| FileChatDisplay
|
| FileChatDisplay
|
||||||
| CustomToolResponse
|
| CustomToolResponse
|
||||||
| StreamingError
|
| StreamingError
|
||||||
|
@ -29,7 +29,6 @@ from onyx.chat.models import LLMRelevanceFilterResponse
|
|||||||
from onyx.chat.models import MessageResponseIDInfo
|
from onyx.chat.models import MessageResponseIDInfo
|
||||||
from onyx.chat.models import MessageSpecificCitations
|
from onyx.chat.models import MessageSpecificCitations
|
||||||
from onyx.chat.models import OnyxAnswerPiece
|
from onyx.chat.models import OnyxAnswerPiece
|
||||||
from onyx.chat.models import OnyxContexts
|
|
||||||
from onyx.chat.models import PromptConfig
|
from onyx.chat.models import PromptConfig
|
||||||
from onyx.chat.models import QADocsResponse
|
from onyx.chat.models import QADocsResponse
|
||||||
from onyx.chat.models import RefinedAnswerImprovement
|
from onyx.chat.models import RefinedAnswerImprovement
|
||||||
@ -131,7 +130,6 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
|
|||||||
from onyx.tools.tool_implementations.search.search_tool import (
|
from onyx.tools.tool_implementations.search.search_tool import (
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
)
|
)
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
|
||||||
from onyx.tools.tool_implementations.search.search_tool import (
|
from onyx.tools.tool_implementations.search.search_tool import (
|
||||||
SEARCH_RESPONSE_SUMMARY_ID,
|
SEARCH_RESPONSE_SUMMARY_ID,
|
||||||
)
|
)
|
||||||
@ -300,7 +298,6 @@ def _get_force_search_settings(
|
|||||||
ChatPacket = (
|
ChatPacket = (
|
||||||
StreamingError
|
StreamingError
|
||||||
| QADocsResponse
|
| QADocsResponse
|
||||||
| OnyxContexts
|
|
||||||
| LLMRelevanceFilterResponse
|
| LLMRelevanceFilterResponse
|
||||||
| FinalUsedContextDocsResponse
|
| FinalUsedContextDocsResponse
|
||||||
| ChatMessageDetail
|
| ChatMessageDetail
|
||||||
@ -919,8 +916,6 @@ def stream_chat_message_objects(
|
|||||||
response=custom_tool_response.tool_result,
|
response=custom_tool_response.tool_result,
|
||||||
tool_name=custom_tool_response.tool_name,
|
tool_name=custom_tool_response.tool_name,
|
||||||
)
|
)
|
||||||
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
|
||||||
yield cast(OnyxContexts, packet.response)
|
|
||||||
|
|
||||||
elif isinstance(packet, StreamStopInfo):
|
elif isinstance(packet, StreamStopInfo):
|
||||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||||
|
@ -3,7 +3,6 @@ from collections.abc import Sequence
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.chat.models import OnyxContext
|
|
||||||
from onyx.context.search.models import InferenceChunk
|
from onyx.context.search.models import InferenceChunk
|
||||||
|
|
||||||
|
|
||||||
@ -12,7 +11,7 @@ class DocumentIdOrderMapping(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def map_document_id_order(
|
def map_document_id_order(
|
||||||
chunks: Sequence[InferenceChunk | LlmDoc | OnyxContext], one_indexed: bool = True
|
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||||
) -> DocumentIdOrderMapping:
|
) -> DocumentIdOrderMapping:
|
||||||
order_mapping = {}
|
order_mapping = {}
|
||||||
current = 1 if one_indexed else 0
|
current = 1 if one_indexed else 0
|
||||||
|
@ -415,6 +415,10 @@ class SearchPipeline:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
|
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
|
||||||
)
|
)
|
||||||
|
# NOTE: final_context_sections must be accessed before accessing self._postprocessing_generator
|
||||||
|
# since the property sets the generator. DO NOT REMOVE.
|
||||||
|
_ = self.final_context_sections
|
||||||
|
|
||||||
self._section_relevance = next(
|
self._section_relevance = next(
|
||||||
cast(
|
cast(
|
||||||
Iterator[list[SectionRelevancePiece]],
|
Iterator[list[SectionRelevancePiece]],
|
||||||
|
@ -12,7 +12,6 @@ from onyx.chat.models import AnswerStyleConfig
|
|||||||
from onyx.chat.models import ContextualPruningConfig
|
from onyx.chat.models import ContextualPruningConfig
|
||||||
from onyx.chat.models import DocumentPruningConfig
|
from onyx.chat.models import DocumentPruningConfig
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.chat.models import OnyxContexts
|
|
||||||
from onyx.chat.models import PromptConfig
|
from onyx.chat.models import PromptConfig
|
||||||
from onyx.chat.models import SectionRelevancePiece
|
from onyx.chat.models import SectionRelevancePiece
|
||||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||||
@ -42,9 +41,6 @@ from onyx.tools.models import SearchQueryInfo
|
|||||||
from onyx.tools.models import SearchToolOverrideKwargs
|
from onyx.tools.models import SearchToolOverrideKwargs
|
||||||
from onyx.tools.models import ToolResponse
|
from onyx.tools.models import ToolResponse
|
||||||
from onyx.tools.tool import Tool
|
from onyx.tools.tool import Tool
|
||||||
from onyx.tools.tool_implementations.search.search_utils import (
|
|
||||||
context_from_inference_section,
|
|
||||||
)
|
|
||||||
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||||
build_next_prompt_for_search_like_tool,
|
build_next_prompt_for_search_like_tool,
|
||||||
@ -58,7 +54,6 @@ from onyx.utils.special_types import JSON_ro
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
|
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
|
||||||
SEARCH_DOC_CONTENT_ID = "search_doc_content"
|
|
||||||
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
||||||
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||||
QUERY_FIELD = "query"
|
QUERY_FIELD = "query"
|
||||||
@ -357,13 +352,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
|||||||
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
||||||
)
|
)
|
||||||
yield from yield_search_responses(
|
yield from yield_search_responses(
|
||||||
query,
|
query=query,
|
||||||
lambda: search_pipeline.retrieved_sections,
|
get_retrieved_sections=lambda: search_pipeline.retrieved_sections,
|
||||||
lambda: search_pipeline.reranked_sections,
|
get_final_context_sections=lambda: search_pipeline.final_context_sections,
|
||||||
lambda: search_pipeline.final_context_sections,
|
search_query_info=search_query_info,
|
||||||
search_query_info,
|
get_section_relevance=lambda: search_pipeline.section_relevance,
|
||||||
lambda: search_pipeline.section_relevance,
|
search_tool=self,
|
||||||
self,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||||
@ -405,7 +399,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
|||||||
def yield_search_responses(
|
def yield_search_responses(
|
||||||
query: str,
|
query: str,
|
||||||
get_retrieved_sections: Callable[[], list[InferenceSection]],
|
get_retrieved_sections: Callable[[], list[InferenceSection]],
|
||||||
get_reranked_sections: Callable[[], list[InferenceSection]],
|
|
||||||
get_final_context_sections: Callable[[], list[InferenceSection]],
|
get_final_context_sections: Callable[[], list[InferenceSection]],
|
||||||
search_query_info: SearchQueryInfo,
|
search_query_info: SearchQueryInfo,
|
||||||
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
|
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
|
||||||
@ -423,16 +416,6 @@ def yield_search_responses(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ToolResponse(
|
|
||||||
id=SEARCH_DOC_CONTENT_ID,
|
|
||||||
response=OnyxContexts(
|
|
||||||
contexts=[
|
|
||||||
context_from_inference_section(section)
|
|
||||||
for section in get_reranked_sections()
|
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
section_relevance = get_section_relevance()
|
section_relevance = get_section_relevance()
|
||||||
yield ToolResponse(
|
yield ToolResponse(
|
||||||
id=SECTION_RELEVANCE_LIST_ID,
|
id=SECTION_RELEVANCE_LIST_ID,
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.chat.models import OnyxContext
|
|
||||||
from onyx.context.search.models import InferenceSection
|
from onyx.context.search.models import InferenceSection
|
||||||
from onyx.prompts.prompt_utils import clean_up_source
|
from onyx.prompts.prompt_utils import clean_up_source
|
||||||
|
|
||||||
@ -32,10 +31,23 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
|
|||||||
return doc_dict
|
return doc_dict
|
||||||
|
|
||||||
|
|
||||||
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
|
def section_to_llm_doc(section: InferenceSection) -> LlmDoc:
|
||||||
return OnyxContext(
|
possible_link_chunks = [section.center_chunk] + section.chunks
|
||||||
content=section.combined_content,
|
link: str | None = None
|
||||||
|
for chunk in possible_link_chunks:
|
||||||
|
if chunk.source_links:
|
||||||
|
link = list(chunk.source_links.values())[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
return LlmDoc(
|
||||||
document_id=section.center_chunk.document_id,
|
document_id=section.center_chunk.document_id,
|
||||||
|
content=section.combined_content,
|
||||||
|
source_type=section.center_chunk.source_type,
|
||||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||||
|
metadata=section.center_chunk.metadata,
|
||||||
|
updated_at=section.center_chunk.updated_at,
|
||||||
blurb=section.center_chunk.blurb,
|
blurb=section.center_chunk.blurb,
|
||||||
|
link=link,
|
||||||
|
source_links=section.center_chunk.source_links,
|
||||||
|
match_highlights=section.center_chunk.match_highlights,
|
||||||
)
|
)
|
||||||
|
@ -9,8 +9,6 @@ from onyx.chat.chat_utils import llm_doc_from_inference_section
|
|||||||
from onyx.chat.models import AnswerStyleConfig
|
from onyx.chat.models import AnswerStyleConfig
|
||||||
from onyx.chat.models import CitationConfig
|
from onyx.chat.models import CitationConfig
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.chat.models import OnyxContext
|
|
||||||
from onyx.chat.models import OnyxContexts
|
|
||||||
from onyx.chat.models import PromptConfig
|
from onyx.chat.models import PromptConfig
|
||||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
@ -19,7 +17,6 @@ from onyx.context.search.models import InferenceSection
|
|||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
from onyx.llm.interfaces import LLMConfig
|
from onyx.llm.interfaces import LLMConfig
|
||||||
from onyx.tools.models import ToolResponse
|
from onyx.tools.models import ToolResponse
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
@ -120,24 +117,7 @@ def mock_search_results(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts:
|
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
|
||||||
return OnyxContexts(
|
|
||||||
contexts=[
|
|
||||||
OnyxContext(
|
|
||||||
content=section.combined_content,
|
|
||||||
document_id=section.center_chunk.document_id,
|
|
||||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
|
||||||
blurb=section.center_chunk.blurb,
|
|
||||||
)
|
|
||||||
for section in mock_inference_sections
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_search_tool(
|
|
||||||
mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc]
|
|
||||||
) -> MagicMock:
|
|
||||||
mock_tool = MagicMock(spec=SearchTool)
|
mock_tool = MagicMock(spec=SearchTool)
|
||||||
mock_tool.name = "search"
|
mock_tool.name = "search"
|
||||||
mock_tool.build_tool_message_content.return_value = "search_response"
|
mock_tool.build_tool_message_content.return_value = "search_response"
|
||||||
@ -146,7 +126,6 @@ def mock_search_tool(
|
|||||||
json.loads(doc.model_dump_json()) for doc in mock_search_results
|
json.loads(doc.model_dump_json()) for doc in mock_search_results
|
||||||
]
|
]
|
||||||
mock_tool.run.return_value = [
|
mock_tool.run.return_value = [
|
||||||
ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts),
|
|
||||||
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
|
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
|
||||||
]
|
]
|
||||||
mock_tool.tool_definition.return_value = {
|
mock_tool.tool_definition.return_value = {
|
||||||
|
@ -19,7 +19,6 @@ from onyx.chat.models import AnswerStyleConfig
|
|||||||
from onyx.chat.models import CitationInfo
|
from onyx.chat.models import CitationInfo
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.chat.models import OnyxAnswerPiece
|
from onyx.chat.models import OnyxAnswerPiece
|
||||||
from onyx.chat.models import OnyxContexts
|
|
||||||
from onyx.chat.models import PromptConfig
|
from onyx.chat.models import PromptConfig
|
||||||
from onyx.chat.models import StreamStopInfo
|
from onyx.chat.models import StreamStopInfo
|
||||||
from onyx.chat.models import StreamStopReason
|
from onyx.chat.models import StreamStopReason
|
||||||
@ -33,7 +32,6 @@ from onyx.tools.force import ForceUseTool
|
|||||||
from onyx.tools.models import ToolCallFinalResult
|
from onyx.tools.models import ToolCallFinalResult
|
||||||
from onyx.tools.models import ToolCallKickoff
|
from onyx.tools.models import ToolCallKickoff
|
||||||
from onyx.tools.models import ToolResponse
|
from onyx.tools.models import ToolResponse
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
|
||||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
)
|
)
|
||||||
@ -141,7 +139,6 @@ def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
|
|||||||
def test_answer_with_search_call(
|
def test_answer_with_search_call(
|
||||||
answer_instance: Answer,
|
answer_instance: Answer,
|
||||||
mock_search_results: list[LlmDoc],
|
mock_search_results: list[LlmDoc],
|
||||||
mock_contexts: OnyxContexts,
|
|
||||||
mock_search_tool: MagicMock,
|
mock_search_tool: MagicMock,
|
||||||
force_use_tool: ForceUseTool,
|
force_use_tool: ForceUseTool,
|
||||||
expected_tool_args: dict,
|
expected_tool_args: dict,
|
||||||
@ -197,25 +194,21 @@ def test_answer_with_search_call(
|
|||||||
tool_name="search", tool_args=expected_tool_args
|
tool_name="search", tool_args=expected_tool_args
|
||||||
)
|
)
|
||||||
assert output[1] == ToolResponse(
|
assert output[1] == ToolResponse(
|
||||||
id=SEARCH_DOC_CONTENT_ID,
|
|
||||||
response=mock_contexts,
|
|
||||||
)
|
|
||||||
assert output[2] == ToolResponse(
|
|
||||||
id="final_context_documents",
|
id="final_context_documents",
|
||||||
response=mock_search_results,
|
response=mock_search_results,
|
||||||
)
|
)
|
||||||
assert output[3] == ToolCallFinalResult(
|
assert output[2] == ToolCallFinalResult(
|
||||||
tool_name="search",
|
tool_name="search",
|
||||||
tool_args=expected_tool_args,
|
tool_args=expected_tool_args,
|
||||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||||
)
|
)
|
||||||
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||||
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
||||||
assert output[5] == expected_citation
|
assert output[4] == expected_citation
|
||||||
assert output[6] == OnyxAnswerPiece(
|
assert output[5] == OnyxAnswerPiece(
|
||||||
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
||||||
)
|
)
|
||||||
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||||
|
|
||||||
expected_answer = (
|
expected_answer = (
|
||||||
"Based on the search results, "
|
"Based on the search results, "
|
||||||
@ -268,7 +261,6 @@ def test_answer_with_search_call(
|
|||||||
def test_answer_with_search_no_tool_calling(
|
def test_answer_with_search_no_tool_calling(
|
||||||
answer_instance: Answer,
|
answer_instance: Answer,
|
||||||
mock_search_results: list[LlmDoc],
|
mock_search_results: list[LlmDoc],
|
||||||
mock_contexts: OnyxContexts,
|
|
||||||
mock_search_tool: MagicMock,
|
mock_search_tool: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
answer_instance.graph_config.tooling.tools = [mock_search_tool]
|
answer_instance.graph_config.tooling.tools = [mock_search_tool]
|
||||||
@ -288,30 +280,26 @@ def test_answer_with_search_no_tool_calling(
|
|||||||
output = list(answer_instance.processed_streamed_output)
|
output = list(answer_instance.processed_streamed_output)
|
||||||
|
|
||||||
# Assertions
|
# Assertions
|
||||||
assert len(output) == 8
|
assert len(output) == 7
|
||||||
assert output[0] == ToolCallKickoff(
|
assert output[0] == ToolCallKickoff(
|
||||||
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
|
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
|
||||||
)
|
)
|
||||||
assert output[1] == ToolResponse(
|
assert output[1] == ToolResponse(
|
||||||
id=SEARCH_DOC_CONTENT_ID,
|
|
||||||
response=mock_contexts,
|
|
||||||
)
|
|
||||||
assert output[2] == ToolResponse(
|
|
||||||
id=FINAL_CONTEXT_DOCUMENTS_ID,
|
id=FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
response=mock_search_results,
|
response=mock_search_results,
|
||||||
)
|
)
|
||||||
assert output[3] == ToolCallFinalResult(
|
assert output[2] == ToolCallFinalResult(
|
||||||
tool_name="search",
|
tool_name="search",
|
||||||
tool_args=DEFAULT_SEARCH_ARGS,
|
tool_args=DEFAULT_SEARCH_ARGS,
|
||||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||||
)
|
)
|
||||||
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||||
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
||||||
assert output[5] == expected_citation
|
assert output[4] == expected_citation
|
||||||
assert output[6] == OnyxAnswerPiece(
|
assert output[5] == OnyxAnswerPiece(
|
||||||
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
||||||
)
|
)
|
||||||
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||||
|
|
||||||
expected_answer = (
|
expected_answer = (
|
||||||
"Based on the search results, "
|
"Based on the search results, "
|
||||||
|
@ -79,7 +79,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
|||||||
for res in results:
|
for res in results:
|
||||||
print(res)
|
print(res)
|
||||||
|
|
||||||
expected_count = 4 if skip_gen_ai_answer_generation else 5
|
expected_count = 3 if skip_gen_ai_answer_generation else 4
|
||||||
assert len(results) == expected_count
|
assert len(results) == expected_count
|
||||||
if not skip_gen_ai_answer_generation:
|
if not skip_gen_ai_answer_generation:
|
||||||
mock_llm.stream.assert_called_once()
|
mock_llm.stream.assert_called_once()
|
||||||
|
Reference in New Issue
Block a user