diff --git a/.gitignore b/.gitignore index b97fb309d796..24739991f22f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,6 @@ .vscode/ *.sw? /backend/tests/regression/answer_quality/search_test_config.yaml -/web/test-results/ \ No newline at end of file +/web/test-results/ +backend/onyx/agent_search/main/test_data.json +backend/tests/regression/answer_quality/test_data.json diff --git a/.vscode/env_template.txt b/.vscode/env_template.txt index 21b13baa6ea1..496731b34941 100644 --- a/.vscode/env_template.txt +++ b/.vscode/env_template.txt @@ -52,3 +52,9 @@ BING_API_KEY= # Enable the full set of Danswer Enterprise Edition features # NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development) ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False + +# Agent Search configs # TODO: Remove give proper namings +AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort +AGENT_RERANKING_STATS=True +AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20 +AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20 diff --git a/backend/alembic/versions/1adf5ea20d2b_agent_doc_result_col.py b/backend/alembic/versions/1adf5ea20d2b_agent_doc_result_col.py new file mode 100644 index 000000000000..62db727f9778 --- /dev/null +++ b/backend/alembic/versions/1adf5ea20d2b_agent_doc_result_col.py @@ -0,0 +1,29 @@ +"""agent_doc_result_col + +Revision ID: 1adf5ea20d2b +Revises: e9cf2bd7baed +Create Date: 2025-01-05 13:14:58.344316 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "1adf5ea20d2b" +down_revision = "e9cf2bd7baed" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add the new column with JSONB type + op.add_column( + "sub_question", + sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True), + ) + + +def downgrade() -> None: + # Drop the column + op.drop_column("sub_question", "sub_question_doc_results") diff --git a/backend/alembic/versions/925b58bd75b6_agent_metric_col_rename__s.py b/backend/alembic/versions/925b58bd75b6_agent_metric_col_rename__s.py new file mode 100644 index 000000000000..6bf5016084b3 --- /dev/null +++ b/backend/alembic/versions/925b58bd75b6_agent_metric_col_rename__s.py @@ -0,0 +1,35 @@ +"""agent_metric_col_rename__s + +Revision ID: 925b58bd75b6 +Revises: 9787be927e58 +Create Date: 2025-01-06 11:20:26.752441 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "925b58bd75b6" +down_revision = "9787be927e58" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Rename columns using PostgreSQL syntax + op.alter_column( + "agent__search_metrics", "base_duration_s", new_column_name="base_duration__s" + ) + op.alter_column( + "agent__search_metrics", "full_duration_s", new_column_name="full_duration__s" + ) + + +def downgrade() -> None: + # Revert the column renames + op.alter_column( + "agent__search_metrics", "base_duration__s", new_column_name="base_duration_s" + ) + op.alter_column( + "agent__search_metrics", "full_duration__s", new_column_name="full_duration_s" + ) diff --git a/backend/alembic/versions/9787be927e58_agent_metric_table_renames__agent__.py b/backend/alembic/versions/9787be927e58_agent_metric_table_renames__agent__.py new file mode 100644 index 000000000000..2b605f5b3d2a --- /dev/null +++ b/backend/alembic/versions/9787be927e58_agent_metric_table_renames__agent__.py @@ -0,0 +1,25 @@ +"""agent_metric_table_renames__agent__ + +Revision ID: 9787be927e58 +Revises: bceb76d618ec +Create Date: 2025-01-06 11:01:44.210160 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "9787be927e58" +down_revision = "bceb76d618ec" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Rename table from agent_search_metrics to agent__search_metrics + op.rename_table("agent_search_metrics", "agent__search_metrics") + + +def downgrade() -> None: + # Rename table back from agent__search_metrics to agent_search_metrics + op.rename_table("agent__search_metrics", "agent_search_metrics") diff --git a/backend/alembic/versions/98a5008d8711_agent_tracking.py b/backend/alembic/versions/98a5008d8711_agent_tracking.py new file mode 100644 index 000000000000..f164d49d5cb0 --- /dev/null +++ b/backend/alembic/versions/98a5008d8711_agent_tracking.py @@ -0,0 +1,42 @@ +"""agent_tracking + +Revision ID: 98a5008d8711 +Revises: 027381bce97c +Create Date: 2025-01-04 14:41:52.732238 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "98a5008d8711" +down_revision = "027381bce97c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "agent_search_metrics", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("persona_id", sa.Integer(), nullable=True), + sa.Column("agent_type", sa.String(), nullable=False), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("base_duration_s", sa.Float(), nullable=False), + sa.Column("full_duration_s", sa.Float(), nullable=False), + sa.Column("base_metrics", postgresql.JSONB(), nullable=True), + sa.Column("refined_metrics", postgresql.JSONB(), nullable=True), + sa.Column("all_metrics", postgresql.JSONB(), nullable=True), + sa.ForeignKeyConstraint( + ["persona_id"], + ["persona.id"], + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade() -> None: + op.drop_table("agent_search_metrics") diff --git a/backend/alembic/versions/bceb76d618ec_agent_table_renames__agent__.py b/backend/alembic/versions/bceb76d618ec_agent_table_renames__agent__.py new file mode 100644 index 000000000000..1c1cb2e0d846 --- /dev/null +++ b/backend/alembic/versions/bceb76d618ec_agent_table_renames__agent__.py @@ -0,0 +1,84 @@ +"""agent_table_renames__agent__ + +Revision ID: bceb76d618ec +Revises: c0132518a25b +Create Date: 2025-01-06 10:50:48.109285 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "bceb76d618ec" +down_revision = "c0132518a25b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.drop_constraint( + "sub_query__search_doc_sub_query_id_fkey", + "sub_query__search_doc", + type_="foreignkey", + ) + op.drop_constraint( + "sub_query__search_doc_search_doc_id_fkey", + "sub_query__search_doc", + type_="foreignkey", + ) + # Rename tables + op.rename_table("sub_query", "agent__sub_query") + op.rename_table("sub_question", "agent__sub_question") + op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc") + + # Update both foreign key constraints for agent__sub_query__search_doc + + # Create new foreign keys with updated names + op.create_foreign_key( + "agent__sub_query__search_doc_sub_query_id_fkey", + "agent__sub_query__search_doc", + "agent__sub_query", + ["sub_query_id"], + ["id"], + ) + op.create_foreign_key( + "agent__sub_query__search_doc_search_doc_id_fkey", + "agent__sub_query__search_doc", + "search_doc", # This table name doesn't change + ["search_doc_id"], + ["id"], + ) + + +def downgrade() -> None: + # Update foreign key constraints for sub_query__search_doc + op.drop_constraint( + "agent__sub_query__search_doc_sub_query_id_fkey", + "agent__sub_query__search_doc", + type_="foreignkey", + ) + op.drop_constraint( + "agent__sub_query__search_doc_search_doc_id_fkey", + "agent__sub_query__search_doc", + type_="foreignkey", + ) + + # Rename tables back + op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc") + op.rename_table("agent__sub_question", "sub_question") + op.rename_table("agent__sub_query", "sub_query") + + op.create_foreign_key( + "sub_query__search_doc_sub_query_id_fkey", + "sub_query__search_doc", + "sub_query", + ["sub_query_id"], + ["id"], + ) + op.create_foreign_key( + "sub_query__search_doc_search_doc_id_fkey", + "sub_query__search_doc", + "search_doc", # This table name doesn't change + ["search_doc_id"], + ["id"], + ) diff --git a/backend/alembic/versions/c0132518a25b_agent_table_changes_rename_level.py b/backend/alembic/versions/c0132518a25b_agent_table_changes_rename_level.py new file mode 100644 index 000000000000..e845380991f3 --- /dev/null +++ b/backend/alembic/versions/c0132518a25b_agent_table_changes_rename_level.py @@ -0,0 +1,40 @@ +"""agent_table_changes_rename_level + +Revision ID: c0132518a25b +Revises: 1adf5ea20d2b +Create Date: 2025-01-05 16:38:37.660152 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "c0132518a25b" +down_revision = "1adf5ea20d2b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add level and level_question_nr columns with NOT NULL constraint + op.add_column( + "sub_question", + sa.Column("level", sa.Integer(), nullable=False, server_default="0"), + ) + op.add_column( + "sub_question", + sa.Column( + "level_question_nr", sa.Integer(), nullable=False, server_default="0" + ), + ) + + # Remove the server_default after the columns are created + op.alter_column("sub_question", "level", server_default=None) + op.alter_column("sub_question", "level_question_nr", server_default=None) + + +def downgrade() -> None: + # Remove the columns + op.drop_column("sub_question", "level_question_nr") + op.drop_column("sub_question", "level") diff --git a/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py b/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py new file mode 100644 index 000000000000..6ed9d783f502 --- /dev/null +++ b/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py @@ -0,0 +1,68 @@ +"""create pro search persistence tables + +Revision ID: e9cf2bd7baed +Revises: 98a5008d8711 +Create Date: 2025-01-02 17:55:56.544246 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision = "e9cf2bd7baed" +down_revision = "98a5008d8711" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Create sub_question table + op.create_table( + "sub_question", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")), + sa.Column( + "chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id") + ), + sa.Column("sub_question", sa.Text), + sa.Column( + "time_created", sa.DateTime(timezone=True), server_default=sa.func.now() + ), + sa.Column("sub_answer", sa.Text), + ) + + # Create sub_query table + op.create_table( + "sub_query", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("parent_question_id", sa.Integer, sa.ForeignKey("sub_question.id")), + sa.Column( + "chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id") + ), + sa.Column("sub_query", sa.Text), + sa.Column( + "time_created", sa.DateTime(timezone=True), server_default=sa.func.now() + ), + ) + + # Create sub_query__search_doc association table + op.create_table( + "sub_query__search_doc", + sa.Column( + "sub_query_id", sa.Integer, sa.ForeignKey("sub_query.id"), primary_key=True + ), + sa.Column( + "search_doc_id", + sa.Integer, + sa.ForeignKey("search_doc.id"), + primary_key=True, + ), + ) + + +def downgrade() -> None: + op.drop_table("sub_query__search_doc") + op.drop_table("sub_query") + op.drop_table("sub_question") diff --git a/backend/ee/onyx/server/query_and_chat/chat_backend.py b/backend/ee/onyx/server/query_and_chat/chat_backend.py index 0a29ed003443..5a3ba29024e3 100644 --- a/backend/ee/onyx/server/query_and_chat/chat_backend.py +++ b/backend/ee/onyx/server/query_and_chat/chat_backend.py @@ -179,6 +179,7 @@ def handle_simplified_chat_message( chunks_below=0, full_doc=chat_message_req.full_doc, structured_response_format=chat_message_req.structured_response_format, + use_agentic_search=chat_message_req.use_agentic_search, ) packets = stream_chat_message_objects( @@ -301,6 +302,7 @@ def handle_send_message_simple_with_history( chunks_below=0, full_doc=req.full_doc, structured_response_format=req.structured_response_format, + use_agentic_search=req.use_agentic_search, ) packets = stream_chat_message_objects( diff --git a/backend/ee/onyx/server/query_and_chat/models.py b/backend/ee/onyx/server/query_and_chat/models.py index 4726236e01f8..eb17489fa713 100644 --- a/backend/ee/onyx/server/query_and_chat/models.py +++ b/backend/ee/onyx/server/query_and_chat/models.py @@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext): # https://platform.openai.com/docs/guides/structured-outputs/introduction structured_response_format: dict | None = None + # If True, uses agentic search instead of basic search + use_agentic_search: bool = False + class BasicCreateChatMessageWithHistoryRequest(ChunkContext): # Last element is the new query. All previous elements are historical context @@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext): # only works if using an OpenAI model. See the following for more details: # https://platform.openai.com/docs/guides/structured-outputs/introduction structured_response_format: dict | None = None + # If True, uses agentic search instead of basic search + use_agentic_search: bool = False class SimpleDoc(BaseModel): @@ -123,6 +128,9 @@ class OneShotQARequest(ChunkContext): # If True, skips generative an AI response to the search query skip_gen_ai_answer_generation: bool = False + # If True, uses pro search instead of basic search + use_agentic_search: bool = False + @model_validator(mode="after") def check_persona_fields(self) -> "OneShotQARequest": if self.persona_override_config is None and self.persona_id is None: diff --git a/backend/ee/onyx/server/query_and_chat/query_backend.py b/backend/ee/onyx/server/query_and_chat/query_backend.py index b8e7abd3e4a3..cb9003b7eff5 100644 --- a/backend/ee/onyx/server/query_and_chat/query_backend.py +++ b/backend/ee/onyx/server/query_and_chat/query_backend.py @@ -196,6 +196,7 @@ def get_answer_stream( retrieval_details=query_request.retrieval_options, rerank_settings=query_request.rerank_settings, db_session=db_session, + use_agentic_search=query_request.use_agentic_search, ) packets = stream_chat_message_objects( diff --git a/backend/onyx/agents/agent_search/basic/graph_builder.py b/backend/onyx/agents/agent_search/basic/graph_builder.py new file mode 100644 index 000000000000..4114ef036a77 --- /dev/null +++ b/backend/onyx/agents/agent_search/basic/graph_builder.py @@ -0,0 +1,103 @@ +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event +from langchain_core.runnables.config import RunnableConfig +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agents.agent_search.basic.states import BasicInput +from onyx.agents.agent_search.basic.states import BasicOutput +from onyx.agents.agent_search.basic.states import BasicState +from onyx.agents.agent_search.basic.states import BasicStateUpdate +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.chat.stream_processing.utils import ( + map_document_id_order, +) +from onyx.tools.tool_implementations.search.search_tool import SearchTool + + +def basic_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=BasicState, + input=BasicInput, + output=BasicOutput, + ) + + ### Add nodes ### + + graph.add_node( + node="get_response", + action=get_response, + ) + + ### Add edges ### + + graph.add_edge(start_key=START, end_key="get_response") + + graph.add_conditional_edges("get_response", should_continue, ["get_response", END]) + graph.add_edge( + start_key="get_response", + end_key=END, + ) + + return graph + + +def should_continue(state: BasicState) -> str: + return ( + END if state["last_llm_call"] is None or state["calls"] > 1 else "get_response" + ) + + +def get_response(state: BasicState, config: RunnableConfig) -> BasicStateUpdate: + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + llm = agent_a_config.primary_llm + current_llm_call = state["last_llm_call"] + if current_llm_call is None: + raise ValueError("last_llm_call is None") + structured_response_format = agent_a_config.structured_response_format + response_handler_manager = state["response_handler_manager"] + # DEBUG: good breakpoint + stream = llm.stream( + # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM + # may choose to not call any tools and just generate the answer, in which case the task prompt is needed. + prompt=current_llm_call.prompt_builder.build(), + tools=[tool.tool_definition() for tool in current_llm_call.tools] or None, + tool_choice=( + "required" + if current_llm_call.tools and current_llm_call.force_use_tool.force_use + else None + ), + structured_response_format=structured_response_format, + ) + + for response in response_handler_manager.handle_llm_response(stream): + dispatch_custom_event( + "basic_response", + response, + ) + + next_call = response_handler_manager.next_llm_call(current_llm_call) + if next_call is not None: + final_search_results, displayed_search_results = SearchTool.get_search_result( + next_call + ) or ([], []) + else: + final_search_results, displayed_search_results = [], [] + + response_handler_manager.answer_handler.update( + ( + final_search_results, + map_document_id_order(final_search_results), + map_document_id_order(displayed_search_results), + ) + ) + return BasicStateUpdate( + last_llm_call=next_call, + calls=state["calls"] + 1, + ) + + +if __name__ == "__main__": + pass diff --git a/backend/onyx/agents/agent_search/basic/states.py b/backend/onyx/agents/agent_search/basic/states.py new file mode 100644 index 000000000000..683c307f7460 --- /dev/null +++ b/backend/onyx/agents/agent_search/basic/states.py @@ -0,0 +1,38 @@ +from typing import TypedDict + +from onyx.chat.llm_response_handler import LLMResponseHandlerManager +from onyx.chat.prompt_builder.build import LLMCall + +## Update States + + +## Graph Input State + + +class BasicInput(TypedDict): + base_question: str + last_llm_call: LLMCall | None + response_handler_manager: LLMResponseHandlerManager + calls: int + + +## Graph Output State + + +class BasicOutput(TypedDict): + pass + + +class BasicStateUpdate(TypedDict): + last_llm_call: LLMCall | None + calls: int + + +## Graph State + + +class BasicState( + BasicInput, + BasicOutput, +): + pass diff --git a/backend/onyx/agents/agent_search/core_state.py b/backend/onyx/agents/agent_search/core_state.py new file mode 100644 index 000000000000..693356b51391 --- /dev/null +++ b/backend/onyx/agents/agent_search/core_state.py @@ -0,0 +1,20 @@ +from operator import add +from typing import Annotated +from typing import TypedDict + + +class CoreState(TypedDict, total=False): + """ + This is the core state that is shared across all subgraphs. + """ + + base_question: str + log_messages: Annotated[list[str], add] + + +class SubgraphCoreState(TypedDict, total=False): + """ + This is the core state that is shared across all subgraphs. + """ + + log_messages: Annotated[list[str], add] diff --git a/backend/onyx/agents/agent_search/db_operations.py b/backend/onyx/agents/agent_search/db_operations.py new file mode 100644 index 000000000000..3df137b1140f --- /dev/null +++ b/backend/onyx/agents/agent_search/db_operations.py @@ -0,0 +1,66 @@ +from uuid import UUID + +from sqlalchemy.orm import Session + +from onyx.db.models import AgentSubQuery +from onyx.db.models import AgentSubQuestion + + +def create_sub_question( + db_session: Session, + chat_session_id: UUID, + primary_message_id: int, + sub_question: str, + sub_answer: str, +) -> AgentSubQuestion: + """Create a new sub-question record in the database.""" + sub_q = AgentSubQuestion( + chat_session_id=chat_session_id, + primary_question_id=primary_message_id, + sub_question=sub_question, + sub_answer=sub_answer, + ) + db_session.add(sub_q) + db_session.flush() + return sub_q + + +def create_sub_query( + db_session: Session, + chat_session_id: UUID, + parent_question_id: int, + sub_query: str, +) -> AgentSubQuery: + """Create a new sub-query record in the database.""" + sub_q = AgentSubQuery( + chat_session_id=chat_session_id, + parent_question_id=parent_question_id, + sub_query=sub_query, + ) + db_session.add(sub_q) + db_session.flush() + return sub_q + + +def get_sub_questions_for_message( + db_session: Session, + primary_message_id: int, +) -> list[AgentSubQuestion]: + """Get all sub-questions for a given primary message.""" + return ( + db_session.query(AgentSubQuestion) + .filter(AgentSubQuestion.primary_question_id == primary_message_id) + .all() + ) + + +def get_sub_queries_for_question( + db_session: Session, + sub_question_id: int, +) -> list[AgentSubQuery]: + """Get all sub-queries for a given sub-question.""" + return ( + db_session.query(AgentSubQuery) + .filter(AgentSubQuery.parent_question_id == sub_question_id) + .all() + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/edges.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/edges.py new file mode 100644 index 000000000000..a67e508f7739 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/edges.py @@ -0,0 +1,26 @@ +from collections.abc import Hashable + +from langgraph.types import Send + +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionInput, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: + logger.debug("sending to expanded retrieval via edge") + + return Send( + "initial_sub_question_expanded_retrieval", + ExpandedRetrievalInput( + question=state["question"], + base_search=False, + sub_question_id=state["question_id"], + ), + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/graph_builder.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/graph_builder.py new file mode 100644 index 000000000000..5af71933facd --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/graph_builder.py @@ -0,0 +1,125 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.edges import ( + send_to_expanded_retrieval, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import ( + answer_check, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import ( + answer_generation, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import ( + format_answer, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import ( + ingest_retrieval, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionInput, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionState, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) +from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def answer_query_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=AnswerQuestionState, + input=AnswerQuestionInput, + output=AnswerQuestionOutput, + ) + + ### Add nodes ### + + expanded_retrieval = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="initial_sub_question_expanded_retrieval", + action=expanded_retrieval, + ) + graph.add_node( + node="answer_check", + action=answer_check, + ) + graph.add_node( + node="answer_generation", + action=answer_generation, + ) + graph.add_node( + node="format_answer", + action=format_answer, + ) + graph.add_node( + node="ingest_retrieval", + action=ingest_retrieval, + ) + + ### Add edges ### + + graph.add_conditional_edges( + source=START, + path=send_to_expanded_retrieval, + path_map=["initial_sub_question_expanded_retrieval"], + ) + graph.add_edge( + start_key="initial_sub_question_expanded_retrieval", + end_key="ingest_retrieval", + ) + graph.add_edge( + start_key="ingest_retrieval", + end_key="answer_generation", + ) + graph.add_edge( + start_key="answer_generation", + end_key="answer_check", + ) + graph.add_edge( + start_key="answer_check", + end_key="format_answer", + ) + graph.add_edge( + start_key="format_answer", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = answer_query_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="what can you do with onyx or danswer?", + ) + with get_session_context_manager() as db_session: + agent_search_config, search_tool = get_test_config( + db_session, primary_llm, fast_llm, search_request + ) + inputs = AnswerQuestionInput( + question="what can you do with onyx?", + question_id="0_0", + ) + for thing in compiled_graph.stream( + input=inputs, + config={"configurable": {"config": agent_search_config}}, + # debug=True, + # subgraphs=True, + ): + logger.debug(thing) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/models.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/models.py new file mode 100644 index 000000000000..361169e59bec --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/models.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +### Models ### + + +class AnswerRetrievalStats(BaseModel): + answer_retrieval_stats: dict[str, float | int] diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_check.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_check.py new file mode 100644 index 000000000000..6fe5e7a9e690 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_check.py @@ -0,0 +1,45 @@ +from typing import cast + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionState, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + QACheckUpdate, +) +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO +from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT +from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER + + +def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate: + if state["answer"] == UNKNOWN_ANSWER: + return QACheckUpdate( + answer_quality=SUB_CHECK_NO, + ) + msg = [ + HumanMessage( + content=SUB_CHECK_PROMPT.format( + question=state["question"], + base_answer=state["answer"], + ) + ) + ] + + agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"]) + fast_llm = agent_searchch_config.fast_llm + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + quality_str = merge_message_runs(response, chunk_separator="")[0].content + + return QACheckUpdate( + answer_quality=quality_str, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py new file mode 100644 index 000000000000..fa482f58a4b0 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py @@ -0,0 +1,111 @@ +import datetime +from typing import Any +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event +from langchain_core.messages import merge_message_runs +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionState, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + QAGenerationUpdate, +) +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( + build_sub_question_answer_prompt, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + ASSISTANT_SYSTEM_PROMPT_DEFAULT, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + ASSISTANT_SYSTEM_PROMPT_PERSONA, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER +from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt +from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.chat.models import AgentAnswerPiece +from onyx.chat.models import StreamStopInfo +from onyx.chat.models import StreamStopReason +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def answer_generation( + state: AnswerQuestionState, config: RunnableConfig +) -> QAGenerationUpdate: + now_start = datetime.datetime.now() + logger.debug(f"--------{now_start}--------START ANSWER GENERATION---") + + agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"]) + question = state["question"] + docs = state["documents"] + level, question_nr = parse_question_id(state["question_id"]) + persona_prompt = get_persona_prompt(agent_search_config.search_request.persona) + + if len(docs) == 0: + answer_str = UNKNOWN_ANSWER + dispatch_custom_event( + "sub_answers", + AgentAnswerPiece( + answer_piece=answer_str, + level=level, + level_question_nr=question_nr, + answer_type="agent_sub_answer", + ), + ) + else: + if len(persona_prompt) > 0: + persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT + else: + persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format( + persona_prompt=persona_prompt + ) + + logger.debug(f"Number of verified retrieval docs: {len(docs)}") + + fast_llm = agent_search_config.fast_llm + msg = build_sub_question_answer_prompt( + question=question, + original_question=agent_search_config.search_request.query, + docs=docs, + persona_specification=persona_specification, + config=fast_llm.config, + ) + + response: list[str | list[str | dict[str, Any]]] = [] + for message in fast_llm.stream( + prompt=msg, + ): + # TODO: in principle, the answer here COULD contain images, but we don't support that yet + content = message.content + if not isinstance(content, str): + raise ValueError( + f"Expected content to be a string, but got {type(content)}" + ) + dispatch_custom_event( + "sub_answers", + AgentAnswerPiece( + answer_piece=content, + level=level, + level_question_nr=question_nr, + answer_type="agent_sub_answer", + ), + ) + response.append(content) + + answer_str = merge_message_runs(response, chunk_separator="")[0].content + + stop_event = StreamStopInfo( + stop_reason=StreamStopReason.FINISHED, + stream_type="sub_answer", + level=level, + level_question_nr=question_nr, + ) + dispatch_custom_event("stream_finished", stop_event) + + return QAGenerationUpdate( + answer=answer_str, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/format_answer.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/format_answer.py new file mode 100644 index 000000000000..a82cf83bdaa0 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/format_answer.py @@ -0,0 +1,25 @@ +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionState, +) +from onyx.agents.agent_search.shared_graph_utils.models import ( + QuestionAnswerResults, +) + + +def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: + return AnswerQuestionOutput( + answer_results=[ + QuestionAnswerResults( + question=state["question"], + question_id=state["question_id"], + quality=state.get("answer_quality", "No"), + answer=state["answer"], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["documents"], + sub_question_retrieval_stats=state["sub_question_retrieval_stats"], + ) + ], + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/ingest_retrieval.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/ingest_retrieval.py new file mode 100644 index 000000000000..3b4c305c271d --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/ingest_retrieval.py @@ -0,0 +1,23 @@ +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + RetrievalIngestionUpdate, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalOutput, +) +from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats + + +def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate: + sub_question_retrieval_stats = state[ + "expanded_retrieval_result" + ].sub_question_retrieval_stats + if sub_question_retrieval_stats is None: + sub_question_retrieval_stats = [AgentChunkStats()] + + return RetrievalIngestionUpdate( + expanded_retrieval_results=state[ + "expanded_retrieval_result" + ].expanded_queries_results, + documents=state["expanded_retrieval_result"].all_documents, + sub_question_retrieval_stats=sub_question_retrieval_stats, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/states.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/states.py new file mode 100644 index 000000000000..98f464dceec8 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/states.py @@ -0,0 +1,65 @@ +from operator import add +from typing import Annotated +from typing import TypedDict + +from onyx.agents.agent_search.core_state import SubgraphCoreState +from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agents.agent_search.shared_graph_utils.models import QueryResult +from onyx.agents.agent_search.shared_graph_utils.models import ( + QuestionAnswerResults, +) +from onyx.agents.agent_search.shared_graph_utils.operators import ( + dedup_inference_sections, +) +from onyx.context.search.models import InferenceSection + + +## Update States +class QACheckUpdate(TypedDict): + answer_quality: str + + +class QAGenerationUpdate(TypedDict): + answer: str + # answer_stat: AnswerStats + + +class RetrievalIngestionUpdate(TypedDict): + expanded_retrieval_results: list[QueryResult] + documents: Annotated[list[InferenceSection], dedup_inference_sections] + sub_question_retrieval_stats: AgentChunkStats + + +## Graph Input State + + +class AnswerQuestionInput(SubgraphCoreState): + question: str + question_id: str # 0_0 is original question, everything else is _. + # level 0 is original question and first decomposition, level 1 is follow up, etc + # question_num is a unique number per original question per level. + + +## Graph State + + +class AnswerQuestionState( + AnswerQuestionInput, + QAGenerationUpdate, + QACheckUpdate, + RetrievalIngestionUpdate, +): + pass + + +## Graph Output State + + +class AnswerQuestionOutput(TypedDict): + """ + This is a list of results even though each call of this subgraph only returns one result. + This is because if we parallelize the answer query subgraph, there will be multiple + results in a list so the add operator is used to add them together. + """ + + answer_results: Annotated[list[QuestionAnswerResults], add] diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/edges.py b/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/edges.py new file mode 100644 index 000000000000..2a5fdc148d05 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/edges.py @@ -0,0 +1,26 @@ +from collections.abc import Hashable + +from langgraph.types import Send + +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionInput, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable: + logger.debug("sending to expanded retrieval for follow up question via edge") + + return Send( + "refined_sub_question_expanded_retrieval", + ExpandedRetrievalInput( + question=state["question"], + sub_question_id=state["question_id"], + base_search=False, + ), + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/graph_builder.py b/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/graph_builder.py new file mode 100644 index 000000000000..3598c4dc689a --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/graph_builder.py @@ -0,0 +1,122 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import ( + answer_check, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import ( + answer_generation, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import ( + format_answer, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import ( + ingest_retrieval, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionInput, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionState, +) +from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.edges import ( + send_to_expanded_refined_retrieval, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def answer_refined_query_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=AnswerQuestionState, + input=AnswerQuestionInput, + output=AnswerQuestionOutput, + ) + + ### Add nodes ### + + expanded_retrieval = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="refined_sub_question_expanded_retrieval", + action=expanded_retrieval, + ) + graph.add_node( + node="refined_sub_answer_check", + action=answer_check, + ) + graph.add_node( + node="refined_sub_answer_generation", + action=answer_generation, + ) + graph.add_node( + node="format_refined_sub_answer", + action=format_answer, + ) + graph.add_node( + node="ingest_refined_retrieval", + action=ingest_retrieval, + ) + + ### Add edges ### + + graph.add_conditional_edges( + source=START, + path=send_to_expanded_refined_retrieval, + path_map=["refined_sub_question_expanded_retrieval"], + ) + graph.add_edge( + start_key="refined_sub_question_expanded_retrieval", + end_key="ingest_refined_retrieval", + ) + graph.add_edge( + start_key="ingest_refined_retrieval", + end_key="refined_sub_answer_generation", + ) + graph.add_edge( + start_key="refined_sub_answer_generation", + end_key="refined_sub_answer_check", + ) + graph.add_edge( + start_key="refined_sub_answer_check", + end_key="format_refined_sub_answer", + ) + graph.add_edge( + start_key="format_refined_sub_answer", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = answer_refined_query_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="what can you do with onyx or danswer?", + ) + with get_session_context_manager() as db_session: + inputs = AnswerQuestionInput( + question="what can you do with onyx?", + question_id="0_0", + ) + for thing in compiled_graph.stream( + input=inputs, + # debug=True, + # subgraphs=True, + ): + logger.debug(thing) + # output = compiled_graph.invoke(inputs) + # logger.debug(output) diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/models.py b/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/models.py new file mode 100644 index 000000000000..5ef251eb0a02 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_refinement_sub_question/models.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + +from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.context.search.models import InferenceSection + +### Models ### + + +class AnswerRetrievalStats(BaseModel): + answer_retrieval_stats: dict[str, float | int] + + +class QuestionAnswerResults(BaseModel): + question: str + answer: str + quality: str + # expanded_retrieval_results: list[QueryResult] + documents: list[InferenceSection] + sub_question_retrieval_stats: AgentChunkStats diff --git a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/graph_builder.py b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/graph_builder.py new file mode 100644 index 000000000000..2d18e4f94368 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/graph_builder.py @@ -0,0 +1,76 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.format_raw_search_results import ( + format_raw_search_results, +) +from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.generate_raw_search_data import ( + generate_raw_search_data, +) +from onyx.agents.agent_search.deep_search_a.base_raw_search.states import ( + BaseRawSearchInput, +) +from onyx.agents.agent_search.deep_search_a.base_raw_search.states import ( + BaseRawSearchOutput, +) +from onyx.agents.agent_search.deep_search_a.base_raw_search.states import ( + BaseRawSearchState, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) + + +def base_raw_search_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=BaseRawSearchState, + input=BaseRawSearchInput, + output=BaseRawSearchOutput, + ) + + ### Add nodes ### + + graph.add_node( + node="generate_raw_search_data", + action=generate_raw_search_data, + ) + + expanded_retrieval = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="expanded_retrieval_base_search", + action=expanded_retrieval, + ) + graph.add_node( + node="format_raw_search_results", + action=format_raw_search_results, + ) + + ### Add edges ### + + graph.add_edge(start_key=START, end_key="generate_raw_search_data") + + graph.add_edge( + start_key="generate_raw_search_data", + end_key="expanded_retrieval_base_search", + ) + graph.add_edge( + start_key="expanded_retrieval_base_search", + end_key="format_raw_search_results", + ) + + # graph.add_edge( + # start_key="expanded_retrieval_base_search", + # end_key=END, + # ) + + graph.add_edge( + start_key="format_raw_search_results", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + pass diff --git a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/models.py b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/models.py new file mode 100644 index 000000000000..49f496b038df --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/models.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel + +from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agents.agent_search.shared_graph_utils.models import QueryResult +from onyx.context.search.models import InferenceSection + +### Models ### + + +class AnswerRetrievalStats(BaseModel): + answer_retrieval_stats: dict[str, float | int] + + +class QuestionAnswerResults(BaseModel): + question: str + answer: str + quality: str + expanded_retrieval_results: list[QueryResult] + documents: list[InferenceSection] + sub_question_retrieval_stats: list[AgentChunkStats] diff --git a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/format_raw_search_results.py b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/format_raw_search_results.py new file mode 100644 index 000000000000..527a01323492 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/format_raw_search_results.py @@ -0,0 +1,18 @@ +from onyx.agents.agent_search.deep_search_a.base_raw_search.states import ( + BaseRawSearchOutput, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalOutput, +) +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput: + logger.debug("format_raw_search_results") + return BaseRawSearchOutput( + base_expanded_retrieval_result=state["expanded_retrieval_result"], + # base_retrieval_results=[state["expanded_retrieval_result"]], + # base_search_documents=[], + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/generate_raw_search_data.py new file mode 100644 index 000000000000..22e69eee983b --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/nodes/generate_raw_search_data.py @@ -0,0 +1,24 @@ +from typing import cast + +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.core_state import CoreState +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def generate_raw_search_data( + state: CoreState, config: RunnableConfig +) -> ExpandedRetrievalInput: + logger.debug("generate_raw_search_data") + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + return ExpandedRetrievalInput( + question=agent_a_config.search_request.query, + base_search=True, + sub_question_id=None, # This graph is always and only used for the original question + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/states.py b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/states.py new file mode 100644 index 000000000000..90676e77e3eb --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/base_raw_search/states.py @@ -0,0 +1,43 @@ +from typing import TypedDict + +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import ( + ExpandedRetrievalResult, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) + + +## Update States + + +## Graph Input State + + +class BaseRawSearchInput(ExpandedRetrievalInput): + pass + + +## Graph Output State + + +class BaseRawSearchOutput(TypedDict): + """ + This is a list of results even though each call of this subgraph only returns one result. + This is because if we parallelize the answer query subgraph, there will be multiple + results in a list so the add operator is used to add them together. + """ + + # base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections] + # base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] + base_expanded_retrieval_result: ExpandedRetrievalResult + + +## Graph State + + +class BaseRawSearchState( + BaseRawSearchInput, + BaseRawSearchOutput, +): + pass diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/edges.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/edges.py new file mode 100644 index 000000000000..6a2db1402a27 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/edges.py @@ -0,0 +1,34 @@ +from collections.abc import Hashable +from typing import cast + +from langchain_core.runnables.config import RunnableConfig +from langgraph.types import Send + +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalState, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + RetrievalInput, +) +from onyx.agents.agent_search.models import AgentSearchConfig + + +def parallel_retrieval_edge( + state: ExpandedRetrievalState, config: RunnableConfig +) -> list[Send | Hashable]: + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + question = state.get("question", agent_a_config.search_request.query) + + query_expansions = state.get("expanded_queries", []) + [question] + return [ + Send( + "doc_retrieval", + RetrievalInput( + query_to_retrieve=query, + question=question, + base_search=False, + sub_question_id=state.get("sub_question_id"), + ), + ) + for query in query_expansions + ] diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/graph_builder.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/graph_builder.py new file mode 100644 index 000000000000..1f0d88ececba --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/graph_builder.py @@ -0,0 +1,134 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.edges import ( + parallel_retrieval_edge, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_reranking import ( + doc_reranking, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_retrieval import ( + doc_retrieval, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_verification import ( + doc_verification, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.expand_queries import ( + expand_queries, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.format_results import ( + format_results, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.verification_kickoff import ( + verification_kickoff, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalOutput, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalState, +) +from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def expanded_retrieval_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=ExpandedRetrievalState, + input=ExpandedRetrievalInput, + output=ExpandedRetrievalOutput, + ) + + ### Add nodes ### + + graph.add_node( + node="expand_queries", + action=expand_queries, + ) + + graph.add_node( + node="doc_retrieval", + action=doc_retrieval, + ) + graph.add_node( + node="verification_kickoff", + action=verification_kickoff, + ) + graph.add_node( + node="doc_verification", + action=doc_verification, + ) + graph.add_node( + node="doc_reranking", + action=doc_reranking, + ) + graph.add_node( + node="format_results", + action=format_results, + ) + + ### Add edges ### + graph.add_edge( + start_key=START, + end_key="expand_queries", + ) + + graph.add_conditional_edges( + source="expand_queries", + path=parallel_retrieval_edge, + path_map=["doc_retrieval"], + ) + graph.add_edge( + start_key="doc_retrieval", + end_key="verification_kickoff", + ) + graph.add_edge( + start_key="doc_verification", + end_key="doc_reranking", + ) + graph.add_edge( + start_key="doc_reranking", + end_key="format_results", + ) + graph.add_edge( + start_key="format_results", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = expanded_retrieval_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="what can you do with onyx or danswer?", + ) + + with get_session_context_manager() as db_session: + agent_a_config, search_tool = get_test_config( + db_session, primary_llm, fast_llm, search_request + ) + inputs = ExpandedRetrievalInput( + question="what can you do with onyx?", + base_search=False, + sub_question_id=None, + ) + for thing in compiled_graph.stream( + input=inputs, + config={"configurable": {"config": agent_a_config}}, + # debug=True, + subgraphs=True, + ): + logger.debug(thing) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/models.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/models.py new file mode 100644 index 000000000000..139f3311a099 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/models.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agents.agent_search.shared_graph_utils.models import QueryResult +from onyx.context.search.models import InferenceSection + + +class ExpandedRetrievalResult(BaseModel): + expanded_queries_results: list[QueryResult] + all_documents: list[InferenceSection] + sub_question_retrieval_stats: AgentChunkStats diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py new file mode 100644 index 000000000000..03a50af61580 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py @@ -0,0 +1,69 @@ +from typing import cast + +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + DocRerankingUpdate, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalState, +) +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores +from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats +from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS +from onyx.configs.dev_configs import AGENT_RERANKING_STATS +from onyx.context.search.models import InferenceSection +from onyx.context.search.models import SearchRequest +from onyx.context.search.pipeline import retrieval_preprocessing +from onyx.context.search.postprocessing.postprocessing import rerank_sections +from onyx.db.engine import get_session_context_manager + + +def doc_reranking( + state: ExpandedRetrievalState, config: RunnableConfig +) -> DocRerankingUpdate: + verified_documents = state["verified_documents"] + + # Rerank post retrieval and verification. First, create a search query + # then create the list of reranked sections + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + question = state.get("question", agent_a_config.search_request.query) + with get_session_context_manager() as db_session: + _search_query = retrieval_preprocessing( + search_request=SearchRequest(query=question), + user=agent_a_config.search_tool.user, # bit of a hack + llm=agent_a_config.fast_llm, + db_session=db_session, + ) + + # skip section filtering + + if ( + _search_query.rerank_settings + and _search_query.rerank_settings.rerank_model_name + and _search_query.rerank_settings.num_rerank > 0 + ): + reranked_documents = rerank_sections( + _search_query, + verified_documents, + ) + else: + logger.warning("No reranking settings found, using unranked documents") + reranked_documents = verified_documents + + if AGENT_RERANKING_STATS: + fit_scores = get_fit_scores(verified_documents, reranked_documents) + else: + fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={}) + + # TODO: stream deduped docs here, or decide to use search tool ranking/verification + + return DocRerankingUpdate( + reranked_documents=[ + doc for doc in reranked_documents if type(doc) == InferenceSection + ][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS], + sub_question_retrieval_stats=fit_scores, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py new file mode 100644 index 000000000000..899f1d677a4b --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py @@ -0,0 +1,93 @@ +from typing import cast + +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + DocRetrievalUpdate, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + RetrievalInput, +) +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores +from onyx.agents.agent_search.shared_graph_utils.models import QueryResult +from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS +from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS +from onyx.context.search.models import InferenceSection +from onyx.db.engine import get_session_context_manager +from onyx.tools.models import SearchQueryInfo +from onyx.tools.tool_implementations.search.search_tool import ( + SEARCH_RESPONSE_SUMMARY_ID, +) +from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary + + +def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrievalUpdate: + """ + Retrieve documents + + Args: + state (RetrievalInput): Primary state + the query to retrieve + config (RunnableConfig): Configuration containing ProSearchConfig + + Updates: + expanded_retrieval_results: list[ExpandedRetrievalResult] + retrieved_documents: list[InferenceSection] + """ + query_to_retrieve = state["query_to_retrieve"] + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + search_tool = agent_a_config.search_tool + + retrieved_docs: list[InferenceSection] = [] + if not query_to_retrieve.strip(): + logger.warning("Empty query, skipping retrieval") + return DocRetrievalUpdate( + expanded_retrieval_results=[], + retrieved_documents=[], + ) + + query_info = None + # new db session to avoid concurrency issues + with get_session_context_manager() as db_session: + for tool_response in search_tool.run( + query=query_to_retrieve, + force_no_rerank=True, + alternate_db_session=db_session, + ): + # get retrieved docs to send to the rest of the graph + if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: + response = cast(SearchResponseSummary, tool_response.response) + retrieved_docs = response.top_sections + query_info = SearchQueryInfo( + predicted_search=response.predicted_search, + final_filters=response.final_filters, + recency_bias_multiplier=response.recency_bias_multiplier, + ) + break + + retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] + pre_rerank_docs = retrieved_docs + if search_tool.search_pipeline is not None: + pre_rerank_docs = ( + search_tool.search_pipeline._retrieved_sections or retrieved_docs + ) + + if AGENT_RETRIEVAL_STATS: + fit_scores = get_fit_scores( + pre_rerank_docs, + retrieved_docs, + ) + else: + fit_scores = None + + expanded_retrieval_result = QueryResult( + query=query_to_retrieve, + search_results=retrieved_docs, + stats=fit_scores, + query_info=query_info, + ) + return DocRetrievalUpdate( + expanded_retrieval_results=[expanded_retrieval_result], + retrieved_documents=retrieved_docs, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_verification.py new file mode 100644 index 000000000000..fbd10b30d714 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_verification.py @@ -0,0 +1,60 @@ +from typing import cast + +from langchain_core.messages import HumanMessage +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + DocVerificationInput, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + DocVerificationUpdate, +) +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( + trim_prompt_piece, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT + + +def doc_verification( + state: DocVerificationInput, config: RunnableConfig +) -> DocVerificationUpdate: + """ + Check whether the document is relevant for the original user question + + Args: + state (DocVerificationInput): The current state + config (RunnableConfig): Configuration containing ProSearchConfig + + Updates: + verified_documents: list[InferenceSection] + """ + + question = state["question"] + doc_to_verify = state["doc_to_verify"] + document_content = doc_to_verify.combined_content + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + fast_llm = agent_a_config.fast_llm + + document_content = trim_prompt_piece( + fast_llm.config, document_content, VERIFIER_PROMPT + question + ) + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=question, document_content=document_content + ) + ) + ] + + response = fast_llm.invoke(msg) + + verified_documents = [] + if isinstance(response.content, str) and "yes" in response.content.lower(): + verified_documents.append(doc_to_verify) + + return DocVerificationUpdate( + verified_documents=verified_documents, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/expand_queries.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/expand_queries.py new file mode 100644 index 000000000000..b0b168d4c72d --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/expand_queries.py @@ -0,0 +1,59 @@ +from typing import cast + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import ( + dispatch_subquery, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + QueryExpansionUpdate, +) +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + REWRITE_PROMPT_MULTI_ORIGINAL, +) +from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated +from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id + + +def expand_queries( + state: ExpandedRetrievalInput, config: RunnableConfig +) -> QueryExpansionUpdate: + # Sometimes we want to expand the original question, sometimes we want to expand a sub-question. + # When we are running this node on the original question, no question is explictly passed in. + # Instead, we use the original question from the search request. + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + question = state.get("question", agent_a_config.search_request.query) + llm = agent_a_config.fast_llm + chat_session_id = agent_a_config.chat_session_id + sub_question_id = state.get("sub_question_id") + if sub_question_id is None: + level, question_nr = 0, 0 + else: + level, question_nr = parse_question_id(sub_question_id) + + if chat_session_id is None: + raise ValueError("chat_session_id must be provided for agent search") + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), + ) + ] + + llm_response_list = dispatch_separated( + llm.stream(prompt=msg), dispatch_subquery(level, question_nr) + ) + + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + rewritten_queries = llm_response.split("\n") + + return QueryExpansionUpdate( + expanded_queries=rewritten_queries, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py new file mode 100644 index 000000000000..b3ebf5de99ea --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py @@ -0,0 +1,83 @@ +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event +from langchain_core.runnables.config import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import ( + ExpandedRetrievalResult, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import ( + calculate_sub_question_retrieval_stats, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalState, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalUpdate, +) +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.chat.models import ExtendedToolResponse +from onyx.tools.tool_implementations.search.search_tool import yield_search_responses + + +def format_results( + state: ExpandedRetrievalState, config: RunnableConfig +) -> ExpandedRetrievalUpdate: + level, question_nr = parse_question_id(state.get("sub_question_id") or "0_0") + query_infos = [ + result.query_info + for result in state["expanded_retrieval_results"] + if result.query_info is not None + ] + if len(query_infos) == 0: + raise ValueError("No query info found") + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + # main question docs will be sent later after aggregation and deduping with sub-question docs + if not (level == 0 and question_nr == 0): + if len(state["reranked_documents"]) > 0: + stream_documents = state["reranked_documents"] + else: + # The sub-question is used as the last query. If no verified documents are found, stream + # the top 3 for that one. We may want to revisit this. + stream_documents = state["expanded_retrieval_results"][-1].search_results[ + :3 + ] + for tool_response in yield_search_responses( + query=state["question"], + reranked_sections=state[ + "retrieved_documents" + ], # TODO: rename params. this one is supposed to be the sections pre-merging + final_context_sections=stream_documents, + search_query_info=query_infos[0], # TODO: handle differing query infos? + get_section_relevance=lambda: None, # TODO: add relevance + search_tool=agent_a_config.search_tool, + ): + dispatch_custom_event( + "tool_response", + ExtendedToolResponse( + id=tool_response.id, + response=tool_response.response, + level=level, + level_question_nr=question_nr, + ), + ) + sub_question_retrieval_stats = calculate_sub_question_retrieval_stats( + verified_documents=state["verified_documents"], + expanded_retrieval_results=state["expanded_retrieval_results"], + ) + + if sub_question_retrieval_stats is None: + sub_question_retrieval_stats = AgentChunkStats() + # else: + # sub_question_retrieval_stats = [sub_question_retrieval_stats] + + return ExpandedRetrievalUpdate( + expanded_retrieval_result=ExpandedRetrievalResult( + expanded_queries_results=state["expanded_retrieval_results"], + all_documents=state["reranked_documents"], + sub_question_retrieval_stats=sub_question_retrieval_stats, + ), + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/verification_kickoff.py new file mode 100644 index 000000000000..744242b726dc --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/verification_kickoff.py @@ -0,0 +1,39 @@ +from typing import cast +from typing import Literal + +from langchain_core.runnables.config import RunnableConfig +from langgraph.types import Command +from langgraph.types import Send + +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + DocVerificationInput, +) +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import ( + ExpandedRetrievalState, +) +from onyx.agents.agent_search.models import AgentSearchConfig + + +def verification_kickoff( + state: ExpandedRetrievalState, + config: RunnableConfig, +) -> Command[Literal["doc_verification"]]: + documents = state["retrieved_documents"] + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + verification_question = state.get("question", agent_a_config.search_request.query) + sub_question_id = state.get("sub_question_id") + return Command( + update={}, + goto=[ + Send( + node="doc_verification", + arg=DocVerificationInput( + doc_to_verify=doc, + question=verification_question, + base_search=False, + sub_question_id=sub_question_id, + ), + ) + for doc in documents + ], + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/operations.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/operations.py new file mode 100644 index 000000000000..86434a669192 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/operations.py @@ -0,0 +1,97 @@ +from collections import defaultdict +from collections.abc import Callable + +import numpy as np +from langchain_core.callbacks.manager import dispatch_custom_event + +from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agents.agent_search.shared_graph_utils.models import QueryResult +from onyx.chat.models import SubQueryPiece +from onyx.context.search.models import InferenceSection +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]: + def helper(token: str, num: int) -> None: + dispatch_custom_event( + "subqueries", + SubQueryPiece( + sub_query=token, + level=level, + level_question_nr=question_nr, + query_id=num, + ), + ) + + return helper + + +def calculate_sub_question_retrieval_stats( + verified_documents: list[InferenceSection], + expanded_retrieval_results: list[QueryResult], +) -> AgentChunkStats: + chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict( + lambda: defaultdict(list) + ) + + for expanded_retrieval_result in expanded_retrieval_results: + for doc in expanded_retrieval_result.search_results: + doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}" + if doc.center_chunk.score is not None: + chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score) + + verified_doc_chunk_ids = [ + f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}" + for verified_document in verified_documents + ] + dismissed_doc_chunk_ids = [] + + raw_chunk_stats_counts: dict[str, int] = defaultdict(int) + raw_chunk_stats_scores: dict[str, float] = defaultdict(float) + for doc_chunk_id, chunk_data in chunk_scores.items(): + if doc_chunk_id in verified_doc_chunk_ids: + raw_chunk_stats_counts["verified_count"] += 1 + + valid_chunk_scores = [ + score for score in chunk_data["score"] if score is not None + ] + raw_chunk_stats_scores["verified_scores"] += float( + np.mean(valid_chunk_scores) + ) + else: + raw_chunk_stats_counts["rejected_count"] += 1 + valid_chunk_scores = [ + score for score in chunk_data["score"] if score is not None + ] + raw_chunk_stats_scores["rejected_scores"] += float( + np.mean(valid_chunk_scores) + ) + dismissed_doc_chunk_ids.append(doc_chunk_id) + + if raw_chunk_stats_counts["verified_count"] == 0: + verified_avg_scores = 0.0 + else: + verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float( + raw_chunk_stats_counts["verified_count"] + ) + + rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None) + if rejected_scores is not None: + rejected_avg_scores = rejected_scores / float( + raw_chunk_stats_counts["rejected_count"] + ) + else: + rejected_avg_scores = None + + chunk_stats = AgentChunkStats( + verified_count=raw_chunk_stats_counts["verified_count"], + verified_avg_scores=verified_avg_scores, + rejected_count=raw_chunk_stats_counts["rejected_count"], + rejected_avg_scores=rejected_avg_scores, + verified_doc_chunk_ids=verified_doc_chunk_ids, + dismissed_doc_chunk_ids=dismissed_doc_chunk_ids, + ) + + return chunk_stats diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/states.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/states.py new file mode 100644 index 000000000000..2572b83906ba --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/states.py @@ -0,0 +1,84 @@ +from operator import add +from typing import Annotated +from typing import TypedDict + +from onyx.agents.agent_search.core_state import SubgraphCoreState +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import ( + ExpandedRetrievalResult, +) +from onyx.agents.agent_search.shared_graph_utils.models import QueryResult +from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats +from onyx.agents.agent_search.shared_graph_utils.operators import ( + dedup_inference_sections, +) +from onyx.context.search.models import InferenceSection + + +### States ### + +## Graph Input State + + +class ExpandedRetrievalInput(SubgraphCoreState): + question: str + base_search: bool + sub_question_id: str | None + + +## Update/Return States + + +class QueryExpansionUpdate(TypedDict): + expanded_queries: list[str] + + +class DocVerificationUpdate(TypedDict): + verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class DocRetrievalUpdate(TypedDict): + expanded_retrieval_results: Annotated[list[QueryResult], add] + retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class DocRerankingUpdate(TypedDict): + reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] + sub_question_retrieval_stats: RetrievalFitStats | None + + +class ExpandedRetrievalUpdate(TypedDict): + expanded_retrieval_result: ExpandedRetrievalResult + + +## Graph Output State + + +class ExpandedRetrievalOutput(TypedDict): + expanded_retrieval_result: ExpandedRetrievalResult + base_expanded_retrieval_result: ExpandedRetrievalResult + + +## Graph State + + +class ExpandedRetrievalState( + # This includes the core state + ExpandedRetrievalInput, + QueryExpansionUpdate, + DocRetrievalUpdate, + DocVerificationUpdate, + DocRerankingUpdate, + ExpandedRetrievalOutput, +): + pass + + +## Conditional Input States + + +class DocVerificationInput(ExpandedRetrievalInput): + doc_to_verify: InferenceSection + + +class RetrievalInput(ExpandedRetrievalInput): + query_to_retrieve: str diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/edges.py b/backend/onyx/agents/agent_search/deep_search_a/main/edges.py new file mode 100644 index 000000000000..ed7fca7f0dbd --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/edges.py @@ -0,0 +1,91 @@ +from collections.abc import Hashable +from typing import Literal + +from langgraph.types import Send + +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionInput, +) +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.deep_search_a.main.states import ( + RequireRefinedAnswerUpdate, +) +from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def parallelize_initial_sub_question_answering( + state: MainState, +) -> list[Send | Hashable]: + if len(state["initial_decomp_questions"]) > 0: + # sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]] + # if len(state["sub_question_records"]) == 0: + # if state["config"].use_persistence: + # raise ValueError("No sub-questions found for initial decompozed questions") + # else: + # # in this case, we are doing retrieval on the original question. + # # to make all the logic consistent, we create a new sub-question + # # with the same content as the original question + # sub_question_record_ids = [1] * len(state["initial_decomp_questions"]) + + return [ + Send( + "answer_query_subgraph", + AnswerQuestionInput( + question=question, + question_id=make_question_id(0, question_nr + 1), + ), + ) + for question_nr, question in enumerate(state["initial_decomp_questions"]) + ] + + else: + return [ + Send( + "ingest_answers", + AnswerQuestionOutput( + answer_results=[], + ), + ) + ] + + +# Define the function that determines whether to continue or not +def continue_to_refined_answer_or_end( + state: RequireRefinedAnswerUpdate, +) -> Literal["refined_sub_question_creation", "logging_node"]: + if state["require_refined_answer"]: + return "refined_sub_question_creation" + else: + return "logging_node" + + +def parallelize_refined_sub_question_answering( + state: MainState, +) -> list[Send | Hashable]: + if len(state["refined_sub_questions"]) > 0: + return [ + Send( + "answer_refined_question", + AnswerQuestionInput( + question=question_data.sub_question, + question_id=make_question_id(1, question_nr), + ), + ) + for question_nr, question_data in state["refined_sub_questions"].items() + ] + + else: + return [ + Send( + "ingest_refined_sub_answers", + AnswerQuestionOutput( + answer_results=[], + ), + ) + ] diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py b/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py new file mode 100644 index 000000000000..d99c8d33d6c6 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py @@ -0,0 +1,339 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.graph_builder import ( + answer_query_graph_builder, +) +from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.graph_builder import ( + answer_refined_query_graph_builder, +) +from onyx.agents.agent_search.deep_search_a.base_raw_search.graph_builder import ( + base_raw_search_graph_builder, +) +from onyx.agents.agent_search.deep_search_a.main.edges import ( + continue_to_refined_answer_or_end, +) +from onyx.agents.agent_search.deep_search_a.main.edges import ( + parallelize_initial_sub_question_answering, +) +from onyx.agents.agent_search.deep_search_a.main.edges import ( + parallelize_refined_sub_question_answering, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.agent_logging import ( + agent_logging, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_decision import ( + agent_path_decision, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_routing import ( + agent_path_routing, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import ( + agent_search_start, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.direct_llm_handling import ( + direct_llm_handling, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import ( + entity_term_extraction_llm, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.generate_initial_answer import ( + generate_initial_answer, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import ( + generate_refined_answer, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_base_retrieval import ( + ingest_initial_base_retrieval, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_sub_question_answers import ( + ingest_initial_sub_question_answers, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_refined_answers import ( + ingest_refined_answers, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.initial_answer_quality_check import ( + initial_answer_quality_check, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.initial_sub_question_creation import ( + initial_sub_question_creation, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision import ( + refined_answer_decision, +) +from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import ( + refined_sub_question_creation, +) +from onyx.agents.agent_search.deep_search_a.main.states import MainInput +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config +from onyx.utils.logger import setup_logger + +logger = setup_logger() + +test_mode = False + + +def main_graph_builder(test_mode: bool = False) -> StateGraph: + graph = StateGraph( + state_schema=MainState, + input=MainInput, + ) + + graph.add_node( + node="agent_path_decision", + action=agent_path_decision, + ) + + graph.add_node( + node="agent_path_routing", + action=agent_path_routing, + ) + + graph.add_node( + node="LLM", + action=direct_llm_handling, + ) + + graph.add_node( + node="agent_search_start", + action=agent_search_start, + ) + + graph.add_node( + node="initial_sub_question_creation", + action=initial_sub_question_creation, + ) + answer_query_subgraph = answer_query_graph_builder().compile() + graph.add_node( + node="answer_query_subgraph", + action=answer_query_subgraph, + ) + + base_raw_search_subgraph = base_raw_search_graph_builder().compile() + graph.add_node( + node="base_raw_search_subgraph", + action=base_raw_search_subgraph, + ) + + # refined_answer_subgraph = refined_answers_graph_builder().compile() + # graph.add_node( + # node="refined_answer_subgraph", + # action=refined_answer_subgraph, + # ) + + graph.add_node( + node="refined_sub_question_creation", + action=refined_sub_question_creation, + ) + + answer_refined_question = answer_refined_query_graph_builder().compile() + graph.add_node( + node="answer_refined_question", + action=answer_refined_question, + ) + + graph.add_node( + node="ingest_refined_answers", + action=ingest_refined_answers, + ) + + graph.add_node( + node="generate_refined_answer", + action=generate_refined_answer, + ) + + # graph.add_node( + # node="check_refined_answer", + # action=check_refined_answer, + # ) + + graph.add_node( + node="ingest_initial_retrieval", + action=ingest_initial_base_retrieval, + ) + graph.add_node( + node="ingest_initial_sub_question_answers", + action=ingest_initial_sub_question_answers, + ) + graph.add_node( + node="generate_initial_answer", + action=generate_initial_answer, + ) + + graph.add_node( + node="initial_answer_quality_check", + action=initial_answer_quality_check, + ) + + graph.add_node( + node="entity_term_extraction_llm", + action=entity_term_extraction_llm, + ) + graph.add_node( + node="refined_answer_decision", + action=refined_answer_decision, + ) + + graph.add_node( + node="logging_node", + action=agent_logging, + ) + # if test_mode: + # graph.add_node( + # node="generate_initial_base_answer", + # action=generate_initial_base_answer, + # ) + + ### Add edges ### + + # raph.add_edge(start_key=START, end_key="base_raw_search_subgraph") + + graph.add_edge( + start_key=START, + end_key="agent_path_decision", + ) + + graph.add_edge( + start_key="agent_path_decision", + end_key="agent_path_routing", + ) + + graph.add_edge( + start_key="agent_search_start", + end_key="base_raw_search_subgraph", + ) + + graph.add_edge( + start_key="agent_search_start", + end_key="initial_sub_question_creation", + ) + + graph.add_edge( + start_key="base_raw_search_subgraph", + end_key="ingest_initial_retrieval", + ) + + graph.add_edge( + start_key="LLM", + end_key=END, + ) + + # graph.add_edge( + # start_key=START, + # end_key="initial_sub_question_creation", + # ) + + graph.add_conditional_edges( + source="initial_sub_question_creation", + path=parallelize_initial_sub_question_answering, + path_map=["answer_query_subgraph"], + ) + graph.add_edge( + start_key="answer_query_subgraph", + end_key="ingest_initial_sub_question_answers", + ) + + graph.add_edge( + start_key=["ingest_initial_sub_question_answers", "ingest_initial_retrieval"], + end_key="generate_initial_answer", + ) + + graph.add_edge( + start_key="generate_initial_answer", + end_key="entity_term_extraction_llm", + ) + + graph.add_edge( + start_key="generate_initial_answer", + end_key="initial_answer_quality_check", + ) + + graph.add_edge( + start_key=["initial_answer_quality_check", "entity_term_extraction_llm"], + end_key="refined_answer_decision", + ) + + graph.add_conditional_edges( + source="refined_answer_decision", + path=continue_to_refined_answer_or_end, + path_map=["refined_sub_question_creation", "logging_node"], + ) + + graph.add_conditional_edges( + source="refined_sub_question_creation", # DONE + path=parallelize_refined_sub_question_answering, + path_map=["answer_refined_question"], + ) + graph.add_edge( + start_key="answer_refined_question", # HERE + end_key="ingest_refined_answers", + ) + + graph.add_edge( + start_key="ingest_refined_answers", + end_key="generate_refined_answer", + ) + + # graph.add_conditional_edges( + # source="refined_answer_decision", + # path=continue_to_refined_answer_or_end, + # path_map=["refined_answer_subgraph", END], + # ) + + # graph.add_edge( + # start_key="refined_answer_subgraph", + # end_key="generate_refined_answer", + # ) + + graph.add_edge( + start_key="generate_refined_answer", + end_key="logging_node", + ) + + graph.add_edge( + start_key="logging_node", + end_key=END, + ) + + # graph.add_edge( + # start_key="generate_refined_answer", + # end_key="check_refined_answer", + # ) + + # graph.add_edge( + # start_key="check_refined_answer", + # end_key=END, + # ) + + return graph + + +if __name__ == "__main__": + pass + + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = main_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + + with get_session_context_manager() as db_session: + search_request = SearchRequest(query="Who created Excel?") + agent_a_config, search_tool = get_test_config( + db_session, primary_llm, fast_llm, search_request + ) + + inputs = MainInput() + + for thing in compiled_graph.stream( + input=inputs, + config={"configurable": {"config": agent_a_config}}, + # stream_mode="debug", + # debug=True, + subgraphs=True, + ): + logger.debug(thing) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/models.py b/backend/onyx/agents/agent_search/deep_search_a/main/models.py new file mode 100644 index 000000000000..2bb487cb89d2 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/models.py @@ -0,0 +1,36 @@ +from pydantic import BaseModel + + +class FollowUpSubQuestion(BaseModel): + sub_question: str + sub_question_id: str + verified: bool + answered: bool + answer: str + + +class AgentTimings(BaseModel): + base_duration__s: float | None + refined_duration__s: float | None + full_duration__s: float | None + + +class AgentBaseMetrics(BaseModel): + num_verified_documents_total: int | None + num_verified_documents_core: int | None + verified_avg_score_core: float | None + num_verified_documents_base: int | float | None + verified_avg_score_base: float | None + base_doc_boost_factor: float | None + support_boost_factor: float | None + duration__s: float | None + + +class AgentRefinedMetrics(BaseModel): + refined_doc_boost_factor: float | None + refined_question_boost_factor: float | None + duration__s: float | None + + +class AgentAdditionalMetrics(BaseModel): + pass diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py new file mode 100644 index 000000000000..d17f4f45bb35 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py @@ -0,0 +1,109 @@ +from datetime import datetime +from typing import cast + +from langchain_core.runnables import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.main.models import AgentAdditionalMetrics +from onyx.agents.agent_search.deep_search_a.main.models import AgentTimings +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import MainOutput +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics +from onyx.db.chat import log_agent_metrics +from onyx.db.chat import log_agent_sub_question_results + + +def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput: + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------LOGGING NODE---") + + agent_start_time = state["agent_start_time"] + agent_base_end_time = state["agent_base_end_time"] + agent_refined_start_time = state["agent_refined_start_time"] or None + agent_refined_end_time = state["agent_refined_end_time"] or None + agent_end_time = agent_refined_end_time or agent_base_end_time + + agent_base_duration = None + if agent_base_end_time: + agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds() + + agent_refined_duration = None + if agent_refined_start_time and agent_refined_end_time: + agent_refined_duration = ( + agent_refined_end_time - agent_refined_start_time + ).total_seconds() + + agent_full_duration = None + if agent_end_time: + agent_full_duration = (agent_end_time - agent_start_time).total_seconds() + + agent_type = "refined" if agent_refined_duration else "base" + + agent_base_metrics = state["agent_base_metrics"] + agent_refined_metrics = state["agent_refined_metrics"] + + combined_agent_metrics = CombinedAgentMetrics( + timings=AgentTimings( + base_duration__s=agent_base_duration, + refined_duration__s=agent_refined_duration, + full_duration__s=agent_full_duration, + ), + base_metrics=agent_base_metrics, + refined_metrics=agent_refined_metrics, + additional_metrics=AgentAdditionalMetrics(), + ) + + persona_id = None + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + if agent_a_config.search_request.persona: + persona_id = agent_a_config.search_request.persona.id + + user_id = None + user = agent_a_config.search_tool.user + if user: + user_id = user.id + + # log the agent metrics + if agent_a_config.db_session is not None: + log_agent_metrics( + db_session=agent_a_config.db_session, + user_id=user_id, + persona_id=persona_id, + agent_type=agent_type, + start_time=agent_start_time, + agent_metrics=combined_agent_metrics, + ) + + if agent_a_config.use_persistence: + # Persist the sub-answer in the database + db_session = agent_a_config.db_session + chat_session_id = agent_a_config.chat_session_id + primary_message_id = agent_a_config.message_id + sub_question_answer_results = state["decomp_answer_results"] + + log_agent_sub_question_results( + db_session=db_session, + chat_session_id=chat_session_id, + primary_message_id=primary_message_id, + sub_question_answer_results=sub_question_answer_results, + ) + + # if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None: + # create_sub_answer( + # db_session=db_session, + # chat_session_id=chat_session_id, + # primary_message_id=primary_message_id, + # sub_question_id=sub_question_id, + # answer=answer_str, + # # ) + # pass + + main_output = MainOutput() + + now_end = datetime.now() + + logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---") + + return main_output diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_decision.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_decision.py new file mode 100644 index 000000000000..3019d209d7ca --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_decision.py @@ -0,0 +1,87 @@ +from datetime import datetime +from typing import cast + +from langchain_core.messages import HumanMessage +from langchain_core.runnables import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.deep_search_a.main.states import RoutingDecision +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.prompts import AGENT_DECISION_PROMPT +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + AGENT_DECISION_PROMPT_AFTER_SEARCH, +) +from onyx.context.search.models import InferenceSection +from onyx.db.engine import get_session_context_manager +from onyx.tools.tool_implementations.search.search_tool import ( + SEARCH_RESPONSE_SUMMARY_ID, +) +from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary + + +def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision: + now_start = datetime.now() + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + question = agent_a_config.search_request.query + perform_initial_search_path_decision = ( + agent_a_config.perform_initial_search_path_decision + ) + + logger.debug(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---") + + if perform_initial_search_path_decision: + search_tool = agent_a_config.search_tool + retrieved_docs: list[InferenceSection] = [] + + # new db session to avoid concurrency issues + with get_session_context_manager() as db_session: + for tool_response in search_tool.run( + query=question, + force_no_rerank=True, + alternate_db_session=db_session, + ): + # get retrieved docs to send to the rest of the graph + if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: + response = cast(SearchResponseSummary, tool_response.response) + retrieved_docs = response.top_sections + break + + sample_doc_str = "\n\n".join( + [doc.combined_content for _, doc in enumerate(retrieved_docs[:3])] + ) + + agent_decision_prompt = AGENT_DECISION_PROMPT_AFTER_SEARCH.format( + question=question, sample_doc_str=sample_doc_str + ) + + else: + sample_doc_str = "" + agent_decision_prompt = AGENT_DECISION_PROMPT.format(question=question) + + msg = [HumanMessage(content=agent_decision_prompt)] + + # Get the rewritten queries in a defined format + model = agent_a_config.fast_llm + + # no need to stream this + resp = model.invoke(msg) + + if isinstance(resp.content, str) and "research" in resp.content.lower(): + routing = "agent_search" + else: + routing = "LLM" + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---" + ) + + return RoutingDecision( + # Decide which route to take + routing=routing, + sample_doc_str=sample_doc_str, + log_messages=[f"Path decision: {routing}, Time taken: {now_end - now_start}"], + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_routing.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_routing.py new file mode 100644 index 000000000000..74c9d78f9ccf --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_path_routing.py @@ -0,0 +1,23 @@ +from typing import Literal + +from langgraph.types import Command + +from onyx.agents.agent_search.deep_search_a.main.states import MainState + + +def agent_path_routing( + state: MainState, +) -> Command[Literal["agent_search_start", "LLM"]]: + routing = state.get("routing", "agent_search") + + if routing == "agent_search": + agent_path = "agent_search_start" + else: + agent_path = "LLM" + + return Command( + # state update + update={"log_messages": [f"Path routing: {agent_path}"]}, + # control flow + goto=agent_path, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_search_start.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_search_start.py new file mode 100644 index 000000000000..3ba00c3f02dd --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_search_start.py @@ -0,0 +1,7 @@ +from onyx.agents.agent_search.core_state import CoreState + + +def agent_search_start(state: CoreState) -> CoreState: + return CoreState( + log_messages=["Agent search start"], + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/direct_llm_handling.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/direct_llm_handling.py new file mode 100644 index 000000000000..712978f55a57 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/direct_llm_handling.py @@ -0,0 +1,87 @@ +from datetime import datetime +from typing import Any +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_content +from langchain_core.runnables import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + ASSISTANT_SYSTEM_PROMPT_DEFAULT, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + ASSISTANT_SYSTEM_PROMPT_PERSONA, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT +from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt +from onyx.chat.models import AgentAnswerPiece + + +def direct_llm_handling( + state: MainState, config: RunnableConfig +) -> InitialAnswerUpdate: + now_start = datetime.now() + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + question = agent_a_config.search_request.query + persona_prompt = get_persona_prompt(agent_a_config.search_request.persona) + + if len(persona_prompt) == 0: + persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT + else: + persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format( + persona_prompt=persona_prompt + ) + + logger.debug(f"--------{now_start}--------LLM HANDLING START---") + + model = agent_a_config.fast_llm + + msg = [ + HumanMessage( + content=DIRECT_LLM_PROMPT.format( + persona_specification=persona_specification, question=question + ) + ) + ] + + streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] + + for message in model.stream(msg): + # TODO: in principle, the answer here COULD contain images, but we don't support that yet + content = message.content + if not isinstance(content, str): + raise ValueError( + f"Expected content to be a string, but got {type(content)}" + ) + dispatch_custom_event( + "initial_agent_answer", + AgentAnswerPiece( + answer_piece=content, + level=0, + level_question_nr=0, + answer_type="agent_level_answer", + ), + ) + streamed_tokens.append(content) + + response = merge_content(*streamed_tokens) + answer = cast(str, response) + + now_end = datetime.now() + + logger.debug(f"--------{now_end}--{now_end - now_start}--------LLM HANDLING END---") + + return InitialAnswerUpdate( + initial_answer=answer, + initial_agent_stats=None, + generated_sub_questions=[], + agent_base_end_time=now_end, + agent_base_metrics=None, + log_messages=[f"LLM handling: {now_end - now_start}"], + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/entity_term_extraction_llm.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/entity_term_extraction_llm.py new file mode 100644 index 000000000000..381b9e48fed0 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/entity_term_extraction_llm.py @@ -0,0 +1,131 @@ +import json +import re +from datetime import datetime +from typing import cast + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs +from langchain_core.runnables import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import ( + EntityTermExtractionUpdate, +) +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( + trim_prompt_piece, +) +from onyx.agents.agent_search.shared_graph_utils.models import Entity +from onyx.agents.agent_search.shared_graph_utils.models import ( + EntityRelationshipTermExtraction, +) +from onyx.agents.agent_search.shared_graph_utils.models import Relationship +from onyx.agents.agent_search.shared_graph_utils.models import Term +from onyx.agents.agent_search.shared_graph_utils.operators import ( + dedup_inference_sections, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT +from onyx.agents.agent_search.shared_graph_utils.utils import format_docs + + +def entity_term_extraction_llm( + state: MainState, config: RunnableConfig +) -> EntityTermExtractionUpdate: + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------GENERATE ENTITIES & TERMS---") + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + if not agent_a_config.allow_refinement: + return EntityTermExtractionUpdate( + entity_retlation_term_extractions=EntityRelationshipTermExtraction( + entities=[], + relationships=[], + terms=[], + ) + ) + + # first four lines duplicates from generate_initial_answer + question = agent_a_config.search_request.query + sub_question_docs = state["documents"] + all_original_question_documents = state["all_original_question_documents"] + relevant_docs = dedup_inference_sections( + sub_question_docs, all_original_question_documents + ) + + # start with the entity/term/extraction + + doc_context = format_docs(relevant_docs) + + doc_context = trim_prompt_piece( + agent_a_config.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question + ) + msg = [ + HumanMessage( + content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context), + ) + ] + fast_llm = agent_a_config.fast_llm + # Grader + llm_response_list = list( + fast_llm.stream( + prompt=msg, + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + cleaned_response = re.sub(r"```json\n|\n```", "", llm_response) + parsed_response = json.loads(cleaned_response) + + entities = [] + relationships = [] + terms = [] + for entity in parsed_response.get("retrieved_entities_relationships", {}).get( + "entities", {} + ): + entity_name = entity.get("entity_name", "") + entity_type = entity.get("entity_type", "") + entities.append(Entity(entity_name=entity_name, entity_type=entity_type)) + + for relationship in parsed_response.get("retrieved_entities_relationships", {}).get( + "relationships", {} + ): + relationship_name = relationship.get("relationship_name", "") + relationship_type = relationship.get("relationship_type", "") + relationship_entities = relationship.get("relationship_entities", []) + relationships.append( + Relationship( + relationship_name=relationship_name, + relationship_type=relationship_type, + relationship_entities=relationship_entities, + ) + ) + + for term in parsed_response.get("retrieved_entities_relationships", {}).get( + "terms", {} + ): + term_name = term.get("term_name", "") + term_type = term.get("term_type", "") + term_similar_to = term.get("term_similar_to", []) + terms.append( + Term( + term_name=term_name, + term_type=term_type, + term_similar_to=term_similar_to, + ) + ) + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------ENTITY TERM EXTRACTION END---" + ) + + return EntityTermExtractionUpdate( + entity_retlation_term_extractions=EntityRelationshipTermExtraction( + entities=entities, + relationships=relationships, + terms=terms, + ) + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py new file mode 100644 index 000000000000..0075b202a8e0 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py @@ -0,0 +1,252 @@ +from datetime import datetime +from typing import Any +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_content +from langchain_core.runnables import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics +from onyx.agents.agent_search.deep_search_a.main.operations import ( + calculate_initial_agent_stats, +) +from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.operations import ( + remove_document_citations, +) +from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( + trim_prompt_piece, +) +from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats +from onyx.agents.agent_search.shared_graph_utils.operators import ( + dedup_inference_sections, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + ASSISTANT_SYSTEM_PROMPT_DEFAULT, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + ASSISTANT_SYSTEM_PROMPT_PERSONA, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + SUB_QUESTION_ANSWER_TEMPLATE, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER +from onyx.agents.agent_search.shared_graph_utils.utils import format_docs +from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt +from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.chat.models import AgentAnswerPiece +from onyx.chat.models import ExtendedToolResponse +from onyx.tools.tool_implementations.search.search_tool import yield_search_responses + + +def generate_initial_answer( + state: MainState, config: RunnableConfig +) -> InitialAnswerUpdate: + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------GENERATE INITIAL---") + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + question = agent_a_config.search_request.query + persona_prompt = get_persona_prompt(agent_a_config.search_request.persona) + sub_question_docs = state["documents"] + all_original_question_documents = state["all_original_question_documents"] + + relevant_docs = dedup_inference_sections( + sub_question_docs, all_original_question_documents + ) + decomp_questions = [] + + if len(relevant_docs) == 0: + dispatch_custom_event( + "initial_agent_answer", + AgentAnswerPiece( + answer_piece=UNKNOWN_ANSWER, + level=0, + level_question_nr=0, + answer_type="agent_level_answer", + ), + ) + + answer = UNKNOWN_ANSWER + initial_agent_stats = InitialAgentResultStats( + sub_questions={}, + original_question={}, + agent_effectiveness={}, + ) + + else: + # Use the query info from the base document retrieval + query_info = get_query_info(state["original_question_retrieval_results"]) + + for tool_response in yield_search_responses( + query=question, + reranked_sections=relevant_docs, + final_context_sections=relevant_docs, + search_query_info=query_info, + get_section_relevance=lambda: None, # TODO: add relevance + search_tool=agent_a_config.search_tool, + ): + dispatch_custom_event( + "tool_response", + ExtendedToolResponse( + id=tool_response.id, + response=tool_response.response, + level=0, + level_question_nr=0, # 0, 0 is the base question + ), + ) + + net_new_original_question_docs = [] + for all_original_question_doc in all_original_question_documents: + if all_original_question_doc not in sub_question_docs: + net_new_original_question_docs.append(all_original_question_doc) + + decomp_answer_results = state["decomp_answer_results"] + + good_qa_list: list[str] = [] + + sub_question_nr = 1 + + for decomp_answer_result in decomp_answer_results: + decomp_questions.append(decomp_answer_result.question) + _, question_nr = parse_question_id(decomp_answer_result.question_id) + if ( + decomp_answer_result.quality.lower().startswith("yes") + and len(decomp_answer_result.answer) > 0 + and decomp_answer_result.answer != UNKNOWN_ANSWER + ): + good_qa_list.append( + SUB_QUESTION_ANSWER_TEMPLATE.format( + sub_question=decomp_answer_result.question, + sub_answer=decomp_answer_result.answer, + sub_question_nr=sub_question_nr, + ) + ) + sub_question_nr += 1 + + if len(good_qa_list) > 0: + sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) + else: + sub_question_answer_str = "" + + # Determine which persona-specification prompt to use + + if len(persona_prompt) == 0: + persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT + else: + persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format( + persona_prompt=persona_prompt + ) + + # Determine which base prompt to use given the sub-question information + if len(good_qa_list) > 0: + base_prompt = INITIAL_RAG_PROMPT + else: + base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS + + model = agent_a_config.fast_llm + + doc_context = format_docs(relevant_docs) + doc_context = trim_prompt_piece( + model.config, + doc_context, + base_prompt + sub_question_answer_str + persona_specification, + ) + + msg = [ + HumanMessage( + content=base_prompt.format( + question=question, + answered_sub_questions=remove_document_citations( + sub_question_answer_str + ), + relevant_docs=format_docs(relevant_docs), + persona_specification=persona_specification, + ) + ) + ] + + streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] + for message in model.stream(msg): + # TODO: in principle, the answer here COULD contain images, but we don't support that yet + content = message.content + if not isinstance(content, str): + raise ValueError( + f"Expected content to be a string, but got {type(content)}" + ) + dispatch_custom_event( + "initial_agent_answer", + AgentAnswerPiece( + answer_piece=content, + level=0, + level_question_nr=0, + answer_type="agent_level_answer", + ), + ) + streamed_tokens.append(content) + + response = merge_content(*streamed_tokens) + answer = cast(str, response) + + initial_agent_stats = calculate_initial_agent_stats( + state["decomp_answer_results"], state["original_question_retrieval_stats"] + ) + + logger.debug( + f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n" + ) + + if initial_agent_stats: + logger.debug(initial_agent_stats.original_question) + logger.debug(initial_agent_stats.sub_questions) + logger.debug(initial_agent_stats.agent_effectiveness) + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n" + ) + + agent_base_end_time = datetime.now() + + agent_base_metrics = AgentBaseMetrics( + num_verified_documents_total=len(relevant_docs), + num_verified_documents_core=state[ + "original_question_retrieval_stats" + ].verified_count, + verified_avg_score_core=state[ + "original_question_retrieval_stats" + ].verified_avg_scores, + num_verified_documents_base=initial_agent_stats.sub_questions.get( + "num_verified_documents", None + ), + verified_avg_score_base=initial_agent_stats.sub_questions.get( + "verified_avg_score", None + ), + base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get( + "utilized_chunk_ratio", None + ), + support_boost_factor=initial_agent_stats.agent_effectiveness.get( + "support_ratio", None + ), + duration__s=(agent_base_end_time - state["agent_start_time"]).total_seconds(), + ) + + return InitialAnswerUpdate( + initial_answer=answer, + initial_agent_stats=initial_agent_stats, + generated_sub_questions=decomp_questions, + agent_base_end_time=agent_base_end_time, + agent_base_metrics=agent_base_metrics, + log_messages=[f"Initial answer generation: {now_end - now_start}"], + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_base_search_only_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_base_search_only_answer.py new file mode 100644 index 000000000000..c1afbef2f194 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_base_search_only_answer.py @@ -0,0 +1,56 @@ +from datetime import datetime +from typing import cast + +from langchain_core.messages import HumanMessage +from langchain_core.runnables import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerBASEUpdate +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( + trim_prompt_piece, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT +from onyx.agents.agent_search.shared_graph_utils.utils import format_docs + + +def generate_initial_base_search_only_answer( + state: MainState, + config: RunnableConfig, +) -> InitialAnswerBASEUpdate: + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---") + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + question = agent_a_config.search_request.query + original_question_docs = state["all_original_question_documents"] + + model = agent_a_config.fast_llm + + doc_context = format_docs(original_question_docs) + doc_context = trim_prompt_piece( + model.config, doc_context, INITIAL_RAG_BASE_PROMPT + question + ) + + msg = [ + HumanMessage( + content=INITIAL_RAG_BASE_PROMPT.format( + question=question, + context=doc_context, + ) + ) + ] + + # Grader + response = model.invoke(msg) + answer = response.pretty_repr() + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n" + ) + + return InitialAnswerBASEUpdate(initial_base_answer=answer) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py new file mode 100644 index 000000000000..c0d850750203 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py @@ -0,0 +1,317 @@ +from datetime import datetime +from typing import Any +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_content +from langchain_core.runnables import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics +from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.operations import ( + remove_document_citations, +) +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.deep_search_a.main.states import RefinedAnswerUpdate +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( + trim_prompt_piece, +) +from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats +from onyx.agents.agent_search.shared_graph_utils.operators import ( + dedup_inference_sections, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + ASSISTANT_SYSTEM_PROMPT_DEFAULT, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + ASSISTANT_SYSTEM_PROMPT_PERSONA, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + REVISED_RAG_PROMPT_NO_SUB_QUESTIONS, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + SUB_QUESTION_ANSWER_TEMPLATE, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER +from onyx.agents.agent_search.shared_graph_utils.utils import format_docs +from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt +from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.chat.models import AgentAnswerPiece +from onyx.chat.models import ExtendedToolResponse +from onyx.tools.tool_implementations.search.search_tool import yield_search_responses + + +def generate_refined_answer( + state: MainState, config: RunnableConfig +) -> RefinedAnswerUpdate: + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------GENERATE REFINED ANSWER---") + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + question = agent_a_config.search_request.query + persona_prompt = get_persona_prompt(agent_a_config.search_request.persona) + + initial_documents = state["documents"] + revised_documents = state["refined_documents"] + + combined_documents = dedup_inference_sections(initial_documents, revised_documents) + + query_info = get_query_info(state["original_question_retrieval_results"]) + # stream refined answer docs + for tool_response in yield_search_responses( + query=question, + reranked_sections=combined_documents, + final_context_sections=combined_documents, + search_query_info=query_info, + get_section_relevance=lambda: None, # TODO: add relevance + search_tool=agent_a_config.search_tool, + ): + dispatch_custom_event( + "tool_response", + ExtendedToolResponse( + id=tool_response.id, + response=tool_response.response, + level=1, + level_question_nr=0, # 0, 0 is the base question + ), + ) + + if len(initial_documents) > 0: + revision_doc_effectiveness = len(combined_documents) / len(initial_documents) + elif len(revised_documents) == 0: + revision_doc_effectiveness = 0.0 + else: + revision_doc_effectiveness = 10.0 + + decomp_answer_results = state["decomp_answer_results"] + # revised_answer_results = state["refined_decomp_answer_results"] + + good_qa_list: list[str] = [] + decomp_questions = [] + + initial_good_sub_questions: list[str] = [] + new_revised_good_sub_questions: list[str] = [] + + sub_question_nr = 1 + + for decomp_answer_result in decomp_answer_results: + question_level, question_nr = parse_question_id( + decomp_answer_result.question_id + ) + + decomp_questions.append(decomp_answer_result.question) + if ( + decomp_answer_result.quality.lower().startswith("yes") + and len(decomp_answer_result.answer) > 0 + and decomp_answer_result.answer != UNKNOWN_ANSWER + ): + good_qa_list.append( + SUB_QUESTION_ANSWER_TEMPLATE.format( + sub_question=decomp_answer_result.question, + sub_answer=decomp_answer_result.answer, + sub_question_nr=sub_question_nr, + ) + ) + if question_level == 0: + initial_good_sub_questions.append(decomp_answer_result.question) + else: + new_revised_good_sub_questions.append(decomp_answer_result.question) + + sub_question_nr += 1 + + initial_good_sub_questions = list(set(initial_good_sub_questions)) + new_revised_good_sub_questions = list(set(new_revised_good_sub_questions)) + total_good_sub_questions = list( + set(initial_good_sub_questions + new_revised_good_sub_questions) + ) + if len(initial_good_sub_questions) > 0: + revision_question_efficiency: float = len(total_good_sub_questions) / len( + initial_good_sub_questions + ) + elif len(new_revised_good_sub_questions) > 0: + revision_question_efficiency = 10.0 + else: + revision_question_efficiency = 1.0 + + sub_question_answer_str = "\n\n------\n\n".join(list(set(good_qa_list))) + + # original answer + + initial_answer = state["initial_answer"] + + # Determine which persona-specification prompt to use + + if len(persona_prompt) == 0: + persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT + else: + persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format( + persona_prompt=persona_prompt + ) + + # Determine which base prompt to use given the sub-question information + if len(good_qa_list) > 0: + base_prompt = REVISED_RAG_PROMPT + else: + base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS + + model = agent_a_config.fast_llm + relevant_docs = format_docs(combined_documents) + relevant_docs = trim_prompt_piece( + model.config, + relevant_docs, + base_prompt + + question + + sub_question_answer_str + + relevant_docs + + initial_answer + + persona_specification, + ) + + msg = [ + HumanMessage( + content=base_prompt.format( + question=question, + answered_sub_questions=remove_document_citations( + sub_question_answer_str + ), + relevant_docs=relevant_docs, + initial_answer=remove_document_citations(initial_answer), + persona_specification=persona_specification, + ) + ) + ] + + # Grader + + streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] + for message in model.stream(msg): + # TODO: in principle, the answer here COULD contain images, but we don't support that yet + content = message.content + if not isinstance(content, str): + raise ValueError( + f"Expected content to be a string, but got {type(content)}" + ) + dispatch_custom_event( + "refined_agent_answer", + AgentAnswerPiece( + answer_piece=content, + level=1, + level_question_nr=0, + answer_type="agent_level_answer", + ), + ) + streamed_tokens.append(content) + + response = merge_content(*streamed_tokens) + answer = cast(str, response) + + # refined_agent_stats = _calculate_refined_agent_stats( + # state["decomp_answer_results"], state["original_question_retrieval_stats"] + # ) + + initial_good_sub_questions_str = "\n".join(list(set(initial_good_sub_questions))) + new_revised_good_sub_questions_str = "\n".join( + list(set(new_revised_good_sub_questions)) + ) + + refined_agent_stats = RefinedAgentStats( + revision_doc_efficiency=revision_doc_effectiveness, + revision_question_efficiency=revision_question_efficiency, + ) + + logger.debug( + f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}" + ) + logger.debug("-" * 10) + logger.debug(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") + + logger.debug("-" * 100) + logger.debug(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n") + logger.debug("-" * 10) + logger.debug( + f"\n\nNEW REVISED Sub-Questions\n\n{new_revised_good_sub_questions_str}\n\n" + ) + + logger.debug("-" * 100) + + logger.debug( + f"\n\nINITAL & REVISED Sub-Questions & Answers:\n\n{sub_question_answer_str}\n\nStas:\n\n" + ) + + logger.debug("-" * 100) + + if state["initial_agent_stats"]: + initial_doc_boost_factor = state["initial_agent_stats"].agent_effectiveness.get( + "utilized_chunk_ratio", "--" + ) + initial_support_boost_factor = state[ + "initial_agent_stats" + ].agent_effectiveness.get("support_ratio", "--") + num_initial_verified_docs = state["initial_agent_stats"].original_question.get( + "num_verified_documents", "--" + ) + initial_verified_docs_avg_score = state[ + "initial_agent_stats" + ].original_question.get("verified_avg_score", "--") + initial_sub_questions_verified_docs = state[ + "initial_agent_stats" + ].sub_questions.get("num_verified_documents", "--") + + logger.debug("INITIAL AGENT STATS") + logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}") + logger.debug(f"Support Boost Factor: {initial_support_boost_factor}") + logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}") + logger.debug( + f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}" + ) + logger.debug( + f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}" + ) + if refined_agent_stats: + logger.debug("-" * 10) + logger.debug("REFINED AGENT STATS") + logger.debug( + f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}" + ) + logger.debug( + f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}" + ) + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n" + ) + + agent_refined_end_time = datetime.now() + if state["agent_refined_start_time"]: + agent_refined_duration = ( + agent_refined_end_time - state["agent_refined_start_time"] + ).total_seconds() + else: + agent_refined_duration = None + + agent_refined_metrics = AgentRefinedMetrics( + refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency, + refined_question_boost_factor=refined_agent_stats.revision_question_efficiency, + duration__s=agent_refined_duration, + ) + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER UPDATE END---" + ) + + return RefinedAnswerUpdate( + refined_answer=answer, + refined_answer_quality=True, # TODO: replace this with the actual check value + refined_agent_stats=refined_agent_stats, + agent_refined_end_time=agent_refined_end_time, + agent_refined_metrics=agent_refined_metrics, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_base_retrieval.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_base_retrieval.py new file mode 100644 index 000000000000..0ee387376ef5 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_base_retrieval.py @@ -0,0 +1,40 @@ +from datetime import datetime + +from onyx.agents.agent_search.deep_search_a.base_raw_search.states import ( + BaseRawSearchOutput, +) +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import ExpandedRetrievalUpdate +from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats + + +def ingest_initial_base_retrieval( + state: BaseRawSearchOutput, +) -> ExpandedRetrievalUpdate: + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---") + + sub_question_retrieval_stats = state[ + "base_expanded_retrieval_result" + ].sub_question_retrieval_stats + if sub_question_retrieval_stats is None: + sub_question_retrieval_stats = AgentChunkStats() + else: + sub_question_retrieval_stats = sub_question_retrieval_stats + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------INGEST INITIAL RETRIEVAL END---" + ) + + return ExpandedRetrievalUpdate( + original_question_retrieval_results=state[ + "base_expanded_retrieval_result" + ].expanded_queries_results, + all_original_question_documents=state[ + "base_expanded_retrieval_result" + ].all_documents, + original_question_retrieval_stats=sub_question_retrieval_stats, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_sub_question_answers.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_sub_question_answers.py new file mode 100644 index 000000000000..90a21c206983 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_initial_sub_question_answers.py @@ -0,0 +1,35 @@ +from datetime import datetime + +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate +from onyx.agents.agent_search.shared_graph_utils.operators import ( + dedup_inference_sections, +) + + +def ingest_initial_sub_question_answers( + state: AnswerQuestionOutput, +) -> DecompAnswersUpdate: + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------INGEST ANSWERS---") + documents = [] + answer_results = state.get("answer_results", []) + for answer_result in answer_results: + documents.extend(answer_result.documents) + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---" + ) + + return DecompAnswersUpdate( + # Deduping is done by the documents operator for the main graph + # so we might not need to dedup here + documents=dedup_inference_sections(documents, []), + decomp_answer_results=answer_results, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_refined_answers.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_refined_answers.py new file mode 100644 index 000000000000..2a3384f6d99d --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/ingest_refined_answers.py @@ -0,0 +1,36 @@ +from datetime import datetime + +from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate +from onyx.agents.agent_search.shared_graph_utils.operators import ( + dedup_inference_sections, +) + + +def ingest_refined_answers( + state: AnswerQuestionOutput, +) -> DecompAnswersUpdate: + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---") + + documents = [] + answer_results = state.get("answer_results", []) + for answer_result in answer_results: + documents.extend(answer_result.documents) + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------INGEST FOLLOW UP ANSWERS END---" + ) + + return DecompAnswersUpdate( + # Deduping is done by the documents operator for the main graph + # so we might not need to dedup here + documents=dedup_inference_sections(documents, []), + decomp_answer_results=answer_results, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/initial_answer_quality_check.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/initial_answer_quality_check.py new file mode 100644 index 000000000000..fed6dfd56f45 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/initial_answer_quality_check.py @@ -0,0 +1,35 @@ +from datetime import datetime + +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import ( + InitialAnswerQualityUpdate, +) +from onyx.agents.agent_search.deep_search_a.main.states import MainState + + +def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate: + """ + Check whether the final output satisfies the original user question + + Args: + state (messages): The current state + + Returns: + InitialAnswerQualityUpdate + """ + + now_start = datetime.now() + + logger.debug( + f"--------{now_start}--------Checking for base answer validity - for not set True/False manually" + ) + + verdict = True + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------INITIAL ANSWER QUALITY CHECK END---" + ) + + return InitialAnswerQualityUpdate(initial_answer_quality=verdict) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/initial_sub_question_creation.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/initial_sub_question_creation.py new file mode 100644 index 000000000000..3ddf2e4d6d14 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/initial_sub_question_creation.py @@ -0,0 +1,146 @@ +from datetime import datetime +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_content +from langchain_core.runnables import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics +from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import BaseDecompUpdate +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + INITIAL_DECOMPOSITION_PROMPT_QUESTIONS, +) +from onyx.agents.agent_search.shared_graph_utils.prompts import ( + INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH, +) +from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated +from onyx.chat.models import StreamStopInfo +from onyx.chat.models import StreamStopReason +from onyx.chat.models import SubQuestionPiece +from onyx.context.search.models import InferenceSection +from onyx.db.engine import get_session_context_manager +from onyx.tools.tool_implementations.search.search_tool import ( + SEARCH_RESPONSE_SUMMARY_ID, +) +from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary + + +def initial_sub_question_creation( + state: MainState, config: RunnableConfig +) -> BaseDecompUpdate: + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------BASE DECOMP START---") + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + question = agent_a_config.search_request.query + chat_session_id = agent_a_config.chat_session_id + primary_message_id = agent_a_config.message_id + perform_initial_search_decomposition = ( + agent_a_config.perform_initial_search_decomposition + ) + perform_initial_search_path_decision = ( + agent_a_config.perform_initial_search_path_decision + ) + + # Use the initial search results to inform the decomposition + sample_doc_str = state.get("sample_doc_str", "") + + if not chat_session_id or not primary_message_id: + raise ValueError( + "chat_session_id and message_id must be provided for agent search" + ) + agent_start_time = datetime.now() + + # Initial search to inform decomposition. Just get top 3 fits + + if perform_initial_search_decomposition: + if not perform_initial_search_path_decision: + search_tool = agent_a_config.search_tool + retrieved_docs: list[InferenceSection] = [] + + # new db session to avoid concurrency issues + with get_session_context_manager() as db_session: + for tool_response in search_tool.run( + query=question, + force_no_rerank=True, + alternate_db_session=db_session, + ): + # get retrieved docs to send to the rest of the graph + if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: + response = cast(SearchResponseSummary, tool_response.response) + retrieved_docs = response.top_sections + break + + sample_doc_str = "\n\n".join( + [doc.combined_content for _, doc in enumerate(retrieved_docs[:3])] + ) + + decomposition_prompt = ( + INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format( + question=question, sample_doc_str=sample_doc_str + ) + ) + + else: + decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format( + question=question + ) + + # Start decomposition + + msg = [HumanMessage(content=decomposition_prompt)] + + # Get the rewritten queries in a defined format + model = agent_a_config.fast_llm + + # Send the initial question as a subquestion with number 0 + dispatch_custom_event( + "decomp_qs", + SubQuestionPiece( + sub_question=question, + level=0, + level_question_nr=0, + ), + ) + # dispatches custom events for subquestion tokens, adding in subquestion ids. + streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0)) + + stop_event = StreamStopInfo( + stop_reason=StreamStopReason.FINISHED, + stream_type="sub_questions", + level=0, + ) + dispatch_custom_event("stream_finished", stop_event) + + deomposition_response = merge_content(*streamed_tokens) + + # this call should only return strings. Commenting out for efficiency + # assert [type(tok) == str for tok in streamed_tokens] + + # use no-op cast() instead of str() which runs code + # list_of_subquestions = clean_and_parse_list_string(cast(str, response)) + list_of_subqs = cast(str, deomposition_response).split("\n") + + decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""] + + now_end = datetime.now() + + logger.debug(f"--------{now_end}--{now_end - now_start}--------BASE DECOMP END---") + + return BaseDecompUpdate( + initial_decomp_questions=decomp_list, + agent_start_time=agent_start_time, + agent_refined_start_time=None, + agent_refined_end_time=None, + agent_refined_metrics=AgentRefinedMetrics( + refined_doc_boost_factor=None, + refined_question_boost_factor=None, + duration__s=None, + ), + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_answer_decision.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_answer_decision.py new file mode 100644 index 000000000000..8cad703f3003 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_answer_decision.py @@ -0,0 +1,39 @@ +from datetime import datetime +from typing import cast + +from langchain_core.runnables import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.deep_search_a.main.states import ( + RequireRefinedAnswerUpdate, +) +from onyx.agents.agent_search.models import AgentSearchConfig + + +def refined_answer_decision( + state: MainState, config: RunnableConfig +) -> RequireRefinedAnswerUpdate: + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------REFINED ANSWER DECISION---") + + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + if "?" in agent_a_config.search_request.query: + decision = False + else: + decision = True + + decision = True + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER DECISION END---" + ) + + if not agent_a_config.allow_refinement: + return RequireRefinedAnswerUpdate(require_refined_answer=decision) + + else: + return RequireRefinedAnswerUpdate(require_refined_answer=not decision) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_sub_question_creation.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_sub_question_creation.py new file mode 100644 index 000000000000..6ec399a9e79b --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/refined_sub_question_creation.py @@ -0,0 +1,112 @@ +from datetime import datetime +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_content +from langchain_core.runnables import RunnableConfig + +from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion +from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion +from onyx.agents.agent_search.deep_search_a.main.operations import logger +from onyx.agents.agent_search.deep_search_a.main.states import ( + FollowUpSubQuestionsUpdate, +) +from onyx.agents.agent_search.deep_search_a.main.states import MainState +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT +from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated +from onyx.agents.agent_search.shared_graph_utils.utils import ( + format_entity_term_extraction, +) +from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id +from onyx.tools.models import ToolCallKickoff + + +def refined_sub_question_creation( + state: MainState, config: RunnableConfig +) -> FollowUpSubQuestionsUpdate: + """ """ + agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) + dispatch_custom_event( + "start_refined_answer_creation", + ToolCallKickoff( + tool_name="agent_search_1", + tool_args={ + "query": agent_a_config.search_request.query, + "answer": state["initial_answer"], + }, + ), + ) + + now_start = datetime.now() + + logger.debug(f"--------{now_start}--------FOLLOW UP DECOMPOSE---") + + agent_refined_start_time = datetime.now() + + question = agent_a_config.search_request.query + base_answer = state["initial_answer"] + + # get the entity term extraction dict and properly format it + entity_retlation_term_extractions = state["entity_retlation_term_extractions"] + + entity_term_extraction_str = format_entity_term_extraction( + entity_retlation_term_extractions + ) + + initial_question_answers = state["decomp_answer_results"] + + addressed_question_list = [ + x.question for x in initial_question_answers if "yes" in x.quality.lower() + ] + + failed_question_list = [ + x.question for x in initial_question_answers if "no" in x.quality.lower() + ] + + msg = [ + HumanMessage( + content=DEEP_DECOMPOSE_PROMPT.format( + question=question, + entity_term_extraction_str=entity_term_extraction_str, + base_answer=base_answer, + answered_sub_questions="\n - ".join(addressed_question_list), + failed_sub_questions="\n - ".join(failed_question_list), + ), + ) + ] + + # Grader + model = agent_a_config.fast_llm + + streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1)) + response = merge_content(*streamed_tokens) + + if isinstance(response, str): + parsed_response = [q for q in response.split("\n") if q.strip() != ""] + else: + raise ValueError("LLM response is not a string") + + refined_sub_question_dict = {} + for sub_question_nr, sub_question in enumerate(parsed_response): + refined_sub_question = FollowUpSubQuestion( + sub_question=sub_question, + sub_question_id=make_question_id(1, sub_question_nr + 1), + verified=False, + answered=False, + answer="", + ) + + refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question + + now_end = datetime.now() + + logger.debug( + f"--------{now_end}--{now_end - now_start}--------FOLLOW UP DECOMPOSE END---" + ) + + return FollowUpSubQuestionsUpdate( + refined_sub_questions=refined_sub_question_dict, + agent_refined_start_time=agent_refined_start_time, + ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/operations.py b/backend/onyx/agents/agent_search/deep_search_a/main/operations.py new file mode 100644 index 000000000000..221d630bc80c --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/operations.py @@ -0,0 +1,145 @@ +import re +from collections.abc import Callable + +from langchain_core.callbacks.manager import dispatch_custom_event + +from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats +from onyx.agents.agent_search.shared_graph_utils.models import QueryResult +from onyx.agents.agent_search.shared_graph_utils.models import ( + QuestionAnswerResults, +) +from onyx.chat.models import SubQuestionPiece +from onyx.tools.models import SearchQueryInfo +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def remove_document_citations(text: str) -> str: + """ + Removes citation expressions of format '[[D1]]()' from text. + The number after D can vary. + + Args: + text: Input text containing citations + + Returns: + Text with citations removed + """ + # Pattern explanation: + # \[\[D\d+\]\]\(\) matches: + # \[\[ - literal [[ characters + # D - literal D character + # \d+ - one or more digits + # \]\] - literal ]] characters + # \(\) - literal () characters + return re.sub(r"\[\[(?:D|Q)\d+\]\]\(\)", "", text) + + +def dispatch_subquestion(level: int) -> Callable[[str, int], None]: + def _helper(sub_question_part: str, num: int) -> None: + dispatch_custom_event( + "decomp_qs", + SubQuestionPiece( + sub_question=sub_question_part, + level=level, + level_question_nr=num, + ), + ) + + return _helper + + +def calculate_initial_agent_stats( + decomp_answer_results: list[QuestionAnswerResults], + original_question_stats: AgentChunkStats, +) -> InitialAgentResultStats: + initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats( + sub_questions={}, + original_question={}, + agent_effectiveness={}, + ) + + orig_verified = original_question_stats.verified_count + orig_support_score = original_question_stats.verified_avg_scores + + verified_document_chunk_ids = [] + support_scores = 0.0 + + for decomp_answer_result in decomp_answer_results: + verified_document_chunk_ids += ( + decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids + ) + if ( + decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores + is not None + ): + support_scores += ( + decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores + ) + + verified_document_chunk_ids = list(set(verified_document_chunk_ids)) + + # Calculate sub-question stats + if ( + verified_document_chunk_ids + and len(verified_document_chunk_ids) > 0 + and support_scores is not None + ): + sub_question_stats: dict[str, float | int | None] = { + "num_verified_documents": len(verified_document_chunk_ids), + "verified_avg_score": float(support_scores / len(decomp_answer_results)), + } + else: + sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None} + + initial_agent_result_stats.sub_questions.update(sub_question_stats) + + # Get original question stats + initial_agent_result_stats.original_question.update( + { + "num_verified_documents": original_question_stats.verified_count, + "verified_avg_score": original_question_stats.verified_avg_scores, + } + ) + + # Calculate chunk utilization ratio + sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"] + + chunk_ratio: float | None = None + if sub_verified is not None and orig_verified is not None and orig_verified > 0: + chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0 + elif sub_verified is not None and sub_verified > 0: + chunk_ratio = 10.0 + + initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio + + if ( + orig_support_score is None + or orig_support_score == 0.0 + and initial_agent_result_stats.sub_questions["verified_avg_score"] is None + ): + initial_agent_result_stats.agent_effectiveness["support_ratio"] = None + elif orig_support_score is None or orig_support_score == 0.0: + initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10 + elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None: + initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0 + else: + initial_agent_result_stats.agent_effectiveness["support_ratio"] = ( + initial_agent_result_stats.sub_questions["verified_avg_score"] + / orig_support_score + ) + + return initial_agent_result_stats + + +def get_query_info(results: list[QueryResult]) -> SearchQueryInfo: + # Use the query info from the base document retrieval + # TODO: see if this is the right way to do this + query_infos = [ + result.query_info for result in results if result.query_info is not None + ] + if len(query_infos) == 0: + raise ValueError("No query info found") + return query_infos[0] diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/states.py b/backend/onyx/agents/agent_search/deep_search_a/main/states.py new file mode 100644 index 000000000000..9081f4a881e4 --- /dev/null +++ b/backend/onyx/agents/agent_search/deep_search_a/main/states.py @@ -0,0 +1,171 @@ +from datetime import datetime +from operator import add +from typing import Annotated +from typing import TypedDict + +from onyx.agents.agent_search.core_state import CoreState +from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import ( + ExpandedRetrievalResult, +) +from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics +from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics +from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion +from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agents.agent_search.shared_graph_utils.models import ( + EntityRelationshipTermExtraction, +) +from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats +from onyx.agents.agent_search.shared_graph_utils.models import QueryResult +from onyx.agents.agent_search.shared_graph_utils.models import ( + QuestionAnswerResults, +) +from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats +from onyx.agents.agent_search.shared_graph_utils.operators import ( + dedup_inference_sections, +) +from onyx.agents.agent_search.shared_graph_utils.operators import ( + dedup_question_answer_results, +) +from onyx.context.search.models import InferenceSection + + +### States ### + +## Update States + + +class RefinedAgentStartStats(TypedDict): + agent_refined_start_time: datetime | None + + +class RefinedAgentEndStats(TypedDict): + agent_refined_end_time: datetime | None + agent_refined_metrics: AgentRefinedMetrics + + +class BaseDecompUpdateBase(TypedDict): + agent_start_time: datetime + initial_decomp_questions: list[str] + + +class RoutingDecisionBase(TypedDict): + routing: str + sample_doc_str: str + + +class RoutingDecision(RoutingDecisionBase): + log_messages: list[str] + + +class BaseDecompUpdate( + RefinedAgentStartStats, RefinedAgentEndStats, BaseDecompUpdateBase +): + pass + + +class InitialAnswerBASEUpdate(TypedDict): + initial_base_answer: str + + +class InitialAnswerUpdateBase(TypedDict): + initial_answer: str + initial_agent_stats: InitialAgentResultStats | None + generated_sub_questions: list[str] + agent_base_end_time: datetime + agent_base_metrics: AgentBaseMetrics | None + + +class InitialAnswerUpdate(InitialAnswerUpdateBase): + log_messages: list[str] + + +class RefinedAnswerUpdateBase(TypedDict): + refined_answer: str + refined_agent_stats: RefinedAgentStats | None + refined_answer_quality: bool + + +class RefinedAnswerUpdate(RefinedAgentEndStats, RefinedAnswerUpdateBase): + pass + + +class InitialAnswerQualityUpdate(TypedDict): + initial_answer_quality: bool + + +class RequireRefinedAnswerUpdate(TypedDict): + require_refined_answer: bool + + +class DecompAnswersUpdate(TypedDict): + documents: Annotated[list[InferenceSection], dedup_inference_sections] + decomp_answer_results: Annotated[ + list[QuestionAnswerResults], dedup_question_answer_results + ] + + +class FollowUpDecompAnswersUpdate(TypedDict): + refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] + refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] + + +class ExpandedRetrievalUpdate(TypedDict): + all_original_question_documents: Annotated[ + list[InferenceSection], dedup_inference_sections + ] + original_question_retrieval_results: list[QueryResult] + original_question_retrieval_stats: AgentChunkStats + + +class EntityTermExtractionUpdate(TypedDict): + entity_retlation_term_extractions: EntityRelationshipTermExtraction + + +class FollowUpSubQuestionsUpdateBase(TypedDict): + refined_sub_questions: dict[int, FollowUpSubQuestion] + + +class FollowUpSubQuestionsUpdate( + RefinedAgentStartStats, FollowUpSubQuestionsUpdateBase +): + pass + + +## Graph Input State +## Graph Input State + + +class MainInput(CoreState): + pass + + +## Graph State + + +class MainState( + # This includes the core state + MainInput, + BaseDecompUpdateBase, + InitialAnswerUpdateBase, + InitialAnswerBASEUpdate, + DecompAnswersUpdate, + ExpandedRetrievalUpdate, + EntityTermExtractionUpdate, + InitialAnswerQualityUpdate, + RequireRefinedAnswerUpdate, + FollowUpSubQuestionsUpdateBase, + FollowUpDecompAnswersUpdate, + RefinedAnswerUpdateBase, + RefinedAgentStartStats, + RefinedAgentEndStats, + RoutingDecisionBase, +): + # expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add] + base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add] + + +## Graph Output State - presently not used + + +class MainOutput(TypedDict): + pass diff --git a/backend/onyx/agents/agent_search/models.py b/backend/onyx/agents/agent_search/models.py new file mode 100644 index 000000000000..c3857dbb673f --- /dev/null +++ b/backend/onyx/agents/agent_search/models.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass +from uuid import UUID + +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from onyx.context.search.models import SearchRequest +from onyx.llm.interfaces import LLM +from onyx.llm.models import PreviousMessage +from onyx.tools.tool_implementations.search.search_tool import SearchTool + + +@dataclass +class AgentSearchConfig: + """ + Configuration for the Agent Search feature. + """ + + # The search request that was used to generate the Pro Search + search_request: SearchRequest + + primary_llm: LLM + fast_llm: LLM + search_tool: SearchTool + use_agentic_search: bool = False + + # For persisting agent search data + chat_session_id: UUID | None = None + + # The message ID of the user message that triggered the Pro Search + message_id: int | None = None + + # Whether to persistence data for the Pro Search (turned off for testing) + use_persistence: bool = True + + # The database session for the Pro Search + db_session: Session | None = None + + # Whether to perform initial search to inform decomposition + perform_initial_search_path_decision: bool = False + + # Whether to perform initial search to inform decomposition + perform_initial_search_decomposition: bool = False + + # Whether to allow creation of refinement questions (and entity extraction, etc.) + allow_refinement: bool = False + + # Message history for the current chat session + message_history: list[PreviousMessage] | None = None + + structured_response_format: dict | None = None + + +class AgentDocumentCitations(BaseModel): + document_id: str + document_title: str + link: str diff --git a/backend/onyx/agents/agent_search/run_graph.py b/backend/onyx/agents/agent_search/run_graph.py new file mode 100644 index 000000000000..a95c377b8255 --- /dev/null +++ b/backend/onyx/agents/agent_search/run_graph.py @@ -0,0 +1,280 @@ +import asyncio +from asyncio import AbstractEventLoop +from collections.abc import AsyncIterable +from collections.abc import Iterable +from datetime import datetime +from typing import cast + +from langchain_core.runnables.schema import StreamEvent +from langgraph.graph.state import CompiledStateGraph + +from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder +from onyx.agents.agent_search.basic.states import BasicInput +from onyx.agents.agent_search.deep_search_a.main.graph_builder import ( + main_graph_builder as main_graph_builder_a, +) +from onyx.agents.agent_search.deep_search_a.main.states import MainInput as MainInput_a +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config +from onyx.chat.llm_response_handler import LLMResponseHandlerManager +from onyx.chat.models import AgentAnswerPiece +from onyx.chat.models import AnswerPacket +from onyx.chat.models import AnswerStream +from onyx.chat.models import ExtendedToolResponse +from onyx.chat.models import StreamStopInfo +from onyx.chat.models import SubQueryPiece +from onyx.chat.models import SubQuestionPiece +from onyx.chat.models import ToolResponse +from onyx.chat.prompt_builder.build import LLMCall +from onyx.configs.dev_configs import GRAPH_NAME +from onyx.context.search.models import SearchRequest +from onyx.db.engine import get_session_context_manager +from onyx.tools.tool_runner import ToolCallKickoff +from onyx.utils.logger import setup_logger + +logger = setup_logger() + +_COMPILED_GRAPH: CompiledStateGraph | None = None + + +def _set_combined_token_value( + combined_token: str, parsed_object: AgentAnswerPiece +) -> AgentAnswerPiece: + parsed_object.answer_piece = combined_token + + return parsed_object + + +def _parse_agent_event( + event: StreamEvent, +) -> AnswerPacket | None: + """ + Parse the event into a typed object. + Return None if we are not interested in the event. + """ + + event_type = event["event"] + + # We always just yield the event data, but this piece is useful for two development reasons: + # 1. It's a list of the names of every place we dispatch a custom event + # 2. We maintain the intended types yielded by each event + if event_type == "on_custom_event": + # TODO: different AnswerStream types for different events + if event["name"] == "decomp_qs": + return cast(SubQuestionPiece, event["data"]) + elif event["name"] == "subqueries": + return cast(SubQueryPiece, event["data"]) + elif event["name"] == "sub_answers": + return cast(AgentAnswerPiece, event["data"]) + elif event["name"] == "stream_finished": + return cast(StreamStopInfo, event["data"]) + elif event["name"] == "initial_agent_answer": + return cast(AgentAnswerPiece, event["data"]) + elif event["name"] == "refined_agent_answer": + return cast(AgentAnswerPiece, event["data"]) + elif event["name"] == "start_refined_answer_creation": + return cast(ToolCallKickoff, event["data"]) + elif event["name"] == "tool_response": + return cast(ToolResponse, event["data"]) + elif event["name"] == "basic_response": + return cast(AnswerPacket, event["data"]) + return None + + +async def tear_down(event_loop: AbstractEventLoop) -> None: + # Collect all tasks and cancel those that are not 'done'. + tasks = asyncio.all_tasks(event_loop) + for task in tasks: + task.cancel() + + # Wait for all tasks to complete, ignoring any CancelledErrors + try: + await asyncio.wait(tasks) + except asyncio.exceptions.CancelledError: + pass + + +def _manage_async_event_streaming( + compiled_graph: CompiledStateGraph, + config: AgentSearchConfig | None, + graph_input: MainInput_a | BasicInput, +) -> Iterable[StreamEvent]: + async def _run_async_event_stream( + loop: AbstractEventLoop, + ) -> AsyncIterable[StreamEvent]: + try: + message_id = config.message_id if config else None + async for event in compiled_graph.astream_events( + input=graph_input, + config={"metadata": {"config": config, "thread_id": str(message_id)}}, + # debug=True, + # indicating v2 here deserves further scrutiny + version="v2", + ): + yield event + finally: + await tear_down(loop) + + # This might be able to be simplified + def _yield_async_to_sync() -> Iterable[StreamEvent]: + loop = asyncio.new_event_loop() + try: + # Get the async generator + async_gen = _run_async_event_stream(loop) + # Convert to AsyncIterator + async_iter = async_gen.__aiter__() + while True: + try: + # Create a coroutine by calling anext with the async iterator + next_coro = anext(async_iter) + # Run the coroutine to get the next event + event = loop.run_until_complete(next_coro) + yield event + except StopAsyncIteration: + break + finally: + loop.close() + + return _yield_async_to_sync() + + +def run_graph( + compiled_graph: CompiledStateGraph, + config: AgentSearchConfig, + input: BasicInput | MainInput_a, +) -> AnswerStream: + input["base_question"] = config.search_request.query if config else "" + # TODO: add these to the environment + config.perform_initial_search_path_decision = True + config.perform_initial_search_decomposition = True + + for event in _manage_async_event_streaming( + compiled_graph=compiled_graph, config=config, graph_input=input + ): + if not (parsed_object := _parse_agent_event(event)): + continue + + yield parsed_object + + +# TODO: call this once on startup, TBD where and if it should be gated based +# on dev mode or not +def load_compiled_graph(graph_name: str) -> CompiledStateGraph: + main_graph_builder = ( + main_graph_builder_a if graph_name == "a" else main_graph_builder_a + ) + global _COMPILED_GRAPH + if _COMPILED_GRAPH is None: + graph = main_graph_builder() + _COMPILED_GRAPH = graph.compile() + return _COMPILED_GRAPH + + +def run_main_graph( + config: AgentSearchConfig, + graph_name: str = "a", +) -> AnswerStream: + compiled_graph = load_compiled_graph(graph_name) + if graph_name == "a": + input = MainInput_a() + else: + input = MainInput_a() + + # Agent search is not a Tool per se, but this is helpful for the frontend + yield ToolCallKickoff( + tool_name="agent_search_0", + tool_args={"query": config.search_request.query}, + ) + yield from run_graph(compiled_graph, config, input) + + +# TODO: unify input types, especially prosearchconfig +def run_basic_graph( + config: AgentSearchConfig, + last_llm_call: LLMCall | None, + response_handler_manager: LLMResponseHandlerManager, +) -> AnswerStream: + graph = basic_graph_builder() + compiled_graph = graph.compile() + # TODO: unify basic input + input = BasicInput( + base_question="", + last_llm_call=last_llm_call, + response_handler_manager=response_handler_manager, + calls=0, + ) + return run_graph(compiled_graph, config, input) + + +if __name__ == "__main__": + from onyx.llm.factory import get_default_llms + + now_start = datetime.now() + logger.debug(f"Start at {now_start}") + + if GRAPH_NAME == "a": + graph = main_graph_builder_a() + else: + graph = main_graph_builder_a() + compiled_graph = graph.compile() + now_end = datetime.now() + logger.debug(f"Graph compiled in {now_end - now_start} seconds") + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + # query="what can you do with gitlab?", + # query="What are the guiding principles behind the development of cockroachDB", + # query="What are the temperatures in Munich, Hawaii, and New York?", + # query="When was Washington born?", + query="What is Onyx?", + ) + # Joachim custom persona + + with get_session_context_manager() as db_session: + config, search_tool = get_test_config( + db_session, primary_llm, fast_llm, search_request + ) + # search_request.persona = get_persona_by_id(1, None, db_session) + config.use_persistence = True + config.perform_initial_search_path_decision = True + config.perform_initial_search_decomposition = True + if GRAPH_NAME == "a": + input = MainInput_a() + else: + input = MainInput_a() + # with open("output.txt", "w") as f: + tool_responses: list = [] + for output in run_graph(compiled_graph, config, input): + # pass + + if isinstance(output, ToolCallKickoff): + pass + elif isinstance(output, ExtendedToolResponse): + tool_responses.append(output.response) + logger.info( + f" ---- ET {output.level} - {output.level_question_nr} | " + ) + elif isinstance(output, SubQueryPiece): + logger.info( + f"Sq {output.level} - {output.level_question_nr} - {output.sub_query} | " + ) + elif isinstance(output, SubQuestionPiece): + logger.info( + f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | " + ) + elif ( + isinstance(output, AgentAnswerPiece) + and output.answer_type == "agent_sub_answer" + ): + logger.info( + f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | " + ) + elif ( + isinstance(output, AgentAnswerPiece) + and output.answer_type == "agent_level_answer" + ): + logger.info( + f" ---------- FA {output.level} - {output.level_question_nr} {output.answer_piece} | " + ) + + # for tool_response in tool_responses: + # logger.debug(tool_response) diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/agent_prompt_ops.py b/backend/onyx/agents/agent_search/shared_graph_utils/agent_prompt_ops.py new file mode 100644 index 000000000000..71326452ac36 --- /dev/null +++ b/backend/onyx/agents/agent_search/shared_graph_utils/agent_prompt_ops.py @@ -0,0 +1,65 @@ +from langchain.schema import AIMessage +from langchain.schema import HumanMessage +from langchain.schema import SystemMessage +from langchain_core.messages.tool import ToolMessage + +from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2 +from onyx.context.search.models import InferenceSection +from onyx.llm.interfaces import LLMConfig +from onyx.llm.utils import get_max_input_tokens +from onyx.natural_language_processing.utils import get_tokenizer +from onyx.natural_language_processing.utils import tokenizer_trim_content + + +def build_sub_question_answer_prompt( + question: str, + original_question: str, + docs: list[InferenceSection], + persona_specification: str, + config: LLMConfig, +) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]: + system_message = SystemMessage( + content=persona_specification, + ) + + docs_format_list = [ + f"""Document Number: [D{doc_nr + 1}]\n + Content: {doc.combined_content}\n\n""" + for doc_nr, doc in enumerate(docs) + ] + + docs_str = "\n\n".join(docs_format_list) + + docs_str = trim_prompt_piece( + config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question + ) + human_message = HumanMessage( + content=BASE_RAG_PROMPT_v2.format( + question=question, original_question=original_question, context=docs_str + ) + ) + + return [system_message, human_message] + + +def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str: + # TODO: this truncating might add latency. We could do a rougher + faster check + # first to determine whether truncation is needed + + # TODO: maybe save the tokenizer and max input tokens if this is getting called multiple times? + llm_tokenizer = get_tokenizer( + provider_type=config.model_provider, + model_name=config.model_name, + ) + + max_tokens = get_max_input_tokens( + model_provider=config.model_provider, + model_name=config.model_name, + ) + + # slightly conservative trimming + return tokenizer_trim_content( + content=prompt_piece, + desired_length=max_tokens - len(llm_tokenizer.encode(reserved_str)), + tokenizer=llm_tokenizer, + ) diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/calculations.py b/backend/onyx/agents/agent_search/shared_graph_utils/calculations.py new file mode 100644 index 000000000000..36b5f2975756 --- /dev/null +++ b/backend/onyx/agents/agent_search/shared_graph_utils/calculations.py @@ -0,0 +1,98 @@ +import numpy as np + +from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics +from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats +from onyx.chat.models import SectionRelevancePiece +from onyx.context.search.models import InferenceSection +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def unique_chunk_id(doc: InferenceSection) -> str: + return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}" + + +def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float: + shift = 0 + for rank_first, doc_id in enumerate(list1[:top_n], 1): + try: + rank_second = list2.index(doc_id) + 1 + except ValueError: + rank_second = len(list2) # Document not found in second list + + shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second) + + return shift / top_n + + +def get_fit_scores( + pre_reranked_results: list[InferenceSection], + post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece], +) -> RetrievalFitStats | None: + """ + Calculate retrieval metrics for search purposes + """ + + if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0: + return None + + ranked_sections = { + "initial": pre_reranked_results, + "reranked": post_reranked_results, + } + + fit_eval: RetrievalFitStats = RetrievalFitStats( + fit_score_lift=0, + rerank_effect=0, + fit_scores={ + "initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]), + "reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]), + }, + ) + + for rank_type, docs in ranked_sections.items(): + logger.debug(f"rank_type: {rank_type}") + + for i in [1, 5, 10]: + fit_eval.fit_scores[rank_type].scores[str(i)] = ( + sum( + [ + float(doc.center_chunk.score) + for doc in docs[:i] + if type(doc) == InferenceSection + and doc.center_chunk.score is not None + ] + ) + / i + ) + + fit_eval.fit_scores[rank_type].scores["fit_score"] = ( + 1 + / 3 + * ( + fit_eval.fit_scores[rank_type].scores["1"] + + fit_eval.fit_scores[rank_type].scores["5"] + + fit_eval.fit_scores[rank_type].scores["10"] + ) + ) + + fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[ + rank_type + ].scores["1"] + + fit_eval.fit_scores[rank_type].chunk_ids = [ + unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection + ] + + fit_eval.fit_score_lift = ( + fit_eval.fit_scores["reranked"].scores["fit_score"] + / fit_eval.fit_scores["initial"].scores["fit_score"] + ) + + fit_eval.rerank_effect = calculate_rank_shift( + fit_eval.fit_scores["initial"].chunk_ids, + fit_eval.fit_scores["reranked"].chunk_ids, + ) + + return fit_eval diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/models.py b/backend/onyx/agents/agent_search/shared_graph_utils/models.py new file mode 100644 index 000000000000..c38c3db821b1 --- /dev/null +++ b/backend/onyx/agents/agent_search/shared_graph_utils/models.py @@ -0,0 +1,112 @@ +from typing import Literal + +from pydantic import BaseModel + +from onyx.agents.agent_search.deep_search_a.main.models import AgentAdditionalMetrics +from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics +from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics +from onyx.agents.agent_search.deep_search_a.main.models import AgentTimings +from onyx.context.search.models import InferenceSection +from onyx.tools.models import SearchQueryInfo + + +# Pydantic models for structured outputs +class RewrittenQueries(BaseModel): + rewritten_queries: list[str] + + +class BinaryDecision(BaseModel): + decision: Literal["yes", "no"] + + +class BinaryDecisionWithReasoning(BaseModel): + reasoning: str + decision: Literal["yes", "no"] + + +class RetrievalFitScoreMetrics(BaseModel): + scores: dict[str, float] + chunk_ids: list[str] + + +class RetrievalFitStats(BaseModel): + fit_score_lift: float + rerank_effect: float + fit_scores: dict[str, RetrievalFitScoreMetrics] + + +class AgentChunkScores(BaseModel): + scores: dict[str, dict[str, list[int | float]]] + + +class AgentChunkStats(BaseModel): + verified_count: int | None + verified_avg_scores: float | None + rejected_count: int | None + rejected_avg_scores: float | None + verified_doc_chunk_ids: list[str] + dismissed_doc_chunk_ids: list[str] + + +class InitialAgentResultStats(BaseModel): + sub_questions: dict[str, float | int | None] + original_question: dict[str, float | int | None] + agent_effectiveness: dict[str, float | int | None] + + +class RefinedAgentStats(BaseModel): + revision_doc_efficiency: float | None + revision_question_efficiency: float | None + + +class Term(BaseModel): + term_name: str + term_type: str + term_similar_to: list[str] + + +### Models ### + + +class Entity(BaseModel): + entity_name: str + entity_type: str + + +class Relationship(BaseModel): + relationship_name: str + relationship_type: str + relationship_entities: list[str] + + +class EntityRelationshipTermExtraction(BaseModel): + entities: list[Entity] + relationships: list[Relationship] + terms: list[Term] + + +### Models ### + + +class QueryResult(BaseModel): + query: str + search_results: list[InferenceSection] + stats: RetrievalFitStats | None + query_info: SearchQueryInfo | None + + +class QuestionAnswerResults(BaseModel): + question: str + question_id: str + answer: str + quality: str + expanded_retrieval_results: list[QueryResult] + documents: list[InferenceSection] + sub_question_retrieval_stats: AgentChunkStats + + +class CombinedAgentMetrics(BaseModel): + timings: AgentTimings + base_metrics: AgentBaseMetrics | None + refined_metrics: AgentRefinedMetrics + additional_metrics: AgentAdditionalMetrics diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/operators.py b/backend/onyx/agents/agent_search/shared_graph_utils/operators.py new file mode 100644 index 000000000000..303593fc6266 --- /dev/null +++ b/backend/onyx/agents/agent_search/shared_graph_utils/operators.py @@ -0,0 +1,31 @@ +from onyx.agents.agent_search.shared_graph_utils.models import ( + QuestionAnswerResults, +) +from onyx.chat.prune_and_merge import _merge_sections +from onyx.context.search.models import InferenceSection + + +def dedup_inference_sections( + list1: list[InferenceSection], list2: list[InferenceSection] +) -> list[InferenceSection]: + deduped = _merge_sections(list1 + list2) + return deduped + + +def dedup_question_answer_results( + question_answer_results_1: list[QuestionAnswerResults], + question_answer_results_2: list[QuestionAnswerResults], +) -> list[QuestionAnswerResults]: + deduped_question_answer_results: list[ + QuestionAnswerResults + ] = question_answer_results_1 + utilized_question_ids: set[str] = set( + [x.question_id for x in question_answer_results_1] + ) + + for question_answer_result in question_answer_results_2: + if question_answer_result.question_id not in utilized_question_ids: + deduped_question_answer_results.append(question_answer_result) + utilized_question_ids.add(question_answer_result.question_id) + + return deduped_question_answer_results diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agents/agent_search/shared_graph_utils/prompts.py new file mode 100644 index 000000000000..00c6366e57f4 --- /dev/null +++ b/backend/onyx/agents/agent_search/shared_graph_utils/prompts.py @@ -0,0 +1,912 @@ +UNKNOWN_ANSWER = "I do not have enough information to answer this question." + +NO_RECOVERED_DOCS = "No relevant documents recovered" + +REWRITE_PROMPT_MULTI_ORIGINAL = """ \n + Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a + document store. Particularly, try to think about resolving ambiguities and make the search queries more specific, + enabling the system to search more broadly. + Also, try to make the search queries not redundant, i.e. not too similar! \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Formulate the queries separated by newlines (Do not say 'Query 1: ...', just write the querytext) as follows: + + +... + queries: """ + +REWRITE_PROMPT_MULTI = """ \n + Please create a list of 2-3 sample documents that could answer an original question. Each document + should be about as long as the original question. \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """ + +# The prompt is only used if there is no persona prompt, so the placeholder is '' +BASE_RAG_PROMPT = ( + """ \n + {persona_specification} + Use the context provided below - and only the + provided context - to answer the given question. (Note that the answer is in service of anserwing a broader + question, given below as 'motivation'.) + + Again, only use the provided context and do not use your internal knowledge! If you cannot answer the + question based on the context, say """ + + f'"{UNKNOWN_ANSWER}"' + + """. It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! + + Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal. + (But keep other details as well.) + + \nContext:\n {context} \n + + Motivation:\n {original_question} \n\n + \n\n + And here is the question I want you to answer based on the context above (with the motivation in mind): + \n--\n {question} \n--\n + """ +) + +BASE_RAG_PROMPT_v2 = ( + """ \n + Use the context provided below - and only the + provided context - to answer the given question. (Note that the answer is in service of answering a broader + question, given below as 'motivation'.) + + Again, only use the provided context and do not use your internal knowledge! If you cannot answer the + question based on the context, say """ + + f'"{UNKNOWN_ANSWER}"' + + """. It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! + + Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal. + (But keep other details as well.) + + Please remember to provide inline citations in the format [[D1]](), [[D2]](), [[D3]](), etc. + Proper citations are very important to the user!\n\n\n + + For your general information, here is the ultimate motivation: + \n--\n {original_question} \n--\n + \n\n + And here is the actual question I want you to answer based on the context above (with the motivation in mind): + \n--\n {question} \n--\n + + Here is the context: + \n\n\n--\n {context} \n--\n + """ +) + +SUB_CHECK_YES = "yes" +SUB_CHECK_NO = "no" + +SUB_CHECK_PROMPT = ( + """ + Your task is to see whether a given answer addresses a given question. + Please do not use any internal knowledge you may have - just focus on whether the answer + as given seems to largely address the question as given, or at least addresses part of the question. + Here is the question: + \n ------- \n + {question} + \n ------- \n + Here is the suggested answer: + \n ------- \n + {base_answer} + \n ------- \n + Does the suggested answer address the question? Please answer with """ + + f'"{SUB_CHECK_YES}" or "{SUB_CHECK_NO}".' +) + + +BASE_CHECK_PROMPT = """ \n + Please check whether 1) the suggested answer seems to fully address the original question AND 2)the + original question requests a simple, factual answer, and there are no ambiguities, judgements, + aggregations, or any other complications that may require extra context. (I.e., if the question is + somewhat addressed, but the answer would benefit from more context, then answer with 'no'.) + + Please only answer with 'yes' or 'no' \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Here is the proposed answer: + \n ------- \n + {initial_answer} + \n ------- \n + Please answer with yes or no:""" + +VERIFIER_PROMPT = """ +You are supposed to judge whether a document text contains data or information that is potentially relevant for a question. + +Here is a document text that you can take as a fact: +-- +DOCUMENT INFORMATION: +{document_content} +-- + +Do you think that this information is useful and relevant to answer the following question? +(Other documents may supply additional information, so do not worry if the provided information +is not enough to answer the question, but it needs to be relevant to the question.) +-- +QUESTION: +{question} +-- + +Please answer with 'yes' or 'no': + +Answer: + +""" + +INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n +If you think it is helpful, please decompose an initial user question into not more +than 4 appropriate sub-questions that help to answer the original question. +The purpose for this decomposition is to isolate individulal entities +(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales +for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our +sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. + +Importantly, if you think it is not needed or helpful, please just return an empty list. That is ok too. + +Here is the initial question: +\n ------- \n +{question} +\n ------- \n + +Please formulate your answer as a list of subquestions: + +Answer: +""" + +REWRITE_PROMPT_SINGLE = """ \n + Please convert an initial user question into a more appropriate search query for retrievel from a + document store. \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Formulate the query: """ + +MODIFIED_RAG_PROMPT = ( + """You are an assistant for question-answering tasks. Use the context provided below + - and only this context - to answer the question. It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! + If you don't have enough infortmation to generate an answer, just say """ + + f'"{UNKNOWN_ANSWER}"' + + """. + Use three sentences maximum and keep the answer concise. + Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer. + Again, only use the provided context and do not use your internal knowledge! + + \nQuestion: {question} + \nContext: {combined_context} \n + + Answer:""" +) + +ORIG_DEEP_DECOMPOSE_PROMPT = """ \n + An initial user question needs to be answered. An initial answer has been provided but it wasn't quite + good enough. Also, some sub-questions had been answered and this information has been used to provide + the initial answer. Some other subquestions may have been suggested based on little knowledge, but they + were not directly answerable. Also, some entities, relationships and terms are givenm to you so that + you have an idea of how the avaiolable data looks like. + + Your role is to generate 3-5 new sub-questions that would help to answer the initial question, + considering: + + 1) The initial question + 2) The initial answer that was found to be unsatisfactory + 3) The sub-questions that were answered + 4) The sub-questions that were suggested but not answered + 5) The entities, relationships and terms that were extracted from the context + + The individual questions should be answerable by a good RAG system. + So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question, but in a way that does + not duplicate questions that were already tried. + + Additional Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + resolve ambiguities, or address shortcoming of the initial answer + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please provide a short explanation for why it is a good sub-question. So + generate a list of dictionaries with the following format: + [{{"sub_question": , "explanation": , "search_term": }}, ...] + + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Here is the initial sub-optimal answer: + \n ------- \n + {base_answer} + \n ------- \n + + Here are the sub-questions that were answered: + \n ------- \n + {answered_sub_questions} + \n ------- \n + + Here are the sub-questions that were suggested but not answered: + \n ------- \n + {failed_sub_questions} + \n ------- \n + + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Again, please find questions that are NOT overlapping too much with the already answered + sub-questions or those that already were suggested and failed. + In other words - what can we try in addition to what has been tried so far? + + Please think through it step by step and then generate the list of json dictionaries with the following + format: + + {{"sub_questions": [{{"sub_question": , + "explanation": , + "search_term": }}, + ...]}} """ + +DEEP_DECOMPOSE_PROMPT = """ \n + An initial user question needs to be answered. An initial answer has been provided but it wasn't quite + good enough. Also, some sub-questions had been answered and this information has been used to provide + the initial answer. Some other subquestions may have been suggested based on little knowledge, but they + were not directly answerable. Also, some entities, relationships and terms are givenm to you so that + you have an idea of how the avaiolable data looks like. + + Your role is to generate 2-4 new sub-questions that would help to answer the initial question, + considering: + + 1) The initial question + 2) The initial answer that was found to be unsatisfactory + 3) The sub-questions that were answered + 4) The sub-questions that were suggested but not answered + 5) The entities, relationships and terms that were extracted from the context + + The individual questions should be answerable by a good RAG system. + So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question, but in a way that does + not duplicate questions that were already tried. + + Additional Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + resolve ambiguities, or address shortcoming of the initial answer + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please also provide a search term that can be used to retrieve relevant + documents from a document store. + - Consider specifically the sub-questions that were suggested but not answered. This is a sign that they are not + answerable with the available context, and you should not ask similar questions. + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Here is the initial sub-optimal answer: + \n ------- \n + {base_answer} + \n ------- \n + + Here are the sub-questions that were answered: + \n ------- \n + {answered_sub_questions} + \n ------- \n + + Here are the sub-questions that were suggested but not answered: + \n ------- \n + {failed_sub_questions} + \n ------- \n + + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. + + Specifically pay attention also to the entities, relationships and terms extracted, as these indicate what type of + objects/relationships/terms you can ask about! Do not ask about entities, terms or relationships that are not + mentioned in the 'entities, relationships and terms' section. + + Again, please find questions that are NOT overlapping too much with the already answered + sub-questions or those that already were suggested and failed. + In other words - what can we try in addition to what has been tried so far? + + Generate the list of questions separated by one new line like this: + + + + ... + """ + +DECOMPOSE_PROMPT = """ \n + For an initial user question, please generate at 5-10 individual sub-questions whose answers would help + \n to answer the initial question. The individual questions should be answerable by a good RAG system. + So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question. + + In order to arrive at meaningful sub-questions, please also consider the context retrieved from the + document store, expressed as entities, relationships and terms. You can also think about the types + mentioned in brackets + + Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + and or resolve ambiguities + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please provide a short explanation for why it is a good sub-question. So + generate a list of dictionaries with the following format: + [{{"sub_question": , "explanation": }}, ...] + + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Don't be too specific unless the original question is specific. + Please think through it step by step and then generate the list of json dictionaries with the following + format: + {{"sub_questions": [{{"sub_question": , + "explanation": , + "search_term": }}, + ...]}} """ + +#### Consolidations +COMBINED_CONTEXT = """------- + Below you will find useful information to answer the original question. First, you see a number of + sub-questions with their answers. This information should be considered to be more focussed and + somewhat more specific to the original question as it tries to contextualized facts. + After that will see the documents that were considered to be relevant to answer the original question. + + Here are the sub-questions and their answers: + \n\n {deep_answer_context} \n\n + \n\n Here are the documents that were considered to be relevant to answer the original question: + \n\n {formated_docs} \n\n + ---------------- + """ + +SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """------- + Below you will find a question that we ultimately want to answer (the original question) and a list of + motivations in arbitrary order for generated sub-questions that are supposed to help us answering the + original question. The motivations are formatted as : . + (Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant + motivation and 2 is less relevant.) + + Please rank the motivations in order of relevance for answering the original question. Also, try to + ensure that the top questions do not duplicate too much, i.e. that they are not too similar. + Ultimately, create a list with the motivation numbers where the number of the most relevant + motivations comes first. + + Here is the original question: + \n\n {original_question} \n\n + \n\n Here is the list of sub-question motivations: + \n\n {sub_question_explanations} \n\n + ---------------- + + Please think step by step and then generate the ranked list of motivations. + + Please format your answer as a json object in the following format: + {{"reasonning": , + "ranked_motivations": }} + """ + + +INITIAL_DECOMPOSITION_PROMPT_QUESTIONS = """ +If you think it is helpful, please decompose an initial user question into no more than 3 appropriate sub-questions that help to +answer the original question. The purpose for this decomposition may be to + 1) isolate individual entities (i.e., 'compare sales of company A and company B' -> ['what are sales for company A', + 'what are sales for company B')] + 2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> ['what are our sales with company A', + 'what is our market share with company A', 'is company A a reference customer for us', etc.]) + 3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you are generally + familiar with the entity, then you can decompose the question into sub-questions that are more specific to components + (i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X', + 'what do we do to improve stability of product X', ...]) + 4) research an area that could really help to answer the question. (But clarifications or disambiguations are more important.) + +If you think that a decomposition is not needed or helpful, please just return an empty string. That is ok too. + +Here is the initial question: +------- +{question} +------- +Please formulate your answer as a newline-separated list of questions like so: + + + + +Answer:""" + +INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH = """ +If you think it is helpful, please decompose an initial user question into no more than 3 appropriate sub-questions that help to +answer the original question. The purpose for this decomposition may be to + 1) isolate individual entities (i.e., 'compare sales of company A and company B' -> ['what are sales for company A', + 'what are sales for company B')] + 2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> ['what are our sales with company A', + 'what is our market share with company A', 'is company A a reference customer for us', etc.]) + 3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you are generally + familiar with the entity, then you can decompose the question into sub-questions that are more specific to components + (i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X', + 'what do we do to improve stability of product X', ...]) + 4) research an area that could really help to answer the question. (But clarifications or disambiguations are more important.) + +Here are some other ruleds: + +1) To give you some context, you will see below also some documents that relate to the question. Please only +use this information to learn what the question is approximately asking about, but do not focus on the details +to construct the sub-questions. +2) If you think that a decomposition is not needed or helpful, please just return an empty string. That is very muchok too. + +Here are the sampple docs to give you some context: +------- +{sample_doc_str} +------- + +And here is the initial question that you should think about decomposing: +------- +{question} +------- + + +Please formulate your answer as a newline-separated list of questions like so: + + + + +Answer:""" + +INITIAL_DECOMPOSITION_PROMPT = """ \n + Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to + answer the original question. The purpose for this decomposition is to isolate individulal entities + (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales + for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our + sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n + + For each sub-question, please also create one search term that can be used to retrieve relevant + documents from a document store. + + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Please formulate your answer as a list of json objects with the following format: + + [{{"sub_question": , "search_term": }}, ...] + + Answer: + """ + +INITIAL_RAG_BASE_PROMPT = ( + """ \n +You are an assistant for question-answering tasks. Use the information provided below - and only the +provided information - to answer the provided question. + +The information provided below consists ofa number of documents that were deemed relevant for the question. + +IMPORTANT RULES: +- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. +You may give some additional facts you learned, but do not try to invent an answer. +- If the information is empty or irrelevant, just say """ + + f'"{UNKNOWN_ANSWER}"' + + """. +- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why. + +Try to keep your answer concise. + +Here is the contextual information from the document store: +\n ------- \n +{context} \n\n\n +\n ------- \n +And here is the question I want you to answer based on the context above (with the motivation in mind): +\n--\n {question} \n--\n +Answer:""" +) + + +AGENT_DECISION_PROMPT = """ +You are an large language model assistant helping users address their information needs. You are tasked with deciding +whether to use a thorough agent search ('research') of a document store to answer a question or request, or whether you want to +address the question or request yourself as an LLM. + +Here are some rules: +- If you think that a thorough search through a document store will help answer the question +or address the request, you should choose the 'research' option. +- If the question asks you do do somethng ('please create...', 'write for me...', etc.), you should choose the 'LLM' option. +- If you think the question is very general and does not refer to a contents of a document store, you should choose +the 'LLM' option. +- Otherwise, you should choose the 'research' option. + +Here is the initial question: +------- +{question} +------- + +Please decide whether to use the agent search or the LLM to answer the question. Choose from two choices, +'research' or 'LLM'. + +Answer:""" + +AGENT_DECISION_PROMPT_AFTER_SEARCH = """ +You are an large language model assistant helping users address their information needs. You are given an initial question +or request and very few sample of documents that a preliminary and fast search from a document store returned. +You are tasked with deciding whether to use a thorough agent search ('research') of the document store to answer a question +or request, or whether you want to address the question or request yourself as an LLM. + +Here are some rules: +- If based on the retrieved documents you think there may be useful information in the document +store to answer or materially help with the request, you should choose the 'research' option. +- If you think that the retrieved document do not help to answer the question or do not help with the request, AND +you know the answer/can handle the request, you should choose the 'LLM' option. +- If the question asks you do do somethng ('please create...', 'write for me...', etc.), you should choose the 'LLM' option. +- If in doubt, choose the 'research' option. + +Here is the initial question: +------- +{question} +------- + +Here is the sample of documents that were retrieved from a document store: +------- +{sample_doc_str} +------- + +Please decide whether to use the agent search ('research') or the LLM to answer the question. Choose from two choices, +'research' or 'LLM'. + +Answer:""" + +### ANSWER GENERATION PROMPTS + +# Persona specification +ASSISTANT_SYSTEM_PROMPT_DEFAULT = """ +You are an assistant for question-answering tasks.""" + +ASSISTANT_SYSTEM_PROMPT_PERSONA = """ +You are an assistant for question-answering tasks. Here is more information about you: +\n ------- \n +{persona_prompt} +\n ------- \n +""" + +SUB_QUESTION_ANSWER_TEMPLATE = """ + Sub-Question: Q{sub_question_nr}\n Sub-Question:\n - \n{sub_question}\n --\nAnswer:\n -\n {sub_answer}\n\n + """ + +SUB_QUESTION_ANSWER_TEMPLATE_REVISED = """ + Sub-Question: Q{sub_question_nr}\n Type: {level_type}\n Sub-Question:\n +- \n{sub_question}\n --\nAnswer:\n -\n {sub_answer}\n\n + """ + +SUB_QUESTION_SEARCH_RESULTS_TEMPLATE = """ + Sub-Question: Q{sub_question_nr}\n Sub-Question:\n - \n{sub_question}\n --\nRelevant Documents:\n + -\n {formatted_sub_question_docs}\n\n + """ + +INITIAL_RAG_PROMPT_SUB_QUESTION_SEARCH = ( + """ \n +{persona_specification} + +Use the information provided below - and only the +provided information - to answer the provided question. + +The information provided below consists of: + 1) a number of sub-questions and supporting document information that would help answer them. + 2) a broader collection of documents that were deemed relevant for the question. These documents contain informattion + that was also provided in the sub-questions and often more. + +IMPORTANT RULES: + - If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. + You may give some additional facts you learned, but do not try to invent an answer. + - If the information is empty or irrelevant, just say """ + + f'"{UNKNOWN_ANSWER}"' + + """. + - If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why. + +Please provide inline citations of documentsin the format [[D1]](), [[D2]](), [[D3]](), etc., If you have multiple citations, +please cite for example as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc. Feel free to cite documents in addition +to the sub-questions! Proper citations are important for the final answer to be verifiable! \n\n\n + +Again, you should be sure that the answer is supported by the information provided! + +Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones, +or assumptions you made. + +Here is the contextual information: +\n-------\n +*Answered Sub-questions (these should really matter!): +{answered_sub_questions} + +And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n + +{relevant_docs} + +\n-------\n +\n +And here is the question I want you to answer based on the information above: +\n--\n +{question} +\n--\n\n +Answer:""" +) + + +DIRECT_LLM_PROMPT = """ \n +{persona_specification} + +Please answer the following question/address the request: +\n--\n +{question} +\n--\n\n +Answer:""" + +INITIAL_RAG_PROMPT = ( + """ \n +{persona_specification} + +Use the information provided below - and only the +provided information - to answer the provided question. + +The information provided below consists of: + 1) a number of answered sub-questions - these are very important(!) and definitely should be + considered to answer the question. + 2) a number of documents that were also deemed relevant for the question. + +IMPORTANT RULES: + - If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. + You may give some additional facts you learned, but do not try to invent an answer. + - If the information is empty or irrelevant, just say """ + + f'"{UNKNOWN_ANSWER}"' + + """. + - If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why. + +Remember to provide inline citations of documents in the format [[D1]](), [[D2]](), [[D3]](), etc., and [[Q1]](), [[Q2]](),... if +you want to cite the answer to a sub-question. If you have multiple citations, please cite for example +as [[D1]]()[[Q3]](), or [[D2]]()[[D4]](), etc. Feel free to cite sub-questions in addition to documents, but make sure that you +have docuemnt citations ([[D7]]() etc.) if possible! + +Again, you should be sure that the answer is supported by the information provided! + +Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones, +or assumptions you made. + +Here is the contextual information: +\n-------\n +*Answered Sub-questions (these should really matter!): +{answered_sub_questions} + +And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n + +{relevant_docs} + +\n-------\n +\n +And here is the question I want you to answer based on the information above: +\n--\n +{question} +\n--\n\n +Answer:""" +) + +# sub_question_answer_str is empty +INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = ( + """{answered_sub_questions} +{persona_specification} +Use the information provided below +- and only the provided information - to answer the provided question. +The information provided below consists of a number of documents that were deemed relevant for the question. + +IMPORTANT RULES: + - If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. + You may give some additional facts you learned, but do not try to invent an answer. + - If the information is irrelevant, just say """ + + f'"{UNKNOWN_ANSWER}"' + + """. + - If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why. + +Again, you should be sure that the answer is supported by the information provided! + +Remember to provide inline citations of documents in the format [[D1]](), [[D2]](), [[D3]](), etc.! + +Try to keep your answer concise. + +Here are is the relevant context information: +\n-------\n +{relevant_docs} +\n-------\n + +And here is the question I want you to answer based on the context above +\n--\n +{question} +\n--\n + +Answer:""" +) + +REVISED_RAG_PROMPT = ( + """\n +{persona_specification} +Use the information provided below - and only the +provided information - to answer the provided question. + +The information provided below consists of: + 1) an initial answer that was given but found to be lacking in some way. + 2) a number of answered sub-questions - these are very important(!) and definitely should be + considered to answer the question. Note that the sub-questions have a type, 'initial' and 'revised'. The 'initial' + ones were available for the initial answer, and the 'revised' were not. So please use the 'revised' sub-questions in + particular to update/extend/correct the initial answer! + information from the revised sub-questions + 3) a number of documents that were also deemed relevant for the question. + +IMPORTANT RULES: + - If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. + You may give some additional facts you learned, but do not try to invent an answer. + - If the information is empty or irrelevant, just say """ + + f'"{UNKNOWN_ANSWER}"' + + """. + - If the information is relevant but not fully conclusive, provide and answer to the extent you can but also + specify that the information is not conclusive and why. +- Ignore the exisiting citations within the answered sub-questions, like [[D1]]()... and [[Q2]]()! +The citations you will need to use will need to refer to the documents and sub-questions that you are explicitly +presented with below! + +Again, you should be sure that the answer is supported by the information provided! + +Remember to provide inline citations of documents in the format [[D1]](), [[D2]](), [[D3]](), etc., and [[Q1]](), [[Q2]](),... if +you want to cite the answer to a sub-question. If you have multiple citations, please cite for example +as [[D1]]()[[Q3]](), or [[D2]]()[[D4]](), etc. Feel free to cite sub-questions in addition to documents, but make sure that you +have docuemnt citations ([[D7]]() etc.) if possible! +Proper citations are important for the final answer to be verifiable! \n\n\n + +Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones, +or assumptions you made. + +Here is the contextual information: +\n-------\n + +*Initial Answer that was found to be lacking: +{initial_answer} + +*Answered Sub-questions (these should really matter! They also contain questions/answers that were not available when the original +answer was constructed): +{answered_sub_questions} + +And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n + +{relevant_docs} + +\n-------\n +\n +Lastly, here is the question I want you to answer based on the information above: +\n--\n +{question} +\n--\n\n +Answer:""" +) + +# sub_question_answer_str is empty +REVISED_RAG_PROMPT_NO_SUB_QUESTIONS = ( + """{answered_sub_questions}\n +{persona_specification} +Use the information provided below - and only the +provided information - to answer the provided question. + +The information provided below consists of: + 1) an initial answer that was given but found to be lacking in some way. + 2) a number of documents that were also deemed relevant for the question. + +IMPORTANT RULES: + - If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. + You may give some additional facts you learned, but do not try to invent an answer. + - If the information is empty or irrelevant, just say """ + + f'"{UNKNOWN_ANSWER}"' + + """. + - If the information is relevant but not fully conclusive, provide and answer to the extent you can but also + specify that the information is not conclusive and why. + +Again, you should be sure that the answer is supported by the information provided! + +Remember to provide inline citations of documents in the format [[D1]](), [[D2]](), [[D3]](), etc. + +Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones, +or assumptions you made. + +Here is the contextual information: +\n-------\n + +*Initial Answer that was found to be lacking: +{initial_answer} + +And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n + +{relevant_docs} + +\n-------\n +\n +Lastly, here is the question I want you to answer based on the information above: +\n--\n +{question} +\n--\n\n +Answer:""" +) + + +ENTITY_TERM_PROMPT = """ \n + Based on the original question and the context retieved from a dataset, please generate a list of + entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts + (e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other. + + \n\n + Here is the original question: + \n ------- \n + {question} + \n ------- \n + And here is the context retrieved: + \n ------- \n + {context} + \n ------- \n + + Please format your answer as a json object in the following format: + + {{"retrieved_entities_relationships": {{ + "entities": [{{ + "entity_name": , + "entity_type": + }}], + "relationships": [{{ + "relationship_name": , + "relationship_type": , + "relationship_entities": [, , ...] + }}], + "terms": [{{ + "term_name": , + "term_type": , + "term_similar_to": + }}] + }} + }} + """ diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py new file mode 100644 index 000000000000..3fe2acc48c27 --- /dev/null +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -0,0 +1,266 @@ +import ast +import json +import re +from collections.abc import Callable +from collections.abc import Iterator +from collections.abc import Sequence +from datetime import datetime +from datetime import timedelta +from typing import Any +from typing import cast +from uuid import UUID + +from langchain_core.messages import BaseMessage +from sqlalchemy.orm import Session + +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.shared_graph_utils.models import ( + EntityRelationshipTermExtraction, +) +from onyx.chat.models import AnswerStyleConfig +from onyx.chat.models import CitationConfig +from onyx.chat.models import DocumentPruningConfig +from onyx.chat.models import PromptConfig +from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE +from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT +from onyx.configs.constants import DEFAULT_PERSONA_ID +from onyx.context.search.enums import LLMEvaluationType +from onyx.context.search.models import InferenceSection +from onyx.context.search.models import RetrievalDetails +from onyx.context.search.models import SearchRequest +from onyx.db.persona import get_persona_by_id +from onyx.db.persona import Persona +from onyx.llm.interfaces import LLM +from onyx.tools.tool_constructor import SearchToolConfig +from onyx.tools.tool_implementations.search.search_tool import SearchTool + + +def normalize_whitespace(text: str) -> str: + """Normalize whitespace in text to single spaces and strip leading/trailing whitespace.""" + import re + + return re.sub(r"\s+", " ", text.strip()) + + +# Post-processing +def format_docs(docs: Sequence[InferenceSection]) -> str: + formatted_doc_list = [] + + for doc_nr, doc in enumerate(docs): + formatted_doc_list.append(f"Document D{doc_nr + 1}:\n{doc.combined_content}") + + return "\n\n".join(formatted_doc_list) + + +def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str: + formatted_doc_list = [] + + for _, doc in enumerate(docs): + formatted_doc_list.append(f"\n...{doc.combined_content}\n") + + return "\n\n".join(formatted_doc_list) + + +def clean_and_parse_list_string(json_string: str) -> list[dict]: + # Remove any prefixes/labels before the actual JSON content + json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL) + + # Remove markdown code block markers and any newline prefixes + cleaned_string = re.sub(r"```json\n|\n```", "", json_string) + cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") + cleaned_string = " ".join(cleaned_string.split()) + + # Try parsing with json.loads first, fall back to ast.literal_eval + try: + return json.loads(cleaned_string) + except json.JSONDecodeError: + try: + return ast.literal_eval(cleaned_string) + except (ValueError, SyntaxError) as e: + raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e + + +def clean_and_parse_json_string(json_string: str) -> dict[str, Any]: + # Remove markdown code block markers and any newline prefixes + cleaned_string = re.sub(r"```json\n|\n```", "", json_string) + cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") + cleaned_string = " ".join(cleaned_string.split()) + # Parse the cleaned string into a Python dictionary + return json.loads(cleaned_string) + + +def format_entity_term_extraction( + entity_term_extraction_dict: EntityRelationshipTermExtraction, +) -> str: + entities = entity_term_extraction_dict.entities + terms = entity_term_extraction_dict.terms + relationships = entity_term_extraction_dict.relationships + + entity_strs = ["\nEntities:\n"] + for entity in entities: + entity_str = f"{entity.entity_name} ({entity.entity_type})" + entity_strs.append(entity_str) + + entity_str = "\n - ".join(entity_strs) + + relationship_strs = ["\n\nRelationships:\n"] + for relationship in relationships: + relationship_name = relationship.relationship_name + relationship_type = relationship.relationship_type + relationship_entities = relationship.relationship_entities + relationship_str = ( + f"""{relationship_name} ({relationship_type}): {relationship_entities}""" + ) + relationship_strs.append(relationship_str) + + relationship_str = "\n - ".join(relationship_strs) + + term_strs = ["\n\nTerms:\n"] + for term in terms: + term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}" + term_strs.append(term_str) + + term_str = "\n - ".join(term_strs) + + return "\n".join(entity_strs + relationship_strs + term_strs) + + +def _format_time_delta(time: timedelta) -> str: + seconds_from_start = f"{((time).seconds):03d}" + microseconds_from_start = f"{((time).microseconds):06d}" + return f"{seconds_from_start}.{microseconds_from_start}" + + +def generate_log_message( + message: str, + node_start_time: datetime, + graph_start_time: datetime | None = None, +) -> str: + current_time = datetime.now() + + if graph_start_time is not None: + graph_time_str = _format_time_delta(current_time - graph_start_time) + else: + graph_time_str = "N/A" + + node_time_str = _format_time_delta(current_time - node_start_time) + + return f"{graph_time_str} ({node_time_str} s): {message}" + + +def get_test_config( + db_session: Session, primary_llm: LLM, fast_llm: LLM, search_request: SearchRequest +) -> tuple[AgentSearchConfig, SearchTool]: + persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session) + document_pruning_config = DocumentPruningConfig( + max_chunks=int( + persona.num_chunks + if persona.num_chunks is not None + else MAX_CHUNKS_FED_TO_CHAT + ), + max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE, + ) + + answer_style_config = AnswerStyleConfig( + citation_config=CitationConfig( + # The docs retrieved by this flow are already relevance-filtered + all_docs_useful=True + ), + document_pruning_config=document_pruning_config, + structured_response_format=None, + ) + + search_tool_config = SearchToolConfig( + answer_style_config=answer_style_config, + document_pruning_config=document_pruning_config, + retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True + rerank_settings=None, # Can use this to change reranking model + selected_sections=None, + latest_query_files=None, + bypass_acl=False, + ) + + prompt_config = PromptConfig.from_model(persona.prompts[0]) + + search_tool = SearchTool( + db_session=db_session, + user=None, + persona=persona, + retrieval_options=search_tool_config.retrieval_options, + prompt_config=prompt_config, + llm=primary_llm, + fast_llm=fast_llm, + pruning_config=search_tool_config.document_pruning_config, + answer_style_config=search_tool_config.answer_style_config, + selected_sections=search_tool_config.selected_sections, + chunks_above=search_tool_config.chunks_above, + chunks_below=search_tool_config.chunks_below, + full_doc=search_tool_config.full_doc, + evaluation_type=( + LLMEvaluationType.BASIC + if persona.llm_relevance_filter + else LLMEvaluationType.SKIP + ), + rerank_settings=search_tool_config.rerank_settings, + bypass_acl=search_tool_config.bypass_acl, + ) + + config = AgentSearchConfig( + search_request=search_request, + # chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), + chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim + # chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan + message_id=1, + use_persistence=True, + primary_llm=primary_llm, + fast_llm=fast_llm, + search_tool=search_tool, + ) + + return config, search_tool + + +def get_persona_prompt(persona: Persona | None) -> str: + if persona is None: + return "" + else: + return "\n".join([x.system_prompt for x in persona.prompts]) + + +def make_question_id(level: int, question_nr: int) -> str: + return f"{level}_{question_nr}" + + +def parse_question_id(question_id: str) -> tuple[int, int]: + level, question_nr = question_id.split("_") + return int(level), int(question_nr) + + +def _dispatch_nonempty( + content: str, dispatch_event: Callable[[str, int], None], num: int +) -> None: + if content != "": + dispatch_event(content, num) + + +def dispatch_separated( + token_itr: Iterator[BaseMessage], + dispatch_event: Callable[[str, int], None], + sep: str = "\n", +) -> list[str | list[str | dict[str, Any]]]: + num = 1 + streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] + for message in token_itr: + content = cast(str, message.content) + if sep in content: + sub_question_parts = content.split(sep) + _dispatch_nonempty(sub_question_parts[0], dispatch_event, num) + num += 1 + _dispatch_nonempty( + "".join(sub_question_parts[1:]).strip(), dispatch_event, num + ) + else: + _dispatch_nonempty(content, dispatch_event, num) + streamed_tokens.append(content) + + return streamed_tokens diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index ff211cbf3075..3d6e2f64d0aa 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -1,13 +1,19 @@ +from collections import defaultdict from collections.abc import Callable -from collections.abc import Iterator from uuid import uuid4 from langchain.schema.messages import BaseMessage from langchain_core.messages import AIMessageChunk from langchain_core.messages import ToolCall +from sqlalchemy.orm import Session +from onyx.agents.agent_search.models import AgentSearchConfig +from onyx.agents.agent_search.run_graph import run_basic_graph +from onyx.agents.agent_search.run_graph import run_main_graph from onyx.chat.llm_response_handler import LLMResponseHandlerManager -from onyx.chat.models import AnswerQuestionPossibleReturn +from onyx.chat.models import AgentAnswerPiece +from onyx.chat.models import AnswerPacket +from onyx.chat.models import AnswerStream from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationInfo from onyx.chat.models import OnyxAnswerPiece @@ -19,32 +25,25 @@ from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall from onyx.chat.stream_processing.answer_response_handler import ( CitationResponseHandler, ) -from onyx.chat.stream_processing.answer_response_handler import ( - DummyAnswerResponseHandler, -) from onyx.chat.stream_processing.utils import ( map_document_id_order, ) +from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler +from onyx.configs.constants import BASIC_KEY from onyx.file_store.utils import InMemoryChatFile from onyx.llm.interfaces import LLM from onyx.llm.models import PreviousMessage from onyx.natural_language_processing.utils import get_tokenizer from onyx.tools.force import ForceUseTool -from onyx.tools.models import ToolResponse from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_tool import SearchTool -from onyx.tools.tool_runner import ToolCallKickoff from onyx.tools.utils import explicit_tool_calling_supported from onyx.utils.logger import setup_logger - logger = setup_logger() -AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse] - - class Answer: def __init__( self, @@ -53,13 +52,13 @@ class Answer: llm: LLM, prompt_config: PromptConfig, force_use_tool: ForceUseTool, + pro_search_config: AgentSearchConfig, # must be the same length as `docs`. If None, all docs are considered "relevant" message_history: list[PreviousMessage] | None = None, single_message_history: str | None = None, # newly passed in files to include as part of this question # TODO THIS NEEDS TO BE HANDLED latest_query_files: list[InMemoryChatFile] | None = None, - files: list[InMemoryChatFile] | None = None, tools: list[Tool] | None = None, # NOTE: for native tool-calling, this is only supported by OpenAI atm, # but we only support them anyways @@ -69,6 +68,8 @@ class Answer: return_contexts: bool = False, skip_gen_ai_answer_generation: bool = False, is_connected: Callable[[], bool] | None = None, + fast_llm: LLM | None = None, + db_session: Session | None = None, ) -> None: if single_message_history and message_history: raise ValueError( @@ -79,7 +80,6 @@ class Answer: self.is_connected: Callable[[], bool] | None = is_connected self.latest_query_files = latest_query_files or [] - self.file_id_to_file = {file.file_id: file for file in (files or [])} self.tools = tools or [] self.force_use_tool = force_use_tool @@ -92,6 +92,7 @@ class Answer: self.prompt_config = prompt_config self.llm = llm + self.fast_llm = fast_llm self.llm_tokenizer = get_tokenizer( provider_type=llm.config.model_provider, model_name=llm.config.model_name, @@ -100,9 +101,7 @@ class Answer: self._final_prompt: list[BaseMessage] | None = None self._streamed_output: list[str] | None = None - self._processed_stream: ( - list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None - ) = None + self._processed_stream: (list[AnswerPacket] | None) = None self._return_contexts = return_contexts self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation @@ -115,55 +114,28 @@ class Answer: and not skip_explicit_tool_calling ) + self.pro_search_config = pro_search_config + self.db_session = db_session + def _get_tools_list(self) -> list[Tool]: if not self.force_use_tool.force_use: return self.tools - tool = next( - (t for t in self.tools if t.name == self.force_use_tool.tool_name), None - ) - if tool is None: - raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found") + tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name) - logger.info( - f"Forcefully using tool='{tool.name}'" - + ( - f" with args='{self.force_use_tool.args}'" - if self.force_use_tool.args is not None - else "" - ) + args_str = ( + f" with args='{self.force_use_tool.args}'" + if self.force_use_tool.args + else "" ) + logger.info(f"Forcefully using tool='{tool.name}'{args_str}") return [tool] - def _handle_specified_tool_call( - self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict - ) -> AnswerStream: - current_llm_call = llm_calls[-1] - - # make a dummy tool handler - tool_handler = ToolResponseHandler([tool]) - - dummy_tool_call_chunk = AIMessageChunk(content="") - dummy_tool_call_chunk.tool_calls = [ - ToolCall(name=tool.name, args=tool_args, id=str(uuid4())) - ] - - response_handler_manager = LLMResponseHandlerManager( - tool_handler, DummyAnswerResponseHandler(), self.is_cancelled - ) - yield from response_handler_manager.handle_llm_response( - iter([dummy_tool_call_chunk]) - ) - - new_llm_call = response_handler_manager.next_llm_call(current_llm_call) - if new_llm_call: - yield from self._get_response(llm_calls + [new_llm_call]) - else: - raise RuntimeError("Tool call handler did not return a new LLM call") - + # TODO: delete the function and move the full body to processed_streamed_output def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: current_llm_call = llm_calls[-1] + tool, tool_args = None, None # handle the case where no decision has to be made; we simply run the tool if ( current_llm_call.force_use_tool.force_use @@ -173,17 +145,10 @@ class Answer: current_llm_call.force_use_tool.tool_name, current_llm_call.force_use_tool.args, ) - tool = next( - (t for t in current_llm_call.tools if t.name == tool_name), None - ) - if not tool: - raise RuntimeError(f"Tool '{tool_name}' not found") - - yield from self._handle_specified_tool_call(llm_calls, tool, tool_args) - return + tool = get_tool_by_name(current_llm_call.tools, tool_name) # special pre-logic for non-tool calling LLM case - if not self.using_tool_calling_llm and current_llm_call.tools: + elif not self.using_tool_calling_llm and current_llm_call.tools: chosen_tool_and_args = ( ToolResponseHandler.get_tool_call_for_non_tool_calling_llm( current_llm_call, self.llm @@ -191,8 +156,24 @@ class Answer: ) if chosen_tool_and_args: tool, tool_args = chosen_tool_and_args - yield from self._handle_specified_tool_call(llm_calls, tool, tool_args) - return + + if tool and tool_args: + dummy_tool_call_chunk = AIMessageChunk(content="") + dummy_tool_call_chunk.tool_calls = [ + ToolCall(name=tool.name, args=tool_args, id=str(uuid4())) + ] + + response_handler_manager = LLMResponseHandlerManager( + ToolResponseHandler([tool]), None, self.is_cancelled + ) + yield from response_handler_manager.handle_llm_response( + iter([dummy_tool_call_chunk]) + ) + + tmp_call = response_handler_manager.next_llm_call(current_llm_call) + if tmp_call is None: + return # no more LLM calls to process + current_llm_call = tmp_call # if we're skipping gen ai answer generation, we should break # out unless we're forcing a tool call. If we don't, we might generate an @@ -212,16 +193,51 @@ class Answer: current_llm_call ) or ([], []) + # NEXT: we still want to handle the LLM response stream, but it is now: + # 1. handle the tool call requests + # 2. feed back the processed results + # 3. handle the citations + answer_handler = CitationResponseHandler( context_docs=final_search_results, final_doc_id_to_rank_map=map_document_id_order(final_search_results), display_doc_id_to_rank_map=map_document_id_order(displayed_search_results), ) + # At the moment, this wrapper class passes streamed stuff through citation and tool handlers. + # In the future, we'll want to handle citations and tool calls in the langgraph graph. response_handler_manager = LLMResponseHandlerManager( tool_call_handler, answer_handler, self.is_cancelled ) + # In langgraph, whether we do the basic thing (call llm stream) or pro search + # is based on a flag in the pro search config + + if self.pro_search_config.use_agentic_search: + if self.pro_search_config.search_request is None: + raise ValueError("Search request must be provided for pro search") + + if self.db_session is None: + raise ValueError("db_session must be provided for pro search") + if self.fast_llm is None: + raise ValueError("fast_llm must be provided for pro search") + + stream = run_main_graph( + config=self.pro_search_config, + ) + else: + stream = run_basic_graph( + config=self.pro_search_config, + last_llm_call=current_llm_call, + response_handler_manager=response_handler_manager, + ) + + processed_stream = [] + for packet in stream: + processed_stream.append(packet) + yield packet + self._processed_stream = processed_stream + return # DEBUG: good breakpoint stream = self.llm.stream( # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM @@ -283,20 +299,56 @@ class Answer: def llm_answer(self) -> str: answer = "" for packet in self.processed_streamed_output: - if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece: + # handle basic answer flow, plus level 0 agent answer flow + # since level 0 is the first answer the user sees and therefore the + # child message of the user message in the db (so it is handled + # like a basic flow answer) + if (isinstance(packet, OnyxAnswerPiece) and packet.answer_piece) or ( + isinstance(packet, AgentAnswerPiece) + and packet.answer_piece + and packet.answer_type == "agent_level_answer" + and packet.level == 0 + ): answer += packet.answer_piece return answer + def llm_answer_by_level(self) -> dict[int, str]: + answer_by_level: dict[int, str] = defaultdict(str) + for packet in self.processed_streamed_output: + if ( + isinstance(packet, AgentAnswerPiece) + and packet.answer_piece + and packet.answer_type == "agent_level_answer" + ): + answer_by_level[packet.level] += packet.answer_piece + elif isinstance(packet, OnyxAnswerPiece) and packet.answer_piece: + answer_by_level[BASIC_KEY[0]] += packet.answer_piece + return answer_by_level + @property def citations(self) -> list[CitationInfo]: citations: list[CitationInfo] = [] for packet in self.processed_streamed_output: - if isinstance(packet, CitationInfo): + if isinstance(packet, CitationInfo) and packet.level is None: citations.append(packet) return citations + def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]: + citations_by_subquestion: dict[ + tuple[int, int], list[CitationInfo] + ] = defaultdict(list) + for packet in self.processed_streamed_output: + if isinstance(packet, CitationInfo): + if packet.level_question_nr is not None and packet.level is not None: + citations_by_subquestion[ + (packet.level, packet.level_question_nr) + ].append(packet) + elif packet.level is None: + citations_by_subquestion[BASIC_KEY].append(packet) + return citations_by_subquestion + def is_cancelled(self) -> bool: if self._is_cancelled: return True diff --git a/backend/onyx/chat/chat_utils.py b/backend/onyx/chat/chat_utils.py index b14a005f386c..526241187aa7 100644 --- a/backend/onyx/chat/chat_utils.py +++ b/backend/onyx/chat/chat_utils.py @@ -48,6 +48,7 @@ def prepare_chat_message_request( retrieval_details: RetrievalDetails | None, rerank_settings: RerankingDetails | None, db_session: Session, + use_agentic_search: bool = False, ) -> CreateChatMessageRequest: # Typically used for one shot flows like SlackBot or non-chat API endpoint use cases new_chat_session = create_chat_session( @@ -72,6 +73,7 @@ def prepare_chat_message_request( search_doc_ids=None, retrieval_options=retrieval_details, rerank_settings=rerank_settings, + use_agentic_search=use_agentic_search, ) diff --git a/backend/onyx/chat/llm_response_handler.py b/backend/onyx/chat/llm_response_handler.py index 7c9c8ee71311..2bf3e8476753 100644 --- a/backend/onyx/chat/llm_response_handler.py +++ b/backend/onyx/chat/llm_response_handler.py @@ -9,25 +9,37 @@ from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler +from onyx.chat.stream_processing.answer_response_handler import ( + DummyAnswerResponseHandler, +) from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler class LLMResponseHandlerManager: + """ + This class is responsible for postprocessing the LLM response stream. + In particular, we: + 1. handle the tool call requests + 2. handle citations + 3. pass through answers generated by the LLM + 4. Stop yielding if the client disconnects + """ + def __init__( self, - tool_handler: ToolResponseHandler, - answer_handler: AnswerResponseHandler, + tool_handler: ToolResponseHandler | None, + answer_handler: AnswerResponseHandler | None, is_cancelled: Callable[[], bool], ): - self.tool_handler = tool_handler - self.answer_handler = answer_handler + self.tool_handler = tool_handler or ToolResponseHandler([]) + self.answer_handler = answer_handler or DummyAnswerResponseHandler() self.is_cancelled = is_cancelled def handle_llm_response( self, stream: Iterator[BaseMessage], ) -> Generator[ResponsePart, None, None]: - all_messages: list[BaseMessage] = [] + all_messages: list[BaseMessage | str] = [] for message in stream: if self.is_cancelled(): yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 2c5426045683..b532e7383f8b 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -3,6 +3,7 @@ from collections.abc import Iterator from datetime import datetime from enum import Enum from typing import Any +from typing import Literal from typing import TYPE_CHECKING from pydantic import BaseModel @@ -48,6 +49,8 @@ class QADocsResponse(RetrievalDocs): applied_source_filters: list[DocumentSource] | None applied_time_cutoff: datetime | None recency_bias_multiplier: float + level: int | None = None + level_question_nr: int | None = None def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore @@ -61,11 +64,17 @@ class QADocsResponse(RetrievalDocs): class StreamStopReason(Enum): CONTEXT_LENGTH = "context_length" CANCELLED = "cancelled" + FINISHED = "finished" class StreamStopInfo(BaseModel): stop_reason: StreamStopReason + stream_type: Literal["", "sub_questions", "sub_answer"] = "" + # used to identify the stream that was stopped for agent search + level: int | None = None + level_question_nr: int | None = None + def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore data = super().model_dump(mode="json", *args, **kwargs) # type: ignore data["stop_reason"] = self.stop_reason.name @@ -108,6 +117,8 @@ class OnyxAnswerPiece(BaseModel): class CitationInfo(BaseModel): citation_num: int document_id: str + level: int | None = None + level_question_nr: int | None = None class AllCitations(BaseModel): @@ -299,6 +310,40 @@ class PromptConfig(BaseModel): model_config = ConfigDict(frozen=True) +class SubQueryPiece(BaseModel): + sub_query: str + level: int + level_question_nr: int + query_id: int + + +class AgentAnswerPiece(BaseModel): + answer_piece: str + level: int + level_question_nr: int + answer_type: Literal["agent_sub_answer", "agent_level_answer"] + + +class SubQuestionPiece(BaseModel): + sub_question: str + level: int + level_question_nr: int + + +class ExtendedToolResponse(ToolResponse): + level: int + level_question_nr: int + + +ProSearchPacket = ( + SubQuestionPiece | AgentAnswerPiece | SubQueryPiece | ExtendedToolResponse +) + +AnswerPacket = ( + AnswerQuestionPossibleReturn | ProSearchPacket | ToolCallKickoff | ToolResponse +) + + ResponsePart = ( OnyxAnswerPiece | CitationInfo @@ -306,4 +351,7 @@ ResponsePart = ( | ToolResponse | ToolCallFinalResult | StreamStopInfo + | ProSearchPacket ) + +AnswerStream = Iterator[AnswerPacket] diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 9478cfdf8033..da19e19b2a85 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -1,11 +1,14 @@ import traceback +from collections import defaultdict from collections.abc import Callable from collections.abc import Iterator +from dataclasses import dataclass from functools import partial from typing import cast from sqlalchemy.orm import Session +from onyx.agents.agent_search.models import AgentSearchConfig from onyx.chat.answer import Answer from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import create_temporary_persona @@ -16,6 +19,7 @@ from onyx.chat.models import CitationConfig from onyx.chat.models import CitationInfo from onyx.chat.models import CustomToolResponse from onyx.chat.models import DocumentPruningConfig +from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import FileChatDisplay from onyx.chat.models import FinalUsedContextDocsResponse from onyx.chat.models import LLMRelevanceFilterResponse @@ -24,20 +28,26 @@ from onyx.chat.models import MessageSpecificCitations from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig +from onyx.chat.models import ProSearchPacket from onyx.chat.models import QADocsResponse from onyx.chat.models import StreamingError from onyx.chat.models import StreamStopInfo +from onyx.chat.models import StreamStopReason from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT +from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY +from onyx.configs.constants import BASIC_KEY from onyx.configs.constants import MessageType from onyx.configs.constants import MilestoneRecordType from onyx.configs.constants import NO_AUTH_USER_ID +from onyx.context.search.enums import LLMEvaluationType from onyx.context.search.enums import OptionalSearchSetting from onyx.context.search.enums import QueryFlow from onyx.context.search.enums import SearchType from onyx.context.search.models import InferenceSection from onyx.context.search.models import RetrievalDetails +from onyx.context.search.models import SearchRequest from onyx.context.search.retrieval.search_runner import inference_sections_from_ids from onyx.context.search.utils import chunks_or_sections_to_search_docs from onyx.context.search.utils import dedupe_documents @@ -159,12 +169,15 @@ def _handle_search_tool_response_summary( ) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]: response_sumary = cast(SearchResponseSummary, packet.response) + is_extended = isinstance(packet, ExtendedToolResponse) dropped_inds = None if not selected_search_docs: top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections) deduped_docs = top_docs - if dedupe_docs: + if ( + dedupe_docs and not is_extended + ): # Extended tool responses are already deduped deduped_docs, dropped_inds = dedupe_documents(top_docs) reference_db_search_docs = [ @@ -178,6 +191,10 @@ def _handle_search_tool_response_summary( translate_db_search_doc_to_server_search_doc(db_search_doc) for db_search_doc in reference_db_search_docs ] + + level, question_nr = None, None + if isinstance(packet, ExtendedToolResponse): + level, question_nr = packet.level, packet.level_question_nr return ( QADocsResponse( rephrased_query=response_sumary.rephrased_query, @@ -187,6 +204,8 @@ def _handle_search_tool_response_summary( applied_source_filters=response_sumary.final_filters.source_type, applied_time_cutoff=response_sumary.final_filters.time_cutoff, recency_bias_multiplier=response_sumary.recency_bias_multiplier, + level=level, + level_question_nr=question_nr, ), reference_db_search_docs, dropped_inds, @@ -282,10 +301,22 @@ ChatPacket = ( | MessageSpecificCitations | MessageResponseIDInfo | StreamStopInfo + | ProSearchPacket ) ChatPacketStream = Iterator[ChatPacket] +# can't store a DbSearchDoc in a Pydantic BaseModel +@dataclass +class AnswerPostInfo: + ai_message_files: list[FileDescriptor] + qa_docs_response: QADocsResponse | None = None + reference_db_search_docs: list[DbSearchDoc] | None = None + dropped_indices: list[int] | None = None + tool_result: ToolCallFinalResult | None = None + message_specific_citations: MessageSpecificCitations | None = None + + def stream_chat_message_objects( new_msg_req: CreateChatMessageRequest, user: User | None, @@ -324,6 +355,7 @@ def stream_chat_message_objects( new_msg_req.chunks_above = 0 new_msg_req.chunks_below = 0 + llm = None try: user_id = user.id if user is not None else None @@ -679,6 +711,58 @@ def stream_chat_message_objects( for tool_list in tool_dict.values(): tools.extend(tool_list) + message_history = [ + PreviousMessage.from_chat_message(msg, files) for msg in history_msgs + ] + + search_request = SearchRequest( + query=final_msg.message, + evaluation_type=( + LLMEvaluationType.BASIC + if persona.llm_relevance_filter + else LLMEvaluationType.SKIP + ), + human_selected_filters=( + retrieval_options.filters if retrieval_options else None + ), + persona=persona, + offset=(retrieval_options.offset if retrieval_options else None), + limit=retrieval_options.limit if retrieval_options else None, + rerank_settings=new_msg_req.rerank_settings, + chunks_above=new_msg_req.chunks_above, + chunks_below=new_msg_req.chunks_below, + full_doc=new_msg_req.full_doc, + enable_auto_detect_filters=( + retrieval_options.enable_auto_detect_filters + if retrieval_options + else None + ), + ) + # TODO: Since we're deleting the current main path in Answer, + # we should construct this unconditionally inside Answer instead + # Leaving it here for the time being to avoid breaking changes + search_tools = [tool for tool in tools if isinstance(tool, SearchTool)] + if len(search_tools) == 0: + raise ValueError("No search tool found") + elif len(search_tools) > 1: + # TODO: handle multiple search tools + raise ValueError("Multiple search tools found") + search_tool = search_tools[0] + pro_search_config = AgentSearchConfig( + use_agentic_search=new_msg_req.use_agentic_search, + search_request=search_request, + chat_session_id=chat_session_id, + message_id=reserved_message_id, + message_history=message_history, + primary_llm=llm, + fast_llm=fast_llm, + search_tool=search_tool, + structured_response_format=new_msg_req.structured_response_format, + db_session=db_session, + ) + + # TODO: add previous messages, answer style config, tools, etc. + # LLM prompt building, response capturing, etc. answer = Answer( is_connected=is_connected, @@ -698,28 +782,40 @@ def stream_chat_message_objects( ) ) ), - message_history=[ - PreviousMessage.from_chat_message(msg, files) for msg in history_msgs - ], + fast_llm=fast_llm, + message_history=message_history, tools=tools, force_use_tool=_get_force_search_settings(new_msg_req, tools), single_message_history=single_message_history, + pro_search_config=pro_search_config, + db_session=db_session, ) - reference_db_search_docs = None - qa_docs_response = None - # any files to associate with the AI message e.g. dall-e generated images - ai_message_files = [] - dropped_indices = None - tool_result = None + # reference_db_search_docs = None + # qa_docs_response = None + # # any files to associate with the AI message e.g. dall-e generated images + # ai_message_files = [] + # dropped_indices = None + # tool_result = None + # TODO: different channels for stored info when it's coming from the agent flow + info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict( + lambda: AnswerPostInfo(ai_message_files=[]) + ) for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): + level, level_question_nr = ( + (packet.level, packet.level_question_nr) + if isinstance(packet, ExtendedToolResponse) + else BASIC_KEY + ) + info = info_by_subq[(level, level_question_nr)] + # TODO: don't need to dedupe here when we do it in agent flow if packet.id == SEARCH_RESPONSE_SUMMARY_ID: ( - qa_docs_response, - reference_db_search_docs, - dropped_indices, + info.qa_docs_response, + info.reference_db_search_docs, + info.dropped_indices, ) = _handle_search_tool_response_summary( packet=packet, db_session=db_session, @@ -731,29 +827,34 @@ def stream_chat_message_objects( else False ), ) - yield qa_docs_response + yield info.qa_docs_response elif packet.id == SECTION_RELEVANCE_LIST_ID: relevance_sections = packet.response - if reference_db_search_docs is not None: - llm_indices = relevant_sections_to_indices( - relevance_sections=relevance_sections, - items=[ - translate_db_search_doc_to_server_search_doc(doc) - for doc in reference_db_search_docs - ], + if info.reference_db_search_docs is None: + logger.warning( + "No reference docs found for relevance filtering" + ) + continue + + llm_indices = relevant_sections_to_indices( + relevance_sections=relevance_sections, + items=[ + translate_db_search_doc_to_server_search_doc(doc) + for doc in info.reference_db_search_docs + ], + ) + + if info.dropped_indices: + llm_indices = drop_llm_indices( + llm_indices=llm_indices, + search_docs=info.reference_db_search_docs, + dropped_indices=info.dropped_indices, ) - if dropped_indices: - llm_indices = drop_llm_indices( - llm_indices=llm_indices, - search_docs=reference_db_search_docs, - dropped_indices=dropped_indices, - ) - - yield LLMRelevanceFilterResponse( - llm_selected_doc_indices=llm_indices - ) + yield LLMRelevanceFilterResponse( + llm_selected_doc_indices=llm_indices + ) elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID: yield FinalUsedContextDocsResponse( final_context_docs=packet.response @@ -773,22 +874,24 @@ def stream_chat_message_objects( ], tenant_id=tenant_id, ) - ai_message_files = [ - FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) - for file_id in file_ids - ] + info.ai_message_files.extend( + [ + FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) + for file_id in file_ids + ] + ) yield FileChatDisplay( file_ids=[str(file_id) for file_id in file_ids] ) elif packet.id == INTERNET_SEARCH_RESPONSE_ID: ( - qa_docs_response, - reference_db_search_docs, + info.qa_docs_response, + info.reference_db_search_docs, ) = _handle_internet_search_tool_response_summary( packet=packet, db_session=db_session, ) - yield qa_docs_response + yield info.qa_docs_response elif packet.id == CUSTOM_TOOL_RESPONSE_ID: custom_tool_response = cast(CustomToolCallSummary, packet.response) @@ -797,7 +900,7 @@ def stream_chat_message_objects( or custom_tool_response.response_type == "csv" ): file_ids = custom_tool_response.tool_result.file_ids - ai_message_files.extend( + info.ai_message_files.extend( [ FileDescriptor( id=str(file_id), @@ -822,10 +925,18 @@ def stream_chat_message_objects( yield cast(OnyxContexts, packet.response) elif isinstance(packet, StreamStopInfo): - pass + if packet.stop_reason == StreamStopReason.FINISHED: + yield packet else: if isinstance(packet, ToolCallFinalResult): - tool_result = packet + level, level_question_nr = ( + (packet.level, packet.level_question_nr) + if packet.level is not None + and packet.level_question_nr is not None + else BASIC_KEY + ) + info = info_by_subq[(level, level_question_nr)] + info.tool_result = packet yield cast(ChatPacket, packet) logger.debug("Reached end of stream") except ValueError as e: @@ -841,59 +952,98 @@ def stream_chat_message_objects( error_msg = str(e) stack_trace = traceback.format_exc() - client_error_msg = litellm_exception_to_error_msg(e, llm) - if llm.config.api_key and len(llm.config.api_key) > 2: - error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]") - stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]") + if llm: + client_error_msg = litellm_exception_to_error_msg(e, llm) + if llm.config.api_key and len(llm.config.api_key) > 2: + error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]") + stack_trace = stack_trace.replace( + llm.config.api_key, "[REDACTED_API_KEY]" + ) - yield StreamingError(error=client_error_msg, stack_trace=stack_trace) + yield StreamingError(error=client_error_msg, stack_trace=stack_trace) db_session.rollback() return # Post-LLM answer processing try: - logger.debug("Post-LLM answer processing") - message_specific_citations: MessageSpecificCitations | None = None - if reference_db_search_docs: - message_specific_citations = _translate_citations( - citations_list=answer.citations, - db_docs=reference_db_search_docs, - ) - if not answer.is_cancelled(): - yield AllCitations(citations=answer.citations) - - # Saving Gen AI answer and responding with message info tool_name_to_tool_id: dict[str, int] = {} for tool_id, tool_list in tool_dict.items(): for tool in tool_list: tool_name_to_tool_id[tool.name] = tool_id + subq_citations = answer.citations_by_subquestion() + for pair in subq_citations: + level, level_question_nr = pair + info = info_by_subq[(level, level_question_nr)] + logger.debug("Post-LLM answer processing") + if info.reference_db_search_docs: + info.message_specific_citations = _translate_citations( + citations_list=subq_citations[pair], + db_docs=info.reference_db_search_docs, + ) + + # TODO: AllCitations should contain subq info? + if not answer.is_cancelled(): + yield AllCitations(citations=subq_citations[pair]) + + # Saving Gen AI answer and responding with message info + + info = ( + info_by_subq[BASIC_KEY] + if BASIC_KEY in info_by_subq + else info_by_subq[AGENT_SEARCH_INITIAL_KEY] + ) gen_ai_response_message = partial_response( message=answer.llm_answer, rephrased_query=( - qa_docs_response.rephrased_query if qa_docs_response else None + info.qa_docs_response.rephrased_query if info.qa_docs_response else None ), - reference_docs=reference_db_search_docs, - files=ai_message_files, + reference_docs=info.reference_db_search_docs, + files=info.ai_message_files, token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), citations=( - message_specific_citations.citation_map - if message_specific_citations + info.message_specific_citations.citation_map + if info.message_specific_citations else None ), error=None, tool_call=( ToolCall( - tool_id=tool_name_to_tool_id[tool_result.tool_name], - tool_name=tool_result.tool_name, - tool_arguments=tool_result.tool_args, - tool_result=tool_result.tool_result, + tool_id=tool_name_to_tool_id[info.tool_result.tool_name], + tool_name=info.tool_result.tool_name, + tool_arguments=info.tool_result.tool_args, + tool_result=info.tool_result.tool_result, ) - if tool_result + if info.tool_result else None ), ) + # TODO: add answers for levels >= 1, where each level has the previous as its parent. Use + # the answer_by_level method in answer.py to get the answers for each level + next_level = 1 + prev_message = gen_ai_response_message + agent_answers = answer.llm_answer_by_level() + while next_level in agent_answers: + next_answer = agent_answers[next_level] + info = info_by_subq[(next_level, AGENT_SEARCH_INITIAL_KEY[1])] + next_answer_message = create_new_chat_message( + chat_session_id=chat_session_id, + parent_message=prev_message, + message=next_answer, + prompt_id=None, + token_count=len(llm_tokenizer_encode_func(next_answer)), + message_type=MessageType.ASSISTANT, + db_session=db_session, + files=info.ai_message_files, + reference_docs=info.reference_db_search_docs, + citations=info.message_specific_citations.citation_map + if info.message_specific_citations + else None, + ) + next_level += 1 + prev_message = next_answer_message + logger.debug("Committing messages") db_session.commit() # actually save user / assistant message diff --git a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py index 759ebc721a63..9395582ee048 100644 --- a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py +++ b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py @@ -174,6 +174,7 @@ class AnswerPromptBuilder: ) +# TODO: rename this? AnswerConfig maybe? class LLMCall(BaseModel__v1): prompt_builder: AnswerPromptBuilder tools: list[Tool] diff --git a/backend/onyx/chat/stream_processing/answer_response_handler.py b/backend/onyx/chat/stream_processing/answer_response_handler.py index 87098c3f1776..96bf77a2e3fd 100644 --- a/backend/onyx/chat/stream_processing/answer_response_handler.py +++ b/backend/onyx/chat/stream_processing/answer_response_handler.py @@ -1,5 +1,7 @@ import abc from collections.abc import Generator +from typing import Any +from typing import cast from langchain_core.messages import BaseMessage @@ -17,21 +19,28 @@ class AnswerResponseHandler(abc.ABC): @abc.abstractmethod def handle_response_part( self, - response_item: BaseMessage | None, - previous_response_items: list[BaseMessage], + response_item: BaseMessage | str | None, + previous_response_items: list[BaseMessage | str], ) -> Generator[ResponsePart, None, None]: raise NotImplementedError + @abc.abstractmethod + def update(self, state_update: Any) -> None: + raise NotImplementedError + class DummyAnswerResponseHandler(AnswerResponseHandler): def handle_response_part( self, - response_item: BaseMessage | None, - previous_response_items: list[BaseMessage], + response_item: BaseMessage | str | None, + previous_response_items: list[BaseMessage | str], ) -> Generator[ResponsePart, None, None]: # This is a dummy handler that returns nothing yield from [] + def update(self, state_update: Any) -> None: + pass + class CitationResponseHandler(AnswerResponseHandler): def __init__( @@ -56,43 +65,121 @@ class CitationResponseHandler(AnswerResponseHandler): def handle_response_part( self, - response_item: BaseMessage | None, - previous_response_items: list[BaseMessage], + response_item: BaseMessage | str | None, + previous_response_items: list[BaseMessage | str], ) -> Generator[ResponsePart, None, None]: if response_item is None: return content = ( - response_item.content if isinstance(response_item.content, str) else "" + response_item.content + if isinstance(response_item, BaseMessage) + else response_item ) + # Ensure content is a string + if not isinstance(content, str): + logger.warning(f"Received non-string content: {type(content)}") + content = str(content) if content is not None else "" + # Process the new content through the citation processor yield from self.citation_processor.process_token(content) + def update(self, state_update: Any) -> None: + state = cast( + tuple[list[LlmDoc], DocumentIdOrderMapping, DocumentIdOrderMapping], + state_update, + ) + self.context_docs = state[0] + self.final_doc_id_to_rank_map = state[1] + self.display_doc_id_to_rank_map = state[2] + self.citation_processor = CitationProcessor( + context_docs=self.context_docs, + final_doc_id_to_rank_map=self.final_doc_id_to_rank_map, + display_doc_id_to_rank_map=self.display_doc_id_to_rank_map, + ) -# No longer in use, remove later -# class QuotesResponseHandler(AnswerResponseHandler): -# def __init__( + +def BaseMessage_to_str(message: BaseMessage) -> str: + content = message.content if isinstance(message, BaseMessage) else message + if not isinstance(content, str): + logger.warning(f"Received non-string content: {type(content)}") + content = str(content) if content is not None else "" + return content + + +# class CitationMultiResponseHandler(AnswerResponseHandler): +# def __init__(self) -> None: +# self.channel_processors: dict[str, CitationProcessor] = {} +# self._default_channel = "__default__" + +# def register_default_channel( # self, # context_docs: list[LlmDoc], -# is_json_prompt: bool = True, -# ): -# self.quotes_processor = QuotesProcessor( +# final_doc_id_to_rank_map: DocumentIdOrderMapping, +# display_doc_id_to_rank_map: DocumentIdOrderMapping, +# ) -> None: +# """Register the default channel with its associated documents and ranking maps.""" +# self.register_channel( +# channel_id=self._default_channel, # context_docs=context_docs, -# is_json_prompt=is_json_prompt, +# final_doc_id_to_rank_map=final_doc_id_to_rank_map, +# display_doc_id_to_rank_map=display_doc_id_to_rank_map, +# ) + +# def register_channel( +# self, +# channel_id: str, +# context_docs: list[LlmDoc], +# final_doc_id_to_rank_map: DocumentIdOrderMapping, +# display_doc_id_to_rank_map: DocumentIdOrderMapping, +# ) -> None: +# """Register a new channel with its associated documents and ranking maps.""" +# self.channel_processors[channel_id] = CitationProcessor( +# context_docs=context_docs, +# final_doc_id_to_rank_map=final_doc_id_to_rank_map, +# display_doc_id_to_rank_map=display_doc_id_to_rank_map, # ) # def handle_response_part( # self, -# response_item: BaseMessage | None, -# previous_response_items: list[BaseMessage], +# response_item: BaseMessage | str | None, +# previous_response_items: list[BaseMessage | str], # ) -> Generator[ResponsePart, None, None]: +# """Default implementation that uses the default channel.""" + +# yield from self.handle_channel_response( +# response_item=content, +# previous_response_items=previous_response_items, +# channel_id=self._default_channel, +# ) + +# def handle_channel_response( +# self, +# response_item: ResponsePart | str | None, +# previous_response_items: list[ResponsePart | str], +# channel_id: str, +# ) -> Generator[ResponsePart, None, None]: +# """Process a response part for a specific channel.""" +# if channel_id not in self.channel_processors: +# raise ValueError(f"Attempted to process response for unregistered channel {channel_id}") + # if response_item is None: -# yield from self.quotes_processor.process_token(None) # return # content = ( -# response_item.content if isinstance(response_item.content, str) else "" +# response_item.content if isinstance(response_item, BaseMessage) else response_item # ) -# yield from self.quotes_processor.process_token(content) +# # Ensure content is a string +# if not isinstance(content, str): +# logger.warning(f"Received non-string content: {type(content)}") +# content = str(content) if content is not None else "" + +# # Process the new content through the channel's citation processor +# yield from self.channel_processors[channel_id].multi_process_token(content) + +# def remove_channel(self, channel_id: str) -> None: +# """Remove a channel and its associated processor.""" +# if channel_id in self.channel_processors: +# del self.channel_processors[channel_id] diff --git a/backend/onyx/chat/stream_processing/citation_processing.py b/backend/onyx/chat/stream_processing/citation_processing.py index 071b28c34579..6f844646acc5 100644 --- a/backend/onyx/chat/stream_processing/citation_processing.py +++ b/backend/onyx/chat/stream_processing/citation_processing.py @@ -4,6 +4,7 @@ from collections.abc import Generator from onyx.chat.models import CitationInfo from onyx.chat.models import LlmDoc from onyx.chat.models import OnyxAnswerPiece +from onyx.chat.models import ResponsePart from onyx.chat.stream_processing.utils import DocumentIdOrderMapping from onyx.configs.chat_configs import STOP_STREAM_PAT from onyx.prompts.constants import TRIPLE_BACKTICK @@ -40,6 +41,164 @@ class CitationProcessor: self.current_citations: list[int] = [] self.past_cite_count = 0 + # TODO: should reference previous citation processing, rework previous, or completely use new one? + def multi_process_token( + self, parsed_object: ResponsePart + ) -> Generator[ResponsePart, None, None]: + # if isinstance(parsed_object,OnyxAnswerPiece): + # # standard citation processing + # yield from self.process_token(parsed_object.answer_piece) + + # elif isinstance(parsed_object, AgentAnswerPiece): + # # citation processing for agent answer pieces + # for token in self.process_token(parsed_object.answer_piece): + # if isinstance(token, CitationInfo): + # yield token + # else: + # yield AgentAnswerPiece(answer_piece=token.answer_piece or '', + # answer_type=parsed_object.answer_type, level=parsed_object.level, + # level_question_nr=parsed_object.level_question_nr) + + # level = getattr(parsed_object, "level", None) + # level_question_nr = getattr(parsed_object, "level_question_nr", None) + + # if isinstance(parsed_object, (AgentAnswerPiece, OnyxAnswerPiece)): + # # logger.debug(f"FA {parsed_object.answer_piece}") + # if isinstance(parsed_object, AgentAnswerPiece): + # token = parsed_object.answer_piece + # level = parsed_object.level + # level_question_nr = parsed_object.level_question_nr + # else: + # yield parsed_object + # return + # # raise ValueError( + # # f"Invalid parsed object type: {type(parsed_object)}" + # # ) + + # if not citation_potential[level][level_question_nr] and token: + # if token.startswith(" ["): + # citation_potential[level][level_question_nr] = True + # current_yield_components[level][level_question_nr] = [token] + # else: + # yield parsed_object + # elif token and citation_potential[level][level_question_nr]: + # current_yield_components[level][level_question_nr].append(token) + # current_yield_str[level][level_question_nr] = "".join( + # current_yield_components[level][level_question_nr] + # ) + + # if current_yield_str[level][level_question_nr].strip().startswith( + # "[D" + # ) or current_yield_str[level][level_question_nr].strip().startswith( + # "[Q" + # ): + # citation_potential[level][level_question_nr] = True + + # else: + # citation_potential[level][level_question_nr] = False + # parsed_object = _set_combined_token_value( + # current_yield_str[level][level_question_nr], parsed_object + # ) + # yield parsed_object + + # if ( + # len(current_yield_components[level][level_question_nr]) > 15 + # ): # ??? 15? + # citation_potential[level][level_question_nr] = False + # parsed_object = _set_combined_token_value( + # current_yield_str[level][level_question_nr], parsed_object + # ) + # yield parsed_object + # elif "]" in current_yield_str[level][level_question_nr]: + # section_split = current_yield_str[level][level_question_nr].split( + # "]" + # ) + # section_split[0] + "]" # dead code? + # start_of_next_section = "]".join(section_split[1:]) + # citation_string = current_yield_str[level][level_question_nr][ + # : -len(start_of_next_section) + # ] + # if "[D" in citation_string: + # cite_open_bracket_marker, cite_close_bracket_marker = ( + # "[", + # "]", + # ) + # cite_identifyer = "D" + + # try: + # cited_document = int( + # citation_string[level][level_question_nr][2:-1] + # ) + # if level and level_question_nr: + # link = agent_document_citations[int(level)][ + # int(level_question_nr) + # ][cited_document].link + # else: + # link = "" + # except (ValueError, IndexError): + # link = "" + # elif "[Q" in citation_string: + # cite_open_bracket_marker, cite_close_bracket_marker = ( + # "{", + # "}", + # ) + # cite_identifyer = "Q" + # else: + # pass + + # citation_string = citation_string.replace( + # "[" + cite_identifyer, + # cite_open_bracket_marker * 2, + # ).replace("]", cite_close_bracket_marker * 2) + + # if cite_identifyer == "D": + # citation_string += f"({link})" + + # parsed_object = _set_combined_token_value( + # citation_string, parsed_object + # ) + + # yield parsed_object + + # current_yield_components[level][level_question_nr] = [ + # start_of_next_section + # ] + # if not start_of_next_section.strip().startswith("["): + # citation_potential[level][level_question_nr] = False + + # elif isinstance(parsed_object, ExtendedToolResponse): + # if parsed_object.id == "search_response_summary": + # level = parsed_object.level + # level_question_nr = parsed_object.level_question_nr + # for inference_section in parsed_object.response.top_sections: + # doc_link = inference_section.center_chunk.source_links[0] + # doc_title = inference_section.center_chunk.title + # doc_id = inference_section.center_chunk.document_id + + # if ( + # doc_id + # not in agent_question_citations_used_docs[level][ + # level_question_nr + # ] + # ): + # if level not in agent_document_citations: + # agent_document_citations[level] = {} + # if level_question_nr not in agent_document_citations[level]: + # agent_document_citations[level][level_question_nr] = [] + + # agent_document_citations[level][level_question_nr].append( + # AgentDocumentCitations( + # document_id=doc_id, + # document_title=doc_title, + # link=doc_link, + # ) + # ) + # agent_question_citations_used_docs[level][ + # level_question_nr + # ].append(doc_id) + + yield parsed_object + def process_token( self, token: str | None ) -> Generator[OnyxAnswerPiece | CitationInfo, None, None]: diff --git a/backend/onyx/chat/tool_handling/tool_response_handler.py b/backend/onyx/chat/tool_handling/tool_response_handler.py index 0c17693a20d6..21359f8272fa 100644 --- a/backend/onyx/chat/tool_handling/tool_response_handler.py +++ b/backend/onyx/chat/tool_handling/tool_response_handler.py @@ -25,6 +25,13 @@ from onyx.utils.logger import setup_logger logger = setup_logger() +def get_tool_by_name(tools: list[Tool], tool_name: str) -> Tool: + for tool in tools: + if tool.name == tool_name: + return tool + raise RuntimeError(f"Tool '{tool_name}' not found") + + class ToolResponseHandler: def __init__(self, tools: list[Tool]): self.tools = tools @@ -45,18 +52,7 @@ class ToolResponseHandler: ) -> tuple[Tool, dict] | None: if llm_call.force_use_tool.force_use: # if we are forcing a tool, we don't need to check which tools to run - tool = next( - ( - t - for t in llm_call.tools - if t.name == llm_call.force_use_tool.tool_name - ), - None, - ) - if not tool: - raise RuntimeError( - f"Tool '{llm_call.force_use_tool.tool_name}' not found" - ) + tool = get_tool_by_name(llm_call.tools, llm_call.force_use_tool.tool_name) tool_args = ( llm_call.force_use_tool.args @@ -118,20 +114,17 @@ class ToolResponseHandler: tool for tool in self.tools if tool.name == tool_call_request["name"] ] - if not known_tools_by_name: - logger.error( - "Tool call requested with unknown name field. \n" - f"self.tools: {self.tools}" - f"tool_call_request: {tool_call_request}" - ) - continue - else: + if known_tools_by_name: selected_tool = known_tools_by_name[0] selected_tool_call_request = tool_call_request - - if selected_tool and selected_tool_call_request: break + logger.error( + "Tool call requested with unknown name field. \n" + f"self.tools: {self.tools}" + f"tool_call_request: {tool_call_request}" + ) + if not selected_tool or not selected_tool_call_request: return @@ -157,8 +150,8 @@ class ToolResponseHandler: def handle_response_part( self, - response_item: BaseMessage | None, - previous_response_items: list[BaseMessage], + response_item: BaseMessage | str | None, + previous_response_items: list[BaseMessage | str], ) -> Generator[ResponsePart, None, None]: if response_item is None: yield from self._handle_tool_call() @@ -171,8 +164,6 @@ class ToolResponseHandler: else: self.tool_call_chunk += response_item # type: ignore - return - def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None: if ( self.tool_runner is None diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index e18a5ee3e7a1..52350865538e 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -38,6 +38,9 @@ DEFAULT_PERSONA_ID = 0 DEFAULT_CC_PAIR_ID = 1 +# subquestion level and question number for basic flow +BASIC_KEY = (-1, -1) +AGENT_SEARCH_INITIAL_KEY = (0, 0) # Postgres connection constants for application_name POSTGRES_WEB_APP_NAME = "web" POSTGRES_INDEXER_APP_NAME = "indexer" diff --git a/backend/onyx/configs/dev_configs.py b/backend/onyx/configs/dev_configs.py new file mode 100644 index 000000000000..49b6cdef7127 --- /dev/null +++ b/backend/onyx/configs/dev_configs.py @@ -0,0 +1,59 @@ +import os + +from .chat_configs import NUM_RETURNED_HITS + + +##### +# Agent Configs +##### + +agent_retrieval_stats_os: bool | str | None = os.environ.get( + "AGENT_RETRIEVAL_STATS", False +) + +AGENT_RETRIEVAL_STATS: bool = False +if isinstance(agent_retrieval_stats_os, str) and agent_retrieval_stats_os == "True": + AGENT_RETRIEVAL_STATS = True +elif isinstance(agent_retrieval_stats_os, bool) and agent_retrieval_stats_os: + AGENT_RETRIEVAL_STATS = True + +agent_max_query_retrieval_results_os: int | str = os.environ.get( + "AGENT_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS +) + +AGENT_MAX_QUERY_RETRIEVAL_RESULTS: int = NUM_RETURNED_HITS +try: + atmqrr = int(agent_max_query_retrieval_results_os) + AGENT_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr +except ValueError: + raise ValueError( + f"MAX_AGENT_QUERY_RETRIEVAL_RESULTS must be an integer, got {AGENT_MAX_QUERY_RETRIEVAL_RESULTS}" + ) + + +# Reranking agent configs +agent_reranking_stats_os: bool | str | None = os.environ.get( + "AGENT_RERANKING_TEST", False +) +AGENT_RERANKING_STATS: bool = False +if isinstance(agent_reranking_stats_os, str) and agent_reranking_stats_os == "True": + AGENT_RERANKING_STATS = True +elif isinstance(agent_reranking_stats_os, bool) and agent_reranking_stats_os: + AGENT_RERANKING_STATS = True + + +agent_reranking_max_query_retrieval_results_os: int | str = os.environ.get( + "AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS +) + +AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS: int = NUM_RETURNED_HITS + +GRAPH_NAME: str = "a" + +try: + atmqrr = int(agent_reranking_max_query_retrieval_results_os) + AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr +except ValueError: + raise ValueError( + f"AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS must be an integer, got {AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS}" + ) diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index 844949881ef5..2388ea53bf4c 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -403,8 +403,18 @@ class SearchPipeline: @property def section_relevance_list(self) -> list[bool]: - llm_indices = relevant_sections_to_indices( - relevance_sections=self.section_relevance, - items=self.final_context_sections, + return section_relevance_list_impl( + section_relevance=self.section_relevance, + final_context_sections=self.final_context_sections, ) - return [ind in llm_indices for ind in range(len(self.final_context_sections))] + + +def section_relevance_list_impl( + section_relevance: list[SectionRelevancePiece] | None, + final_context_sections: list[InferenceSection], +) -> list[bool]: + llm_indices = relevant_sections_to_indices( + relevance_sections=section_relevance, + items=final_context_sections, + ) + return [ind in llm_indices for ind in range(len(final_context_sections))] diff --git a/backend/onyx/context/search/utils.py b/backend/onyx/context/search/utils.py index 8a25ad1b783e..4b42bb080808 100644 --- a/backend/onyx/context/search/utils.py +++ b/backend/onyx/context/search/utils.py @@ -80,7 +80,7 @@ def drop_llm_indices( search_docs: Sequence[DBSearchDoc | SavedSearchDoc], dropped_indices: list[int], ) -> list[int]: - llm_bools = [True if i in llm_indices else False for i in range(len(search_docs))] + llm_bools = [i in llm_indices for i in range(len(search_docs))] if dropped_indices: llm_bools = [ val for ind, val in enumerate(llm_bools) if ind not in dropped_indices diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index 3601fdc67cf9..67239ddfe566 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -1,6 +1,8 @@ from collections.abc import Sequence from datetime import datetime from datetime import timedelta +from typing import Any +from typing import cast from uuid import UUID from fastapi import HTTPException @@ -15,13 +17,22 @@ from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session +from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics +from onyx.agents.agent_search.shared_graph_utils.models import ( + QuestionAnswerResults, +) from onyx.auth.schemas import UserRole from onyx.chat.models import DocumentRelevance from onyx.configs.chat_configs import HARD_DELETE_CHATS from onyx.configs.constants import MessageType +from onyx.context.search.models import InferenceSection from onyx.context.search.models import RetrievalDocs from onyx.context.search.models import SavedSearchDoc from onyx.context.search.models import SearchDoc as ServerSearchDoc +from onyx.context.search.utils import chunks_or_sections_to_search_docs +from onyx.db.models import AgentSearchMetrics +from onyx.db.models import AgentSubQuery +from onyx.db.models import AgentSubQuestion from onyx.db.models import ChatMessage from onyx.db.models import ChatMessage__SearchDoc from onyx.db.models import ChatSession @@ -37,9 +48,11 @@ from onyx.file_store.models import FileDescriptor from onyx.llm.override_models import LLMOverride from onyx.llm.override_models import PromptOverride from onyx.server.query_and_chat.models import ChatMessageDetail +from onyx.server.query_and_chat.models import SubQueryDetail +from onyx.server.query_and_chat.models import SubQuestionDetail from onyx.tools.tool_runner import ToolCallFinalResult from onyx.utils.logger import setup_logger - +from onyx.utils.special_types import JSON_ro logger = setup_logger() @@ -496,6 +509,7 @@ def get_chat_messages_by_session( prefetch_tool_calls: bool = False, ) -> list[ChatMessage]: if not skip_permission_check: + # bug if we ever call this expecting the permission check to not be skipped get_chat_session_by_id( chat_session_id=chat_session_id, user_id=user_id, db_session=db_session ) @@ -507,7 +521,12 @@ def get_chat_messages_by_session( ) if prefetch_tool_calls: - stmt = stmt.options(joinedload(ChatMessage.tool_call)) + stmt = stmt.options( + joinedload(ChatMessage.tool_call), + joinedload(ChatMessage.sub_questions).joinedload( + AgentSubQuestion.sub_queries + ), + ) result = db_session.scalars(stmt).unique().all() else: result = db_session.scalars(stmt).all() @@ -837,14 +856,54 @@ def translate_db_search_doc_to_server_search_doc( ) -def get_retrieval_docs_from_chat_message( - chat_message: ChatMessage, remove_doc_content: bool = False +def translate_db_sub_questions_to_server_objects( + db_sub_questions: list[AgentSubQuestion], +) -> list[SubQuestionDetail]: + sub_questions = [] + for sub_question in db_sub_questions: + sub_queries = [] + docs: dict[str, SearchDoc] = {} + doc_results = cast( + list[dict[str, JSON_ro]], sub_question.sub_question_doc_results + ) + verified_doc_ids = [x["document_id"] for x in doc_results] + for sub_query in sub_question.sub_queries: + doc_ids = [doc.id for doc in sub_query.search_docs] + sub_queries.append( + SubQueryDetail( + query=sub_query.sub_query, + query_id=sub_query.id, + doc_ids=doc_ids, + ) + ) + for doc in sub_query.search_docs: + docs[doc.document_id] = doc + + verified_docs = [ + docs[cast(str, doc_id)] for doc_id in verified_doc_ids if doc_id in docs + ] + + sub_questions.append( + SubQuestionDetail( + level=sub_question.level, + level_question_nr=sub_question.level_question_nr, + question=sub_question.sub_question, + answer=sub_question.sub_answer, + sub_queries=sub_queries, + context_docs=get_retrieval_docs_from_search_docs(verified_docs), + ) + ) + return sub_questions + + +def get_retrieval_docs_from_search_docs( + search_docs: list[SearchDoc], remove_doc_content: bool = False ) -> RetrievalDocs: top_documents = [ translate_db_search_doc_to_server_search_doc( db_doc, remove_doc_content=remove_doc_content ) - for db_doc in chat_message.search_docs + for db_doc in search_docs ] top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore return RetrievalDocs(top_documents=top_documents) @@ -861,8 +920,8 @@ def translate_db_message_to_chat_message_detail( latest_child_message=chat_message.latest_child_message, message=chat_message.message, rephrased_query=chat_message.rephrased_query, - context_docs=get_retrieval_docs_from_chat_message( - chat_message, remove_doc_content=remove_doc_content + context_docs=get_retrieval_docs_from_search_docs( + chat_message.search_docs, remove_doc_content=remove_doc_content ), message_type=chat_message.message_type, time_sent=chat_message.time_sent, @@ -877,6 +936,118 @@ def translate_db_message_to_chat_message_detail( else None, alternate_assistant_id=chat_message.alternate_assistant_id, overridden_model=chat_message.overridden_model, + sub_questions=translate_db_sub_questions_to_server_objects( + chat_message.sub_questions + ), ) return chat_msg_detail + + +def log_agent_metrics( + db_session: Session, + user_id: UUID | None, + persona_id: int | None, # Can be none if temporary persona is used + agent_type: str, + start_time: datetime, + agent_metrics: CombinedAgentMetrics, +) -> AgentSearchMetrics: + agent_timings = agent_metrics.timings + agent_base_metrics = agent_metrics.base_metrics + agent_refined_metrics = agent_metrics.refined_metrics + agent_additional_metrics = agent_metrics.additional_metrics + + agent_metric_tracking = AgentSearchMetrics( + user_id=user_id, + persona_id=persona_id, + agent_type=agent_type, + start_time=start_time, + base_duration__s=agent_timings.base_duration__s, + full_duration__s=agent_timings.full_duration__s, + base_metrics=vars(agent_base_metrics), + refined_metrics=vars(agent_refined_metrics), + all_metrics=vars(agent_additional_metrics), + ) + + db_session.add(agent_metric_tracking) + db_session.flush() + + return agent_metric_tracking + + +def log_agent_sub_question_results( + db_session: Session, + chat_session_id: UUID | None, + primary_message_id: int | None, + sub_question_answer_results: list[QuestionAnswerResults], +) -> None: + def _create_citation_format_list( + document_citations: list[InferenceSection], + ) -> list[dict[str, Any]]: + citation_list: list[dict[str, Any]] = [] + for document_citation in document_citations: + document_citation_dict = { + "link": "", + "blurb": document_citation.center_chunk.blurb, + "content": document_citation.center_chunk.content, + "metadata": document_citation.center_chunk.metadata, + "updated_at": str(document_citation.center_chunk.updated_at), + "document_id": document_citation.center_chunk.document_id, + "source_type": "file", + "source_links": document_citation.center_chunk.source_links, + "match_highlights": document_citation.center_chunk.match_highlights, + "semantic_identifier": document_citation.center_chunk.semantic_identifier, + } + + citation_list.append(document_citation_dict) + + return citation_list + + now = datetime.now() + + for sub_question_answer_result in sub_question_answer_results: + level, level_question_nr = [ + int(x) for x in sub_question_answer_result.question_id.split("_") + ] + sub_question = sub_question_answer_result.question + sub_answer = sub_question_answer_result.answer + sub_document_results = _create_citation_format_list( + sub_question_answer_result.documents + ) + + sub_question_object = AgentSubQuestion( + chat_session_id=chat_session_id, + primary_question_id=primary_message_id, + level=level, + level_question_nr=level_question_nr, + sub_question=sub_question, + sub_answer=sub_answer, + sub_question_doc_results=sub_document_results, + ) + + db_session.add(sub_question_object) + db_session.commit() + # db_session.flush() + + sub_question_id = sub_question_object.id + + for sub_query in sub_question_answer_result.expanded_retrieval_results: + sub_query_object = AgentSubQuery( + parent_question_id=sub_question_id, + chat_session_id=chat_session_id, + sub_query=sub_query.query, + time_created=now, + ) + + db_session.add(sub_query_object) + db_session.commit() + # db_session.flush() + + search_docs = chunks_or_sections_to_search_docs(sub_query.search_results) + for doc in search_docs: + db_doc = create_db_search_doc(doc, db_session) + db_session.add(db_doc) + sub_query_object.search_docs.append(db_doc) + db_session.commit() + + return None diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 83747eef806f..f6a94bfc4ea8 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -325,6 +325,17 @@ class ChatMessage__SearchDoc(Base): ) +class AgentSubQuery__SearchDoc(Base): + __tablename__ = "agent__sub_query__search_doc" + + sub_query_id: Mapped[int] = mapped_column( + ForeignKey("agent__sub_query.id"), primary_key=True + ) + search_doc_id: Mapped[int] = mapped_column( + ForeignKey("search_doc.id"), primary_key=True + ) + + class Document__Tag(Base): __tablename__ = "document__tag" @@ -1048,6 +1059,11 @@ class SearchDoc(Base): secondary=ChatMessage__SearchDoc.__table__, back_populates="search_docs", ) + sub_queries = relationship( + "AgentSubQuery", + secondary=AgentSubQuery__SearchDoc.__table__, + back_populates="search_docs", + ) class ToolCall(Base): @@ -1214,6 +1230,11 @@ class ChatMessage(Base): uselist=False, ) + sub_questions: Mapped[list["AgentSubQuestion"]] = relationship( + "AgentSubQuestion", + back_populates="primary_message", + ) + standard_answers: Mapped[list["StandardAnswer"]] = relationship( "StandardAnswer", secondary=ChatMessage__StandardAnswer.__table__, @@ -1248,6 +1269,71 @@ class ChatFolder(Base): return self.display_priority < other.display_priority +class AgentSubQuestion(Base): + """ + A sub-question is a question that is asked of the LLM to gather supporting + information to answer a primary question. + """ + + __tablename__ = "agent__sub_question" + + id: Mapped[int] = mapped_column(primary_key=True) + primary_question_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + chat_session_id: Mapped[UUID] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("chat_session.id") + ) + sub_question: Mapped[str] = mapped_column(Text) + level: Mapped[int] = mapped_column(Integer) + level_question_nr: Mapped[int] = mapped_column(Integer) + time_created: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + sub_answer: Mapped[str] = mapped_column(Text) + sub_question_doc_results: Mapped[JSON_ro] = mapped_column(postgresql.JSONB()) + + # Relationships + primary_message: Mapped["ChatMessage"] = relationship( + "ChatMessage", + foreign_keys=[primary_question_id], + back_populates="sub_questions", + ) + chat_session: Mapped["ChatSession"] = relationship("ChatSession") + sub_queries: Mapped[list["AgentSubQuery"]] = relationship( + "AgentSubQuery", back_populates="parent_question" + ) + + +class AgentSubQuery(Base): + """ + A sub-query is a vector DB query that gathers supporting information to answer a sub-question. + """ + + __tablename__ = "agent__sub_query" + + id: Mapped[int] = mapped_column(primary_key=True) + parent_question_id: Mapped[int] = mapped_column( + ForeignKey("agent__sub_question.id") + ) + chat_session_id: Mapped[UUID] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("chat_session.id") + ) + sub_query: Mapped[str] = mapped_column(Text) + time_created: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + + # Relationships + parent_question: Mapped["AgentSubQuestion"] = relationship( + "AgentSubQuestion", back_populates="sub_queries" + ) + chat_session: Mapped["ChatSession"] = relationship("ChatSession") + search_docs: Mapped[list["SearchDoc"]] = relationship( + "SearchDoc", + secondary=AgentSubQuery__SearchDoc.__table__, + back_populates="sub_queries", + ) + + """ Feedback, Logging, Metrics Tables """ @@ -1751,6 +1837,25 @@ class PGFileStore(Base): lobj_oid: Mapped[int] = mapped_column(Integer, nullable=False) +class AgentSearchMetrics(Base): + __tablename__ = "agent__search_metrics" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) + persona_id: Mapped[int | None] = mapped_column( + ForeignKey("persona.id"), nullable=True + ) + agent_type: Mapped[str] = mapped_column(String) + start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) + base_duration__s: Mapped[float] = mapped_column(Float) + full_duration__s: Mapped[float] = mapped_column(Float) + base_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) + refined_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) + all_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) + + """ ************************************************************************ Enterprise Edition Models diff --git a/backend/onyx/server/query_and_chat/models.py b/backend/onyx/server/query_and_chat/models.py index b107460bfc62..ef34498450ae 100644 --- a/backend/onyx/server/query_and_chat/models.py +++ b/backend/onyx/server/query_and_chat/models.py @@ -134,6 +134,10 @@ class CreateChatMessageRequest(ChunkContext): # https://platform.openai.com/docs/guides/structured-outputs/introduction structured_response_format: dict | None = None + # If true, ignores most of the search options and uses pro search instead. + # TODO: decide how many of the above options we want to pass through to pro search + use_agentic_search: bool = False + @model_validator(mode="after") def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest": if self.search_doc_ids is None and self.retrieval_options is None: @@ -200,6 +204,22 @@ class SearchFeedbackRequest(BaseModel): return self +class SubQueryDetail(BaseModel): + query: str + query_id: int + # TODO: store these to enable per-query doc selection + doc_ids: list[int] | None = None + + +class SubQuestionDetail(BaseModel): + level: int + level_question_nr: int + question: str + answer: str + sub_queries: list[SubQueryDetail] | None = None + context_docs: RetrievalDocs | None = None + + class ChatMessageDetail(BaseModel): message_id: int parent_message: int | None = None @@ -211,9 +231,10 @@ class ChatMessageDetail(BaseModel): time_sent: datetime overridden_model: str | None alternate_assistant_id: int | None = None - # Dict mapping citation number to db_doc_id chat_session_id: UUID | None = None + # Dict mapping citation number to db_doc_id citations: dict[int, int] | None = None + sub_questions: list[SubQuestionDetail] | None = None files: list[FileDescriptor] tool_call: ToolCallFinalResult | None diff --git a/backend/onyx/tools/message.py b/backend/onyx/tools/message.py index d55901116236..659f38731e38 100644 --- a/backend/onyx/tools/message.py +++ b/backend/onyx/tools/message.py @@ -25,6 +25,11 @@ class ToolCallSummary(BaseModel__v1): tool_call_request: AIMessage tool_call_result: ToolMessage + # This is a workaround to allow arbitrary types in the model + # TODO: Remove this once we have a better solution + class Config: + arbitrary_types_allowed = True + def tool_call_tokens( tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer diff --git a/backend/onyx/tools/models.py b/backend/onyx/tools/models.py index 4f56aecd3729..23f80aee3933 100644 --- a/backend/onyx/tools/models.py +++ b/backend/onyx/tools/models.py @@ -4,6 +4,9 @@ from uuid import UUID from pydantic import BaseModel from pydantic import model_validator +from onyx.context.search.enums import SearchType +from onyx.context.search.models import IndexFilters + class ToolResponse(BaseModel): id: str | None = None @@ -38,6 +41,9 @@ class ToolCallFinalResult(ToolCallKickoff): tool_result: Any = ( None # we would like to use JSON_ro, but can't due to its recursive nature ) + # agentic additions; only need to set during agentic tool calls + level: int | None = None + level_question_nr: int | None = None class DynamicSchemaInfo(BaseModel): @@ -45,5 +51,11 @@ class DynamicSchemaInfo(BaseModel): message_id: int | None +class SearchQueryInfo(BaseModel): + predicted_search: SearchType | None + final_filters: IndexFilters + recency_bias_multiplier: float + + CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID" MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID" diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index ed8af4c9cb89..f139e114f482 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -1,9 +1,9 @@ import json +from collections.abc import Callable from collections.abc import Generator from typing import Any from typing import cast -from pydantic import BaseModel from sqlalchemy.orm import Session from onyx.chat.chat_utils import llm_doc_from_inference_section @@ -25,13 +25,13 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.context.search.enums import LLMEvaluationType from onyx.context.search.enums import QueryFlow -from onyx.context.search.enums import SearchType from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceSection from onyx.context.search.models import RerankingDetails from onyx.context.search.models import RetrievalDetails from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import SearchPipeline +from onyx.context.search.pipeline import section_relevance_list_impl from onyx.db.models import Persona from onyx.db.models import User from onyx.llm.interfaces import LLM @@ -39,6 +39,7 @@ from onyx.llm.models import PreviousMessage from onyx.secondary_llm_flows.choose_search import check_if_need_search from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase from onyx.tools.message import ToolCallSummary +from onyx.tools.models import SearchQueryInfo from onyx.tools.models import ToolResponse from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict @@ -62,13 +63,10 @@ SECTION_RELEVANCE_LIST_ID = "section_relevance_list" SEARCH_EVALUATION_ID = "llm_doc_eval" -class SearchResponseSummary(BaseModel): +class SearchResponseSummary(SearchQueryInfo): top_sections: list[InferenceSection] rephrased_query: str | None = None predicted_flow: QueryFlow | None - predicted_search: SearchType | None - final_filters: IndexFilters - recency_bias_multiplier: float SEARCH_TOOL_DESCRIPTION = """ @@ -117,6 +115,8 @@ class SearchTool(Tool): self.fast_llm = fast_llm self.evaluation_type = evaluation_type + self.search_pipeline: SearchPipeline | None = None + self.selected_sections = selected_sections self.full_doc = full_doc @@ -281,8 +281,10 @@ class SearchTool(Tool): yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) - def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: + def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: query = cast(str, kwargs["query"]) + force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False)) + alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None)) if self.selected_sections: yield from self._build_response_for_specified_sections(query) @@ -291,7 +293,9 @@ class SearchTool(Tool): search_pipeline = SearchPipeline( search_request=SearchRequest( query=query, - evaluation_type=self.evaluation_type, + evaluation_type=LLMEvaluationType.SKIP + if force_no_rerank + else self.evaluation_type, human_selected_filters=( self.retrieval_options.filters if self.retrieval_options else None ), @@ -300,7 +304,16 @@ class SearchTool(Tool): self.retrieval_options.offset if self.retrieval_options else None ), limit=self.retrieval_options.limit if self.retrieval_options else None, - rerank_settings=self.rerank_settings, + rerank_settings=RerankingDetails( + rerank_model_name=None, + rerank_api_url=None, + rerank_provider_type=None, + rerank_api_key=None, + num_rerank=0, + disable_rerank_for_streaming=True, + ) + if force_no_rerank + else self.rerank_settings, chunks_above=self.chunks_above, chunks_below=self.chunks_below, full_doc=self.full_doc, @@ -314,57 +327,25 @@ class SearchTool(Tool): llm=self.llm, fast_llm=self.fast_llm, bypass_acl=self.bypass_acl, - db_session=self.db_session, + db_session=alternate_db_session or self.db_session, prompt_config=self.prompt_config, ) + self.search_pipeline = search_pipeline # used for agent_search metrics - yield ToolResponse( - id=SEARCH_RESPONSE_SUMMARY_ID, - response=SearchResponseSummary( - rephrased_query=query, - top_sections=search_pipeline.final_context_sections, - predicted_flow=search_pipeline.predicted_flow, - predicted_search=search_pipeline.predicted_search_type, - final_filters=search_pipeline.search_query.filters, - recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, - ), + search_query_info = SearchQueryInfo( + predicted_search=search_pipeline.search_query.search_type, + final_filters=search_pipeline.search_query.filters, + recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, ) - - yield ToolResponse( - id=SEARCH_DOC_CONTENT_ID, - response=OnyxContexts( - contexts=[ - OnyxContext( - content=section.combined_content, - document_id=section.center_chunk.document_id, - semantic_identifier=section.center_chunk.semantic_identifier, - blurb=section.center_chunk.blurb, - ) - for section in search_pipeline.reranked_sections - ] - ), + yield from yield_search_responses( + query, + search_pipeline.reranked_sections, + search_pipeline.final_context_sections, + search_query_info, + lambda: search_pipeline.section_relevance, + self, ) - yield ToolResponse( - id=SECTION_RELEVANCE_LIST_ID, - response=search_pipeline.section_relevance, - ) - - pruned_sections = prune_sections( - sections=search_pipeline.final_context_sections, - section_relevance_list=search_pipeline.section_relevance_list, - prompt_config=self.prompt_config, - llm_config=self.llm.config, - question=query, - contextual_pruning_config=self.contextual_pruning_config, - ) - - llm_docs = [ - llm_doc_from_inference_section(section) for section in pruned_sections - ] - - yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) - def final_result(self, *args: ToolResponse) -> JSON_ro: final_docs = cast( list[LlmDoc], @@ -425,3 +406,64 @@ class SearchTool(Tool): initial_search_results = cast(list[LlmDoc], initial_search_results) return final_search_results, initial_search_results + + +# Allows yielding the same responses as a SearchTool without being a SearchTool. +# SearchTool passed in to allow for access to SearchTool properties. +# We can't just call SearchTool methods in the graph because we're operating on +# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run. +def yield_search_responses( + query: str, + reranked_sections: list[InferenceSection], + final_context_sections: list[InferenceSection], + search_query_info: SearchQueryInfo, + get_section_relevance: Callable[[], list[SectionRelevancePiece] | None], + search_tool: SearchTool, +) -> Generator[ToolResponse, None, None]: + yield ToolResponse( + id=SEARCH_RESPONSE_SUMMARY_ID, + response=SearchResponseSummary( + rephrased_query=query, + top_sections=final_context_sections, + predicted_flow=QueryFlow.QUESTION_ANSWER, + predicted_search=search_query_info.predicted_search, + final_filters=search_query_info.final_filters, + recency_bias_multiplier=search_query_info.recency_bias_multiplier, + ), + ) + + yield ToolResponse( + id=SEARCH_DOC_CONTENT_ID, + response=OnyxContexts( + contexts=[ + OnyxContext( + content=section.combined_content, + document_id=section.center_chunk.document_id, + semantic_identifier=section.center_chunk.semantic_identifier, + blurb=section.center_chunk.blurb, + ) + for section in reranked_sections + ] + ), + ) + + section_relevance = get_section_relevance() + yield ToolResponse( + id=SECTION_RELEVANCE_LIST_ID, + response=section_relevance, + ) + + pruned_sections = prune_sections( + sections=final_context_sections, + section_relevance_list=section_relevance_list_impl( + section_relevance, final_context_sections + ), + prompt_config=search_tool.prompt_config, + llm_config=search_tool.llm.config, + question=query, + contextual_pruning_config=search_tool.contextual_pruning_config, + ) + + llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections] + + yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index d32255d9f65d..e43ebd62e6dd 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -3,6 +3,7 @@ plugins = "sqlalchemy.ext.mypy.plugin" mypy_path = "$MYPY_CONFIG_FILE_DIR" explicit_package_bases = true disallow_untyped_defs = true +enable_error_code = ["possibly-undefined"] [[tool.mypy.overrides]] module = "alembic.versions.*" diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index db5ad54e3db9..0ce81897417f 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -29,9 +29,14 @@ inflection==0.5.1 jira==3.5.1 jsonref==1.1.0 trafilatura==1.12.2 -langchain==0.1.17 -langchain-core==0.1.50 -langchain-text-splitters==0.0.1 +langchain==0.3.7 +langchain-core==0.3.24 +langchain-openai==0.2.9 +langchain-text-splitters==0.3.2 +langchainhub==0.1.21 +langgraph==0.2.59 +langgraph-checkpoint==2.0.5 +langgraph-sdk==0.1.44 litellm==1.55.4 lxml==5.3.0 lxml_html_clean==0.2.2 diff --git a/backend/tests/regression/answer_quality/agent_test.py b/backend/tests/regression/answer_quality/agent_test.py new file mode 100644 index 000000000000..25b19964efd4 --- /dev/null +++ b/backend/tests/regression/answer_quality/agent_test.py @@ -0,0 +1,127 @@ +import csv +import datetime +import json +import os + +import yaml + +from onyx.agents.agent_search.deep_search_a.main.graph_builder import main_graph_builder +from onyx.agents.agent_search.deep_search_a.main.states import MainInput +from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config +from onyx.context.search.models import SearchRequest +from onyx.db.engine import get_session_context_manager +from onyx.llm.factory import get_default_llms + + +cwd = os.getcwd() +CONFIG = yaml.safe_load( + open(f"{cwd}/backend/tests/regression/answer_quality/search_test_config.yaml") +) +INPUT_DIR = CONFIG["agent_test_input_folder"] +OUTPUT_DIR = CONFIG["agent_test_output_folder"] + + +graph = main_graph_builder(test_mode=True) +compiled_graph = graph.compile() +primary_llm, fast_llm = get_default_llms() + +# create a local json test data file and use it here + + +input_file_object = open( + f"{INPUT_DIR}/agent_test_data.json", +) +output_file = f"{OUTPUT_DIR}/agent_test_output.csv" + + +test_data = json.load(input_file_object) +example_data = test_data["examples"] +example_ids = test_data["example_ids"] + +with get_session_context_manager() as db_session: + output_data = [] + + for example in example_data: + example_id = example["id"] + if len(example_ids) > 0 and example_id not in example_ids: + continue + + example_question = example["question"] + target_sub_questions = example.get("target_sub_questions", []) + num_target_sub_questions = len(target_sub_questions) + search_request = SearchRequest(query=example_question) + + config, search_tool = get_test_config( + db_session=db_session, + primary_llm=primary_llm, + fast_llm=fast_llm, + search_request=search_request, + ) + + inputs = MainInput() + + start_time = datetime.datetime.now() + + question_result = compiled_graph.invoke( + input=inputs, config={"metadata": {"config": config}} + ) + end_time = datetime.datetime.now() + + duration = end_time - start_time + if num_target_sub_questions > 0: + chunk_expansion_ratio = ( + question_result["initial_agent_stats"] + .get("agent_effectiveness", {}) + .get("utilized_chunk_ratio", None) + ) + support_effectiveness_ratio = ( + question_result["initial_agent_stats"] + .get("agent_effectiveness", {}) + .get("support_ratio", None) + ) + else: + chunk_expansion_ratio = None + support_effectiveness_ratio = None + + generated_sub_questions = question_result.get("generated_sub_questions", []) + num_generated_sub_questions = len(generated_sub_questions) + base_answer = question_result["initial_base_answer"].split("==")[-1] + agent_answer = question_result["initial_answer"].split("==")[-1] + + output_point = { + "example_id": example_id, + "question": example_question, + "duration": duration, + "target_sub_questions": target_sub_questions, + "generated_sub_questions": generated_sub_questions, + "num_target_sub_questions": num_target_sub_questions, + "num_generated_sub_questions": num_generated_sub_questions, + "chunk_expansion_ratio": chunk_expansion_ratio, + "support_effectiveness_ratio": support_effectiveness_ratio, + "base_answer": base_answer, + "agent_answer": agent_answer, + } + + output_data.append(output_point) + + +with open(output_file, "w", newline="") as csvfile: + fieldnames = [ + "example_id", + "question", + "duration", + "target_sub_questions", + "generated_sub_questions", + "num_target_sub_questions", + "num_generated_sub_questions", + "chunk_expansion_ratio", + "support_effectiveness_ratio", + "base_answer", + "agent_answer", + ] + + writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter="\t") + writer.writeheader() + writer.writerows(output_data) + +print("DONE") diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 49ea64315273..b1895337990e 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -187,6 +187,8 @@ export function ChatPage({ const enterpriseSettings = settings?.enterpriseSettings; const [documentSidebarToggled, setDocumentSidebarToggled] = useState(false); + const [filtersToggled, setFiltersToggled] = useState(false); + const [langgraphEnabled, setLanggraphEnabled] = useState(false); const [userSettingsToggled, setUserSettingsToggled] = useState(false); @@ -1275,6 +1277,7 @@ export function ChatPage({ systemPromptOverride: searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined, useExistingUserMessage: isSeededChat, + useLanggraph: langgraphEnabled, }); const delay = (ms: number) => { @@ -2258,6 +2261,17 @@ export function ChatPage({ hideUserDropdown={user?.is_anonymous_user} /> )} +
+ +
{documentSidebarInitialWidth !== undefined && isReady ? ( { const documentsAreSelected = selectedDocumentIds && selectedDocumentIds.length > 0; @@ -206,6 +208,7 @@ export async function* sendMessage({ } : null, use_existing_user_message: useExistingUserMessage, + use_agentic_search: useLanggraph, }); const response = await fetch(`/api/chat/send-message`, {