Merge branch 'main' into add-user-on-slack-bot-invoke

This commit is contained in:
Hyeong Joon Suh 2024-09-09 11:52:17 -07:00
commit 261c4b7021
60 changed files with 1373 additions and 519 deletions

View File

@ -75,8 +75,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \
nltk.download('wordnet', quiet=True); \
nltk.download('punkt', quiet=True);"
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
# Set up application files
WORKDIR /app

View File

@ -0,0 +1,26 @@
"""add support for litellm proxy in reranking
Revision ID: ba98eba0f66a
Revises: bceb1e139447
Create Date: 2024-09-06 10:36:04.507332
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ba98eba0f66a"
down_revision = "bceb1e139447"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("search_settings", "rerank_api_url")

View File

@ -88,3 +88,6 @@ HARD_DELETE_CHATS = False
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None
# Enable in-house model for detecting connector-based filtering in queries
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)

View File

@ -45,10 +45,15 @@ def extract_jira_project(url: str) -> tuple[str, str]:
return jira_base, jira_project
def extract_text_from_content(content: dict) -> str:
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if "content" in content:
for block in content["content"]:
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
@ -72,18 +77,15 @@ def _get_comment_strs(
comment_strs = []
for comment in jira.fields.comment.comments:
try:
if hasattr(comment, "body"):
body_text = extract_text_from_content(comment.raw["body"])
elif hasattr(comment, "raw"):
body = comment.raw.get("body", "No body content available")
body_text = (
extract_text_from_content(body) if isinstance(body, dict) else body
)
else:
body_text = "No body attribute found"
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
@ -126,11 +128,14 @@ def fetch_jira_issues_batch(
)
continue
description = (
jira.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(jira.raw["fields"]["description"])
)
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = (
f"{jira.fields.description}\n"
if jira.fields.description
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
semantic_rep = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
page_url = f"{jira_client.client_info()}/browse/{jira.key}"

View File

@ -327,7 +327,7 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
export_mime_type = "text/plain"
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value
GDriveMimeType.MARKDOWN.value,
]:
export_mime_type = mime_type

View File

@ -237,6 +237,14 @@ class NotionConnector(LoadConnector, PollConnector):
)
continue
if result_type == "external_object_instance_page":
logger.warning(
f"Skipping 'external_object_instance_page' ('{result_block_id}') for base block '{base_block_id}': "
f"Notion API does not currently support reading external blocks (as of 24/07/03) "
f"(discussion: https://github.com/danswer-ai/danswer/issues/1761)"
)
continue
cur_result_text_arr = []
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:

View File

@ -25,7 +25,6 @@ from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -137,7 +136,7 @@ class SharepointConnector(LoadConnector, PollConnector):
.execute_query()
]
else:
sites = self.graph_client.sites.get().execute_query()
sites = self.graph_client.sites.get_all().execute_query()
self.site_data = [
SiteData(url=None, folder=None, sites=sites, driveitems=[])
]

View File

@ -29,6 +29,7 @@ from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@ -1,6 +1,8 @@
import io
import ipaddress
import socket
from datetime import datetime
from datetime import timezone
from enum import Enum
from typing import Any
from typing import cast
@ -203,6 +205,15 @@ def _read_urls_file(location: str) -> list[str]:
return urls
def _get_datetime_from_last_modified_header(last_modified: str) -> datetime | None:
try:
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
tzinfo=timezone.utc
)
except (ValueError, TypeError):
return None
class WebConnector(LoadConnector):
def __init__(
self,
@ -288,6 +299,7 @@ class WebConnector(LoadConnector):
page_text, metadata = read_pdf_file(
file=io.BytesIO(response.content)
)
last_modified = response.headers.get("Last-Modified")
doc_batch.append(
Document(
@ -296,12 +308,22 @@ class WebConnector(LoadConnector):
source=DocumentSource.WEB,
semantic_identifier=current_url.split("/")[-1],
metadata=metadata,
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
)
if last_modified
else None,
)
)
continue
page = context.new_page()
page_response = page.goto(current_url)
last_modified = (
page_response.header_value("Last-Modified")
if page_response
else None
)
final_page = page.url
if final_page != current_url:
logger.info(f"Redirected to {final_page}")
@ -337,6 +359,11 @@ class WebConnector(LoadConnector):
source=DocumentSource.WEB,
semantic_identifier=parsed_html.title or current_url,
metadata={},
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
)
if last_modified
else None,
)
)

View File

@ -61,7 +61,7 @@ from shared_configs.enums import RerankerProvider
class Base(DeclarativeBase):
pass
__abstract__ = True
class EncryptedString(TypeDecorator):
@ -450,7 +450,7 @@ class Document(Base):
)
tags = relationship(
"Tag",
secondary="document__tag",
secondary=Document__Tag.__table__,
back_populates="documents",
)
@ -467,7 +467,7 @@ class Tag(Base):
documents = relationship(
"Document",
secondary="document__tag",
secondary=Document__Tag.__table__,
back_populates="tags",
)
@ -578,6 +578,8 @@ class SearchSettings(Base):
Enum(RerankerProvider, native_enum=False), nullable=True
)
rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True)
num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS)
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
@ -816,7 +818,7 @@ class SearchDoc(Base):
chat_messages = relationship(
"ChatMessage",
secondary="chat_message__search_doc",
secondary=ChatMessage__SearchDoc.__table__,
back_populates="search_docs",
)
@ -959,7 +961,7 @@ class ChatMessage(Base):
)
search_docs: Mapped[list["SearchDoc"]] = relationship(
"SearchDoc",
secondary="chat_message__search_doc",
secondary=ChatMessage__SearchDoc.__table__,
back_populates="chat_messages",
)
# NOTE: Should always be attached to the `assistant` message.

View File

@ -240,9 +240,43 @@ def read_pdf_file(
def docx_to_text(file: IO[Any]) -> str:
def is_simple_table(table: docx.table.Table) -> bool:
for row in table.rows:
# No omitted cells
if row.grid_cols_before > 0 or row.grid_cols_after > 0:
return False
# No nested tables
if any(cell.tables for cell in row.cells):
return False
return True
def extract_cell_text(cell: docx.table._Cell) -> str:
cell_paragraphs = [para.text.strip() for para in cell.paragraphs]
return " ".join(p for p in cell_paragraphs if p) or "N/A"
paragraphs = []
doc = docx.Document(file)
full_text = [para.text for para in doc.paragraphs]
return TEXT_SECTION_SEPARATOR.join(full_text)
for item in doc.iter_inner_content():
if isinstance(item, docx.text.paragraph.Paragraph):
paragraphs.append(item.text)
elif isinstance(item, docx.table.Table):
if not item.rows or not is_simple_table(item):
continue
# Every row is a new line, joined with a single newline
table_content = "\n".join(
[
",\t".join(extract_cell_text(cell) for cell in row.cells)
for row in item.rows
]
)
paragraphs.append(table_content)
# Docx already has good spacing between paragraphs
return "\n".join(paragraphs)
def pptx_to_text(file: IO[Any]) -> str:

View File

@ -392,8 +392,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
if search_settings.rerank_model_name and not search_settings.provider_type:
if (
search_settings.rerank_model_name
and not search_settings.provider_type
and not search_settings.rerank_provider_type
):
warm_up_cross_encoder(search_settings.rerank_model_name)
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")

View File

@ -24,6 +24,8 @@ from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
@ -240,6 +242,7 @@ class RerankingModel:
model_name: str,
provider_type: RerankerProvider | None,
api_key: str | None,
api_url: str | None,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
@ -248,6 +251,7 @@ class RerankingModel:
self.model_name = model_name
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url
def predict(self, query: str, passages: list[str]) -> list[float]:
rerank_request = RerankRequest(
@ -256,6 +260,7 @@ class RerankingModel:
model_name=self.model_name,
provider_type=self.provider_type,
api_key=self.api_key,
api_url=self.api_url,
)
response = requests.post(
@ -301,6 +306,37 @@ class QueryAnalysisModel:
return response_model.is_keyword, response_model.keywords
class ConnectorClassificationModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
):
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.connector_classification_endpoint = (
model_server_url + "/custom/connector-classification"
)
def predict(
self,
query: str,
available_connectors: list[str],
) -> list[str]:
connector_classification_request = ConnectorClassificationRequest(
available_connectors=available_connectors,
query=query,
)
response = requests.post(
self.connector_classification_endpoint,
json=connector_classification_request.dict(),
)
response.raise_for_status()
response_model = ConnectorClassificationResponse(**response.json())
return response_model.connectors
def warm_up_retry(
func: Callable[..., Any],
tries: int = 20,
@ -367,6 +403,7 @@ def warm_up_cross_encoder(
reranking_model = RerankingModel(
model_name=rerank_model_name,
provider_type=None,
api_url=None,
api_key=None,
)

View File

@ -26,6 +26,7 @@ MAX_METRICS_CONTENT = (
class RerankingDetails(BaseModel):
# If model is None (or num_rerank is 0), then reranking is turned off
rerank_model_name: str | None
rerank_api_url: str | None
rerank_provider_type: RerankerProvider | None
rerank_api_key: str | None = None
@ -42,6 +43,7 @@ class RerankingDetails(BaseModel):
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
rerank_api_url=search_settings.rerank_api_url,
)
@ -81,7 +83,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
num_rerank=search_settings.num_rerank,
# Multilingual Expansion
multilingual_expansion=search_settings.multilingual_expansion,
api_url=search_settings.api_url,
rerank_api_url=search_settings.rerank_api_url,
)

View File

@ -100,6 +100,7 @@ def semantic_reranking(
model_name=rerank_settings.rerank_model_name,
provider_type=rerank_settings.rerank_provider_type,
api_key=rerank_settings.rerank_api_key,
api_url=rerank_settings.rerank_api_url,
)
passages = [

View File

@ -3,7 +3,6 @@ from collections.abc import Callable
import nltk # type:ignore
from nltk.corpus import stopwords # type:ignore
from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
@ -40,7 +39,7 @@ logger = setup_logger()
def download_nltk_data() -> None:
resources = {
"stopwords": "corpora/stopwords",
"wordnet": "corpora/wordnet",
# "wordnet": "corpora/wordnet", # Not in use
"punkt": "tokenizers/punkt",
}
@ -58,15 +57,16 @@ def download_nltk_data() -> None:
def lemmatize_text(keywords: list[str]) -> list[str]:
try:
query = " ".join(keywords)
lemmatizer = WordNetLemmatizer()
word_tokens = word_tokenize(query)
lemmatized_words = [lemmatizer.lemmatize(word) for word in word_tokens]
combined_keywords = list(set(keywords + lemmatized_words))
return combined_keywords
except Exception:
return keywords
raise NotImplementedError("Lemmatization should not be used currently")
# try:
# query = " ".join(keywords)
# lemmatizer = WordNetLemmatizer()
# word_tokens = word_tokenize(query)
# lemmatized_words = [lemmatizer.lemmatize(word) for word in word_tokens]
# combined_keywords = list(set(keywords + lemmatized_words))
# return combined_keywords
# except Exception:
# return keywords
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:

View File

@ -3,12 +3,16 @@ import random
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import ENABLE_CONNECTOR_CLASSIFIER
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_unique_document_sources
from danswer.db.engine import get_sqlalchemy_engine
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.natural_language_processing.search_nlp_models import (
ConnectorClassificationModel,
)
from danswer.prompts.constants import SOURCES_KEY
from danswer.prompts.filter_extration import FILE_SOURCE_WARNING
from danswer.prompts.filter_extration import SOURCE_FILTER_PROMPT
@ -42,11 +46,38 @@ def _sample_document_sources(
return random.sample(valid_sources, num_sample)
def _sample_documents_using_custom_connector_classifier(
query: str,
valid_sources: list[DocumentSource],
) -> list[DocumentSource] | None:
query_joined = "".join(ch for ch in query.lower() if ch.isalnum())
available_connectors = list(
filter(
lambda conn: conn.lower() in query_joined,
[item.value for item in valid_sources],
)
)
if not available_connectors:
return None
connectors = ConnectorClassificationModel().predict(query, available_connectors)
return strings_to_document_sources(connectors) if connectors else None
def extract_source_filter(
query: str, llm: LLM, db_session: Session
) -> list[DocumentSource] | None:
"""Returns a list of valid sources for search or None if no specific sources were detected"""
valid_sources = fetch_unique_document_sources(db_session)
if not valid_sources:
return None
if ENABLE_CONNECTOR_CLASSIFIER:
return _sample_documents_using_custom_connector_classifier(query, valid_sources)
def _get_source_filter_messages(
query: str,
valid_sources: list[DocumentSource],
@ -146,10 +177,6 @@ def extract_source_filter(
logger.warning("LLM failed to provide a valid Source Filter output")
return None
valid_sources = fetch_unique_document_sources(db_session)
if not valid_sources:
return None
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))

View File

@ -4,6 +4,7 @@ from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import DocumentSource
@ -346,8 +347,18 @@ class GoogleServiceAccountKey(BaseModel):
class GoogleServiceAccountCredentialRequest(BaseModel):
google_drive_delegated_user: str | None # email of user to impersonate
gmail_delegated_user: str | None # email of user to impersonate
google_drive_delegated_user: str | None = None # email of user to impersonate
gmail_delegated_user: str | None = None # email of user to impersonate
@model_validator(mode="after")
def check_user_delegation(self) -> "GoogleServiceAccountCredentialRequest":
if (self.google_drive_delegated_user is None) == (
self.gmail_delegated_user is None
):
raise ValueError(
"Exactly one of google_drive_delegated_user or gmail_delegated_user must be set"
)
return self
class FileUploadResponse(BaseModel):

View File

@ -3,15 +3,21 @@ import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from danswer.utils.logger import setup_logger
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.danswer_torch_model import ConnectorClassifier
from model_server.danswer_torch_model import HybridClassifier
from model_server.utils import simple_log_function_time
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INTENT_MODEL_TAG
from shared_configs.configs import INTENT_MODEL_VERSION
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
@ -19,10 +25,55 @@ logger = setup_logger()
router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
"distilbert-base-uncased"
)
return _CONNECTOR_CLASSIFIER_TOKENIZER
def get_local_connector_classifier(
model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
) -> ConnectorClassifier:
global _CONNECTOR_CLASSIFIER_MODEL
if _CONNECTOR_CLASSIFIER_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
local_path
)
except Exception as e:
logger.warning(f"Failed to load model directly: {e}")
try:
# Attempt to download the model snapshot
logger.info(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
local_path
)
except Exception as e:
logger.error(
f"Failed to load model even after attempted snapshot download: {e}"
)
raise
return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> AutoTokenizer:
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
@ -61,6 +112,74 @@ def get_local_intent_model(
return _INTENT_MODEL
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
tokenizer: PreTrainedTokenizer,
connector_token_end_id: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
token and then the user query.
"""
input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
for connector in connectors:
connector_token_ids = tokenizer(
connector,
add_special_tokens=False,
return_tensors="pt",
)
input_ids = torch.cat(
(
input_ids,
connector_token_ids["input_ids"].squeeze(dim=0),
torch.tensor([connector_token_end_id], dtype=torch.long),
),
dim=-1,
)
query_token_ids = tokenizer(
query,
add_special_tokens=False,
return_tensors="pt",
)
input_ids = torch.cat(
(
input_ids,
query_token_ids["input_ids"].squeeze(dim=0),
torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
),
dim=-1,
)
attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
def warm_up_connector_classifier_model() -> None:
logger.info(
f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
)
connector_classifier_tokenizer = get_connector_classifier_tokenizer()
connector_classifier = get_local_connector_classifier()
input_ids, attention_mask = tokenize_connector_classification_query(
["GitHub"],
"danswer classifier query google doc",
connector_classifier_tokenizer,
connector_classifier.connector_end_token_id,
)
input_ids = input_ids.to(connector_classifier.device)
attention_mask = attention_mask.to(connector_classifier.device)
connector_classifier(input_ids, attention_mask)
def warm_up_intent_model() -> None:
logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
intent_tokenizer = get_intent_model_tokenizer()
@ -157,6 +276,35 @@ def clean_keywords(keywords: list[str]) -> list[str]:
return cleaned_words
def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
tokenizer = get_connector_classifier_tokenizer()
model = get_local_connector_classifier()
connector_names = req.available_connectors
input_ids, attention_mask = tokenize_connector_classification_query(
connector_names,
req.query,
tokenizer,
model.connector_end_token_id,
)
input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)
global_confidence, classifier_confidence = model(input_ids, attention_mask)
if global_confidence.item() < 0.5:
return []
passed_connectors = []
for i, connector_name in enumerate(connector_names):
if classifier_confidence.view(-1)[i].item() > 0.5:
passed_connectors.append(connector_name)
return passed_connectors
def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
tokenizer = get_intent_model_tokenizer()
model_input = tokenizer(
@ -189,6 +337,22 @@ def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
return is_keyword_sequence, cleaned_keywords
@router.post("/connector-classification")
async def process_connector_classification_request(
classification_request: ConnectorClassificationRequest,
) -> ConnectorClassificationResponse:
if INDEXING_ONLY:
raise RuntimeError(
"Indexing model server should not call connector classification endpoint"
)
if len(classification_request.available_connectors) == 0:
return ConnectorClassificationResponse(connectors=[])
connectors = run_connector_classification(classification_request)
return ConnectorClassificationResponse(connectors=connectors)
@router.post("/query-analysis")
async def process_analysis_request(
intent_request: IntentRequest,

View File

@ -4,7 +4,8 @@ import os
import torch
import torch.nn as nn
from transformers import DistilBertConfig # type: ignore
from transformers import DistilBertModel
from transformers import DistilBertModel # type: ignore
from transformers import DistilBertTokenizer # type: ignore
class HybridClassifier(nn.Module):
@ -21,7 +22,6 @@ class HybridClassifier(nn.Module):
self.distilbert.config.dim, self.distilbert.config.dim
)
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
self.dropout = nn.Dropout(self.distilbert.config.seq_classif_dropout)
self.device = torch.device("cpu")
@ -36,8 +36,7 @@ class HybridClassifier(nn.Module):
# Intent classification on the CLS token
cls_token_state = sequence_output[:, 0, :]
pre_classifier_out = self.pre_classifier(cls_token_state)
dropout_out = self.dropout(pre_classifier_out)
intent_logits = self.intent_classifier(dropout_out)
intent_logits = self.intent_classifier(pre_classifier_out)
# Keyword classification on all tokens
token_logits = self.keyword_classifier(sequence_output)
@ -72,3 +71,70 @@ class HybridClassifier(nn.Module):
param.requires_grad = False
return model
class ConnectorClassifier(nn.Module):
def __init__(self, config: DistilBertConfig) -> None:
super().__init__()
self.config = config
self.distilbert = DistilBertModel(config)
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# Token indicating end of connector name, and on which classifier is used
self.connector_end_token_id = self.tokenizer.get_vocab()[
self.config.connector_end_token
]
self.device = torch.device("cpu")
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
cls_hidden_states = hidden_states[
:, 0, :
] # Take leap of faith that first token is always [CLS]
global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
global_confidence = torch.sigmoid(global_logits).view(-1)
connector_end_position_ids = input_ids == self.connector_end_token_id
connector_end_hidden_states = hidden_states[connector_end_position_ids]
classifier_output = self.connector_match_classifier(connector_end_hidden_states)
classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
return global_confidence, classifier_confidence
@classmethod
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("mps")
if torch.backends.mps.is_available()
else torch.device("cpu")
)
state_dict = torch.load(
os.path.join(repo_dir, "pytorch_model.pt"),
map_location=device,
weights_only=True,
)
model = cls(config)
model.load_state_dict(state_dict)
model.to(device)
model.device = device
model.eval()
for param in model.parameters():
param.requires_grad = False
return model

