mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-05 17:53:54 +02:00
Simplify passing in of file IDs for filtering (#4791)
* Simplify passing in of file IDs for filtering * Address RK comments
This commit is contained in:
committed by
trial-danswer
parent
5031096a2b
commit
6ad423efce
@@ -111,11 +111,14 @@ class BaseFilters(BaseModel):
|
||||
document_set: list[str] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
tags: list[Tag] | None = None
|
||||
|
||||
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters):
|
||||
class IndexFilters(BaseFilters, UserFileFilters):
|
||||
access_control_list: list[str] | None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@@ -150,6 +153,7 @@ class SearchRequest(ChunkContext):
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
|
||||
human_selected_filters: BaseFilters | None = None
|
||||
user_file_filters: UserFileFilters | None = None
|
||||
enable_auto_detect_filters: bool | None = None
|
||||
persona: Persona | None = None
|
||||
|
||||
|
@@ -164,14 +164,15 @@ def retrieval_preprocessing(
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
user_file_ids = preset_filters.user_file_ids or []
|
||||
user_folder_ids = preset_filters.user_folder_ids or []
|
||||
user_file_filters = search_request.user_file_filters
|
||||
user_file_ids = (user_file_filters.user_file_ids or []) if user_file_filters else []
|
||||
user_folder_ids = (
|
||||
(user_file_filters.user_folder_ids or []) if user_file_filters else []
|
||||
)
|
||||
if persona and persona.user_files:
|
||||
user_file_ids = user_file_ids + [
|
||||
file.id
|
||||
for file in persona.user_files
|
||||
if file.id not in (preset_filters.user_file_ids or [])
|
||||
]
|
||||
user_file_ids = list(
|
||||
set(user_file_ids) | set([file.id for file in persona.user_files])
|
||||
)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
@@ -32,6 +31,7 @@ from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.models import UserFileFilters
|
||||
from onyx.context.search.pipeline import SearchPipeline
|
||||
from onyx.context.search.pipeline import section_relevance_list_impl
|
||||
from onyx.db.models import Persona
|
||||
@@ -324,30 +324,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
yield from self._build_response_for_specified_sections(query)
|
||||
return
|
||||
|
||||
# Create a copy of the retrieval options with user_file_ids if provided
|
||||
retrieval_options = copy.deepcopy(self.retrieval_options)
|
||||
if (user_file_ids or user_folder_ids) and retrieval_options:
|
||||
# Create a copy to avoid modifying the original
|
||||
filters = (
|
||||
retrieval_options.filters.model_copy()
|
||||
if retrieval_options.filters
|
||||
else BaseFilters()
|
||||
)
|
||||
filters.user_file_ids = user_file_ids
|
||||
retrieval_options = retrieval_options.model_copy(
|
||||
update={"filters": filters}
|
||||
)
|
||||
elif user_file_ids or user_folder_ids:
|
||||
# Create new retrieval options with user_file_ids
|
||||
filters = BaseFilters(
|
||||
user_file_ids=user_file_ids, user_folder_ids=user_folder_ids
|
||||
)
|
||||
retrieval_options = RetrievalDetails(filters=filters)
|
||||
|
||||
retrieval_options = self.retrieval_options or RetrievalDetails()
|
||||
if document_sources or time_cutoff:
|
||||
# Get retrieval_options and filters, or create if they don't exist
|
||||
retrieval_options = retrieval_options or RetrievalDetails()
|
||||
retrieval_options.filters = retrieval_options.filters or BaseFilters()
|
||||
# if empty, just start with an empty filters object
|
||||
if not retrieval_options.filters:
|
||||
retrieval_options.filters = BaseFilters()
|
||||
|
||||
# Handle document sources
|
||||
if document_sources:
|
||||
@@ -370,6 +351,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
human_selected_filters=(
|
||||
retrieval_options.filters if retrieval_options else None
|
||||
),
|
||||
user_file_filters=UserFileFilters(
|
||||
user_file_ids=user_file_ids, user_folder_ids=user_folder_ids
|
||||
),
|
||||
persona=self.persona,
|
||||
offset=(retrieval_options.offset if retrieval_options else None),
|
||||
limit=retrieval_options.limit if retrieval_options else None,
|
||||
|
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
|
||||
API_SERVER_URL = "http://localhost:3000" # Adjust this to your Onyx server URL
|
||||
HEADERS = {"Content-Type": "application/json"}
|
||||
API_SERVER_URL = "http://localhost:3000"
|
||||
API_KEY = "onyx-api-key" # API key here, if auth is enabled
|
||||
HEADERS = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"}
|
||||
|
||||
|
||||
def create_connector(
|
||||
@@ -15,23 +16,29 @@ def create_connector(
|
||||
connector_specific_config: Dict[str, Any],
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
access_type: str = "public",
|
||||
) -> Dict[str, Any]:
|
||||
connector_update_request = {
|
||||
"name": name,
|
||||
"name": name + " Connector",
|
||||
"source": source,
|
||||
"input_type": input_type,
|
||||
"connector_specific_config": connector_specific_config,
|
||||
"is_public": is_public,
|
||||
"groups": groups or [],
|
||||
"access_type": access_type,
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/api/manage/admin/connector",
|
||||
json=connector_update_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/api/manage/admin/connector",
|
||||
json=connector_update_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
print(f"Error response body: {e.response.text}")
|
||||
raise
|
||||
|
||||
|
||||
def create_credential(
|
||||
@@ -42,20 +49,25 @@ def create_credential(
|
||||
groups: list[int] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
credential_request = {
|
||||
"name": name,
|
||||
"name": name + " Credential",
|
||||
"source": source,
|
||||
"credential_json": credential_json,
|
||||
"admin_public": is_public,
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/api/manage/credential",
|
||||
json=credential_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
try:
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/api/manage/credential",
|
||||
json=credential_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
print(f"Error response body: {e.response.text}")
|
||||
raise
|
||||
|
||||
|
||||
def create_cc_pair(
|
||||
@@ -71,77 +83,53 @@ def create_cc_pair(
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}",
|
||||
json=cc_pair_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
try:
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}",
|
||||
json=cc_pair_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
print(f"Error response body: {e.response.text}")
|
||||
raise
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Create a Web connector
|
||||
web_connector = create_connector(
|
||||
name="Example Web Connector",
|
||||
source="web",
|
||||
input_type="load_state",
|
||||
connector_specific_config={
|
||||
"base_url": "https://example.com",
|
||||
"web_connector_type": "recursive",
|
||||
},
|
||||
)
|
||||
print(f"Created Web Connector: {web_connector}")
|
||||
# Parse the JSON file that contains the connector creation requests
|
||||
with open("creation_request/connector_creation_template.json", "r") as file:
|
||||
connector_creation_requests = json.load(file)
|
||||
|
||||
# Create a credential for the Web connector
|
||||
web_credential = create_credential(
|
||||
name="Example Web Credential",
|
||||
source="web",
|
||||
credential_json={}, # Web connectors typically don't need credentials
|
||||
is_public=True,
|
||||
)
|
||||
print(f"Created Web Credential: {web_credential}")
|
||||
for connector_creation_request in connector_creation_requests:
|
||||
connector_response = create_connector(
|
||||
name=connector_creation_request["name"],
|
||||
source=connector_creation_request["source"],
|
||||
input_type=connector_creation_request["input_type"],
|
||||
connector_specific_config=dict(
|
||||
connector_creation_request.get("connector_specific_config", {})
|
||||
),
|
||||
access_type=connector_creation_request.get("access_type", "public"),
|
||||
)
|
||||
|
||||
# Create CC pair for Web connector
|
||||
web_cc_pair = create_cc_pair(
|
||||
connector_id=web_connector["id"],
|
||||
credential_id=web_credential["id"],
|
||||
name="Example Web CC Pair",
|
||||
access_type="public",
|
||||
)
|
||||
print(f"Created Web CC Pair: {web_cc_pair}")
|
||||
credential_id = connector_creation_request.get("credential_id")
|
||||
# If a credential id is provided, reuse credential rather than creating a new one
|
||||
if not credential_id:
|
||||
credential_response = create_credential(
|
||||
name=connector_creation_request["name"],
|
||||
source=connector_creation_request["source"],
|
||||
credential_json=connector_creation_request.get("credential_json", {}),
|
||||
)
|
||||
credential_id = credential_response.get("id")
|
||||
|
||||
# Create a GitHub connector
|
||||
github_connector = create_connector(
|
||||
name="Example GitHub Connector",
|
||||
source="github",
|
||||
input_type="poll",
|
||||
connector_specific_config={
|
||||
"repo_owner": "example-owner",
|
||||
"repo_name": "example-repo",
|
||||
"include_prs": True,
|
||||
"include_issues": True,
|
||||
},
|
||||
)
|
||||
print(f"Created GitHub Connector: {github_connector}")
|
||||
create_cc_pair(
|
||||
connector_id=connector_response.get("id"),
|
||||
credential_id=credential_id,
|
||||
name=connector_creation_request["name"],
|
||||
access_type=connector_creation_request.get("access_type", "public"),
|
||||
)
|
||||
|
||||
# Create a credential for the GitHub connector
|
||||
github_credential = create_credential(
|
||||
name="Example GitHub Credential",
|
||||
source="github",
|
||||
credential_json={"github_access_token": "your_github_access_token_here"},
|
||||
is_public=True,
|
||||
)
|
||||
print(f"Created GitHub Credential: {github_credential}")
|
||||
|
||||
# Create CC pair for GitHub connector
|
||||
github_cc_pair = create_cc_pair(
|
||||
connector_id=github_connector["id"],
|
||||
credential_id=github_credential["id"],
|
||||
name="Example GitHub CC Pair",
|
||||
access_type="public",
|
||||
)
|
||||
print(f"Created GitHub CC Pair: {github_cc_pair}")
|
||||
print(f"Created connector: {connector_creation_request['name']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -0,0 +1,45 @@
|
||||
[
|
||||
{
|
||||
"name": "Example Web",
|
||||
"source": "web",
|
||||
"input_type": "load_state",
|
||||
"connector_specific_config": {
|
||||
"base_url": "https://example.com",
|
||||
"web_connector_type": "recursive"
|
||||
},
|
||||
"credential_json": {},
|
||||
"is_public": true,
|
||||
"credential_id": null
|
||||
},
|
||||
{
|
||||
"name": "Example GitHub",
|
||||
"source": "github",
|
||||
"input_type": "poll",
|
||||
"connector_specific_config": {
|
||||
"repo_owner": "example-owner",
|
||||
"repositories": "example-repo",
|
||||
"include_prs": true,
|
||||
"include_issues": true
|
||||
},
|
||||
"credential_json": {
|
||||
"github_access_token": "your_github_access_token_here"
|
||||
},
|
||||
"is_public": true,
|
||||
"credential_id": null
|
||||
},
|
||||
{
|
||||
"name": "Example Google Drive",
|
||||
"source": "google_drive",
|
||||
"input_type": "poll",
|
||||
"is_public": true,
|
||||
"connector_specific_config": {
|
||||
"include_shared_drives": true,
|
||||
"include_my_drives": true,
|
||||
"include_files_shared_with_me": true
|
||||
},
|
||||
"credential_json": {
|
||||
"google_primary_admin": "Admin Email"
|
||||
},
|
||||
"credential_id": 1
|
||||
}
|
||||
]
|
@@ -1451,9 +1451,7 @@ export function ChatPage({
|
||||
filterManager.selectedSources,
|
||||
filterManager.selectedDocumentSets,
|
||||
filterManager.timeRange,
|
||||
filterManager.selectedTags,
|
||||
selectedFiles.map((file) => file.id)
|
||||
// selectedFolders.map((folder) => folder.id)
|
||||
filterManager.selectedTags
|
||||
),
|
||||
selectedDocumentIds: selectedDocuments
|
||||
.filter(
|
||||
|
Reference in New Issue
Block a user