diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index d18d99dbc29c..da13905fe92c 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1347,55 +1347,6 @@ class ChannelConfig(TypedDict): follow_up_tags: NotRequired[list[str]] -class StandardAnswerCategory(Base): - __tablename__ = "standard_answer_category" - - id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(String, unique=True) - standard_answers: Mapped[list["StandardAnswer"]] = relationship( - "StandardAnswer", - secondary=StandardAnswer__StandardAnswerCategory.__table__, - back_populates="categories", - ) - slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship( - "SlackBotConfig", - secondary=SlackBotConfig__StandardAnswerCategory.__table__, - back_populates="standard_answer_categories", - ) - - -class StandardAnswer(Base): - __tablename__ = "standard_answer" - - id: Mapped[int] = mapped_column(primary_key=True) - keyword: Mapped[str] = mapped_column(String) - answer: Mapped[str] = mapped_column(String) - active: Mapped[bool] = mapped_column(Boolean) - match_regex: Mapped[bool] = mapped_column(Boolean) - match_any_keywords: Mapped[bool] = mapped_column(Boolean) - - __table_args__ = ( - Index( - "unique_keyword_active", - keyword, - active, - unique=True, - postgresql_where=(active == True), # noqa: E712 - ), - ) - - categories: Mapped[list[StandardAnswerCategory]] = relationship( - "StandardAnswerCategory", - secondary=StandardAnswer__StandardAnswerCategory.__table__, - back_populates="standard_answers", - ) - chat_messages: Mapped[list[ChatMessage]] = relationship( - "ChatMessage", - secondary=ChatMessage__StandardAnswer.__table__, - back_populates="standard_answers", - ) - - class SlackBotResponseType(str, PyEnum): QUOTES = "quotes" CITATIONS = "citations" @@ -1421,7 +1372,7 @@ class SlackBotConfig(Base): ) persona: Mapped[Persona | None] = relationship("Persona") - standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship( + standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship( "StandardAnswerCategory", secondary=SlackBotConfig__StandardAnswerCategory.__table__, back_populates="slack_bot_configs", @@ -1651,6 +1602,55 @@ class TokenRateLimit__UserGroup(Base): ) +class StandardAnswerCategory(Base): + __tablename__ = "standard_answer_category" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String, unique=True) + standard_answers: Mapped[list["StandardAnswer"]] = relationship( + "StandardAnswer", + secondary=StandardAnswer__StandardAnswerCategory.__table__, + back_populates="categories", + ) + slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship( + "SlackBotConfig", + secondary=SlackBotConfig__StandardAnswerCategory.__table__, + back_populates="standard_answer_categories", + ) + + +class StandardAnswer(Base): + __tablename__ = "standard_answer" + + id: Mapped[int] = mapped_column(primary_key=True) + keyword: Mapped[str] = mapped_column(String) + answer: Mapped[str] = mapped_column(String) + active: Mapped[bool] = mapped_column(Boolean) + match_regex: Mapped[bool] = mapped_column(Boolean) + match_any_keywords: Mapped[bool] = mapped_column(Boolean) + + __table_args__ = ( + Index( + "unique_keyword_active", + keyword, + active, + unique=True, + postgresql_where=(active == True), # noqa: E712 + ), + ) + + categories: Mapped[list[StandardAnswerCategory]] = relationship( + "StandardAnswerCategory", + secondary=StandardAnswer__StandardAnswerCategory.__table__, + back_populates="standard_answers", + ) + chat_messages: Mapped[list[ChatMessage]] = relationship( + "ChatMessage", + secondary=ChatMessage__StandardAnswer.__table__, + back_populates="standard_answers", + ) + + """Tables related to Permission Sync""" diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index 322dc4c4ed9c..a37bd18c0ec7 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session @@ -14,8 +15,11 @@ from danswer.db.models import User from danswer.db.persona import get_default_prompt from danswer.db.persona import mark_persona_as_deleted from danswer.db.persona import upsert_persona -from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids from danswer.search.enums import RecencyBiasSetting +from danswer.utils.errors import EERequiredError +from danswer.utils.variable_functionality import ( + fetch_versioned_implementation_with_fallback, +) def _build_persona_name(channel_names: list[str]) -> str: @@ -70,6 +74,10 @@ def create_slack_bot_persona( return persona +def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list: + return [] + + def insert_slack_bot_config( persona_id: int | None, channel_config: ChannelConfig, @@ -78,14 +86,29 @@ def insert_slack_bot_config( enable_auto_filters: bool, db_session: Session, ) -> SlackBotConfig: - existing_standard_answer_categories = fetch_standard_answer_categories_by_ids( - standard_answer_category_ids=standard_answer_category_ids, - db_session=db_session, - ) - if len(existing_standard_answer_categories) != len(standard_answer_category_ids): - raise ValueError( - f"Some or all categories with ids {standard_answer_category_ids} do not exist" + versioned_fetch_standard_answer_categories_by_ids = ( + fetch_versioned_implementation_with_fallback( + "danswer.db.standard_answer", + "fetch_standard_answer_categories_by_ids", + _no_ee_standard_answer_categories, ) + ) + existing_standard_answer_categories = ( + versioned_fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=standard_answer_category_ids, + db_session=db_session, + ) + ) + + if len(existing_standard_answer_categories) != len(standard_answer_category_ids): + if len(existing_standard_answer_categories) == 0: + raise EERequiredError( + "Standard answers are a paid Enterprise Edition feature - enable EE or remove standard answer categories" + ) + else: + raise ValueError( + f"Some or all categories with ids {standard_answer_category_ids} do not exist" + ) slack_bot_config = SlackBotConfig( persona_id=persona_id, @@ -117,9 +140,18 @@ def update_slack_bot_config( f"Unable to find slack bot config with ID {slack_bot_config_id}" ) - existing_standard_answer_categories = fetch_standard_answer_categories_by_ids( - standard_answer_category_ids=standard_answer_category_ids, - db_session=db_session, + versioned_fetch_standard_answer_categories_by_ids = ( + fetch_versioned_implementation_with_fallback( + "danswer.db.standard_answer", + "fetch_standard_answer_categories_by_ids", + _no_ee_standard_answer_categories, + ) + ) + existing_standard_answer_categories = ( + versioned_fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=standard_answer_category_ids, + db_session=db_session, + ) ) if len(existing_standard_answer_categories) != len(standard_answer_category_ids): raise ValueError( diff --git a/backend/danswer/db/standard_answer.py b/backend/danswer/db/standard_answer.py deleted file mode 100644 index 85d5d922889a..000000000000 --- a/backend/danswer/db/standard_answer.py +++ /dev/null @@ -1,202 +0,0 @@ -from collections.abc import Sequence - -from sqlalchemy import select -from sqlalchemy.orm import Session - -from danswer.db.models import StandardAnswer -from danswer.db.models import StandardAnswerCategory -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -def check_category_validity(category_name: str) -> bool: - """If a category name is too long, it should not be used (it will cause an error in Postgres - as the unique constraint can only apply to entries that are less than 2704 bytes). - - Additionally, extremely long categories are not really usable / useful.""" - if len(category_name) > 255: - logger.error( - f"Category with name '{category_name}' is too long, cannot be used" - ) - return False - - return True - - -def insert_standard_answer_category( - category_name: str, db_session: Session -) -> StandardAnswerCategory: - if not check_category_validity(category_name): - raise ValueError(f"Invalid category name: {category_name}") - standard_answer_category = StandardAnswerCategory(name=category_name) - db_session.add(standard_answer_category) - db_session.commit() - - return standard_answer_category - - -def insert_standard_answer( - keyword: str, - answer: str, - category_ids: list[int], - match_regex: bool, - match_any_keywords: bool, - db_session: Session, -) -> StandardAnswer: - existing_categories = fetch_standard_answer_categories_by_ids( - standard_answer_category_ids=category_ids, - db_session=db_session, - ) - if len(existing_categories) != len(category_ids): - raise ValueError(f"Some or all categories with ids {category_ids} do not exist") - - standard_answer = StandardAnswer( - keyword=keyword, - answer=answer, - categories=existing_categories, - active=True, - match_regex=match_regex, - match_any_keywords=match_any_keywords, - ) - db_session.add(standard_answer) - db_session.commit() - return standard_answer - - -def update_standard_answer( - standard_answer_id: int, - keyword: str, - answer: str, - category_ids: list[int], - match_regex: bool, - match_any_keywords: bool, - db_session: Session, -) -> StandardAnswer: - standard_answer = db_session.scalar( - select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) - ) - if standard_answer is None: - raise ValueError(f"No standard answer with id {standard_answer_id}") - - existing_categories = fetch_standard_answer_categories_by_ids( - standard_answer_category_ids=category_ids, - db_session=db_session, - ) - if len(existing_categories) != len(category_ids): - raise ValueError(f"Some or all categories with ids {category_ids} do not exist") - - standard_answer.keyword = keyword - standard_answer.answer = answer - standard_answer.categories = list(existing_categories) - standard_answer.match_regex = match_regex - standard_answer.match_any_keywords = match_any_keywords - - db_session.commit() - - return standard_answer - - -def remove_standard_answer( - standard_answer_id: int, - db_session: Session, -) -> None: - standard_answer = db_session.scalar( - select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) - ) - if standard_answer is None: - raise ValueError(f"No standard answer with id {standard_answer_id}") - - standard_answer.active = False - db_session.commit() - - -def update_standard_answer_category( - standard_answer_category_id: int, - category_name: str, - db_session: Session, -) -> StandardAnswerCategory: - standard_answer_category = db_session.scalar( - select(StandardAnswerCategory).where( - StandardAnswerCategory.id == standard_answer_category_id - ) - ) - if standard_answer_category is None: - raise ValueError( - f"No standard answer category with id {standard_answer_category_id}" - ) - - if not check_category_validity(category_name): - raise ValueError(f"Invalid category name: {category_name}") - - standard_answer_category.name = category_name - - db_session.commit() - - return standard_answer_category - - -def fetch_standard_answer_category( - standard_answer_category_id: int, - db_session: Session, -) -> StandardAnswerCategory | None: - return db_session.scalar( - select(StandardAnswerCategory).where( - StandardAnswerCategory.id == standard_answer_category_id - ) - ) - - -def fetch_standard_answer_categories_by_ids( - standard_answer_category_ids: list[int], - db_session: Session, -) -> Sequence[StandardAnswerCategory]: - return db_session.scalars( - select(StandardAnswerCategory).where( - StandardAnswerCategory.id.in_(standard_answer_category_ids) - ) - ).all() - - -def fetch_standard_answer_categories( - db_session: Session, -) -> Sequence[StandardAnswerCategory]: - return db_session.scalars(select(StandardAnswerCategory)).all() - - -def fetch_standard_answer( - standard_answer_id: int, - db_session: Session, -) -> StandardAnswer | None: - return db_session.scalar( - select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) - ) - - -def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]: - return db_session.scalars( - select(StandardAnswer).where(StandardAnswer.active.is_(True)) - ).all() - - -def create_initial_default_standard_answer_category(db_session: Session) -> None: - default_category_id = 0 - default_category_name = "General" - default_category = fetch_standard_answer_category( - standard_answer_category_id=default_category_id, - db_session=db_session, - ) - if default_category is not None: - if default_category.name != default_category_name: - raise ValueError( - "DB is not in a valid initial state. " - "Default standard answer category does not have expected name." - ) - return - - standard_answer_category = StandardAnswerCategory( - id=default_category_id, - name=default_category_name, - ) - db_session.add(standard_answer_category) - db_session.commit() diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 78b44c52b5dd..9a681c39a13a 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -62,7 +62,6 @@ from danswer.db.search_settings import get_current_search_settings from danswer.db.search_settings import get_secondary_search_settings from danswer.db.search_settings import update_current_search_settings from danswer.db.search_settings import update_secondary_search_settings -from danswer.db.standard_answer import create_initial_default_standard_answer_category from danswer.db.swap_index import check_index_swap from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import DocumentIndex @@ -186,9 +185,6 @@ def setup_postgres(db_session: Session) -> None: create_initial_default_connector(db_session) associate_default_cc_pair(db_session) - logger.notice("Verifying default standard answer category exists.") - create_initial_default_standard_answer_category(db_session) - logger.notice("Loading default Prompts and Personas") delete_old_default_personas(db_session) load_chat_yamls() diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index e4618c45658d..7b0a3813a82e 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -1,6 +1,4 @@ -import re from datetime import datetime -from typing import Any from typing import TYPE_CHECKING from pydantic import BaseModel @@ -17,13 +15,12 @@ from danswer.db.models import AllowedAnswerFilters from danswer.db.models import ChannelConfig from danswer.db.models import SlackBotConfig as SlackBotConfigModel from danswer.db.models import SlackBotResponseType -from danswer.db.models import StandardAnswer as StandardAnswerModel -from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel from danswer.db.models import User from danswer.search.models import SavedSearchSettings from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.models import FullUserSnapshot from danswer.server.models import InvitedUserSnapshot +from ee.danswer.server.manage.models import StandardAnswerCategory if TYPE_CHECKING: @@ -119,95 +116,6 @@ class HiddenUpdateRequest(BaseModel): hidden: bool -class StandardAnswerCategoryCreationRequest(BaseModel): - name: str - - -class StandardAnswerCategory(BaseModel): - id: int - name: str - - @classmethod - def from_model( - cls, standard_answer_category: StandardAnswerCategoryModel - ) -> "StandardAnswerCategory": - return cls( - id=standard_answer_category.id, - name=standard_answer_category.name, - ) - - -class StandardAnswer(BaseModel): - id: int - keyword: str - answer: str - categories: list[StandardAnswerCategory] - match_regex: bool - match_any_keywords: bool - - @classmethod - def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer": - return cls( - id=standard_answer_model.id, - keyword=standard_answer_model.keyword, - answer=standard_answer_model.answer, - match_regex=standard_answer_model.match_regex, - match_any_keywords=standard_answer_model.match_any_keywords, - categories=[ - StandardAnswerCategory.from_model(standard_answer_category_model) - for standard_answer_category_model in standard_answer_model.categories - ], - ) - - -class StandardAnswerCreationRequest(BaseModel): - keyword: str - answer: str - categories: list[int] - match_regex: bool - match_any_keywords: bool - - @field_validator("categories", mode="before") - @classmethod - def validate_categories(cls, value: list[int]) -> list[int]: - if len(value) < 1: - raise ValueError( - "At least one category must be attached to a standard answer" - ) - return value - - @model_validator(mode="after") - def validate_only_match_any_if_not_regex(self) -> Any: - if self.match_regex and self.match_any_keywords: - raise ValueError( - "Can only match any keywords in keyword mode, not regex mode" - ) - - return self - - @model_validator(mode="after") - def validate_keyword_if_regex(self) -> Any: - if not self.match_regex: - # no validation for keywords - return self - - try: - re.compile(self.keyword) - return self - except re.error as err: - if isinstance(err.pattern, bytes): - raise ValueError( - f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}' - ) - else: - pattern = f'r"{err.pattern}"' if err.pattern is not None else "" - raise ValueError( - " ".join( - ["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"] - ) - ) - - class SlackBotTokens(BaseModel): bot_token: str app_token: str @@ -233,6 +141,7 @@ class SlackBotConfigCreationRequest(BaseModel): # list of user emails follow_up_tags: list[str] | None = None response_type: SlackBotResponseType + # XXX this is going away soon standard_answer_categories: list[int] = Field(default_factory=list) @field_validator("answer_filters", mode="before") @@ -257,6 +166,7 @@ class SlackBotConfig(BaseModel): persona: PersonaSnapshot | None channel_config: ChannelConfig response_type: SlackBotResponseType + # XXX this is going away soon standard_answer_categories: list[StandardAnswerCategory] enable_auto_filters: bool @@ -275,6 +185,7 @@ class SlackBotConfig(BaseModel): ), channel_config=slack_bot_config_model.channel_config, response_type=slack_bot_config_model.response_type, + # XXX this is going away soon standard_answer_categories=[ StandardAnswerCategory.from_model(standard_answer_category_model) for standard_answer_category_model in slack_bot_config_model.standard_answer_categories diff --git a/backend/danswer/server/manage/slack_bot.py b/backend/danswer/server/manage/slack_bot.py index 0fb1459072bc..9a06b225cce0 100644 --- a/backend/danswer/server/manage/slack_bot.py +++ b/backend/danswer/server/manage/slack_bot.py @@ -108,6 +108,7 @@ def create_slack_bot_config( persona_id=persona_id, channel_config=channel_config, response_type=slack_bot_config_creation_request.response_type, + # XXX this is going away soon standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories, db_session=db_session, enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters, diff --git a/backend/danswer/utils/errors.py b/backend/danswer/utils/errors.py new file mode 100644 index 000000000000..86b9d4252f31 --- /dev/null +++ b/backend/danswer/utils/errors.py @@ -0,0 +1,3 @@ +class EERequiredError(Exception): + """This error is thrown if an Enterprise Edition feature or API is + requested but the Enterprise Edition flag is not set.""" diff --git a/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py b/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py index 96c72187a67a..6807e77135a4 100644 --- a/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py +++ b/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py @@ -21,11 +21,11 @@ from danswer.db.chat import get_or_create_root_message from danswer.db.models import Prompt from danswer.db.models import SlackBotConfig from danswer.db.models import StandardAnswer as StandardAnswerModel -from danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer from danswer.utils.logger import DanswerLoggingAdapter from danswer.utils.logger import setup_logger from ee.danswer.db.standard_answer import fetch_standard_answer_categories_by_names from ee.danswer.db.standard_answer import find_matching_standard_answers +from ee.danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer logger = setup_logger() diff --git a/backend/ee/danswer/db/standard_answer.py b/backend/ee/danswer/db/standard_answer.py index 2887a487b5eb..0fa074e36a7f 100644 --- a/backend/ee/danswer/db/standard_answer.py +++ b/backend/ee/danswer/db/standard_answer.py @@ -12,6 +12,198 @@ from danswer.utils.logger import setup_logger logger = setup_logger() +def check_category_validity(category_name: str) -> bool: + """If a category name is too long, it should not be used (it will cause an error in Postgres + as the unique constraint can only apply to entries that are less than 2704 bytes). + + Additionally, extremely long categories are not really usable / useful.""" + if len(category_name) > 255: + logger.error( + f"Category with name '{category_name}' is too long, cannot be used" + ) + return False + + return True + + +def insert_standard_answer_category( + category_name: str, db_session: Session +) -> StandardAnswerCategory: + if not check_category_validity(category_name): + raise ValueError(f"Invalid category name: {category_name}") + standard_answer_category = StandardAnswerCategory(name=category_name) + db_session.add(standard_answer_category) + db_session.commit() + + return standard_answer_category + + +def insert_standard_answer( + keyword: str, + answer: str, + category_ids: list[int], + match_regex: bool, + match_any_keywords: bool, + db_session: Session, +) -> StandardAnswer: + existing_categories = fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=category_ids, + db_session=db_session, + ) + if len(existing_categories) != len(category_ids): + raise ValueError(f"Some or all categories with ids {category_ids} do not exist") + + standard_answer = StandardAnswer( + keyword=keyword, + answer=answer, + categories=existing_categories, + active=True, + match_regex=match_regex, + match_any_keywords=match_any_keywords, + ) + db_session.add(standard_answer) + db_session.commit() + return standard_answer + + +def update_standard_answer( + standard_answer_id: int, + keyword: str, + answer: str, + category_ids: list[int], + match_regex: bool, + match_any_keywords: bool, + db_session: Session, +) -> StandardAnswer: + standard_answer = db_session.scalar( + select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) + ) + if standard_answer is None: + raise ValueError(f"No standard answer with id {standard_answer_id}") + + existing_categories = fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=category_ids, + db_session=db_session, + ) + if len(existing_categories) != len(category_ids): + raise ValueError(f"Some or all categories with ids {category_ids} do not exist") + + standard_answer.keyword = keyword + standard_answer.answer = answer + standard_answer.categories = list(existing_categories) + standard_answer.match_regex = match_regex + standard_answer.match_any_keywords = match_any_keywords + + db_session.commit() + + return standard_answer + + +def remove_standard_answer( + standard_answer_id: int, + db_session: Session, +) -> None: + standard_answer = db_session.scalar( + select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) + ) + if standard_answer is None: + raise ValueError(f"No standard answer with id {standard_answer_id}") + + standard_answer.active = False + db_session.commit() + + +def update_standard_answer_category( + standard_answer_category_id: int, + category_name: str, + db_session: Session, +) -> StandardAnswerCategory: + standard_answer_category = db_session.scalar( + select(StandardAnswerCategory).where( + StandardAnswerCategory.id == standard_answer_category_id + ) + ) + if standard_answer_category is None: + raise ValueError( + f"No standard answer category with id {standard_answer_category_id}" + ) + + if not check_category_validity(category_name): + raise ValueError(f"Invalid category name: {category_name}") + + standard_answer_category.name = category_name + + db_session.commit() + + return standard_answer_category + + +def fetch_standard_answer_category( + standard_answer_category_id: int, + db_session: Session, +) -> StandardAnswerCategory | None: + return db_session.scalar( + select(StandardAnswerCategory).where( + StandardAnswerCategory.id == standard_answer_category_id + ) + ) + + +def fetch_standard_answer_categories_by_ids( + standard_answer_category_ids: list[int], + db_session: Session, +) -> Sequence[StandardAnswerCategory]: + return db_session.scalars( + select(StandardAnswerCategory).where( + StandardAnswerCategory.id.in_(standard_answer_category_ids) + ) + ).all() + + +def fetch_standard_answer_categories( + db_session: Session, +) -> Sequence[StandardAnswerCategory]: + return db_session.scalars(select(StandardAnswerCategory)).all() + + +def fetch_standard_answer( + standard_answer_id: int, + db_session: Session, +) -> StandardAnswer | None: + return db_session.scalar( + select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) + ) + + +def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]: + return db_session.scalars( + select(StandardAnswer).where(StandardAnswer.active.is_(True)) + ).all() + + +def create_initial_default_standard_answer_category(db_session: Session) -> None: + default_category_id = 0 + default_category_name = "General" + default_category = fetch_standard_answer_category( + standard_answer_category_id=default_category_id, + db_session=db_session, + ) + if default_category is not None: + if default_category.name != default_category_name: + raise ValueError( + "DB is not in a valid initial state. " + "Default standard answer category does not have expected name." + ) + return + + standard_answer_category = StandardAnswerCategory( + id=default_category_id, + name=default_category_name, + ) + db_session.add(standard_answer_category) + db_session.commit() + + def fetch_standard_answer_categories_by_names( standard_answer_category_names: list[str], db_session: Session, diff --git a/backend/ee/danswer/server/manage/models.py b/backend/ee/danswer/server/manage/models.py new file mode 100644 index 000000000000..ae2c401a2fac --- /dev/null +++ b/backend/ee/danswer/server/manage/models.py @@ -0,0 +1,98 @@ +import re +from typing import Any + +from pydantic import BaseModel +from pydantic import field_validator +from pydantic import model_validator + +from danswer.db.models import StandardAnswer as StandardAnswerModel +from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel + + +class StandardAnswerCategoryCreationRequest(BaseModel): + name: str + + +class StandardAnswerCategory(BaseModel): + id: int + name: str + + @classmethod + def from_model( + cls, standard_answer_category: StandardAnswerCategoryModel + ) -> "StandardAnswerCategory": + return cls( + id=standard_answer_category.id, + name=standard_answer_category.name, + ) + + +class StandardAnswer(BaseModel): + id: int + keyword: str + answer: str + categories: list[StandardAnswerCategory] + match_regex: bool + match_any_keywords: bool + + @classmethod + def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer": + return cls( + id=standard_answer_model.id, + keyword=standard_answer_model.keyword, + answer=standard_answer_model.answer, + match_regex=standard_answer_model.match_regex, + match_any_keywords=standard_answer_model.match_any_keywords, + categories=[ + StandardAnswerCategory.from_model(standard_answer_category_model) + for standard_answer_category_model in standard_answer_model.categories + ], + ) + + +class StandardAnswerCreationRequest(BaseModel): + keyword: str + answer: str + categories: list[int] + match_regex: bool + match_any_keywords: bool + + @field_validator("categories", mode="before") + @classmethod + def validate_categories(cls, value: list[int]) -> list[int]: + if len(value) < 1: + raise ValueError( + "At least one category must be attached to a standard answer" + ) + return value + + @model_validator(mode="after") + def validate_only_match_any_if_not_regex(self) -> Any: + if self.match_regex and self.match_any_keywords: + raise ValueError( + "Can only match any keywords in keyword mode, not regex mode" + ) + + return self + + @model_validator(mode="after") + def validate_keyword_if_regex(self) -> Any: + if not self.match_regex: + # no validation for keywords + return self + + try: + re.compile(self.keyword) + return self + except re.error as err: + if isinstance(err.pattern, bytes): + raise ValueError( + f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}' + ) + else: + pattern = f'r"{err.pattern}"' if err.pattern is not None else "" + raise ValueError( + " ".join( + ["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"] + ) + ) diff --git a/backend/ee/danswer/server/manage/standard_answer.py b/backend/ee/danswer/server/manage/standard_answer.py index ea3ca0bc0dc7..e832fa190780 100644 --- a/backend/ee/danswer/server/manage/standard_answer.py +++ b/backend/ee/danswer/server/manage/standard_answer.py @@ -6,19 +6,19 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.db.engine import get_session from danswer.db.models import User -from danswer.db.standard_answer import fetch_standard_answer -from danswer.db.standard_answer import fetch_standard_answer_categories -from danswer.db.standard_answer import fetch_standard_answer_category -from danswer.db.standard_answer import fetch_standard_answers -from danswer.db.standard_answer import insert_standard_answer -from danswer.db.standard_answer import insert_standard_answer_category -from danswer.db.standard_answer import remove_standard_answer -from danswer.db.standard_answer import update_standard_answer -from danswer.db.standard_answer import update_standard_answer_category -from danswer.server.manage.models import StandardAnswer -from danswer.server.manage.models import StandardAnswerCategory -from danswer.server.manage.models import StandardAnswerCategoryCreationRequest -from danswer.server.manage.models import StandardAnswerCreationRequest +from ee.danswer.db.standard_answer import fetch_standard_answer +from ee.danswer.db.standard_answer import fetch_standard_answer_categories +from ee.danswer.db.standard_answer import fetch_standard_answer_category +from ee.danswer.db.standard_answer import fetch_standard_answers +from ee.danswer.db.standard_answer import insert_standard_answer +from ee.danswer.db.standard_answer import insert_standard_answer_category +from ee.danswer.db.standard_answer import remove_standard_answer +from ee.danswer.db.standard_answer import update_standard_answer +from ee.danswer.db.standard_answer import update_standard_answer_category +from ee.danswer.server.manage.models import StandardAnswer +from ee.danswer.server.manage.models import StandardAnswerCategory +from ee.danswer.server.manage.models import StandardAnswerCategoryCreationRequest +from ee.danswer.server.manage.models import StandardAnswerCreationRequest router = APIRouter(prefix="/manage") diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py index b1ea648c8f01..cc66c0efab91 100644 --- a/backend/ee/danswer/server/query_and_chat/models.py +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -8,7 +8,7 @@ from danswer.search.enums import SearchType from danswer.search.models import ChunkContext from danswer.search.models import RerankingDetails from danswer.search.models import RetrievalDetails -from danswer.server.manage.models import StandardAnswer +from ee.danswer.server.manage.models import StandardAnswer class StandardAnswerRequest(BaseModel): diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index 10dc1afb9721..ab6c4b017f9f 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -13,6 +13,9 @@ from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.settings.models import Settings from danswer.server.settings.store import store_settings as store_base_settings from danswer.utils.logger import setup_logger +from ee.danswer.db.standard_answer import ( + create_initial_default_standard_answer_category, +) from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload from ee.danswer.server.enterprise_settings.models import EnterpriseSettings from ee.danswer.server.enterprise_settings.store import store_analytics_script @@ -21,6 +24,7 @@ from ee.danswer.server.enterprise_settings.store import ( ) from ee.danswer.server.enterprise_settings.store import upload_logo + logger = setup_logger() _SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION" @@ -146,3 +150,6 @@ def seed_db() -> None: _seed_logo(db_session, seed_config.seeded_logo_path) _seed_enterprise_settings(seed_config) _seed_analytics_script(seed_config) + + logger.notice("Verifying default standard answer category exists.") + create_initial_default_standard_answer_category(db_session)