diff --git a/backend/danswer/db/tag.py b/backend/danswer/db/tag.py index 66418b948..688b8a112 100644 --- a/backend/danswer/db/tag.py +++ b/backend/danswer/db/tag.py @@ -1,5 +1,6 @@ from sqlalchemy import delete from sqlalchemy import func +from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.orm import Session @@ -107,18 +108,28 @@ def create_or_add_document_tag_list( def get_tags_by_value_prefix_for_source_types( + tag_key_prefix: str | None, tag_value_prefix: str | None, sources: list[DocumentSource] | None, + limit: int | None, db_session: Session, ) -> list[Tag]: query = select(Tag) - if tag_value_prefix: - query = query.where(Tag.tag_value.startswith(tag_value_prefix)) + if tag_key_prefix or tag_value_prefix: + conditions = [] + if tag_key_prefix: + conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%")) + if tag_value_prefix: + conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%")) + query = query.where(or_(*conditions)) if sources: query = query.where(Tag.source.in_(sources)) + if limit: + query = query.limit(limit) + result = db_session.execute(query) tags = result.scalars().all() diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 43192211b..a7505178b 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -88,6 +88,7 @@ def get_tags( # If this is empty or None, then tags for all sources are considered sources: list[DocumentSource] | None = None, allow_prefix: bool = True, # This is currently the only option + limit: int = 50, _: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> TagResponse: @@ -95,8 +96,10 @@ def get_tags( raise NotImplementedError("Cannot disable prefix match for now") db_tags = get_tags_by_value_prefix_for_source_types( + tag_key_prefix=match_pattern, tag_value_prefix=match_pattern, sources=sources, + limit=limit, db_session=db_session, ) server_tags = [ diff --git a/web/src/app/chat/modal/configuration/FiltersTab.tsx b/web/src/app/chat/modal/configuration/FiltersTab.tsx index 581c15cce..66e559ae5 100644 --- a/web/src/app/chat/modal/configuration/FiltersTab.tsx +++ b/web/src/app/chat/modal/configuration/FiltersTab.tsx @@ -1,7 +1,7 @@ import { useChatContext } from "@/components/context/ChatContext"; import { FilterManager } from "@/lib/hooks"; import { listSourceMetadata } from "@/lib/sources"; -import { useRef, useState } from "react"; +import { useEffect, useRef, useState } from "react"; import { DateRangePicker, DateRangePickerItem, @@ -12,23 +12,46 @@ import { getXDaysAgo } from "@/lib/dateUtils"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; import { Bubble } from "@/components/Bubble"; import { FiX } from "react-icons/fi"; +import { getValidTags } from "@/lib/tags/tagUtils"; +import debounce from "lodash/debounce"; +import { Tag } from "@/lib/types"; export function FiltersTab({ filterManager, }: { filterManager: FilterManager; }): JSX.Element { - const [filterValue, setFilterValue] = useState(""); - const inputRef = useRef(null); - const { availableSources, availableDocumentSets, availableTags } = useChatContext(); + const [filterValue, setFilterValue] = useState(""); + const [filteredTags, setFilteredTags] = useState(availableTags); + const inputRef = useRef(null); + const allSources = listSourceMetadata(); const availableSourceMetadata = allSources.filter((source) => availableSources.includes(source.internalName) ); + const debouncedFetchTags = useRef( + debounce(async (value: string) => { + if (value) { + const fetchedTags = await getValidTags(value); + setFilteredTags(fetchedTags); + } else { + setFilteredTags(availableTags); + } + }, 50) + ).current; + + useEffect(() => { + debouncedFetchTags(filterValue); + + return () => { + debouncedFetchTags.cancel(); + }; + }, [filterValue, availableTags, debouncedFetchTags]); + return (
@@ -210,17 +233,15 @@ export function FiltersTab({
- {availableTags.length > 0 ? ( - availableTags + {filteredTags.length > 0 ? ( + filteredTags .filter( (tag) => !filterManager.selectedTags.some( (selectedTag) => selectedTag.tag_key === tag.tag_key && selectedTag.tag_value === tag.tag_value - ) && - (tag.tag_key.includes(filterValue) || - tag.tag_value.includes(filterValue)) + ) ) .slice(0, 12) .map((tag) => ( diff --git a/web/src/components/search/filtering/TagFilter.tsx b/web/src/components/search/filtering/TagFilter.tsx index 73b17f5b8..ec3a7f38b 100644 --- a/web/src/components/search/filtering/TagFilter.tsx +++ b/web/src/components/search/filtering/TagFilter.tsx @@ -2,6 +2,8 @@ import { containsObject, objectsAreEquivalent } from "@/lib/contains"; import { Tag } from "@/lib/types"; import { useEffect, useRef, useState } from "react"; import { FiTag, FiX } from "react-icons/fi"; +import debounce from "lodash/debounce"; +import { getValidTags } from "@/lib/tags/tagUtils"; export function TagFilter({ tags, @@ -14,6 +16,7 @@ export function TagFilter({ }) { const [filterValue, setFilterValue] = useState(""); const [tagOptionsAreVisible, setTagOptionsAreVisible] = useState(false); + const [filteredTags, setFilteredTags] = useState(tags); const inputRef = useRef(null); const popupRef = useRef(null); @@ -45,14 +48,28 @@ export function TagFilter({ }; }, []); - const filterValueLower = filterValue.toLowerCase(); - const filteredTags = filterValueLower - ? tags.filter( - (tags) => - tags.tag_value.toLowerCase().startsWith(filterValueLower) || - tags.tag_key.toLowerCase().startsWith(filterValueLower) - ) - : tags; + const debouncedFetchTags = useRef( + debounce(async (value: string) => { + if (value) { + const fetchedTags = await getValidTags(value); + setFilteredTags(fetchedTags); + } else { + setFilteredTags(tags); + } + }, 50) + ).current; + + useEffect(() => { + debouncedFetchTags(filterValue); + + return () => { + debouncedFetchTags.cancel(); + }; + }, [filterValue, tags, debouncedFetchTags]); + + const handleFilterChange = (event: React.ChangeEvent) => { + setFilterValue(event.target.value); + }; return (
@@ -61,7 +78,7 @@ export function TagFilter({ className="w-full border border-border py-0.5 px-2 rounded text-sm h-8" placeholder="Find a tag" value={filterValue} - onChange={(event) => setFilterValue(event.target.value)} + onChange={handleFilterChange} onFocus={() => setTagOptionsAreVisible(true)} /> {selectedTags.length > 0 && ( diff --git a/web/src/lib/tags/tagUtils.ts b/web/src/lib/tags/tagUtils.ts new file mode 100644 index 000000000..6f2ac5188 --- /dev/null +++ b/web/src/lib/tags/tagUtils.ts @@ -0,0 +1,23 @@ +import { Tag } from "../types"; + +export async function getValidTags( + matchPattern: string | null = null, + sources: string[] | null = null +): Promise { + const params = new URLSearchParams(); + if (matchPattern) params.append("match_pattern", matchPattern); + if (sources) sources.forEach((source) => params.append("sources", source)); + + const response = await fetch(`/api/query/valid-tags?${params.toString()}`, { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + throw new Error("Failed to fetch valid tags"); + } + + return (await response.json()).tags; +}