check last_pruned instead of is_pruning (#2748)

* check last_pruned instead of is_pruning

* try using the ThreadingHTTPServer class for stability and avoiding blocking single-threaded behavior

* add startup delay to web server in test

* just explicitly return None if we can't parse the datetime

* switch to uvicorn for test stability
This commit is contained in:
rkuo-danswer 2024-10-16 11:52:27 -07:00 committed by GitHub
parent 1a9921f63e
commit 0a0215ceee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 74 additions and 18 deletions

View File

@ -1,4 +1,5 @@
import math import math
from datetime import datetime
from http import HTTPStatus from http import HTTPStatus
from fastapi import APIRouter from fastapi import APIRouter
@ -204,12 +205,12 @@ def update_cc_pair_name(
raise HTTPException(status_code=400, detail="Name must be unique") raise HTTPException(status_code=400, detail="Name must be unique")
@router.get("/admin/cc-pair/{cc_pair_id}/prune") @router.get("/admin/cc-pair/{cc_pair_id}/last_pruned")
def get_cc_pair_latest_prune( def get_cc_pair_last_pruned(
cc_pair_id: int, cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user), user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> bool: ) -> datetime | None:
cc_pair = get_connector_credential_pair_from_id( cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id, cc_pair_id=cc_pair_id,
db_session=db_session, db_session=db_session,
@ -219,11 +220,10 @@ def get_cc_pair_latest_prune(
if not cc_pair: if not cc_pair:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Connection not found for current user's permissions", detail="cc_pair not found for current user's permissions",
) )
rcp = RedisConnectorPruning(cc_pair.id) return cc_pair.last_pruned
return rcp.is_pruning(db_session, get_redis_client())
@router.post("/admin/cc-pair/{cc_pair_id}/prune") @router.post("/admin/cc-pair/{cc_pair_id}/prune")

View File

@ -274,31 +274,40 @@ class CCPairManager:
result.raise_for_status() result.raise_for_status()
@staticmethod @staticmethod
def is_pruning( def last_pruned(
cc_pair: DATestCCPair, cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None, user_performing_action: DATestUser | None = None,
) -> bool: ) -> datetime | None:
response = requests.get( response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune", url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/last_pruned",
headers=user_performing_action.headers headers=user_performing_action.headers
if user_performing_action if user_performing_action
else GENERAL_HEADERS, else GENERAL_HEADERS,
) )
response.raise_for_status() response.raise_for_status()
response_bool = response.json() response_str = response.json()
return response_bool
# If the response itself is a datetime string, parse it
if not isinstance(response_str, str):
return None
try:
return datetime.fromisoformat(response_str)
except ValueError:
return None
@staticmethod @staticmethod
def wait_for_prune( def wait_for_prune(
cc_pair: DATestCCPair, cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY, timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None, user_performing_action: DATestUser | None = None,
) -> None: ) -> None:
"""after: The task register time must be after this time.""" """after: The task register time must be after this time."""
start = time.monotonic() start = time.monotonic()
while True: while True:
result = CCPairManager.is_pruning(cc_pair, user_performing_action) last_pruned = CCPairManager.last_pruned(cc_pair, user_performing_action)
if not result: if last_pruned and last_pruned > after:
break break
elapsed = time.monotonic() - start elapsed = time.monotonic() - start

View File

@ -195,8 +195,9 @@ def test_slack_prune(
) )
# Prune the cc_pair # Prune the cc_pair
now = datetime.now(timezone.utc)
CCPairManager.prune(cc_pair, user_performing_action=admin_user) CCPairManager.prune(cc_pair, user_performing_action=admin_user)
CCPairManager.wait_for_prune(cc_pair, user_performing_action=admin_user) CCPairManager.wait_for_prune(cc_pair, now, user_performing_action=admin_user)
# ----------------------------VERIFY THE CHANGES--------------------------- # ----------------------------VERIFY THE CHANGES---------------------------
# Ensure admin user can't see deleted messages # Ensure admin user can't see deleted messages

View File