View File

@ -362,6 +362,28 @@ def cohere_rerank(
return [result.relevance_score for result in sorted_results]
def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
with httpx.Client() as client:
response = client.post(
api_url,
json={
"model": model_name,
"query": query,
"documents": docs,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [
item["relevance_score"]
for item in sorted(result["results"], key=lambda x: x["index"])
]
@router.post("/bi-encoder-embed")
async def process_embed_request(
embed_request: EmbedRequest,
@ -418,6 +440,20 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.LITELLM:
if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")
sim_scores = litellm_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
api_url=rerank_request.api_url,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")

View File

@ -1,4 +1,4 @@
aiohttp==3.9.4
aiohttp==3.10.2
alembic==1.10.4
asyncpg==0.27.0
atlassian-python-api==3.37.0
@ -26,13 +26,12 @@ huggingface-hub==0.20.1
jira==3.5.1
jsonref==1.1.0
langchain==0.1.17
langchain-community==0.0.36
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.43.18
llama-index==0.9.45
Mako==1.2.4
msal==1.26.0
msal==1.28.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
@ -50,7 +49,7 @@ python-pptx==0.6.23
pypdf==3.17.0
pytest-mock==3.12.0
pytest-playwright==0.3.2
python-docx==1.1.0
python-docx==1.1.2
python-dotenv==1.0.0
python-multipart==0.0.7
pywikibot==9.0.0

View File

@ -8,7 +8,7 @@ pydantic==2.8.2
retry==0.9.2
safetensors==0.4.2
sentence-transformers==2.6.1
torch==2.0.1
torch==2.2.0
transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3

View File

@ -16,9 +16,12 @@ INDEXING_MODEL_SERVER_PORT = int(
)
# Danswer custom Deep Learning Models
CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
INTENT_MODEL_TAG = "v1.0.3"
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512

View File

@ -11,6 +11,7 @@ class EmbeddingProvider(str, Enum):
class RerankerProvider(str, Enum):
COHERE = "cohere"
LITELLM = "litellm"
class EmbedTextType(str, Enum):

View File

@ -7,6 +7,15 @@ from shared_configs.enums import RerankerProvider
Embedding = list[float]
class ConnectorClassificationRequest(BaseModel):
available_connectors: list[str]
query: str
class ConnectorClassificationResponse(BaseModel):
connectors: list[str]
class EmbedRequest(BaseModel):
texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases
@ -34,6 +43,7 @@ class RerankRequest(BaseModel):
model_name: str
provider_type: RerankerProvider | None = None
api_key: str | None = None
api_url: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@ -1,4 +1,3 @@
version: '3'
services:
api_server:
image: danswer/danswer-backend:${IMAGE_TAG:-latest}

View File

@ -1,4 +1,3 @@
version: '3'
services:
api_server:
image: danswer/danswer-backend:${IMAGE_TAG:-latest}

View File

@ -1,4 +1,3 @@
version: '3'
services:
api_server:
image: danswer/danswer-backend:${IMAGE_TAG:-latest}

View File

@ -1,4 +1,3 @@
version: '3'
services:
api_server:
image: danswer/danswer-backend:${IMAGE_TAG:-latest}

View File

@ -1,4 +1,3 @@
version: '3'
services:
api_server:
image: danswer/danswer-backend:${IMAGE_TAG:-latest}

View File

@ -1,33 +1,8 @@
# This env template shows how to configure Danswer for multilingual use
# In this case, it is configured for French and English
# To use it, copy it to .env in the docker_compose directory.
# Feel free to combine it with the other templates to suit your needs
# This env template shows how to configure Danswer for custom multilingual use
# Note that for most use cases it will be enough to configure Danswer multilingual purely through the UI
# See "Search Settings" -> "Advanced" for UI options.
# To use it, copy it to .env in the docker_compose directory (or the equivalent environment settings file for your deployment)
# Rephrase the user query in specified languages using LLM, use comma separated values
MULTILINGUAL_QUERY_EXPANSION="English, French"
# Change the below to suit your specific needs, can be more explicit about the language of the response
LANGUAGE_HINT="IMPORTANT: Respond in the same language as my query!"
# The following is included with the user prompt. Here's one example but feel free to customize it to your needs:
LANGUAGE_HINT="IMPORTANT: ALWAYS RESPOND IN FRENCH! Even if the documents and the user query are in English, your response must be in French."
LANGUAGE_CHAT_NAMING_HINT="The name of the conversation must be in the same language as the user query."
# A recent MIT license multilingual model: https://huggingface.co/intfloat/multilingual-e5-small
DOCUMENT_ENCODER_MODEL="intfloat/multilingual-e5-small"
# The model above is trained with the following prefix for queries and passages to improve retrieval
# by letting the model know which of the two type is currently being embedded
ASYM_QUERY_PREFIX="query: "
ASYM_PASSAGE_PREFIX="passage: "
# Depends model by model, the one shown above is tuned with this as True
NORMALIZE_EMBEDDINGS="True"
# Use LLM to determine if chunks are relevant to the query
# May not work well for languages that do not have much training data in the LLM training set
# If using a common language like Spanish, French, Chinese, etc. this can be kept turned on
DISABLE_LLM_DOC_RELEVANCE="True"
# Enables fine-grained embeddings for better retrieval
# At the cost of indexing speed (~5x slower), query time is same speed
# Since reranking is turned off and multilingual retrieval is generally harder
# it is advised to turn this one on
ENABLE_MULTIPASS_INDEXING="True"

View File

@ -8,47 +8,6 @@ const version = env_version || package_version;
const nextConfig = {
output: "standalone",
swcMinify: true,
rewrites: async () => {
// In production, something else (nginx in the one box setup) should take
// care of this rewrite. TODO (chris): better support setups where
// web_server and api_server are on different machines.
if (process.env.NODE_ENV === "production") return [];
return [
{
source: "/api/:path*",
destination: "http://127.0.0.1:8080/:path*", // Proxy to Backend
},
];
},
redirects: async () => {
// In production, something else (nginx in the one box setup) should take
// care of this redirect. TODO (chris): better support setups where
// web_server and api_server are on different machines.
const defaultRedirects = [];
if (process.env.NODE_ENV === "production") return defaultRedirects;
return defaultRedirects.concat([
{
source: "/api/chat/send-message:params*",
destination: "http://127.0.0.1:8080/chat/send-message:params*", // Proxy to Backend
permanent: true,
},
{
source: "/api/query/stream-answer-with-quote:params*",
destination:
"http://127.0.0.1:8080/query/stream-answer-with-quote:params*", // Proxy to Backend
permanent: true,
},
{
source: "/api/query/stream-query-validation:params*",
destination:
"http://127.0.0.1:8080/query/stream-query-validation:params*", // Proxy to Backend
permanent: true,
},
]);
},
publicRuntimeConfig: {
version,
},

14
web/package-lock.json generated
View File

@ -2555,11 +2555,11 @@
}
},
"node_modules/braces": {
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz",
"integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==",
"version": "3.0.3",
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz",
"integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==",
"dependencies": {
"fill-range": "^7.0.1"
"fill-range": "^7.1.1"
},
"engines": {
"node": ">=8"
@ -4061,9 +4061,9 @@
}
},
"node_modules/fill-range": {
"version": "7.0.1",
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz",
"integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==",
"version": "7.1.1",
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz",
"integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==",
"dependencies": {
"to-regex-range": "^5.0.1"
},

View File

@ -31,11 +31,13 @@ export function LLMProviderUpdateForm({
existingLlmProvider,
shouldMarkAsDefault,
setPopup,
hideAdvanced,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
shouldMarkAsDefault?: boolean;
hideAdvanced?: boolean;
setPopup?: (popup: PopupSpec) => void;
}) {
const { mutate } = useSWRConfig();
@ -52,7 +54,7 @@ export function LLMProviderUpdateForm({
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name ?? "",
name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? "",
api_version: existingLlmProvider?.api_version ?? "",
@ -218,17 +220,20 @@ export function LLMProviderUpdateForm({
}}
>
{({ values, setFieldValue }) => (
<Form className="gap-y-6 items-stretch mt-8">
<TextFormField
name="name"
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
disabled={existingLlmProvider ? true : false}
/>
<Form className="gap-y-4 items-stretch mt-6">
{!hideAdvanced && (
<TextFormField
name="name"
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
disabled={existingLlmProvider ? true : false}
/>
)}
{llmProviderDescriptor.api_key_required && (
<TextFormField
small={hideAdvanced}
name="api_key"
label="API Key"
placeholder="API Key"
@ -238,6 +243,7 @@ export function LLMProviderUpdateForm({
{llmProviderDescriptor.api_base_required && (
<TextFormField
small={hideAdvanced}
name="api_base"
label="API Base"
placeholder="API Base"
@ -246,6 +252,7 @@ export function LLMProviderUpdateForm({
{llmProviderDescriptor.api_version_required && (
<TextFormField
small={hideAdvanced}
name="api_version"
label="API Version"
placeholder="API Version"
@ -255,6 +262,7 @@ export function LLMProviderUpdateForm({
{llmProviderDescriptor.custom_config_keys?.map((customConfigKey) => (
<div key={customConfigKey.name}>
<TextFormField
small={hideAdvanced}
name={`custom_config.${customConfigKey.name}`}
label={
customConfigKey.is_required
@ -266,134 +274,144 @@ export function LLMProviderUpdateForm({
</div>
))}
<Divider />
{llmProviderDescriptor.llm_names.length > 0 ? (
<SelectorFormField
name="default_model_name"
subtext="The model to use by default for this provider unless otherwise specified."
label="Default Model"
options={llmProviderDescriptor.llm_names.map((name) => ({
name: getDisplayNameForModel(name),
value: name,
}))}
maxHeight="max-h-56"
/>
) : (
<TextFormField
name="default_model_name"
subtext="The model to use by default for this provider unless otherwise specified."
label="Default Model"
placeholder="E.g. gpt-4"
/>
)}
{llmProviderDescriptor.llm_names.length > 0 ? (
<SelectorFormField
name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If \`Default\` is specified, will use
the Default Model configured above.`}
label="[Optional] Fast Model"
options={llmProviderDescriptor.llm_names.map((name) => ({
name: getDisplayNameForModel(name),
value: name,
}))}
includeDefault
maxHeight="max-h-56"
/>
) : (
<TextFormField
name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If \`Default\` is specified, will use
the Default Model configured above.`}
label="[Optional] Fast Model"
placeholder="E.g. gpt-4"
/>
)}
<Divider />
{llmProviderDescriptor.name != "azure" && (
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
)}
{showAdvancedOptions && (
{!hideAdvanced && (
<>
{llmProviderDescriptor.llm_names.length > 0 && (
<div className="w-full">
<MultiSelectField
selectedInitially={values.display_model_names}
name="display_model_names"
label="Display Models"
subtext="Select the models to make available to users. Unselected models will not be available."
options={llmProviderDescriptor.llm_names.map((name) => ({
value: name,
label: getDisplayNameForModel(name),
}))}
onChange={(selected) =>
setFieldValue("display_model_names", selected)
}
/>
</div>
<Divider />
{llmProviderDescriptor.llm_names.length > 0 ? (
<SelectorFormField
name="default_model_name"
subtext="The model to use by default for this provider unless otherwise specified."
label="Default Model"
options={llmProviderDescriptor.llm_names.map((name) => ({
name: getDisplayNameForModel(name),
value: name,
}))}
maxHeight="max-h-56"
/>
) : (
<TextFormField
name="default_model_name"
subtext="The model to use by default for this provider unless otherwise specified."
label="Default Model"
placeholder="E.g. gpt-4"
/>
)}
{isPaidEnterpriseFeaturesEnabled && userGroups && (
<>
<BooleanFormField
small
removeIndent
alignTop
name="is_public"
label="Is Public?"
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
/>
{llmProviderDescriptor.llm_names.length > 0 ? (
<SelectorFormField
name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If \`Default\` is specified, will use
the Default Model configured above.`}
label="[Optional] Fast Model"
options={llmProviderDescriptor.llm_names.map((name) => ({
name: getDisplayNameForModel(name),
value: name,
}))}
includeDefault
maxHeight="max-h-56"
/>
) : (
<TextFormField
name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If \`Default\` is specified, will use
the Default Model configured above.`}
label="[Optional] Fast Model"
placeholder="E.g. gpt-4"
/>
)}
{userGroups && userGroups.length > 0 && !values.is_public && (
<div>
<Text>
Select which User Groups should have access to this LLM
Provider.
</Text>
<div className="flex flex-wrap gap-2 mt-2">
{userGroups.map((userGroup) => {
const isSelected = values.groups.includes(
userGroup.id
);
return (
<Bubble
key={userGroup.id}
isSelected={isSelected}
onClick={() => {
if (isSelected) {
setFieldValue(
"groups",
values.groups.filter(
(id) => id !== userGroup.id
)
);
} else {
setFieldValue("groups", [
...values.groups,
userGroup.id,
]);
}
}}
>
<div className="flex">
<GroupsIcon />
<div className="ml-1">{userGroup.name}</div>
</div>
</Bubble>
);
})}
</div>
<Divider />
{llmProviderDescriptor.name != "azure" && (
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
)}
{showAdvancedOptions && (
<>
{llmProviderDescriptor.llm_names.length > 0 && (
<div className="w-full">
<MultiSelectField
selectedInitially={values.display_model_names}
name="display_model_names"
label="Display Models"
subtext="Select the models to make available to users. Unselected models will not be available."
options={llmProviderDescriptor.llm_names.map(
(name) => ({
value: name,
label: getDisplayNameForModel(name),
})
)}
onChange={(selected) =>
setFieldValue("display_model_names", selected)
}
/>
</div>
)}
{isPaidEnterpriseFeaturesEnabled && userGroups && (
<>
<BooleanFormField
small
removeIndent
alignTop
name="is_public"
label="Is Public?"
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
/>
{userGroups &&
userGroups.length > 0 &&
!values.is_public && (
<div>
<Text>
Select which User Groups should have access to
this LLM Provider.
</Text>
<div className="flex flex-wrap gap-2 mt-2">
{userGroups.map((userGroup) => {
const isSelected = values.groups.includes(
userGroup.id
);
return (
<Bubble
key={userGroup.id}
isSelected={isSelected}
onClick={() => {
if (isSelected) {
setFieldValue(
"groups",
values.groups.filter(
(id) => id !== userGroup.id
)
);
} else {
setFieldValue("groups", [
...values.groups,
userGroup.id,
]);
}
}}
>
<div className="flex">
<GroupsIcon />
<div className="ml-1">
{userGroup.name}
</div>
</div>
</Bubble>
);
})}
</div>
</div>
)}
</>
)}
</>
)}
</>
@ -432,6 +450,27 @@ export function LLMProviderUpdateForm({
return;
}
// If the deleted provider was the default, set the first remaining provider as default
const remainingProvidersResponse = await fetch(
LLM_PROVIDERS_ADMIN_URL
);
if (remainingProvidersResponse.ok) {
const remainingProviders =
await remainingProvidersResponse.json();
if (remainingProviders.length > 0) {
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
{
method: "POST",
}
);
if (!setDefaultResponse.ok) {
console.error("Failed to set new default provider");
}
}
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
}}

View File

@ -17,6 +17,8 @@ import {
GoogleDriveServiceAccountCredentialJson,
} from "@/lib/connectors/credentials";
import { Button as TremorButton } from "@tremor/react";
type GoogleDriveCredentialJsonTypes = "authorized_user" | "service_account";
export const DriveJsonUpload = ({
@ -344,7 +346,7 @@ export const DriveOAuthSection = ({
if (serviceAccountKeyData?.service_account_email) {
return (
<div>
<p className="text-sm mb-2">
<p className="text-sm mb-6">
When using a Google Drive Service Account, you can either have Danswer
act as the service account itself OR you can specify an account for
the service account to impersonate.
@ -356,70 +358,59 @@ export const DriveOAuthSection = ({
the documents you want to index with the service account.
</p>
<Card>
<Formik
initialValues={{
google_drive_delegated_user: "",
}}
validationSchema={Yup.object().shape({
google_drive_delegated_user: Yup.string().optional(),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
const response = await fetch(
"/api/manage/admin/connector/google-drive/service-account-credential",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
google_drive_delegated_user:
values.google_drive_delegated_user,
}),
}
);
if (response.ok) {
setPopup({
message: "Successfully created service account credential",
type: "success",
});
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to create service account credential - ${errorMsg}`,
type: "error",
});
<Formik
initialValues={{
google_drive_delegated_user: "",
}}
validationSchema={Yup.object().shape({
google_drive_delegated_user: Yup.string().optional(),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
const response = await fetch(
"/api/manage/admin/connector/google-drive/service-account-credential",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
google_drive_delegated_user:
values.google_drive_delegated_user,
}),
}
refreshCredentials();
}}
>
{({ isSubmitting }) => (
<Form>
<TextFormField
name="google_drive_delegated_user"
label="[Optional] User email to impersonate:"
subtext="If left blank, Danswer will use the service account itself."
/>
<div className="flex">
<button
type="submit"
disabled={isSubmitting}
className={
"bg-slate-500 hover:bg-slate-700 text-white " +
"font-bold py-2 px-4 rounded focus:outline-none " +
"focus:shadow-outline w-full max-w-sm mx-auto"
}
>
Submit
</button>
</div>
</Form>
)}
</Formik>
</Card>
);
if (response.ok) {
setPopup({
message: "Successfully created service account credential",
type: "success",
});
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to create service account credential - ${errorMsg}`,
type: "error",
});
}
refreshCredentials();
}}
>
{({ isSubmitting }) => (
<Form>
<TextFormField
name="google_drive_delegated_user"
label="[Optional] User email to impersonate:"
subtext="If left blank, Danswer will use the service account itself."
/>
<div className="flex">
<TremorButton type="submit" disabled={isSubmitting}>
Create Credential
</TremorButton>
</div>
</Form>
)}
</Formik>
</div>
);
}

View File

@ -8,8 +8,6 @@ import { ErrorCallout } from "@/components/ErrorCallout";
import { LoadingAnimation } from "@/components/Loading";
import { usePopup } from "@/components/admin/connectors/Popup";
import { ConnectorIndexingStatus } from "@/lib/types";
import { getCurrentUser } from "@/lib/user";
import { User, UserRole } from "@/lib/types";
import { usePublicCredentials } from "@/lib/hooks";
import { Title } from "@tremor/react";
import { DriveJsonUploadSection, DriveOAuthSection } from "./Credential";
@ -109,6 +107,7 @@ const GDriveMain = ({}: {}) => {
| undefined = credentialsData.find(
(credential) => credential.credential_json?.google_drive_service_account_key
);
const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus<
GoogleDriveConfig,
GoogleDriveCredentialJson

View File

@ -7,7 +7,11 @@ import {
rerankingModels,
} from "./interfaces";
import { FiExternalLink } from "react-icons/fi";
import { CohereIcon, MixedBreadIcon } from "@/components/icons/icons";
import {
CohereIcon,
LiteLLMIcon,
MixedBreadIcon,
} from "@/components/icons/icons";
import { Modal } from "@/components/Modal";
import { Button } from "@tremor/react";
import { TextFormField } from "@/components/admin/connectors/Field";
@ -35,6 +39,8 @@ const RerankingDetailsForm = forwardRef<
ref
) => {
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false);
const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] =
useState(false);
return (
<Formik
@ -48,13 +54,17 @@ const RerankingDetailsForm = forwardRef<
.optional(),
api_key: Yup.string().nullable(),
num_rerank: Yup.number().min(1, "Must be at least 1"),
rerank_api_url: Yup.string()
.url("Must be a valid URL")
.matches(/^https?:\/\//, "URL must start with http:// or https://")
.nullable(),
})}
onSubmit={async (_, { setSubmitting }) => {
setSubmitting(false);
}}
enableReinitialize={true}
>
{({ values, setFieldValue }) => {
{({ values, setFieldValue, resetForm }) => {
const resetRerankingValues = () => {
setRerankingDetails({
...values,
@ -131,14 +141,22 @@ const RerankingDetailsForm = forwardRef<
)
: rerankingModels.filter(
(modelCard) =>
modelCard.modelName ==
originalRerankingDetails.rerank_model_name
(modelCard.modelName ==
originalRerankingDetails.rerank_model_name &&
modelCard.rerank_provider_type ==
originalRerankingDetails.rerank_provider_type) ||
(modelCard.rerank_provider_type ==
RerankerProvider.LITELLM &&
originalRerankingDetails.rerank_provider_type ==
RerankerProvider.LITELLM)
)
).map((card) => {
const isSelected =
values.rerank_provider_type ===
card.rerank_provider_type &&
values.rerank_model_name === card.modelName;
(card.modelName == null ||
values.rerank_model_name === card.modelName);
return (
<div
key={`${card.rerank_provider_type}-${card.modelName}`}
@ -148,26 +166,39 @@ const RerankingDetailsForm = forwardRef<
: "border-gray-200 hover:border-blue-300 hover:shadow-sm"
}`}
onClick={() => {
if (card.rerank_provider_type) {
if (
card.rerank_provider_type == RerankerProvider.COHERE
) {
setIsApiKeyModalOpen(true);
} else if (
card.rerank_provider_type ==
RerankerProvider.LITELLM
) {
setShowLiteLLMConfigurationModal(true);
}
if (!isSelected) {
setRerankingDetails({
...values,
rerank_provider_type: card.rerank_provider_type!,
rerank_model_name: card.modelName || null,
rerank_api_key: null,
rerank_api_url: null,
});
setFieldValue(
"rerank_provider_type",
card.rerank_provider_type
);
setFieldValue("rerank_model_name", card.modelName);
}
setRerankingDetails({
...values,
rerank_provider_type: card.rerank_provider_type!,
rerank_model_name: card.modelName,
rerank_api_key: null,
});
setFieldValue(
"rerank_provider_type",
card.rerank_provider_type
);
setFieldValue("rerank_model_name", card.modelName);
}}
>
<div className="flex items-center justify-between mb-3">
<div className="flex items-center">
{card.rerank_provider_type ===
RerankerProvider.COHERE ? (
RerankerProvider.LITELLM ? (
<LiteLLMIcon size={24} className="mr-2" />
) : RerankerProvider.COHERE ? (
<CohereIcon size={24} className="mr-2" />
) : (
<MixedBreadIcon size={24} className="mr-2" />
@ -199,6 +230,88 @@ const RerankingDetailsForm = forwardRef<
})}
</div>
{showLiteLLMConfigurationModal && (
<Modal
onOutsideClick={() => {
resetForm();
setShowLiteLLMConfigurationModal(false);
}}
width="w-[800px]"
title="API Key Configuration"
>
<div className="w-full flex flex-col gap-y-4 px-4">
<TextFormField
subtext="Set the URL at which your LiteLLM Proxy is hosted"
placeholder={values.rerank_api_url || undefined}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({
...values,
rerank_api_url: value,
});
setFieldValue("rerank_api_url", value);
}}
type="text"
label="LiteLLM Proxy URL"
name="rerank_api_url"
/>
<TextFormField
subtext="Set the key to access your LiteLLM Proxy"
placeholder={
values.rerank_api_key
? "*".repeat(values.rerank_api_key.length)
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({
...values,
rerank_api_key: value,
});
setFieldValue("rerank_api_key", value);
}}
type="password"
label="LiteLLM Proxy Key"
name="rerank_api_key"
optional
/>
<TextFormField
subtext="Set the model name to use for LiteLLM Proxy"
placeholder={
values.rerank_model_name
? "*".repeat(values.rerank_model_name.length)
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({
...values,
rerank_model_name: value,
});
setFieldValue("rerank_model_name", value);
}}
label="LiteLLM Model Name"
name="rerank_model_name"
optional
/>
<div className="flex w-full justify-end mt-4">
<Button
onClick={() => {
setShowLiteLLMConfigurationModal(false);
}}
color="blue"
size="xs"
>
Update
</Button>
</div>
</div>
</Modal>
)}
{isApiKeyModalOpen && (
<Modal
onOutsideClick={() => {
@ -218,7 +331,11 @@ const RerankingDetailsForm = forwardRef<
>
<div className="w-full px-4">
<TextFormField
placeholder={values.rerank_api_key || undefined}
placeholder={
values.rerank_api_key
? "*".repeat(values.rerank_api_key.length)
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({

View File

@ -5,11 +5,13 @@ export interface RerankingDetails {
rerank_model_name: string | null;
rerank_provider_type: RerankerProvider | null;
rerank_api_key: string | null;
rerank_api_url: string | null;
num_rerank: number;
}
export enum RerankerProvider {
COHERE = "cohere",
LITELLM = "litellm",
}
export interface AdvancedSearchConfiguration {
model_name: string;
@ -40,7 +42,7 @@ export interface SavedSearchSettings extends RerankingDetails {
export interface RerankingModel {
rerank_provider_type: RerankerProvider | null;
modelName: string;
modelName?: string;
displayName: string;
description: string;
link: string;
@ -48,6 +50,13 @@ export interface RerankingModel {
}
export const rerankingModels: RerankingModel[] = [
{
rerank_provider_type: RerankerProvider.LITELLM,
cloud: true,
displayName: "LiteLLM",
description: "Host your own reranker or router with LiteLLM proxy",
link: "https://docs.litellm.ai/docs/proxy",
},
{
rerank_provider_type: null,
cloud: false,

View File

@ -4,7 +4,7 @@ import * as Yup from "yup";
import CredentialSubText from "@/components/credentials/CredentialFields";
import { TrashIcon } from "@/components/icons/icons";
import { FaPlus } from "react-icons/fa";
import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces";
import { AdvancedSearchConfiguration } from "../interfaces";
import { BooleanFormField } from "@/components/admin/connectors/Field";
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";

View File

@ -10,7 +10,7 @@ import {
CloudEmbeddingModel,
EmbeddingProvider,
HostedEmbeddingModel,
} from "../../../../components/embedding/interfaces";
} from "@/components/embedding/interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { ErrorCallout } from "@/components/ErrorCallout";
import useSWR, { mutate } from "swr";
@ -18,7 +18,6 @@ import { ThreeDotsLoader } from "@/components/Loading";
import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage";
import {
AdvancedSearchConfiguration,
RerankerProvider,
RerankingDetails,
SavedSearchSettings,
} from "../interfaces";
@ -49,6 +48,7 @@ export default function EmbeddingForm() {
num_rerank: 0,
rerank_provider_type: null,
rerank_model_name: "",
rerank_api_url: null,
});
const updateAdvancedEmbeddingDetails = (
@ -124,6 +124,7 @@ export default function EmbeddingForm() {
num_rerank: searchSettings.num_rerank,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
rerank_api_url: searchSettings.rerank_api_url,
});
}
}, [searchSettings]);
@ -134,12 +135,14 @@ export default function EmbeddingForm() {
num_rerank: searchSettings.num_rerank,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
rerank_api_url: searchSettings.rerank_api_url,
}
: {
rerank_api_key: "",
num_rerank: 0,
rerank_provider_type: null,
rerank_model_name: "",
rerank_api_url: null,
};
useEffect(() => {

View File

@ -0,0 +1,116 @@
import { INTERNAL_URL } from "@/lib/constants";
import { NextRequest, NextResponse } from "next/server";
/* NextJS is annoying and makes use use a separate function for
each request type >:( */
export async function GET(
request: NextRequest,
{ params }: { params: { path: string[] } }
) {
return handleRequest(request, params.path);
}
export async function POST(
request: NextRequest,
{ params }: { params: { path: string[] } }
) {
return handleRequest(request, params.path);
}
export async function PUT(
request: NextRequest,
{ params }: { params: { path: string[] } }
) {
return handleRequest(request, params.path);
}
export async function PATCH(
request: NextRequest,
{ params }: { params: { path: string[] } }
) {
return handleRequest(request, params.path);
}
export async function DELETE(
request: NextRequest,
{ params }: { params: { path: string[] } }
) {
return handleRequest(request, params.path);
}
export async function HEAD(
request: NextRequest,
{ params }: { params: { path: string[] } }
) {
return handleRequest(request, params.path);
}
export async function OPTIONS(
request: NextRequest,
{ params }: { params: { path: string[] } }
) {
return handleRequest(request, params.path);
}
async function handleRequest(request: NextRequest, path: string[]) {
if (process.env.NODE_ENV !== "development") {
return NextResponse.json(
{
message:
"This API is only available in development mode. In production, something else (e.g. nginx) should handle this.",
},
{ status: 404 }
);
}
try {
const backendUrl = new URL(`${INTERNAL_URL}/${path.join("/")}`);
// Get the URL parameters from the request
const urlParams = new URLSearchParams(request.url.split("?")[1]);
// Append the URL parameters to the backend URL
urlParams.forEach((value, key) => {
backendUrl.searchParams.append(key, value);
});
const response = await fetch(backendUrl, {
method: request.method,
headers: request.headers,
body: request.body,
// @ts-ignore
duplex: "half",
});
// Check if the response is a stream
if (
response.headers.get("Transfer-Encoding") === "chunked" ||
response.headers.get("Content-Type")?.includes("stream")
) {
// If it's a stream, create a TransformStream to pass the data through
const { readable, writable } = new TransformStream();
response.body?.pipeTo(writable);
return new NextResponse(readable, {
status: response.status,
headers: response.headers,
});
} else {
return new NextResponse(response.body, {
status: response.status,
headers: response.headers,
});
}
} catch (error: unknown) {
console.error("Proxy error:", error);
return NextResponse.json(
{
message: "Proxy error",
error:
error instanceof Error ? error.message : "An unknown error occurred",
},
{ status: 500 }
);
}
}

View File

@ -98,6 +98,7 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import { SEARCH_TOOL_NAME } from "./tools/constants";
import { useUser } from "@/components/user/UserProvider";
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@ -106,12 +107,10 @@ const SYSTEM_MESSAGE_ID = -3;
export function ChatPage({
toggle,
documentSidebarInitialWidth,
defaultSelectedAssistantId,
toggledSidebar,
}: {
toggle: (toggled?: boolean) => void;
documentSidebarInitialWidth?: number;
defaultSelectedAssistantId?: number;
toggledSidebar: boolean;
}) {
const router = useRouter();
@ -126,8 +125,13 @@ export function ChatPage({
folders,
openedFolders,
userInputPrompts,
defaultAssistantId,
shouldShowWelcomeModal,
refreshChatSessions,
} = useChatContext();
const [showApiKeyModal, setShowApiKeyModal] = useState(true);
const { user, refreshUser, isLoadingUser } = useUser();
// chat session
@ -162,9 +166,9 @@ export function ChatPage({
? availableAssistants.find(
(assistant) => assistant.id === existingChatSessionAssistantId
)
: defaultSelectedAssistantId !== undefined
: defaultAssistantId !== undefined
? availableAssistants.find(
(assistant) => assistant.id === defaultSelectedAssistantId
(assistant) => assistant.id === defaultAssistantId
)
: undefined
);
@ -327,8 +331,8 @@ export function ChatPage({
async function initialSessionFetch() {
if (existingChatSessionId === null) {
setIsFetchingChatMessages(false);
if (defaultSelectedAssistantId !== undefined) {
setSelectedAssistantFromId(defaultSelectedAssistantId);
if (defaultAssistantId !== undefined) {
setSelectedAssistantFromId(defaultAssistantId);
} else {
setSelectedAssistant(undefined);
}
@ -402,7 +406,7 @@ export function ChatPage({
// force re-name if the chat session doesn't have one
if (!chatSession.description) {
await nameChatSession(existingChatSessionId, seededMessage);
router.refresh(); // need to refresh to update name on sidebar
refreshChatSessions();
}
}
}
@ -676,12 +680,10 @@ export function ChatPage({
useEffect(() => {
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
setSelectedAssistant(
filteredAssistants.find(
(persona) => persona.id === defaultSelectedAssistantId
)
filteredAssistants.find((persona) => persona.id === defaultAssistantId)
);
}
}, [defaultSelectedAssistantId]);
}, [defaultAssistantId]);
const [
selectedDocuments,
@ -1111,6 +1113,12 @@ export function ChatPage({
console.error(
"First packet should contain message response info "
);
if (Object.hasOwn(packet, "error")) {
const error = (packet as StreamingError).error;
setLoadingError(error);
updateChatState("input");
return;
}
continue;
}
@ -1330,6 +1338,7 @@ export function ChatPage({
if (!searchParamBasedChatSessionName) {
await new Promise((resolve) => setTimeout(resolve, 200));
await nameChatSession(currChatSessionId, currMessage);
refreshChatSessions();
}
// NOTE: don't switch pages if the user has navigated away from the chat
@ -1465,6 +1474,7 @@ export function ChatPage({
// Used to maintain a "time out" for history sidebar so our existing refs can have time to process change
const [untoggled, setUntoggled] = useState(false);
const [loadingError, setLoadingError] = useState<string | null>(null);
const explicitlyUntoggle = () => {
setShowDocSidebar(false);
@ -1588,6 +1598,11 @@ export function ChatPage({
return (
<>
<HealthCheckBanner />
{showApiKeyModal && !shouldShowWelcomeModal && (
<ApiKeyModal hide={() => setShowApiKeyModal(false)} />
)}
{/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit.
Only used in the EE version of the app. */}
{popup}
@ -1760,7 +1775,6 @@ export function ChatPage({
className={`h-full w-full relative flex-auto transition-margin duration-300 overflow-x-auto mobile:pb-12 desktop:pb-[100px]`}
{...getRootProps()}
>
{/* <input {...getInputProps()} /> */}
<div
className={`w-full h-full flex flex-col overflow-y-auto include-scrollbar overflow-x-hidden relative`}
ref={scrollableDivRef}
@ -1770,7 +1784,8 @@ export function ChatPage({
{messageHistory.length === 0 &&
!isFetchingChatMessages &&
currentSessionChatState == "input" && (
currentSessionChatState == "input" &&
!loadingError && (
<ChatIntro
availableSources={finalAvailableSources}
selectedPersona={liveAssistant}
@ -2078,16 +2093,17 @@ export function ChatPage({
}
})}
{currentSessionChatState == "loading" &&
!currentSessionRegenerationState?.regenerating &&
messageHistory[messageHistory.length - 1]?.type !=
"user" && (
<HumanMessage
key={-2}
messageId={-1}
content={submittedMessage}
/>
)}
{currentSessionChatState == "loading" ||
(loadingError &&
!currentSessionRegenerationState?.regenerating &&
messageHistory[messageHistory.length - 1]
?.type != "user" && (
<HumanMessage
key={-2}
messageId={-1}
content={submittedMessage}
/>
))}
{currentSessionChatState == "loading" && (
<div
@ -2116,6 +2132,20 @@ export function ChatPage({
</div>
)}
{loadingError && (
<div key={-1}>
<AIMessage
currentPersona={liveAssistant}
messageId={-1}
personaName={liveAssistant.name}
content={
<p className="text-red-700 text-sm my-auto">
{loadingError}
</p>
}
/>
</div>
)}
{currentPersona &&
currentPersona.starter_messages &&
currentPersona.starter_messages.length > 0 &&
@ -2177,6 +2207,9 @@ export function ChatPage({
</div>
)}
<ChatInputBar
showConfigureAPIKey={() =>
setShowApiKeyModal(true)
}
chatState={currentSessionChatState}
stopGenerating={stopGenerating}
openModelSettings={() => setSettingsToggled(true)}

View File

@ -3,21 +3,15 @@ import { ChatPage } from "./ChatPage";
import FunctionalWrapper from "./shared_chat_search/FunctionalWrapper";
export default function WrappedChat({
defaultAssistantId,
initiallyToggled,
}: {
defaultAssistantId?: number;
initiallyToggled: boolean;
}) {
return (
<FunctionalWrapper
initiallyToggled={initiallyToggled}
content={(toggledSidebar, toggle) => (
<ChatPage
toggle={toggle}
defaultSelectedAssistantId={defaultAssistantId}
toggledSidebar={toggledSidebar}
/>
<ChatPage toggle={toggle} toggledSidebar={toggledSidebar} />
)}
/>
);

View File

@ -33,12 +33,15 @@ import { Tooltip } from "@/components/tooltip/Tooltip";
import { Hoverable } from "@/components/Hoverable";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { ChatState } from "../types";
import UnconfiguredProviderText from "@/components/chat_search/UnconfiguredProviderText";
import { useSearchContext } from "@/components/context/SearchContext";
const MAX_INPUT_HEIGHT = 200;
export function ChatInputBar({
openModelSettings,
showDocs,
showConfigureAPIKey,
selectedDocuments,
message,
setMessage,
@ -62,6 +65,7 @@ export function ChatInputBar({
chatSessionId,
inputPrompts,
}: {
showConfigureAPIKey: () => void;
openModelSettings: () => void;
chatState: ChatState;
stopGenerating: () => void;
@ -111,6 +115,7 @@ export function ChatInputBar({
}
}
};
const settings = useContext(SettingsContext);
const { llmProviders } = useChatContext();
@ -364,6 +369,9 @@ export function ChatInputBar({
<div>
<SelectedFilterDisplay filterManager={filterManager} />
</div>
<UnconfiguredProviderText showConfigureAPIKey={showConfigureAPIKey} />
<div
className="
opacity-100

View File

@ -2,10 +2,10 @@ import { redirect } from "next/navigation";
import { unstable_noStore as noStore } from "next/cache";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import { WelcomeModal } from "@/components/initialSetup/welcome/WelcomeModalWrapper";
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
import { ChatProvider } from "@/components/context/ChatContext";
import { fetchChatData } from "@/lib/chat/fetchChatData";
import WrappedChat from "./WrappedChat";
import { ProviderContextProvider } from "@/components/chat_search/ProviderContext";
export default async function Page({
searchParams,
@ -23,7 +23,6 @@ export default async function Page({
const {
user,
chatSessions,
ccPairs,
availableSources,
documentSets,
assistants,
@ -33,9 +32,7 @@ export default async function Page({
toggleSidebar,
openedFolders,
defaultAssistantId,
finalDocumentSidebarInitialWidth,
shouldShowWelcomeModal,
shouldDisplaySourcesIncompleteModal,
userInputPrompts,
} = data;
@ -43,9 +40,7 @@ export default async function Page({
<>
<InstantSSRAutoRefresh />
{shouldShowWelcomeModal && <WelcomeModal user={user} />}
{!shouldShowWelcomeModal && !shouldDisplaySourcesIncompleteModal && (
<ApiKeyModal user={user} />
)}
<ChatProvider
value={{
chatSessions,
@ -57,12 +52,13 @@ export default async function Page({
folders,
openedFolders,
userInputPrompts,
shouldShowWelcomeModal,
defaultAssistantId,
}}
>
<WrappedChat
defaultAssistantId={defaultAssistantId}
initiallyToggled={toggleSidebar}
/>
<ProviderContextProvider>
<WrappedChat initiallyToggled={toggleSidebar} />
</ProviderContextProvider>
</ChatProvider>
</>
);

View File

@ -6,6 +6,7 @@ import {
} from "@/components/settings/lib";
import {
CUSTOM_ANALYTICS_ENABLED,
EE_ENABLED,
SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED,
} from "@/lib/constants";
import { SettingsProvider } from "@/components/settings/SettingsProvider";
@ -53,6 +54,7 @@ export default async function RootLayout({
children: React.ReactNode;
}) {
const combinedSettings = await fetchSettingsSS();
if (!combinedSettings) {
// Just display a simple full page error if fetching fails.
@ -72,8 +74,34 @@ export default async function RootLayout({
<h1 className="text-2xl font-bold mb-4 text-error">Error</h1>
<p className="text-text-500">
Your Danswer instance was not configured properly and your
settings could not be loaded. Please contact your admin to fix
this error.
settings could not be loaded. This could be due to an admin
configuration issue or an incomplete setup.
</p>
<p className="mt-4">
If you&apos;re an admin, please check{" "}
<a
className="text-link"
href="https://docs.danswer.dev/introduction?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
target="_blank"
rel="noopener noreferrer"
>
our docs
</a>{" "}
to see how to configure Danswer properly. If you&apos;re a user,
please contact your admin to fix this error.
</p>
<p className="mt-4">
For additional support and guidance, you can reach out to our
community on{" "}
<a
className="text-link"
href="https://danswer.ai?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
target="_blank"
rel="noopener noreferrer"
>
Slack
</a>
.
</p>
</Card>
</div>

View File

@ -1,31 +1,12 @@
"use client";
import { SearchSection } from "@/components/search/SearchSection";
import FunctionalWrapper from "../chat/shared_chat_search/FunctionalWrapper";
import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types";
import { Persona } from "../admin/assistants/interfaces";
import { ChatSession } from "../chat/interfaces";
export default function WrappedSearch({
querySessions,
ccPairs,
documentSets,
personas,
searchTypeDefault,
tags,
user,
agenticSearchEnabled,
initiallyToggled,
disabledAgentic,
}: {
disabledAgentic: boolean;
querySessions: ChatSession[];
ccPairs: CCPairBasicInfo[];
documentSets: DocumentSet[];
personas: Persona[];
searchTypeDefault: string;
tags: Tag[];
user: User | null;
agenticSearchEnabled: boolean;
initiallyToggled: boolean;
}) {
return (
@ -33,16 +14,8 @@ export default function WrappedSearch({
initiallyToggled={initiallyToggled}
content={(toggledSidebar, toggle) => (
<SearchSection
disabledAgentic={disabledAgentic}
agenticSearchEnabled={agenticSearchEnabled}
toggle={toggle}
toggledSidebar={toggledSidebar}
querySessions={querySessions}
user={user}
ccPairs={ccPairs}
documentSets={documentSets}
personas={personas}
tags={tags}
defaultSearchType={searchTypeDefault}
/>
)}

View File

@ -5,7 +5,6 @@ import {
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
import { fetchSS } from "@/lib/utilsSS";
import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types";
import { cookies } from "next/headers";
@ -34,6 +33,8 @@ import {
DISABLE_LLM_DOC_RELEVANCE,
} from "@/lib/constants";
import WrappedSearch from "./WrappedSearch";
import { SearchProvider } from "@/components/context/SearchContext";
import { ProviderContextProvider } from "@/components/chat_search/ProviderContext";
export default async function Home() {
// Disable caching so we always get the up to date connector / document set / persona info
@ -185,10 +186,6 @@ export default async function Home() {
{shouldShowWelcomeModal && <WelcomeModal user={user} />}
<InstantSSRAutoRefresh />
{!shouldShowWelcomeModal &&
!shouldDisplayNoSourcesModal &&
!shouldDisplaySourcesIncompleteModal && <ApiKeyModal user={user} />}
{shouldDisplayNoSourcesModal && <NoSourcesModal />}
{shouldDisplaySourcesIncompleteModal && (
@ -199,18 +196,27 @@ export default async function Home() {
Only used in the EE version of the app. */}
<ChatPopup />
<WrappedSearch
disabledAgentic={DISABLE_LLM_DOC_RELEVANCE}
initiallyToggled={toggleSidebar}
querySessions={querySessions}
user={user}
ccPairs={ccPairs}
documentSets={documentSets}
personas={assistants}
tags={tags}
searchTypeDefault={searchTypeDefault}
agenticSearchEnabled={agenticSearchEnabled}
/>
<SearchProvider
value={{
querySessions,
ccPairs,
documentSets,
assistants,
tags,
agenticSearchEnabled,
disabledAgentic: DISABLE_LLM_DOC_RELEVANCE,
initiallyToggled: toggleSidebar,
shouldShowWelcomeModal,
shouldDisplayNoSources: shouldDisplayNoSourcesModal,
}}
>
<ProviderContextProvider>
<WrappedSearch
initiallyToggled={toggleSidebar}
searchTypeDefault={searchTypeDefault}
/>
</ProviderContextProvider>
</SearchProvider>
</>
);
}

View File

@ -198,7 +198,7 @@ export function TextFormField({
rounded-lg
w-full
py-2
px-3
px-3
mt-1
placeholder:font-description
placeholder:text-base

View File

@ -0,0 +1,70 @@
"use client";
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import React, { createContext, useContext, useState, useEffect } from "react";
import { useUser } from "../user/UserProvider";
import { useRouter } from "next/navigation";
import { checkLlmProvider } from "../initialSetup/welcome/lib";
interface ProviderContextType {
shouldShowConfigurationNeeded: boolean;
providerOptions: WellKnownLLMProviderDescriptor[];
refreshProviderInfo: () => Promise<void>; // Add this line
}
const ProviderContext = createContext<ProviderContextType | undefined>(
undefined
);
export function ProviderContextProvider({
children,
}: {
children: React.ReactNode;
}) {
const { user } = useUser();
const router = useRouter();
const [validProviderExists, setValidProviderExists] = useState<boolean>(true);
const [providerOptions, setProviderOptions] = useState<
WellKnownLLMProviderDescriptor[]
>([]);
const fetchProviderInfo = async () => {
const { providers, options, defaultCheckSuccessful } =
await checkLlmProvider(user);
setValidProviderExists(providers.length > 0 && defaultCheckSuccessful);
setProviderOptions(options);
};
useEffect(() => {
fetchProviderInfo();
}, [router, user]);
const shouldShowConfigurationNeeded =
!validProviderExists && providerOptions.length > 0;
const refreshProviderInfo = async () => {
await fetchProviderInfo();
};
return (
<ProviderContext.Provider
value={{
shouldShowConfigurationNeeded,
providerOptions,
refreshProviderInfo, // Add this line
}}
>
{children}
</ProviderContext.Provider>
);
}
export function useProviderStatus() {
const context = useContext(ProviderContext);
if (context === undefined) {
throw new Error(
"useProviderStatus must be used within a ProviderContextProvider"
);
}
return context;
}

View File

@ -0,0 +1,27 @@
import { useProviderStatus } from "./ProviderContext";
export default function CredentialNotConfigured({
showConfigureAPIKey,
}: {
showConfigureAPIKey: () => void;
}) {
const { shouldShowConfigurationNeeded } = useProviderStatus();
if (!shouldShowConfigurationNeeded) {
return null;
}
return (
<p className="text-base text-center w-full text-subtle">
Please note that you have not yet configured an LLM provider. You can
configure one{" "}
<button
onClick={showConfigureAPIKey}
className="text-link hover:underline cursor-pointer"
>
here
</button>
.
</p>
);
}

View File

@ -1,4 +1,4 @@
import { Dispatch, SetStateAction, useEffect, useRef, useState } from "react";
import { Dispatch, SetStateAction, useEffect, useRef } from "react";
interface UseSidebarVisibilityProps {
toggledSidebar: boolean;

View File

@ -1,6 +1,6 @@
"use client";
import React, { createContext, useContext } from "react";
import React, { createContext, useContext, useState } from "react";
import { DocumentSet, Tag, User, ValidSources } from "@/lib/types";
import { ChatSession } from "@/app/chat/interfaces";
import { Persona } from "@/app/admin/assistants/interfaces";
@ -18,15 +18,40 @@ interface ChatContextProps {
folders: Folder[];
openedFolders: Record<string, boolean>;
userInputPrompts: InputPrompt[];
shouldShowWelcomeModal?: boolean;
shouldDisplaySourcesIncompleteModal?: boolean;
defaultAssistantId?: number;
refreshChatSessions: () => Promise<void>;
}
const ChatContext = createContext<ChatContextProps | undefined>(undefined);
// We use Omit to exclude 'refreshChatSessions' from the value prop type
// because we're defining it within the component
export const ChatProvider: React.FC<{
value: ChatContextProps;
value: Omit<ChatContextProps, "refreshChatSessions">;
children: React.ReactNode;
}> = ({ value, children }) => {
return <ChatContext.Provider value={value}>{children}</ChatContext.Provider>;
const [chatSessions, setChatSessions] = useState(value?.chatSessions || []);
const refreshChatSessions = async () => {
try {
const response = await fetch("/api/chat/get-user-chat-sessions");
if (!response.ok) throw new Error("Failed to fetch chat sessions");
const { sessions } = await response.json();
setChatSessions(sessions);
} catch (error) {
console.error("Error refreshing chat sessions:", error);
}
};
return (
<ChatContext.Provider
value={{ ...value, chatSessions, refreshChatSessions }}
>
{children}
</ChatContext.Provider>
);
};
export const useChatContext = (): ChatContextProps => {

View File

@ -0,0 +1,38 @@
"use client";
import React, { createContext, useContext } from "react";
import { CCPairBasicInfo, DocumentSet, Tag } from "@/lib/types";
import { Persona } from "@/app/admin/assistants/interfaces";
import { ChatSession } from "@/app/chat/interfaces";
interface SearchContextProps {
querySessions: ChatSession[];
ccPairs: CCPairBasicInfo[];
documentSets: DocumentSet[];
assistants: Persona[];
tags: Tag[];
agenticSearchEnabled: boolean;
disabledAgentic: boolean;
initiallyToggled: boolean;
shouldShowWelcomeModal: boolean;
shouldDisplayNoSources: boolean;
}
const SearchContext = createContext<SearchContextProps | undefined>(undefined);
export const SearchProvider: React.FC<{
value: SearchContextProps;
children: React.ReactNode;
}> = ({ value, children }) => {
return (
<SearchContext.Provider value={value}>{children}</SearchContext.Provider>
);
};
export const useSearchContext = (): SearchContextProps => {
const context = useContext(SearchContext);
if (!context) {
throw new Error("useSearchContext must be used within a SearchProvider");
}
return context;
};

View File

@ -27,13 +27,11 @@ function UsageTypeSection({
title,
description,
callToAction,
icon,
onClick,
}: {
title: string;
description: string | JSX.Element;
callToAction: string;
icon?: React.ElementType;
onClick: () => void;
}) {
return (
@ -243,7 +241,6 @@ export function _WelcomeModal({ user }: { user: User | null }) {
this is the option for you!
</Text>
}
icon={FiMessageSquare}
callToAction="Get Started"
onClick={() => {
setSelectedFlow("chat");

View File

@ -55,6 +55,7 @@ export const ApiKeyForm = ({
return (
<TabPanel key={provider.name}>
<LLMProviderUpdateForm
hideAdvanced
llmProviderDescriptor={provider}
onClose={() => onSuccess()}
shouldMarkAsDefault

View File

@ -1,60 +1,38 @@
"use client";
import { useState, useEffect } from "react";
import { ApiKeyForm } from "./ApiKeyForm";
import { Modal } from "../Modal";
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { checkLlmProvider } from "../initialSetup/welcome/lib";
import { User } from "@/lib/types";
import { useRouter } from "next/navigation";
import { useProviderStatus } from "../chat_search/ProviderContext";
export const ApiKeyModal = ({ user }: { user: User | null }) => {
export const ApiKeyModal = ({ hide }: { hide: () => void }) => {
const router = useRouter();
const [forceHidden, setForceHidden] = useState<boolean>(false);
const [validProviderExists, setValidProviderExists] = useState<boolean>(true);
const [providerOptions, setProviderOptions] = useState<
WellKnownLLMProviderDescriptor[]
>([]);
const {
shouldShowConfigurationNeeded,
providerOptions,
refreshProviderInfo,
} = useProviderStatus();
useEffect(() => {
async function fetchProviderInfo() {
const { providers, options, defaultCheckSuccessful } =
await checkLlmProvider(user);
setValidProviderExists(providers.length > 0 && defaultCheckSuccessful);
setProviderOptions(options);
}
fetchProviderInfo();
}, []);
// don't show if
// (1) a valid provider has been setup or
// (2) there are no provider options (e.g. user isn't an admin)
// (3) user explicitly hides the modal
if (validProviderExists || !providerOptions.length || forceHidden) {
if (!shouldShowConfigurationNeeded) {
return null;
}
return (
<Modal
title="LLM Key Setup"
className="max-w-4xl"
onOutsideClick={() => setForceHidden(true)}
title="Set an API Key!"
className="max-w-3xl"
onOutsideClick={() => hide()}
>
<div className="max-h-[75vh] overflow-y-auto flex flex-col px-4">
<div>
<div className="mb-5 text-sm">
Please setup an LLM below in order to start using Danswer Search or
Danswer Chat. Don&apos;t worry, you can always change this later in
the Admin Panel.
Please provide an API Key below in order to start using
Danswer  you can always change this later.
<br />
<br />
Or if you&apos;d rather look around first,{" "}
<strong
onClick={() => setForceHidden(true)}
className="text-link cursor-pointer"
>
If you&apos;d rather look around first, you can
<strong onClick={() => hide()} className="text-link cursor-pointer">
{" "}
skip this step
</strong>
.
@ -63,7 +41,8 @@ export const ApiKeyModal = ({ user }: { user: User | null }) => {
<ApiKeyForm
onSuccess={() => {
router.refresh();
setForceHidden(true);
refreshProviderInfo();
hide();
}}
providerOptions={providerOptions}
/>

View File

@ -37,6 +37,10 @@ import { FeedbackModal } from "@/app/chat/modal/FeedbackModal";
import { deleteChatSession, handleChatFeedback } from "@/app/chat/lib";
import SearchAnswer from "./SearchAnswer";
import { DeleteEntityModal } from "../modals/DeleteEntityModal";
import { ApiKeyModal } from "../llm/ApiKeyModal";
import { useSearchContext } from "../context/SearchContext";
import { useUser } from "../user/UserProvider";
import UnconfiguredProviderText from "../chat_search/UnconfiguredProviderText";
export type searchState =
| "input"
@ -58,33 +62,28 @@ const VALID_QUESTION_RESPONSE_DEFAULT: ValidQuestionResponse = {
};
interface SearchSectionProps {
disabledAgentic: boolean;
ccPairs: CCPairBasicInfo[];
documentSets: DocumentSet[];
personas: Persona[];
tags: Tag[];
toggle: () => void;
querySessions: ChatSession[];
defaultSearchType: SearchType;
user: User | null;
toggledSidebar: boolean;
agenticSearchEnabled: boolean;
}
export const SearchSection = ({
ccPairs,
toggle,
disabledAgentic,
documentSets,
agenticSearchEnabled,
personas,
user,
tags,
querySessions,
toggledSidebar,
defaultSearchType,
}: SearchSectionProps) => {
// Search Bar
const {
querySessions,
ccPairs,
documentSets,
assistants,
tags,
shouldShowWelcomeModal,
agenticSearchEnabled,
disabledAgentic,
shouldDisplayNoSources,
} = useSearchContext();
const [query, setQuery] = useState<string>("");
const [comments, setComments] = useState<any>(null);
const [contentEnriched, setContentEnriched] = useState(false);
@ -100,6 +99,8 @@ export const SearchSection = ({
messageId: null,
});
const [showApiKeyModal, setShowApiKeyModal] = useState(true);
const [agentic, setAgentic] = useState(agenticSearchEnabled);
const toggleAgentic = () => {
@ -147,7 +148,7 @@ export const SearchSection = ({
useState<SearchType>(defaultSearchType);
const [selectedPersona, setSelectedPersona] = useState<number>(
personas[0]?.id || 0
assistants[0]?.id || 0
);
// Used for search state display
@ -158,8 +159,8 @@ export const SearchSection = ({
const availableSources = ccPairs.map((ccPair) => ccPair.source);
const [finalAvailableSources, finalAvailableDocumentSets] =
computeAvailableFilters({
selectedPersona: personas.find(
(persona) => persona.id === selectedPersona
selectedPersona: assistants.find(
(assistant) => assistant.id === selectedPersona
),
availableSources: availableSources,
availableDocumentSets: documentSets,
@ -362,6 +363,7 @@ export const SearchSection = ({
setSearchState("input");
}
};
const { user } = useUser();
const [searchAnswerExpanded, setSearchAnswerExpanded] = useState(false);
const resetInput = (finalized?: boolean) => {
@ -403,8 +405,8 @@ export const SearchSection = ({
documentSets: filterManager.selectedDocumentSets,
timeRange: filterManager.timeRange,
tags: filterManager.selectedTags,
persona: personas.find(
(persona) => persona.id === selectedPersona
persona: assistants.find(
(assistant) => assistant.id === selectedPersona
) as Persona,
updateCurrentAnswer: cancellable({
cancellationToken: lastSearchCancellationToken.current,
@ -595,6 +597,12 @@ export const SearchSection = ({
<div className="flex relative pr-[8px] h-full text-default">
{popup}
{!shouldDisplayNoSources &&
showApiKeyModal &&
!shouldShowWelcomeModal && (
<ApiKeyModal hide={() => setShowApiKeyModal(false)} />
)}
{deletingChatSession && (
<DeleteEntityModal
entityType="search"
@ -747,6 +755,11 @@ export const SearchSection = ({
</div>
</div>
</div>
<UnconfiguredProviderText
showConfigureAPIKey={() => setShowApiKeyModal(true)}
/>
<FullSearchBar
toggleAgentic={
disabledAgentic ? undefined : toggleAgentic

View File

@ -44,7 +44,6 @@ interface FetchChatDataResult {
toggleSidebar: boolean;
finalDocumentSidebarInitialWidth?: number;
shouldShowWelcomeModal: boolean;
shouldDisplaySourcesIncompleteModal: boolean;
userInputPrompts: InputPrompt[];
}
@ -242,7 +241,6 @@ export async function fetchChatData(searchParams: {
finalDocumentSidebarInitialWidth,
toggleSidebar,
shouldShowWelcomeModal,
shouldDisplaySourcesIncompleteModal,
userInputPrompts,
};
}