mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 13:05:49 +02:00
EE movement followup for Standard Answers (#2467)
* Move StandardAnswer to EE section of danswer/db/models * Move StandardAnswer DB layer to EE * Add EERequiredError for distinct error handling here * Handle EE fallback for slack bot config * Migrate all standard answer models to ee * Flagging categories for removal * Add missing versioned impl for update_slack_bot_config --------- Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
This commit is contained in:
@@ -1347,55 +1347,6 @@ class ChannelConfig(TypedDict):
|
||||
follow_up_tags: NotRequired[list[str]]
|
||||
|
||||
|
||||
class StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer_category"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="categories",
|
||||
)
|
||||
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
|
||||
"SlackBotConfig",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answer_categories",
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(Base):
|
||||
__tablename__ = "standard_answer"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
match_regex: Mapped[bool] = mapped_column(Boolean)
|
||||
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"unique_keyword_active",
|
||||
keyword,
|
||||
active,
|
||||
unique=True,
|
||||
postgresql_where=(active == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
chat_messages: Mapped[list[ChatMessage]] = relationship(
|
||||
"ChatMessage",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
@@ -1421,7 +1372,7 @@ class SlackBotConfig(Base):
|
||||
)
|
||||
|
||||
persona: Mapped[Persona | None] = relationship("Persona")
|
||||
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="slack_bot_configs",
|
||||
@@ -1651,6 +1602,55 @@ class TokenRateLimit__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer_category"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="categories",
|
||||
)
|
||||
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
|
||||
"SlackBotConfig",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answer_categories",
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(Base):
|
||||
__tablename__ = "standard_answer"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
match_regex: Mapped[bool] = mapped_column(Boolean)
|
||||
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"unique_keyword_active",
|
||||
keyword,
|
||||
active,
|
||||
unique=True,
|
||||
postgresql_where=(active == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
chat_messages: Mapped[list[ChatMessage]] = relationship(
|
||||
"ChatMessage",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Permission Sync"""
|
||||
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -14,8 +15,11 @@ from danswer.db.models import User
|
||||
from danswer.db.persona import get_default_prompt
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.utils.errors import EERequiredError
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
def _build_persona_name(channel_names: list[str]) -> str:
|
||||
@@ -70,6 +74,10 @@ def create_slack_bot_persona(
|
||||
return persona
|
||||
|
||||
|
||||
def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list:
|
||||
return []
|
||||
|
||||
|
||||
def insert_slack_bot_config(
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
@@ -78,14 +86,29 @@ def insert_slack_bot_config(
|
||||
enable_auto_filters: bool,
|
||||
db_session: Session,
|
||||
) -> SlackBotConfig:
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
versioned_fetch_standard_answer_categories_by_ids = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.standard_answer",
|
||||
"fetch_standard_answer_categories_by_ids",
|
||||
_no_ee_standard_answer_categories,
|
||||
)
|
||||
)
|
||||
existing_standard_answer_categories = (
|
||||
versioned_fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
if len(existing_standard_answer_categories) == 0:
|
||||
raise EERequiredError(
|
||||
"Standard answers are a paid Enterprise Edition feature - enable EE or remove standard answer categories"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
)
|
||||
|
||||
slack_bot_config = SlackBotConfig(
|
||||
persona_id=persona_id,
|
||||
@@ -117,9 +140,18 @@ def update_slack_bot_config(
|
||||
f"Unable to find slack bot config with ID {slack_bot_config_id}"
|
||||
)
|
||||
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
versioned_fetch_standard_answer_categories_by_ids = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.standard_answer",
|
||||
"fetch_standard_answer_categories_by_ids",
|
||||
_no_ee_standard_answer_categories,
|
||||
)
|
||||
)
|
||||
existing_standard_answer_categories = (
|
||||
versioned_fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
|
@@ -1,202 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import StandardAnswer
|
||||
from danswer.db.models import StandardAnswerCategory
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_category_validity(category_name: str) -> bool:
|
||||
"""If a category name is too long, it should not be used (it will cause an error in Postgres
|
||||
as the unique constraint can only apply to entries that are less than 2704 bytes).
|
||||
|
||||
Additionally, extremely long categories are not really usable / useful."""
|
||||
if len(category_name) > 255:
|
||||
logger.error(
|
||||
f"Category with name '{category_name}' is too long, cannot be used"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def insert_standard_answer_category(
|
||||
category_name: str, db_session: Session
|
||||
) -> StandardAnswerCategory:
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
standard_answer_category = StandardAnswerCategory(name=category_name)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def insert_standard_answer(
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer = StandardAnswer(
|
||||
keyword=keyword,
|
||||
answer=answer,
|
||||
categories=existing_categories,
|
||||
active=True,
|
||||
match_regex=match_regex,
|
||||
match_any_keywords=match_any_keywords,
|
||||
)
|
||||
db_session.add(standard_answer)
|
||||
db_session.commit()
|
||||
return standard_answer
|
||||
|
||||
|
||||
def update_standard_answer(
|
||||
standard_answer_id: int,
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer.keyword = keyword
|
||||
standard_answer.answer = answer
|
||||
standard_answer.categories = list(existing_categories)
|
||||
standard_answer.match_regex = match_regex
|
||||
standard_answer.match_any_keywords = match_any_keywords
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer
|
||||
|
||||
|
||||
def remove_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
standard_answer.active = False
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
category_name: str,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category = db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
if standard_answer_category is None:
|
||||
raise ValueError(
|
||||
f"No standard answer category with id {standard_answer_category_id}"
|
||||
)
|
||||
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
|
||||
standard_answer_category.name = category_name
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def fetch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id.in_(standard_answer_category_ids)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories(
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(select(StandardAnswerCategory)).all()
|
||||
|
||||
|
||||
def fetch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswer).where(StandardAnswer.active.is_(True))
|
||||
).all()
|
||||
|
||||
|
||||
def create_initial_default_standard_answer_category(db_session: Session) -> None:
|
||||
default_category_id = 0
|
||||
default_category_name = "General"
|
||||
default_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=default_category_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if default_category is not None:
|
||||
if default_category.name != default_category_name:
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default standard answer category does not have expected name."
|
||||
)
|
||||
return
|
||||
|
||||
standard_answer_category = StandardAnswerCategory(
|
||||
id=default_category_id,
|
||||
name=default_category_name,
|
||||
)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
@@ -62,7 +62,6 @@ from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_current_search_settings
|
||||
from danswer.db.search_settings import update_secondary_search_settings
|
||||
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
@@ -186,9 +185,6 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls()
|
||||
|
@@ -1,6 +1,4 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -17,13 +15,12 @@ from danswer.db.models import AllowedAnswerFilters
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
from danswer.db.models import User
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
from ee.danswer.server.manage.models import StandardAnswerCategory
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -119,95 +116,6 @@ class HiddenUpdateRequest(BaseModel):
|
||||
hidden: bool
|
||||
|
||||
|
||||
class StandardAnswerCategoryCreationRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class StandardAnswerCategory(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, standard_answer_category: StandardAnswerCategoryModel
|
||||
) -> "StandardAnswerCategory":
|
||||
return cls(
|
||||
id=standard_answer_category.id,
|
||||
name=standard_answer_category.name,
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(BaseModel):
|
||||
id: int
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
return cls(
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
match_any_keywords=standard_answer_model.match_any_keywords,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
"At least one category must be attached to a standard answer"
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_only_match_any_if_not_regex(self) -> Any:
|
||||
if self.match_regex and self.match_any_keywords:
|
||||
raise ValueError(
|
||||
"Can only match any keywords in keyword mode, not regex mode"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
# no validation for keywords
|
||||
return self
|
||||
|
||||
try:
|
||||
re.compile(self.keyword)
|
||||
return self
|
||||
except re.error as err:
|
||||
if isinstance(err.pattern, bytes):
|
||||
raise ValueError(
|
||||
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
|
||||
)
|
||||
else:
|
||||
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
|
||||
raise ValueError(
|
||||
" ".join(
|
||||
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SlackBotTokens(BaseModel):
|
||||
bot_token: str
|
||||
app_token: str
|
||||
@@ -233,6 +141,7 @@ class SlackBotConfigCreationRequest(BaseModel):
|
||||
# list of user emails
|
||||
follow_up_tags: list[str] | None = None
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[int] = Field(default_factory=list)
|
||||
|
||||
@field_validator("answer_filters", mode="before")
|
||||
@@ -257,6 +166,7 @@ class SlackBotConfig(BaseModel):
|
||||
persona: PersonaSnapshot | None
|
||||
channel_config: ChannelConfig
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[StandardAnswerCategory]
|
||||
enable_auto_filters: bool
|
||||
|
||||
@@ -275,6 +185,7 @@ class SlackBotConfig(BaseModel):
|
||||
),
|
||||
channel_config=slack_bot_config_model.channel_config,
|
||||
response_type=slack_bot_config_model.response_type,
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in slack_bot_config_model.standard_answer_categories
|
||||
|
@@ -108,6 +108,7 @@ def create_slack_bot_config(
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_bot_config_creation_request.response_type,
|
||||
# XXX this is going away soon
|
||||
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
|
||||
db_session=db_session,
|
||||
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,
|
||||
|
3
backend/danswer/utils/errors.py
Normal file
3
backend/danswer/utils/errors.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class EERequiredError(Exception):
|
||||
"""This error is thrown if an Enterprise Edition feature or API is
|
||||
requested but the Enterprise Edition flag is not set."""
|
@@ -21,11 +21,11 @@ from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from ee.danswer.db.standard_answer import find_matching_standard_answers
|
||||
from ee.danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
@@ -12,6 +12,198 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_category_validity(category_name: str) -> bool:
|
||||
"""If a category name is too long, it should not be used (it will cause an error in Postgres
|
||||
as the unique constraint can only apply to entries that are less than 2704 bytes).
|
||||
|
||||
Additionally, extremely long categories are not really usable / useful."""
|
||||
if len(category_name) > 255:
|
||||
logger.error(
|
||||
f"Category with name '{category_name}' is too long, cannot be used"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def insert_standard_answer_category(
|
||||
category_name: str, db_session: Session
|
||||
) -> StandardAnswerCategory:
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
standard_answer_category = StandardAnswerCategory(name=category_name)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def insert_standard_answer(
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer = StandardAnswer(
|
||||
keyword=keyword,
|
||||
answer=answer,
|
||||
categories=existing_categories,
|
||||
active=True,
|
||||
match_regex=match_regex,
|
||||
match_any_keywords=match_any_keywords,
|
||||
)
|
||||
db_session.add(standard_answer)
|
||||
db_session.commit()
|
||||
return standard_answer
|
||||
|
||||
|
||||
def update_standard_answer(
|
||||
standard_answer_id: int,
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer.keyword = keyword
|
||||
standard_answer.answer = answer
|
||||
standard_answer.categories = list(existing_categories)
|
||||
standard_answer.match_regex = match_regex
|
||||
standard_answer.match_any_keywords = match_any_keywords
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer
|
||||
|
||||
|
||||
def remove_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
standard_answer.active = False
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
category_name: str,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category = db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
if standard_answer_category is None:
|
||||
raise ValueError(
|
||||
f"No standard answer category with id {standard_answer_category_id}"
|
||||
)
|
||||
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
|
||||
standard_answer_category.name = category_name
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def fetch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id.in_(standard_answer_category_ids)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories(
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(select(StandardAnswerCategory)).all()
|
||||
|
||||
|
||||
def fetch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswer).where(StandardAnswer.active.is_(True))
|
||||
).all()
|
||||
|
||||
|
||||
def create_initial_default_standard_answer_category(db_session: Session) -> None:
|
||||
default_category_id = 0
|
||||
default_category_name = "General"
|
||||
default_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=default_category_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if default_category is not None:
|
||||
if default_category.name != default_category_name:
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default standard answer category does not have expected name."
|
||||
)
|
||||
return
|
||||
|
||||
standard_answer_category = StandardAnswerCategory(
|
||||
id=default_category_id,
|
||||
name=default_category_name,
|
||||
)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_names(
|
||||
standard_answer_category_names: list[str],
|
||||
db_session: Session,
|
||||
|
98
backend/ee/danswer/server/manage/models.py
Normal file
98
backend/ee/danswer/server/manage/models.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
|
||||
|
||||
class StandardAnswerCategoryCreationRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class StandardAnswerCategory(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, standard_answer_category: StandardAnswerCategoryModel
|
||||
) -> "StandardAnswerCategory":
|
||||
return cls(
|
||||
id=standard_answer_category.id,
|
||||
name=standard_answer_category.name,
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(BaseModel):
|
||||
id: int
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
return cls(
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
match_any_keywords=standard_answer_model.match_any_keywords,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
"At least one category must be attached to a standard answer"
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_only_match_any_if_not_regex(self) -> Any:
|
||||
if self.match_regex and self.match_any_keywords:
|
||||
raise ValueError(
|
||||
"Can only match any keywords in keyword mode, not regex mode"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
# no validation for keywords
|
||||
return self
|
||||
|
||||
try:
|
||||
re.compile(self.keyword)
|
||||
return self
|
||||
except re.error as err:
|
||||
if isinstance(err.pattern, bytes):
|
||||
raise ValueError(
|
||||
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
|
||||
)
|
||||
else:
|
||||
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
|
||||
raise ValueError(
|
||||
" ".join(
|
||||
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
|
||||
)
|
||||
)
|
@@ -6,19 +6,19 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.standard_answer import fetch_standard_answer
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories
|
||||
from danswer.db.standard_answer import fetch_standard_answer_category
|
||||
from danswer.db.standard_answer import fetch_standard_answers
|
||||
from danswer.db.standard_answer import insert_standard_answer
|
||||
from danswer.db.standard_answer import insert_standard_answer_category
|
||||
from danswer.db.standard_answer import remove_standard_answer
|
||||
from danswer.db.standard_answer import update_standard_answer
|
||||
from danswer.db.standard_answer import update_standard_answer_category
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.server.manage.models import StandardAnswerCategory
|
||||
from danswer.server.manage.models import StandardAnswerCategoryCreationRequest
|
||||
from danswer.server.manage.models import StandardAnswerCreationRequest
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer_categories
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer_category
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answers
|
||||
from ee.danswer.db.standard_answer import insert_standard_answer
|
||||
from ee.danswer.db.standard_answer import insert_standard_answer_category
|
||||
from ee.danswer.db.standard_answer import remove_standard_answer
|
||||
from ee.danswer.db.standard_answer import update_standard_answer
|
||||
from ee.danswer.db.standard_answer import update_standard_answer_category
|
||||
from ee.danswer.server.manage.models import StandardAnswer
|
||||
from ee.danswer.server.manage.models import StandardAnswerCategory
|
||||
from ee.danswer.server.manage.models import StandardAnswerCategoryCreationRequest
|
||||
from ee.danswer.server.manage.models import StandardAnswerCreationRequest
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
@@ -8,7 +8,7 @@ from danswer.search.enums import SearchType
|
||||
from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from ee.danswer.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
|
@@ -13,6 +13,9 @@ from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.settings.models import Settings
|
||||
from danswer.server.settings.store import store_settings as store_base_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.standard_answer import (
|
||||
create_initial_default_standard_answer_category,
|
||||
)
|
||||
from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.danswer.server.enterprise_settings.store import store_analytics_script
|
||||
@@ -21,6 +24,7 @@ from ee.danswer.server.enterprise_settings.store import (
|
||||
)
|
||||
from ee.danswer.server.enterprise_settings.store import upload_logo
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
|
||||
@@ -146,3 +150,6 @@ def seed_db() -> None:
|
||||
_seed_logo(db_session, seed_config.seeded_logo_path)
|
||||
_seed_enterprise_settings(seed_config)
|
||||
_seed_analytics_script(seed_config)
|
||||
|
||||
logger.notice("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
Reference in New Issue
Block a user