mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
Match any/all keywords in Standard Answers (#2443)
* migration: add column "match_any_keywords" to StandardAnswer * Implement any/all keyword matching for standard answers * Add match_any_keywords to non-searchable fields * Remove stray print * Simplify Slack messages for any and all cases --------- Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
"""match_any_keywords flag for standard answers
|
||||
|
||||
Revision ID: 5c7fdadae813
|
||||
Revises: efb35676026c
|
||||
Create Date: 2024-09-13 18:52:59.256478
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5c7fdadae813"
|
||||
down_revision = "efb35676026c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"standard_answer",
|
||||
sa.Column(
|
||||
"match_any_keywords",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("standard_answer", "match_any_keywords")
|
||||
# ### end Alembic commands ###
|
@@ -1372,6 +1372,7 @@ class StandardAnswer(Base):
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
match_regex: Mapped[bool] = mapped_column(Boolean)
|
||||
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
|
@@ -41,6 +41,7 @@ def insert_standard_answer(
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
@@ -56,6 +57,7 @@ def insert_standard_answer(
|
||||
categories=existing_categories,
|
||||
active=True,
|
||||
match_regex=match_regex,
|
||||
match_any_keywords=match_any_keywords,
|
||||
)
|
||||
db_session.add(standard_answer)
|
||||
db_session.commit()
|
||||
@@ -68,6 +70,7 @@ def update_standard_answer(
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
standard_answer = db_session.scalar(
|
||||
@@ -87,6 +90,7 @@ def update_standard_answer(
|
||||
standard_answer.answer = answer
|
||||
standard_answer.categories = list(existing_categories)
|
||||
standard_answer.match_regex = match_regex
|
||||
standard_answer.match_any_keywords = match_any_keywords
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
@@ -143,6 +143,7 @@ class StandardAnswer(BaseModel):
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
@@ -151,6 +152,7 @@ class StandardAnswer(BaseModel):
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
match_any_keywords=standard_answer_model.match_any_keywords,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
@@ -163,6 +165,7 @@ class StandardAnswerCreationRequest(BaseModel):
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
@@ -173,6 +176,15 @@ class StandardAnswerCreationRequest(BaseModel):
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_only_match_any_if_not_regex(self) -> Any:
|
||||
if self.match_regex and self.match_any_keywords:
|
||||
raise ValueError(
|
||||
"Can only match any keywords in keyword mode, not regex mode"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
|
@@ -174,7 +174,7 @@ def _handle_standard_answers(
|
||||
formatted_answers = []
|
||||
for standard_answer, match_str in matching_standard_answers:
|
||||
since_you_mentioned_pretext = (
|
||||
f'Since your question contained "_{match_str}_"'
|
||||
f'Since your question contains "_{match_str}_"'
|
||||
)
|
||||
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
|
||||
formatted_answer = f"{since_you_mentioned_pretext}, I thought this might be useful: \n\n{block_quotified_answer}"
|
||||
|
@@ -36,8 +36,8 @@ def find_matching_standard_answers(
|
||||
If `answer_instance.match_regex` is true, the definition is considered "matched"
|
||||
if the query matches the `answer_instance.keyword` using `re.search`.
|
||||
|
||||
Otherwise, the definition is considered "matched" if each space-delimited token
|
||||
in `keyword` exists in `query`.
|
||||
Otherwise, the definition is considered "matched" if the space-delimited tokens
|
||||
in `keyword` exists in `query`, depending on the state of `match_any_keywords`
|
||||
"""
|
||||
stmt = (
|
||||
select(StandardAnswer)
|
||||
@@ -56,11 +56,13 @@ def find_matching_standard_answers(
|
||||
|
||||
else:
|
||||
# Remove punctuation and split the keyword into individual words
|
||||
keyword_words = "".join(
|
||||
char
|
||||
for char in standard_answer.keyword.lower()
|
||||
if char not in string.punctuation
|
||||
).split()
|
||||
keyword_words = set(
|
||||
"".join(
|
||||
char
|
||||
for char in standard_answer.keyword.lower()
|
||||
if char not in string.punctuation
|
||||
).split()
|
||||
)
|
||||
|
||||
# Remove punctuation and split the query into individual words
|
||||
query_words = "".join(
|
||||
@@ -68,9 +70,18 @@ def find_matching_standard_answers(
|
||||
).split()
|
||||
|
||||
# Check if all of the keyword words are in the query words
|
||||
if all(word in query_words for word in keyword_words):
|
||||
matching_standard_answers.append(
|
||||
(standard_answer, standard_answer.keyword)
|
||||
)
|
||||
if standard_answer.match_any_keywords:
|
||||
for word in query_words:
|
||||
if word in keyword_words:
|
||||
matching_standard_answers.append((standard_answer, word))
|
||||
break
|
||||
else:
|
||||
if all(word in query_words for word in keyword_words):
|
||||
matching_standard_answers.append(
|
||||
(
|
||||
standard_answer,
|
||||
re.sub(r"\s+?", ", ", standard_answer.keyword),
|
||||
)
|
||||
)
|
||||
|
||||
return matching_standard_answers
|
||||
|
@@ -34,6 +34,7 @@ def create_standard_answer(
|
||||
answer=standard_answer_creation_request.answer,
|
||||
category_ids=standard_answer_creation_request.categories,
|
||||
match_regex=standard_answer_creation_request.match_regex,
|
||||
match_any_keywords=standard_answer_creation_request.match_any_keywords,
|
||||
db_session=db_session,
|
||||
)
|
||||
return StandardAnswer.from_model(standard_answer_model)
|
||||
@@ -72,6 +73,7 @@ def patch_standard_answer(
|
||||
answer=standard_answer_creation_request.answer,
|
||||
category_ids=standard_answer_creation_request.categories,
|
||||
match_regex=standard_answer_creation_request.match_regex,
|
||||
match_any_keywords=standard_answer_creation_request.match_any_keywords,
|
||||
db_session=db_session,
|
||||
)
|
||||
return StandardAnswer.from_model(standard_answer_model)
|
||||
|
@@ -9,15 +9,25 @@ import * as Yup from "yup";
|
||||
import {
|
||||
createStandardAnswer,
|
||||
createStandardAnswerCategory,
|
||||
StandardAnswerCreationRequest,
|
||||
updateStandardAnswer,
|
||||
} from "./lib";
|
||||
import {
|
||||
TextFormField,
|
||||
MarkdownFormField,
|
||||
BooleanFormField,
|
||||
SelectorFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import MultiSelectDropdown from "@/components/MultiSelectDropdown";
|
||||
|
||||
function mapKeywordSelectToMatchAny(keywordSelect: "any" | "all"): boolean {
|
||||
return keywordSelect == "any";
|
||||
}
|
||||
|
||||
function mapMatchAnyToKeywordSelect(matchAny: boolean): "any" | "all" {
|
||||
return matchAny ? "any" : "all";
|
||||
}
|
||||
|
||||
export const StandardAnswerCreationForm = ({
|
||||
standardAnswerCategories,
|
||||
existingStandardAnswer,
|
||||
@@ -45,6 +55,11 @@ export const StandardAnswerCreationForm = ({
|
||||
matchRegex: existingStandardAnswer
|
||||
? existingStandardAnswer.match_regex
|
||||
: false,
|
||||
matchAnyKeywords: existingStandardAnswer
|
||||
? mapMatchAnyToKeywordSelect(
|
||||
existingStandardAnswer.match_any_keywords
|
||||
)
|
||||
: "all",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
keyword: Yup.string()
|
||||
@@ -59,8 +74,11 @@ export const StandardAnswerCreationForm = ({
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
|
||||
const cleanedValues = {
|
||||
const cleanedValues: StandardAnswerCreationRequest = {
|
||||
...values,
|
||||
matchAnyKeywords: mapKeywordSelectToMatchAny(
|
||||
values.matchAnyKeywords
|
||||
),
|
||||
categories: values.categories.map((category) => category.id),
|
||||
};
|
||||
|
||||
@@ -98,11 +116,19 @@ export const StandardAnswerCreationForm = ({
|
||||
tooltip="Triggers if the question matches this regex pattern (using Python `re.search()`)"
|
||||
placeholder="(?:it|support)\s*ticket"
|
||||
/>
|
||||
) : values.matchAnyKeywords == "any" ? (
|
||||
<TextFormField
|
||||
name="keyword"
|
||||
label="Any of these keywords, separated by spaces"
|
||||
tooltip="A question must match these keywords in order to trigger the answer."
|
||||
placeholder="ticket problem issue"
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="keyword"
|
||||
label="Keywords"
|
||||
tooltip="Triggers if the question contains all of these keywords, in any order."
|
||||
label="All of these keywords, in any order, separated by spaces"
|
||||
tooltip="A question must match these keywords in order to trigger the answer."
|
||||
placeholder="it ticket"
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
@@ -113,6 +139,27 @@ export const StandardAnswerCreationForm = ({
|
||||
label="Match regex"
|
||||
name="matchRegex"
|
||||
/>
|
||||
{values.matchRegex ? null : (
|
||||
<SelectorFormField
|
||||
defaultValue={`all`}
|
||||
label="Keyword detection strategy"
|
||||
subtext="Choose whether to require the user's question to contain any or all of the keywords above to show this answer."
|
||||
name="matchAnyKeywords"
|
||||
options={[
|
||||
{
|
||||
name: "All keywords",
|
||||
value: "all",
|
||||
},
|
||||
{
|
||||
name: "Any keywords",
|
||||
value: "any",
|
||||
},
|
||||
]}
|
||||
onSelect={(selected) => {
|
||||
setFieldValue("matchAnyKeywords", selected);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<div className="w-full">
|
||||
<MarkdownFormField
|
||||
name="answer"
|
||||
|
@@ -7,6 +7,7 @@ export interface StandardAnswerCreationRequest {
|
||||
answer: string;
|
||||
categories: number[];
|
||||
matchRegex: boolean;
|
||||
matchAnyKeywords: boolean;
|
||||
}
|
||||
|
||||
const buildRequestBodyFromStandardAnswerCategoryCreationRequest = (
|
||||
@@ -50,6 +51,7 @@ const buildRequestBodyFromStandardAnswerCreationRequest = (
|
||||
answer: request.answer,
|
||||
categories: request.categories,
|
||||
match_regex: request.matchRegex,
|
||||
match_any_keywords: request.matchAnyKeywords,
|
||||
});
|
||||
};
|
||||
|
||||
|
@@ -171,8 +171,14 @@ const StandardAnswersTable = ({
|
||||
];
|
||||
|
||||
const filteredStandardAnswers = standardAnswers.filter((standardAnswer) => {
|
||||
const { answer, id, categories, match_regex, ...fieldsToSearch } =
|
||||
standardAnswer;
|
||||
const {
|
||||
answer,
|
||||
id,
|
||||
categories,
|
||||
match_regex,
|
||||
match_any_keywords,
|
||||
...fieldsToSearch
|
||||
} = standardAnswer;
|
||||
const cleanedQuery = query.toLowerCase();
|
||||
const searchMatch = Object.values(fieldsToSearch).some((value) => {
|
||||
return value.toLowerCase().includes(cleanedQuery);
|
||||
|
@@ -162,6 +162,7 @@ export interface StandardAnswer {
|
||||
keyword: string;
|
||||
answer: string;
|
||||
match_regex: boolean;
|
||||
match_any_keywords: boolean;
|
||||
categories: StandardAnswerCategory[];
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user