add new standard answer test endpoint (#1789)

This commit is contained in:
pablodanswer 2024-07-12 10:06:30 -07:00 committed by GitHub
parent e90c66c1b6
commit c7af6a4601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 80 additions and 0 deletions

View File

@ -18,7 +18,42 @@ from danswer.db.chat import get_chat_sessions_by_slack_thread_id
from danswer.db.chat import get_or_create_root_message from danswer.db.chat import get_or_create_root_message
from danswer.db.models import Prompt from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotConfig
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from danswer.db.standard_answer import find_matching_standard_answers from danswer.db.standard_answer import find_matching_standard_answers
from danswer.server.manage.models import StandardAnswer
from danswer.utils.logger import setup_logger
logger = setup_logger()
def oneoff_standard_answers(
message: str,
slack_bot_categories: list[str],
db_session: Session,
) -> list[StandardAnswer]:
"""
Respond to the user message if it matches any configured standard answers.
Returns a list of matching StandardAnswers if found, otherwise None.
"""
configured_standard_answers = {
standard_answer
for category in fetch_standard_answer_categories_by_names(
slack_bot_categories, db_session=db_session
)
for standard_answer in category.standard_answers
}
matching_standard_answers = find_matching_standard_answers(
query=message,
id_in=[answer.id for answer in configured_standard_answers],
db_session=db_session,
)
server_standard_answers = [
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
]
return server_standard_answers
def handle_standard_answers( def handle_standard_answers(

View File

@ -140,6 +140,17 @@ def fetch_standard_answer_category(
) )
def fetch_standard_answer_categories_by_names(
standard_answer_category_names: list[str],
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(
select(StandardAnswerCategory).where(
StandardAnswerCategory.name.in_(standard_answer_category_names)
)
).all()
def fetch_standard_answer_categories_by_ids( def fetch_standard_answer_categories_by_ids(
standard_answer_category_ids: list[int], standard_answer_category_ids: list[int],
db_session: Session, db_session: Session,

View File

@ -4,6 +4,16 @@ from danswer.configs.constants import DocumentSource
from danswer.search.enums import SearchType from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext from danswer.search.models import ChunkContext
from danswer.search.models import RetrievalDetails from danswer.search.models import RetrievalDetails
from danswer.server.manage.models import StandardAnswer
class StandardAnswerRequest(BaseModel):
message: str
slack_bot_categories: list[str]
class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = []
class DocumentSearchRequest(ChunkContext): class DocumentSearchRequest(ChunkContext):

View File

@ -1,10 +1,14 @@
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import Depends from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.danswerbot.slack.handlers.handle_standard_answers import (
oneoff_standard_answers,
)
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.models import User from danswer.db.models import User
from danswer.db.persona import get_persona_by_id from danswer.db.persona import get_persona_by_id
@ -25,6 +29,8 @@ from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices from danswer.search.utils import drop_llm_indices
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
logger = setup_logger() logger = setup_logger()
@ -155,3 +161,21 @@ def get_answer_with_quote(
) )
return answer_details return answer_details
@basic_router.get("/standard-answer")
def get_standard_answer(
request: StandardAnswerRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> StandardAnswerResponse:
try:
standard_answers = oneoff_standard_answers(
message=request.message,
slack_bot_categories=request.slack_bot_categories,
db_session=db_session,
)
return StandardAnswerResponse(standard_answers=standard_answers)
except Exception as e:
logger.error(f"Error in get_standard_answer: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="An internal server error occurred")

0
backend/query_backend.py Normal file
View File