Support regex in standard answers (#2377)

* Support regex in standard answers

* fix mypy

* Add match_regex boolean column to StandardAnswer

* Add match_regex flag and validation to Pydantic models

* GET /manage/admin/standard-answer: add match_regex to create_standard_answer

* PATCH /manage/admin/standard-answer/🆔 add match_regex to update_standard_answer

* Add "Match Regex" toggle to standard answer form

* Decode error pattern in case it's bytes

* Refactor regex support to use match_regex flag instead of supplemental tuple

* Better error handling for invalid regexes

* Show "match regex" in table and style keywords appropriately

* Fix stale UI copy for non-"match_regex" branch

* Fix stale docstring in find_matching_standard_answers

* Update down_revision to reflect most recent migration

* Update UI copy

* Initial implementation of match group display

* Fix pydantic StandardAnswer vs SQLAlchemy StandardAnswer model usage

* Update docstring return type

* Fix missing key prop

---------

Co-authored-by: Hyeong Joon Suh <hyeongjoonsuh@Hyeongs-MacBook-Pro.local>
Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
This commit is contained in:
hj-danswer
2024-09-13 17:07:42 -07:00
committed by GitHub
parent da6e46ae75
commit 3cb00de6d4
10 changed files with 171 additions and 41 deletions

View File

@@ -0,0 +1,30 @@
"""standard answer match_regex flag
Revision ID: efb35676026c
Revises: 52a219fb5233
Create Date: 2024-09-11 13:55:46.101149
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "efb35676026c"
down_revision = "0ebb1d516877"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column("match_regex", sa.Boolean(), nullable=False, default=False),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("standard_answer", "match_regex")
# ### end Alembic commands ###

View File

@@ -16,9 +16,10 @@ from danswer.db.chat import get_chat_sessions_by_slack_thread_id
from danswer.db.chat import get_or_create_root_message from danswer.db.chat import get_or_create_root_message
from danswer.db.models import Prompt from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotConfig
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from danswer.db.standard_answer import find_matching_standard_answers from danswer.db.standard_answer import find_matching_standard_answers
from danswer.server.manage.models import StandardAnswer from danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
from danswer.utils.logger import DanswerLoggingAdapter from danswer.utils.logger import DanswerLoggingAdapter
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@@ -29,7 +30,7 @@ def oneoff_standard_answers(
message: str, message: str,
slack_bot_categories: list[str], slack_bot_categories: list[str],
db_session: Session, db_session: Session,
) -> list[StandardAnswer]: ) -> list[PydanticStandardAnswer]:
""" """
Respond to the user message if it matches any configured standard answers. Respond to the user message if it matches any configured standard answers.
@@ -50,7 +51,8 @@ def oneoff_standard_answers(
) )
server_standard_answers = [ server_standard_answers = [
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers PydanticStandardAnswer.from_model(standard_answer_model)
for (standard_answer_model, _) in matching_standard_answers
] ]
return server_standard_answers return server_standard_answers
@@ -114,14 +116,15 @@ def handle_standard_answers(
usable_standard_answers = configured_standard_answers.difference( usable_standard_answers = configured_standard_answers.difference(
used_standard_answer_ids used_standard_answer_ids
) )
matching_standard_answers: list[tuple[StandardAnswerModel, str]] = []
if usable_standard_answers: if usable_standard_answers:
matching_standard_answers = find_matching_standard_answers( matching_standard_answers = find_matching_standard_answers(
query=query_msg.message, query=query_msg.message,
id_in=[standard_answer.id for standard_answer in usable_standard_answers], id_in=[standard_answer.id for standard_answer in usable_standard_answers],
db_session=db_session, db_session=db_session,
) )
else:
matching_standard_answers = []
if matching_standard_answers: if matching_standard_answers:
chat_session = create_chat_session( chat_session = create_chat_session(
db_session=db_session, db_session=db_session,
@@ -149,12 +152,12 @@ def handle_standard_answers(
) )
formatted_answers = [] formatted_answers = []
for standard_answer in matching_standard_answers: for standard_answer, match_str in matching_standard_answers:
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ") since_you_mentioned_pretext = (
formatted_answer = ( f'Since your question contained "_{match_str}_"'
f'Since you mentioned _"{standard_answer.keyword}"_, '
f"I thought this might be useful: \n\n{block_quotified_answer}"
) )
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
formatted_answer = f"{since_you_mentioned_pretext}, I thought this might be useful: \n\n{block_quotified_answer}"
formatted_answers.append(formatted_answer) formatted_answers.append(formatted_answer)
answer_message = "\n\n".join(formatted_answers) answer_message = "\n\n".join(formatted_answers)

View File

@@ -1371,6 +1371,7 @@ class StandardAnswer(Base):
keyword: Mapped[str] = mapped_column(String) keyword: Mapped[str] = mapped_column(String)
answer: Mapped[str] = mapped_column(String) answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean) active: Mapped[bool] = mapped_column(Boolean)
match_regex: Mapped[bool] = mapped_column(Boolean)
__table_args__ = ( __table_args__ = (
Index( Index(

View File

@@ -1,3 +1,4 @@
import re
import string import string
from collections.abc import Sequence from collections.abc import Sequence
@@ -41,6 +42,7 @@ def insert_standard_answer(
keyword: str, keyword: str,
answer: str, answer: str,
category_ids: list[int], category_ids: list[int],
match_regex: bool,
db_session: Session, db_session: Session,
) -> StandardAnswer: ) -> StandardAnswer:
existing_categories = fetch_standard_answer_categories_by_ids( existing_categories = fetch_standard_answer_categories_by_ids(
@@ -55,6 +57,7 @@ def insert_standard_answer(
answer=answer, answer=answer,
categories=existing_categories, categories=existing_categories,
active=True, active=True,
match_regex=match_regex,
) )
db_session.add(standard_answer) db_session.add(standard_answer)
db_session.commit() db_session.commit()
@@ -66,6 +69,7 @@ def update_standard_answer(
keyword: str, keyword: str,
answer: str, answer: str,
category_ids: list[int], category_ids: list[int],
match_regex: bool,
db_session: Session, db_session: Session,
) -> StandardAnswer: ) -> StandardAnswer:
standard_answer = db_session.scalar( standard_answer = db_session.scalar(
@@ -84,6 +88,7 @@ def update_standard_answer(
standard_answer.keyword = keyword standard_answer.keyword = keyword
standard_answer.answer = answer standard_answer.answer = answer
standard_answer.categories = list(existing_categories) standard_answer.categories = list(existing_categories)
standard_answer.match_regex = match_regex
db_session.commit() db_session.commit()
@@ -181,31 +186,51 @@ def find_matching_standard_answers(
id_in: list[int], id_in: list[int],
query: str, query: str,
db_session: Session, db_session: Session,
) -> list[StandardAnswer]: ) -> list[tuple[StandardAnswer, str]]:
"""
Returns a list of tuples, where each tuple is a StandardAnswer definition matching
the query and a string representing the match (either the regex match group or the
set of keywords).
If `answer_instance.match_regex` is true, the definition is considered "matched"
if the query matches the `answer_instance.keyword` using `re.search`.
Otherwise, the definition is considered "matched" if each space-delimited token
in `keyword` exists in `query`.
"""
stmt = ( stmt = (
select(StandardAnswer) select(StandardAnswer)
.where(StandardAnswer.active.is_(True)) .where(StandardAnswer.active.is_(True))
.where(StandardAnswer.id.in_(id_in)) .where(StandardAnswer.id.in_(id_in))
) )
possible_standard_answers = db_session.scalars(stmt).all() possible_standard_answers: Sequence[StandardAnswer] = db_session.scalars(stmt).all()
matching_standard_answers: list[StandardAnswer] = [] matching_standard_answers: list[tuple[StandardAnswer, str]] = []
for standard_answer in possible_standard_answers: for standard_answer in possible_standard_answers:
# Remove punctuation and split the keyword into individual words if standard_answer.match_regex:
keyword_words = "".join( maybe_matches = re.search(standard_answer.keyword, query, re.IGNORECASE)
char if maybe_matches is not None:
for char in standard_answer.keyword.lower() match_group = maybe_matches.group(0)
if char not in string.punctuation matching_standard_answers.append((standard_answer, match_group))
).split()
# Remove punctuation and split the query into individual words else:
query_words = "".join( # Remove punctuation and split the keyword into individual words
char for char in query.lower() if char not in string.punctuation keyword_words = "".join(
).split() char
for char in standard_answer.keyword.lower()
if char not in string.punctuation
).split()
# Check if all of the keyword words are in the query words # Remove punctuation and split the query into individual words
if all(word in query_words for word in keyword_words): query_words = "".join(
matching_standard_answers.append(standard_answer) char for char in query.lower() if char not in string.punctuation
).split()
# Check if all of the keyword words are in the query words
if all(word in query_words for word in keyword_words):
matching_standard_answers.append(
(standard_answer, standard_answer.keyword)
)
return matching_standard_answers return matching_standard_answers

View File

@@ -1,4 +1,6 @@
import re
from datetime import datetime from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from pydantic import BaseModel from pydantic import BaseModel
@@ -140,6 +142,7 @@ class StandardAnswer(BaseModel):
keyword: str keyword: str
answer: str answer: str
categories: list[StandardAnswerCategory] categories: list[StandardAnswerCategory]
match_regex: bool
@classmethod @classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer": def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
@@ -147,6 +150,7 @@ class StandardAnswer(BaseModel):
id=standard_answer_model.id, id=standard_answer_model.id,
keyword=standard_answer_model.keyword, keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer, answer=standard_answer_model.answer,
match_regex=standard_answer_model.match_regex,
categories=[ categories=[
StandardAnswerCategory.from_model(standard_answer_category_model) StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories for standard_answer_category_model in standard_answer_model.categories
@@ -158,6 +162,7 @@ class StandardAnswerCreationRequest(BaseModel):
keyword: str keyword: str
answer: str answer: str
categories: list[int] categories: list[int]
match_regex: bool
@field_validator("categories", mode="before") @field_validator("categories", mode="before")
@classmethod @classmethod
@@ -168,6 +173,28 @@ class StandardAnswerCreationRequest(BaseModel):
) )
return value return value
@model_validator(mode="after")
def validate_keyword_if_regex(self) -> Any:
if not self.match_regex:
# no validation for keywords
return self
try:
re.compile(self.keyword)
return self
except re.error as err:
if isinstance(err.pattern, bytes):
raise ValueError(
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
)
else:
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
raise ValueError(
" ".join(
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
)
)
class SlackBotTokens(BaseModel): class SlackBotTokens(BaseModel):
bot_token: str bot_token: str

View File

@@ -33,6 +33,7 @@ def create_standard_answer(
keyword=standard_answer_creation_request.keyword, keyword=standard_answer_creation_request.keyword,
answer=standard_answer_creation_request.answer, answer=standard_answer_creation_request.answer,
category_ids=standard_answer_creation_request.categories, category_ids=standard_answer_creation_request.categories,
match_regex=standard_answer_creation_request.match_regex,
db_session=db_session, db_session=db_session,
) )
return StandardAnswer.from_model(standard_answer_model) return StandardAnswer.from_model(standard_answer_model)
@@ -70,6 +71,7 @@ def patch_standard_answer(
keyword=standard_answer_creation_request.keyword, keyword=standard_answer_creation_request.keyword,
answer=standard_answer_creation_request.answer, answer=standard_answer_creation_request.answer,
category_ids=standard_answer_creation_request.categories, category_ids=standard_answer_creation_request.categories,
match_regex=standard_answer_creation_request.match_regex,
db_session=db_session, db_session=db_session,
) )
return StandardAnswer.from_model(standard_answer_model) return StandardAnswer.from_model(standard_answer_model)

View File

@@ -14,6 +14,7 @@ import {
import { import {
TextFormField, TextFormField,
MarkdownFormField, MarkdownFormField,
BooleanFormField,
} from "@/components/admin/connectors/Field"; } from "@/components/admin/connectors/Field";
import MultiSelectDropdown from "@/components/MultiSelectDropdown"; import MultiSelectDropdown from "@/components/MultiSelectDropdown";
@@ -41,10 +42,13 @@ export const StandardAnswerCreationForm = ({
categories: existingStandardAnswer categories: existingStandardAnswer
? existingStandardAnswer.categories ? existingStandardAnswer.categories
: [], : [],
matchRegex: existingStandardAnswer
? existingStandardAnswer.match_regex
: false,
}} }}
validationSchema={Yup.object().shape({ validationSchema={Yup.object().shape({
keyword: Yup.string() keyword: Yup.string()
.required("Keyword or phrase is required") .required("Keywords or pattern is required")
.max(255) .max(255)
.min(1), .min(1),
answer: Yup.string().required("Answer is required").min(1), answer: Yup.string().required("Answer is required").min(1),
@@ -86,18 +90,34 @@ export const StandardAnswerCreationForm = ({
> >
{({ isSubmitting, values, setFieldValue }) => ( {({ isSubmitting, values, setFieldValue }) => (
<Form> <Form>
<TextFormField {values.matchRegex ? (
name="keyword" <TextFormField
label="Keywords" name="keyword"
tooltip="If all specified keywords are in the question, then we will respond with the answer below" label="Regex pattern"
placeholder="e.g. Wifi Password" isCode
autoCompleteDisabled={true} tooltip="Triggers if the question matches this regex pattern (using Python `re.search()`)"
placeholder="(?:it|support)\s*ticket"
/>
) : (
<TextFormField
name="keyword"
label="Keywords"
tooltip="Triggers if the question contains all of these keywords, in any order."
placeholder="it ticket"
autoCompleteDisabled={true}
/>
)}
<BooleanFormField
subtext="Match a regex pattern instead of an exact keyword"
optional
label="Match regex"
name="matchRegex"
/> />
<div className="w-full"> <div className="w-full">
<MarkdownFormField <MarkdownFormField
name="answer" name="answer"
label="Answer" label="Answer"
placeholder="The answer in markdown" placeholder="The answer in Markdown. Example: If you need any help from the IT team, please email internalsupport@company.com"
/> />
</div> </div>
<div className="w-4/12"> <div className="w-4/12">

View File

@@ -6,6 +6,7 @@ export interface StandardAnswerCreationRequest {
keyword: string; keyword: string;
answer: string; answer: string;
categories: number[]; categories: number[];
matchRegex: boolean;
} }
const buildRequestBodyFromStandardAnswerCategoryCreationRequest = ( const buildRequestBodyFromStandardAnswerCategoryCreationRequest = (
@@ -48,6 +49,7 @@ const buildRequestBodyFromStandardAnswerCreationRequest = (
keyword: request.keyword, keyword: request.keyword,
answer: request.answer, answer: request.answer,
categories: request.categories, categories: request.categories,
match_regex: request.matchRegex,
}); });
}; };

View File

@@ -26,6 +26,7 @@ import { FilterDropdown } from "@/components/search/filtering/FilterDropdown";
import { FiTag } from "react-icons/fi"; import { FiTag } from "react-icons/fi";
import { SelectedBubble } from "@/components/search/filtering/Filters"; import { SelectedBubble } from "@/components/search/filtering/Filters";
import { PageSelector } from "@/components/PageSelector"; import { PageSelector } from "@/components/PageSelector";
import { CustomCheckbox } from "@/components/CustomCheckbox";
const NUM_RESULTS_PER_PAGE = 10; const NUM_RESULTS_PER_PAGE = 10;
@@ -36,15 +37,23 @@ const RowTemplate = ({
entries, entries,
}: { }: {
id: number; id: number;
entries: [Displayable, Displayable, Displayable, Displayable, Displayable]; entries: [
Displayable,
Displayable,
Displayable,
Displayable,
Displayable,
Displayable,
];
}) => { }) => {
return ( return (
<TableRow key={id}> <TableRow key={id}>
<TableCell className="w-1/24">{entries[0]}</TableCell> <TableCell className="w-1/24">{entries[0]}</TableCell>
<TableCell className="w-2/12">{entries[1]}</TableCell> <TableCell className="w-2/12">{entries[1]}</TableCell>
<TableCell className="w-2/12">{entries[2]}</TableCell> <TableCell className="w-2/12">{entries[2]}</TableCell>
<TableCell className="w-7/12 overflow-auto">{entries[3]}</TableCell> <TableCell className="w-1/24">{entries[3]}</TableCell>
<TableCell className="w-1/24">{entries[4]}</TableCell> <TableCell className="w-7/12 overflow-auto">{entries[4]}</TableCell>
<TableCell className="w-1/24">{entries[5]}</TableCell>
</TableRow> </TableRow>
); );
}; };
@@ -108,7 +117,15 @@ const StandardAnswersTableRow = ({
<CategoryBubble key={category.id} name={category.name} /> <CategoryBubble key={category.id} name={category.name} />
))} ))}
</div>, </div>,
standardAnswer.keyword, <ReactMarkdown key={`keyword-${standardAnswer.id}`}>
{standardAnswer.match_regex
? `\`${standardAnswer.keyword}\``
: standardAnswer.keyword}
</ReactMarkdown>,
<CustomCheckbox
key={`match_regex-${standardAnswer.id}`}
checked={standardAnswer.match_regex}
/>,
<ReactMarkdown <ReactMarkdown
key={`answer-${standardAnswer.id}`} key={`answer-${standardAnswer.id}`}
className="prose" className="prose"
@@ -147,13 +164,15 @@ const StandardAnswersTable = ({
const columns = [ const columns = [
{ name: "", key: "edit" }, { name: "", key: "edit" },
{ name: "Categories", key: "category" }, { name: "Categories", key: "category" },
{ name: "Keyword/Phrase", key: "keyword" }, { name: "Keywords/Pattern", key: "keyword" },
{ name: "Match regex?", key: "match_regex" },
{ name: "Answer", key: "answer" }, { name: "Answer", key: "answer" },
{ name: "", key: "delete" }, { name: "", key: "delete" },
]; ];
const filteredStandardAnswers = standardAnswers.filter((standardAnswer) => { const filteredStandardAnswers = standardAnswers.filter((standardAnswer) => {
const { answer, id, categories, ...fieldsToSearch } = standardAnswer; const { answer, id, categories, match_regex, ...fieldsToSearch } =
standardAnswer;
const cleanedQuery = query.toLowerCase(); const cleanedQuery = query.toLowerCase();
const searchMatch = Object.values(fieldsToSearch).some((value) => { const searchMatch = Object.values(fieldsToSearch).some((value) => {
return value.toLowerCase().includes(cleanedQuery); return value.toLowerCase().includes(cleanedQuery);
@@ -285,7 +304,7 @@ const StandardAnswersTable = ({
/> />
)) ))
) : ( ) : (
<RowTemplate id={0} entries={["", "", "", "", ""]} /> <RowTemplate id={0} entries={["", "", "", "", "", ""]} />
)} )}
</TableBody> </TableBody>
</Table> </Table>

View File

@@ -161,6 +161,7 @@ export interface StandardAnswer {
id: number; id: number;
keyword: string; keyword: string;
answer: string; answer: string;
match_regex: boolean;
categories: StandardAnswerCategory[]; categories: StandardAnswerCategory[];
} }