alembic heads, basic citations, search pipeline state

This commit is contained in:
Evan Lohn
2025-02-01 14:15:45 -08:00
parent 5a95a5c9fd
commit 29440f5482
5 changed files with 25 additions and 14 deletions

View File

@@ -1,7 +1,7 @@
"""agent_tracking """agent_tracking
Revision ID: 98a5008d8711 Revision ID: 98a5008d8711
Revises: 4d58345da04a Revises: 33ea50e88f24
Create Date: 2025-01-29 17:00:00.000001 Create Date: 2025-01-29 17:00:00.000001
""" """
@@ -12,7 +12,7 @@ from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "98a5008d8711" revision = "98a5008d8711"
down_revision = "4d58345da04a" down_revision = "33ea50e88f24"
branch_labels = None branch_labels = None
depends_on = None depends_on = None

View File

@@ -68,12 +68,16 @@ def retrieve_documents(
query_info = None query_info = None
if search_tool is None: if search_tool is None:
raise ValueError("search_tool must be provided for agentic search") raise ValueError("search_tool must be provided for agentic search")
callback_container: list[list[InferenceSection]] = []
# new db session to avoid concurrency issues # new db session to avoid concurrency issues
with get_session_context_manager() as db_session: with get_session_context_manager() as db_session:
for tool_response in search_tool.run( for tool_response in search_tool.run(
query=query_to_retrieve, query=query_to_retrieve,
force_no_rerank=True, force_no_rerank=True,
alternate_db_session=db_session, alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
): ):
# get retrieved docs to send to the rest of the graph # get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
@@ -87,13 +91,9 @@ def retrieve_documents(
break break
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
pre_rerank_docs = retrieved_docs
if search_tool.search_pipeline is not None:
pre_rerank_docs = (
search_tool.search_pipeline._retrieved_sections or retrieved_docs
)
if AGENT_RETRIEVAL_STATS: if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0]
fit_scores = get_fit_scores( fit_scores = get_fit_scores(
pre_rerank_docs, pre_rerank_docs,
retrieved_docs, retrieved_docs,

View File

@@ -982,9 +982,10 @@ def stream_chat_message_objects(
# Saving Gen AI answer and responding with message info # Saving Gen AI answer and responding with message info
basic_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
info = ( info = (
info_by_subq[SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])] info_by_subq[basic_key]
if BASIC_KEY in info_by_subq if basic_key in info_by_subq
else info_by_subq[ else info_by_subq[
SubQuestionKey( SubQuestionKey(
level=AGENT_SEARCH_INITIAL_KEY[0], level=AGENT_SEARCH_INITIAL_KEY[0],
@@ -1018,7 +1019,7 @@ def stream_chat_message_objects(
), ),
) )
# TODO: add answers for levels >= 1, where each level has the previous as its parent. Use # add answers for levels >= 1, where each level has the previous as its parent. Use
# the answer_by_level method in answer.py to get the answers for each level # the answer_by_level method in answer.py to get the answers for each level
next_level = 1 next_level = 1
prev_message = gen_ai_response_message prev_message = gen_ai_response_message

View File

@@ -56,6 +56,8 @@ class SearchPipeline:
retrieval_metrics_callback: ( retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None Callable[[RetrievalMetricsContainer], None] | None
) = None, ) = None,
retrieved_sections_callback: Callable[[list[InferenceSection]], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None, prompt_config: PromptConfig | None = None,
): ):
@@ -80,6 +82,8 @@ class SearchPipeline:
self._retrieved_chunks: list[InferenceChunk] | None = None self._retrieved_chunks: list[InferenceChunk] | None = None
# Another call made to the document index to get surrounding sections # Another call made to the document index to get surrounding sections
self._retrieved_sections: list[InferenceSection] | None = None self._retrieved_sections: list[InferenceSection] | None = None
self.retrieved_sections_callback = retrieved_sections_callback
# Reranking and LLM section selection can be run together # Reranking and LLM section selection can be run together
# If only LLM selection is on, the reranked chunks are yielded immediatly # If only LLM selection is on, the reranked chunks are yielded immediatly
self._reranked_sections: list[InferenceSection] | None = None self._reranked_sections: list[InferenceSection] | None = None
@@ -328,9 +332,13 @@ class SearchPipeline:
if self._reranked_sections is not None: if self._reranked_sections is not None:
return self._reranked_sections return self._reranked_sections
retrieved_sections = self._get_sections()
if self.retrieved_sections_callback is not None:
self.retrieved_sections_callback(retrieved_sections)
self._postprocessing_generator = search_postprocessing( self._postprocessing_generator = search_postprocessing(
search_query=self.search_query, search_query=self.search_query,
retrieved_sections=self._get_sections(), retrieved_sections=retrieved_sections,
llm=self.fast_llm, llm=self.fast_llm,
rerank_metrics_callback=self.rerank_metrics_callback, rerank_metrics_callback=self.rerank_metrics_callback,
) )

View File

@@ -112,8 +112,6 @@ class SearchTool(Tool):
self.fast_llm = fast_llm self.fast_llm = fast_llm
self.evaluation_type = evaluation_type self.evaluation_type = evaluation_type
self.search_pipeline: SearchPipeline | None = None
self.selected_sections = selected_sections self.selected_sections = selected_sections
self.full_doc = full_doc self.full_doc = full_doc
@@ -282,6 +280,10 @@ class SearchTool(Tool):
query = cast(str, kwargs["query"]) query = cast(str, kwargs["query"])
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False)) force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None)) alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
retrieved_sections_callback = cast(
Callable[[list[InferenceSection]], None],
kwargs.get("retrieved_sections_callback"),
)
if self.selected_sections: if self.selected_sections:
yield from self._build_response_for_specified_sections(query) yield from self._build_response_for_specified_sections(query)
@@ -326,8 +328,8 @@ class SearchTool(Tool):
bypass_acl=self.bypass_acl, bypass_acl=self.bypass_acl,
db_session=alternate_db_session or self.db_session, db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config, prompt_config=self.prompt_config,
retrieved_sections_callback=retrieved_sections_callback,
) )
self.search_pipeline = search_pipeline # used for agent_search metrics
search_query_info = SearchQueryInfo( search_query_info = SearchQueryInfo(
predicted_search=search_pipeline.search_query.search_type, predicted_search=search_pipeline.search_query.search_type,