alembic heads, basic citations, search pipeline state

This commit is contained in:
Evan Lohn 2025-02-01 14:15:45 -08:00
parent 5a95a5c9fd
commit 29440f5482
5 changed files with 25 additions and 14 deletions

View File

@ -1,7 +1,7 @@
"""agent_tracking
Revision ID: 98a5008d8711
Revises: 4d58345da04a
Revises: 33ea50e88f24
Create Date: 2025-01-29 17:00:00.000001
"""
@ -12,7 +12,7 @@ from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "98a5008d8711"
down_revision = "4d58345da04a"
down_revision = "33ea50e88f24"
branch_labels = None
depends_on = None

View File

@ -68,12 +68,16 @@ def retrieve_documents(
query_info = None
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
callback_container: list[list[InferenceSection]] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
@ -87,13 +91,9 @@ def retrieve_documents(
break
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
pre_rerank_docs = retrieved_docs
if search_tool.search_pipeline is not None:
pre_rerank_docs = (
search_tool.search_pipeline._retrieved_sections or retrieved_docs
)
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0]
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,

View File

@ -982,9 +982,10 @@ def stream_chat_message_objects(
# Saving Gen AI answer and responding with message info
basic_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
info = (
info_by_subq[SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])]
if BASIC_KEY in info_by_subq
info_by_subq[basic_key]
if basic_key in info_by_subq
else info_by_subq[
SubQuestionKey(
level=AGENT_SEARCH_INITIAL_KEY[0],
@ -1018,7 +1019,7 @@ def stream_chat_message_objects(
),
)
# TODO: add answers for levels >= 1, where each level has the previous as its parent. Use
# add answers for levels >= 1, where each level has the previous as its parent. Use
# the answer_by_level method in answer.py to get the answers for each level
next_level = 1
prev_message = gen_ai_response_message

View File

@ -56,6 +56,8 @@ class SearchPipeline:
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
retrieved_sections_callback: Callable[[list[InferenceSection]], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None,
):
@ -80,6 +82,8 @@ class SearchPipeline:
self._retrieved_chunks: list[InferenceChunk] | None = None
# Another call made to the document index to get surrounding sections
self._retrieved_sections: list[InferenceSection] | None = None
self.retrieved_sections_callback = retrieved_sections_callback
# Reranking and LLM section selection can be run together
# If only LLM selection is on, the reranked chunks are yielded immediatly
self._reranked_sections: list[InferenceSection] | None = None
@ -328,9 +332,13 @@ class SearchPipeline:
if self._reranked_sections is not None:
return self._reranked_sections
retrieved_sections = self._get_sections()
if self.retrieved_sections_callback is not None:
self.retrieved_sections_callback(retrieved_sections)
self._postprocessing_generator = search_postprocessing(
search_query=self.search_query,
retrieved_sections=self._get_sections(),
retrieved_sections=retrieved_sections,
llm=self.fast_llm,
rerank_metrics_callback=self.rerank_metrics_callback,
)

View File

@ -112,8 +112,6 @@ class SearchTool(Tool):
self.fast_llm = fast_llm
self.evaluation_type = evaluation_type
self.search_pipeline: SearchPipeline | None = None
self.selected_sections = selected_sections
self.full_doc = full_doc
@ -282,6 +280,10 @@ class SearchTool(Tool):
query = cast(str, kwargs["query"])
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
retrieved_sections_callback = cast(
Callable[[list[InferenceSection]], None],
kwargs.get("retrieved_sections_callback"),
)
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
@ -326,8 +328,8 @@ class SearchTool(Tool):
bypass_acl=self.bypass_acl,
db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config,
retrieved_sections_callback=retrieved_sections_callback,
)
self.search_pipeline = search_pipeline # used for agent_search metrics
search_query_info = SearchQueryInfo(
predicted_search=search_pipeline.search_query.search_type,