Simplify passing in of file IDs for filtering (#4791)

* Simplify passing in of file IDs for filtering

* Address RK comments
This commit is contained in:
Chris Weaver
2025-05-29 22:08:21 -07:00
committed by trial-danswer
parent 5031096a2b
commit 6ad423efce
6 changed files with 137 additions and 117 deletions

View File

@@ -111,11 +111,14 @@ class BaseFilters(BaseModel):
document_set: list[str] | None = None
time_cutoff: datetime | None = None
tags: list[Tag] | None = None
class UserFileFilters(BaseModel):
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
class IndexFilters(BaseFilters):
class IndexFilters(BaseFilters, UserFileFilters):
access_control_list: list[str] | None
tenant_id: str | None = None
@@ -150,6 +153,7 @@ class SearchRequest(ChunkContext):
search_type: SearchType = SearchType.SEMANTIC
human_selected_filters: BaseFilters | None = None
user_file_filters: UserFileFilters | None = None
enable_auto_detect_filters: bool | None = None
persona: Persona | None = None

View File

@@ -164,14 +164,15 @@ def retrieval_preprocessing(
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
user_file_ids = preset_filters.user_file_ids or []
user_folder_ids = preset_filters.user_folder_ids or []
user_file_filters = search_request.user_file_filters
user_file_ids = (user_file_filters.user_file_ids or []) if user_file_filters else []
user_folder_ids = (
(user_file_filters.user_folder_ids or []) if user_file_filters else []
)
if persona and persona.user_files:
user_file_ids = user_file_ids + [
file.id
for file in persona.user_files
if file.id not in (preset_filters.user_file_ids or [])
]
user_file_ids = list(
set(user_file_ids) | set([file.id for file in persona.user_files])
)
final_filters = IndexFilters(
user_file_ids=user_file_ids,

View File

@@ -1,4 +1,3 @@
import copy
import json
import time
from collections.abc import Callable
@@ -32,6 +31,7 @@ from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.context.search.models import UserFileFilters
from onyx.context.search.pipeline import SearchPipeline
from onyx.context.search.pipeline import section_relevance_list_impl
from onyx.db.models import Persona
@@ -324,30 +324,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
yield from self._build_response_for_specified_sections(query)
return
# Create a copy of the retrieval options with user_file_ids if provided
retrieval_options = copy.deepcopy(self.retrieval_options)
if (user_file_ids or user_folder_ids) and retrieval_options:
# Create a copy to avoid modifying the original
filters = (
retrieval_options.filters.model_copy()
if retrieval_options.filters
else BaseFilters()
)
filters.user_file_ids = user_file_ids
retrieval_options = retrieval_options.model_copy(
update={"filters": filters}
)
elif user_file_ids or user_folder_ids:
# Create new retrieval options with user_file_ids
filters = BaseFilters(
user_file_ids=user_file_ids, user_folder_ids=user_folder_ids
)
retrieval_options = RetrievalDetails(filters=filters)
retrieval_options = self.retrieval_options or RetrievalDetails()
if document_sources or time_cutoff:
# Get retrieval_options and filters, or create if they don't exist
retrieval_options = retrieval_options or RetrievalDetails()
retrieval_options.filters = retrieval_options.filters or BaseFilters()
# if empty, just start with an empty filters object
if not retrieval_options.filters:
retrieval_options.filters = BaseFilters()
# Handle document sources
if document_sources:
@@ -370,6 +351,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
human_selected_filters=(
retrieval_options.filters if retrieval_options else None
),
user_file_filters=UserFileFilters(
user_file_ids=user_file_ids, user_folder_ids=user_folder_ids
),
persona=self.persona,
offset=(retrieval_options.offset if retrieval_options else None),
limit=retrieval_options.limit if retrieval_options else None,

View File

@@ -1,11 +1,12 @@
import json
from typing import Any
from typing import Dict
import requests
API_SERVER_URL = "http://localhost:3000" # Adjust this to your Onyx server URL
HEADERS = {"Content-Type": "application/json"}
API_SERVER_URL = "http://localhost:3000"
API_KEY = "onyx-api-key" # API key here, if auth is enabled
HEADERS = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"}
def create_connector(
@@ -15,23 +16,29 @@ def create_connector(
connector_specific_config: Dict[str, Any],
is_public: bool = True,
groups: list[int] | None = None,
access_type: str = "public",
) -> Dict[str, Any]:
connector_update_request = {
"name": name,
"name": name + " Connector",
"source": source,
"input_type": input_type,
"connector_specific_config": connector_specific_config,
"is_public": is_public,
"groups": groups or [],
"access_type": access_type,
}
try:
response = requests.post(
url=f"{API_SERVER_URL}/api/manage/admin/connector",
json=connector_update_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
response = requests.post(
url=f"{API_SERVER_URL}/api/manage/admin/connector",
json=connector_update_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
print(f"Error response body: {e.response.text}")
raise
def create_credential(
@@ -42,20 +49,25 @@ def create_credential(
groups: list[int] | None = None,
) -> Dict[str, Any]:
credential_request = {
"name": name,
"name": name + " Credential",
"source": source,
"credential_json": credential_json,
"admin_public": is_public,
"groups": groups or [],
}
response = requests.post(
url=f"{API_SERVER_URL}/api/manage/credential",
json=credential_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
try:
response = requests.post(
url=f"{API_SERVER_URL}/api/manage/credential",
json=credential_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
print(f"Error response body: {e.response.text}")
raise
def create_cc_pair(
@@ -71,77 +83,53 @@ def create_cc_pair(
"groups": groups or [],
}
response = requests.put(
url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}",
json=cc_pair_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
try:
response = requests.put(
url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}",
json=cc_pair_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
print(f"Error response body: {e.response.text}")
raise
def main() -> None:
# Create a Web connector
web_connector = create_connector(
name="Example Web Connector",
source="web",
input_type="load_state",
connector_specific_config={
"base_url": "https://example.com",
"web_connector_type": "recursive",
},
)
print(f"Created Web Connector: {web_connector}")
# Parse the JSON file that contains the connector creation requests
with open("creation_request/connector_creation_template.json", "r") as file:
connector_creation_requests = json.load(file)
# Create a credential for the Web connector
web_credential = create_credential(
name="Example Web Credential",
source="web",
credential_json={}, # Web connectors typically don't need credentials
is_public=True,
)
print(f"Created Web Credential: {web_credential}")
for connector_creation_request in connector_creation_requests:
connector_response = create_connector(
name=connector_creation_request["name"],
source=connector_creation_request["source"],
input_type=connector_creation_request["input_type"],
connector_specific_config=dict(
connector_creation_request.get("connector_specific_config", {})
),
access_type=connector_creation_request.get("access_type", "public"),
)
# Create CC pair for Web connector
web_cc_pair = create_cc_pair(
connector_id=web_connector["id"],
credential_id=web_credential["id"],
name="Example Web CC Pair",
access_type="public",
)
print(f"Created Web CC Pair: {web_cc_pair}")
credential_id = connector_creation_request.get("credential_id")
# If a credential id is provided, reuse credential rather than creating a new one
if not credential_id:
credential_response = create_credential(
name=connector_creation_request["name"],
source=connector_creation_request["source"],
credential_json=connector_creation_request.get("credential_json", {}),
)
credential_id = credential_response.get("id")
# Create a GitHub connector
github_connector = create_connector(
name="Example GitHub Connector",
source="github",
input_type="poll",
connector_specific_config={
"repo_owner": "example-owner",
"repo_name": "example-repo",
"include_prs": True,
"include_issues": True,
},
)
print(f"Created GitHub Connector: {github_connector}")
create_cc_pair(
connector_id=connector_response.get("id"),
credential_id=credential_id,
name=connector_creation_request["name"],
access_type=connector_creation_request.get("access_type", "public"),
)
# Create a credential for the GitHub connector
github_credential = create_credential(
name="Example GitHub Credential",
source="github",
credential_json={"github_access_token": "your_github_access_token_here"},
is_public=True,
)
print(f"Created GitHub Credential: {github_credential}")
# Create CC pair for GitHub connector
github_cc_pair = create_cc_pair(
connector_id=github_connector["id"],
credential_id=github_credential["id"],
name="Example GitHub CC Pair",
access_type="public",
)
print(f"Created GitHub CC Pair: {github_cc_pair}")
print(f"Created connector: {connector_creation_request['name']}")
if __name__ == "__main__":

View File

@@ -0,0 +1,45 @@
[
{
"name": "Example Web",
"source": "web",
"input_type": "load_state",
"connector_specific_config": {
"base_url": "https://example.com",
"web_connector_type": "recursive"
},
"credential_json": {},
"is_public": true,
"credential_id": null
},
{
"name": "Example GitHub",
"source": "github",
"input_type": "poll",
"connector_specific_config": {
"repo_owner": "example-owner",
"repositories": "example-repo",
"include_prs": true,
"include_issues": true
},
"credential_json": {
"github_access_token": "your_github_access_token_here"
},
"is_public": true,
"credential_id": null
},
{
"name": "Example Google Drive",
"source": "google_drive",
"input_type": "poll",
"is_public": true,
"connector_specific_config": {
"include_shared_drives": true,
"include_my_drives": true,
"include_files_shared_with_me": true
},
"credential_json": {
"google_primary_admin": "Admin Email"
},
"credential_id": 1
}
]

View File

@@ -1451,9 +1451,7 @@ export function ChatPage({
filterManager.selectedSources,
filterManager.selectedDocumentSets,
filterManager.timeRange,
filterManager.selectedTags,
selectedFiles.map((file) => file.id)
// selectedFolders.map((folder) => folder.id)
filterManager.selectedTags
),
selectedDocumentIds: selectedDocuments
.filter(