diff --git a/backend/open_webui/apps/retrieval/loader/main.py b/backend/open_webui/apps/retrieval/loader/main.py new file mode 100644 index 000000000..f4c948b43 --- /dev/null +++ b/backend/open_webui/apps/retrieval/loader/main.py @@ -0,0 +1,183 @@ +import requests +import logging + +from langchain_community.document_loaders import ( + BSHTMLLoader, + CSVLoader, + Docx2txtLoader, + OutlookMessageLoader, + PyPDFLoader, + TextLoader, + UnstructuredEPubLoader, + UnstructuredExcelLoader, + UnstructuredMarkdownLoader, + UnstructuredPowerPointLoader, + UnstructuredRSTLoader, + UnstructuredXMLLoader, + YoutubeLoader, +) +from langchain_core.documents import Document +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +known_source_ext = [ + "go", + "py", + "java", + "sh", + "bat", + "ps1", + "cmd", + "js", + "ts", + "css", + "cpp", + "hpp", + "h", + "c", + "cs", + "sql", + "log", + "ini", + "pl", + "pm", + "r", + "dart", + "dockerfile", + "env", + "php", + "hs", + "hsc", + "lua", + "nginxconf", + "conf", + "m", + "mm", + "plsql", + "perl", + "rb", + "rs", + "db2", + "scala", + "bash", + "swift", + "vue", + "svelte", + "msg", + "ex", + "exs", + "erl", + "tsx", + "jsx", + "hs", + "lhs", +] + + +class TikaLoader: + def __init__(self, url, file_path, mime_type=None): + self.url = url + self.file_path = file_path + self.mime_type = mime_type + + def load(self) -> list[Document]: + with open(self.file_path, "rb") as f: + data = f.read() + + if self.mime_type is not None: + headers = {"Content-Type": self.mime_type} + else: + headers = {} + + endpoint = self.url + if not endpoint.endswith("/"): + endpoint += "/" + endpoint += "tika/text" + + r = requests.put(endpoint, data=data, headers=headers) + + if r.ok: + raw_metadata = r.json() + text = raw_metadata.get("X-TIKA:content", "") + + if "Content-Type" in raw_metadata: + headers["Content-Type"] = raw_metadata["Content-Type"] + + log.info("Tika extracted text: %s", text) + + return [Document(page_content=text, metadata=headers)] + else: + raise Exception(f"Error calling Tika: {r.reason}") + + +class Loader: + def __init__(self, engine: str = "", **kwargs): + self.engine = engine + self.kwargs = kwargs + + def load( + self, filename: str, file_content_type: str, file_path: str + ) -> list[Document]: + loader = self._get_loader(filename, file_content_type, file_path) + return loader.load() + + def _get_loader(self, filename: str, file_content_type: str, file_path: str): + file_ext = filename.split(".")[-1].lower() + + if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): + if file_ext in known_source_ext or ( + file_content_type and file_content_type.find("text/") >= 0 + ): + loader = TextLoader(file_path, autodetect_encoding=True) + else: + loader = TikaLoader( + url=self.kwargs.get("TIKA_SERVER_URL"), + file_path=file_path, + mime_type=file_content_type, + ) + else: + if file_ext == "pdf": + loader = PyPDFLoader( + file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES") + ) + elif file_ext == "csv": + loader = CSVLoader(file_path) + elif file_ext == "rst": + loader = UnstructuredRSTLoader(file_path, mode="elements") + elif file_ext == "xml": + loader = UnstructuredXMLLoader(file_path) + elif file_ext in ["htm", "html"]: + loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") + elif file_ext == "md": + loader = UnstructuredMarkdownLoader(file_path) + elif file_content_type == "application/epub+zip": + loader = UnstructuredEPubLoader(file_path) + elif ( + file_content_type + == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + or file_ext == "docx" + ): + loader = Docx2txtLoader(file_path) + elif file_content_type in [ + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ] or file_ext in ["xls", "xlsx"]: + loader = UnstructuredExcelLoader(file_path) + elif file_content_type in [ + "application/vnd.ms-powerpoint", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ] or file_ext in ["ppt", "pptx"]: + loader = UnstructuredPowerPointLoader(file_path) + elif file_ext == "msg": + loader = OutlookMessageLoader(file_path) + elif file_ext in known_source_ext or ( + file_content_type and file_content_type.find("text/") >= 0 + ): + loader = TextLoader(file_path, autodetect_encoding=True) + else: + loader = TextLoader(file_path, autodetect_encoding=True) + + return loader diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 8f23ea2c5..3e1ec8854 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -3,34 +3,39 @@ import logging import mimetypes import os import shutil -import socket -import urllib.parse + import uuid from datetime import datetime from pathlib import Path from typing import Iterator, Optional, Sequence, Union - -import numpy as np -import torch -import requests -import validators - from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel -from open_webui.apps.retrieval.search.main import SearchResult -from open_webui.apps.retrieval.search.brave import search_brave -from open_webui.apps.retrieval.search.duckduckgo import search_duckduckgo -from open_webui.apps.retrieval.search.google_pse import search_google_pse -from open_webui.apps.retrieval.search.jina_search import search_jina -from open_webui.apps.retrieval.search.searchapi import search_searchapi -from open_webui.apps.retrieval.search.searxng import search_searxng -from open_webui.apps.retrieval.search.serper import search_serper -from open_webui.apps.retrieval.search.serply import search_serply -from open_webui.apps.retrieval.search.serpstack import search_serpstack -from open_webui.apps.retrieval.search.tavily import search_tavily +from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT + +# Information retrieval models +from open_webui.apps.retrieval.model.colbert import ColBERT + +# Document loaders +from open_webui.apps.retrieval.loader.main import Loader + +# Web search engines +from open_webui.apps.retrieval.web.main import SearchResult +from open_webui.apps.retrieval.web.utils import get_web_loader +from open_webui.apps.retrieval.web.brave import search_brave +from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo +from open_webui.apps.retrieval.web.google_pse import search_google_pse +from open_webui.apps.retrieval.web.jina_search import search_jina +from open_webui.apps.retrieval.web.searchapi import search_searchapi +from open_webui.apps.retrieval.web.searxng import search_searxng +from open_webui.apps.retrieval.web.serper import search_serper +from open_webui.apps.retrieval.web.serply import search_serply +from open_webui.apps.retrieval.web.serpstack import search_serpstack +from open_webui.apps.retrieval.web.tavily import search_tavily + + from open_webui.apps.retrieval.utils import ( get_embedding_function, get_model_path, @@ -39,6 +44,7 @@ from open_webui.apps.retrieval.utils import ( query_doc, query_doc_with_hybrid_search, ) + from open_webui.apps.webui.models.documents import DocumentForm, Documents from open_webui.apps.webui.models.files import Files from open_webui.config import ( @@ -98,28 +104,13 @@ from open_webui.utils.misc import ( sanitize_filename, ) from open_webui.utils.utils import get_admin_user, get_verified_user -from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import ( - BSHTMLLoader, - CSVLoader, - Docx2txtLoader, - OutlookMessageLoader, - PyPDFLoader, - TextLoader, - UnstructuredEPubLoader, - UnstructuredExcelLoader, - UnstructuredMarkdownLoader, - UnstructuredPowerPointLoader, - UnstructuredRSTLoader, - UnstructuredXMLLoader, - WebBaseLoader, YoutubeLoader, ) from langchain_core.documents import Document -from colbert.infra import ColBERTConfig -from colbert.modeling.checkpoint import Checkpoint + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -200,83 +191,6 @@ def update_reranking_model( ): if reranking_model: if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): - - class ColBERT: - def __init__(self, name) -> None: - print("ColBERT: Loading model", name) - self.device = "cuda" if torch.cuda.is_available() else "cpu" - - if DOCKER: - # This is a workaround for the issue with the docker container - # where the torch extension is not loaded properly - # and the following error is thrown: - # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory - - lock_file = "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock" - if os.path.exists(lock_file): - os.remove(lock_file) - - self.ckpt = Checkpoint( - name, - colbert_config=ColBERTConfig(model_name=name), - ).to(self.device) - pass - - def calculate_similarity_scores( - self, query_embeddings, document_embeddings - ): - - query_embeddings = query_embeddings.to(self.device) - document_embeddings = document_embeddings.to(self.device) - - # Validate dimensions to ensure compatibility - if query_embeddings.dim() != 3: - raise ValueError( - f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}." - ) - if document_embeddings.dim() != 3: - raise ValueError( - f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}." - ) - if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: - raise ValueError( - "There should be either one query or queries equal to the number of documents." - ) - - # Transpose the query embeddings to align for matrix multiplication - transposed_query_embeddings = query_embeddings.permute(0, 2, 1) - # Compute similarity scores using batch matrix multiplication - computed_scores = torch.matmul( - document_embeddings, transposed_query_embeddings - ) - # Apply max pooling to extract the highest semantic similarity across each document's sequence - maximum_scores = torch.max(computed_scores, dim=1).values - - # Sum up the maximum scores across features to get the overall document relevance scores - final_scores = maximum_scores.sum(dim=1) - - normalized_scores = torch.softmax(final_scores, dim=0) - - return normalized_scores.detach().cpu().numpy().astype(np.float32) - - def predict(self, sentences): - - query = sentences[0][0] - docs = [i[1] for i in sentences] - - # Embedding the documents - embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0] - # Embedding the queries - embedded_queries = self.ckpt.queryFromText([query], bsize=32) - embedded_query = embedded_queries[0] - - # Calculate retrieval scores for the query against all documents - scores = self.calculate_similarity_scores( - embedded_query.unsqueeze(0), embedded_docs - ) - - return scores - try: app.state.sentence_transformer_rf = ColBERT( get_model_path(reranking_model, auto_update) @@ -707,89 +621,319 @@ async def update_query_settings( } -class QueryDocForm(BaseModel): - collection_name: str - query: str - k: Optional[int] = None - r: Optional[float] = None - hybrid: Optional[bool] = None +#################################### +# +# Document process and retrieval +# +#################################### -@app.post("/query/doc") -def query_doc_handler( - form_data: QueryDocForm, +def store_data_in_vector_db( + data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False +) -> bool: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + + docs = text_splitter.split_documents(data) + + if len(docs) > 0: + log.info(f"store_data_in_vector_db {docs}") + return store_docs_in_vector_db(docs, collection_name, metadata, overwrite) + else: + raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) + + +def store_text_in_vector_db( + text, metadata, collection_name, overwrite: bool = False +) -> bool: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=app.state.config.CHUNK_SIZE, + chunk_overlap=app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + docs = text_splitter.create_documents([text], metadatas=[metadata]) + return store_docs_in_vector_db(docs, collection_name, overwrite=overwrite) + + +def store_docs_in_vector_db( + docs, collection_name, metadata: Optional[dict] = None, overwrite: bool = False +) -> bool: + log.info(f"store_docs_in_vector_db {docs} {collection_name}") + + texts = [doc.page_content for doc in docs] + metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs] + + # ChromaDB does not like datetime formats + # for meta-data so convert them to string. + for metadata in metadatas: + for key, value in metadata.items(): + if isinstance(value, datetime): + metadata[key] = str(value) + + try: + if overwrite: + if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): + log.info(f"deleting existing collection {collection_name}") + VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) + + if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): + log.info(f"collection {collection_name} already exists") + return True + else: + embedding_function = get_embedding_function( + app.state.config.RAG_EMBEDDING_ENGINE, + app.state.config.RAG_EMBEDDING_MODEL, + app.state.sentence_transformer_ef, + app.state.config.OPENAI_API_KEY, + app.state.config.OPENAI_API_BASE_URL, + app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, + ) + + embeddings = embedding_function( + list(map(lambda x: x.replace("\n", " "), texts)) + ) + + VECTOR_DB_CLIENT.insert( + collection_name=collection_name, + items=[ + { + "id": str(uuid.uuid4()), + "text": text, + "vector": embeddings[idx], + "metadata": metadatas[idx], + } + for idx, text in enumerate(texts) + ], + ) + + return True + except Exception as e: + log.exception(e) + return False + + +@app.post("/doc") +def store_doc( + collection_name: Optional[str] = Form(None), + file: UploadFile = File(...), user=Depends(get_verified_user), ): + # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" + + log.info(f"file.content_type: {file.content_type}") try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: - return query_doc_with_hybrid_search( - collection_name=form_data.collection_name, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, - r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD - ), - ) - else: - return query_doc( - collection_name=form_data.collection_name, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, + unsanitized_filename = file.filename + filename = os.path.basename(unsanitized_filename) + + file_path = f"{UPLOAD_DIR}/{filename}" + + contents = file.file.read() + with open(file_path, "wb") as f: + f.write(contents) + f.close() + + f = open(file_path, "rb") + if collection_name is None: + collection_name = calculate_sha256(f)[:63] + f.close() + + loader = Loader( + engine=app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, + ) + data = loader.load(filename, file.content_type, file_path) + + try: + result = store_data_in_vector_db(data, collection_name) + + if result: + return { + "status": True, + "collection_name": collection_name, + "filename": filename, + } + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=e, ) except Exception as e: log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) + if "No pandoc was found" in str(e): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) -class QueryCollectionsForm(BaseModel): - collection_names: list[str] - query: str - k: Optional[int] = None - r: Optional[float] = None - hybrid: Optional[bool] = None +class ProcessFileForm(BaseModel): + file_id: str + collection_name: Optional[str] = None -@app.post("/query/collection") -def query_collection_handler( - form_data: QueryCollectionsForm, +@app.post("/process/file") +def process_file( + form_data: ProcessFileForm, user=Depends(get_verified_user), ): try: - if app.state.config.ENABLE_RAG_HYBRID_SEARCH: - return query_collection_with_hybrid_search( - collection_names=form_data.collection_names, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, - reranking_function=app.state.sentence_transformer_rf, - r=( - form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD - ), - ) - else: - return query_collection( - collection_names=form_data.collection_names, - query=form_data.query, - embedding_function=app.state.EMBEDDING_FUNCTION, - k=form_data.k if form_data.k else app.state.config.TOP_K, + file = Files.get_file_by_id(form_data.file_id) + file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}") + + loader = Loader( + engine=app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, + ) + data = loader.load(file.filename, file.meta.get("content_type"), file_path) + + f = open(file_path, "rb") + collection_name = form_data.collection_name + if collection_name is None: + collection_name = calculate_sha256(f)[:63] + f.close() + + try: + result = store_data_in_vector_db( + data, + collection_name, + { + "file_id": form_data.file_id, + "name": file.meta.get("name", file.filename), + }, ) + if result: + + return { + "status": True, + "collection_name": collection_name, + "known_type": known_type, + "filename": file.meta.get("name", file.filename), + } + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=e, + ) except Exception as e: log.exception(e) + if "No pandoc was found" in str(e): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +class TextRAGForm(BaseModel): + name: str + content: str + collection_name: Optional[str] = None + + +@app.post("/text") +def store_text( + form_data: TextRAGForm, + user=Depends(get_verified_user), +): + collection_name = form_data.collection_name + if collection_name is None: + collection_name = calculate_sha256_string(form_data.content) + + result = store_text_in_vector_db( + form_data.content, + metadata={"name": form_data.name, "created_by": user.id}, + collection_name=collection_name, + ) + + if result: + return {"status": True, "collection_name": collection_name} + else: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(), ) -@app.post("/youtube") -def store_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)): +@app.get("/process/dir") +def process_docs_dir(user=Depends(get_admin_user)): + for path in Path(DOCS_DIR).rglob("./**/*"): + try: + if path.is_file() and not path.name.startswith("."): + tags = extract_folders_after_data_docs(path) + filename = path.name + file_content_type = mimetypes.guess_type(path) + + f = open(path, "rb") + collection_name = calculate_sha256(f)[:63] + f.close() + + loader = Loader( + engine=app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, + ) + data = loader.load(filename, file_content_type[0], str(path)) + + try: + result = store_data_in_vector_db(data, collection_name) + + if result: + sanitized_filename = sanitize_filename(filename) + doc = Documents.get_doc_by_name(sanitized_filename) + + if doc is None: + doc = Documents.insert_new_doc( + user.id, + DocumentForm( + **{ + "name": sanitized_filename, + "title": filename, + "collection_name": collection_name, + "filename": filename, + "content": ( + json.dumps( + { + "tags": list( + map( + lambda name: {"name": name}, + tags, + ) + ) + } + ) + if len(tags) + else "{}" + ), + } + ), + ) + except Exception as e: + log.exception(e) + pass + + except Exception as e: + log.exception(e) + + return True + + +@app.post("/process/youtube") +def process_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)): try: loader = YoutubeLoader.from_youtube_url( form_data.url, @@ -817,13 +961,14 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)): ) -@app.post("/web") -def store_web(form_data: UrlForm, user=Depends(get_verified_user)): +@app.post("/process/web") +def process_web(form_data: UrlForm, user=Depends(get_verified_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: loader = get_web_loader( form_data.url, verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, ) data = loader.load() @@ -845,53 +990,6 @@ def store_web(form_data: UrlForm, user=Depends(get_verified_user)): ) -def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True): - # Check if the URL is valid - if not validate_url(url): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - return SafeWebBaseLoader( - url, - verify_ssl=verify_ssl, - requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS, - continue_on_failure=True, - ) - - -def validate_url(url: Union[str, Sequence[str]]): - if isinstance(url, str): - if isinstance(validators.url(url), validators.ValidationError): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - if not ENABLE_RAG_LOCAL_WEB_FETCH: - # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses - parsed_url = urllib.parse.urlparse(url) - # Get IPv4 and IPv6 addresses - ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) - # Check if any of the resolved addresses are private - # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader - for ip in ipv4_addresses: - if validators.ipv4(ip, private=True): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - for ip in ipv6_addresses: - if validators.ipv6(ip, private=True): - raise ValueError(ERROR_MESSAGES.INVALID_URL) - return True - elif isinstance(url, Sequence): - return all(validate_url(u) for u in url) - else: - return False - - -def resolve_hostname(hostname): - # Get address information - addr_info = socket.getaddrinfo(hostname, None) - - # Extract IP addresses from address information - ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] - ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] - - return ipv4_addresses, ipv6_addresses - - def search_web(engine: str, query: str) -> list[SearchResult]: """Search the web using a search engine and return the results as a list of SearchResult objects. Will look for a search engine API key in environment variables in the following order: @@ -1007,8 +1105,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]: raise Exception("No search engine API key found in environment variables") -@app.post("/web/search") -def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)): +@app.post("/process/web/search") +def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)): try: logging.info( f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" @@ -1048,450 +1146,92 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)): ) -def store_data_in_vector_db( - data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False -) -> bool: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) - - docs = text_splitter.split_documents(data) - - if len(docs) > 0: - log.info(f"store_data_in_vector_db {docs}") - return store_docs_in_vector_db(docs, collection_name, metadata, overwrite) - else: - raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) +class QueryDocForm(BaseModel): + collection_name: str + query: str + k: Optional[int] = None + r: Optional[float] = None + hybrid: Optional[bool] = None -def store_text_in_vector_db( - text, metadata, collection_name, overwrite: bool = False -) -> bool: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.config.CHUNK_SIZE, - chunk_overlap=app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) - docs = text_splitter.create_documents([text], metadatas=[metadata]) - return store_docs_in_vector_db(docs, collection_name, overwrite=overwrite) - - -def store_docs_in_vector_db( - docs, collection_name, metadata: Optional[dict] = None, overwrite: bool = False -) -> bool: - log.info(f"store_docs_in_vector_db {docs} {collection_name}") - - texts = [doc.page_content for doc in docs] - metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs] - - # ChromaDB does not like datetime formats - # for meta-data so convert them to string. - for metadata in metadatas: - for key, value in metadata.items(): - if isinstance(value, datetime): - metadata[key] = str(value) - - try: - if overwrite: - if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): - log.info(f"deleting existing collection {collection_name}") - VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) - - if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): - log.info(f"collection {collection_name} already exists") - return True - else: - embedding_function = get_embedding_function( - app.state.config.RAG_EMBEDDING_ENGINE, - app.state.config.RAG_EMBEDDING_MODEL, - app.state.sentence_transformer_ef, - app.state.config.OPENAI_API_KEY, - app.state.config.OPENAI_API_BASE_URL, - app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE, - ) - - embedding_texts = embedding_function( - list(map(lambda x: x.replace("\n", " "), texts)) - ) - - VECTOR_DB_CLIENT.insert( - collection_name=collection_name, - items=[ - { - "id": str(uuid.uuid4()), - "text": text, - "vector": embedding_texts[idx], - "metadata": metadatas[idx], - } - for idx, text in enumerate(texts) - ], - ) - - return True - except Exception as e: - log.exception(e) - return False - - -class TikaLoader: - def __init__(self, file_path, mime_type=None): - self.file_path = file_path - self.mime_type = mime_type - - def load(self) -> list[Document]: - with open(self.file_path, "rb") as f: - data = f.read() - - if self.mime_type is not None: - headers = {"Content-Type": self.mime_type} - else: - headers = {} - - endpoint = app.state.config.TIKA_SERVER_URL - if not endpoint.endswith("/"): - endpoint += "/" - endpoint += "tika/text" - - r = requests.put(endpoint, data=data, headers=headers) - - if r.ok: - raw_metadata = r.json() - text = raw_metadata.get("X-TIKA:content", "") - - if "Content-Type" in raw_metadata: - headers["Content-Type"] = raw_metadata["Content-Type"] - - log.info("Tika extracted text: %s", text) - - return [Document(page_content=text, metadata=headers)] - else: - raise Exception(f"Error calling Tika: {r.reason}") - - -def get_loader(filename: str, file_content_type: str, file_path: str): - file_ext = filename.split(".")[-1].lower() - known_type = True - - known_source_ext = [ - "go", - "py", - "java", - "sh", - "bat", - "ps1", - "cmd", - "js", - "ts", - "css", - "cpp", - "hpp", - "h", - "c", - "cs", - "sql", - "log", - "ini", - "pl", - "pm", - "r", - "dart", - "dockerfile", - "env", - "php", - "hs", - "hsc", - "lua", - "nginxconf", - "conf", - "m", - "mm", - "plsql", - "perl", - "rb", - "rs", - "db2", - "scala", - "bash", - "swift", - "vue", - "svelte", - "msg", - "ex", - "exs", - "erl", - "tsx", - "jsx", - "hs", - "lhs", - ] - - if ( - app.state.config.CONTENT_EXTRACTION_ENGINE == "tika" - and app.state.config.TIKA_SERVER_URL - ): - if file_ext in known_source_ext or ( - file_content_type and file_content_type.find("text/") >= 0 - ): - loader = TextLoader(file_path, autodetect_encoding=True) - else: - loader = TikaLoader(file_path, file_content_type) - else: - if file_ext == "pdf": - loader = PyPDFLoader( - file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES - ) - elif file_ext == "csv": - loader = CSVLoader(file_path) - elif file_ext == "rst": - loader = UnstructuredRSTLoader(file_path, mode="elements") - elif file_ext == "xml": - loader = UnstructuredXMLLoader(file_path) - elif file_ext in ["htm", "html"]: - loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") - elif file_ext == "md": - loader = UnstructuredMarkdownLoader(file_path) - elif file_content_type == "application/epub+zip": - loader = UnstructuredEPubLoader(file_path) - elif ( - file_content_type - == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - or file_ext == "docx" - ): - loader = Docx2txtLoader(file_path) - elif file_content_type in [ - "application/vnd.ms-excel", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - ] or file_ext in ["xls", "xlsx"]: - loader = UnstructuredExcelLoader(file_path) - elif file_content_type in [ - "application/vnd.ms-powerpoint", - "application/vnd.openxmlformats-officedocument.presentationml.presentation", - ] or file_ext in ["ppt", "pptx"]: - loader = UnstructuredPowerPointLoader(file_path) - elif file_ext == "msg": - loader = OutlookMessageLoader(file_path) - elif file_ext in known_source_ext or ( - file_content_type and file_content_type.find("text/") >= 0 - ): - loader = TextLoader(file_path, autodetect_encoding=True) - else: - loader = TextLoader(file_path, autodetect_encoding=True) - known_type = False - - return loader, known_type - - -@app.post("/doc") -def store_doc( - collection_name: Optional[str] = Form(None), - file: UploadFile = File(...), - user=Depends(get_verified_user), -): - # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" - - log.info(f"file.content_type: {file.content_type}") - try: - unsanitized_filename = file.filename - filename = os.path.basename(unsanitized_filename) - - file_path = f"{UPLOAD_DIR}/{filename}" - - contents = file.file.read() - with open(file_path, "wb") as f: - f.write(contents) - f.close() - - f = open(file_path, "rb") - if collection_name is None: - collection_name = calculate_sha256(f)[:63] - f.close() - - loader, known_type = get_loader(filename, file.content_type, file_path) - data = loader.load() - - try: - result = store_data_in_vector_db(data, collection_name) - - if result: - return { - "status": True, - "collection_name": collection_name, - "filename": filename, - "known_type": known_type, - } - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=e, - ) - except Exception as e: - log.exception(e) - if "No pandoc was found" in str(e): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -class ProcessFileForm(BaseModel): - file_id: str - collection_name: Optional[str] = None - - -@app.post("/process/file") -def process_file( - form_data: ProcessFileForm, +@app.post("/query/doc") +def query_doc_handler( + form_data: QueryDocForm, user=Depends(get_verified_user), ): try: - file = Files.get_file_by_id(form_data.file_id) - file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}") - - f = open(file_path, "rb") - - collection_name = form_data.collection_name - if collection_name is None: - collection_name = calculate_sha256(f)[:63] - f.close() - - loader, known_type = get_loader( - file.filename, file.meta.get("content_type"), file_path - ) - data = loader.load() - - try: - result = store_data_in_vector_db( - data, - collection_name, - { - "file_id": form_data.file_id, - "name": file.meta.get("name", file.filename), - }, + if app.state.config.ENABLE_RAG_HYBRID_SEARCH: + return query_doc_with_hybrid_search( + collection_name=form_data.collection_name, + query=form_data.query, + embedding_function=app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else app.state.config.TOP_K, + reranking_function=app.state.sentence_transformer_rf, + r=( + form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + ), ) - - if result: - - return { - "status": True, - "collection_name": collection_name, - "known_type": known_type, - "filename": file.meta.get("name", file.filename), - } - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=e, + else: + return query_doc( + collection_name=form_data.collection_name, + query=form_data.query, + embedding_function=app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else app.state.config.TOP_K, ) except Exception as e: log.exception(e) - if "No pandoc was found" in str(e): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - -class TextRAGForm(BaseModel): - name: str - content: str - collection_name: Optional[str] = None - - -@app.post("/text") -def store_text( - form_data: TextRAGForm, - user=Depends(get_verified_user), -): - collection_name = form_data.collection_name - if collection_name is None: - collection_name = calculate_sha256_string(form_data.content) - - result = store_text_in_vector_db( - form_data.content, - metadata={"name": form_data.name, "created_by": user.id}, - collection_name=collection_name, - ) - - if result: - return {"status": True, "collection_name": collection_name} - else: raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DEFAULT(), + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), ) -@app.get("/scan") -def scan_docs_dir(user=Depends(get_admin_user)): - for path in Path(DOCS_DIR).rglob("./**/*"): - try: - if path.is_file() and not path.name.startswith("."): - tags = extract_folders_after_data_docs(path) - filename = path.name - file_content_type = mimetypes.guess_type(path) +class QueryCollectionsForm(BaseModel): + collection_names: list[str] + query: str + k: Optional[int] = None + r: Optional[float] = None + hybrid: Optional[bool] = None - f = open(path, "rb") - collection_name = calculate_sha256(f)[:63] - f.close() - loader, known_type = get_loader( - filename, file_content_type[0], str(path) - ) - data = loader.load() +@app.post("/query/collection") +def query_collection_handler( + form_data: QueryCollectionsForm, + user=Depends(get_verified_user), +): + try: + if app.state.config.ENABLE_RAG_HYBRID_SEARCH: + return query_collection_with_hybrid_search( + collection_names=form_data.collection_names, + query=form_data.query, + embedding_function=app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else app.state.config.TOP_K, + reranking_function=app.state.sentence_transformer_rf, + r=( + form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD + ), + ) + else: + return query_collection( + collection_names=form_data.collection_names, + query=form_data.query, + embedding_function=app.state.EMBEDDING_FUNCTION, + k=form_data.k if form_data.k else app.state.config.TOP_K, + ) - try: - result = store_data_in_vector_db(data, collection_name) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) - if result: - sanitized_filename = sanitize_filename(filename) - doc = Documents.get_doc_by_name(sanitized_filename) - if doc is None: - doc = Documents.insert_new_doc( - user.id, - DocumentForm( - **{ - "name": sanitized_filename, - "title": filename, - "collection_name": collection_name, - "filename": filename, - "content": ( - json.dumps( - { - "tags": list( - map( - lambda name: {"name": name}, - tags, - ) - ) - } - ) - if len(tags) - else "{}" - ), - } - ), - ) - except Exception as e: - log.exception(e) - pass - - except Exception as e: - log.exception(e) - - return True +#################################### +# +# Vector DB operations +# +#################################### @app.post("/reset/db") @@ -1544,33 +1284,6 @@ def reset(user=Depends(get_admin_user)) -> bool: return True -class SafeWebBaseLoader(WebBaseLoader): - """WebBaseLoader with enhanced error handling for URLs.""" - - def lazy_load(self) -> Iterator[Document]: - """Lazy load text from the url(s) in web_path with error handling.""" - for path in self.web_paths: - try: - soup = self._scrape(path, bs_kwargs=self.bs_kwargs) - text = soup.get_text(**self.bs_get_text_kwargs) - - # Build metadata - metadata = {"source": path} - if title := soup.find("title"): - metadata["title"] = title.get_text() - if description := soup.find("meta", attrs={"name": "description"}): - metadata["description"] = description.get( - "content", "No description found." - ) - if html := soup.find("html"): - metadata["language"] = html.get("lang", "No language found.") - - yield Document(page_content=text, metadata=metadata) - except Exception as e: - # Log the error and continue with the next URL - log.error(f"Error loading {path}: {e}") - - if ENV == "dev": @app.get("/ef") diff --git a/backend/open_webui/apps/retrieval/model/colbert.py b/backend/open_webui/apps/retrieval/model/colbert.py new file mode 100644 index 000000000..ea3204cb8 --- /dev/null +++ b/backend/open_webui/apps/retrieval/model/colbert.py @@ -0,0 +1,81 @@ +import os +import torch +import numpy as np +from colbert.infra import ColBERTConfig +from colbert.modeling.checkpoint import Checkpoint + + +class ColBERT: + def __init__(self, name, **kwargs) -> None: + print("ColBERT: Loading model", name) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + DOCKER = kwargs.get("env") == "docker" + if DOCKER: + # This is a workaround for the issue with the docker container + # where the torch extension is not loaded properly + # and the following error is thrown: + # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory + + lock_file = ( + "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock" + ) + if os.path.exists(lock_file): + os.remove(lock_file) + + self.ckpt = Checkpoint( + name, + colbert_config=ColBERTConfig(model_name=name), + ).to(self.device) + pass + + def calculate_similarity_scores(self, query_embeddings, document_embeddings): + + query_embeddings = query_embeddings.to(self.device) + document_embeddings = document_embeddings.to(self.device) + + # Validate dimensions to ensure compatibility + if query_embeddings.dim() != 3: + raise ValueError( + f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}." + ) + if document_embeddings.dim() != 3: + raise ValueError( + f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}." + ) + if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: + raise ValueError( + "There should be either one query or queries equal to the number of documents." + ) + + # Transpose the query embeddings to align for matrix multiplication + transposed_query_embeddings = query_embeddings.permute(0, 2, 1) + # Compute similarity scores using batch matrix multiplication + computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings) + # Apply max pooling to extract the highest semantic similarity across each document's sequence + maximum_scores = torch.max(computed_scores, dim=1).values + + # Sum up the maximum scores across features to get the overall document relevance scores + final_scores = maximum_scores.sum(dim=1) + + normalized_scores = torch.softmax(final_scores, dim=0) + + return normalized_scores.detach().cpu().numpy().astype(np.float32) + + def predict(self, sentences): + + query = sentences[0][0] + docs = [i[1] for i in sentences] + + # Embedding the documents + embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0] + # Embedding the queries + embedded_queries = self.ckpt.queryFromText([query], bsize=32) + embedded_query = embedded_queries[0] + + # Calculate retrieval scores for the query against all documents + scores = self.calculate_similarity_scores( + embedded_query.unsqueeze(0), embedded_docs + ) + + return scores diff --git a/backend/open_webui/apps/retrieval/search/brave.py b/backend/open_webui/apps/retrieval/web/brave.py similarity index 93% rename from backend/open_webui/apps/retrieval/search/brave.py rename to backend/open_webui/apps/retrieval/web/brave.py index 11a2938b2..f988b3b08 100644 --- a/backend/open_webui/apps/retrieval/search/brave.py +++ b/backend/open_webui/apps/retrieval/web/brave.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/search/duckduckgo.py b/backend/open_webui/apps/retrieval/web/duckduckgo.py similarity index 94% rename from backend/open_webui/apps/retrieval/search/duckduckgo.py rename to backend/open_webui/apps/retrieval/web/duckduckgo.py index 82558ba37..11e512296 100644 --- a/backend/open_webui/apps/retrieval/search/duckduckgo.py +++ b/backend/open_webui/apps/retrieval/web/duckduckgo.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from duckduckgo_search import DDGS from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/apps/retrieval/search/google_pse.py b/backend/open_webui/apps/retrieval/web/google_pse.py similarity index 94% rename from backend/open_webui/apps/retrieval/search/google_pse.py rename to backend/open_webui/apps/retrieval/web/google_pse.py index c42851f47..61b919583 100644 --- a/backend/open_webui/apps/retrieval/search/google_pse.py +++ b/backend/open_webui/apps/retrieval/web/google_pse.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/search/jina_search.py b/backend/open_webui/apps/retrieval/web/jina_search.py similarity index 94% rename from backend/open_webui/apps/retrieval/search/jina_search.py rename to backend/open_webui/apps/retrieval/web/jina_search.py index f44f10d5c..487bbc948 100644 --- a/backend/open_webui/apps/retrieval/search/jina_search.py +++ b/backend/open_webui/apps/retrieval/web/jina_search.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.retrieval.search.main import SearchResult +from open_webui.apps.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS from yarl import URL diff --git a/backend/open_webui/apps/retrieval/search/main.py b/backend/open_webui/apps/retrieval/web/main.py similarity index 100% rename from backend/open_webui/apps/retrieval/search/main.py rename to backend/open_webui/apps/retrieval/web/main.py diff --git a/backend/open_webui/apps/retrieval/search/searchapi.py b/backend/open_webui/apps/retrieval/web/searchapi.py similarity index 93% rename from backend/open_webui/apps/retrieval/search/searchapi.py rename to backend/open_webui/apps/retrieval/web/searchapi.py index a648d6600..412dc6b69 100644 --- a/backend/open_webui/apps/retrieval/search/searchapi.py +++ b/backend/open_webui/apps/retrieval/web/searchapi.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/search/searxng.py b/backend/open_webui/apps/retrieval/web/searxng.py similarity index 97% rename from backend/open_webui/apps/retrieval/search/searxng.py rename to backend/open_webui/apps/retrieval/web/searxng.py index 14b6b40b5..cb1eaf91d 100644 --- a/backend/open_webui/apps/retrieval/search/searxng.py +++ b/backend/open_webui/apps/retrieval/web/searxng.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/search/serper.py b/backend/open_webui/apps/retrieval/web/serper.py similarity index 93% rename from backend/open_webui/apps/retrieval/search/serper.py rename to backend/open_webui/apps/retrieval/web/serper.py index afebe8097..436fa167e 100644 --- a/backend/open_webui/apps/retrieval/search/serper.py +++ b/backend/open_webui/apps/retrieval/web/serper.py @@ -3,7 +3,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/search/serply.py b/backend/open_webui/apps/retrieval/web/serply.py similarity index 95% rename from backend/open_webui/apps/retrieval/search/serply.py rename to backend/open_webui/apps/retrieval/web/serply.py index 266fd666a..1c2521c47 100644 --- a/backend/open_webui/apps/retrieval/search/serply.py +++ b/backend/open_webui/apps/retrieval/web/serply.py @@ -3,7 +3,7 @@ from typing import Optional from urllib.parse import urlencode import requests -from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/search/serpstack.py b/backend/open_webui/apps/retrieval/web/serpstack.py similarity index 94% rename from backend/open_webui/apps/retrieval/search/serpstack.py rename to backend/open_webui/apps/retrieval/web/serpstack.py index 236fb5181..b655934de 100644 --- a/backend/open_webui/apps/retrieval/search/serpstack.py +++ b/backend/open_webui/apps/retrieval/web/serpstack.py @@ -2,7 +2,7 @@ import logging from typing import Optional import requests -from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results +from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/search/tavily.py b/backend/open_webui/apps/retrieval/web/tavily.py similarity index 93% rename from backend/open_webui/apps/retrieval/search/tavily.py rename to backend/open_webui/apps/retrieval/web/tavily.py index 00f5b15c4..03b0be75a 100644 --- a/backend/open_webui/apps/retrieval/search/tavily.py +++ b/backend/open_webui/apps/retrieval/web/tavily.py @@ -1,7 +1,7 @@ import logging import requests -from open_webui.apps.retrieval.search.main import SearchResult +from open_webui.apps.retrieval.web.main import SearchResult from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/retrieval/search/testdata/brave.json b/backend/open_webui/apps/retrieval/web/testdata/brave.json similarity index 100% rename from backend/open_webui/apps/retrieval/search/testdata/brave.json rename to backend/open_webui/apps/retrieval/web/testdata/brave.json diff --git a/backend/open_webui/apps/retrieval/search/testdata/google_pse.json b/backend/open_webui/apps/retrieval/web/testdata/google_pse.json similarity index 100% rename from backend/open_webui/apps/retrieval/search/testdata/google_pse.json rename to backend/open_webui/apps/retrieval/web/testdata/google_pse.json diff --git a/backend/open_webui/apps/retrieval/search/testdata/searchapi.json b/backend/open_webui/apps/retrieval/web/testdata/searchapi.json similarity index 100% rename from backend/open_webui/apps/retrieval/search/testdata/searchapi.json rename to backend/open_webui/apps/retrieval/web/testdata/searchapi.json diff --git a/backend/open_webui/apps/retrieval/search/testdata/searxng.json b/backend/open_webui/apps/retrieval/web/testdata/searxng.json similarity index 100% rename from backend/open_webui/apps/retrieval/search/testdata/searxng.json rename to backend/open_webui/apps/retrieval/web/testdata/searxng.json diff --git a/backend/open_webui/apps/retrieval/search/testdata/serper.json b/backend/open_webui/apps/retrieval/web/testdata/serper.json similarity index 100% rename from backend/open_webui/apps/retrieval/search/testdata/serper.json rename to backend/open_webui/apps/retrieval/web/testdata/serper.json diff --git a/backend/open_webui/apps/retrieval/search/testdata/serply.json b/backend/open_webui/apps/retrieval/web/testdata/serply.json similarity index 100% rename from backend/open_webui/apps/retrieval/search/testdata/serply.json rename to backend/open_webui/apps/retrieval/web/testdata/serply.json diff --git a/backend/open_webui/apps/retrieval/search/testdata/serpstack.json b/backend/open_webui/apps/retrieval/web/testdata/serpstack.json similarity index 100% rename from backend/open_webui/apps/retrieval/search/testdata/serpstack.json rename to backend/open_webui/apps/retrieval/web/testdata/serpstack.json diff --git a/backend/open_webui/apps/retrieval/web/utils.py b/backend/open_webui/apps/retrieval/web/utils.py new file mode 100644 index 000000000..2df98b33c --- /dev/null +++ b/backend/open_webui/apps/retrieval/web/utils.py @@ -0,0 +1,97 @@ +import socket +import urllib.parse +import validators +from typing import Union, Sequence, Iterator + +from langchain_community.document_loaders import ( + WebBaseLoader, +) +from langchain_core.documents import Document + + +from open_webui.constants import ERROR_MESSAGES +from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH +from open_webui.env import SRC_LOG_LEVELS + +import logging + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def validate_url(url: Union[str, Sequence[str]]): + if isinstance(url, str): + if isinstance(validators.url(url), validators.ValidationError): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + if not ENABLE_RAG_LOCAL_WEB_FETCH: + # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses + parsed_url = urllib.parse.urlparse(url) + # Get IPv4 and IPv6 addresses + ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) + # Check if any of the resolved addresses are private + # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader + for ip in ipv4_addresses: + if validators.ipv4(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + for ip in ipv6_addresses: + if validators.ipv6(ip, private=True): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + return True + elif isinstance(url, Sequence): + return all(validate_url(u) for u in url) + else: + return False + + +def resolve_hostname(hostname): + # Get address information + addr_info = socket.getaddrinfo(hostname, None) + + # Extract IP addresses from address information + ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] + ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] + + return ipv4_addresses, ipv6_addresses + + +class SafeWebBaseLoader(WebBaseLoader): + """WebBaseLoader with enhanced error handling for URLs.""" + + def lazy_load(self) -> Iterator[Document]: + """Lazy load text from the url(s) in web_path with error handling.""" + for path in self.web_paths: + try: + soup = self._scrape(path, bs_kwargs=self.bs_kwargs) + text = soup.get_text(**self.bs_get_text_kwargs) + + # Build metadata + metadata = {"source": path} + if title := soup.find("title"): + metadata["title"] = title.get_text() + if description := soup.find("meta", attrs={"name": "description"}): + metadata["description"] = description.get( + "content", "No description found." + ) + if html := soup.find("html"): + metadata["language"] = html.get("lang", "No language found.") + + yield Document(page_content=text, metadata=metadata) + except Exception as e: + # Log the error and continue with the next URL + log.error(f"Error loading {path}: {e}") + + +def get_web_loader( + url: Union[str, Sequence[str]], + verify_ssl: bool = True, + requests_per_second: int = 2, +): + # Check if the URL is valid + if not validate_url(url): + raise ValueError(ERROR_MESSAGES.INVALID_URL) + return SafeWebBaseLoader( + url, + verify_ssl=verify_ssl, + requests_per_second=requests_per_second, + continue_on_failure=True, + ) diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index ce3a0c0a5..cf86e951c 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -170,284 +170,6 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings return res; }; -export const processFile = async (token: string, file_id: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/process/file`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - file_id: file_id - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => { - const data = new FormData(); - data.append('file', file); - data.append('collection_name', collection_name); - - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/doc`, { - method: 'POST', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - }, - body: data - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const uploadWebToVectorDB = async (token: string, collection_name: string, url: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/web`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - url: url, - collection_name: collection_name - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const uploadYoutubeTranscriptionToVectorDB = async (token: string, url: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/youtube`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - url: url - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const queryDoc = async ( - token: string, - collection_name: string, - query: string, - k: number | null = null -) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - collection_name: collection_name, - query: query, - k: k - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const queryCollection = async ( - token: string, - collection_names: string, - query: string, - k: number | null = null -) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - collection_names: collection_names, - query: query, - k: k - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const scanDocs = async (token: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/scan`, { - method: 'GET', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const resetUploadDir = async (token: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/reset/uploads`, { - method: 'POST', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - -export const resetVectorDB = async (token: string) => { - let error = null; - - const res = await fetch(`${RAG_API_BASE_URL}/reset/db`, { - method: 'POST', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - export const getEmbeddingConfig = async (token: string) => { let error = null; @@ -578,14 +300,140 @@ export const updateRerankingConfig = async (token: string, payload: RerankingMod return res; }; -export const runWebSearch = async ( +export interface SearchDocument { + status: boolean; + collection_name: string; + filenames: string[]; +} + +export const processFile = async (token: string, file_id: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/process/file`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + file_id: file_id + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const processDocsDir = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/process/dir`, { + method: 'GET', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const processYoutubeVideo = async (token: string, url: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/process/youtube`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + url: url + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const processWeb = async (token: string, collection_name: string, url: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/process/web`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + url: url, + collection_name: collection_name + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const processWebSearch = async ( token: string, query: string, collection_name?: string ): Promise => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/web/search`, { + const res = await fetch(`${RAG_API_BASE_URL}/process/web/search`, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -613,8 +461,128 @@ export const runWebSearch = async ( return res; }; -export interface SearchDocument { - status: boolean; - collection_name: string; - filenames: string[]; -} +export const queryDoc = async ( + token: string, + collection_name: string, + query: string, + k: number | null = null +) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + collection_name: collection_name, + query: query, + k: k + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const queryCollection = async ( + token: string, + collection_names: string, + query: string, + k: number | null = null +) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + collection_names: collection_names, + query: query, + k: k + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const resetUploadDir = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/reset/uploads`, { + method: 'POST', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const resetVectorDB = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/reset/db`, { + method: 'POST', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index 84f60847e..c10b60aa0 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -7,7 +7,7 @@ import { deleteAllFiles, deleteFileById } from '$lib/apis/files'; import { getQuerySettings, - scanDocs, + processDocsDir, updateQuerySettings, resetVectorDB, getEmbeddingConfig, @@ -63,7 +63,7 @@ const scanHandler = async () => { scanDirLoading = true; - const res = await scanDocs(localStorage.token); + const res = await processDocsDir(localStorage.token); scanDirLoading = false; if (res) { diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index e196936a6..f60f0ede3 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -52,7 +52,7 @@ updateChatById } from '$lib/apis/chats'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; - import { runWebSearch } from '$lib/apis/retrieval'; + import { processWebSearch } from '$lib/apis/retrieval'; import { createOpenAITextStream } from '$lib/apis/streaming'; import { queryMemory } from '$lib/apis/memories'; import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users'; @@ -1737,7 +1737,7 @@ }); history.messages[responseMessageId] = responseMessage; - const results = await runWebSearch(localStorage.token, searchQuery).catch((error) => { + const results = await processWebSearch(localStorage.token, searchQuery).catch((error) => { console.log(error); toast.error(error); diff --git a/src/lib/components/chat/Controls/Controls.svelte b/src/lib/components/chat/Controls/Controls.svelte index 35184f385..50d5a5648 100644 --- a/src/lib/components/chat/Controls/Controls.svelte +++ b/src/lib/components/chat/Controls/Controls.svelte @@ -46,6 +46,9 @@ chatFiles.splice(fileIdx, 1); chatFiles = chatFiles; }} + on:click={() => { + console.log(file); + }} /> {/each} diff --git a/src/lib/components/chat/MessageInput/Commands.svelte b/src/lib/components/chat/MessageInput/Commands.svelte index ed0e7dfb6..d1f85d458 100644 --- a/src/lib/components/chat/MessageInput/Commands.svelte +++ b/src/lib/components/chat/MessageInput/Commands.svelte @@ -9,7 +9,7 @@ import Models from './Commands/Models.svelte'; import { removeLastWordFromString } from '$lib/utils'; - import { uploadWebToVectorDB, uploadYoutubeTranscriptionToVectorDB } from '$lib/apis/retrieval'; + import { processWeb, processYoutubeVideo } from '$lib/apis/retrieval'; export let prompt = ''; export let files = []; @@ -41,7 +41,7 @@ try { files = [...files, doc]; - const res = await uploadWebToVectorDB(localStorage.token, '', url); + const res = await processWeb(localStorage.token, '', url); if (res) { doc.status = 'processed'; @@ -69,7 +69,7 @@ try { files = [...files, doc]; - const res = await uploadYoutubeTranscriptionToVectorDB(localStorage.token, url); + const res = await processYoutubeVideo(localStorage.token, url); if (res) { doc.status = 'processed'; diff --git a/src/lib/components/common/FileItem.svelte b/src/lib/components/common/FileItem.svelte index f4dfd27e7..7e8592ab9 100644 --- a/src/lib/components/common/FileItem.svelte +++ b/src/lib/components/common/FileItem.svelte @@ -8,8 +8,6 @@ export let colorClassName = 'bg-white dark:bg-gray-800'; export let url: string | null = null; - export let clickHandler: Function | null = null; - export let dismissible = false; export let status = 'processed'; @@ -17,7 +15,7 @@ export let type: string; export let size: number; - function formatSize(size) { + const formatSize = (size) => { if (size == null) return 'Unknown size'; if (typeof size !== 'number' || size < 0) return 'Invalid size'; if (size === 0) return '0 B'; @@ -29,7 +27,7 @@ unitIndex++; } return `${size.toFixed(1)} ${units[unitIndex]}`; - } + };
@@ -37,17 +35,7 @@ class="h-14 {className} flex items-center space-x-3 {colorClassName} rounded-xl border border-gray-100 dark:border-gray-800 text-left" type="button" on:click={async () => { - if (clickHandler === null) { - if (url) { - if (type === 'file') { - window.open(`${url}/content`, '_blank').focus(); - } else { - window.open(`${url}`, '_blank').focus(); - } - } - } else { - clickHandler(); - } + dispatch('click'); }} >