mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-18 07:40:05 +02:00
parent
0a7775860c
commit
51ec2517cb
@ -9,8 +9,6 @@ from danswer.configs.constants import ModelHostType
|
|||||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
|
||||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
from danswer.direct_qa.exceptions import UnknownModelError
|
||||||
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||||
@ -62,12 +60,10 @@ def get_default_qa_handler(model: str) -> QAHandler:
|
|||||||
|
|
||||||
def get_default_qa_model(
|
def get_default_qa_model(
|
||||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||||
model_version: str = GEN_AI_MODEL_VERSION,
|
|
||||||
endpoint: str | None = GEN_AI_ENDPOINT,
|
endpoint: str | None = GEN_AI_ENDPOINT,
|
||||||
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
||||||
api_key: str | None = GEN_AI_API_KEY,
|
api_key: str | None = GEN_AI_API_KEY,
|
||||||
timeout: int = QA_TIMEOUT,
|
timeout: int = QA_TIMEOUT,
|
||||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> QAModel:
|
) -> QAModel:
|
||||||
if not api_key:
|
if not api_key:
|
||||||
@ -79,16 +75,7 @@ def get_default_qa_model(
|
|||||||
try:
|
try:
|
||||||
# un-used arguments will be ignored by the underlying `LLM` class
|
# un-used arguments will be ignored by the underlying `LLM` class
|
||||||
# if any args are missing, a `TypeError` will be thrown
|
# if any args are missing, a `TypeError` will be thrown
|
||||||
llm = get_default_llm(
|
llm = get_default_llm()
|
||||||
model=internal_model,
|
|
||||||
api_key=api_key,
|
|
||||||
model_version=model_version,
|
|
||||||
endpoint=endpoint,
|
|
||||||
model_host_type=model_host_type,
|
|
||||||
timeout=timeout,
|
|
||||||
max_output_tokens=max_output_tokens,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
qa_handler = get_default_qa_handler(model=internal_model)
|
qa_handler = get_default_qa_handler(model=internal_model)
|
||||||
|
|
||||||
return QABlock(
|
return QABlock(
|
||||||
|
@ -3,10 +3,7 @@ from collections.abc import Iterator
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from langchain.schema.messages import AIMessage
|
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
from langchain.schema.messages import HumanMessage
|
|
||||||
from langchain.schema.messages import SystemMessage
|
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||||
@ -19,32 +16,8 @@ from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
|||||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||||
from danswer.llm.llm import LLM
|
from danswer.llm.llm import LLM
|
||||||
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
|
from danswer.llm.utils import str_prompt_to_langchain_prompt
|
||||||
def _dict_based_prompt_to_langchain_prompt(
|
|
||||||
messages: list[dict[str, str]]
|
|
||||||
) -> list[BaseMessage]:
|
|
||||||
prompt: list[BaseMessage] = []
|
|
||||||
for message in messages:
|
|
||||||
role = message.get("role")
|
|
||||||
content = message.get("content")
|
|
||||||
if not role:
|
|
||||||
raise ValueError(f"Message missing `role`: {message}")
|
|
||||||
if not content:
|
|
||||||
raise ValueError(f"Message missing `content`: {message}")
|
|
||||||
elif role == "user":
|
|
||||||
prompt.append(HumanMessage(content=content))
|
|
||||||
elif role == "system":
|
|
||||||
prompt.append(SystemMessage(content=content))
|
|
||||||
elif role == "assistant":
|
|
||||||
prompt.append(AIMessage(content=content))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown role: {role}")
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def _str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
|
|
||||||
return [HumanMessage(content=message)]
|
|
||||||
|
|
||||||
|
|
||||||
class QAHandler(abc.ABC):
|
class QAHandler(abc.ABC):
|
||||||
@ -69,7 +42,7 @@ class JsonChatQAHandler(QAHandler):
|
|||||||
def build_prompt(
|
def build_prompt(
|
||||||
self, query: str, context_chunks: list[InferenceChunk]
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
return _dict_based_prompt_to_langchain_prompt(
|
return dict_based_prompt_to_langchain_prompt(
|
||||||
JsonChatProcessor.fill_prompt(
|
JsonChatProcessor.fill_prompt(
|
||||||
question=query, chunks=context_chunks, include_metadata=False
|
question=query, chunks=context_chunks, include_metadata=False
|
||||||
)
|
)
|
||||||
@ -91,7 +64,7 @@ class SimpleChatQAHandler(QAHandler):
|
|||||||
def build_prompt(
|
def build_prompt(
|
||||||
self, query: str, context_chunks: list[InferenceChunk]
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
return _str_prompt_to_langchain_prompt(
|
return str_prompt_to_langchain_prompt(
|
||||||
WeakModelFreeformProcessor.fill_prompt(
|
WeakModelFreeformProcessor.fill_prompt(
|
||||||
question=query,
|
question=query,
|
||||||
chunks=context_chunks,
|
chunks=context_chunks,
|
||||||
|
@ -1,16 +1,36 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.constants import DanswerGenAIModel
|
from danswer.configs.constants import DanswerGenAIModel
|
||||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
from danswer.configs.model_configs import API_TYPE_OPENAI
|
||||||
|
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||||
|
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||||
|
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
|
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||||
from danswer.llm.azure import AzureGPT
|
from danswer.llm.azure import AzureGPT
|
||||||
from danswer.llm.llm import LLM
|
from danswer.llm.llm import LLM
|
||||||
from danswer.llm.openai import OpenAIGPT
|
from danswer.llm.openai import OpenAIGPT
|
||||||
|
|
||||||
|
|
||||||
def get_default_llm(model: str, **kwargs: Any) -> LLM:
|
def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
|
||||||
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||||
if API_TYPE_OPENAI == "azure":
|
if API_TYPE_OPENAI == "azure":
|
||||||
return AzureGPT(**kwargs)
|
return AzureGPT(**kwargs)
|
||||||
return OpenAIGPT(**kwargs)
|
return OpenAIGPT(**kwargs)
|
||||||
|
|
||||||
raise ValueError(f"Unknown LLM model: {model}")
|
raise ValueError(f"Unknown LLM model: {model}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_llm(**kwargs: Any) -> LLM:
|
||||||
|
return get_llm_from_model(
|
||||||
|
model=INTERNAL_MODEL_VERSION,
|
||||||
|
api_key=GEN_AI_API_KEY,
|
||||||
|
model_version=GEN_AI_MODEL_VERSION,
|
||||||
|
endpoint=GEN_AI_ENDPOINT,
|
||||||
|
model_host_type=GEN_AI_HOST_TYPE,
|
||||||
|
timeout=QA_TIMEOUT,
|
||||||
|
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
@ -2,15 +2,43 @@ from collections.abc import Iterator
|
|||||||
|
|
||||||
from langchain.prompts.base import StringPromptValue
|
from langchain.prompts.base import StringPromptValue
|
||||||
from langchain.prompts.chat import ChatPromptValue
|
from langchain.prompts.chat import ChatPromptValue
|
||||||
from langchain.schema import (
|
from langchain.schema import PromptValue
|
||||||
PromptValue,
|
|
||||||
)
|
|
||||||
from langchain.schema.language_model import LanguageModelInput
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
|
from langchain.schema.messages import AIMessage
|
||||||
|
from langchain.schema.messages import BaseMessage
|
||||||
from langchain.schema.messages import BaseMessageChunk
|
from langchain.schema.messages import BaseMessageChunk
|
||||||
|
from langchain.schema.messages import HumanMessage
|
||||||
|
from langchain.schema.messages import SystemMessage
|
||||||
|
|
||||||
from danswer.configs.app_configs import LOG_LEVEL
|
from danswer.configs.app_configs import LOG_LEVEL
|
||||||
|
|
||||||
|
|
||||||
|
def dict_based_prompt_to_langchain_prompt(
|
||||||
|
messages: list[dict[str, str]]
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
prompt: list[BaseMessage] = []
|
||||||
|
for message in messages:
|
||||||
|
role = message.get("role")
|
||||||
|
content = message.get("content")
|
||||||
|
if not role:
|
||||||
|
raise ValueError(f"Message missing `role`: {message}")
|
||||||
|
if not content:
|
||||||
|
raise ValueError(f"Message missing `content`: {message}")
|
||||||
|
elif role == "user":
|
||||||
|
prompt.append(HumanMessage(content=content))
|
||||||
|
elif role == "system":
|
||||||
|
prompt.append(SystemMessage(content=content))
|
||||||
|
elif role == "assistant":
|
||||||
|
prompt.append(AIMessage(content=content))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown role: {role}")
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
|
||||||
|
return [HumanMessage(content=message)]
|
||||||
|
|
||||||
|
|
||||||
def message_generator_to_string_generator(
|
def message_generator_to_string_generator(
|
||||||
messages: Iterator[BaseMessageChunk],
|
messages: Iterator[BaseMessageChunk],
|
||||||
) -> Iterator[str]:
|
) -> Iterator[str]:
|
||||||
@ -18,21 +46,21 @@ def message_generator_to_string_generator(
|
|||||||
yield message.content
|
yield message.content
|
||||||
|
|
||||||
|
|
||||||
def convert_input(input: LanguageModelInput) -> str:
|
def convert_input(lm_input: LanguageModelInput) -> str:
|
||||||
"""Heavily inspired by:
|
"""Heavily inspired by:
|
||||||
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
|
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
|
||||||
"""
|
"""
|
||||||
prompt_value = None
|
prompt_value = None
|
||||||
if isinstance(input, PromptValue):
|
if isinstance(lm_input, PromptValue):
|
||||||
prompt_value = input
|
prompt_value = lm_input
|
||||||
elif isinstance(input, str):
|
elif isinstance(lm_input, str):
|
||||||
prompt_value = StringPromptValue(text=input)
|
prompt_value = StringPromptValue(text=lm_input)
|
||||||
elif isinstance(input, list):
|
elif isinstance(lm_input, list):
|
||||||
prompt_value = ChatPromptValue(messages=input)
|
prompt_value = ChatPromptValue(messages=lm_input)
|
||||||
|
|
||||||
if prompt_value is None:
|
if prompt_value is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid input type {type(input)}. "
|
f"Invalid input type {type(lm_input)}. "
|
||||||
"Must be a PromptValue, str, or list of BaseMessages."
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
0
backend/danswer/secondary_llm_flows/__init__.py
Normal file
0
backend/danswer/secondary_llm_flows/__init__.py
Normal file
107
backend/danswer/secondary_llm_flows/query_validation.py
Normal file
107
backend/danswer/secondary_llm_flows/query_validation.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import re
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
|
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||||
|
from danswer.llm.build import get_default_llm
|
||||||
|
from danswer.server.models import QueryValidationResponse
|
||||||
|
from danswer.server.utils import get_json_line
|
||||||
|
|
||||||
|
REASONING_PAT = "REASONING: "
|
||||||
|
ANSWERABLE_PAT = "ANSWERABLE: "
|
||||||
|
COT_PAT = "\nLet's think step by step"
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"You are a helper tool to determine if a query is answerable using retrieval augmented "
|
||||||
|
f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant "
|
||||||
|
f"documents found from search. Sources contain both up to date and proprietary information for "
|
||||||
|
f"the specific team. For named or unknown entities, assume the search will always find "
|
||||||
|
f"consistent knowledge about the entity. Determine if that system should attempt to answer. "
|
||||||
|
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"',
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "What is this Slack channel about?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"{REASONING_PAT}First the system must determine which Slack channel is being referred to."
|
||||||
|
f"By fetching 5 documents related to Slack channel contents, it is not possible to determine"
|
||||||
|
f"which Slack channel the user is referring to.\n{ANSWERABLE_PAT}False",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Danswer is unreachable.{COT_PAT}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"{REASONING_PAT}The system searches documents related to Danswer being "
|
||||||
|
f"unreachable. Assuming the documents from search contains situations where Danswer is not "
|
||||||
|
f"reachable and contains a fix, the query is answerable.\n{ANSWERABLE_PAT}True",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": f"How many customers do we have?{COT_PAT}"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"{REASONING_PAT}Assuming the searched documents contains customer acquisition information"
|
||||||
|
f"including a list of customers, the query can be answered.\n{ANSWERABLE_PAT}True",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": user_query + COT_PAT},
|
||||||
|
]
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answerability_reasoning(model_raw: str) -> str:
|
||||||
|
reasoning_match = re.search(
|
||||||
|
f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_raw, re.DOTALL
|
||||||
|
)
|
||||||
|
reasoning_text = reasoning_match.group(1).strip() if reasoning_match else ""
|
||||||
|
return reasoning_text
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answerability_bool(model_raw: str) -> bool:
|
||||||
|
answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_raw)
|
||||||
|
answerable_text = answerable_match.group(1).strip() if answerable_match else ""
|
||||||
|
answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False
|
||||||
|
return answerable
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_answerability(user_query: str) -> tuple[str, bool]:
|
||||||
|
messages = get_query_validation_messages(user_query)
|
||||||
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
|
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||||
|
|
||||||
|
reasoning = extract_answerability_reasoning(model_output)
|
||||||
|
answerable = extract_answerability_bool(model_output)
|
||||||
|
|
||||||
|
return reasoning, answerable
|
||||||
|
|
||||||
|
|
||||||
|
def stream_query_answerability(user_query: str) -> Iterator[str]:
|
||||||
|
messages = get_query_validation_messages(user_query)
|
||||||
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
|
tokens = get_default_llm().stream(filled_llm_prompt)
|
||||||
|
reasoning_pat_found = False
|
||||||
|
model_output = ""
|
||||||
|
for token in tokens:
|
||||||
|
model_output = model_output + token
|
||||||
|
|
||||||
|
if not reasoning_pat_found and REASONING_PAT in model_output:
|
||||||
|
reasoning_pat_found = True
|
||||||
|
remaining = model_output[len(REASONING_PAT) :]
|
||||||
|
if remaining:
|
||||||
|
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if reasoning_pat_found:
|
||||||
|
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
|
||||||
|
|
||||||
|
reasoning = extract_answerability_reasoning(model_output)
|
||||||
|
answerable = extract_answerability_bool(model_output)
|
||||||
|
|
||||||
|
yield get_json_line(
|
||||||
|
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
|
||||||
|
)
|
||||||
|
return
|
@ -151,6 +151,11 @@ class SearchFeedbackRequest(BaseModel):
|
|||||||
search_feedback: SearchFeedbackType
|
search_feedback: SearchFeedbackType
|
||||||
|
|
||||||
|
|
||||||
|
class QueryValidationResponse(BaseModel):
|
||||||
|
reasoning: str
|
||||||
|
answerable: bool
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(BaseModel):
|
class SearchResponse(BaseModel):
|
||||||
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
|
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
|
||||||
top_ranked_docs: list[SearchDoc] | None
|
top_ranked_docs: list[SearchDoc] | None
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
@ -29,12 +28,16 @@ from danswer.search.models import QueryFlow
|
|||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
from danswer.search.semantic_search import chunks_to_search_docs
|
from danswer.search.semantic_search import chunks_to_search_docs
|
||||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||||
|
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||||
|
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||||
from danswer.server.models import HelperResponse
|
from danswer.server.models import HelperResponse
|
||||||
from danswer.server.models import QAFeedbackRequest
|
from danswer.server.models import QAFeedbackRequest
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
|
from danswer.server.models import QueryValidationResponse
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.server.models import SearchFeedbackRequest
|
from danswer.server.models import SearchFeedbackRequest
|
||||||
from danswer.server.models import SearchResponse
|
from danswer.server.models import SearchResponse
|
||||||
|
from danswer.server.utils import get_json_line
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.timing import log_generator_function_time
|
from danswer.utils.timing import log_generator_function_time
|
||||||
|
|
||||||
@ -43,10 +46,6 @@ logger = setup_logger()
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def get_json_line(json_dict: dict) -> str:
|
|
||||||
return json.dumps(json_dict) + "\n"
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/search-intent")
|
@router.get("/search-intent")
|
||||||
def get_search_type(
|
def get_search_type(
|
||||||
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
||||||
@ -56,6 +55,25 @@ def get_search_type(
|
|||||||
return recommend_search_flow(query, use_keyword)
|
return recommend_search_flow(query, use_keyword)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/query-validation")
|
||||||
|
def query_validation(
|
||||||
|
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
||||||
|
) -> QueryValidationResponse:
|
||||||
|
query = question.query
|
||||||
|
reasoning, answerable = get_query_answerability(query)
|
||||||
|
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stream-query-validation")
|
||||||
|
def stream_query_validation(
|
||||||
|
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
||||||
|
) -> StreamingResponse:
|
||||||
|
query = question.query
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_query_answerability(query), media_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/semantic-search")
|
@router.post("/semantic-search")
|
||||||
def semantic_search(
|
def semantic_search(
|
||||||
question: QuestionRequest,
|
question: QuestionRequest,
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def get_json_line(json_dict: dict) -> str:
|
||||||
|
return json.dumps(json_dict) + "\n"
|
||||||
|
|
||||||
|
|
||||||
def mask_string(sensitive_str: str) -> str:
|
def mask_string(sensitive_str: str) -> str:
|
||||||
return "****...**" + sensitive_str[-4:]
|
return "****...**" + sensitive_str[-4:]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user