mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-12 09:00:53 +02:00
addressing PR comments
This commit is contained in:
parent
2b8cd63b34
commit
bb6d55783e
@ -85,7 +85,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
graph = basic_graph_builder()
|
graph = basic_graph_builder()
|
||||||
compiled_graph = graph.compile()
|
compiled_graph = graph.compile()
|
||||||
# TODO: unify basic input
|
|
||||||
input = BasicInput(logs="")
|
input = BasicInput(logs="")
|
||||||
primary_llm, fast_llm = get_default_llms()
|
primary_llm, fast_llm = get_default_llms()
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
|
@ -16,8 +16,9 @@ from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
|||||||
|
|
||||||
|
|
||||||
class BasicInput(BaseModel):
|
class BasicInput(BaseModel):
|
||||||
# TODO: subclass global log update state
|
# Langgraph needs a nonempty input, but we pass in all static
|
||||||
logs: str = ""
|
# data through a RunnableConfig.
|
||||||
|
_unused: bool = True
|
||||||
|
|
||||||
|
|
||||||
## Graph Output State
|
## Graph Output State
|
||||||
|
@ -16,17 +16,9 @@ from onyx.utils.logger import setup_logger
|
|||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
# TODO: handle citations here; below is what was previously passed in
|
|
||||||
# see basic_use_tool_response.py for where these variables come from
|
|
||||||
# answer_handler = CitationResponseHandler(
|
|
||||||
# context_docs=final_search_results,
|
|
||||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
|
||||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
def process_llm_stream(
|
def process_llm_stream(
|
||||||
stream: Iterator[BaseMessage],
|
messages: Iterator[BaseMessage],
|
||||||
should_stream_answer: bool,
|
should_stream_answer: bool,
|
||||||
final_search_results: list[LlmDoc] | None = None,
|
final_search_results: list[LlmDoc] | None = None,
|
||||||
displayed_search_results: list[LlmDoc] | None = None,
|
displayed_search_results: list[LlmDoc] | None = None,
|
||||||
@ -46,20 +38,20 @@ def process_llm_stream(
|
|||||||
full_answer = ""
|
full_answer = ""
|
||||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||||
# the stream will contain AIMessageChunks with tool call information.
|
# the stream will contain AIMessageChunks with tool call information.
|
||||||
for response in stream:
|
for message in messages:
|
||||||
answer_piece = response.content
|
answer_piece = message.content
|
||||||
if not isinstance(answer_piece, str):
|
if not isinstance(answer_piece, str):
|
||||||
# TODO: handle non-string content
|
# this is only used for logging, so fine to
|
||||||
logger.warning(f"Received non-string content: {type(answer_piece)}")
|
# just add the string representation
|
||||||
answer_piece = str(answer_piece)
|
answer_piece = str(answer_piece)
|
||||||
full_answer += answer_piece
|
full_answer += answer_piece
|
||||||
|
|
||||||
if isinstance(response, AIMessageChunk) and (
|
if isinstance(message, AIMessageChunk) and (
|
||||||
response.tool_call_chunks or response.tool_calls
|
message.tool_call_chunks or message.tool_calls
|
||||||
):
|
):
|
||||||
tool_call_chunk += response # type: ignore
|
tool_call_chunk += message # type: ignore
|
||||||
elif should_stream_answer:
|
elif should_stream_answer:
|
||||||
for response_part in answer_handler.handle_response_part(response, []):
|
for response_part in answer_handler.handle_response_part(message, []):
|
||||||
dispatch_custom_event(
|
dispatch_custom_event(
|
||||||
"basic_response",
|
"basic_response",
|
||||||
response_part,
|
response_part,
|
||||||
|
@ -87,6 +87,7 @@ def generate_initial_answer(
|
|||||||
|
|
||||||
# Use the query info from the base document retrieval
|
# Use the query info from the base document retrieval
|
||||||
query_info = get_query_info(state.original_question_retrieval_results)
|
query_info = get_query_info(state.original_question_retrieval_results)
|
||||||
|
|
||||||
if agent_a_config.search_tool is None:
|
if agent_a_config.search_tool is None:
|
||||||
raise ValueError("search_tool must be provided for agentic search")
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
|
@ -10,6 +10,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import (
|
|||||||
QuestionAnswerResults,
|
QuestionAnswerResults,
|
||||||
)
|
)
|
||||||
from onyx.chat.models import SubQuestionPiece
|
from onyx.chat.models import SubQuestionPiece
|
||||||
|
from onyx.context.search.models import IndexFilters
|
||||||
from onyx.tools.models import SearchQueryInfo
|
from onyx.tools.models import SearchQueryInfo
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
@ -141,5 +142,10 @@ def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
|
|||||||
result.query_info for result in results if result.query_info is not None
|
result.query_info for result in results if result.query_info is not None
|
||||||
]
|
]
|
||||||
if len(query_infos) == 0:
|
if len(query_infos) == 0:
|
||||||
|
return SearchQueryInfo(
|
||||||
|
predicted_search=None,
|
||||||
|
final_filters=IndexFilters(access_control_list=None),
|
||||||
|
recency_bias_multiplier=1.0,
|
||||||
|
)
|
||||||
raise ValueError("No query info found")
|
raise ValueError("No query info found")
|
||||||
return query_infos[0]
|
return query_infos[0]
|
||||||
|
@ -26,6 +26,8 @@ from onyx.chat.models import StreamStopInfo
|
|||||||
from onyx.chat.models import SubQueryPiece
|
from onyx.chat.models import SubQueryPiece
|
||||||
from onyx.chat.models import SubQuestionPiece
|
from onyx.chat.models import SubQuestionPiece
|
||||||
from onyx.chat.models import ToolResponse
|
from onyx.chat.models import ToolResponse
|
||||||
|
from onyx.configs.agent_configs import ALLOW_REFINEMENT
|
||||||
|
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||||
from onyx.context.search.models import SearchRequest
|
from onyx.context.search.models import SearchRequest
|
||||||
from onyx.db.engine import get_session_context_manager
|
from onyx.db.engine import get_session_context_manager
|
||||||
from onyx.tools.tool_runner import ToolCallKickoff
|
from onyx.tools.tool_runner import ToolCallKickoff
|
||||||
@ -58,7 +60,6 @@ def _parse_agent_event(
|
|||||||
# 1. It's a list of the names of every place we dispatch a custom event
|
# 1. It's a list of the names of every place we dispatch a custom event
|
||||||
# 2. We maintain the intended types yielded by each event
|
# 2. We maintain the intended types yielded by each event
|
||||||
if event_type == "on_custom_event":
|
if event_type == "on_custom_event":
|
||||||
# TODO: different AnswerStream types for different events
|
|
||||||
if event["name"] == "decomp_qs":
|
if event["name"] == "decomp_qs":
|
||||||
return cast(SubQuestionPiece, event["data"])
|
return cast(SubQuestionPiece, event["data"])
|
||||||
elif event["name"] == "subqueries":
|
elif event["name"] == "subqueries":
|
||||||
@ -133,15 +134,30 @@ def _manage_async_event_streaming(
|
|||||||
return _yield_async_to_sync()
|
return _yield_async_to_sync()
|
||||||
|
|
||||||
|
|
||||||
|
def manage_sync_streaming(
|
||||||
|
compiled_graph: CompiledStateGraph,
|
||||||
|
config: AgentSearchConfig,
|
||||||
|
graph_input: BasicInput | MainInput_a,
|
||||||
|
) -> Iterable[StreamEvent]:
|
||||||
|
message_id = config.message_id if config else None
|
||||||
|
for event in compiled_graph.stream(
|
||||||
|
stream_mode="custom",
|
||||||
|
input=graph_input,
|
||||||
|
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||||
|
# debug=True,
|
||||||
|
):
|
||||||
|
print(event)
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def run_graph(
|
def run_graph(
|
||||||
compiled_graph: CompiledStateGraph,
|
compiled_graph: CompiledStateGraph,
|
||||||
config: AgentSearchConfig,
|
config: AgentSearchConfig,
|
||||||
input: BasicInput | MainInput_a,
|
input: BasicInput | MainInput_a,
|
||||||
) -> AnswerStream:
|
) -> AnswerStream:
|
||||||
# TODO: add these to the environment
|
config.perform_initial_search_decomposition = INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||||
# config.perform_initial_search_path_decision = False
|
config.allow_refinement = ALLOW_REFINEMENT
|
||||||
config.perform_initial_search_decomposition = True
|
|
||||||
config.allow_refinement = True
|
|
||||||
|
|
||||||
for event in _manage_async_event_streaming(
|
for event in _manage_async_event_streaming(
|
||||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||||
@ -152,8 +168,8 @@ def run_graph(
|
|||||||
yield parsed_object
|
yield parsed_object
|
||||||
|
|
||||||
|
|
||||||
# TODO: call this once on startup, TBD where and if it should be gated based
|
# It doesn't actually take very long to load the graph, but we'd rather
|
||||||
# on dev mode or not
|
# not compile it again on every request.
|
||||||
def load_compiled_graph() -> CompiledStateGraph:
|
def load_compiled_graph() -> CompiledStateGraph:
|
||||||
global _COMPILED_GRAPH
|
global _COMPILED_GRAPH
|
||||||
if _COMPILED_GRAPH is None:
|
if _COMPILED_GRAPH is None:
|
||||||
@ -176,13 +192,11 @@ def run_main_graph(
|
|||||||
yield from run_graph(compiled_graph, config, input)
|
yield from run_graph(compiled_graph, config, input)
|
||||||
|
|
||||||
|
|
||||||
# TODO: unify input types, especially prosearchconfig
|
|
||||||
def run_basic_graph(
|
def run_basic_graph(
|
||||||
config: AgentSearchConfig,
|
config: AgentSearchConfig,
|
||||||
) -> AnswerStream:
|
) -> AnswerStream:
|
||||||
graph = basic_graph_builder()
|
graph = basic_graph_builder()
|
||||||
compiled_graph = graph.compile()
|
compiled_graph = graph.compile()
|
||||||
# TODO: unify basic input
|
|
||||||
input = BasicInput()
|
input = BasicInput()
|
||||||
return run_graph(compiled_graph, config, input)
|
return run_graph(compiled_graph, config, input)
|
||||||
|
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
INITIAL_SEARCH_DECOMPOSITION_ENABLED = True
|
||||||
|
ALLOW_REFINEMENT = True
|
||||||
|
|
||||||
AGENT_DEFAULT_RETRIEVAL_HITS = 15
|
AGENT_DEFAULT_RETRIEVAL_HITS = 15
|
||||||
AGENT_DEFAULT_RERANKING_HITS = 10
|
AGENT_DEFAULT_RERANKING_HITS = 10
|
||||||
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
|
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
|
||||||
|
Loading…
x
Reference in New Issue
Block a user