LLM to validate user Query (#365)

Backend Only
This commit is contained in:
Yuhong Sun 2023-08-31 15:33:39 -07:00 committed by GitHub
parent 0a7775860c
commit 51ec2517cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 205 additions and 62 deletions

View File

@ -9,8 +9,6 @@ from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
@ -62,12 +60,10 @@ def get_default_qa_handler(model: str) -> QAHandler:
def get_default_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
model_version: str = GEN_AI_MODEL_VERSION,
endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
**kwargs: Any,
) -> QAModel:
if not api_key:
@ -79,16 +75,7 @@ def get_default_qa_model(
try:
# un-used arguments will be ignored by the underlying `LLM` class
# if any args are missing, a `TypeError` will be thrown
llm = get_default_llm(
model=internal_model,
api_key=api_key,
model_version=model_version,
endpoint=endpoint,
model_host_type=model_host_type,
timeout=timeout,
max_output_tokens=max_output_tokens,
**kwargs,
)
llm = get_default_llm()
qa_handler = get_default_qa_handler(model=internal_model)
return QABlock(

View File

@ -3,10 +3,7 @@ from collections.abc import Iterator
from copy import copy
import tiktoken
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.interfaces import AnswerQuestionReturn
@ -19,32 +16,8 @@ from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.llm.llm import LLM
def _dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]]
) -> list[BaseMessage]:
prompt: list[BaseMessage] = []
for message in messages:
role = message.get("role")
content = message.get("content")
if not role:
raise ValueError(f"Message missing `role`: {message}")
if not content:
raise ValueError(f"Message missing `content`: {message}")
elif role == "user":
prompt.append(HumanMessage(content=content))
elif role == "system":
prompt.append(SystemMessage(content=content))
elif role == "assistant":
prompt.append(AIMessage(content=content))
else:
raise ValueError(f"Unknown role: {role}")
return prompt
def _str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
return [HumanMessage(content=message)]
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import str_prompt_to_langchain_prompt
class QAHandler(abc.ABC):
@ -69,7 +42,7 @@ class JsonChatQAHandler(QAHandler):
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
return _dict_based_prompt_to_langchain_prompt(
return dict_based_prompt_to_langchain_prompt(
JsonChatProcessor.fill_prompt(
question=query, chunks=context_chunks, include_metadata=False
)
@ -91,7 +64,7 @@ class SimpleChatQAHandler(QAHandler):
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
return _str_prompt_to_langchain_prompt(
return str_prompt_to_langchain_prompt(
WeakModelFreeformProcessor.fill_prompt(
question=query,
chunks=context_chunks,

View File

@ -1,16 +1,36 @@
from typing import Any
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.llm.azure import AzureGPT
from danswer.llm.llm import LLM
from danswer.llm.openai import OpenAIGPT
def get_default_llm(model: str, **kwargs: Any) -> LLM:
def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
if model == DanswerGenAIModel.OPENAI_CHAT.value:
if API_TYPE_OPENAI == "azure":
return AzureGPT(**kwargs)
return OpenAIGPT(**kwargs)
raise ValueError(f"Unknown LLM model: {model}")
def get_default_llm(**kwargs: Any) -> LLM:
return get_llm_from_model(
model=INTERNAL_MODEL_VERSION,
api_key=GEN_AI_API_KEY,
model_version=GEN_AI_MODEL_VERSION,
endpoint=GEN_AI_ENDPOINT,
model_host_type=GEN_AI_HOST_TYPE,
timeout=QA_TIMEOUT,
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
**kwargs,
)

View File

@ -2,15 +2,43 @@ from collections.abc import Iterator
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import (
PromptValue,
)
from langchain.schema import PromptValue
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import BaseMessageChunk
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.configs.app_configs import LOG_LEVEL
def dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]]
) -> list[BaseMessage]:
prompt: list[BaseMessage] = []
for message in messages:
role = message.get("role")
content = message.get("content")
if not role:
raise ValueError(f"Message missing `role`: {message}")
if not content:
raise ValueError(f"Message missing `content`: {message}")
elif role == "user":
prompt.append(HumanMessage(content=content))
elif role == "system":
prompt.append(SystemMessage(content=content))
elif role == "assistant":
prompt.append(AIMessage(content=content))
else:
raise ValueError(f"Unknown role: {role}")
return prompt
def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
return [HumanMessage(content=message)]
def message_generator_to_string_generator(
messages: Iterator[BaseMessageChunk],
) -> Iterator[str]:
@ -18,21 +46,21 @@ def message_generator_to_string_generator(
yield message.content
def convert_input(input: LanguageModelInput) -> str:
def convert_input(lm_input: LanguageModelInput) -> str:
"""Heavily inspired by:
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
"""
prompt_value = None
if isinstance(input, PromptValue):
prompt_value = input
elif isinstance(input, str):
prompt_value = StringPromptValue(text=input)
elif isinstance(input, list):
prompt_value = ChatPromptValue(messages=input)
if isinstance(lm_input, PromptValue):
prompt_value = lm_input
elif isinstance(lm_input, str):
prompt_value = StringPromptValue(text=lm_input)
elif isinstance(lm_input, list):
prompt_value = ChatPromptValue(messages=lm_input)
if prompt_value is None:
raise ValueError(
f"Invalid input type {type(input)}. "
f"Invalid input type {type(lm_input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)

View File

@ -0,0 +1,107 @@
import re
from collections.abc import Iterator
from dataclasses import asdict
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.llm.build import get_default_llm
from danswer.server.models import QueryValidationResponse
from danswer.server.utils import get_json_line
REASONING_PAT = "REASONING: "
ANSWERABLE_PAT = "ANSWERABLE: "
COT_PAT = "\nLet's think step by step"
def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
messages = [
{
"role": "system",
"content": f"You are a helper tool to determine if a query is answerable using retrieval augmented "
f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant "
f"documents found from search. Sources contain both up to date and proprietary information for "
f"the specific team. For named or unknown entities, assume the search will always find "
f"consistent knowledge about the entity. Determine if that system should attempt to answer. "
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"',
},
{"role": "user", "content": "What is this Slack channel about?"},
{
"role": "assistant",
"content": f"{REASONING_PAT}First the system must determine which Slack channel is being referred to."
f"By fetching 5 documents related to Slack channel contents, it is not possible to determine"
f"which Slack channel the user is referring to.\n{ANSWERABLE_PAT}False",
},
{
"role": "user",
"content": f"Danswer is unreachable.{COT_PAT}",
},
{
"role": "assistant",
"content": f"{REASONING_PAT}The system searches documents related to Danswer being "
f"unreachable. Assuming the documents from search contains situations where Danswer is not "
f"reachable and contains a fix, the query is answerable.\n{ANSWERABLE_PAT}True",
},
{"role": "user", "content": f"How many customers do we have?{COT_PAT}"},
{
"role": "assistant",
"content": f"{REASONING_PAT}Assuming the searched documents contains customer acquisition information"
f"including a list of customers, the query can be answered.\n{ANSWERABLE_PAT}True",
},
{"role": "user", "content": user_query + COT_PAT},
]
return messages
def extract_answerability_reasoning(model_raw: str) -> str:
reasoning_match = re.search(
f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_raw, re.DOTALL
)
reasoning_text = reasoning_match.group(1).strip() if reasoning_match else ""
return reasoning_text
def extract_answerability_bool(model_raw: str) -> bool:
answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_raw)
answerable_text = answerable_match.group(1).strip() if answerable_match else ""
answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False
return answerable
def get_query_answerability(user_query: str) -> tuple[str, bool]:
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)
return reasoning, answerable
def stream_query_answerability(user_query: str) -> Iterator[str]:
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
tokens = get_default_llm().stream(filled_llm_prompt)
reasoning_pat_found = False
model_output = ""
for token in tokens:
model_output = model_output + token
if not reasoning_pat_found and REASONING_PAT in model_output:
reasoning_pat_found = True
remaining = model_output[len(REASONING_PAT) :]
if remaining:
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
continue
if reasoning_pat_found:
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)
yield get_json_line(
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
)
return

View File

@ -151,6 +151,11 @@ class SearchFeedbackRequest(BaseModel):
search_feedback: SearchFeedbackType
class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool
class SearchResponse(BaseModel):
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
top_ranked_docs: list[SearchDoc] | None

View File

@ -1,4 +1,3 @@
import json
from collections.abc import Generator
from dataclasses import asdict
@ -29,12 +28,16 @@ from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
from danswer.server.models import HelperResponse
from danswer.server.models import QAFeedbackRequest
from danswer.server.models import QAResponse
from danswer.server.models import QueryValidationResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@ -43,10 +46,6 @@ logger = setup_logger()
router = APIRouter()
def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"
@router.get("/search-intent")
def get_search_type(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
@ -56,6 +55,25 @@ def get_search_type(
return recommend_search_flow(query, use_keyword)
@router.get("/query-validation")
def query_validation(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
) -> QueryValidationResponse:
query = question.query
reasoning, answerable = get_query_answerability(query)
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)
@router.get("/stream-query-validation")
def stream_query_validation(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
) -> StreamingResponse:
query = question.query
return StreamingResponse(
stream_query_answerability(query), media_type="application/json"
)
@router.post("/semantic-search")
def semantic_search(
question: QuestionRequest,

View File

@ -1,6 +1,11 @@
import json
from typing import Any
def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"
def mask_string(sensitive_str: str) -> str:
return "****...**" + sensitive_str[-4:]