mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-01 02:30:18 +02:00
addressing PR comments
This commit is contained in:
parent
2b8cd63b34
commit
bb6d55783e
@ -85,7 +85,6 @@ if __name__ == "__main__":
|
||||
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput(logs="")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
|
@ -16,8 +16,9 @@ from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
|
||||
|
||||
class BasicInput(BaseModel):
|
||||
# TODO: subclass global log update state
|
||||
logs: str = ""
|
||||
# Langgraph needs a nonempty input, but we pass in all static
|
||||
# data through a RunnableConfig.
|
||||
_unused: bool = True
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
@ -16,17 +16,9 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# TODO: handle citations here; below is what was previously passed in
|
||||
# see basic_use_tool_response.py for where these variables come from
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=final_search_results,
|
||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
# )
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
stream: Iterator[BaseMessage],
|
||||
messages: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
@ -46,20 +38,20 @@ def process_llm_stream(
|
||||
full_answer = ""
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for response in stream:
|
||||
answer_piece = response.content
|
||||
for message in messages:
|
||||
answer_piece = message.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# TODO: handle non-string content
|
||||
logger.warning(f"Received non-string content: {type(answer_piece)}")
|
||||
# this is only used for logging, so fine to
|
||||
# just add the string representation
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
|
||||
if isinstance(response, AIMessageChunk) and (
|
||||
response.tool_call_chunks or response.tool_calls
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
tool_call_chunk += response # type: ignore
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(response, []):
|
||||
for response_part in answer_handler.handle_response_part(message, []):
|
||||
dispatch_custom_event(
|
||||
"basic_response",
|
||||
response_part,
|
||||
|
@ -87,6 +87,7 @@ def generate_initial_answer(
|
||||
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
|
||||
if agent_a_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
for tool_response in yield_search_responses(
|
||||
|
@ -10,6 +10,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -141,5 +142,10 @@ def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
|
||||
result.query_info for result in results if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
return SearchQueryInfo(
|
||||
predicted_search=None,
|
||||
final_filters=IndexFilters(access_control_list=None),
|
||||
recency_bias_multiplier=1.0,
|
||||
)
|
||||
raise ValueError("No query info found")
|
||||
return query_infos[0]
|
||||
|
@ -26,6 +26,8 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.models import ToolResponse
|
||||
from onyx.configs.agent_configs import ALLOW_REFINEMENT
|
||||
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
@ -58,7 +60,6 @@ def _parse_agent_event(
|
||||
# 1. It's a list of the names of every place we dispatch a custom event
|
||||
# 2. We maintain the intended types yielded by each event
|
||||
if event_type == "on_custom_event":
|
||||
# TODO: different AnswerStream types for different events
|
||||
if event["name"] == "decomp_qs":
|
||||
return cast(SubQuestionPiece, event["data"])
|
||||
elif event["name"] == "subqueries":
|
||||
@ -133,15 +134,30 @@ def _manage_async_event_streaming(
|
||||
return _yield_async_to_sync()
|
||||
|
||||
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig,
|
||||
graph_input: BasicInput | MainInput_a,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.message_id if config else None
|
||||
for event in compiled_graph.stream(
|
||||
stream_mode="custom",
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
# debug=True,
|
||||
):
|
||||
print(event)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig,
|
||||
input: BasicInput | MainInput_a,
|
||||
) -> AnswerStream:
|
||||
# TODO: add these to the environment
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.perform_initial_search_decomposition = True
|
||||
config.allow_refinement = True
|
||||
config.perform_initial_search_decomposition = INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
config.allow_refinement = ALLOW_REFINEMENT
|
||||
|
||||
for event in _manage_async_event_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
@ -152,8 +168,8 @@ def run_graph(
|
||||
yield parsed_object
|
||||
|
||||
|
||||
# TODO: call this once on startup, TBD where and if it should be gated based
|
||||
# on dev mode or not
|
||||
# It doesn't actually take very long to load the graph, but we'd rather
|
||||
# not compile it again on every request.
|
||||
def load_compiled_graph() -> CompiledStateGraph:
|
||||
global _COMPILED_GRAPH
|
||||
if _COMPILED_GRAPH is None:
|
||||
@ -176,13 +192,11 @@ def run_main_graph(
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
# TODO: unify input types, especially prosearchconfig
|
||||
def run_basic_graph(
|
||||
config: AgentSearchConfig,
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput()
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
@ -1,5 +1,8 @@
|
||||
import os
|
||||
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED = True
|
||||
ALLOW_REFINEMENT = True
|
||||
|
||||
AGENT_DEFAULT_RETRIEVAL_HITS = 15
|
||||
AGENT_DEFAULT_RERANKING_HITS = 10
|
||||
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
|
||||
|
Loading…
x
Reference in New Issue
Block a user