mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Source Filter Extraction (#708)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
@@ -11,19 +10,13 @@ from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import ToolInfo
|
||||
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
||||
|
||||
|
||||
def build_system_text_from_persona(persona: Persona) -> str | None:
|
||||
text = (persona.system_text or "").strip()
|
||||
if persona.datetime_aware:
|
||||
current_datetime = datetime.now()
|
||||
# Format looks like: "October 16, 2023 14:30"
|
||||
formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M")
|
||||
|
||||
text += (
|
||||
"\n\nAdditional Information:\n"
|
||||
f"\t- The current date and time is {formatted_datetime}."
|
||||
)
|
||||
text += "\n\nAdditional Information:\n" f"\t- {get_current_llm_day_time()}."
|
||||
|
||||
return text or None
|
||||
|
||||
|
@@ -169,8 +169,8 @@ DOC_TIME_DECAY = float(
|
||||
os.environ.get("DOC_TIME_DECAY") or 0.5 # Hits limit at 2 years by default
|
||||
)
|
||||
FAVOR_RECENT_DECAY_MULTIPLIER = 2
|
||||
DISABLE_TIME_FILTER_EXTRACTION = (
|
||||
os.environ.get("DISABLE_TIME_FILTER_EXTRACTION", "").lower() == "true"
|
||||
DISABLE_LLM_FILTER_EXTRACTION = (
|
||||
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
|
||||
)
|
||||
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
|
@@ -182,15 +182,22 @@ def build_qa_response_blocks(
|
||||
query_event_id: int,
|
||||
answer: str | None,
|
||||
quotes: list[DanswerQuote] | None,
|
||||
source_filters: list[DocumentSource] | None,
|
||||
time_cutoff: datetime | None,
|
||||
favor_recent: bool,
|
||||
) -> list[Block]:
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
ai_answer_header = HeaderBlock(text="AI Answer")
|
||||
|
||||
filter_block: Block | None = None
|
||||
if time_cutoff or favor_recent:
|
||||
if time_cutoff or favor_recent or source_filters:
|
||||
filter_text = "Filters: "
|
||||
if source_filters:
|
||||
sources_str = ", ".join([s.value for s in source_filters])
|
||||
filter_text += f"`Sources in [{sources_str}]`"
|
||||
if time_cutoff or favor_recent:
|
||||
filter_text += " and "
|
||||
if time_cutoff is not None:
|
||||
time_str = time_cutoff.strftime("%b %d, %Y")
|
||||
filter_text += f"`Docs Updated >= {time_str}` "
|
||||
|
@@ -260,6 +260,7 @@ def handle_message(
|
||||
query_event_id=answer.query_event_id,
|
||||
answer=answer.answer,
|
||||
quotes=answer.quotes,
|
||||
source_filters=answer.source_type,
|
||||
time_cutoff=answer.time_cutoff,
|
||||
favor_recent=answer.favor_recent,
|
||||
)
|
||||
|
@@ -202,3 +202,11 @@ def fetch_latest_index_attempts_by_status(
|
||||
),
|
||||
)
|
||||
return cast(list[IndexAttempt], query.all())
|
||||
|
||||
|
||||
def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
|
||||
distinct_sources = db_session.query(Connector.source).distinct().all()
|
||||
|
||||
sources = [source[0] for source in distinct_sources]
|
||||
|
||||
return sources
|
||||
|
@@ -22,12 +22,14 @@ from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import danswer_search
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
||||
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
|
||||
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import RerankedRetrievalDocs
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@@ -52,9 +54,22 @@ def answer_qa_query(
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
logger.info(f"Received QA query: {query}")
|
||||
|
||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||
functions_to_run: dict[Callable, tuple] = {
|
||||
extract_question_time_filters: (question,),
|
||||
extract_question_source_filters: (question, db_session),
|
||||
query_intent: (query,),
|
||||
}
|
||||
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
|
||||
source_filters = parallel_results["extract_question_source_filters"]
|
||||
predicted_search, predicted_flow = parallel_results["query_intent"]
|
||||
|
||||
# Modifies the question object but nothing upstream uses it
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
question.favor_recent = favor_recent
|
||||
question.filters.source_type = source_filters
|
||||
|
||||
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
||||
question=question,
|
||||
@@ -65,9 +80,6 @@ def answer_qa_query(
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
|
||||
# TODO retire this
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
|
||||
if not ranked_chunks:
|
||||
return QAResponse(
|
||||
answer=None,
|
||||
@@ -77,6 +89,7 @@ def answer_qa_query(
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
@@ -96,6 +109,7 @@ def answer_qa_query(
|
||||
predicted_flow=QueryFlow.SEARCH,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
@@ -113,6 +127,7 @@ def answer_qa_query(
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
error_msg=str(e),
|
||||
@@ -159,6 +174,7 @@ def answer_qa_query(
|
||||
predicted_search=predicted_search,
|
||||
eval_res_valid=True if valid else False,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
error_msg=error_msg,
|
||||
@@ -172,6 +188,7 @@ def answer_qa_query(
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
error_msg=error_msg,
|
||||
@@ -194,9 +211,22 @@ def answer_qa_query_stream(
|
||||
query = question.query
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
|
||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||
functions_to_run: dict[Callable, tuple] = {
|
||||
extract_question_time_filters: (question,),
|
||||
extract_question_source_filters: (question, db_session),
|
||||
query_intent: (query,),
|
||||
}
|
||||
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
|
||||
source_filters = parallel_results["extract_question_source_filters"]
|
||||
predicted_search, predicted_flow = parallel_results["query_intent"]
|
||||
|
||||
# Modifies the question object but nothing upstream uses it
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
question.favor_recent = favor_recent
|
||||
question.filters.source_type = source_filters
|
||||
|
||||
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
||||
question=question,
|
||||
@@ -205,9 +235,6 @@ def answer_qa_query_stream(
|
||||
document_index=get_default_document_index(),
|
||||
)
|
||||
|
||||
# TODO retire this
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
|
||||
|
@@ -335,7 +335,10 @@ def _build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
|
||||
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
|
||||
filter_str += _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list)
|
||||
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, filters.source_type)
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
|
||||
|
@@ -9,3 +9,4 @@ UNCERTAINTY_PAT = "?"
|
||||
QUOTE_PAT = "Quote:"
|
||||
QUOTES_PAT_PLURAL = "Quotes:"
|
||||
INVALID_PAT = "Invalid:"
|
||||
SOURCES_KEY = "sources"
|
||||
|
9
backend/danswer/prompts/prompt_utils.py
Normal file
9
backend/danswer/prompts/prompt_utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_current_llm_day_time() -> str:
|
||||
current_datetime = datetime.now()
|
||||
# Format looks like: "October 16, 2023 14:30"
|
||||
formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M")
|
||||
day_of_week = current_datetime.strftime("%A")
|
||||
return f"The current day and time is {day_of_week} {formatted_datetime}"
|
@@ -2,6 +2,7 @@ from danswer.prompts.constants import ANSWER_PAT
|
||||
from danswer.prompts.constants import ANSWERABLE_PAT
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
from danswer.prompts.constants import QUESTION_PAT
|
||||
from danswer.prompts.constants import SOURCES_KEY
|
||||
from danswer.prompts.constants import THOUGHT_PAT
|
||||
|
||||
|
||||
@@ -31,21 +32,6 @@ Hint: Remember, if ANY of the conditions are True, it is Invalid.
|
||||
""".strip()
|
||||
|
||||
|
||||
TIME_FILTER_PROMPT = """
|
||||
You are a tool to identify time filters to apply to a user query for a downstream search \
|
||||
application. The downstream application is able to use a recency bias or apply a hard cutoff to \
|
||||
remove all documents before the cutoff. Identify the correct filters to apply for the user query.
|
||||
|
||||
Always answer with ONLY a json which contains the keys "filter_type", "filter_value", \
|
||||
"value_multiple" and "date".
|
||||
|
||||
The valid values for "filter_type" are "hard cutoff", "favors recent", or "not time sensitive".
|
||||
The valid values for "filter_value" are "day", "week", "month", "quarter", "half", or "year".
|
||||
The valid values for "value_multiple" is any number.
|
||||
The valid values for "date" is a date in format MM/DD/YYYY.
|
||||
""".strip()
|
||||
|
||||
|
||||
ANSWERABLE_PROMPT = f"""
|
||||
You are a helper tool to determine if a query is answerable using retrieval augmented generation.
|
||||
The main system will try to answer the user query based on ONLY the top 5 most relevant \
|
||||
@@ -91,6 +77,61 @@ won't find an answer.
|
||||
""".strip()
|
||||
|
||||
|
||||
# Smaller followup prompts in time_filter.py
|
||||
TIME_FILTER_PROMPT = """
|
||||
You are a tool to identify time filters to apply to a user query for a downstream search \
|
||||
application. The downstream application is able to use a recency bias or apply a hard cutoff to \
|
||||
remove all documents before the cutoff. Identify the correct filters to apply for the user query.
|
||||
|
||||
The current day and time is {current_day_time_str}.
|
||||
|
||||
Always answer with ONLY a json which contains the keys "filter_type", "filter_value", \
|
||||
"value_multiple" and "date".
|
||||
|
||||
The valid values for "filter_type" are "hard cutoff", "favors recent", or "not time sensitive".
|
||||
The valid values for "filter_value" are "day", "week", "month", "quarter", "half", or "year".
|
||||
The valid values for "value_multiple" is any number.
|
||||
The valid values for "date" is a date in format MM/DD/YYYY, ALWAYS follow this format.
|
||||
""".strip()
|
||||
|
||||
|
||||
# Smaller followup prompts in source_filter.py
|
||||
# Known issue: LLMs like GPT-3.5 try to generalize. If the valid sources contains "web" but not
|
||||
# "confluence" and the user asks for confluence related things, the LLM will select "web" since
|
||||
# confluence is accessed as a website. This cannot be fixed without also reducing the capability
|
||||
# to match things like repository->github, website->web, etc.
|
||||
# This is generally not a big issue though as if the company has confluence, hopefully they add
|
||||
# a connector for it or the user is aware that confluence has not been added.
|
||||
SOURCE_FILTER_PROMPT = f"""
|
||||
Given a user query, extract relevant source filters for use in a downstream search tool.
|
||||
Respond with a json containing the source filters or null if no specific sources are referenced.
|
||||
ONLY extract sources when the user is explicitly limiting the scope of where information is \
|
||||
coming from.
|
||||
The user may provide invalid source filters, ignore those.
|
||||
|
||||
The valid sources are:
|
||||
{{valid_sources}}
|
||||
{{web_source_warning}}
|
||||
{{file_source_warning}}
|
||||
|
||||
|
||||
ALWAYS answer with ONLY a json with the key "{SOURCES_KEY}". \
|
||||
The value for "{SOURCES_KEY}" must be null or a list of valid sources.
|
||||
|
||||
Sample Response:
|
||||
{{sample_response}}
|
||||
""".strip()
|
||||
|
||||
WEB_SOURCE_WARNING = """
|
||||
Note: The "web" source only applies to when the user specifies "website" in the query. \
|
||||
It does not apply to tools such as Confluence, GitHub, etc. which have a website.
|
||||
""".strip()
|
||||
|
||||
FILE_SOURCE_WARNING = """
|
||||
Note: The "file" source only applies to when the user refers to uploaded files in the query.
|
||||
""".strip()
|
||||
|
||||
|
||||
# User the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
print(ANSWERABLE_PROMPT)
|
||||
|
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
@@ -31,7 +32,7 @@ class Embedder:
|
||||
|
||||
|
||||
class BaseFilters(BaseModel):
|
||||
source_type: list[str] | None = None
|
||||
source_type: list[DocumentSource] | None = None
|
||||
document_set: list[str] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
|
||||
|
185
backend/danswer/secondary_llm_flows/source_filter.py
Normal file
185
backend/danswer/secondary_llm_flows/source_filter.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import json
|
||||
import random
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_unique_document_sources
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.prompts.constants import SOURCES_KEY
|
||||
from danswer.prompts.secondary_llm_flows import FILE_SOURCE_WARNING
|
||||
from danswer.prompts.secondary_llm_flows import SOURCE_FILTER_PROMPT
|
||||
from danswer.prompts.secondary_llm_flows import WEB_SOURCE_WARNING
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def strings_to_document_sources(source_strs: list[str]) -> list[DocumentSource]:
|
||||
sources = []
|
||||
for s in source_strs:
|
||||
try:
|
||||
sources.append(DocumentSource(s))
|
||||
except ValueError:
|
||||
logger.warning(f"Failed to translate {s} to a DocumentSource")
|
||||
return sources
|
||||
|
||||
|
||||
def _sample_document_sources(
|
||||
valid_sources: list[DocumentSource],
|
||||
num_sample: int,
|
||||
allow_less: bool = True,
|
||||
) -> list[DocumentSource]:
|
||||
if len(valid_sources) < num_sample:
|
||||
if not allow_less:
|
||||
raise RuntimeError("Not enough sample Document Sources")
|
||||
return random.sample(valid_sources, len(valid_sources))
|
||||
else:
|
||||
return random.sample(valid_sources, num_sample)
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def extract_source_filter(
|
||||
query: str, db_session: Session
|
||||
) -> list[DocumentSource] | None:
|
||||
"""Returns a list of valid sources for search or None if no specific sources were detected"""
|
||||
|
||||
def _get_source_filter_messages(
|
||||
query: str,
|
||||
valid_sources: list[DocumentSource],
|
||||
# Seems the LLM performs similarly without examples
|
||||
show_samples: bool = False,
|
||||
) -> list[dict[str, str]]:
|
||||
sample_json = {
|
||||
SOURCES_KEY: [
|
||||
s.value
|
||||
for s in _sample_document_sources(
|
||||
valid_sources=valid_sources, num_sample=2
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
web_warning = WEB_SOURCE_WARNING if DocumentSource.WEB in valid_sources else ""
|
||||
file_warning = (
|
||||
FILE_SOURCE_WARNING if DocumentSource.FILE in valid_sources else ""
|
||||
)
|
||||
|
||||
msg_1_sources = _sample_document_sources(
|
||||
valid_sources=valid_sources, num_sample=2
|
||||
)
|
||||
msg_1_source_str = " and ".join([s.capitalize() for s in msg_1_sources])
|
||||
|
||||
msg_2_sources = _sample_document_sources(
|
||||
valid_sources=valid_sources, num_sample=2
|
||||
)
|
||||
|
||||
msg_2_real_source = msg_2_sources[0]
|
||||
msg_2_fake_source_str = (
|
||||
msg_2_sources[1].value.capitalize()
|
||||
if len(msg_2_sources) > 1
|
||||
else "Confluence"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": SOURCE_FILTER_PROMPT.format(
|
||||
valid_sources=[s.value for s in valid_sources],
|
||||
web_source_warning=web_warning,
|
||||
file_source_warning=file_warning,
|
||||
sample_response=json.dumps(sample_json),
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"What documents in {msg_1_source_str} cover engineer onboarding",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps({SOURCES_KEY: msg_1_sources}),
|
||||
},
|
||||
{"role": "user", "content": "What's the latest on project Corgies?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps({SOURCES_KEY: None}),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"What information from {msg_2_real_source.value.capitalize()} "
|
||||
f"mentions {msg_2_fake_source_str}?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps({SOURCES_KEY: [msg_2_real_source]}),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What page from Danswer contains debugging instruction on segfault",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps({SOURCES_KEY: None}),
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
if show_samples:
|
||||
return messages
|
||||
|
||||
# Only system prompt and latest user query
|
||||
return [messages[0], messages[-1]]
|
||||
|
||||
def _extract_source_filters_from_llm_out(
|
||||
model_out: str,
|
||||
) -> list[DocumentSource] | None:
|
||||
try:
|
||||
sources_dict = extract_embedded_json(model_out)
|
||||
sources_list = sources_dict.get(SOURCES_KEY)
|
||||
if not sources_list:
|
||||
return None
|
||||
|
||||
return strings_to_document_sources(sources_list)
|
||||
except ValueError:
|
||||
logger.warning("LLM failed to provide a valid Source Filter output")
|
||||
return None
|
||||
|
||||
valid_sources = fetch_unique_document_sources(db_session)
|
||||
if not valid_sources:
|
||||
return None
|
||||
|
||||
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
return _extract_source_filters_from_llm_out(model_output)
|
||||
|
||||
|
||||
def extract_question_source_filters(
|
||||
question: QuestionRequest,
|
||||
db_session: Session,
|
||||
disable_llm_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
|
||||
) -> list[DocumentSource] | None:
|
||||
# If specified in the question, don't update
|
||||
if question.filters.source_type:
|
||||
return question.filters.source_type
|
||||
|
||||
if not question.enable_auto_detect_filters or disable_llm_extraction:
|
||||
return None
|
||||
|
||||
return extract_source_filter(question.query, db_session)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Just for testing purposes
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
while True:
|
||||
user_input = input("Query to Extract Sources: ")
|
||||
sources = extract_source_filter(user_input, db_session)
|
||||
print(sources)
|
@@ -5,9 +5,10 @@ from datetime import timezone
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_TIME_FILTER_EXTRACTION
|
||||
from danswer.configs.app_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
||||
from danswer.prompts.secondary_llm_flows import TIME_FILTER_PROMPT
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -51,7 +52,9 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": TIME_FILTER_PROMPT,
|
||||
"content": TIME_FILTER_PROMPT.format(
|
||||
current_day_time_str=get_current_llm_day_time()
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -152,7 +155,7 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
|
||||
|
||||
def extract_question_time_filters(
|
||||
question: QuestionRequest,
|
||||
disable_llm_extraction: bool = DISABLE_TIME_FILTER_EXTRACTION,
|
||||
disable_llm_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
|
||||
) -> tuple[datetime | None, bool]:
|
||||
time_cutoff = question.filters.time_cutoff
|
||||
favor_recent = question.favor_recent
|
@@ -290,6 +290,7 @@ class SearchResponse(BaseModel):
|
||||
top_ranked_docs: list[SearchDoc] | None
|
||||
lower_ranked_docs: list[SearchDoc] | None
|
||||
query_event_id: int
|
||||
source_type: list[DocumentSource] | None
|
||||
time_cutoff: datetime | None
|
||||
favor_recent: bool
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
@@ -20,9 +22,10 @@ from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import danswer_search
|
||||
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
||||
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
|
||||
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
|
||||
from danswer.server.models import HelperResponse
|
||||
from danswer.server.models import QAFeedbackRequest
|
||||
from danswer.server.models import QAResponse
|
||||
@@ -32,6 +35,7 @@ from danswer.server.models import SearchDoc
|
||||
from danswer.server.models import SearchFeedbackRequest
|
||||
from danswer.server.models import SearchResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -125,9 +129,19 @@ def handle_search_request(
|
||||
query = question.query
|
||||
logger.info(f"Received {question.search_type.value} " f"search query: {query}")
|
||||
|
||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||
functions_to_run: dict[Callable, tuple] = {
|
||||
extract_question_time_filters: (question,),
|
||||
extract_question_source_filters: (question, db_session),
|
||||
}
|
||||
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
|
||||
source_filters = parallel_results["extract_question_source_filters"]
|
||||
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
question.favor_recent = favor_recent
|
||||
question.filters.source_type = source_filters
|
||||
|
||||
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
||||
question=question,
|
||||
@@ -141,6 +155,7 @@ def handle_search_request(
|
||||
top_ranked_docs=None,
|
||||
lower_ranked_docs=None,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
@@ -152,6 +167,7 @@ def handle_search_request(
|
||||
top_ranked_docs=top_docs,
|
||||
lower_ranked_docs=lower_top_docs or None,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
38
backend/danswer/utils/threadpool_concurrency.py
Normal file
38
backend/danswer/utils/threadpool_concurrency.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_functions_in_parallel(
|
||||
functions_with_args: dict[Callable, tuple]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Executes multiple functions in parallel and returns a dictionary with the results.
|
||||
|
||||
Args:
|
||||
functions_with_args (dict): A dictionary mapping functions to a tuple of arguments.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping function names to their results or error messages.
|
||||
"""
|
||||
results = {}
|
||||
with ThreadPoolExecutor(max_workers=len(functions_with_args)) as executor:
|
||||
future_to_function = {
|
||||
executor.submit(func, *args): func.__name__
|
||||
for func, args in functions_with_args.items()
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_function):
|
||||
function_name = future_to_function[future]
|
||||
try:
|
||||
results[function_name] = future.result()
|
||||
except Exception as e:
|
||||
logger.exception(f"Function {function_name} failed due to {e}")
|
||||
raise
|
||||
|
||||
return results
|
@@ -2,6 +2,7 @@ import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
@@ -14,53 +15,38 @@ F = TypeVar("F", bound=Callable)
|
||||
FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
|
||||
|
||||
|
||||
def log_function_time(
|
||||
func_name: str | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
"""Build a timing wrapper for a function. Logs how long the function took to run.
|
||||
Use like:
|
||||
|
||||
@log_function_time()
|
||||
def my_func():
|
||||
...
|
||||
"""
|
||||
|
||||
def timing_wrapper(func: F) -> F:
|
||||
def log_function_time(func_name: str | None = None) -> Callable[[F], F]:
|
||||
def decorator(func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
logger.info(
|
||||
f"{func_name or func.__name__} took {time.time() - start_time} seconds"
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"{func_name or func.__name__} took {elapsed_time} seconds")
|
||||
return result
|
||||
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
return timing_wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def log_generator_function_time(
|
||||
func_name: str | None = None,
|
||||
) -> Callable[[FG], FG]:
|
||||
"""Build a timing wrapper for a function which returns a generator.
|
||||
Logs how long the function took to run.
|
||||
Use like:
|
||||
|
||||
@log_generator_function_time()
|
||||
def my_func():
|
||||
...
|
||||
yield X
|
||||
...
|
||||
"""
|
||||
|
||||
def timing_wrapper(func: FG) -> FG:
|
||||
def log_generator_function_time(func_name: str | None = None) -> Callable[[FG], FG]:
|
||||
def decorator(func: FG) -> FG:
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
yield from func(*args, **kwargs)
|
||||
logger.info(
|
||||
f"{func_name or func.__name__} took {time.time() - start_time} seconds"
|
||||
)
|
||||
gen = func(*args, **kwargs)
|
||||
try:
|
||||
value = next(gen)
|
||||
while True:
|
||||
yield value
|
||||
value = next(gen)
|
||||
except StopIteration:
|
||||
pass
|
||||
finally:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"{func_name or func.__name__} took {elapsed_time} seconds")
|
||||
|
||||
return cast(FG, wrapped_func)
|
||||
|
||||
return timing_wrapper
|
||||
return decorator
|
||||
|
Reference in New Issue
Block a user