mirror of
https://github.com/open-webui/open-webui.git
synced 2025-07-01 19:20:40 +02:00
Add batching
This commit is contained in:
@ -2,16 +2,14 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from typing import List, Optional
|
||||||
from typing import Iterator, Optional, Sequence, Union
|
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
|
from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@ -52,7 +50,7 @@ from open_webui.apps.retrieval.utils import (
|
|||||||
query_doc_with_hybrid_search,
|
query_doc_with_hybrid_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
from open_webui.apps.webui.models.files import Files
|
from open_webui.apps.webui.models.files import FileModel, Files
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
BRAVE_SEARCH_API_KEY,
|
BRAVE_SEARCH_API_KEY,
|
||||||
KAGI_SEARCH_API_KEY,
|
KAGI_SEARCH_API_KEY,
|
||||||
@ -64,7 +62,6 @@ from open_webui.config import (
|
|||||||
CONTENT_EXTRACTION_ENGINE,
|
CONTENT_EXTRACTION_ENGINE,
|
||||||
CORS_ALLOW_ORIGIN,
|
CORS_ALLOW_ORIGIN,
|
||||||
ENABLE_RAG_HYBRID_SEARCH,
|
ENABLE_RAG_HYBRID_SEARCH,
|
||||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
|
||||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||||
ENABLE_RAG_WEB_SEARCH,
|
ENABLE_RAG_WEB_SEARCH,
|
||||||
ENV,
|
ENV,
|
||||||
@ -86,7 +83,6 @@ from open_webui.config import (
|
|||||||
RAG_RERANKING_MODEL,
|
RAG_RERANKING_MODEL,
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||||
DEFAULT_RAG_TEMPLATE,
|
|
||||||
RAG_TEMPLATE,
|
RAG_TEMPLATE,
|
||||||
RAG_TOP_K,
|
RAG_TOP_K,
|
||||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||||
@ -118,10 +114,7 @@ from open_webui.env import (
|
|||||||
DOCKER,
|
DOCKER,
|
||||||
)
|
)
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
calculate_sha256,
|
|
||||||
calculate_sha256_string,
|
calculate_sha256_string,
|
||||||
extract_folders_after_data_docs,
|
|
||||||
sanitize_filename,
|
|
||||||
)
|
)
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
|
||||||
@ -1047,6 +1040,106 @@ def process_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessFilesForm(BaseModel):
|
||||||
|
files: List[FileModel]
|
||||||
|
collection_name: str
|
||||||
|
|
||||||
|
class BatchProcessFilesResult(BaseModel):
|
||||||
|
file_id: str
|
||||||
|
status: str
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
class BatchProcessFilesResponse(BaseModel):
|
||||||
|
results: List[BatchProcessFilesResult]
|
||||||
|
errors: List[BatchProcessFilesResult]
|
||||||
|
|
||||||
|
@app.post("/process/files/batch")
|
||||||
|
def process_files_batch(
|
||||||
|
form_data: BatchProcessFilesForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
) -> BatchProcessFilesResponse:
|
||||||
|
"""
|
||||||
|
Process a batch of files and save them to the vector database.
|
||||||
|
"""
|
||||||
|
results: List[BatchProcessFilesResult] = []
|
||||||
|
errors: List[BatchProcessFilesResult] = []
|
||||||
|
collection_name = form_data.collection_name
|
||||||
|
|
||||||
|
|
||||||
|
# Prepare all documents first
|
||||||
|
all_docs: List[Document] = []
|
||||||
|
for file_request in form_data.files:
|
||||||
|
try:
|
||||||
|
file = Files.get_file_by_id(file_request.file_id)
|
||||||
|
if not file:
|
||||||
|
log.error(f"process_files_batch: File {file_request.file_id} not found")
|
||||||
|
raise ValueError(f"File {file_request.file_id} not found")
|
||||||
|
|
||||||
|
text_content = file_request.content
|
||||||
|
|
||||||
|
docs: List[Document] = [
|
||||||
|
Document(
|
||||||
|
page_content=text_content.replace("<br/>", "\n"),
|
||||||
|
metadata={
|
||||||
|
**file.meta,
|
||||||
|
"name": file_request.filename,
|
||||||
|
"created_by": file.user_id,
|
||||||
|
"file_id": file.id,
|
||||||
|
"source": file_request.filename,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
hash = calculate_sha256_string(text_content)
|
||||||
|
Files.update_file_hash_by_id(file.id, hash)
|
||||||
|
Files.update_file_data_by_id(file.id, {"content": text_content})
|
||||||
|
|
||||||
|
all_docs.extend(docs)
|
||||||
|
results.append(BatchProcessFilesResult(
|
||||||
|
file_id=file.id,
|
||||||
|
status="prepared"
|
||||||
|
))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"process_files_batch: Error processing file {file_request.file_id}: {str(e)}")
|
||||||
|
errors.append(BatchProcessFilesResult(
|
||||||
|
file_id=file_request.file_id,
|
||||||
|
status="failed",
|
||||||
|
error=str(e)
|
||||||
|
))
|
||||||
|
|
||||||
|
# Save all documents in one batch
|
||||||
|
if all_docs:
|
||||||
|
try:
|
||||||
|
save_docs_to_vector_db(
|
||||||
|
docs=all_docs,
|
||||||
|
collection_name=collection_name,
|
||||||
|
add=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update all files with collection name
|
||||||
|
for result in results:
|
||||||
|
Files.update_file_metadata_by_id(
|
||||||
|
result.file_id,
|
||||||
|
{"collection_name": collection_name}
|
||||||
|
)
|
||||||
|
result.status = "completed"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}")
|
||||||
|
for result in results:
|
||||||
|
result.status = "failed"
|
||||||
|
errors.append(BatchProcessFilesResult(
|
||||||
|
file_id=result.file_id,
|
||||||
|
error=str(e)
|
||||||
|
))
|
||||||
|
|
||||||
|
return BatchProcessFilesResponse(
|
||||||
|
results=results,
|
||||||
|
errors=errors
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProcessTextForm(BaseModel):
|
class ProcessTextForm(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
content: str
|
content: str
|
||||||
@ -1509,3 +1602,4 @@ if ENV == "dev":
|
|||||||
@app.get("/ef/{text}")
|
@app.get("/ef/{text}")
|
||||||
async def get_embeddings_text(text: str):
|
async def get_embeddings_text(text: str):
|
||||||
return {"result": app.state.EMBEDDING_FUNCTION(text)}
|
return {"result": app.state.EMBEDDING_FUNCTION(text)}
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import json
|
from typing import List, Optional
|
||||||
from typing import Optional, Union
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
import logging
|
import logging
|
||||||
@ -12,11 +11,11 @@ from open_webui.apps.webui.models.knowledge import (
|
|||||||
)
|
)
|
||||||
from open_webui.apps.webui.models.files import Files, FileModel
|
from open_webui.apps.webui.models.files import Files, FileModel
|
||||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||||
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
|
from open_webui.apps.retrieval.main import BatchProcessFilesForm, process_file, ProcessFileForm, process_files_batch
|
||||||
|
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
|
|
||||||
|
|
||||||
@ -508,3 +507,78 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
|
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
|
||||||
|
|
||||||
return knowledge
|
return knowledge
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# AddFilesToKnowledge
|
||||||
|
############################
|
||||||
|
|
||||||
|
@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
|
||||||
|
def add_files_to_knowledge_batch(
|
||||||
|
id: str,
|
||||||
|
form_data: list[KnowledgeFileIdForm],
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add multiple files to a knowledge base
|
||||||
|
"""
|
||||||
|
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||||
|
if not knowledge:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
if knowledge.user_id != user.id and user.role != "admin":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get files content
|
||||||
|
print(f"files/batch/add - {len(form_data)} files")
|
||||||
|
files: List[FileModel] = []
|
||||||
|
for form in form_data:
|
||||||
|
file = Files.get_file_by_id(form.file_id)
|
||||||
|
if not file:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"File {form.file_id} not found",
|
||||||
|
)
|
||||||
|
files.append(file)
|
||||||
|
|
||||||
|
# Process files
|
||||||
|
result = process_files_batch(BatchProcessFilesForm(
|
||||||
|
files=files,
|
||||||
|
collection_name=id
|
||||||
|
))
|
||||||
|
|
||||||
|
# Add successful files to knowledge base
|
||||||
|
data = knowledge.data or {}
|
||||||
|
existing_file_ids = data.get("file_ids", [])
|
||||||
|
|
||||||
|
# Only add files that were successfully processed
|
||||||
|
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
|
||||||
|
for file_id in successful_file_ids:
|
||||||
|
if file_id not in existing_file_ids:
|
||||||
|
existing_file_ids.append(file_id)
|
||||||
|
|
||||||
|
data["file_ids"] = existing_file_ids
|
||||||
|
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
||||||
|
|
||||||
|
# If there were any errors, include them in the response
|
||||||
|
if result.errors:
|
||||||
|
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
||||||
|
return KnowledgeFilesResponse(
|
||||||
|
**knowledge.model_dump(),
|
||||||
|
files=Files.get_files_by_ids(existing_file_ids),
|
||||||
|
warnings={
|
||||||
|
"message": "Some files failed to process",
|
||||||
|
"errors": error_details
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return KnowledgeFilesResponse(
|
||||||
|
**knowledge.model_dump(),
|
||||||
|
files=Files.get_files_by_ids(existing_file_ids)
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user