Migrate all standard answer models to ee

This commit is contained in:
danswer-trial 2024-09-16 13:10:27 -07:00
parent 6c29ad8d59
commit 6eeee2a96d
7 changed files with 122 additions and 112 deletions

View File

@ -62,7 +62,6 @@ from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_secondary_search_settings
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
@ -186,9 +185,6 @@ def setup_postgres(db_session: Session) -> None:
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()

View File

@ -1,6 +1,4 @@
import re
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
@ -17,13 +15,12 @@ from danswer.db.models import AllowedAnswerFilters
from danswer.db.models import ChannelConfig
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
from danswer.db.models import SlackBotResponseType
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
from danswer.db.models import User
from danswer.search.models import SavedSearchSettings
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot
from ee.danswer.server.manage.models import StandardAnswerCategory
if TYPE_CHECKING:
@ -119,95 +116,6 @@ class HiddenUpdateRequest(BaseModel):
hidden: bool
class StandardAnswerCategoryCreationRequest(BaseModel):
name: str
class StandardAnswerCategory(BaseModel):
id: int
name: str
@classmethod
def from_model(
cls, standard_answer_category: StandardAnswerCategoryModel
) -> "StandardAnswerCategory":
return cls(
id=standard_answer_category.id,
name=standard_answer_category.name,
)
class StandardAnswer(BaseModel):
id: int
keyword: str
answer: str
categories: list[StandardAnswerCategory]
match_regex: bool
match_any_keywords: bool
@classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
return cls(
id=standard_answer_model.id,
keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer,
match_regex=standard_answer_model.match_regex,
match_any_keywords=standard_answer_model.match_any_keywords,
categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories
],
)
class StandardAnswerCreationRequest(BaseModel):
keyword: str
answer: str
categories: list[int]
match_regex: bool
match_any_keywords: bool
@field_validator("categories", mode="before")
@classmethod
def validate_categories(cls, value: list[int]) -> list[int]:
if len(value) < 1:
raise ValueError(
"At least one category must be attached to a standard answer"
)
return value
@model_validator(mode="after")
def validate_only_match_any_if_not_regex(self) -> Any:
if self.match_regex and self.match_any_keywords:
raise ValueError(
"Can only match any keywords in keyword mode, not regex mode"
)
return self
@model_validator(mode="after")
def validate_keyword_if_regex(self) -> Any:
if not self.match_regex:
# no validation for keywords
return self
try:
re.compile(self.keyword)
return self
except re.error as err:
if isinstance(err.pattern, bytes):
raise ValueError(
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
)
else:
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
raise ValueError(
" ".join(
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
)
)
class SlackBotTokens(BaseModel):
bot_token: str
app_token: str
@ -257,6 +165,7 @@ class SlackBotConfig(BaseModel):
persona: PersonaSnapshot | None
channel_config: ChannelConfig
response_type: SlackBotResponseType
# XXX this is going away soon
standard_answer_categories: list[StandardAnswerCategory]
enable_auto_filters: bool

View File

@ -21,11 +21,11 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
from danswer.utils.logger import DanswerLoggingAdapter
from danswer.utils.logger import setup_logger
from ee.danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from ee.danswer.db.standard_answer import find_matching_standard_answers
from ee.danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
logger = setup_logger()

View File

@ -0,0 +1,98 @@
import re
from typing import Any
from pydantic import BaseModel
from pydantic import field_validator
from pydantic import model_validator
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
class StandardAnswerCategoryCreationRequest(BaseModel):
name: str
class StandardAnswerCategory(BaseModel):
id: int
name: str
@classmethod
def from_model(
cls, standard_answer_category: StandardAnswerCategoryModel
) -> "StandardAnswerCategory":
return cls(
id=standard_answer_category.id,
name=standard_answer_category.name,
)
class StandardAnswer(BaseModel):
id: int
keyword: str
answer: str
categories: list[StandardAnswerCategory]
match_regex: bool
match_any_keywords: bool
@classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
return cls(
id=standard_answer_model.id,
keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer,
match_regex=standard_answer_model.match_regex,
match_any_keywords=standard_answer_model.match_any_keywords,
categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories
],
)
class StandardAnswerCreationRequest(BaseModel):
keyword: str
answer: str
categories: list[int]
match_regex: bool
match_any_keywords: bool
@field_validator("categories", mode="before")
@classmethod
def validate_categories(cls, value: list[int]) -> list[int]:
if len(value) < 1:
raise ValueError(
"At least one category must be attached to a standard answer"
)
return value
@model_validator(mode="after")
def validate_only_match_any_if_not_regex(self) -> Any:
if self.match_regex and self.match_any_keywords:
raise ValueError(
"Can only match any keywords in keyword mode, not regex mode"
)
return self
@model_validator(mode="after")
def validate_keyword_if_regex(self) -> Any:
if not self.match_regex:
# no validation for keywords
return self
try:
re.compile(self.keyword)
return self
except re.error as err:
if isinstance(err.pattern, bytes):
raise ValueError(
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
)
else:
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
raise ValueError(
" ".join(
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
)
)

View File

@ -6,19 +6,19 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.standard_answer import fetch_standard_answer
from danswer.db.standard_answer import fetch_standard_answer_categories
from danswer.db.standard_answer import fetch_standard_answer_category
from danswer.db.standard_answer import fetch_standard_answers
from danswer.db.standard_answer import insert_standard_answer
from danswer.db.standard_answer import insert_standard_answer_category
from danswer.db.standard_answer import remove_standard_answer
from danswer.db.standard_answer import update_standard_answer
from danswer.db.standard_answer import update_standard_answer_category
from danswer.server.manage.models import StandardAnswer
from danswer.server.manage.models import StandardAnswerCategory
from danswer.server.manage.models import StandardAnswerCategoryCreationRequest
from danswer.server.manage.models import StandardAnswerCreationRequest
from ee.danswer.db.standard_answer import fetch_standard_answer
from ee.danswer.db.standard_answer import fetch_standard_answer_categories
from ee.danswer.db.standard_answer import fetch_standard_answer_category
from ee.danswer.db.standard_answer import fetch_standard_answers
from ee.danswer.db.standard_answer import insert_standard_answer
from ee.danswer.db.standard_answer import insert_standard_answer_category
from ee.danswer.db.standard_answer import remove_standard_answer
from ee.danswer.db.standard_answer import update_standard_answer
from ee.danswer.db.standard_answer import update_standard_answer_category
from ee.danswer.server.manage.models import StandardAnswer
from ee.danswer.server.manage.models import StandardAnswerCategory
from ee.danswer.server.manage.models import StandardAnswerCategoryCreationRequest
from ee.danswer.server.manage.models import StandardAnswerCreationRequest
router = APIRouter(prefix="/manage")

View File

@ -8,7 +8,7 @@ from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.server.manage.models import StandardAnswer
from ee.danswer.server.manage.models import StandardAnswer
class StandardAnswerRequest(BaseModel):

View File

@ -13,6 +13,9 @@ from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.settings.models import Settings
from danswer.server.settings.store import store_settings as store_base_settings
from danswer.utils.logger import setup_logger
from ee.danswer.db.standard_answer import (
create_initial_default_standard_answer_category,
)
from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
from ee.danswer.server.enterprise_settings.store import store_analytics_script
@ -21,6 +24,7 @@ from ee.danswer.server.enterprise_settings.store import (
)
from ee.danswer.server.enterprise_settings.store import upload_logo
logger = setup_logger()
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
@ -146,3 +150,6 @@ def seed_db() -> None:
_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)
_seed_analytics_script(seed_config)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)