Strict Tenant ID Enforcement (#3871)

* strict tenant id enforcement

* k

* k

* nit

* merge

* nit

* k
This commit is contained in:
pablonyx
2025-02-18 16:52:56 -08:00
committed by GitHub
parent 2013beb9e0
commit 47fd4fa233
68 changed files with 390 additions and 357 deletions

View File

@@ -47,7 +47,7 @@ def get_user_id(user_email: str) -> tuple[UUID, str]:
get_tenant_id_for_email(user_email) if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
)
with get_session_with_tenant(tenant_id) as session:
with get_session_with_tenant(tenant_id=tenant_id) as session:
user = get_user_by_email(user_email, session)
if user is None:
raise ValueError(f"User not found for email: {user_email}")

View File

@@ -41,6 +41,7 @@ from sqlalchemy import and_
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
@@ -64,6 +65,7 @@ from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -473,7 +475,7 @@ def get_document_acls(
def get_current_chunk_count(
document_id: str, index_name: str, tenant_id: str
) -> int | None:
with get_session_with_tenant(tenant_id=tenant_id) as session:
with get_session_with_current_tenant() as session:
return (
session.query(Document.chunk_count)
.filter(Document.id == document_id)
@@ -513,7 +515,7 @@ class VespaDebugging:
# Sample random documents and compare chunk counts
mismatches = []
no_chunks = []
with get_session_with_tenant(tenant_id=self.tenant_id) as session:
with get_session_with_current_tenant() as session:
# Get a sample of random documents
from sqlalchemy import func
@@ -796,6 +798,7 @@ def main() -> None:
args = parser.parse_args()
vespa_debug = VespaDebugging(args.tenant_id)
CURRENT_TENANT_ID_CONTEXTVAR.set(args.tenant_id)
if args.action == "delete-all-documents":
if not args.tenant_id:
parser.error("--tenant-id is required for delete-all-documents action")