diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index e4adca62a..81b123c84 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -213,7 +213,10 @@ def get_connector_indexing_status( for index_attempt in index_attempts: # don't consider index attempts where the connector has been deleted # or the credential has been deleted - if index_attempt.connector_id and index_attempt.credential_id: + if ( + index_attempt.connector_id is not None + and index_attempt.credential_id is not None + ): connector_credential_pair_to_index_attempts[ (index_attempt.connector_id, index_attempt.credential_id) ].append(index_attempt) diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 0cbf35f05..4b7493009 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -7,6 +7,7 @@ from typing import TypeVar from danswer.configs.constants import DocumentSource from danswer.connectors.models import InputType +from danswer.datastores.interfaces import IndexFilter from danswer.db.models import Connector from danswer.db.models import IndexingStatus from pydantic import BaseModel @@ -77,7 +78,7 @@ class QuestionRequest(BaseModel): query: str collection: str use_keyword: bool | None - filters: str | None # string of list[IndexFilter] + filters: list[IndexFilter] | None class SearchResponse(BaseModel): diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 199a67f47..977e3584b 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -27,13 +27,13 @@ logger = setup_logger() router = APIRouter() -@router.get("/semantic-search") +@router.post("/semantic-search") def semantic_search( - question: QuestionRequest = Depends(), user: User = Depends(current_user) + question: QuestionRequest, user: User = Depends(current_user) ) -> SearchResponse: query = question.query collection = question.collection - filters = json.loads(question.filters) if question.filters is not None else None + filters = question.filters logger.info(f"Received semantic search query: {query}") user_id = None if user is None else int(user.id) @@ -49,13 +49,13 @@ def semantic_search( return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=other_top_docs) -@router.get("/keyword-search", response_model=SearchResponse) +@router.post("/keyword-search") def keyword_search( - question: QuestionRequest = Depends(), user: User = Depends(current_user) + question: QuestionRequest, user: User = Depends(current_user) ) -> SearchResponse: query = question.query collection = question.collection - filters = json.loads(question.filters) if question.filters is not None else None + filters = question.filters logger.info(f"Received keyword search query: {query}") user_id = None if user is None else int(user.id) @@ -69,15 +69,15 @@ def keyword_search( return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=None) -@router.get("/direct-qa", response_model=QAResponse) +@router.post("/direct-qa") def direct_qa( - question: QuestionRequest = Depends(), user: User = Depends(current_user) + question: QuestionRequest, user: User = Depends(current_user) ) -> QAResponse: start_time = time.time() query = question.query collection = question.collection - filters = json.loads(question.filters) if question.filters is not None else None + filters = question.filters use_keyword = question.use_keyword logger.info(f"Received QA query: {query}") @@ -115,9 +115,9 @@ def direct_qa( ) -@router.get("/stream-direct-qa") +@router.post("/stream-direct-qa") def stream_direct_qa( - question: QuestionRequest = Depends(), user: User = Depends(current_user) + question: QuestionRequest, user: User = Depends(current_user) ) -> StreamingResponse: top_documents_key = "top_documents" unranked_top_docs_key = "unranked_top_documents" @@ -125,7 +125,7 @@ def stream_direct_qa( def stream_qa_portions() -> Generator[str, None, None]: query = question.query collection = question.collection - filters = json.loads(question.filters) if question.filters is not None else None + filters = question.filters use_keyword = question.use_keyword logger.info(f"Received QA query: {query}") diff --git a/web/src/app/page.tsx b/web/src/app/page.tsx index bf01562f9..aab5736ed 100644 --- a/web/src/app/page.tsx +++ b/web/src/app/page.tsx @@ -23,7 +23,7 @@ export default async function Home() {
-
+
diff --git a/web/src/components/search/Filters.tsx b/web/src/components/search/Filters.tsx new file mode 100644 index 000000000..3c6e82e70 --- /dev/null +++ b/web/src/components/search/Filters.tsx @@ -0,0 +1,59 @@ +import React from "react"; +import { Source } from "./interfaces"; +import { getSourceIcon } from "../source"; +import { Funnel } from "@phosphor-icons/react"; + +interface SourceSelectorProps { + selectedSources: Source[]; + setSelectedSources: React.Dispatch>; +} + +const sources: Source[] = [ + { displayName: "Google Drive", internalName: "google_drive" }, + { displayName: "Slack", internalName: "slack" }, + { displayName: "Confluence", internalName: "confluence" }, + { displayName: "Github PRs", internalName: "github" }, + { displayName: "Web", internalName: "web" }, +]; + +export function SourceSelector({ + selectedSources, + setSelectedSources, +}: SourceSelectorProps) { + const handleSelect = (source: Source) => { + setSelectedSources((prev: Source[]) => { + if (prev.includes(source)) { + return prev.filter((s) => s.internalName !== source.internalName); + } else { + return [...prev, source]; + } + }); + }; + + return ( +
+
+

Filters

+ +
+ {sources.map((source) => ( +
handleSelect(source)} + > + {getSourceIcon(source.internalName, "16")} + + {source.displayName} + +
+ ))} +
+ ); +} diff --git a/web/src/components/search/SearchResultsDisplay.tsx b/web/src/components/search/SearchResultsDisplay.tsx index 52e73817a..107629d45 100644 --- a/web/src/components/search/SearchResultsDisplay.tsx +++ b/web/src/components/search/SearchResultsDisplay.tsx @@ -45,11 +45,7 @@ export const SearchResultsDisplay: React.FC = ({ } if (answer === null && documents === null && quotes === null) { - return ( -
- Something went wrong, please try again. -
- ); + return
No matching documents found.
; } const dedupedQuotes: Quote[] = []; diff --git a/web/src/components/search/SearchSection.tsx b/web/src/components/search/SearchSection.tsx index 0a2e80963..38aca87dc 100644 --- a/web/src/components/search/SearchSection.tsx +++ b/web/src/components/search/SearchSection.tsx @@ -4,6 +4,8 @@ import { useState } from "react"; import { SearchBar } from "./SearchBar"; import { SearchResultsDisplay } from "./SearchResultsDisplay"; import { Quote, Document, SearchResponse } from "./types"; +import { SourceSelector } from "./Filters"; +import { Source } from "./interfaces"; const initialSearchResponse: SearchResponse = { answer: null, @@ -55,24 +57,44 @@ const processRawChunkString = ( return [parsedChunkSections, currPartialChunk]; }; -const searchRequestStreamed = async ( - query: string, - updateCurrentAnswer: (val: string) => void, - updateQuotes: (quotes: Record) => void, - updateDocs: (docs: Document[]) => void -) => { - const url = new URL("/api/stream-direct-qa", window.location.origin); - const params = new URLSearchParams({ - query, - collection: "danswer_index", - }).toString(); - url.search = params; +interface SearchRequestStreamedArgs { + query: string; + sources: Source[]; + updateCurrentAnswer: (val: string) => void; + updateQuotes: (quotes: Record) => void; + updateDocs: (docs: Document[]) => void; +} +const searchRequestStreamed = async ({ + query, + sources, + updateCurrentAnswer, + updateQuotes, + updateDocs, +}: SearchRequestStreamedArgs) => { let answer = ""; let quotes: Record | null = null; let relevantDocuments: Document[] | null = null; try { - const response = await fetch(url); + const response = await fetch("/api/stream-direct-qa", { + method: "POST", + body: JSON.stringify({ + query, + collection: "danswer_index", + ...(sources.length > 0 + ? { + filters: [ + { + source_type: sources.map((source) => source.internalName), + }, + ], + } + : {}), + }), + headers: { + "Content-Type": "application/json", + }, + }); const reader = response.body?.getReader(); const decoder = new TextDecoder("utf-8"); @@ -139,49 +161,62 @@ const searchRequestStreamed = async ( }; export const SearchSection: React.FC<{}> = () => { + // Search const [searchResponse, setSearchResponse] = useState( null ); const [isFetching, setIsFetching] = useState(false); + // Filters + const [sources, setSources] = useState([]); + return ( - <> - { - setIsFetching(true); - setSearchResponse({ - answer: null, - quotes: null, - documents: null, - }); - searchRequestStreamed( - query, - (answer) => - setSearchResponse((prevState) => ({ - ...(prevState || initialSearchResponse), - answer, - })), - (quotes) => - setSearchResponse((prevState) => ({ - ...(prevState || initialSearchResponse), - quotes, - })), - (documents) => - setSearchResponse((prevState) => ({ - ...(prevState || initialSearchResponse), - documents, - })) - ).then(() => { - setIsFetching(false); - }); - }} - /> -
- +
+
- +
+ { + setIsFetching(true); + setSearchResponse({ + answer: null, + quotes: null, + documents: null, + }); + searchRequestStreamed({ + query, + sources, + updateCurrentAnswer: (answer) => + setSearchResponse((prevState) => ({ + ...(prevState || initialSearchResponse), + answer, + })), + updateQuotes: (quotes) => + setSearchResponse((prevState) => ({ + ...(prevState || initialSearchResponse), + quotes, + })), + updateDocs: (documents) => + setSearchResponse((prevState) => ({ + ...(prevState || initialSearchResponse), + documents, + })), + }).then(() => { + setIsFetching(false); + }); + }} + /> +
+ +
+
+
); }; diff --git a/web/src/components/search/interfaces.tsx b/web/src/components/search/interfaces.tsx new file mode 100644 index 000000000..4ce570c78 --- /dev/null +++ b/web/src/components/search/interfaces.tsx @@ -0,0 +1,6 @@ +import { ValidSources } from "@/lib/types"; + +export interface Source { + displayName: string; + internalName: ValidSources; +}