Quote loading UI + adding back period to end of answer + adding custom logo (#55)

* Logo

* Add spinners + some small housekeeping on the backend
This commit is contained in:
Chris Weaver
2023-05-16 20:14:06 -07:00
committed by GitHub
parent 821df50fa9
commit 494514dc68
9 changed files with 117 additions and 56 deletions

View File

@@ -1,4 +1,5 @@
import abc import abc
from collections.abc import Generator
from typing import Any from typing import Any
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
@@ -18,5 +19,5 @@ class QAModel:
self, self,
query: str, query: str,
context_docs: list[InferenceChunk], context_docs: list[InferenceChunk],
) -> Any: ) -> Generator[dict[str, Any] | None, None, None]:
raise NotImplementedError raise NotImplementedError

View File

@@ -4,6 +4,7 @@ import re
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
from typing import cast
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
@@ -241,12 +242,12 @@ class OpenAICompletionQA(QAModel):
stream=True, stream=True,
) )
model_output = "" model_output: str = ""
found_answer_start = False found_answer_start = False
found_answer_end = False found_answer_end = False
# iterate through the stream of events # iterate through the stream of events
for event in response: for event in response:
event_text = event["choices"][0]["text"] event_text = cast(str, event["choices"][0]["text"])
model_previous = model_output model_previous = model_output
model_output += event_text model_output += event_text
@@ -259,6 +260,7 @@ class OpenAICompletionQA(QAModel):
if found_answer_start and not found_answer_end: if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text): if stream_answer_end(model_previous, event_text):
found_answer_end = True found_answer_end = True
yield {"answer_finished": True}
continue continue
yield {"answer_data": event_text} yield {"answer_data": event_text}
@@ -343,11 +345,11 @@ class OpenAIChatCompletionQA(QAModel):
stream=True, stream=True,
) )
model_output = "" model_output: str = ""
found_answer_start = False found_answer_start = False
found_answer_end = False found_answer_end = False
for event in response: for event in response:
event_dict = event["choices"][0]["delta"] event_dict = cast(str, event["choices"][0]["delta"])
if ( if (
"content" not in event_dict "content" not in event_dict
): # could be a role message or empty termination ): # could be a role message or empty termination
@@ -365,6 +367,7 @@ class OpenAIChatCompletionQA(QAModel):
if found_answer_start and not found_answer_end: if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text): if stream_answer_end(model_previous, event_text):
found_answer_end = True found_answer_end = True
yield {"answer_finished": True}
continue continue
yield {"answer_data": event_text} yield {"answer_data": event_text}

View File

@@ -16,7 +16,6 @@
# Specifically the sentence-transformers/all-distilroberta-v1 and cross-encoder/ms-marco-MiniLM-L-6-v2 models # Specifically the sentence-transformers/all-distilroberta-v1 and cross-encoder/ms-marco-MiniLM-L-6-v2 models
# The original authors can be found at https://www.sbert.net/ # The original authors can be found at https://www.sbert.net/
import json import json
from typing import List
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.app_configs import NUM_RETURNED_HITS
@@ -65,8 +64,8 @@ def warm_up_models() -> None:
@log_function_time() @log_function_time()
def semantic_reranking( def semantic_reranking(
query: str, query: str,
chunks: List[InferenceChunk], chunks: list[InferenceChunk],
) -> List[InferenceChunk]: ) -> list[InferenceChunk]:
cross_encoder = get_default_reranking_model() cross_encoder = get_default_reranking_model()
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
scored_results = list(zip(sim_scores, chunks)) scored_results = list(zip(sim_scores, chunks))
@@ -84,7 +83,7 @@ def retrieve_ranked_documents(
filters: list[DatastoreFilter] | None, filters: list[DatastoreFilter] | None,
datastore: Datastore, datastore: Datastore,
num_hits: int = NUM_RETURNED_HITS, num_hits: int = NUM_RETURNED_HITS,
) -> List[InferenceChunk] | None: ) -> list[InferenceChunk] | None:
top_chunks = datastore.semantic_retrieval(query, filters, num_hits) top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
if not top_chunks: if not top_chunks:
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "") filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")

View File

@@ -1,15 +1,18 @@
import time import time
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
from typing import TypeVar
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
logger = setup_logger() logger = setup_logger()
F = TypeVar("F", bound=Callable)
def log_function_time( def log_function_time(
func_name: str | None = None, func_name: str | None = None,
) -> Callable[[Callable], Callable]: ) -> Callable[[F], F]:
"""Build a timing wrapper for a function. Logs how long the function took to run. """Build a timing wrapper for a function. Logs how long the function took to run.
Use like: Use like:

