mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-28 12:58:41 +02:00
Support regex in standard answers (#2377)
* Support regex in standard answers * fix mypy * Add match_regex boolean column to StandardAnswer * Add match_regex flag and validation to Pydantic models * GET /manage/admin/standard-answer: add match_regex to create_standard_answer * PATCH /manage/admin/standard-answer/🆔 add match_regex to update_standard_answer * Add "Match Regex" toggle to standard answer form * Decode error pattern in case it's bytes * Refactor regex support to use match_regex flag instead of supplemental tuple * Better error handling for invalid regexes * Show "match regex" in table and style keywords appropriately * Fix stale UI copy for non-"match_regex" branch * Fix stale docstring in find_matching_standard_answers * Update down_revision to reflect most recent migration * Update UI copy * Initial implementation of match group display * Fix pydantic StandardAnswer vs SQLAlchemy StandardAnswer model usage * Update docstring return type * Fix missing key prop --------- Co-authored-by: Hyeong Joon Suh <hyeongjoonsuh@Hyeongs-MacBook-Pro.local> Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
"""standard answer match_regex flag
|
||||
|
||||
Revision ID: efb35676026c
|
||||
Revises: 52a219fb5233
|
||||
Create Date: 2024-09-11 13:55:46.101149
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "efb35676026c"
|
||||
down_revision = "0ebb1d516877"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"standard_answer",
|
||||
sa.Column("match_regex", sa.Boolean(), nullable=False, default=False),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("standard_answer", "match_regex")
|
||||
# ### end Alembic commands ###
|
@@ -16,9 +16,10 @@ from danswer.db.chat import get_chat_sessions_by_slack_thread_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from danswer.db.standard_answer import find_matching_standard_answers
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -29,7 +30,7 @@ def oneoff_standard_answers(
|
||||
message: str,
|
||||
slack_bot_categories: list[str],
|
||||
db_session: Session,
|
||||
) -> list[StandardAnswer]:
|
||||
) -> list[PydanticStandardAnswer]:
|
||||
"""
|
||||
Respond to the user message if it matches any configured standard answers.
|
||||
|
||||
@@ -50,7 +51,8 @@ def oneoff_standard_answers(
|
||||
)
|
||||
|
||||
server_standard_answers = [
|
||||
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
|
||||
PydanticStandardAnswer.from_model(standard_answer_model)
|
||||
for (standard_answer_model, _) in matching_standard_answers
|
||||
]
|
||||
return server_standard_answers
|
||||
|
||||
@@ -114,14 +116,15 @@ def handle_standard_answers(
|
||||
usable_standard_answers = configured_standard_answers.difference(
|
||||
used_standard_answer_ids
|
||||
)
|
||||
|
||||
matching_standard_answers: list[tuple[StandardAnswerModel, str]] = []
|
||||
if usable_standard_answers:
|
||||
matching_standard_answers = find_matching_standard_answers(
|
||||
query=query_msg.message,
|
||||
id_in=[standard_answer.id for standard_answer in usable_standard_answers],
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
matching_standard_answers = []
|
||||
|
||||
if matching_standard_answers:
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
@@ -149,12 +152,12 @@ def handle_standard_answers(
|
||||
)
|
||||
|
||||
formatted_answers = []
|
||||
for standard_answer in matching_standard_answers:
|
||||
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
|
||||
formatted_answer = (
|
||||
f'Since you mentioned _"{standard_answer.keyword}"_, '
|
||||
f"I thought this might be useful: \n\n{block_quotified_answer}"
|
||||
for standard_answer, match_str in matching_standard_answers:
|
||||
since_you_mentioned_pretext = (
|
||||
f'Since your question contained "_{match_str}_"'
|
||||
)
|
||||
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
|
||||
formatted_answer = f"{since_you_mentioned_pretext}, I thought this might be useful: \n\n{block_quotified_answer}"
|
||||
formatted_answers.append(formatted_answer)
|
||||
answer_message = "\n\n".join(formatted_answers)
|
||||
|
||||
|
@@ -1371,6 +1371,7 @@ class StandardAnswer(Base):
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
match_regex: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
import string
|
||||
from collections.abc import Sequence
|
||||
|
||||
@@ -41,6 +42,7 @@ def insert_standard_answer(
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
@@ -55,6 +57,7 @@ def insert_standard_answer(
|
||||
answer=answer,
|
||||
categories=existing_categories,
|
||||
active=True,
|
||||
match_regex=match_regex,
|
||||
)
|
||||
db_session.add(standard_answer)
|
||||
db_session.commit()
|
||||
@@ -66,6 +69,7 @@ def update_standard_answer(
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
standard_answer = db_session.scalar(
|
||||
@@ -84,6 +88,7 @@ def update_standard_answer(
|
||||
standard_answer.keyword = keyword
|
||||
standard_answer.answer = answer
|
||||
standard_answer.categories = list(existing_categories)
|
||||
standard_answer.match_regex = match_regex
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -181,31 +186,51 @@ def find_matching_standard_answers(
|
||||
id_in: list[int],
|
||||
query: str,
|
||||
db_session: Session,
|
||||
) -> list[StandardAnswer]:
|
||||
) -> list[tuple[StandardAnswer, str]]:
|
||||
"""
|
||||
Returns a list of tuples, where each tuple is a StandardAnswer definition matching
|
||||
the query and a string representing the match (either the regex match group or the
|
||||
set of keywords).
|
||||
|
||||
If `answer_instance.match_regex` is true, the definition is considered "matched"
|
||||
if the query matches the `answer_instance.keyword` using `re.search`.
|
||||
|
||||
Otherwise, the definition is considered "matched" if each space-delimited token
|
||||
in `keyword` exists in `query`.
|
||||
"""
|
||||
stmt = (
|
||||
select(StandardAnswer)
|
||||
.where(StandardAnswer.active.is_(True))
|
||||
.where(StandardAnswer.id.in_(id_in))
|
||||
)
|
||||
possible_standard_answers = db_session.scalars(stmt).all()
|
||||
possible_standard_answers: Sequence[StandardAnswer] = db_session.scalars(stmt).all()
|
||||
|
||||
matching_standard_answers: list[StandardAnswer] = []
|
||||
matching_standard_answers: list[tuple[StandardAnswer, str]] = []
|
||||
for standard_answer in possible_standard_answers:
|
||||
# Remove punctuation and split the keyword into individual words
|
||||
keyword_words = "".join(
|
||||
char
|
||||
for char in standard_answer.keyword.lower()
|
||||
if char not in string.punctuation
|
||||
).split()
|
||||
if standard_answer.match_regex:
|
||||
maybe_matches = re.search(standard_answer.keyword, query, re.IGNORECASE)
|
||||
if maybe_matches is not None:
|
||||
match_group = maybe_matches.group(0)
|
||||
matching_standard_answers.append((standard_answer, match_group))
|
||||
|
||||
# Remove punctuation and split the query into individual words
|
||||
query_words = "".join(
|
||||
char for char in query.lower() if char not in string.punctuation
|
||||
).split()
|
||||
else:
|
||||
# Remove punctuation and split the keyword into individual words
|
||||
keyword_words = "".join(
|
||||
char
|
||||
for char in standard_answer.keyword.lower()
|
||||
if char not in string.punctuation
|
||||
).split()
|
||||
|
||||
# Check if all of the keyword words are in the query words
|
||||
if all(word in query_words for word in keyword_words):
|
||||
matching_standard_answers.append(standard_answer)
|
||||
# Remove punctuation and split the query into individual words
|
||||
query_words = "".join(
|
||||
char for char in query.lower() if char not in string.punctuation
|
||||
).split()
|
||||
|
||||
# Check if all of the keyword words are in the query words
|
||||
if all(word in query_words for word in keyword_words):
|
||||
matching_standard_answers.append(
|
||||
(standard_answer, standard_answer.keyword)
|
||||
)
|
||||
|
||||
return matching_standard_answers
|
||||
|
||||
|
@@ -1,4 +1,6 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -140,6 +142,7 @@ class StandardAnswer(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
@@ -147,6 +150,7 @@ class StandardAnswer(BaseModel):
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
@@ -158,6 +162,7 @@ class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
@@ -168,6 +173,28 @@ class StandardAnswerCreationRequest(BaseModel):
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
# no validation for keywords
|
||||
return self
|
||||
|
||||
try:
|
||||
re.compile(self.keyword)
|
||||
return self
|
||||
except re.error as err:
|
||||
if isinstance(err.pattern, bytes):
|
||||
raise ValueError(
|
||||
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
|
||||
)
|
||||
else:
|
||||
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
|
||||
raise ValueError(
|
||||
" ".join(
|
||||
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SlackBotTokens(BaseModel):
|
||||
bot_token: str
|
||||
|
@@ -33,6 +33,7 @@ def create_standard_answer(
|
||||
keyword=standard_answer_creation_request.keyword,
|
||||
answer=standard_answer_creation_request.answer,
|
||||
category_ids=standard_answer_creation_request.categories,
|
||||
match_regex=standard_answer_creation_request.match_regex,
|
||||
db_session=db_session,
|
||||
)
|
||||
return StandardAnswer.from_model(standard_answer_model)
|
||||
@@ -70,6 +71,7 @@ def patch_standard_answer(
|
||||
keyword=standard_answer_creation_request.keyword,
|
||||
answer=standard_answer_creation_request.answer,
|
||||
category_ids=standard_answer_creation_request.categories,
|
||||
match_regex=standard_answer_creation_request.match_regex,
|
||||
db_session=db_session,
|
||||
)
|
||||
return StandardAnswer.from_model(standard_answer_model)
|
||||
|
@@ -14,6 +14,7 @@ import {
|
||||
import {
|
||||
TextFormField,
|
||||
MarkdownFormField,
|
||||
BooleanFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import MultiSelectDropdown from "@/components/MultiSelectDropdown";
|
||||
|
||||
@@ -41,10 +42,13 @@ export const StandardAnswerCreationForm = ({
|
||||
categories: existingStandardAnswer
|
||||
? existingStandardAnswer.categories
|
||||
: [],
|
||||
matchRegex: existingStandardAnswer
|
||||
? existingStandardAnswer.match_regex
|
||||
: false,
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
keyword: Yup.string()
|
||||
.required("Keyword or phrase is required")
|
||||
.required("Keywords or pattern is required")
|
||||
.max(255)
|
||||
.min(1),
|
||||
answer: Yup.string().required("Answer is required").min(1),
|
||||
@@ -86,18 +90,34 @@ export const StandardAnswerCreationForm = ({
|
||||
>
|
||||
{({ isSubmitting, values, setFieldValue }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="keyword"
|
||||
label="Keywords"
|
||||
tooltip="If all specified keywords are in the question, then we will respond with the answer below"
|
||||
placeholder="e.g. Wifi Password"
|
||||
autoCompleteDisabled={true}
|
||||
{values.matchRegex ? (
|
||||
<TextFormField
|
||||
name="keyword"
|
||||
label="Regex pattern"
|
||||
isCode
|
||||
tooltip="Triggers if the question matches this regex pattern (using Python `re.search()`)"
|
||||
placeholder="(?:it|support)\s*ticket"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="keyword"
|
||||
label="Keywords"
|
||||
tooltip="Triggers if the question contains all of these keywords, in any order."
|
||||
placeholder="it ticket"
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
)}
|
||||
<BooleanFormField
|
||||
subtext="Match a regex pattern instead of an exact keyword"
|
||||
optional
|
||||
label="Match regex"
|
||||
name="matchRegex"
|
||||
/>
|
||||
<div className="w-full">
|
||||
<MarkdownFormField
|
||||
name="answer"
|
||||
label="Answer"
|
||||
placeholder="The answer in markdown"
|
||||
placeholder="The answer in Markdown. Example: If you need any help from the IT team, please email internalsupport@company.com"
|
||||
/>
|
||||
</div>
|
||||
<div className="w-4/12">
|
||||
|
@@ -6,6 +6,7 @@ export interface StandardAnswerCreationRequest {
|
||||
keyword: string;
|
||||
answer: string;
|
||||
categories: number[];
|
||||
matchRegex: boolean;
|
||||
}
|
||||
|
||||
const buildRequestBodyFromStandardAnswerCategoryCreationRequest = (
|
||||
@@ -48,6 +49,7 @@ const buildRequestBodyFromStandardAnswerCreationRequest = (
|
||||
keyword: request.keyword,
|
||||
answer: request.answer,
|
||||
categories: request.categories,
|
||||
match_regex: request.matchRegex,
|
||||
});
|
||||
};
|
||||
|
||||
|
@@ -26,6 +26,7 @@ import { FilterDropdown } from "@/components/search/filtering/FilterDropdown";
|
||||
import { FiTag } from "react-icons/fi";
|
||||
import { SelectedBubble } from "@/components/search/filtering/Filters";
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
import { CustomCheckbox } from "@/components/CustomCheckbox";
|
||||
|
||||
const NUM_RESULTS_PER_PAGE = 10;
|
||||
|
||||
@@ -36,15 +37,23 @@ const RowTemplate = ({
|
||||
entries,
|
||||
}: {
|
||||
id: number;
|
||||
entries: [Displayable, Displayable, Displayable, Displayable, Displayable];
|
||||
entries: [
|
||||
Displayable,
|
||||
Displayable,
|
||||
Displayable,
|
||||
Displayable,
|
||||
Displayable,
|
||||
Displayable,
|
||||
];
|
||||
}) => {
|
||||
return (
|
||||
<TableRow key={id}>
|
||||
<TableCell className="w-1/24">{entries[0]}</TableCell>
|
||||
<TableCell className="w-2/12">{entries[1]}</TableCell>
|
||||
<TableCell className="w-2/12">{entries[2]}</TableCell>
|
||||
<TableCell className="w-7/12 overflow-auto">{entries[3]}</TableCell>
|
||||
<TableCell className="w-1/24">{entries[4]}</TableCell>
|
||||
<TableCell className="w-1/24">{entries[3]}</TableCell>
|
||||
<TableCell className="w-7/12 overflow-auto">{entries[4]}</TableCell>
|
||||
<TableCell className="w-1/24">{entries[5]}</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
};
|
||||
@@ -108,7 +117,15 @@ const StandardAnswersTableRow = ({
|
||||
<CategoryBubble key={category.id} name={category.name} />
|
||||
))}
|
||||
</div>,
|
||||
standardAnswer.keyword,
|
||||
<ReactMarkdown key={`keyword-${standardAnswer.id}`}>
|
||||
{standardAnswer.match_regex
|
||||
? `\`${standardAnswer.keyword}\``
|
||||
: standardAnswer.keyword}
|
||||
</ReactMarkdown>,
|
||||
<CustomCheckbox
|
||||
key={`match_regex-${standardAnswer.id}`}
|
||||
checked={standardAnswer.match_regex}
|
||||
/>,
|
||||
<ReactMarkdown
|
||||
key={`answer-${standardAnswer.id}`}
|
||||
className="prose"
|
||||
@@ -147,13 +164,15 @@ const StandardAnswersTable = ({
|
||||
const columns = [
|
||||
{ name: "", key: "edit" },
|
||||
{ name: "Categories", key: "category" },
|
||||
{ name: "Keyword/Phrase", key: "keyword" },
|
||||
{ name: "Keywords/Pattern", key: "keyword" },
|
||||
{ name: "Match regex?", key: "match_regex" },
|
||||
{ name: "Answer", key: "answer" },
|
||||
{ name: "", key: "delete" },
|
||||
];
|
||||
|
||||
const filteredStandardAnswers = standardAnswers.filter((standardAnswer) => {
|
||||
const { answer, id, categories, ...fieldsToSearch } = standardAnswer;
|
||||
const { answer, id, categories, match_regex, ...fieldsToSearch } =
|
||||
standardAnswer;
|
||||
const cleanedQuery = query.toLowerCase();
|
||||
const searchMatch = Object.values(fieldsToSearch).some((value) => {
|
||||
return value.toLowerCase().includes(cleanedQuery);
|
||||
@@ -285,7 +304,7 @@ const StandardAnswersTable = ({
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
<RowTemplate id={0} entries={["", "", "", "", ""]} />
|
||||
<RowTemplate id={0} entries={["", "", "", "", "", ""]} />
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
@@ -161,6 +161,7 @@ export interface StandardAnswer {
|
||||
id: number;
|
||||
keyword: string;
|
||||
answer: string;
|
||||
match_regex: boolean;
|
||||
categories: StandardAnswerCategory[];
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user