2024-07-11 17:42:11 -07:00

238 lines
7.8 KiB
Python

import json
import os
import subprocess
import sys
from threading import Thread
from typing import IO
from retry import retry
def _run_command(command: str, stream_output: bool = False) -> tuple[str, str]:
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
)
stdout_lines: list[str] = []
stderr_lines: list[str] = []
def process_stream(stream: IO[str], lines: list[str]) -> None:
for line in stream:
lines.append(line)
if stream_output:
print(
line,
end="",
file=sys.stdout if stream == process.stdout else sys.stderr,
)
stdout_thread = Thread(target=process_stream, args=(process.stdout, stdout_lines))
stderr_thread = Thread(target=process_stream, args=(process.stderr, stderr_lines))
stdout_thread.start()
stderr_thread.start()
stdout_thread.join()
stderr_thread.join()
process.wait()
if process.returncode != 0:
raise RuntimeError(f"Command failed with error: {''.join(stderr_lines)}")
return "".join(stdout_lines), "".join(stderr_lines)
def get_current_commit_sha() -> str:
print("Getting current commit SHA...")
stdout, _ = _run_command("git rev-parse HEAD")
sha = stdout.strip()
print(f"Current commit SHA: {sha}")
return sha
def switch_to_branch(branch: str) -> None:
print(f"Switching to branch: {branch}...")
_run_command(f"git checkout {branch}")
_run_command("git pull")
print(f"Successfully switched to branch: {branch}")
print("Repository updated successfully.")
def manage_data_directories(suffix: str, base_path: str, use_cloud_gpu: bool) -> str:
# Use the user's home directory as the base path
target_path = os.path.join(os.path.expanduser(base_path), f"test{suffix}")
directories = {
"DANSWER_POSTGRES_DATA_DIR": os.path.join(target_path, "postgres/"),
"DANSWER_VESPA_DATA_DIR": os.path.join(target_path, "vespa/"),
}
if not use_cloud_gpu:
directories["DANSWER_INDEX_MODEL_CACHE_DIR"] = os.path.join(
target_path, "index_model_cache/"
)
directories["DANSWER_INFERENCE_MODEL_CACHE_DIR"] = os.path.join(
target_path, "inference_model_cache/"
)
# Create directories if they don't exist
for env_var, directory in directories.items():
os.makedirs(directory, exist_ok=True)
os.environ[env_var] = directory
print(f"Set {env_var} to: {directory}")
relari_output_path = os.path.join(target_path, "relari_output/")
os.makedirs(relari_output_path, exist_ok=True)
return relari_output_path
def set_env_variables(
remote_server_ip: str,
remote_server_port: str,
use_cloud_gpu: bool,
llm_config: dict,
) -> None:
env_vars: dict = {}
env_vars["ENV_SEED_CONFIGURATION"] = json.dumps({"llms": [llm_config]})
env_vars["ENABLE_PAID_ENTERPRISE_EDITION_FEATURES"] = "true"
if use_cloud_gpu:
env_vars["MODEL_SERVER_HOST"] = remote_server_ip
env_vars["MODEL_SERVER_PORT"] = remote_server_port
env_vars["INDEXING_MODEL_SERVER_HOST"] = remote_server_ip
for env_var_name, env_var in env_vars.items():
os.environ[env_var_name] = env_var
print(f"Set {env_var_name} to: {env_var}")
def start_docker_compose(
run_suffix: str, launch_web_ui: bool, use_cloud_gpu: bool
) -> None:
print("Starting Docker Compose...")
os.chdir(os.path.dirname(__file__))
os.chdir("../../../../deployment/docker_compose/")
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack{run_suffix} up -d"
command += " --build"
command += " --force-recreate"
if not launch_web_ui:
command += " --scale web_server=0"
command += " --scale nginx=0"
if use_cloud_gpu:
command += " --scale indexing_model_server=0"
command += " --scale inference_model_server=0"
print("Docker Command:\n", command)
_run_command(command, stream_output=True)
print("Containers have been launched")
def cleanup_docker(run_suffix: str) -> None:
print(
f"Deleting Docker containers, volumes, and networks for project suffix: {run_suffix}"
)
stdout, _ = _run_command("docker ps -a --format '{{json .}}'")
containers = [json.loads(line) for line in stdout.splitlines()]
project_name = f"danswer-stack{run_suffix}"
containers_to_delete = [
c for c in containers if c["Names"].startswith(project_name)
]
if not containers_to_delete:
print(f"No containers found for project: {project_name}")
else:
container_ids = " ".join([c["ID"] for c in containers_to_delete])
_run_command(f"docker rm -f {container_ids}")
print(
f"Successfully deleted {len(containers_to_delete)} containers for project: {project_name}"
)
stdout, _ = _run_command("docker volume ls --format '{{.Name}}'")
volumes = stdout.splitlines()
volumes_to_delete = [v for v in volumes if v.startswith(project_name)]
if not volumes_to_delete:
print(f"No volumes found for project: {project_name}")
return
# Delete filtered volumes
volume_names = " ".join(volumes_to_delete)
_run_command(f"docker volume rm {volume_names}")
print(
f"Successfully deleted {len(volumes_to_delete)} volumes for project: {project_name}"
)
stdout, _ = _run_command("docker network ls --format '{{.Name}}'")
networks = stdout.splitlines()
networks_to_delete = [n for n in networks if run_suffix in n]
if not networks_to_delete:
print(f"No networks found containing suffix: {run_suffix}")
else:
network_names = " ".join(networks_to_delete)
_run_command(f"docker network rm {network_names}")
print(
f"Successfully deleted {len(networks_to_delete)} networks containing suffix: {run_suffix}"
)
@retry(tries=5, delay=5, backoff=2)
def get_api_server_host_port(suffix: str) -> str:
"""
This pulls all containers with the provided suffix
It then grabs the JSON specific container with a name containing "api_server"
It then grabs the port info from the JSON and strips out the relevent data
"""
container_name = "api_server"
stdout, _ = _run_command("docker ps -a --format '{{json .}}'")
containers = [json.loads(line) for line in stdout.splitlines()]
server_jsons = []
for container in containers:
if container_name in container["Names"] and suffix in container["Names"]:
server_jsons.append(container)
if not server_jsons:
raise RuntimeError(
f"No container found containing: {container_name} and {suffix}"
)
elif len(server_jsons) > 1:
raise RuntimeError(
f"Too many containers matching {container_name} found, please indicate a suffix"
)
server_json = server_jsons[0]
# This is in case the api_server has multiple ports
client_port = "8080"
ports = server_json.get("Ports", "")
port_infos = ports.split(",") if ports else []
port_dict = {}
for port_info in port_infos:
port_arr = port_info.split(":")[-1].split("->") if port_info else []
if len(port_arr) == 2:
port_dict[port_arr[1]] = port_arr[0]
# Find the host port where client_port is in the key
matching_ports = [value for key, value in port_dict.items() if client_port in key]
if len(matching_ports) > 1:
raise RuntimeError(f"Too many ports matching {client_port} found")
if not matching_ports:
raise RuntimeError(
f"No port found containing: {client_port} for container: {container_name} and suffix: {suffix}"
)
return matching_ports[0]