Support regex in standard answers (#2377)

* Support regex in standard answers

* fix mypy

* Add match_regex boolean column to StandardAnswer

* Add match_regex flag and validation to Pydantic models

* GET /manage/admin/standard-answer: add match_regex to create_standard_answer

* PATCH /manage/admin/standard-answer/🆔 add match_regex to update_standard_answer

* Add "Match Regex" toggle to standard answer form

* Decode error pattern in case it's bytes

* Refactor regex support to use match_regex flag instead of supplemental tuple

* Better error handling for invalid regexes

* Show "match regex" in table and style keywords appropriately

* Fix stale UI copy for non-"match_regex" branch

* Fix stale docstring in find_matching_standard_answers

* Update down_revision to reflect most recent migration

* Update UI copy

* Initial implementation of match group display

* Fix pydantic StandardAnswer vs SQLAlchemy StandardAnswer model usage

* Update docstring return type

* Fix missing key prop

---------

Co-authored-by: Hyeong Joon Suh <hyeongjoonsuh@Hyeongs-MacBook-Pro.local>
Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
This commit is contained in:
hj-danswer
2024-09-13 17:07:42 -07:00
committed by GitHub
parent da6e46ae75
commit 3cb00de6d4
10 changed files with 171 additions and 41 deletions

View File

@@ -0,0 +1,30 @@
"""standard answer match_regex flag
Revision ID: efb35676026c
Revises: 52a219fb5233
Create Date: 2024-09-11 13:55:46.101149
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "efb35676026c"
down_revision = "0ebb1d516877"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column("match_regex", sa.Boolean(), nullable=False, default=False),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("standard_answer", "match_regex")
# ### end Alembic commands ###

View File

@@ -16,9 +16,10 @@ from danswer.db.chat import get_chat_sessions_by_slack_thread_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from danswer.db.standard_answer import find_matching_standard_answers
from danswer.server.manage.models import StandardAnswer
from danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
from danswer.utils.logger import DanswerLoggingAdapter
from danswer.utils.logger import setup_logger
@@ -29,7 +30,7 @@ def oneoff_standard_answers(
message: str,
slack_bot_categories: list[str],
db_session: Session,
) -> list[StandardAnswer]:
) -> list[PydanticStandardAnswer]:
"""
Respond to the user message if it matches any configured standard answers.
@@ -50,7 +51,8 @@ def oneoff_standard_answers(
)
server_standard_answers = [
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
PydanticStandardAnswer.from_model(standard_answer_model)
for (standard_answer_model, _) in matching_standard_answers
]
return server_standard_answers
@@ -114,14 +116,15 @@ def handle_standard_answers(
usable_standard_answers = configured_standard_answers.difference(
used_standard_answer_ids
)
matching_standard_answers: list[tuple[StandardAnswerModel, str]] = []
if usable_standard_answers:
matching_standard_answers = find_matching_standard_answers(
query=query_msg.message,
id_in=[standard_answer.id for standard_answer in usable_standard_answers],
db_session=db_session,
)
else:
matching_standard_answers = []
if matching_standard_answers:
chat_session = create_chat_session(
db_session=db_session,
@@ -149,12 +152,12 @@ def handle_standard_answers(
)
formatted_answers = []
for standard_answer in matching_standard_answers:
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
formatted_answer = (
f'Since you mentioned _"{standard_answer.keyword}"_, '
f"I thought this might be useful: \n\n{block_quotified_answer}"
for standard_answer, match_str in matching_standard_answers:
since_you_mentioned_pretext = (
f'Since your question contained "_{match_str}_"'
)
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
formatted_answer = f"{since_you_mentioned_pretext}, I thought this might be useful: \n\n{block_quotified_answer}"
formatted_answers.append(formatted_answer)
answer_message = "\n\n".join(formatted_answers)

View File

@@ -1371,6 +1371,7 @@ class StandardAnswer(Base):
keyword: Mapped[str] = mapped_column(String)
answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
match_regex: Mapped[bool] = mapped_column(Boolean)
__table_args__ = (
Index(

View File

@@ -1,3 +1,4 @@
import re
import string
from collections.abc import Sequence
@@ -41,6 +42,7 @@ def insert_standard_answer(
keyword: str,
answer: str,
category_ids: list[int],
match_regex: bool,
db_session: Session,
) -> StandardAnswer:
existing_categories = fetch_standard_answer_categories_by_ids(
@@ -55,6 +57,7 @@ def insert_standard_answer(
answer=answer,
categories=existing_categories,
active=True,
match_regex=match_regex,
)
db_session.add(standard_answer)
db_session.commit()
@@ -66,6 +69,7 @@ def update_standard_answer(
keyword: str,
answer: str,
category_ids: list[int],
match_regex: bool,
db_session: Session,
) -> StandardAnswer:
standard_answer = db_session.scalar(
@@ -84,6 +88,7 @@ def update_standard_answer(
standard_answer.keyword = keyword
standard_answer.answer = answer
standard_answer.categories = list(existing_categories)
standard_answer.match_regex = match_regex
db_session.commit()
@@ -181,31 +186,51 @@ def find_matching_standard_answers(
id_in: list[int],
query: str,
db_session: Session,
) -> list[StandardAnswer]:
) -> list[tuple[StandardAnswer, str]]:
"""
Returns a list of tuples, where each tuple is a StandardAnswer definition matching
the query and a string representing the match (either the regex match group or the
set of keywords).
If `answer_instance.match_regex` is true, the definition is considered "matched"
if the query matches the `answer_instance.keyword` using `re.search`.
Otherwise, the definition is considered "matched" if each space-delimited token
in `keyword` exists in `query`.
"""
stmt = (
select(StandardAnswer)
.where(StandardAnswer.active.is_(True))
.where(StandardAnswer.id.in_(id_in))
)
possible_standard_answers = db_session.scalars(stmt).all()
possible_standard_answers: Sequence[StandardAnswer] = db_session.scalars(stmt).all()
matching_standard_answers: list[StandardAnswer] = []
matching_standard_answers: list[tuple[StandardAnswer, str]] = []
for standard_answer in possible_standard_answers:
# Remove punctuation and split the keyword into individual words
keyword_words = "".join(
char
for char in standard_answer.keyword.lower()
if char not in string.punctuation
).split()
if standard_answer.match_regex:
maybe_matches = re.search(standard_answer.keyword, query, re.IGNORECASE)
if maybe_matches is not None:
match_group = maybe_matches.group(0)
matching_standard_answers.append((standard_answer, match_group))
# Remove punctuation and split the query into individual words
query_words = "".join(
char for char in query.lower() if char not in string.punctuation
).split()
else:
# Remove punctuation and split the keyword into individual words
keyword_words = "".join(
char
for char in standard_answer.keyword.lower()
if char not in string.punctuation
).split()
# Check if all of the keyword words are in the query words
if all(word in query_words for word in keyword_words):
matching_standard_answers.append(standard_answer)
# Remove punctuation and split the query into individual words
query_words = "".join(
char for char in query.lower() if char not in string.punctuation
).split()
# Check if all of the keyword words are in the query words
if all(word in query_words for word in keyword_words):
matching_standard_answers.append(
(standard_answer, standard_answer.keyword)
)
return matching_standard_answers

View File

@@ -1,4 +1,6 @@
import re
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
@@ -140,6 +142,7 @@ class StandardAnswer(BaseModel):
keyword: str
answer: str
categories: list[StandardAnswerCategory]
match_regex: bool
@classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
@@ -147,6 +150,7 @@ class StandardAnswer(BaseModel):
id=standard_answer_model.id,
keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer,
match_regex=standard_answer_model.match_regex,
categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories
@@ -158,6 +162,7 @@ class StandardAnswerCreationRequest(BaseModel):
keyword: str
answer: str
categories: list[int]
match_regex: bool
@field_validator("categories", mode="before")
@classmethod
@@ -168,6 +173,28 @@ class StandardAnswerCreationRequest(BaseModel):
)
return value
@model_validator(mode="after")
def validate_keyword_if_regex(self) -> Any:
if not self.match_regex:
# no validation for keywords
return self
try:
re.compile(self.keyword)
return self
except re.error as err:
if isinstance(err.pattern, bytes):
raise ValueError(
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
)
else:
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
raise ValueError(
" ".join(
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
)
)
class SlackBotTokens(BaseModel):
bot_token: str

View File

@@ -33,6 +33,7 @@ def create_standard_answer(
keyword=standard_answer_creation_request.keyword,
answer=standard_answer_creation_request.answer,
category_ids=standard_answer_creation_request.categories,
match_regex=standard_answer_creation_request.match_regex,
db_session=db_session,
)
return StandardAnswer.from_model(standard_answer_model)
@@ -70,6 +71,7 @@ def patch_standard_answer(
keyword=standard_answer_creation_request.keyword,
answer=standard_answer_creation_request.answer,
category_ids=standard_answer_creation_request.categories,
match_regex=standard_answer_creation_request.match_regex,
db_session=db_session,
)
return StandardAnswer.from_model(standard_answer_model)

View File

@@ -14,6 +14,7 @@ import {
import {
TextFormField,
MarkdownFormField,
BooleanFormField,
} from "@/components/admin/connectors/Field";
import MultiSelectDropdown from "@/components/MultiSelectDropdown";
@@ -41,10 +42,13 @@ export const StandardAnswerCreationForm = ({
categories: existingStandardAnswer
? existingStandardAnswer.categories
: [],
matchRegex: existingStandardAnswer
? existingStandardAnswer.match_regex
: false,
}}
validationSchema={Yup.object().shape({
keyword: Yup.string()
.required("Keyword or phrase is required")
.required("Keywords or pattern is required")
.max(255)
.min(1),
answer: Yup.string().required("Answer is required").min(1),
@@ -86,18 +90,34 @@ export const StandardAnswerCreationForm = ({
>
{({ isSubmitting, values, setFieldValue }) => (
<Form>
<TextFormField
name="keyword"
label="Keywords"
tooltip="If all specified keywords are in the question, then we will respond with the answer below"
placeholder="e.g. Wifi Password"
autoCompleteDisabled={true}
{values.matchRegex ? (
<TextFormField
name="keyword"
label="Regex pattern"
isCode
tooltip="Triggers if the question matches this regex pattern (using Python `re.search()`)"
placeholder="(?:it|support)\s*ticket"
/>
) : (
<TextFormField
name="keyword"
label="Keywords"
tooltip="Triggers if the question contains all of these keywords, in any order."
placeholder="it ticket"
autoCompleteDisabled={true}
/>
)}
<BooleanFormField
subtext="Match a regex pattern instead of an exact keyword"
optional
label="Match regex"
name="matchRegex"
/>
<div className="w-full">
<MarkdownFormField
name="answer"
label="Answer"
placeholder="The answer in markdown"
placeholder="The answer in Markdown. Example: If you need any help from the IT team, please email internalsupport@company.com"
/>
</div>
<div className="w-4/12">

View File

@@ -6,6 +6,7 @@ export interface StandardAnswerCreationRequest {
keyword: string;
answer: string;
categories: number[];
matchRegex: boolean;
}
const buildRequestBodyFromStandardAnswerCategoryCreationRequest = (
@@ -48,6 +49,7 @@ const buildRequestBodyFromStandardAnswerCreationRequest = (
keyword: request.keyword,
answer: request.answer,
categories: request.categories,
match_regex: request.matchRegex,
});
};

View File

@@ -26,6 +26,7 @@ import { FilterDropdown } from "@/components/search/filtering/FilterDropdown";
import { FiTag } from "react-icons/fi";
import { SelectedBubble } from "@/components/search/filtering/Filters";
import { PageSelector } from "@/components/PageSelector";
import { CustomCheckbox } from "@/components/CustomCheckbox";
const NUM_RESULTS_PER_PAGE = 10;
@@ -36,15 +37,23 @@ const RowTemplate = ({
entries,
}: {
id: number;
entries: [Displayable, Displayable, Displayable, Displayable, Displayable];
entries: [
Displayable,
Displayable,
Displayable,
Displayable,
Displayable,
Displayable,
];
}) => {
return (
<TableRow key={id}>
<TableCell className="w-1/24">{entries[0]}</TableCell>
<TableCell className="w-2/12">{entries[1]}</TableCell>
<TableCell className="w-2/12">{entries[2]}</TableCell>
<TableCell className="w-7/12 overflow-auto">{entries[3]}</TableCell>
<TableCell className="w-1/24">{entries[4]}</TableCell>
<TableCell className="w-1/24">{entries[3]}</TableCell>
<TableCell className="w-7/12 overflow-auto">{entries[4]}</TableCell>
<TableCell className="w-1/24">{entries[5]}</TableCell>
</TableRow>
);
};
@@ -108,7 +117,15 @@ const StandardAnswersTableRow = ({
<CategoryBubble key={category.id} name={category.name} />
))}
</div>,
standardAnswer.keyword,
<ReactMarkdown key={`keyword-${standardAnswer.id}`}>
{standardAnswer.match_regex
? `\`${standardAnswer.keyword}\``
: standardAnswer.keyword}
</ReactMarkdown>,
<CustomCheckbox
key={`match_regex-${standardAnswer.id}`}
checked={standardAnswer.match_regex}
/>,
<ReactMarkdown
key={`answer-${standardAnswer.id}`}
className="prose"
@@ -147,13 +164,15 @@ const StandardAnswersTable = ({
const columns = [
{ name: "", key: "edit" },
{ name: "Categories", key: "category" },
{ name: "Keyword/Phrase", key: "keyword" },
{ name: "Keywords/Pattern", key: "keyword" },
{ name: "Match regex?", key: "match_regex" },
{ name: "Answer", key: "answer" },
{ name: "", key: "delete" },
];
const filteredStandardAnswers = standardAnswers.filter((standardAnswer) => {
const { answer, id, categories, ...fieldsToSearch } = standardAnswer;
const { answer, id, categories, match_regex, ...fieldsToSearch } =
standardAnswer;
const cleanedQuery = query.toLowerCase();
const searchMatch = Object.values(fieldsToSearch).some((value) => {
return value.toLowerCase().includes(cleanedQuery);
@@ -285,7 +304,7 @@ const StandardAnswersTable = ({
/>
))
) : (
<RowTemplate id={0} entries={["", "", "", "", ""]} />
<RowTemplate id={0} entries={["", "", "", "", "", ""]} />
)}
</TableBody>
</Table>

View File

@@ -161,6 +161,7 @@ export interface StandardAnswer {
id: number;
keyword: string;
answer: string;
match_regex: boolean;
categories: StandardAnswerCategory[];
}