mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
parent
0a7775860c
commit
51ec2517cb
@ -9,8 +9,6 @@ from danswer.configs.constants import ModelHostType
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||
@ -62,12 +60,10 @@ def get_default_qa_handler(model: str) -> QAHandler:
|
||||
|
||||
def get_default_qa_model(
|
||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
endpoint: str | None = GEN_AI_ENDPOINT,
|
||||
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
||||
api_key: str | None = GEN_AI_API_KEY,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
**kwargs: Any,
|
||||
) -> QAModel:
|
||||
if not api_key:
|
||||
@ -79,16 +75,7 @@ def get_default_qa_model(
|
||||
try:
|
||||
# un-used arguments will be ignored by the underlying `LLM` class
|
||||
# if any args are missing, a `TypeError` will be thrown
|
||||
llm = get_default_llm(
|
||||
model=internal_model,
|
||||
api_key=api_key,
|
||||
model_version=model_version,
|
||||
endpoint=endpoint,
|
||||
model_host_type=model_host_type,
|
||||
timeout=timeout,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
llm = get_default_llm()
|
||||
qa_handler = get_default_qa_handler(model=internal_model)
|
||||
|
||||
return QABlock(
|
||||
|
@ -3,10 +3,7 @@ from collections.abc import Iterator
|
||||
from copy import copy
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
@ -19,32 +16,8 @@ from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.llm.llm import LLM
|
||||
|
||||
|
||||
def _dict_based_prompt_to_langchain_prompt(
|
||||
messages: list[dict[str, str]]
|
||||
) -> list[BaseMessage]:
|
||||
prompt: list[BaseMessage] = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if not role:
|
||||
raise ValueError(f"Message missing `role`: {message}")
|
||||
if not content:
|
||||
raise ValueError(f"Message missing `content`: {message}")
|
||||
elif role == "user":
|
||||
prompt.append(HumanMessage(content=content))
|
||||
elif role == "system":
|
||||
prompt.append(SystemMessage(content=content))
|
||||
elif role == "assistant":
|
||||
prompt.append(AIMessage(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
return prompt
|
||||
|
||||
|
||||
def _str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
|
||||
return [HumanMessage(content=message)]
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import str_prompt_to_langchain_prompt
|
||||
|
||||
|
||||
class QAHandler(abc.ABC):
|
||||
@ -69,7 +42,7 @@ class JsonChatQAHandler(QAHandler):
|
||||
def build_prompt(
|
||||
self, query: str, context_chunks: list[InferenceChunk]
|
||||
) -> list[BaseMessage]:
|
||||
return _dict_based_prompt_to_langchain_prompt(
|
||||
return dict_based_prompt_to_langchain_prompt(
|
||||
JsonChatProcessor.fill_prompt(
|
||||
question=query, chunks=context_chunks, include_metadata=False
|
||||
)
|
||||
@ -91,7 +64,7 @@ class SimpleChatQAHandler(QAHandler):
|
||||
def build_prompt(
|
||||
self, query: str, context_chunks: list[InferenceChunk]
|
||||
) -> list[BaseMessage]:
|
||||
return _str_prompt_to_langchain_prompt(
|
||||
return str_prompt_to_langchain_prompt(
|
||||
WeakModelFreeformProcessor.fill_prompt(
|
||||
question=query,
|
||||
chunks=context_chunks,
|
||||
|
@ -1,16 +1,36 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.llm.azure import AzureGPT
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.openai import OpenAIGPT
|
||||
|
||||
|
||||
def get_default_llm(model: str, **kwargs: Any) -> LLM:
|
||||
def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
|
||||
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
if API_TYPE_OPENAI == "azure":
|
||||
return AzureGPT(**kwargs)
|
||||
return OpenAIGPT(**kwargs)
|
||||
|
||||
raise ValueError(f"Unknown LLM model: {model}")
|
||||
|
||||
|
||||
def get_default_llm(**kwargs: Any) -> LLM:
|
||||
return get_llm_from_model(
|
||||
model=INTERNAL_MODEL_VERSION,
|
||||
api_key=GEN_AI_API_KEY,
|
||||
model_version=GEN_AI_MODEL_VERSION,
|
||||
endpoint=GEN_AI_ENDPOINT,
|
||||
model_host_type=GEN_AI_HOST_TYPE,
|
||||
timeout=QA_TIMEOUT,
|
||||
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -2,15 +2,43 @@ from collections.abc import Iterator
|
||||
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.schema import (
|
||||
PromptValue,
|
||||
)
|
||||
from langchain.schema import PromptValue
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import BaseMessageChunk
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.configs.app_configs import LOG_LEVEL
|
||||
|
||||
|
||||
def dict_based_prompt_to_langchain_prompt(
|
||||
messages: list[dict[str, str]]
|
||||
) -> list[BaseMessage]:
|
||||
prompt: list[BaseMessage] = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if not role:
|
||||
raise ValueError(f"Message missing `role`: {message}")
|
||||
if not content:
|
||||
raise ValueError(f"Message missing `content`: {message}")
|
||||
elif role == "user":
|
||||
prompt.append(HumanMessage(content=content))
|
||||
elif role == "system":
|
||||
prompt.append(SystemMessage(content=content))
|
||||
elif role == "assistant":
|
||||
prompt.append(AIMessage(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
return prompt
|
||||
|
||||
|
||||
def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
|
||||
return [HumanMessage(content=message)]
|
||||
|
||||
|
||||
def message_generator_to_string_generator(
|
||||
messages: Iterator[BaseMessageChunk],
|
||||
) -> Iterator[str]:
|
||||
@ -18,21 +46,21 @@ def message_generator_to_string_generator(
|
||||
yield message.content
|
||||
|
||||
|
||||
def convert_input(input: LanguageModelInput) -> str:
|
||||
def convert_input(lm_input: LanguageModelInput) -> str:
|
||||
"""Heavily inspired by:
|
||||
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
|
||||
"""
|
||||
prompt_value = None
|
||||
if isinstance(input, PromptValue):
|
||||
prompt_value = input
|
||||
elif isinstance(input, str):
|
||||
prompt_value = StringPromptValue(text=input)
|
||||
elif isinstance(input, list):
|
||||
prompt_value = ChatPromptValue(messages=input)
|
||||
if isinstance(lm_input, PromptValue):
|
||||
prompt_value = lm_input
|
||||
elif isinstance(lm_input, str):
|
||||
prompt_value = StringPromptValue(text=lm_input)
|
||||
elif isinstance(lm_input, list):
|
||||
prompt_value = ChatPromptValue(messages=lm_input)
|
||||
|
||||
if prompt_value is None:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
f"Invalid input type {type(lm_input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
|
||||
|
0
backend/danswer/secondary_llm_flows/__init__.py
Normal file
0
backend/danswer/secondary_llm_flows/__init__.py
Normal file
107
backend/danswer/secondary_llm_flows/query_validation.py
Normal file
107
backend/danswer/secondary_llm_flows/query_validation.py
Normal file
@ -0,0 +1,107 @@
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import asdict
|
||||
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.server.models import QueryValidationResponse
|
||||
from danswer.server.utils import get_json_line
|
||||
|
||||
REASONING_PAT = "REASONING: "
|
||||
ANSWERABLE_PAT = "ANSWERABLE: "
|
||||
COT_PAT = "\nLet's think step by step"
|
||||
|
||||
|
||||
def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a helper tool to determine if a query is answerable using retrieval augmented "
|
||||
f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant "
|
||||
f"documents found from search. Sources contain both up to date and proprietary information for "
|
||||
f"the specific team. For named or unknown entities, assume the search will always find "
|
||||
f"consistent knowledge about the entity. Determine if that system should attempt to answer. "
|
||||
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"',
|
||||
},
|
||||
{"role": "user", "content": "What is this Slack channel about?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"{REASONING_PAT}First the system must determine which Slack channel is being referred to."
|
||||
f"By fetching 5 documents related to Slack channel contents, it is not possible to determine"
|
||||
f"which Slack channel the user is referring to.\n{ANSWERABLE_PAT}False",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Danswer is unreachable.{COT_PAT}",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"{REASONING_PAT}The system searches documents related to Danswer being "
|
||||
f"unreachable. Assuming the documents from search contains situations where Danswer is not "
|
||||
f"reachable and contains a fix, the query is answerable.\n{ANSWERABLE_PAT}True",
|
||||
},
|
||||
{"role": "user", "content": f"How many customers do we have?{COT_PAT}"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"{REASONING_PAT}Assuming the searched documents contains customer acquisition information"
|
||||
f"including a list of customers, the query can be answered.\n{ANSWERABLE_PAT}True",
|
||||
},
|
||||
{"role": "user", "content": user_query + COT_PAT},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def extract_answerability_reasoning(model_raw: str) -> str:
|
||||
reasoning_match = re.search(
|
||||
f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_raw, re.DOTALL
|
||||
)
|
||||
reasoning_text = reasoning_match.group(1).strip() if reasoning_match else ""
|
||||
return reasoning_text
|
||||
|
||||
|
||||
def extract_answerability_bool(model_raw: str) -> bool:
|
||||
answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_raw)
|
||||
answerable_text = answerable_match.group(1).strip() if answerable_match else ""
|
||||
answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False
|
||||
return answerable
|
||||
|
||||
|
||||
def get_query_answerability(user_query: str) -> tuple[str, bool]:
|
||||
messages = get_query_validation_messages(user_query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
|
||||
reasoning = extract_answerability_reasoning(model_output)
|
||||
answerable = extract_answerability_bool(model_output)
|
||||
|
||||
return reasoning, answerable
|
||||
|
||||
|
||||
def stream_query_answerability(user_query: str) -> Iterator[str]:
|
||||
messages = get_query_validation_messages(user_query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
tokens = get_default_llm().stream(filled_llm_prompt)
|
||||
reasoning_pat_found = False
|
||||
model_output = ""
|
||||
for token in tokens:
|
||||
model_output = model_output + token
|
||||
|
||||
if not reasoning_pat_found and REASONING_PAT in model_output:
|
||||
reasoning_pat_found = True
|
||||
remaining = model_output[len(REASONING_PAT) :]
|
||||
if remaining:
|
||||
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
|
||||
continue
|
||||
|
||||
if reasoning_pat_found:
|
||||
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
|
||||
|
||||
reasoning = extract_answerability_reasoning(model_output)
|
||||
answerable = extract_answerability_bool(model_output)
|
||||
|
||||
yield get_json_line(
|
||||
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
|
||||
)
|
||||
return
|
@ -151,6 +151,11 @@ class SearchFeedbackRequest(BaseModel):
|
||||
search_feedback: SearchFeedbackType
|
||||
|
||||
|
||||
class QueryValidationResponse(BaseModel):
|
||||
reasoning: str
|
||||
answerable: bool
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
|
||||
top_ranked_docs: list[SearchDoc] | None
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict
|
||||
|
||||
@ -29,12 +28,16 @@ from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.semantic_search import chunks_to_search_docs
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||
from danswer.server.models import HelperResponse
|
||||
from danswer.server.models import QAFeedbackRequest
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QueryValidationResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import SearchFeedbackRequest
|
||||
from danswer.server.models import SearchResponse
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@ -43,10 +46,6 @@ logger = setup_logger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_json_line(json_dict: dict) -> str:
|
||||
return json.dumps(json_dict) + "\n"
|
||||
|
||||
|
||||
@router.get("/search-intent")
|
||||
def get_search_type(
|
||||
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
||||
@ -56,6 +55,25 @@ def get_search_type(
|
||||
return recommend_search_flow(query, use_keyword)
|
||||
|
||||
|
||||
@router.get("/query-validation")
|
||||
def query_validation(
|
||||
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
||||
) -> QueryValidationResponse:
|
||||
query = question.query
|
||||
reasoning, answerable = get_query_answerability(query)
|
||||
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)
|
||||
|
||||
|
||||
@router.get("/stream-query-validation")
|
||||
def stream_query_validation(
|
||||
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
||||
) -> StreamingResponse:
|
||||
query = question.query
|
||||
return StreamingResponse(
|
||||
stream_query_answerability(query), media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/semantic-search")
|
||||
def semantic_search(
|
||||
question: QuestionRequest,
|
||||
|
@ -1,6 +1,11 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_json_line(json_dict: dict) -> str:
|
||||
return json.dumps(json_dict) + "\n"
|
||||
|
||||
|
||||
def mask_string(sensitive_str: str) -> str:
|
||||
return "****...**" + sensitive_str[-4:]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user