mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
add new standard answer test endpoint (#1789)
This commit is contained in:
parent
e90c66c1b6
commit
c7af6a4601
@ -18,7 +18,42 @@ from danswer.db.chat import get_chat_sessions_by_slack_thread_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from danswer.db.standard_answer import find_matching_standard_answers
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def oneoff_standard_answers(
|
||||
message: str,
|
||||
slack_bot_categories: list[str],
|
||||
db_session: Session,
|
||||
) -> list[StandardAnswer]:
|
||||
"""
|
||||
Respond to the user message if it matches any configured standard answers.
|
||||
|
||||
Returns a list of matching StandardAnswers if found, otherwise None.
|
||||
"""
|
||||
configured_standard_answers = {
|
||||
standard_answer
|
||||
for category in fetch_standard_answer_categories_by_names(
|
||||
slack_bot_categories, db_session=db_session
|
||||
)
|
||||
for standard_answer in category.standard_answers
|
||||
}
|
||||
|
||||
matching_standard_answers = find_matching_standard_answers(
|
||||
query=message,
|
||||
id_in=[answer.id for answer in configured_standard_answers],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
server_standard_answers = [
|
||||
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
|
||||
]
|
||||
return server_standard_answers
|
||||
|
||||
|
||||
def handle_standard_answers(
|
||||
|
@ -140,6 +140,17 @@ def fetch_standard_answer_category(
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_names(
|
||||
standard_answer_category_names: list[str],
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.name.in_(standard_answer_category_names)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids: list[int],
|
||||
db_session: Session,
|
||||
|
@ -4,6 +4,16 @@ from danswer.configs.constants import DocumentSource
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
message: str
|
||||
slack_bot_categories: list[str]
|
||||
|
||||
|
||||
class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = []
|
||||
|
||||
|
||||
class DocumentSearchRequest(ChunkContext):
|
||||
|
@ -1,10 +1,14 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
oneoff_standard_answers,
|
||||
)
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
@ -25,6 +29,8 @@ from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
||||
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
|
||||
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -155,3 +161,21 @@ def get_answer_with_quote(
|
||||
)
|
||||
|
||||
return answer_details
|
||||
|
||||
|
||||
@basic_router.get("/standard-answer")
|
||||
def get_standard_answer(
|
||||
request: StandardAnswerRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> StandardAnswerResponse:
|
||||
try:
|
||||
standard_answers = oneoff_standard_answers(
|
||||
message=request.message,
|
||||
slack_bot_categories=request.slack_bot_categories,
|
||||
db_session=db_session,
|
||||
)
|
||||
return StandardAnswerResponse(standard_answers=standard_answers)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_standard_answer: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="An internal server error occurred")
|
||||
|
0
backend/query_backend.py
Normal file
0
backend/query_backend.py
Normal file
Loading…
x
Reference in New Issue
Block a user