BIN
web/public/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

View File

@@ -3,6 +3,7 @@
import { User } from "@/lib/types"; import { User } from "@/lib/types";
import { logout } from "@/lib/user"; import { logout } from "@/lib/user";
import { UserCircle } from "@phosphor-icons/react"; import { UserCircle } from "@phosphor-icons/react";
import Image from "next/image";
import Link from "next/link"; import Link from "next/link";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import React, { useEffect, useRef, useState } from "react"; import React, { useEffect, useRef, useState } from "react";
@@ -52,7 +53,12 @@ export const Header: React.FC<HeaderProps> = ({ user }) => {
<header className="bg-gray-800 text-gray-200 py-4"> <header className="bg-gray-800 text-gray-200 py-4">
<div className="mx-8 flex"> <div className="mx-8 flex">
<Link href="/"> <Link href="/">
<h1 className="text-2xl font-bold">danswer 💃</h1> <div className="flex">
<div className="h-[32px] w-[30px]">
<Image src="/logo.png" alt="Logo" width="1419" height="1520" />
</div>
<h1 className="flex text-2xl font-bold my-auto">Danswer</h1>
</div>
</Link> </Link>
<div <div

View File

@@ -3,9 +3,13 @@ import "./loading.css";
interface LoadingAnimationProps { interface LoadingAnimationProps {
text?: string; text?: string;
size?: "text-sm" | "text-md";
} }
export const LoadingAnimation: React.FC<LoadingAnimationProps> = ({ text }) => { export const LoadingAnimation: React.FC<LoadingAnimationProps> = ({
text,
size,
}) => {
const [dots, setDots] = useState("..."); const [dots, setDots] = useState("...");
useEffect(() => { useEffect(() => {
@@ -29,7 +33,7 @@ export const LoadingAnimation: React.FC<LoadingAnimationProps> = ({ text }) => {
return ( return (
<div className="loading-animation flex"> <div className="loading-animation flex">
<div className="mx-auto"> <div className={"mx-auto flex" + size ? ` ${size}` : ""}>
{text === undefined ? "Thinking" : text} {text === undefined ? "Thinking" : text}
<span className="dots">{dots}</span> <span className="dots">{dots}</span>
</div> </div>

View File

@@ -1,7 +1,22 @@
import React from "react"; import React from "react";
import { Quote, Document } from "./types"; import { Quote, Document } from "./types";
import { LoadingAnimation } from "../Loading";
import { getSourceIcon } from "../source"; import { getSourceIcon } from "../source";
import { LoadingAnimation } from "../Loading";
const removeDuplicateDocs = (documents: Document[]) => {
const seen = new Set<string>();
const output: Document[] = [];
documents.forEach((document) => {
if (
document.semantic_identifier &&
!seen.has(document.semantic_identifier)
) {
output.push(document);
seen.add(document.semantic_identifier);
}
});
return output;
};
interface SearchResultsDisplayProps { interface SearchResultsDisplayProps {
answer: string | null; answer: string | null;
@@ -18,7 +33,13 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
}) => { }) => {
if (!answer) { if (!answer) {
if (isFetching) { if (isFetching) {
return <LoadingAnimation />; return (
<div className="flex">
<div className="mx-auto">
<LoadingAnimation />
</div>
</div>
);
} }
return null; return null;
} }
@@ -41,28 +62,34 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
return ( return (
<> <>
<div className="p-4 border-2 rounded-md border-gray-700"> <div className="p-4 border-2 rounded-md border-gray-700">
<h2 className="text font-bold mb-2">AI Answer</h2> <div className="flex mb-1">
<h2 className="text font-bold my-auto">AI Answer</h2>
</div>
<p className="mb-4">{answer}</p> <p className="mb-4">{answer}</p>
{dedupedQuotes.length > 0 && ( {quotes !== null && (
<> <>
<h2 className="text-sm font-bold mb-2">Sources</h2> <h2 className="text-sm font-bold mb-2">Sources</h2>
<div className="flex"> {isFetching && dedupedQuotes.length === 0 ? (
{dedupedQuotes.map((quoteInfo) => ( <LoadingAnimation text="Finding quotes" size="text-sm" />
<a ) : (
key={quoteInfo.document_id} <div className="flex">
className="p-2 border border-gray-800 rounded-lg text-sm flex max-w-[230px] hover:bg-gray-800" {dedupedQuotes.map((quoteInfo) => (
href={quoteInfo.link} <a
target="_blank" key={quoteInfo.document_id}
rel="noopener noreferrer" className="p-2 border border-gray-800 rounded-lg text-sm flex max-w-[230px] hover:bg-gray-800"
> href={quoteInfo.link}
{getSourceIcon(quoteInfo.source_type, "20")} target="_blank"
<p className="truncate break-all"> rel="noopener noreferrer"
{quoteInfo.semantic_identifier || quoteInfo.document_id} >
</p> {getSourceIcon(quoteInfo.source_type, "20")}
</a> <p className="truncate break-all">
))} {quoteInfo.semantic_identifier || quoteInfo.document_id}
</div> </p>
</a>
))}
</div>
)}
</> </>
)} )}
</div> </div>
@@ -72,25 +99,27 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
<div className="font-bold border-b mb-4 pb-1 border-gray-800"> <div className="font-bold border-b mb-4 pb-1 border-gray-800">
Results Results
</div> </div>
{documents.slice(0, 5).map((doc) => ( {removeDuplicateDocs(documents)
<div .slice(0, 7)
key={doc.document_id} .map((doc) => (
className="text-sm border-b border-gray-800 mb-3" <div
> key={doc.document_id}
<a className="text-sm border-b border-gray-800 mb-3"
className="rounded-lg flex font-bold"
href={doc.link}
target="_blank"
rel="noopener noreferrer"
> >
{getSourceIcon(doc.source_type, "20")} <a
<p className="truncate break-all"> className="rounded-lg flex font-bold"
{doc.semantic_identifier || doc.document_id} href={doc.link}
</p> target="_blank"
</a> rel="noopener noreferrer"
<p className="pl-1 py-3 text-gray-200">{doc.blurb}</p> >
</div> {getSourceIcon(doc.source_type, "20")}
))} <p className="truncate break-all">
{doc.semantic_identifier || doc.document_id}
</p>
</a>
<p className="pl-1 py-3 text-gray-200">{doc.blurb}</p>
</div>
))}
</div> </div>
)} )}
</> </>

View File

@@ -63,6 +63,8 @@ const searchRequestStreamed = async (
url.search = params; url.search = params;
let answer = ""; let answer = "";
let quotes: Record<string, Quote> | null = null;
let relevantDocuments: Document[] | null = null;
try { try {
const response = await fetch(url); const response = await fetch(url);
const reader = response.body?.getReader(); const reader = response.body?.getReader();
@@ -96,12 +98,26 @@ const searchRequestStreamed = async (
if (answerChunk) { if (answerChunk) {
answer += answerChunk; answer += answerChunk;
updateCurrentAnswer(answer); updateCurrentAnswer(answer);
} else if (chunk.answer_finished) {
// set quotes as non-null to signify that the answer is finished and
// we're now looking for quotes
updateQuotes({});
if (
!answer.endsWith(".") &&
!answer.endsWith("?") &&
!answer.endsWith("!")
) {
answer += ".";
updateCurrentAnswer(answer);
}
} else { } else {
const docs = chunk.top_documents as any[]; const docs = chunk.top_documents as any[];
if (docs) { if (docs) {
updateDocs(docs.map((doc) => JSON.parse(doc) as Document)); relevantDocuments = docs.map((doc) => JSON.parse(doc) as Document);
updateDocs(relevantDocuments);
} else { } else {
updateQuotes(chunk as Record<string, Quote>); quotes = chunk as Record<string, Quote>;
updateQuotes(quotes);
} }
} }
}); });
@@ -109,7 +125,7 @@ const searchRequestStreamed = async (
} catch (err) { } catch (err) {
console.error("Fetch error:", err); console.error("Fetch error:", err);
} }
return answer; return { answer, quotes, relevantDocuments };
}; };
export const SearchSection: React.FC<{}> = () => { export const SearchSection: React.FC<{}> = () => {
@@ -123,11 +139,11 @@ export const SearchSection: React.FC<{}> = () => {
<SearchBar <SearchBar
onSearch={(query) => { onSearch={(query) => {
setIsFetching(true); setIsFetching(true);
setAnswer(""); setAnswer(null);
setQuotes(null); setQuotes(null);
setDocuments(null); setDocuments(null);
searchRequestStreamed(query, setAnswer, setQuotes, setDocuments).then( searchRequestStreamed(query, setAnswer, setQuotes, setDocuments).then(
() => { ({ quotes }) => {
setIsFetching(false); setIsFetching(false);
// if no quotes were given, set to empty object so that the SearchResultsDisplay // if no quotes were given, set to empty object so that the SearchResultsDisplay
// component knows that the search was successful but no quotes were found // component knows that the search was successful but no quotes were found