@ -10,6 +10,10 @@ from datetime import timezone
from time import sleep from time import sleep
from typing import Any from typing import Any
import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from danswer.server.documents.models import DocumentSource from danswer.server.documents.models import DocumentSource
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from tests.integration.common_utils.managers.api_key import APIKeyManager from tests.integration.common_utils.managers.api_key import APIKeyManager
@ -21,10 +25,50 @@ from tests.integration.common_utils.vespa import vespa_fixture
logger = setup_logger() logger = setup_logger()
# FastAPI server for serving files
def create_fastapi_app(directory: str) -> FastAPI:
app = FastAPI()
# Mount the directory to serve static files
app.mount("/", StaticFiles(directory=directory, html=True), name="static")
return app
# as far as we know, this doesn't hang when crawled. This is good.
@contextmanager
def fastapi_server_context(
directory: str, port: int = 8000
) -> Generator[None, None, None]:
app = create_fastapi_app(directory)
config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level="info")
server = uvicorn.Server(config)
# Create a thread to run the FastAPI server
server_thread = threading.Thread(target=server.run)
server_thread.daemon = (
True # Ensures the thread will exit when the main program exits
)
try:
# Start the server in the background
server_thread.start()
sleep(5) # Give it a few seconds to start
yield # Yield control back to the calling function (context manager in use)
finally:
# Shutdown the server
server.should_exit = True
server_thread.join()
# Leaving this here for posterity and experimentation, but the reason we're
# not using this is python's web servers hang frequently when crawled
# this is obviously not good for a unit test
@contextmanager @contextmanager
def http_server_context( def http_server_context(
directory: str, port: int = 8000 directory: str, port: int = 8000
) -> Generator[http.server.HTTPServer, None, None]: ) -> Generator[http.server.ThreadingHTTPServer, None, None]:
# Create a handler that serves files from the specified directory # Create a handler that serves files from the specified directory
def handler_class( def handler_class(
*args: Any, **kwargs: Any *args: Any, **kwargs: Any
@ -34,7 +78,7 @@ def http_server_context(
) )
# Create an HTTPServer instance # Create an HTTPServer instance
httpd = http.server.HTTPServer(("0.0.0.0", port), handler_class) httpd = http.server.ThreadingHTTPServer(("0.0.0.0", port), handler_class)
# Define a thread that runs the server in the background # Define a thread that runs the server in the background
server_thread = threading.Thread(target=httpd.serve_forever) server_thread = threading.Thread(target=httpd.serve_forever)
@ -45,6 +89,7 @@ def http_server_context(
try: try:
# Start the server in the background # Start the server in the background
server_thread.start() server_thread.start()
sleep(5) # give it a few seconds to start
yield httpd yield httpd
finally: finally:
# Shutdown the server and wait for the thread to finish # Shutdown the server and wait for the thread to finish
@ -70,7 +115,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
website_src = os.path.join(test_directory, "website") website_src = os.path.join(test_directory, "website")
website_tgt = os.path.join(temp_dir, "website") website_tgt = os.path.join(temp_dir, "website")
shutil.copytree(website_src, website_tgt) shutil.copytree(website_src, website_tgt)
with http_server_context(os.path.join(temp_dir, "website"), port): with fastapi_server_context(os.path.join(temp_dir, "website"), port):
sleep(1) # sleep a tiny bit before starting everything sleep(1) # sleep a tiny bit before starting everything
hostname = os.getenv("TEST_WEB_HOSTNAME", "localhost") hostname = os.getenv("TEST_WEB_HOSTNAME", "localhost")
@ -105,9 +150,10 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
logger.info("Removing courses.html.") logger.info("Removing courses.html.")
os.remove(os.path.join(website_tgt, "courses.html")) os.remove(os.path.join(website_tgt, "courses.html"))
now = datetime.now(timezone.utc)
CCPairManager.prune(cc_pair_1, user_performing_action=admin_user) CCPairManager.prune(cc_pair_1, user_performing_action=admin_user)
CCPairManager.wait_for_prune( CCPairManager.wait_for_prune(
cc_pair_1, timeout=60, user_performing_action=admin_user cc_pair_1, now, timeout=60, user_performing_action=admin_user
) )
selected_cc_pair = CCPairManager.get_one( selected_cc_pair = CCPairManager.get_one(