add new standard answer test endpoint (#1789)

This commit is contained in:
pablodanswer 2024-07-12 10:06:30 -07:00 committed by GitHub
parent e90c66c1b6
commit c7af6a4601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 80 additions and 0 deletions

View File

@ -18,7 +18,42 @@ from danswer.db.chat import get_chat_sessions_by_slack_thread_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from danswer.db.standard_answer import find_matching_standard_answers
from danswer.server.manage.models import StandardAnswer
from danswer.utils.logger import setup_logger
logger = setup_logger()
def oneoff_standard_answers(
message: str,
slack_bot_categories: list[str],
db_session: Session,
) -> list[StandardAnswer]:
"""
Respond to the user message if it matches any configured standard answers.
Returns a list of matching StandardAnswers if found, otherwise None.
"""
configured_standard_answers = {
standard_answer
for category in fetch_standard_answer_categories_by_names(
slack_bot_categories, db_session=db_session
)
for standard_answer in category.standard_answers
}
matching_standard_answers = find_matching_standard_answers(
query=message,
id_in=[answer.id for answer in configured_standard_answers],
db_session=db_session,
)
server_standard_answers = [
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
]
return server_standard_answers
def handle_standard_answers(

View File

@ -140,6 +140,17 @@ def fetch_standard_answer_category(
)
def fetch_standard_answer_categories_by_names(
standard_answer_category_names: list[str],
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(
select(StandardAnswerCategory).where(
StandardAnswerCategory.name.in_(standard_answer_category_names)
)
).all()
def fetch_standard_answer_categories_by_ids(
standard_answer_category_ids: list[int],
db_session: Session,

View File

@ -4,6 +4,16 @@ from danswer.configs.constants import DocumentSource
from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RetrievalDetails
from danswer.server.manage.models import StandardAnswer
class StandardAnswerRequest(BaseModel):
message: str
slack_bot_categories: list[str]
class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = []
class DocumentSearchRequest(ChunkContext):

View File

@ -1,10 +1,14 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.danswerbot.slack.handlers.handle_standard_answers import (
oneoff_standard_answers,
)
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
@ -25,6 +29,8 @@ from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.utils.logger import setup_logger
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
logger = setup_logger()
@ -155,3 +161,21 @@ def get_answer_with_quote(
)
return answer_details
@basic_router.get("/standard-answer")
def get_standard_answer(
request: StandardAnswerRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> StandardAnswerResponse:
try:
standard_answers = oneoff_standard_answers(
message=request.message,
slack_bot_categories=request.slack_bot_categories,
db_session=db_session,
)
return StandardAnswerResponse(standard_answers=standard_answers)
except Exception as e:
logger.error(f"Error in get_standard_answer: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="An internal server error occurred")

0
backend/query_backend.py Normal file
View File