Match any/all keywords in Standard Answers (#2443)

* migration: add column "match_any_keywords" to StandardAnswer

* Implement any/all keyword matching for standard answers

* Add match_any_keywords to non-searchable fields

* Remove stray print

* Simplify Slack messages for any and all cases

---------

Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
This commit is contained in:
trial-danswer
2024-09-13 22:28:07 -07:00
committed by GitHub
parent 974f85da66
commit 430c9a47d7
11 changed files with 138 additions and 17 deletions

View File

@@ -0,0 +1,35 @@
"""match_any_keywords flag for standard answers
Revision ID: 5c7fdadae813
Revises: efb35676026c
Create Date: 2024-09-13 18:52:59.256478
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5c7fdadae813"
down_revision = "efb35676026c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column(
"match_any_keywords",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("standard_answer", "match_any_keywords")
# ### end Alembic commands ###

View File

@@ -1372,6 +1372,7 @@ class StandardAnswer(Base):
answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
match_regex: Mapped[bool] = mapped_column(Boolean)
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
__table_args__ = (
Index(

View File

@@ -41,6 +41,7 @@ def insert_standard_answer(
answer: str,
category_ids: list[int],
match_regex: bool,
match_any_keywords: bool,
db_session: Session,
) -> StandardAnswer:
existing_categories = fetch_standard_answer_categories_by_ids(
@@ -56,6 +57,7 @@ def insert_standard_answer(
categories=existing_categories,
active=True,
match_regex=match_regex,
match_any_keywords=match_any_keywords,
)
db_session.add(standard_answer)
db_session.commit()
@@ -68,6 +70,7 @@ def update_standard_answer(
answer: str,
category_ids: list[int],
match_regex: bool,
match_any_keywords: bool,
db_session: Session,
) -> StandardAnswer:
standard_answer = db_session.scalar(
@@ -87,6 +90,7 @@ def update_standard_answer(
standard_answer.answer = answer
standard_answer.categories = list(existing_categories)
standard_answer.match_regex = match_regex
standard_answer.match_any_keywords = match_any_keywords
db_session.commit()

View File

@@ -143,6 +143,7 @@ class StandardAnswer(BaseModel):
answer: str
categories: list[StandardAnswerCategory]
match_regex: bool
match_any_keywords: bool
@classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
@@ -151,6 +152,7 @@ class StandardAnswer(BaseModel):
keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer,
match_regex=standard_answer_model.match_regex,
match_any_keywords=standard_answer_model.match_any_keywords,
categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories
@@ -163,6 +165,7 @@ class StandardAnswerCreationRequest(BaseModel):
answer: str
categories: list[int]
match_regex: bool
match_any_keywords: bool
@field_validator("categories", mode="before")
@classmethod
@@ -173,6 +176,15 @@ class StandardAnswerCreationRequest(BaseModel):
)
return value
@model_validator(mode="after")
def validate_only_match_any_if_not_regex(self) -> Any:
if self.match_regex and self.match_any_keywords:
raise ValueError(
"Can only match any keywords in keyword mode, not regex mode"
)
return self
@model_validator(mode="after")
def validate_keyword_if_regex(self) -> Any:
if not self.match_regex:

View File

@@ -174,7 +174,7 @@ def _handle_standard_answers(
formatted_answers = []
for standard_answer, match_str in matching_standard_answers:
since_you_mentioned_pretext = (
f'Since your question contained "_{match_str}_"'
f'Since your question contains "_{match_str}_"'
)
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
formatted_answer = f"{since_you_mentioned_pretext}, I thought this might be useful: \n\n{block_quotified_answer}"

View File

@@ -36,8 +36,8 @@ def find_matching_standard_answers(
If `answer_instance.match_regex` is true, the definition is considered "matched"
if the query matches the `answer_instance.keyword` using `re.search`.
Otherwise, the definition is considered "matched" if each space-delimited token
in `keyword` exists in `query`.
Otherwise, the definition is considered "matched" if the space-delimited tokens
in `keyword` exists in `query`, depending on the state of `match_any_keywords`
"""
stmt = (
select(StandardAnswer)
@@ -56,11 +56,13 @@ def find_matching_standard_answers(
else:
# Remove punctuation and split the keyword into individual words
keyword_words = "".join(
char
for char in standard_answer.keyword.lower()
if char not in string.punctuation
).split()
keyword_words = set(
"".join(
char
for char in standard_answer.keyword.lower()
if char not in string.punctuation
).split()
)
# Remove punctuation and split the query into individual words
query_words = "".join(
@@ -68,9 +70,18 @@ def find_matching_standard_answers(
).split()
# Check if all of the keyword words are in the query words
if all(word in query_words for word in keyword_words):
matching_standard_answers.append(
(standard_answer, standard_answer.keyword)
)
if standard_answer.match_any_keywords:
for word in query_words:
if word in keyword_words:
matching_standard_answers.append((standard_answer, word))
break
else:
if all(word in query_words for word in keyword_words):
matching_standard_answers.append(
(
standard_answer,
re.sub(r"\s+?", ", ", standard_answer.keyword),
)
)
return matching_standard_answers

View File

@@ -34,6 +34,7 @@ def create_standard_answer(
answer=standard_answer_creation_request.answer,
category_ids=standard_answer_creation_request.categories,
match_regex=standard_answer_creation_request.match_regex,
match_any_keywords=standard_answer_creation_request.match_any_keywords,
db_session=db_session,
)
return StandardAnswer.from_model(standard_answer_model)
@@ -72,6 +73,7 @@ def patch_standard_answer(
answer=standard_answer_creation_request.answer,
category_ids=standard_answer_creation_request.categories,
match_regex=standard_answer_creation_request.match_regex,
match_any_keywords=standard_answer_creation_request.match_any_keywords,
db_session=db_session,
)
return StandardAnswer.from_model(standard_answer_model)

View File

@@ -9,15 +9,25 @@ import * as Yup from "yup";
import {
createStandardAnswer,
createStandardAnswerCategory,
StandardAnswerCreationRequest,
updateStandardAnswer,
} from "./lib";
import {
TextFormField,
MarkdownFormField,
BooleanFormField,
SelectorFormField,
} from "@/components/admin/connectors/Field";
import MultiSelectDropdown from "@/components/MultiSelectDropdown";
function mapKeywordSelectToMatchAny(keywordSelect: "any" | "all"): boolean {
return keywordSelect == "any";
}
function mapMatchAnyToKeywordSelect(matchAny: boolean): "any" | "all" {
return matchAny ? "any" : "all";
}
export const StandardAnswerCreationForm = ({
standardAnswerCategories,
existingStandardAnswer,
@@ -45,6 +55,11 @@ export const StandardAnswerCreationForm = ({
matchRegex: existingStandardAnswer
? existingStandardAnswer.match_regex
: false,
matchAnyKeywords: existingStandardAnswer
? mapMatchAnyToKeywordSelect(
existingStandardAnswer.match_any_keywords
)
: "all",
}}
validationSchema={Yup.object().shape({
keyword: Yup.string()
@@ -59,8 +74,11 @@ export const StandardAnswerCreationForm = ({
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
const cleanedValues = {
const cleanedValues: StandardAnswerCreationRequest = {
...values,
matchAnyKeywords: mapKeywordSelectToMatchAny(
values.matchAnyKeywords
),
categories: values.categories.map((category) => category.id),
};
@@ -98,11 +116,19 @@ export const StandardAnswerCreationForm = ({
tooltip="Triggers if the question matches this regex pattern (using Python `re.search()`)"
placeholder="(?:it|support)\s*ticket"
/>
) : values.matchAnyKeywords == "any" ? (
<TextFormField
name="keyword"
label="Any of these keywords, separated by spaces"
tooltip="A question must match these keywords in order to trigger the answer."
placeholder="ticket problem issue"
autoCompleteDisabled={true}
/>
) : (
<TextFormField
name="keyword"
label="Keywords"
tooltip="Triggers if the question contains all of these keywords, in any order."
label="All of these keywords, in any order, separated by spaces"
tooltip="A question must match these keywords in order to trigger the answer."
placeholder="it ticket"
autoCompleteDisabled={true}
/>
@@ -113,6 +139,27 @@ export const StandardAnswerCreationForm = ({
label="Match regex"
name="matchRegex"
/>
{values.matchRegex ? null : (
<SelectorFormField
defaultValue={`all`}
label="Keyword detection strategy"
subtext="Choose whether to require the user's question to contain any or all of the keywords above to show this answer."
name="matchAnyKeywords"
options={[
{
name: "All keywords",
value: "all",
},
{
name: "Any keywords",
value: "any",
},
]}
onSelect={(selected) => {
setFieldValue("matchAnyKeywords", selected);
}}
/>
)}
<div className="w-full">
<MarkdownFormField
name="answer"

View File

@@ -7,6 +7,7 @@ export interface StandardAnswerCreationRequest {
answer: string;
categories: number[];
matchRegex: boolean;
matchAnyKeywords: boolean;
}
const buildRequestBodyFromStandardAnswerCategoryCreationRequest = (
@@ -50,6 +51,7 @@ const buildRequestBodyFromStandardAnswerCreationRequest = (
answer: request.answer,
categories: request.categories,
match_regex: request.matchRegex,
match_any_keywords: request.matchAnyKeywords,
});
};

View File

@@ -171,8 +171,14 @@ const StandardAnswersTable = ({
];
const filteredStandardAnswers = standardAnswers.filter((standardAnswer) => {
const { answer, id, categories, match_regex, ...fieldsToSearch } =
standardAnswer;
const {
answer,
id,
categories,
match_regex,
match_any_keywords,
...fieldsToSearch
} = standardAnswer;
const cleanedQuery = query.toLowerCase();
const searchMatch = Object.values(fieldsToSearch).some((value) => {
return value.toLowerCase().includes(cleanedQuery);

View File

@@ -162,6 +162,7 @@ export interface StandardAnswer {
keyword: string;
answer: string;
match_regex: boolean;
match_any_keywords: boolean;
categories: StandardAnswerCategory[];
}