mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 12:30:49 +02:00
Merge branch 'main' into add-user-on-slack-bot-invoke
This commit is contained in:
commit
261c4b7021
@ -75,8 +75,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('wordnet', quiet=True); \
|
||||
nltk.download('punkt', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
@ -0,0 +1,26 @@
|
||||
"""add support for litellm proxy in reranking
|
||||
|
||||
Revision ID: ba98eba0f66a
|
||||
Revises: bceb1e139447
|
||||
Create Date: 2024-09-06 10:36:04.507332
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ba98eba0f66a"
|
||||
down_revision = "bceb1e139447"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_settings", "rerank_api_url")
|
@ -88,3 +88,6 @@ HARD_DELETE_CHATS = False
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
|
||||
# Enable in-house model for detecting connector-based filtering in queries
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
@ -45,10 +45,15 @@ def extract_jira_project(url: str) -> tuple[str, str]:
|
||||
return jira_base, jira_project
|
||||
|
||||
|
||||
def extract_text_from_content(content: dict) -> str:
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
texts = []
|
||||
if "content" in content:
|
||||
for block in content["content"]:
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
@ -72,18 +77,15 @@ def _get_comment_strs(
|
||||
comment_strs = []
|
||||
for comment in jira.fields.comment.comments:
|
||||
try:
|
||||
if hasattr(comment, "body"):
|
||||
body_text = extract_text_from_content(comment.raw["body"])
|
||||
elif hasattr(comment, "raw"):
|
||||
body = comment.raw.get("body", "No body content available")
|
||||
body_text = (
|
||||
extract_text_from_content(body) if isinstance(body, dict) else body
|
||||
)
|
||||
else:
|
||||
body_text = "No body attribute found"
|
||||
body_text = (
|
||||
comment.body
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(comment.raw["body"])
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(comment, "author")
|
||||
and hasattr(comment.author, "emailAddress")
|
||||
and comment.author.emailAddress in comment_email_blacklist
|
||||
):
|
||||
continue # Skip adding comment if author's email is in blacklist
|
||||
@ -126,11 +128,14 @@ def fetch_jira_issues_batch(
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
jira.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = (
|
||||
f"{jira.fields.description}\n"
|
||||
if jira.fields.description
|
||||
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
|
||||
semantic_rep = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
@ -327,7 +327,7 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
export_mime_type = "text/plain"
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
export_mime_type = mime_type
|
||||
|
||||
|
@ -237,6 +237,14 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
continue
|
||||
|
||||
if result_type == "external_object_instance_page":
|
||||
logger.warning(
|
||||
f"Skipping 'external_object_instance_page' ('{result_block_id}') for base block '{base_block_id}': "
|
||||
f"Notion API does not currently support reading external blocks (as of 24/07/03) "
|
||||
f"(discussion: https://github.com/danswer-ai/danswer/issues/1761)"
|
||||
)
|
||||
continue
|
||||
|
||||
cur_result_text_arr = []
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
|
@ -25,7 +25,6 @@ from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@ -137,7 +136,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
.execute_query()
|
||||
]
|
||||
else:
|
||||
sites = self.graph_client.sites.get().execute_query()
|
||||
sites = self.graph_client.sites.get_all().execute_query()
|
||||
self.site_data = [
|
||||
SiteData(url=None, folder=None, sites=sites, driveitems=[])
|
||||
]
|
||||
|
@ -29,6 +29,7 @@ from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
import io
|
||||
import ipaddress
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@ -203,6 +205,15 @@ def _read_urls_file(location: str) -> list[str]:
|
||||
return urls
|
||||
|
||||
|
||||
def _get_datetime_from_last_modified_header(last_modified: str) -> datetime | None:
|
||||
try:
|
||||
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
class WebConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
@ -288,6 +299,7 @@ class WebConnector(LoadConnector):
|
||||
page_text, metadata = read_pdf_file(
|
||||
file=io.BytesIO(response.content)
|
||||
)
|
||||
last_modified = response.headers.get("Last-Modified")
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
@ -296,12 +308,22 @@ class WebConnector(LoadConnector):
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=current_url.split("/")[-1],
|
||||
metadata=metadata,
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
)
|
||||
if last_modified
|
||||
else None,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
page = context.new_page()
|
||||
page_response = page.goto(current_url)
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified")
|
||||
if page_response
|
||||
else None
|
||||
)
|
||||
final_page = page.url
|
||||
if final_page != current_url:
|
||||
logger.info(f"Redirected to {final_page}")
|
||||
@ -337,6 +359,11 @@ class WebConnector(LoadConnector):
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=parsed_html.title or current_url,
|
||||
metadata={},
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
)
|
||||
if last_modified
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -61,7 +61,7 @@ from shared_configs.enums import RerankerProvider
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
@ -450,7 +450,7 @@ class Document(Base):
|
||||
)
|
||||
tags = relationship(
|
||||
"Tag",
|
||||
secondary="document__tag",
|
||||
secondary=Document__Tag.__table__,
|
||||
back_populates="documents",
|
||||
)
|
||||
|
||||
@ -467,7 +467,7 @@ class Tag(Base):
|
||||
|
||||
documents = relationship(
|
||||
"Document",
|
||||
secondary="document__tag",
|
||||
secondary=Document__Tag.__table__,
|
||||
back_populates="tags",
|
||||
)
|
||||
|
||||
@ -578,6 +578,8 @@ class SearchSettings(Base):
|
||||
Enum(RerankerProvider, native_enum=False), nullable=True
|
||||
)
|
||||
rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS)
|
||||
|
||||
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
|
||||
@ -816,7 +818,7 @@ class SearchDoc(Base):
|
||||
|
||||
chat_messages = relationship(
|
||||
"ChatMessage",
|
||||
secondary="chat_message__search_doc",
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="search_docs",
|
||||
)
|
||||
|
||||
@ -959,7 +961,7 @@ class ChatMessage(Base):
|
||||
)
|
||||
search_docs: Mapped[list["SearchDoc"]] = relationship(
|
||||
"SearchDoc",
|
||||
secondary="chat_message__search_doc",
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
|
@ -240,9 +240,43 @@ def read_pdf_file(
|
||||
|
||||
|
||||
def docx_to_text(file: IO[Any]) -> str:
|
||||
def is_simple_table(table: docx.table.Table) -> bool:
|
||||
for row in table.rows:
|
||||
# No omitted cells
|
||||
if row.grid_cols_before > 0 or row.grid_cols_after > 0:
|
||||
return False
|
||||
|
||||
# No nested tables
|
||||
if any(cell.tables for cell in row.cells):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extract_cell_text(cell: docx.table._Cell) -> str:
|
||||
cell_paragraphs = [para.text.strip() for para in cell.paragraphs]
|
||||
return " ".join(p for p in cell_paragraphs if p) or "N/A"
|
||||
|
||||
paragraphs = []
|
||||
doc = docx.Document(file)
|
||||
full_text = [para.text for para in doc.paragraphs]
|
||||
return TEXT_SECTION_SEPARATOR.join(full_text)
|
||||
for item in doc.iter_inner_content():
|
||||
if isinstance(item, docx.text.paragraph.Paragraph):
|
||||
paragraphs.append(item.text)
|
||||
|
||||
elif isinstance(item, docx.table.Table):
|
||||
if not item.rows or not is_simple_table(item):
|
||||
continue
|
||||
|
||||
# Every row is a new line, joined with a single newline
|
||||
table_content = "\n".join(
|
||||
[
|
||||
",\t".join(extract_cell_text(cell) for cell in row.cells)
|
||||
for row in item.rows
|
||||
]
|
||||
)
|
||||
paragraphs.append(table_content)
|
||||
|
||||
# Docx already has good spacing between paragraphs
|
||||
return "\n".join(paragraphs)
|
||||
|
||||
|
||||
def pptx_to_text(file: IO[Any]) -> str:
|
||||
|
@ -392,8 +392,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
logger.notice(
|
||||
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
|
||||
)
|
||||
|
||||
if search_settings.rerank_model_name and not search_settings.provider_type:
|
||||
if (
|
||||
search_settings.rerank_model_name
|
||||
and not search_settings.provider_type
|
||||
and not search_settings.rerank_provider_type
|
||||
):
|
||||
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||
|
||||
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
|
@ -24,6 +24,8 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
@ -240,6 +242,7 @@ class RerankingModel:
|
||||
model_name: str,
|
||||
provider_type: RerankerProvider | None,
|
||||
api_key: str | None,
|
||||
api_url: str | None,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
@ -248,6 +251,7 @@ class RerankingModel:
|
||||
self.model_name = model_name
|
||||
self.provider_type = provider_type
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[float]:
|
||||
rerank_request = RerankRequest(
|
||||
@ -256,6 +260,7 @@ class RerankingModel:
|
||||
model_name=self.model_name,
|
||||
provider_type=self.provider_type,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
@ -301,6 +306,37 @@ class QueryAnalysisModel:
|
||||
return response_model.is_keyword, response_model.keywords
|
||||
|
||||
|
||||
class ConnectorClassificationModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
):
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.connector_classification_endpoint = (
|
||||
model_server_url + "/custom/connector-classification"
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
query: str,
|
||||
available_connectors: list[str],
|
||||
) -> list[str]:
|
||||
connector_classification_request = ConnectorClassificationRequest(
|
||||
available_connectors=available_connectors,
|
||||
query=query,
|
||||
)
|
||||
response = requests.post(
|
||||
self.connector_classification_endpoint,
|
||||
json=connector_classification_request.dict(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_model = ConnectorClassificationResponse(**response.json())
|
||||
|
||||
return response_model.connectors
|
||||
|
||||
|
||||
def warm_up_retry(
|
||||
func: Callable[..., Any],
|
||||
tries: int = 20,
|
||||
@ -367,6 +403,7 @@ def warm_up_cross_encoder(
|
||||
reranking_model = RerankingModel(
|
||||
model_name=rerank_model_name,
|
||||
provider_type=None,
|
||||
api_url=None,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
|
@ -26,6 +26,7 @@ MAX_METRICS_CONTENT = (
|
||||
class RerankingDetails(BaseModel):
|
||||
# If model is None (or num_rerank is 0), then reranking is turned off
|
||||
rerank_model_name: str | None
|
||||
rerank_api_url: str | None
|
||||
rerank_provider_type: RerankerProvider | None
|
||||
rerank_api_key: str | None = None
|
||||
|
||||
@ -42,6 +43,7 @@ class RerankingDetails(BaseModel):
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
rerank_api_key=search_settings.rerank_api_key,
|
||||
num_rerank=search_settings.num_rerank,
|
||||
rerank_api_url=search_settings.rerank_api_url,
|
||||
)
|
||||
|
||||
|
||||
@ -81,7 +83,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
num_rerank=search_settings.num_rerank,
|
||||
# Multilingual Expansion
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
api_url=search_settings.api_url,
|
||||
rerank_api_url=search_settings.rerank_api_url,
|
||||
)
|
||||
|
||||
|
||||
|
@ -100,6 +100,7 @@ def semantic_reranking(
|
||||
model_name=rerank_settings.rerank_model_name,
|
||||
provider_type=rerank_settings.rerank_provider_type,
|
||||
api_key=rerank_settings.rerank_api_key,
|
||||
api_url=rerank_settings.rerank_api_url,
|
||||
)
|
||||
|
||||
passages = [
|
||||
|
@ -3,7 +3,6 @@ from collections.abc import Callable
|
||||
|
||||
import nltk # type:ignore
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.stem import WordNetLemmatizer # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -40,7 +39,7 @@ logger = setup_logger()
|
||||
def download_nltk_data() -> None:
|
||||
resources = {
|
||||
"stopwords": "corpora/stopwords",
|
||||
"wordnet": "corpora/wordnet",
|
||||
# "wordnet": "corpora/wordnet", # Not in use
|
||||
"punkt": "tokenizers/punkt",
|
||||
}
|
||||
|
||||
@ -58,15 +57,16 @@ def download_nltk_data() -> None:
|
||||
|
||||
|
||||
def lemmatize_text(keywords: list[str]) -> list[str]:
|
||||
try:
|
||||
query = " ".join(keywords)
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
word_tokens = word_tokenize(query)
|
||||
lemmatized_words = [lemmatizer.lemmatize(word) for word in word_tokens]
|
||||
combined_keywords = list(set(keywords + lemmatized_words))
|
||||
return combined_keywords
|
||||
except Exception:
|
||||
return keywords
|
||||
raise NotImplementedError("Lemmatization should not be used currently")
|
||||
# try:
|
||||
# query = " ".join(keywords)
|
||||
# lemmatizer = WordNetLemmatizer()
|
||||
# word_tokens = word_tokenize(query)
|
||||
# lemmatized_words = [lemmatizer.lemmatize(word) for word in word_tokens]
|
||||
# combined_keywords = list(set(keywords + lemmatized_words))
|
||||
# return combined_keywords
|
||||
# except Exception:
|
||||
# return keywords
|
||||
|
||||
|
||||
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
||||
|
@ -3,12 +3,16 @@ import random
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import ENABLE_CONNECTOR_CLASSIFIER
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_unique_document_sources
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.natural_language_processing.search_nlp_models import (
|
||||
ConnectorClassificationModel,
|
||||
)
|
||||
from danswer.prompts.constants import SOURCES_KEY
|
||||
from danswer.prompts.filter_extration import FILE_SOURCE_WARNING
|
||||
from danswer.prompts.filter_extration import SOURCE_FILTER_PROMPT
|
||||
@ -42,11 +46,38 @@ def _sample_document_sources(
|
||||
return random.sample(valid_sources, num_sample)
|
||||
|
||||
|
||||
def _sample_documents_using_custom_connector_classifier(
|
||||
query: str,
|
||||
valid_sources: list[DocumentSource],
|
||||
) -> list[DocumentSource] | None:
|
||||
query_joined = "".join(ch for ch in query.lower() if ch.isalnum())
|
||||
available_connectors = list(
|
||||
filter(
|
||||
lambda conn: conn.lower() in query_joined,
|
||||
[item.value for item in valid_sources],
|
||||
)
|
||||
)
|
||||
|
||||
if not available_connectors:
|
||||
return None
|
||||
|
||||
connectors = ConnectorClassificationModel().predict(query, available_connectors)
|
||||
|
||||
return strings_to_document_sources(connectors) if connectors else None
|
||||
|
||||
|
||||
def extract_source_filter(
|
||||
query: str, llm: LLM, db_session: Session
|
||||
) -> list[DocumentSource] | None:
|
||||
"""Returns a list of valid sources for search or None if no specific sources were detected"""
|
||||
|
||||
valid_sources = fetch_unique_document_sources(db_session)
|
||||
if not valid_sources:
|
||||
return None
|
||||
|
||||
if ENABLE_CONNECTOR_CLASSIFIER:
|
||||
return _sample_documents_using_custom_connector_classifier(query, valid_sources)
|
||||
|
||||
def _get_source_filter_messages(
|
||||
query: str,
|
||||
valid_sources: list[DocumentSource],
|
||||
@ -146,10 +177,6 @@ def extract_source_filter(
|
||||
logger.warning("LLM failed to provide a valid Source Filter output")
|
||||
return None
|
||||
|
||||
valid_sources = fetch_unique_document_sources(db_session)
|
||||
if not valid_sources:
|
||||
return None
|
||||
|
||||
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
||||
|
@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@ -346,8 +347,18 @@ class GoogleServiceAccountKey(BaseModel):
|
||||
|
||||
|
||||
class GoogleServiceAccountCredentialRequest(BaseModel):
|
||||
google_drive_delegated_user: str | None # email of user to impersonate
|
||||
gmail_delegated_user: str | None # email of user to impersonate
|
||||
google_drive_delegated_user: str | None = None # email of user to impersonate
|
||||
gmail_delegated_user: str | None = None # email of user to impersonate
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_user_delegation(self) -> "GoogleServiceAccountCredentialRequest":
|
||||
if (self.google_drive_delegated_user is None) == (
|
||||
self.gmail_delegated_user is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Exactly one of google_drive_delegated_user or gmail_delegated_user must be set"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class FileUploadResponse(BaseModel):
|
||||
|
@ -3,15 +3,21 @@ import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import BatchEncoding
|
||||
from transformers import BatchEncoding # type: ignore
|
||||
from transformers import PreTrainedTokenizer # type: ignore
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.danswer_torch_model import ConnectorClassifier
|
||||
from model_server.danswer_torch_model import HybridClassifier
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import INTENT_MODEL_TAG
|
||||
from shared_configs.configs import INTENT_MODEL_VERSION
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
@ -19,10 +25,55 @@ logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
|
||||
"distilbert-base-uncased"
|
||||
)
|
||||
return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
|
||||
|
||||
def get_local_connector_classifier(
|
||||
model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
|
||||
tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
|
||||
) -> ConnectorClassifier:
|
||||
global _CONNECTOR_CLASSIFIER_MODEL
|
||||
if _CONNECTOR_CLASSIFIER_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
local_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.info(f"Downloading model snapshot for {model_name_or_path}")
|
||||
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
|
||||
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
local_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
def get_intent_model_tokenizer() -> AutoTokenizer:
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
@ -61,6 +112,74 @@ def get_local_intent_model(
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
connector_token_end_id: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
|
||||
|
||||
The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
|
||||
token and then the user query.
|
||||
"""
|
||||
|
||||
input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
|
||||
|
||||
for connector in connectors:
|
||||
connector_token_ids = tokenizer(
|
||||
connector,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = torch.cat(
|
||||
(
|
||||
input_ids,
|
||||
connector_token_ids["input_ids"].squeeze(dim=0),
|
||||
torch.tensor([connector_token_end_id], dtype=torch.long),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
query_token_ids = tokenizer(
|
||||
query,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = torch.cat(
|
||||
(
|
||||
input_ids,
|
||||
query_token_ids["input_ids"].squeeze(dim=0),
|
||||
torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
|
||||
|
||||
return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
|
||||
|
||||
|
||||
def warm_up_connector_classifier_model() -> None:
|
||||
logger.info(
|
||||
f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
|
||||
)
|
||||
connector_classifier_tokenizer = get_connector_classifier_tokenizer()
|
||||
connector_classifier = get_local_connector_classifier()
|
||||
|
||||
input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
["GitHub"],
|
||||
"danswer classifier query google doc",
|
||||
connector_classifier_tokenizer,
|
||||
connector_classifier.connector_end_token_id,
|
||||
)
|
||||
input_ids = input_ids.to(connector_classifier.device)
|
||||
attention_mask = attention_mask.to(connector_classifier.device)
|
||||
|
||||
connector_classifier(input_ids, attention_mask)
|
||||
|
||||
|
||||
def warm_up_intent_model() -> None:
|
||||
logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
@ -157,6 +276,35 @@ def clean_keywords(keywords: list[str]) -> list[str]:
|
||||
return cleaned_words
|
||||
|
||||
|
||||
def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
|
||||
tokenizer = get_connector_classifier_tokenizer()
|
||||
model = get_local_connector_classifier()
|
||||
|
||||
connector_names = req.available_connectors
|
||||
|
||||
input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
connector_names,
|
||||
req.query,
|
||||
tokenizer,
|
||||
model.connector_end_token_id,
|
||||
)
|
||||
input_ids = input_ids.to(model.device)
|
||||
attention_mask = attention_mask.to(model.device)
|
||||
|
||||
global_confidence, classifier_confidence = model(input_ids, attention_mask)
|
||||
|
||||
if global_confidence.item() < 0.5:
|
||||
return []
|
||||
|
||||
passed_connectors = []
|
||||
|
||||
for i, connector_name in enumerate(connector_names):
|
||||
if classifier_confidence.view(-1)[i].item() > 0.5:
|
||||
passed_connectors.append(connector_name)
|
||||
|
||||
return passed_connectors
|
||||
|
||||
|
||||
def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
model_input = tokenizer(
|
||||
@ -189,6 +337,22 @@ def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
||||
return is_keyword_sequence, cleaned_keywords
|
||||
|
||||
|
||||
@router.post("/connector-classification")
|
||||
async def process_connector_classification_request(
|
||||
classification_request: ConnectorClassificationRequest,
|
||||
) -> ConnectorClassificationResponse:
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError(
|
||||
"Indexing model server should not call connector classification endpoint"
|
||||
)
|
||||
|
||||
if len(classification_request.available_connectors) == 0:
|
||||
return ConnectorClassificationResponse(connectors=[])
|
||||
|
||||
connectors = run_connector_classification(classification_request)
|
||||
return ConnectorClassificationResponse(connectors=connectors)
|
||||
|
||||
|
||||
@router.post("/query-analysis")
|
||||
async def process_analysis_request(
|
||||
intent_request: IntentRequest,
|
||||
|
@ -4,7 +4,8 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
from transformers import DistilBertModel
|
||||
from transformers import DistilBertModel # type: ignore
|
||||
from transformers import DistilBertTokenizer # type: ignore
|
||||
|
||||
|
||||
class HybridClassifier(nn.Module):
|
||||
@ -21,7 +22,6 @@ class HybridClassifier(nn.Module):
|
||||
self.distilbert.config.dim, self.distilbert.config.dim
|
||||
)
|
||||
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
||||
self.dropout = nn.Dropout(self.distilbert.config.seq_classif_dropout)
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
@ -36,8 +36,7 @@ class HybridClassifier(nn.Module):
|
||||
# Intent classification on the CLS token
|
||||
cls_token_state = sequence_output[:, 0, :]
|
||||
pre_classifier_out = self.pre_classifier(cls_token_state)
|
||||
dropout_out = self.dropout(pre_classifier_out)
|
||||
intent_logits = self.intent_classifier(dropout_out)
|
||||
intent_logits = self.intent_classifier(pre_classifier_out)
|
||||
|
||||
# Keyword classification on all tokens
|
||||
token_logits = self.keyword_classifier(sequence_output)
|
||||
@ -72,3 +71,70 @@ class HybridClassifier(nn.Module):
|
||||
param.requires_grad = False
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ConnectorClassifier(nn.Module):
|
||||
def __init__(self, config: DistilBertConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.distilbert = DistilBertModel(config)
|
||||
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
||||
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
||||
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
# Token indicating end of connector name, and on which classifier is used
|
||||
self.connector_end_token_id = self.tokenizer.get_vocab()[
|
||||
self.config.connector_end_token
|
||||
]
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.distilbert(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
|
||||
cls_hidden_states = hidden_states[
|
||||
:, 0, :
|
||||
] # Take leap of faith that first token is always [CLS]
|
||||
global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
|
||||
global_confidence = torch.sigmoid(global_logits).view(-1)
|
||||
|
||||
connector_end_position_ids = input_ids == self.connector_end_token_id
|
||||
connector_end_hidden_states = hidden_states[connector_end_position_ids]
|
||||
classifier_output = self.connector_match_classifier(connector_end_hidden_states)
|
||||
classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
|
||||
|
||||
return global_confidence, classifier_confidence
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
|
||||
device = (
|
||||
torch.device("cuda")
|
||||
if torch.cuda.is_available()
|
||||
else torch.device("mps")
|
||||
if torch.backends.mps.is_available()
|
||||
else torch.device("cpu")
|
||||
)
|
||||
state_dict = torch.load(
|
||||
os.path.join(repo_dir, "pytorch_model.pt"),
|
||||
map_location=device,
|
||||
weights_only=True,
|
||||
)
|
||||
|
||||
model = cls(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(device)
|
||||
model.device = device
|
||||
model.eval()
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return model
|
||||
|
@ -362,6 +362,28 @@ def cohere_rerank(
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
def litellm_rerank(
|
||||
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[float]:
|
||||
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [
|
||||
item["relevance_score"]
|
||||
for item in sorted(result["results"], key=lambda x: x["index"])
|
||||
]
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
@ -418,6 +440,20 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
model_name=rerank_request.model_name,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
elif rerank_request.provider_type == RerankerProvider.LITELLM:
|
||||
if rerank_request.api_url is None:
|
||||
raise ValueError("API URL is required for LiteLLM reranking.")
|
||||
|
||||
sim_scores = litellm_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
api_url=rerank_request.api_url,
|
||||
model_name=rerank_request.model_name,
|
||||
api_key=rerank_request.api_key,
|
||||
)
|
||||
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
|
@ -1,4 +1,4 @@
|
||||
aiohttp==3.9.4
|
||||
aiohttp==3.10.2
|
||||
alembic==1.10.4
|
||||
asyncpg==0.27.0
|
||||
atlassian-python-api==3.37.0
|
||||
@ -26,13 +26,12 @@ huggingface-hub==0.20.1
|
||||
jira==3.5.1
|
||||
jsonref==1.1.0
|
||||
langchain==0.1.17
|
||||
langchain-community==0.0.36
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.43.18
|
||||
llama-index==0.9.45
|
||||
Mako==1.2.4
|
||||
msal==1.26.0
|
||||
msal==1.28.0
|
||||
nltk==3.8.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
@ -50,7 +49,7 @@ python-pptx==0.6.23
|
||||
pypdf==3.17.0
|
||||
pytest-mock==3.12.0
|
||||
pytest-playwright==0.3.2
|
||||
python-docx==1.1.0
|
||||
python-docx==1.1.2
|
||||
python-dotenv==1.0.0
|
||||
python-multipart==0.0.7
|
||||
pywikibot==9.0.0
|
||||
|
@ -8,7 +8,7 @@ pydantic==2.8.2
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
torch==2.0.1
|
||||
torch==2.2.0
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
|
@ -16,9 +16,12 @@ INDEXING_MODEL_SERVER_PORT = int(
|
||||
)
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
|
||||
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
|
||||
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
||||
INTENT_MODEL_TAG = "v1.0.3"
|
||||
|
||||
|
||||
# Bi-Encoder, other details
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
|
||||
|
@ -11,6 +11,7 @@ class EmbeddingProvider(str, Enum):
|
||||
|
||||
class RerankerProvider(str, Enum):
|
||||
COHERE = "cohere"
|
||||
LITELLM = "litellm"
|
||||
|
||||
|
||||
class EmbedTextType(str, Enum):
|
||||
|
@ -7,6 +7,15 @@ from shared_configs.enums import RerankerProvider
|
||||
Embedding = list[float]
|
||||
|
||||
|
||||
class ConnectorClassificationRequest(BaseModel):
|
||||
available_connectors: list[str]
|
||||
query: str
|
||||
|
||||
|
||||
class ConnectorClassificationResponse(BaseModel):
|
||||
connectors: list[str]
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
texts: list[str]
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
@ -34,6 +43,7 @@ class RerankRequest(BaseModel):
|
||||
model_name: str
|
||||
provider_type: RerankerProvider | None = None
|
||||
api_key: str | None = None
|
||||
api_url: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
@ -1,4 +1,3 @@
|
||||
version: '3'
|
||||
services:
|
||||
api_server:
|
||||
image: danswer/danswer-backend:${IMAGE_TAG:-latest}
|
||||
|
@ -1,4 +1,3 @@
|
||||
version: '3'
|
||||
services:
|
||||
api_server:
|
||||
image: danswer/danswer-backend:${IMAGE_TAG:-latest}
|
||||
|
@ -1,4 +1,3 @@
|
||||
version: '3'
|
||||
services:
|
||||
api_server:
|
||||
image: danswer/danswer-backend:${IMAGE_TAG:-latest}
|
||||
|
@ -1,4 +1,3 @@
|
||||
version: '3'
|
||||
services:
|
||||
api_server:
|
||||
image: danswer/danswer-backend:${IMAGE_TAG:-latest}
|
||||
|
@ -1,4 +1,3 @@
|
||||
version: '3'
|
||||
services:
|
||||
api_server:
|
||||
image: danswer/danswer-backend:${IMAGE_TAG:-latest}
|
||||
|
@ -1,33 +1,8 @@
|
||||
# This env template shows how to configure Danswer for multilingual use
|
||||
# In this case, it is configured for French and English
|
||||
# To use it, copy it to .env in the docker_compose directory.
|
||||
# Feel free to combine it with the other templates to suit your needs
|
||||
# This env template shows how to configure Danswer for custom multilingual use
|
||||
# Note that for most use cases it will be enough to configure Danswer multilingual purely through the UI
|
||||
# See "Search Settings" -> "Advanced" for UI options.
|
||||
# To use it, copy it to .env in the docker_compose directory (or the equivalent environment settings file for your deployment)
|
||||
|
||||
|
||||
# Rephrase the user query in specified languages using LLM, use comma separated values
|
||||
MULTILINGUAL_QUERY_EXPANSION="English, French"
|
||||
# Change the below to suit your specific needs, can be more explicit about the language of the response
|
||||
LANGUAGE_HINT="IMPORTANT: Respond in the same language as my query!"
|
||||
# The following is included with the user prompt. Here's one example but feel free to customize it to your needs:
|
||||
LANGUAGE_HINT="IMPORTANT: ALWAYS RESPOND IN FRENCH! Even if the documents and the user query are in English, your response must be in French."
|
||||
LANGUAGE_CHAT_NAMING_HINT="The name of the conversation must be in the same language as the user query."
|
||||
|
||||
# A recent MIT license multilingual model: https://huggingface.co/intfloat/multilingual-e5-small
|
||||
DOCUMENT_ENCODER_MODEL="intfloat/multilingual-e5-small"
|
||||
|
||||
# The model above is trained with the following prefix for queries and passages to improve retrieval
|
||||
# by letting the model know which of the two type is currently being embedded
|
||||
ASYM_QUERY_PREFIX="query: "
|
||||
ASYM_PASSAGE_PREFIX="passage: "
|
||||
|
||||
# Depends model by model, the one shown above is tuned with this as True
|
||||
NORMALIZE_EMBEDDINGS="True"
|
||||
|
||||
# Use LLM to determine if chunks are relevant to the query
|
||||
# May not work well for languages that do not have much training data in the LLM training set
|
||||
# If using a common language like Spanish, French, Chinese, etc. this can be kept turned on
|
||||
DISABLE_LLM_DOC_RELEVANCE="True"
|
||||
|
||||
# Enables fine-grained embeddings for better retrieval
|
||||
# At the cost of indexing speed (~5x slower), query time is same speed
|
||||
# Since reranking is turned off and multilingual retrieval is generally harder
|
||||
# it is advised to turn this one on
|
||||
ENABLE_MULTIPASS_INDEXING="True"
|
||||
|
@ -8,47 +8,6 @@ const version = env_version || package_version;
|
||||
const nextConfig = {
|
||||
output: "standalone",
|
||||
swcMinify: true,
|
||||
rewrites: async () => {
|
||||
// In production, something else (nginx in the one box setup) should take
|
||||
// care of this rewrite. TODO (chris): better support setups where
|
||||
// web_server and api_server are on different machines.
|
||||
if (process.env.NODE_ENV === "production") return [];
|
||||
|
||||
return [
|
||||
{
|
||||
source: "/api/:path*",
|
||||
destination: "http://127.0.0.1:8080/:path*", // Proxy to Backend
|
||||
},
|
||||
];
|
||||
},
|
||||
redirects: async () => {
|
||||
// In production, something else (nginx in the one box setup) should take
|
||||
// care of this redirect. TODO (chris): better support setups where
|
||||
// web_server and api_server are on different machines.
|
||||
const defaultRedirects = [];
|
||||
|
||||
if (process.env.NODE_ENV === "production") return defaultRedirects;
|
||||
|
||||
return defaultRedirects.concat([
|
||||
{
|
||||
source: "/api/chat/send-message:params*",
|
||||
destination: "http://127.0.0.1:8080/chat/send-message:params*", // Proxy to Backend
|
||||
permanent: true,
|
||||
},
|
||||
{
|
||||
source: "/api/query/stream-answer-with-quote:params*",
|
||||
destination:
|
||||
"http://127.0.0.1:8080/query/stream-answer-with-quote:params*", // Proxy to Backend
|
||||
permanent: true,
|
||||
},
|
||||
{
|
||||
source: "/api/query/stream-query-validation:params*",
|
||||
destination:
|
||||
"http://127.0.0.1:8080/query/stream-query-validation:params*", // Proxy to Backend
|
||||
permanent: true,
|
||||
},
|
||||
]);
|
||||
},
|
||||
publicRuntimeConfig: {
|
||||
version,
|
||||
},
|
||||
|
14
web/package-lock.json
generated
14
web/package-lock.json
generated
@ -2555,11 +2555,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/braces": {
|
||||
"version": "3.0.2",
|
||||
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz",
|
||||
"integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==",
|
||||
"version": "3.0.3",
|
||||
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz",
|
||||
"integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==",
|
||||
"dependencies": {
|
||||
"fill-range": "^7.0.1"
|
||||
"fill-range": "^7.1.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
@ -4061,9 +4061,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/fill-range": {
|
||||
"version": "7.0.1",
|
||||
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz",
|
||||
"integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==",
|
||||
"version": "7.1.1",
|
||||
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz",
|
||||
"integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==",
|
||||
"dependencies": {
|
||||
"to-regex-range": "^5.0.1"
|
||||
},
|
||||
|
@ -31,11 +31,13 @@ export function LLMProviderUpdateForm({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setPopup,
|
||||
hideAdvanced,
|
||||
}: {
|
||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
|
||||
onClose: () => void;
|
||||
existingLlmProvider?: FullLLMProvider;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
hideAdvanced?: boolean;
|
||||
setPopup?: (popup: PopupSpec) => void;
|
||||
}) {
|
||||
const { mutate } = useSWRConfig();
|
||||
@ -52,7 +54,7 @@ export function LLMProviderUpdateForm({
|
||||
|
||||
// Define the initial values based on the provider's requirements
|
||||
const initialValues = {
|
||||
name: existingLlmProvider?.name ?? "",
|
||||
name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? "",
|
||||
api_version: existingLlmProvider?.api_version ?? "",
|
||||
@ -218,17 +220,20 @@ export function LLMProviderUpdateForm({
|
||||
}}
|
||||
>
|
||||
{({ values, setFieldValue }) => (
|
||||
<Form className="gap-y-6 items-stretch mt-8">
|
||||
<TextFormField
|
||||
name="name"
|
||||
label="Display Name"
|
||||
subtext="A name which you can use to identify this provider when selecting it in the UI."
|
||||
placeholder="Display Name"
|
||||
disabled={existingLlmProvider ? true : false}
|
||||
/>
|
||||
<Form className="gap-y-4 items-stretch mt-6">
|
||||
{!hideAdvanced && (
|
||||
<TextFormField
|
||||
name="name"
|
||||
label="Display Name"
|
||||
subtext="A name which you can use to identify this provider when selecting it in the UI."
|
||||
placeholder="Display Name"
|
||||
disabled={existingLlmProvider ? true : false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{llmProviderDescriptor.api_key_required && (
|
||||
<TextFormField
|
||||
small={hideAdvanced}
|
||||
name="api_key"
|
||||
label="API Key"
|
||||
placeholder="API Key"
|
||||
@ -238,6 +243,7 @@ export function LLMProviderUpdateForm({
|
||||
|
||||
{llmProviderDescriptor.api_base_required && (
|
||||
<TextFormField
|
||||
small={hideAdvanced}
|
||||
name="api_base"
|
||||
label="API Base"
|
||||
placeholder="API Base"
|
||||
@ -246,6 +252,7 @@ export function LLMProviderUpdateForm({
|
||||
|
||||
{llmProviderDescriptor.api_version_required && (
|
||||
<TextFormField
|
||||
small={hideAdvanced}
|
||||
name="api_version"
|
||||
label="API Version"
|
||||
placeholder="API Version"
|
||||
@ -255,6 +262,7 @@ export function LLMProviderUpdateForm({
|
||||
{llmProviderDescriptor.custom_config_keys?.map((customConfigKey) => (
|
||||
<div key={customConfigKey.name}>
|
||||
<TextFormField
|
||||
small={hideAdvanced}
|
||||
name={`custom_config.${customConfigKey.name}`}
|
||||
label={
|
||||
customConfigKey.is_required
|
||||
@ -266,134 +274,144 @@ export function LLMProviderUpdateForm({
|
||||
</div>
|
||||
))}
|
||||
|
||||
<Divider />
|
||||
|
||||
{llmProviderDescriptor.llm_names.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
options={llmProviderDescriptor.llm_names.map((name) => ({
|
||||
name: getDisplayNameForModel(name),
|
||||
value: name,
|
||||
}))}
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
)}
|
||||
|
||||
{llmProviderDescriptor.llm_names.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
options={llmProviderDescriptor.llm_names.map((name) => ({
|
||||
name: getDisplayNameForModel(name),
|
||||
value: name,
|
||||
}))}
|
||||
includeDefault
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
)}
|
||||
|
||||
<Divider />
|
||||
|
||||
{llmProviderDescriptor.name != "azure" && (
|
||||
<AdvancedOptionsToggle
|
||||
showAdvancedOptions={showAdvancedOptions}
|
||||
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||
/>
|
||||
)}
|
||||
|
||||
{showAdvancedOptions && (
|
||||
{!hideAdvanced && (
|
||||
<>
|
||||
{llmProviderDescriptor.llm_names.length > 0 && (
|
||||
<div className="w-full">
|
||||
<MultiSelectField
|
||||
selectedInitially={values.display_model_names}
|
||||
name="display_model_names"
|
||||
label="Display Models"
|
||||
subtext="Select the models to make available to users. Unselected models will not be available."
|
||||
options={llmProviderDescriptor.llm_names.map((name) => ({
|
||||
value: name,
|
||||
label: getDisplayNameForModel(name),
|
||||
}))}
|
||||
onChange={(selected) =>
|
||||
setFieldValue("display_model_names", selected)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<Divider />
|
||||
|
||||
{llmProviderDescriptor.llm_names.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
options={llmProviderDescriptor.llm_names.map((name) => ({
|
||||
name: getDisplayNameForModel(name),
|
||||
value: name,
|
||||
}))}
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
)}
|
||||
|
||||
{isPaidEnterpriseFeaturesEnabled && userGroups && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
small
|
||||
removeIndent
|
||||
alignTop
|
||||
name="is_public"
|
||||
label="Is Public?"
|
||||
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
|
||||
/>
|
||||
{llmProviderDescriptor.llm_names.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
options={llmProviderDescriptor.llm_names.map((name) => ({
|
||||
name: getDisplayNameForModel(name),
|
||||
value: name,
|
||||
}))}
|
||||
includeDefault
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
)}
|
||||
|
||||
{userGroups && userGroups.length > 0 && !values.is_public && (
|
||||
<div>
|
||||
<Text>
|
||||
Select which User Groups should have access to this LLM
|
||||
Provider.
|
||||
</Text>
|
||||
<div className="flex flex-wrap gap-2 mt-2">
|
||||
{userGroups.map((userGroup) => {
|
||||
const isSelected = values.groups.includes(
|
||||
userGroup.id
|
||||
);
|
||||
return (
|
||||
<Bubble
|
||||
key={userGroup.id}
|
||||
isSelected={isSelected}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
setFieldValue(
|
||||
"groups",
|
||||
values.groups.filter(
|
||||
(id) => id !== userGroup.id
|
||||
)
|
||||
);
|
||||
} else {
|
||||
setFieldValue("groups", [
|
||||
...values.groups,
|
||||
userGroup.id,
|
||||
]);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex">
|
||||
<GroupsIcon />
|
||||
<div className="ml-1">{userGroup.name}</div>
|
||||
</div>
|
||||
</Bubble>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
<Divider />
|
||||
|
||||
{llmProviderDescriptor.name != "azure" && (
|
||||
<AdvancedOptionsToggle
|
||||
showAdvancedOptions={showAdvancedOptions}
|
||||
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||
/>
|
||||
)}
|
||||
|
||||
{showAdvancedOptions && (
|
||||
<>
|
||||
{llmProviderDescriptor.llm_names.length > 0 && (
|
||||
<div className="w-full">
|
||||
<MultiSelectField
|
||||
selectedInitially={values.display_model_names}
|
||||
name="display_model_names"
|
||||
label="Display Models"
|
||||
subtext="Select the models to make available to users. Unselected models will not be available."
|
||||
options={llmProviderDescriptor.llm_names.map(
|
||||
(name) => ({
|
||||
value: name,
|
||||
label: getDisplayNameForModel(name),
|
||||
})
|
||||
)}
|
||||
onChange={(selected) =>
|
||||
setFieldValue("display_model_names", selected)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isPaidEnterpriseFeaturesEnabled && userGroups && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
small
|
||||
removeIndent
|
||||
alignTop
|
||||
name="is_public"
|
||||
label="Is Public?"
|
||||
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
|
||||
/>
|
||||
|
||||
{userGroups &&
|
||||
userGroups.length > 0 &&
|
||||
!values.is_public && (
|
||||
<div>
|
||||
<Text>
|
||||
Select which User Groups should have access to
|
||||
this LLM Provider.
|
||||
</Text>
|
||||
<div className="flex flex-wrap gap-2 mt-2">
|
||||
{userGroups.map((userGroup) => {
|
||||
const isSelected = values.groups.includes(
|
||||
userGroup.id
|
||||
);
|
||||
return (
|
||||
<Bubble
|
||||
key={userGroup.id}
|
||||
isSelected={isSelected}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
setFieldValue(
|
||||
"groups",
|
||||
values.groups.filter(
|
||||
(id) => id !== userGroup.id
|
||||
)
|
||||
);
|
||||
} else {
|
||||
setFieldValue("groups", [
|
||||
...values.groups,
|
||||
userGroup.id,
|
||||
]);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex">
|
||||
<GroupsIcon />
|
||||
<div className="ml-1">
|
||||
{userGroup.name}
|
||||
</div>
|
||||
</div>
|
||||
</Bubble>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
@ -432,6 +450,27 @@ export function LLMProviderUpdateForm({
|
||||
return;
|
||||
}
|
||||
|
||||
// If the deleted provider was the default, set the first remaining provider as default
|
||||
const remainingProvidersResponse = await fetch(
|
||||
LLM_PROVIDERS_ADMIN_URL
|
||||
);
|
||||
if (remainingProvidersResponse.ok) {
|
||||
const remainingProviders =
|
||||
await remainingProvidersResponse.json();
|
||||
|
||||
if (remainingProviders.length > 0) {
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!setDefaultResponse.ok) {
|
||||
console.error("Failed to set new default provider");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
}}
|
||||
|
@ -17,6 +17,8 @@ import {
|
||||
GoogleDriveServiceAccountCredentialJson,
|
||||
} from "@/lib/connectors/credentials";
|
||||
|
||||
import { Button as TremorButton } from "@tremor/react";
|
||||
|
||||
type GoogleDriveCredentialJsonTypes = "authorized_user" | "service_account";
|
||||
|
||||
export const DriveJsonUpload = ({
|
||||
@ -344,7 +346,7 @@ export const DriveOAuthSection = ({
|
||||
if (serviceAccountKeyData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<p className="text-sm mb-2">
|
||||
<p className="text-sm mb-6">
|
||||
When using a Google Drive Service Account, you can either have Danswer
|
||||
act as the service account itself OR you can specify an account for
|
||||
the service account to impersonate.
|
||||
@ -356,70 +358,59 @@ export const DriveOAuthSection = ({
|
||||
the documents you want to index with the service account.
|
||||
</p>
|
||||
|
||||
<Card>
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_drive_delegated_user: "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_drive_delegated_user: Yup.string().optional(),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_drive_delegated_user:
|
||||
values.google_drive_delegated_user,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_drive_delegated_user: "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_drive_delegated_user: Yup.string().optional(),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_drive_delegated_user:
|
||||
values.google_drive_delegated_user,
|
||||
}),
|
||||
}
|
||||
refreshCredentials();
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_drive_delegated_user"
|
||||
label="[Optional] User email to impersonate:"
|
||||
subtext="If left blank, Danswer will use the service account itself."
|
||||
/>
|
||||
<div className="flex">
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className={
|
||||
"bg-slate-500 hover:bg-slate-700 text-white " +
|
||||
"font-bold py-2 px-4 rounded focus:outline-none " +
|
||||
"focus:shadow-outline w-full max-w-sm mx-auto"
|
||||
}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</Card>
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
refreshCredentials();
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_drive_delegated_user"
|
||||
label="[Optional] User email to impersonate:"
|
||||
subtext="If left blank, Danswer will use the service account itself."
|
||||
/>
|
||||
<div className="flex">
|
||||
<TremorButton type="submit" disabled={isSubmitting}>
|
||||
Create Credential
|
||||
</TremorButton>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
@ -8,8 +8,6 @@ import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { ConnectorIndexingStatus } from "@/lib/types";
|
||||
import { getCurrentUser } from "@/lib/user";
|
||||
import { User, UserRole } from "@/lib/types";
|
||||
import { usePublicCredentials } from "@/lib/hooks";
|
||||
import { Title } from "@tremor/react";
|
||||
import { DriveJsonUploadSection, DriveOAuthSection } from "./Credential";
|
||||
@ -109,6 +107,7 @@ const GDriveMain = ({}: {}) => {
|
||||
| undefined = credentialsData.find(
|
||||
(credential) => credential.credential_json?.google_drive_service_account_key
|
||||
);
|
||||
|
||||
const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus<
|
||||
GoogleDriveConfig,
|
||||
GoogleDriveCredentialJson
|
||||
|
@ -7,7 +7,11 @@ import {
|
||||
rerankingModels,
|
||||
} from "./interfaces";
|
||||
import { FiExternalLink } from "react-icons/fi";
|
||||
import { CohereIcon, MixedBreadIcon } from "@/components/icons/icons";
|
||||
import {
|
||||
CohereIcon,
|
||||
LiteLLMIcon,
|
||||
MixedBreadIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button } from "@tremor/react";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
@ -35,6 +39,8 @@ const RerankingDetailsForm = forwardRef<
|
||||
ref
|
||||
) => {
|
||||
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false);
|
||||
const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] =
|
||||
useState(false);
|
||||
|
||||
return (
|
||||
<Formik
|
||||
@ -48,13 +54,17 @@ const RerankingDetailsForm = forwardRef<
|
||||
.optional(),
|
||||
api_key: Yup.string().nullable(),
|
||||
num_rerank: Yup.number().min(1, "Must be at least 1"),
|
||||
rerank_api_url: Yup.string()
|
||||
.url("Must be a valid URL")
|
||||
.matches(/^https?:\/\//, "URL must start with http:// or https://")
|
||||
.nullable(),
|
||||
})}
|
||||
onSubmit={async (_, { setSubmitting }) => {
|
||||
setSubmitting(false);
|
||||
}}
|
||||
enableReinitialize={true}
|
||||
>
|
||||
{({ values, setFieldValue }) => {
|
||||
{({ values, setFieldValue, resetForm }) => {
|
||||
const resetRerankingValues = () => {
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
@ -131,14 +141,22 @@ const RerankingDetailsForm = forwardRef<
|
||||
)
|
||||
: rerankingModels.filter(
|
||||
(modelCard) =>
|
||||
modelCard.modelName ==
|
||||
originalRerankingDetails.rerank_model_name
|
||||
(modelCard.modelName ==
|
||||
originalRerankingDetails.rerank_model_name &&
|
||||
modelCard.rerank_provider_type ==
|
||||
originalRerankingDetails.rerank_provider_type) ||
|
||||
(modelCard.rerank_provider_type ==
|
||||
RerankerProvider.LITELLM &&
|
||||
originalRerankingDetails.rerank_provider_type ==
|
||||
RerankerProvider.LITELLM)
|
||||
)
|
||||
).map((card) => {
|
||||
const isSelected =
|
||||
values.rerank_provider_type ===
|
||||
card.rerank_provider_type &&
|
||||
values.rerank_model_name === card.modelName;
|
||||
(card.modelName == null ||
|
||||
values.rerank_model_name === card.modelName);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${card.rerank_provider_type}-${card.modelName}`}
|
||||
@ -148,26 +166,39 @@ const RerankingDetailsForm = forwardRef<
|
||||
: "border-gray-200 hover:border-blue-300 hover:shadow-sm"
|
||||
}`}
|
||||
onClick={() => {
|
||||
if (card.rerank_provider_type) {
|
||||
if (
|
||||
card.rerank_provider_type == RerankerProvider.COHERE
|
||||
) {
|
||||
setIsApiKeyModalOpen(true);
|
||||
} else if (
|
||||
card.rerank_provider_type ==
|
||||
RerankerProvider.LITELLM
|
||||
) {
|
||||
setShowLiteLLMConfigurationModal(true);
|
||||
}
|
||||
|
||||
if (!isSelected) {
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_provider_type: card.rerank_provider_type!,
|
||||
rerank_model_name: card.modelName || null,
|
||||
rerank_api_key: null,
|
||||
rerank_api_url: null,
|
||||
});
|
||||
setFieldValue(
|
||||
"rerank_provider_type",
|
||||
card.rerank_provider_type
|
||||
);
|
||||
setFieldValue("rerank_model_name", card.modelName);
|
||||
}
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_provider_type: card.rerank_provider_type!,
|
||||
rerank_model_name: card.modelName,
|
||||
rerank_api_key: null,
|
||||
});
|
||||
setFieldValue(
|
||||
"rerank_provider_type",
|
||||
card.rerank_provider_type
|
||||
);
|
||||
setFieldValue("rerank_model_name", card.modelName);
|
||||
}}
|
||||
>
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<div className="flex items-center">
|
||||
{card.rerank_provider_type ===
|
||||
RerankerProvider.COHERE ? (
|
||||
RerankerProvider.LITELLM ? (
|
||||
<LiteLLMIcon size={24} className="mr-2" />
|
||||
) : RerankerProvider.COHERE ? (
|
||||
<CohereIcon size={24} className="mr-2" />
|
||||
) : (
|
||||
<MixedBreadIcon size={24} className="mr-2" />
|
||||
@ -199,6 +230,88 @@ const RerankingDetailsForm = forwardRef<
|
||||
})}
|
||||
</div>
|
||||
|
||||
{showLiteLLMConfigurationModal && (
|
||||
<Modal
|
||||
onOutsideClick={() => {
|
||||
resetForm();
|
||||
setShowLiteLLMConfigurationModal(false);
|
||||
}}
|
||||
width="w-[800px]"
|
||||
title="API Key Configuration"
|
||||
>
|
||||
<div className="w-full flex flex-col gap-y-4 px-4">
|
||||
<TextFormField
|
||||
subtext="Set the URL at which your LiteLLM Proxy is hosted"
|
||||
placeholder={values.rerank_api_url || undefined}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_api_url: value,
|
||||
});
|
||||
setFieldValue("rerank_api_url", value);
|
||||
}}
|
||||
type="text"
|
||||
label="LiteLLM Proxy URL"
|
||||
name="rerank_api_url"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
subtext="Set the key to access your LiteLLM Proxy"
|
||||
placeholder={
|
||||
values.rerank_api_key
|
||||
? "*".repeat(values.rerank_api_key.length)
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_api_key: value,
|
||||
});
|
||||
setFieldValue("rerank_api_key", value);
|
||||
}}
|
||||
type="password"
|
||||
label="LiteLLM Proxy Key"
|
||||
name="rerank_api_key"
|
||||
optional
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
subtext="Set the model name to use for LiteLLM Proxy"
|
||||
placeholder={
|
||||
values.rerank_model_name
|
||||
? "*".repeat(values.rerank_model_name.length)
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_model_name: value,
|
||||
});
|
||||
setFieldValue("rerank_model_name", value);
|
||||
}}
|
||||
label="LiteLLM Model Name"
|
||||
name="rerank_model_name"
|
||||
optional
|
||||
/>
|
||||
|
||||
<div className="flex w-full justify-end mt-4">
|
||||
<Button
|
||||
onClick={() => {
|
||||
setShowLiteLLMConfigurationModal(false);
|
||||
}}
|
||||
color="blue"
|
||||
size="xs"
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
)}
|
||||
|
||||
{isApiKeyModalOpen && (
|
||||
<Modal
|
||||
onOutsideClick={() => {
|
||||
@ -218,7 +331,11 @@ const RerankingDetailsForm = forwardRef<
|
||||
>
|
||||
<div className="w-full px-4">
|
||||
<TextFormField
|
||||
placeholder={values.rerank_api_key || undefined}
|
||||
placeholder={
|
||||
values.rerank_api_key
|
||||
? "*".repeat(values.rerank_api_key.length)
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
|
@ -5,11 +5,13 @@ export interface RerankingDetails {
|
||||
rerank_model_name: string | null;
|
||||
rerank_provider_type: RerankerProvider | null;
|
||||
rerank_api_key: string | null;
|
||||
rerank_api_url: string | null;
|
||||
num_rerank: number;
|
||||
}
|
||||
|
||||
export enum RerankerProvider {
|
||||
COHERE = "cohere",
|
||||
LITELLM = "litellm",
|
||||
}
|
||||
export interface AdvancedSearchConfiguration {
|
||||
model_name: string;
|
||||
@ -40,7 +42,7 @@ export interface SavedSearchSettings extends RerankingDetails {
|
||||
|
||||
export interface RerankingModel {
|
||||
rerank_provider_type: RerankerProvider | null;
|
||||
modelName: string;
|
||||
modelName?: string;
|
||||
displayName: string;
|
||||
description: string;
|
||||
link: string;
|
||||
@ -48,6 +50,13 @@ export interface RerankingModel {
|
||||
}
|
||||
|
||||
export const rerankingModels: RerankingModel[] = [
|
||||
{
|
||||
rerank_provider_type: RerankerProvider.LITELLM,
|
||||
cloud: true,
|
||||
displayName: "LiteLLM",
|
||||
description: "Host your own reranker or router with LiteLLM proxy",
|
||||
link: "https://docs.litellm.ai/docs/proxy",
|
||||
},
|
||||
{
|
||||
rerank_provider_type: null,
|
||||
cloud: false,
|
||||
|
@ -4,7 +4,7 @@ import * as Yup from "yup";
|
||||
import CredentialSubText from "@/components/credentials/CredentialFields";
|
||||
import { TrashIcon } from "@/components/icons/icons";
|
||||
import { FaPlus } from "react-icons/fa";
|
||||
import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces";
|
||||
import { AdvancedSearchConfiguration } from "../interfaces";
|
||||
import { BooleanFormField } from "@/components/admin/connectors/Field";
|
||||
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";
|
||||
|
||||
|
@ -10,7 +10,7 @@ import {
|
||||
CloudEmbeddingModel,
|
||||
EmbeddingProvider,
|
||||
HostedEmbeddingModel,
|
||||
} from "../../../../components/embedding/interfaces";
|
||||
} from "@/components/embedding/interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import useSWR, { mutate } from "swr";
|
||||
@ -18,7 +18,6 @@ import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage";
|
||||
import {
|
||||
AdvancedSearchConfiguration,
|
||||
RerankerProvider,
|
||||
RerankingDetails,
|
||||
SavedSearchSettings,
|
||||
} from "../interfaces";
|
||||
@ -49,6 +48,7 @@ export default function EmbeddingForm() {
|
||||
num_rerank: 0,
|
||||
rerank_provider_type: null,
|
||||
rerank_model_name: "",
|
||||
rerank_api_url: null,
|
||||
});
|
||||
|
||||
const updateAdvancedEmbeddingDetails = (
|
||||
@ -124,6 +124,7 @@ export default function EmbeddingForm() {
|
||||
num_rerank: searchSettings.num_rerank,
|
||||
rerank_provider_type: searchSettings.rerank_provider_type,
|
||||
rerank_model_name: searchSettings.rerank_model_name,
|
||||
rerank_api_url: searchSettings.rerank_api_url,
|
||||
});
|
||||
}
|
||||
}, [searchSettings]);
|
||||
@ -134,12 +135,14 @@ export default function EmbeddingForm() {
|
||||
num_rerank: searchSettings.num_rerank,
|
||||
rerank_provider_type: searchSettings.rerank_provider_type,
|
||||
rerank_model_name: searchSettings.rerank_model_name,
|
||||
rerank_api_url: searchSettings.rerank_api_url,
|
||||
}
|
||||
: {
|
||||
rerank_api_key: "",
|
||||
num_rerank: 0,
|
||||
rerank_provider_type: null,
|
||||
rerank_model_name: "",
|
||||
rerank_api_url: null,
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
|
116
web/src/app/api/[...path]/route.ts
Normal file
116
web/src/app/api/[...path]/route.ts
Normal file
@ -0,0 +1,116 @@
|
||||
import { INTERNAL_URL } from "@/lib/constants";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
/* NextJS is annoying and makes use use a separate function for
|
||||
each request type >:( */
|
||||
|
||||
export async function GET(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { path: string[] } }
|
||||
) {
|
||||
return handleRequest(request, params.path);
|
||||
}
|
||||
|
||||
export async function POST(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { path: string[] } }
|
||||
) {
|
||||
return handleRequest(request, params.path);
|
||||
}
|
||||
|
||||
export async function PUT(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { path: string[] } }
|
||||
) {
|
||||
return handleRequest(request, params.path);
|
||||
}
|
||||
|
||||
export async function PATCH(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { path: string[] } }
|
||||
) {
|
||||
return handleRequest(request, params.path);
|
||||
}
|
||||
|
||||
export async function DELETE(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { path: string[] } }
|
||||
) {
|
||||
return handleRequest(request, params.path);
|
||||
}
|
||||
|
||||
export async function HEAD(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { path: string[] } }
|
||||
) {
|
||||
return handleRequest(request, params.path);
|
||||
}
|
||||
|
||||
export async function OPTIONS(
|
||||
request: NextRequest,
|
||||
{ params }: { params: { path: string[] } }
|
||||
) {
|
||||
return handleRequest(request, params.path);
|
||||
}
|
||||
|
||||
async function handleRequest(request: NextRequest, path: string[]) {
|
||||
if (process.env.NODE_ENV !== "development") {
|
||||
return NextResponse.json(
|
||||
{
|
||||
message:
|
||||
"This API is only available in development mode. In production, something else (e.g. nginx) should handle this.",
|
||||
},
|
||||
{ status: 404 }
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const backendUrl = new URL(`${INTERNAL_URL}/${path.join("/")}`);
|
||||
|
||||
// Get the URL parameters from the request
|
||||
const urlParams = new URLSearchParams(request.url.split("?")[1]);
|
||||
|
||||
// Append the URL parameters to the backend URL
|
||||
urlParams.forEach((value, key) => {
|
||||
backendUrl.searchParams.append(key, value);
|
||||
});
|
||||
|
||||
const response = await fetch(backendUrl, {
|
||||
method: request.method,
|
||||
headers: request.headers,
|
||||
body: request.body,
|
||||
// @ts-ignore
|
||||
duplex: "half",
|
||||
});
|
||||
|
||||
// Check if the response is a stream
|
||||
if (
|
||||
response.headers.get("Transfer-Encoding") === "chunked" ||
|
||||
response.headers.get("Content-Type")?.includes("stream")
|
||||
) {
|
||||
// If it's a stream, create a TransformStream to pass the data through
|
||||
const { readable, writable } = new TransformStream();
|
||||
response.body?.pipeTo(writable);
|
||||
|
||||
return new NextResponse(readable, {
|
||||
status: response.status,
|
||||
headers: response.headers,
|
||||
});
|
||||
} else {
|
||||
return new NextResponse(response.body, {
|
||||
status: response.status,
|
||||
headers: response.headers,
|
||||
});
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
console.error("Proxy error:", error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
message: "Proxy error",
|
||||
error:
|
||||
error instanceof Error ? error.message : "An unknown error occurred",
|
||||
},
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
@ -98,6 +98,7 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
|
||||
import { SEARCH_TOOL_NAME } from "./tools/constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@ -106,12 +107,10 @@ const SYSTEM_MESSAGE_ID = -3;
|
||||
export function ChatPage({
|
||||
toggle,
|
||||
documentSidebarInitialWidth,
|
||||
defaultSelectedAssistantId,
|
||||
toggledSidebar,
|
||||
}: {
|
||||
toggle: (toggled?: boolean) => void;
|
||||
documentSidebarInitialWidth?: number;
|
||||
defaultSelectedAssistantId?: number;
|
||||
toggledSidebar: boolean;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
@ -126,8 +125,13 @@ export function ChatPage({
|
||||
folders,
|
||||
openedFolders,
|
||||
userInputPrompts,
|
||||
defaultAssistantId,
|
||||
shouldShowWelcomeModal,
|
||||
refreshChatSessions,
|
||||
} = useChatContext();
|
||||
|
||||
const [showApiKeyModal, setShowApiKeyModal] = useState(true);
|
||||
|
||||
const { user, refreshUser, isLoadingUser } = useUser();
|
||||
|
||||
// chat session
|
||||
@ -162,9 +166,9 @@ export function ChatPage({
|
||||
? availableAssistants.find(
|
||||
(assistant) => assistant.id === existingChatSessionAssistantId
|
||||
)
|
||||
: defaultSelectedAssistantId !== undefined
|
||||
: defaultAssistantId !== undefined
|
||||
? availableAssistants.find(
|
||||
(assistant) => assistant.id === defaultSelectedAssistantId
|
||||
(assistant) => assistant.id === defaultAssistantId
|
||||
)
|
||||
: undefined
|
||||
);
|
||||
@ -327,8 +331,8 @@ export function ChatPage({
|
||||
async function initialSessionFetch() {
|
||||
if (existingChatSessionId === null) {
|
||||
setIsFetchingChatMessages(false);
|
||||
if (defaultSelectedAssistantId !== undefined) {
|
||||
setSelectedAssistantFromId(defaultSelectedAssistantId);
|
||||
if (defaultAssistantId !== undefined) {
|
||||
setSelectedAssistantFromId(defaultAssistantId);
|
||||
} else {
|
||||
setSelectedAssistant(undefined);
|
||||
}
|
||||
@ -402,7 +406,7 @@ export function ChatPage({
|
||||
// force re-name if the chat session doesn't have one
|
||||
if (!chatSession.description) {
|
||||
await nameChatSession(existingChatSessionId, seededMessage);
|
||||
router.refresh(); // need to refresh to update name on sidebar
|
||||
refreshChatSessions();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -676,12 +680,10 @@ export function ChatPage({
|
||||
useEffect(() => {
|
||||
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
|
||||
setSelectedAssistant(
|
||||
filteredAssistants.find(
|
||||
(persona) => persona.id === defaultSelectedAssistantId
|
||||
)
|
||||
filteredAssistants.find((persona) => persona.id === defaultAssistantId)
|
||||
);
|
||||
}
|
||||
}, [defaultSelectedAssistantId]);
|
||||
}, [defaultAssistantId]);
|
||||
|
||||
const [
|
||||
selectedDocuments,
|
||||
@ -1111,6 +1113,12 @@ export function ChatPage({
|
||||
console.error(
|
||||
"First packet should contain message response info "
|
||||
);
|
||||
if (Object.hasOwn(packet, "error")) {
|
||||
const error = (packet as StreamingError).error;
|
||||
setLoadingError(error);
|
||||
updateChatState("input");
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1330,6 +1338,7 @@ export function ChatPage({
|
||||
if (!searchParamBasedChatSessionName) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
await nameChatSession(currChatSessionId, currMessage);
|
||||
refreshChatSessions();
|
||||
}
|
||||
|
||||
// NOTE: don't switch pages if the user has navigated away from the chat
|
||||
@ -1465,6 +1474,7 @@ export function ChatPage({
|
||||
|
||||
// Used to maintain a "time out" for history sidebar so our existing refs can have time to process change
|
||||
const [untoggled, setUntoggled] = useState(false);
|
||||
const [loadingError, setLoadingError] = useState<string | null>(null);
|
||||
|
||||
const explicitlyUntoggle = () => {
|
||||
setShowDocSidebar(false);
|
||||
@ -1588,6 +1598,11 @@ export function ChatPage({
|
||||
return (
|
||||
<>
|
||||
<HealthCheckBanner />
|
||||
|
||||
{showApiKeyModal && !shouldShowWelcomeModal && (
|
||||
<ApiKeyModal hide={() => setShowApiKeyModal(false)} />
|
||||
)}
|
||||
|
||||
{/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit.
|
||||
Only used in the EE version of the app. */}
|
||||
{popup}
|
||||
@ -1760,7 +1775,6 @@ export function ChatPage({
|
||||
className={`h-full w-full relative flex-auto transition-margin duration-300 overflow-x-auto mobile:pb-12 desktop:pb-[100px]`}
|
||||
{...getRootProps()}
|
||||
>
|
||||
{/* <input {...getInputProps()} /> */}
|
||||
<div
|
||||
className={`w-full h-full flex flex-col overflow-y-auto include-scrollbar overflow-x-hidden relative`}
|
||||
ref={scrollableDivRef}
|
||||
@ -1770,7 +1784,8 @@ export function ChatPage({
|
||||
|
||||
{messageHistory.length === 0 &&
|
||||
!isFetchingChatMessages &&
|
||||
currentSessionChatState == "input" && (
|
||||
currentSessionChatState == "input" &&
|
||||
!loadingError && (
|
||||
<ChatIntro
|
||||
availableSources={finalAvailableSources}
|
||||
selectedPersona={liveAssistant}
|
||||
@ -2078,16 +2093,17 @@ export function ChatPage({
|
||||
}
|
||||
})}
|
||||
|
||||
{currentSessionChatState == "loading" &&
|
||||
!currentSessionRegenerationState?.regenerating &&
|
||||
messageHistory[messageHistory.length - 1]?.type !=
|
||||
"user" && (
|
||||
<HumanMessage
|
||||
key={-2}
|
||||
messageId={-1}
|
||||
content={submittedMessage}
|
||||
/>
|
||||
)}
|
||||
{currentSessionChatState == "loading" ||
|
||||
(loadingError &&
|
||||
!currentSessionRegenerationState?.regenerating &&
|
||||
messageHistory[messageHistory.length - 1]
|
||||
?.type != "user" && (
|
||||
<HumanMessage
|
||||
key={-2}
|
||||
messageId={-1}
|
||||
content={submittedMessage}
|
||||
/>
|
||||
))}
|
||||
|
||||
{currentSessionChatState == "loading" && (
|
||||
<div
|
||||
@ -2116,6 +2132,20 @@ export function ChatPage({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{loadingError && (
|
||||
<div key={-1}>
|
||||
<AIMessage
|
||||
currentPersona={liveAssistant}
|
||||
messageId={-1}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<p className="text-red-700 text-sm my-auto">
|
||||
{loadingError}
|
||||
</p>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{currentPersona &&
|
||||
currentPersona.starter_messages &&
|
||||
currentPersona.starter_messages.length > 0 &&
|
||||
@ -2177,6 +2207,9 @@ export function ChatPage({
|
||||
</div>
|
||||
)}
|
||||
<ChatInputBar
|
||||
showConfigureAPIKey={() =>
|
||||
setShowApiKeyModal(true)
|
||||
}
|
||||
chatState={currentSessionChatState}
|
||||
stopGenerating={stopGenerating}
|
||||
openModelSettings={() => setSettingsToggled(true)}
|
||||
|
@ -3,21 +3,15 @@ import { ChatPage } from "./ChatPage";
|
||||
import FunctionalWrapper from "./shared_chat_search/FunctionalWrapper";
|
||||
|
||||
export default function WrappedChat({
|
||||
defaultAssistantId,
|
||||
initiallyToggled,
|
||||
}: {
|
||||
defaultAssistantId?: number;
|
||||
initiallyToggled: boolean;
|
||||
}) {
|
||||
return (
|
||||
<FunctionalWrapper
|
||||
initiallyToggled={initiallyToggled}
|
||||
content={(toggledSidebar, toggle) => (
|
||||
<ChatPage
|
||||
toggle={toggle}
|
||||
defaultSelectedAssistantId={defaultAssistantId}
|
||||
toggledSidebar={toggledSidebar}
|
||||
/>
|
||||
<ChatPage toggle={toggle} toggledSidebar={toggledSidebar} />
|
||||
)}
|
||||
/>
|
||||
);
|
||||
|
@ -33,12 +33,15 @@ import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { Hoverable } from "@/components/Hoverable";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { ChatState } from "../types";
|
||||
import UnconfiguredProviderText from "@/components/chat_search/UnconfiguredProviderText";
|
||||
import { useSearchContext } from "@/components/context/SearchContext";
|
||||
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
|
||||
export function ChatInputBar({
|
||||
openModelSettings,
|
||||
showDocs,
|
||||
showConfigureAPIKey,
|
||||
selectedDocuments,
|
||||
message,
|
||||
setMessage,
|
||||
@ -62,6 +65,7 @@ export function ChatInputBar({
|
||||
chatSessionId,
|
||||
inputPrompts,
|
||||
}: {
|
||||
showConfigureAPIKey: () => void;
|
||||
openModelSettings: () => void;
|
||||
chatState: ChatState;
|
||||
stopGenerating: () => void;
|
||||
@ -111,6 +115,7 @@ export function ChatInputBar({
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const settings = useContext(SettingsContext);
|
||||
|
||||
const { llmProviders } = useChatContext();
|
||||
@ -364,6 +369,9 @@ export function ChatInputBar({
|
||||
<div>
|
||||
<SelectedFilterDisplay filterManager={filterManager} />
|
||||
</div>
|
||||
|
||||
<UnconfiguredProviderText showConfigureAPIKey={showConfigureAPIKey} />
|
||||
|
||||
<div
|
||||
className="
|
||||
opacity-100
|
||||
|
@ -2,10 +2,10 @@ import { redirect } from "next/navigation";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import { WelcomeModal } from "@/components/initialSetup/welcome/WelcomeModalWrapper";
|
||||
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
|
||||
import { ChatProvider } from "@/components/context/ChatContext";
|
||||
import { fetchChatData } from "@/lib/chat/fetchChatData";
|
||||
import WrappedChat from "./WrappedChat";
|
||||
import { ProviderContextProvider } from "@/components/chat_search/ProviderContext";
|
||||
|
||||
export default async function Page({
|
||||
searchParams,
|
||||
@ -23,7 +23,6 @@ export default async function Page({
|
||||
const {
|
||||
user,
|
||||
chatSessions,
|
||||
ccPairs,
|
||||
availableSources,
|
||||
documentSets,
|
||||
assistants,
|
||||
@ -33,9 +32,7 @@ export default async function Page({
|
||||
toggleSidebar,
|
||||
openedFolders,
|
||||
defaultAssistantId,
|
||||
finalDocumentSidebarInitialWidth,
|
||||
shouldShowWelcomeModal,
|
||||
shouldDisplaySourcesIncompleteModal,
|
||||
userInputPrompts,
|
||||
} = data;
|
||||
|
||||
@ -43,9 +40,7 @@ export default async function Page({
|
||||
<>
|
||||
<InstantSSRAutoRefresh />
|
||||
{shouldShowWelcomeModal && <WelcomeModal user={user} />}
|
||||
{!shouldShowWelcomeModal && !shouldDisplaySourcesIncompleteModal && (
|
||||
<ApiKeyModal user={user} />
|
||||
)}
|
||||
|
||||
<ChatProvider
|
||||
value={{
|
||||
chatSessions,
|
||||
@ -57,12 +52,13 @@ export default async function Page({
|
||||
folders,
|
||||
openedFolders,
|
||||
userInputPrompts,
|
||||
shouldShowWelcomeModal,
|
||||
defaultAssistantId,
|
||||
}}
|
||||
>
|
||||
<WrappedChat
|
||||
defaultAssistantId={defaultAssistantId}
|
||||
initiallyToggled={toggleSidebar}
|
||||
/>
|
||||
<ProviderContextProvider>
|
||||
<WrappedChat initiallyToggled={toggleSidebar} />
|
||||
</ProviderContextProvider>
|
||||
</ChatProvider>
|
||||
</>
|
||||
);
|
||||
|
@ -6,6 +6,7 @@ import {
|
||||
} from "@/components/settings/lib";
|
||||
import {
|
||||
CUSTOM_ANALYTICS_ENABLED,
|
||||
EE_ENABLED,
|
||||
SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED,
|
||||
} from "@/lib/constants";
|
||||
import { SettingsProvider } from "@/components/settings/SettingsProvider";
|
||||
@ -53,6 +54,7 @@ export default async function RootLayout({
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const combinedSettings = await fetchSettingsSS();
|
||||
|
||||
if (!combinedSettings) {
|
||||
// Just display a simple full page error if fetching fails.
|
||||
|
||||
@ -72,8 +74,34 @@ export default async function RootLayout({
|
||||
<h1 className="text-2xl font-bold mb-4 text-error">Error</h1>
|
||||
<p className="text-text-500">
|
||||
Your Danswer instance was not configured properly and your
|
||||
settings could not be loaded. Please contact your admin to fix
|
||||
this error.
|
||||
settings could not be loaded. This could be due to an admin
|
||||
configuration issue or an incomplete setup.
|
||||
</p>
|
||||
<p className="mt-4">
|
||||
If you're an admin, please check{" "}
|
||||
<a
|
||||
className="text-link"
|
||||
href="https://docs.danswer.dev/introduction?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
our docs
|
||||
</a>{" "}
|
||||
to see how to configure Danswer properly. If you're a user,
|
||||
please contact your admin to fix this error.
|
||||
</p>
|
||||
<p className="mt-4">
|
||||
For additional support and guidance, you can reach out to our
|
||||
community on{" "}
|
||||
<a
|
||||
className="text-link"
|
||||
href="https://danswer.ai?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
Slack
|
||||
</a>
|
||||
.
|
||||
</p>
|
||||
</Card>
|
||||
</div>
|
||||
|
@ -1,31 +1,12 @@
|
||||
"use client";
|
||||
import { SearchSection } from "@/components/search/SearchSection";
|
||||
import FunctionalWrapper from "../chat/shared_chat_search/FunctionalWrapper";
|
||||
import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types";
|
||||
import { Persona } from "../admin/assistants/interfaces";
|
||||
import { ChatSession } from "../chat/interfaces";
|
||||
|
||||
export default function WrappedSearch({
|
||||
querySessions,
|
||||
ccPairs,
|
||||
documentSets,
|
||||
personas,
|
||||
searchTypeDefault,
|
||||
tags,
|
||||
user,
|
||||
agenticSearchEnabled,
|
||||
initiallyToggled,
|
||||
disabledAgentic,
|
||||
}: {
|
||||
disabledAgentic: boolean;
|
||||
querySessions: ChatSession[];
|
||||
ccPairs: CCPairBasicInfo[];
|
||||
documentSets: DocumentSet[];
|
||||
personas: Persona[];
|
||||
searchTypeDefault: string;
|
||||
tags: Tag[];
|
||||
user: User | null;
|
||||
agenticSearchEnabled: boolean;
|
||||
initiallyToggled: boolean;
|
||||
}) {
|
||||
return (
|
||||
@ -33,16 +14,8 @@ export default function WrappedSearch({
|
||||
initiallyToggled={initiallyToggled}
|
||||
content={(toggledSidebar, toggle) => (
|
||||
<SearchSection
|
||||
disabledAgentic={disabledAgentic}
|
||||
agenticSearchEnabled={agenticSearchEnabled}
|
||||
toggle={toggle}
|
||||
toggledSidebar={toggledSidebar}
|
||||
querySessions={querySessions}
|
||||
user={user}
|
||||
ccPairs={ccPairs}
|
||||
documentSets={documentSets}
|
||||
personas={personas}
|
||||
tags={tags}
|
||||
defaultSearchType={searchTypeDefault}
|
||||
/>
|
||||
)}
|
||||
|
@ -5,7 +5,6 @@ import {
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types";
|
||||
import { cookies } from "next/headers";
|
||||
@ -34,6 +33,8 @@ import {
|
||||
DISABLE_LLM_DOC_RELEVANCE,
|
||||
} from "@/lib/constants";
|
||||
import WrappedSearch from "./WrappedSearch";
|
||||
import { SearchProvider } from "@/components/context/SearchContext";
|
||||
import { ProviderContextProvider } from "@/components/chat_search/ProviderContext";
|
||||
|
||||
export default async function Home() {
|
||||
// Disable caching so we always get the up to date connector / document set / persona info
|
||||
@ -185,10 +186,6 @@ export default async function Home() {
|
||||
{shouldShowWelcomeModal && <WelcomeModal user={user} />}
|
||||
<InstantSSRAutoRefresh />
|
||||
|
||||
{!shouldShowWelcomeModal &&
|
||||
!shouldDisplayNoSourcesModal &&
|
||||
!shouldDisplaySourcesIncompleteModal && <ApiKeyModal user={user} />}
|
||||
|
||||
{shouldDisplayNoSourcesModal && <NoSourcesModal />}
|
||||
|
||||
{shouldDisplaySourcesIncompleteModal && (
|
||||
@ -199,18 +196,27 @@ export default async function Home() {
|
||||
Only used in the EE version of the app. */}
|
||||
<ChatPopup />
|
||||
|
||||
<WrappedSearch
|
||||
disabledAgentic={DISABLE_LLM_DOC_RELEVANCE}
|
||||
initiallyToggled={toggleSidebar}
|
||||
querySessions={querySessions}
|
||||
user={user}
|
||||
ccPairs={ccPairs}
|
||||
documentSets={documentSets}
|
||||
personas={assistants}
|
||||
tags={tags}
|
||||
searchTypeDefault={searchTypeDefault}
|
||||
agenticSearchEnabled={agenticSearchEnabled}
|
||||
/>
|
||||
<SearchProvider
|
||||
value={{
|
||||
querySessions,
|
||||
ccPairs,
|
||||
documentSets,
|
||||
assistants,
|
||||
tags,
|
||||
agenticSearchEnabled,
|
||||
disabledAgentic: DISABLE_LLM_DOC_RELEVANCE,
|
||||
initiallyToggled: toggleSidebar,
|
||||
shouldShowWelcomeModal,
|
||||
shouldDisplayNoSources: shouldDisplayNoSourcesModal,
|
||||
}}
|
||||
>
|
||||
<ProviderContextProvider>
|
||||
<WrappedSearch
|
||||
initiallyToggled={toggleSidebar}
|
||||
searchTypeDefault={searchTypeDefault}
|
||||
/>
|
||||
</ProviderContextProvider>
|
||||
</SearchProvider>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ export function TextFormField({
|
||||
rounded-lg
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
px-3
|
||||
mt-1
|
||||
placeholder:font-description
|
||||
placeholder:text-base
|
||||
|
70
web/src/components/chat_search/ProviderContext.tsx
Normal file
70
web/src/components/chat_search/ProviderContext.tsx
Normal file
@ -0,0 +1,70 @@
|
||||
"use client";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import React, { createContext, useContext, useState, useEffect } from "react";
|
||||
import { useUser } from "../user/UserProvider";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { checkLlmProvider } from "../initialSetup/welcome/lib";
|
||||
|
||||
interface ProviderContextType {
|
||||
shouldShowConfigurationNeeded: boolean;
|
||||
providerOptions: WellKnownLLMProviderDescriptor[];
|
||||
refreshProviderInfo: () => Promise<void>; // Add this line
|
||||
}
|
||||
|
||||
const ProviderContext = createContext<ProviderContextType | undefined>(
|
||||
undefined
|
||||
);
|
||||
|
||||
export function ProviderContextProvider({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const { user } = useUser();
|
||||
const router = useRouter();
|
||||
|
||||
const [validProviderExists, setValidProviderExists] = useState<boolean>(true);
|
||||
const [providerOptions, setProviderOptions] = useState<
|
||||
WellKnownLLMProviderDescriptor[]
|
||||
>([]);
|
||||
|
||||
const fetchProviderInfo = async () => {
|
||||
const { providers, options, defaultCheckSuccessful } =
|
||||
await checkLlmProvider(user);
|
||||
setValidProviderExists(providers.length > 0 && defaultCheckSuccessful);
|
||||
setProviderOptions(options);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchProviderInfo();
|
||||
}, [router, user]);
|
||||
|
||||
const shouldShowConfigurationNeeded =
|
||||
!validProviderExists && providerOptions.length > 0;
|
||||
|
||||
const refreshProviderInfo = async () => {
|
||||
await fetchProviderInfo();
|
||||
};
|
||||
|
||||
return (
|
||||
<ProviderContext.Provider
|
||||
value={{
|
||||
shouldShowConfigurationNeeded,
|
||||
providerOptions,
|
||||
refreshProviderInfo, // Add this line
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</ProviderContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useProviderStatus() {
|
||||
const context = useContext(ProviderContext);
|
||||
if (context === undefined) {
|
||||
throw new Error(
|
||||
"useProviderStatus must be used within a ProviderContextProvider"
|
||||
);
|
||||
}
|
||||
return context;
|
||||
}
|
27
web/src/components/chat_search/UnconfiguredProviderText.tsx
Normal file
27
web/src/components/chat_search/UnconfiguredProviderText.tsx
Normal file
@ -0,0 +1,27 @@
|
||||
import { useProviderStatus } from "./ProviderContext";
|
||||
|
||||
export default function CredentialNotConfigured({
|
||||
showConfigureAPIKey,
|
||||
}: {
|
||||
showConfigureAPIKey: () => void;
|
||||
}) {
|
||||
const { shouldShowConfigurationNeeded } = useProviderStatus();
|
||||
|
||||
if (!shouldShowConfigurationNeeded) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<p className="text-base text-center w-full text-subtle">
|
||||
Please note that you have not yet configured an LLM provider. You can
|
||||
configure one{" "}
|
||||
<button
|
||||
onClick={showConfigureAPIKey}
|
||||
className="text-link hover:underline cursor-pointer"
|
||||
>
|
||||
here
|
||||
</button>
|
||||
.
|
||||
</p>
|
||||
);
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
import { Dispatch, SetStateAction, useEffect, useRef, useState } from "react";
|
||||
import { Dispatch, SetStateAction, useEffect, useRef } from "react";
|
||||
|
||||
interface UseSidebarVisibilityProps {
|
||||
toggledSidebar: boolean;
|
||||
|
@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import React, { createContext, useContext } from "react";
|
||||
import React, { createContext, useContext, useState } from "react";
|
||||
import { DocumentSet, Tag, User, ValidSources } from "@/lib/types";
|
||||
import { ChatSession } from "@/app/chat/interfaces";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
@ -18,15 +18,40 @@ interface ChatContextProps {
|
||||
folders: Folder[];
|
||||
openedFolders: Record<string, boolean>;
|
||||
userInputPrompts: InputPrompt[];
|
||||
shouldShowWelcomeModal?: boolean;
|
||||
shouldDisplaySourcesIncompleteModal?: boolean;
|
||||
defaultAssistantId?: number;
|
||||
refreshChatSessions: () => Promise<void>;
|
||||
}
|
||||
|
||||
const ChatContext = createContext<ChatContextProps | undefined>(undefined);
|
||||
|
||||
// We use Omit to exclude 'refreshChatSessions' from the value prop type
|
||||
// because we're defining it within the component
|
||||
export const ChatProvider: React.FC<{
|
||||
value: ChatContextProps;
|
||||
value: Omit<ChatContextProps, "refreshChatSessions">;
|
||||
children: React.ReactNode;
|
||||
}> = ({ value, children }) => {
|
||||
return <ChatContext.Provider value={value}>{children}</ChatContext.Provider>;
|
||||
const [chatSessions, setChatSessions] = useState(value?.chatSessions || []);
|
||||
|
||||
const refreshChatSessions = async () => {
|
||||
try {
|
||||
const response = await fetch("/api/chat/get-user-chat-sessions");
|
||||
if (!response.ok) throw new Error("Failed to fetch chat sessions");
|
||||
const { sessions } = await response.json();
|
||||
setChatSessions(sessions);
|
||||
} catch (error) {
|
||||
console.error("Error refreshing chat sessions:", error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<ChatContext.Provider
|
||||
value={{ ...value, chatSessions, refreshChatSessions }}
|
||||
>
|
||||
{children}
|
||||
</ChatContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export const useChatContext = (): ChatContextProps => {
|
||||
|
38
web/src/components/context/SearchContext.tsx
Normal file
38
web/src/components/context/SearchContext.tsx
Normal file
@ -0,0 +1,38 @@
|
||||
"use client";
|
||||
|
||||
import React, { createContext, useContext } from "react";
|
||||
import { CCPairBasicInfo, DocumentSet, Tag } from "@/lib/types";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { ChatSession } from "@/app/chat/interfaces";
|
||||
|
||||
interface SearchContextProps {
|
||||
querySessions: ChatSession[];
|
||||
ccPairs: CCPairBasicInfo[];
|
||||
documentSets: DocumentSet[];
|
||||
assistants: Persona[];
|
||||
tags: Tag[];
|
||||
agenticSearchEnabled: boolean;
|
||||
disabledAgentic: boolean;
|
||||
initiallyToggled: boolean;
|
||||
shouldShowWelcomeModal: boolean;
|
||||
shouldDisplayNoSources: boolean;
|
||||
}
|
||||
|
||||
const SearchContext = createContext<SearchContextProps | undefined>(undefined);
|
||||
|
||||
export const SearchProvider: React.FC<{
|
||||
value: SearchContextProps;
|
||||
children: React.ReactNode;
|
||||
}> = ({ value, children }) => {
|
||||
return (
|
||||
<SearchContext.Provider value={value}>{children}</SearchContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export const useSearchContext = (): SearchContextProps => {
|
||||
const context = useContext(SearchContext);
|
||||
if (!context) {
|
||||
throw new Error("useSearchContext must be used within a SearchProvider");
|
||||
}
|
||||
return context;
|
||||
};
|
@ -27,13 +27,11 @@ function UsageTypeSection({
|
||||
title,
|
||||
description,
|
||||
callToAction,
|
||||
icon,
|
||||
onClick,
|
||||
}: {
|
||||
title: string;
|
||||
description: string | JSX.Element;
|
||||
callToAction: string;
|
||||
icon?: React.ElementType;
|
||||
onClick: () => void;
|
||||
}) {
|
||||
return (
|
||||
@ -243,7 +241,6 @@ export function _WelcomeModal({ user }: { user: User | null }) {
|
||||
this is the option for you!
|
||||
</Text>
|
||||
}
|
||||
icon={FiMessageSquare}
|
||||
callToAction="Get Started"
|
||||
onClick={() => {
|
||||
setSelectedFlow("chat");
|
||||
|
@ -55,6 +55,7 @@ export const ApiKeyForm = ({
|
||||
return (
|
||||
<TabPanel key={provider.name}>
|
||||
<LLMProviderUpdateForm
|
||||
hideAdvanced
|
||||
llmProviderDescriptor={provider}
|
||||
onClose={() => onSuccess()}
|
||||
shouldMarkAsDefault
|
||||
|
@ -1,60 +1,38 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { ApiKeyForm } from "./ApiKeyForm";
|
||||
import { Modal } from "../Modal";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { checkLlmProvider } from "../initialSetup/welcome/lib";
|
||||
import { User } from "@/lib/types";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useProviderStatus } from "../chat_search/ProviderContext";
|
||||
|
||||
export const ApiKeyModal = ({ user }: { user: User | null }) => {
|
||||
export const ApiKeyModal = ({ hide }: { hide: () => void }) => {
|
||||
const router = useRouter();
|
||||
|
||||
const [forceHidden, setForceHidden] = useState<boolean>(false);
|
||||
const [validProviderExists, setValidProviderExists] = useState<boolean>(true);
|
||||
const [providerOptions, setProviderOptions] = useState<
|
||||
WellKnownLLMProviderDescriptor[]
|
||||
>([]);
|
||||
const {
|
||||
shouldShowConfigurationNeeded,
|
||||
providerOptions,
|
||||
refreshProviderInfo,
|
||||
} = useProviderStatus();
|
||||
|
||||
useEffect(() => {
|
||||
async function fetchProviderInfo() {
|
||||
const { providers, options, defaultCheckSuccessful } =
|
||||
await checkLlmProvider(user);
|
||||
setValidProviderExists(providers.length > 0 && defaultCheckSuccessful);
|
||||
setProviderOptions(options);
|
||||
}
|
||||
|
||||
fetchProviderInfo();
|
||||
}, []);
|
||||
|
||||
// don't show if
|
||||
// (1) a valid provider has been setup or
|
||||
// (2) there are no provider options (e.g. user isn't an admin)
|
||||
// (3) user explicitly hides the modal
|
||||
if (validProviderExists || !providerOptions.length || forceHidden) {
|
||||
if (!shouldShowConfigurationNeeded) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title="LLM Key Setup"
|
||||
className="max-w-4xl"
|
||||
onOutsideClick={() => setForceHidden(true)}
|
||||
title="Set an API Key!"
|
||||
className="max-w-3xl"
|
||||
onOutsideClick={() => hide()}
|
||||
>
|
||||
<div className="max-h-[75vh] overflow-y-auto flex flex-col px-4">
|
||||
<div>
|
||||
<div className="mb-5 text-sm">
|
||||
Please setup an LLM below in order to start using Danswer Search or
|
||||
Danswer Chat. Don't worry, you can always change this later in
|
||||
the Admin Panel.
|
||||
Please provide an API Key below in order to start using
|
||||
Danswer – you can always change this later.
|
||||
<br />
|
||||
<br />
|
||||
Or if you'd rather look around first,{" "}
|
||||
<strong
|
||||
onClick={() => setForceHidden(true)}
|
||||
className="text-link cursor-pointer"
|
||||
>
|
||||
If you'd rather look around first, you can
|
||||
<strong onClick={() => hide()} className="text-link cursor-pointer">
|
||||
{" "}
|
||||
skip this step
|
||||
</strong>
|
||||
.
|
||||
@ -63,7 +41,8 @@ export const ApiKeyModal = ({ user }: { user: User | null }) => {
|
||||
<ApiKeyForm
|
||||
onSuccess={() => {
|
||||
router.refresh();
|
||||
setForceHidden(true);
|
||||
refreshProviderInfo();
|
||||
hide();
|
||||
}}
|
||||
providerOptions={providerOptions}
|
||||
/>
|
||||
|
@ -37,6 +37,10 @@ import { FeedbackModal } from "@/app/chat/modal/FeedbackModal";
|
||||
import { deleteChatSession, handleChatFeedback } from "@/app/chat/lib";
|
||||
import SearchAnswer from "./SearchAnswer";
|
||||
import { DeleteEntityModal } from "../modals/DeleteEntityModal";
|
||||
import { ApiKeyModal } from "../llm/ApiKeyModal";
|
||||
import { useSearchContext } from "../context/SearchContext";
|
||||
import { useUser } from "../user/UserProvider";
|
||||
import UnconfiguredProviderText from "../chat_search/UnconfiguredProviderText";
|
||||
|
||||
export type searchState =
|
||||
| "input"
|
||||
@ -58,33 +62,28 @@ const VALID_QUESTION_RESPONSE_DEFAULT: ValidQuestionResponse = {
|
||||
};
|
||||
|
||||
interface SearchSectionProps {
|
||||
disabledAgentic: boolean;
|
||||
ccPairs: CCPairBasicInfo[];
|
||||
documentSets: DocumentSet[];
|
||||
personas: Persona[];
|
||||
tags: Tag[];
|
||||
toggle: () => void;
|
||||
querySessions: ChatSession[];
|
||||
defaultSearchType: SearchType;
|
||||
user: User | null;
|
||||
toggledSidebar: boolean;
|
||||
agenticSearchEnabled: boolean;
|
||||
}
|
||||
|
||||
export const SearchSection = ({
|
||||
ccPairs,
|
||||
toggle,
|
||||
disabledAgentic,
|
||||
documentSets,
|
||||
agenticSearchEnabled,
|
||||
personas,
|
||||
user,
|
||||
tags,
|
||||
querySessions,
|
||||
toggledSidebar,
|
||||
defaultSearchType,
|
||||
}: SearchSectionProps) => {
|
||||
// Search Bar
|
||||
const {
|
||||
querySessions,
|
||||
ccPairs,
|
||||
documentSets,
|
||||
assistants,
|
||||
tags,
|
||||
shouldShowWelcomeModal,
|
||||
agenticSearchEnabled,
|
||||
disabledAgentic,
|
||||
shouldDisplayNoSources,
|
||||
} = useSearchContext();
|
||||
|
||||
const [query, setQuery] = useState<string>("");
|
||||
const [comments, setComments] = useState<any>(null);
|
||||
const [contentEnriched, setContentEnriched] = useState(false);
|
||||
@ -100,6 +99,8 @@ export const SearchSection = ({
|
||||
messageId: null,
|
||||
});
|
||||
|
||||
const [showApiKeyModal, setShowApiKeyModal] = useState(true);
|
||||
|
||||
const [agentic, setAgentic] = useState(agenticSearchEnabled);
|
||||
|
||||
const toggleAgentic = () => {
|
||||
@ -147,7 +148,7 @@ export const SearchSection = ({
|
||||
useState<SearchType>(defaultSearchType);
|
||||
|
||||
const [selectedPersona, setSelectedPersona] = useState<number>(
|
||||
personas[0]?.id || 0
|
||||
assistants[0]?.id || 0
|
||||
);
|
||||
|
||||
// Used for search state display
|
||||
@ -158,8 +159,8 @@ export const SearchSection = ({
|
||||
const availableSources = ccPairs.map((ccPair) => ccPair.source);
|
||||
const [finalAvailableSources, finalAvailableDocumentSets] =
|
||||
computeAvailableFilters({
|
||||
selectedPersona: personas.find(
|
||||
(persona) => persona.id === selectedPersona
|
||||
selectedPersona: assistants.find(
|
||||
(assistant) => assistant.id === selectedPersona
|
||||
),
|
||||
availableSources: availableSources,
|
||||
availableDocumentSets: documentSets,
|
||||
@ -362,6 +363,7 @@ export const SearchSection = ({
|
||||
setSearchState("input");
|
||||
}
|
||||
};
|
||||
const { user } = useUser();
|
||||
const [searchAnswerExpanded, setSearchAnswerExpanded] = useState(false);
|
||||
|
||||
const resetInput = (finalized?: boolean) => {
|
||||
@ -403,8 +405,8 @@ export const SearchSection = ({
|
||||
documentSets: filterManager.selectedDocumentSets,
|
||||
timeRange: filterManager.timeRange,
|
||||
tags: filterManager.selectedTags,
|
||||
persona: personas.find(
|
||||
(persona) => persona.id === selectedPersona
|
||||
persona: assistants.find(
|
||||
(assistant) => assistant.id === selectedPersona
|
||||
) as Persona,
|
||||
updateCurrentAnswer: cancellable({
|
||||
cancellationToken: lastSearchCancellationToken.current,
|
||||
@ -595,6 +597,12 @@ export const SearchSection = ({
|
||||
<div className="flex relative pr-[8px] h-full text-default">
|
||||
{popup}
|
||||
|
||||
{!shouldDisplayNoSources &&
|
||||
showApiKeyModal &&
|
||||
!shouldShowWelcomeModal && (
|
||||
<ApiKeyModal hide={() => setShowApiKeyModal(false)} />
|
||||
)}
|
||||
|
||||
{deletingChatSession && (
|
||||
<DeleteEntityModal
|
||||
entityType="search"
|
||||
@ -747,6 +755,11 @@ export const SearchSection = ({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<UnconfiguredProviderText
|
||||
showConfigureAPIKey={() => setShowApiKeyModal(true)}
|
||||
/>
|
||||
|
||||
<FullSearchBar
|
||||
toggleAgentic={
|
||||
disabledAgentic ? undefined : toggleAgentic
|
||||
|
@ -44,7 +44,6 @@ interface FetchChatDataResult {
|
||||
toggleSidebar: boolean;
|
||||
finalDocumentSidebarInitialWidth?: number;
|
||||
shouldShowWelcomeModal: boolean;
|
||||
shouldDisplaySourcesIncompleteModal: boolean;
|
||||
userInputPrompts: InputPrompt[];
|
||||
}
|
||||
|
||||
@ -242,7 +241,6 @@ export async function fetchChatData(searchParams: {
|
||||
finalDocumentSidebarInitialWidth,
|
||||
toggleSidebar,
|
||||
shouldShowWelcomeModal,
|
||||
shouldDisplaySourcesIncompleteModal,
|
||||
userInputPrompts,
|
||||
};
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user