mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 04:18:32 +02:00
Migrate all standard answer models to ee
This commit is contained in:
parent
6c29ad8d59
commit
6eeee2a96d
@ -62,7 +62,6 @@ from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_current_search_settings
|
||||
from danswer.db.search_settings import update_secondary_search_settings
|
||||
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
@ -186,9 +185,6 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls()
|
||||
|
@ -1,6 +1,4 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -17,13 +15,12 @@ from danswer.db.models import AllowedAnswerFilters
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
from danswer.db.models import User
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
from ee.danswer.server.manage.models import StandardAnswerCategory
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -119,95 +116,6 @@ class HiddenUpdateRequest(BaseModel):
|
||||
hidden: bool
|
||||
|
||||
|
||||
class StandardAnswerCategoryCreationRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class StandardAnswerCategory(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, standard_answer_category: StandardAnswerCategoryModel
|
||||
) -> "StandardAnswerCategory":
|
||||
return cls(
|
||||
id=standard_answer_category.id,
|
||||
name=standard_answer_category.name,
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(BaseModel):
|
||||
id: int
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
return cls(
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
match_any_keywords=standard_answer_model.match_any_keywords,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
"At least one category must be attached to a standard answer"
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_only_match_any_if_not_regex(self) -> Any:
|
||||
if self.match_regex and self.match_any_keywords:
|
||||
raise ValueError(
|
||||
"Can only match any keywords in keyword mode, not regex mode"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
# no validation for keywords
|
||||
return self
|
||||
|
||||
try:
|
||||
re.compile(self.keyword)
|
||||
return self
|
||||
except re.error as err:
|
||||
if isinstance(err.pattern, bytes):
|
||||
raise ValueError(
|
||||
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
|
||||
)
|
||||
else:
|
||||
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
|
||||
raise ValueError(
|
||||
" ".join(
|
||||
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SlackBotTokens(BaseModel):
|
||||
bot_token: str
|
||||
app_token: str
|
||||
@ -257,6 +165,7 @@ class SlackBotConfig(BaseModel):
|
||||
persona: PersonaSnapshot | None
|
||||
channel_config: ChannelConfig
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[StandardAnswerCategory]
|
||||
enable_auto_filters: bool
|
||||
|
||||
|
@ -21,11 +21,11 @@ from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from ee.danswer.db.standard_answer import find_matching_standard_answers
|
||||
from ee.danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
98
backend/ee/danswer/server/manage/models.py
Normal file
98
backend/ee/danswer/server/manage/models.py
Normal file
@ -0,0 +1,98 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
|
||||
|
||||
class StandardAnswerCategoryCreationRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class StandardAnswerCategory(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, standard_answer_category: StandardAnswerCategoryModel
|
||||
) -> "StandardAnswerCategory":
|
||||
return cls(
|
||||
id=standard_answer_category.id,
|
||||
name=standard_answer_category.name,
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(BaseModel):
|
||||
id: int
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
return cls(
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
match_any_keywords=standard_answer_model.match_any_keywords,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
"At least one category must be attached to a standard answer"
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_only_match_any_if_not_regex(self) -> Any:
|
||||
if self.match_regex and self.match_any_keywords:
|
||||
raise ValueError(
|
||||
"Can only match any keywords in keyword mode, not regex mode"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
# no validation for keywords
|
||||
return self
|
||||
|
||||
try:
|
||||
re.compile(self.keyword)
|
||||
return self
|
||||
except re.error as err:
|
||||
if isinstance(err.pattern, bytes):
|
||||
raise ValueError(
|
||||
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
|
||||
)
|
||||
else:
|
||||
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
|
||||
raise ValueError(
|
||||
" ".join(
|
||||
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
|
||||
)
|
||||
)
|
@ -6,19 +6,19 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.standard_answer import fetch_standard_answer
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories
|
||||
from danswer.db.standard_answer import fetch_standard_answer_category
|
||||
from danswer.db.standard_answer import fetch_standard_answers
|
||||
from danswer.db.standard_answer import insert_standard_answer
|
||||
from danswer.db.standard_answer import insert_standard_answer_category
|
||||
from danswer.db.standard_answer import remove_standard_answer
|
||||
from danswer.db.standard_answer import update_standard_answer
|
||||
from danswer.db.standard_answer import update_standard_answer_category
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.server.manage.models import StandardAnswerCategory
|
||||
from danswer.server.manage.models import StandardAnswerCategoryCreationRequest
|
||||
from danswer.server.manage.models import StandardAnswerCreationRequest
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer_categories
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer_category
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answers
|
||||
from ee.danswer.db.standard_answer import insert_standard_answer
|
||||
from ee.danswer.db.standard_answer import insert_standard_answer_category
|
||||
from ee.danswer.db.standard_answer import remove_standard_answer
|
||||
from ee.danswer.db.standard_answer import update_standard_answer
|
||||
from ee.danswer.db.standard_answer import update_standard_answer_category
|
||||
from ee.danswer.server.manage.models import StandardAnswer
|
||||
from ee.danswer.server.manage.models import StandardAnswerCategory
|
||||
from ee.danswer.server.manage.models import StandardAnswerCategoryCreationRequest
|
||||
from ee.danswer.server.manage.models import StandardAnswerCreationRequest
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
@ -8,7 +8,7 @@ from danswer.search.enums import SearchType
|
||||
from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from ee.danswer.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
|
@ -13,6 +13,9 @@ from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.settings.models import Settings
|
||||
from danswer.server.settings.store import store_settings as store_base_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.standard_answer import (
|
||||
create_initial_default_standard_answer_category,
|
||||
)
|
||||
from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.danswer.server.enterprise_settings.store import store_analytics_script
|
||||
@ -21,6 +24,7 @@ from ee.danswer.server.enterprise_settings.store import (
|
||||
)
|
||||
from ee.danswer.server.enterprise_settings.store import upload_logo
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
|
||||
@ -146,3 +150,6 @@ def seed_db() -> None:
|
||||
_seed_logo(db_session, seed_config.seeded_logo_path)
|
||||
_seed_enterprise_settings(seed_config)
|
||||
_seed_analytics_script(seed_config)
|
||||
|
||||
logger.notice("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
Loading…
x
Reference in New Issue
Block a user