This commit is contained in:
Weves 2024-11-25 08:15:45 -08:00 committed by Chris Weaver
parent b625ee32a7
commit 076ce2ebd0

View File

@ -6,11 +6,11 @@ from sqlalchemy import and_
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.db.models import SamlAccount from danswer.db.models import SamlAccount
from danswer.db.models import User
def upsert_saml_account( def upsert_saml_account(
@ -52,7 +52,7 @@ async def get_saml_account(
(which is necessarily async due to FastAPI Users)""" (which is necessarily async due to FastAPI Users)"""
stmt = ( stmt = (
select(SamlAccount) select(SamlAccount)
.join(User, User.id == SamlAccount.user_id) # type: ignore .options(selectinload(SamlAccount.user)) # Use selectinload for collections
.where( .where(
and_( and_(
SamlAccount.encrypted_cookie == cookie, SamlAccount.encrypted_cookie == cookie,
@ -62,7 +62,7 @@ async def get_saml_account(
) )
result = await async_db_session.execute(stmt) result = await async_db_session.execute(stmt)
return result.scalar_one_or_none() return result.scalars().unique().one_or_none()
async def expire_saml_account( async def expire_saml_account(