Prompt user for OpenAI key

This commit is contained in:
Weves 2023-05-20 15:51:34 -07:00 committed by Chris Weaver
parent 544ba8f50d
commit 0b8c69ceeb
17 changed files with 406 additions and 11 deletions

View File

@ -12,6 +12,7 @@ SECTION_CONTINUATION = "section_continuation"
ALLOWED_USERS = "allowed_users"
ALLOWED_GROUPS = "allowed_groups"
NO_AUTH_USER = "FooBarUser" # TODO rework this temporary solution
OPENAI_API_KEY_STORAGE_KEY = "openai_api_key"
class DocumentSource(str, Enum):

View File

@ -1,15 +1,20 @@
from typing import Any
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.question_answer import OpenAIChatCompletionQA
from danswer.direct_qa.question_answer import OpenAICompletionQA
from danswer.dynamic_configs import get_dynamic_config_store
def get_default_backend_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
internal_model: str = INTERNAL_MODEL_VERSION, **kwargs: dict[str, Any]
) -> QAModel:
if internal_model == "openai-completion":
return OpenAICompletionQA()
return OpenAICompletionQA(**kwargs)
elif internal_model == "openai-chat-completion":
return OpenAIChatCompletionQA()
return OpenAIChatCompletionQA(**kwargs)
else:
raise ValueError("Wrong internal QA model set.")

View File

@ -0,0 +1,21 @@
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import OpenAIQAModel
from danswer.dynamic_configs import get_dynamic_config_store
from openai.error import AuthenticationError
def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
if not openai_api_key:
return False
qa_model = get_default_backend_qa_model(api_key=openai_api_key)
if not isinstance(qa_model, OpenAIQAModel):
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
try:
qa_model.answer_question("Do not respond", [])
return True
except AuthenticationError:
return False

View File

@ -3,6 +3,7 @@ import math
import re
from collections.abc import Callable
from collections.abc import Generator
from functools import wraps
from typing import Any
from typing import cast
from typing import Dict
@ -17,6 +18,7 @@ from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import BLURB
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
@ -29,15 +31,19 @@ from danswer.direct_qa.qa_prompts import json_chat_processor
from danswer.direct_qa.qa_prompts import json_processor
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.utils.logging import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup
from danswer.utils.timing import log_function_time
from openai.error import AuthenticationError
logger = setup_logger()
openai.api_key = OPENAI_API_KEY
def get_openai_api_key():
return OPENAI_API_KEY or get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
def get_json_line(json_dict: dict) -> str:
@ -181,16 +187,23 @@ def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
return False
class OpenAICompletionQA(QAModel):
# used to check if the QAModel is an OpenAI model
class OpenAIQAModel(QAModel):
pass
class OpenAICompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: Callable[[str, list[str]], str] = json_processor,
model_version: str = OPENAI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.api_key = api_key or get_openai_api_key()
@log_function_time()
def answer_question(
@ -202,6 +215,7 @@ class OpenAICompletionQA(QAModel):
try:
response = openai.Completion.create(
api_key=self.api_key,
prompt=filled_prompt,
temperature=0,
top_p=1,
@ -214,6 +228,9 @@ class OpenAICompletionQA(QAModel):
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
@ -232,6 +249,7 @@ class OpenAICompletionQA(QAModel):
try:
response = openai.Completion.create(
api_key=self.api_key,
prompt=filled_prompt,
temperature=0,
top_p=1,
@ -263,7 +281,9 @@ class OpenAICompletionQA(QAModel):
yield {"answer_finished": True}
continue
yield {"answer_data": event_text}
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
@ -276,7 +296,7 @@ class OpenAICompletionQA(QAModel):
yield quotes_dict
class OpenAIChatCompletionQA(QAModel):
class OpenAIChatCompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: Callable[
@ -285,11 +305,13 @@ class OpenAIChatCompletionQA(QAModel):
model_version: str = OPENAI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
reflexion_try_count: int = 0,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.reflexion_try_count = reflexion_try_count
self.api_key = api_key or get_openai_api_key()
@log_function_time()
def answer_question(
@ -302,6 +324,7 @@ class OpenAIChatCompletionQA(QAModel):
for _ in range(self.reflexion_try_count + 1):
try:
response = openai.ChatCompletion.create(
api_key=self.api_key,
messages=messages,
temperature=0,
top_p=1,
@ -316,6 +339,9 @@ class OpenAIChatCompletionQA(QAModel):
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Exception as e:
logger.exception(e)
logger.warning(f"Model failure for query: {query}")
@ -335,6 +361,7 @@ class OpenAIChatCompletionQA(QAModel):
try:
response = openai.ChatCompletion.create(
api_key=self.api_key,
messages=messages,
temperature=0,
top_p=1,
@ -370,7 +397,9 @@ class OpenAIChatCompletionQA(QAModel):
yield {"answer_finished": True}
continue
yield {"answer_data": event_text}
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Exception as e:
logger.exception(e)
model_output = "Model Failure"

View File

@ -1,4 +1,5 @@
import json
import os
from pathlib import Path
from typing import cast
@ -36,3 +37,11 @@ class FileSystemBackedDynamicConfigStore(DynamicConfigStore):
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
with open(self.dir_path / key) as f:
return cast(JSON_ro, json.load(f))
def delete(self, key: str) -> None:
file_path = self.dir_path / key
if not file_path.exists():
raise ConfigNotFoundError
lock = _get_file_lock(file_path)
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
os.remove(file_path)

View File

@ -21,3 +21,7 @@ class DynamicConfigStore:
@abc.abstractmethod
def load(self, key: str) -> JSON_ro:
raise NotImplementedError
@abc.abstractmethod
def delete(self, key: str) -> None:
raise NotImplementedError

View File

@ -1,8 +1,10 @@
from typing import Any
from typing import cast
from danswer.auth.users import current_admin_user
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import NO_AUTH_USER
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.connectors.factory import build_connector
from danswer.connectors.google_drive.connector_auth import get_auth_url
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
@ -17,7 +19,13 @@ from danswer.db.index_attempt import insert_index_attempt
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import User
from danswer.direct_qa.key_validation import (
check_openai_api_key_is_valid,
)
from danswer.direct_qa.question_answer import get_openai_api_key
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.models import ApiKey
from danswer.server.models import AuthStatus
from danswer.server.models import AuthUrl
from danswer.server.models import GDriveCallback
@ -27,6 +35,7 @@ from danswer.server.models import ListIndexAttemptsResponse
from danswer.utils.logging import setup_logger
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
router = APIRouter(prefix="/admin")
@ -140,3 +149,59 @@ def list_all_index_attempts(
for index_attempt in index_attempts
]
)
@router.head("/openai-api-key/validate")
def validate_existing_openai_api_key(
_: User = Depends(current_admin_user),
) -> None:
is_valid = False
try:
openai_api_key = get_openai_api_key()
is_valid = check_openai_api_key_is_valid(openai_api_key)
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
if not is_valid:
raise HTTPException(status_code=400, detail="Invalid API key provided")
@router.get("/openai-api-key")
def get_openai_api_key_from_dynamic_config_store(
_: User = Depends(current_admin_user),
) -> ApiKey:
"""
NOTE: Only gets value from dynamic config store as to not expose env variables.
"""
try:
# only get last 4 characters of key to not expose full key
return ApiKey(
api_key=cast(
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
)[-4:]
)
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
@router.post("/openai-api-key")
def store_openai_api_key(
request: ApiKey,
_: User = Depends(current_admin_user),
) -> None:
try:
is_valid = check_openai_api_key_is_valid(request.api_key)
if not is_valid:
raise HTTPException(400, "Invalid API key provided")
get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key)
except RuntimeError as e:
raise HTTPException(400, str(e))
@router.delete("/openai-api-key")
def delete_openai_api_key(
_: User = Depends(current_admin_user),
) -> None:
get_dynamic_config_store().delete(OPENAI_API_KEY_STORAGE_KEY)

View File

@ -73,3 +73,7 @@ class IndexAttemptSnapshot(BaseModel):
class ListIndexAttemptsResponse(BaseModel):
index_attempts: list[IndexAttemptSnapshot]
class ApiKey(BaseModel):
api_key: str

View File

@ -0,0 +1,76 @@
"use client";
import { LoadingAnimation } from "@/components/Loading";
import { KeyIcon, TrashIcon } from "@/components/icons/icons";
import { ApiKeyForm } from "@/components/openai/ApiKeyForm";
import { OPENAI_API_KEY_URL } from "@/components/openai/constants";
import { fetcher } from "@/lib/fetcher";
import useSWR, { mutate } from "swr";
const ExistingKeys = () => {
const { data, isLoading, error } = useSWR<{ api_key: string }>(
OPENAI_API_KEY_URL,
fetcher
);
if (isLoading) {
return <LoadingAnimation text="Loading" />;
}
if (error) {
return <div className="text-red-600">Error loading existing keys</div>;
}
if (!data?.api_key) {
return null;
}
return (
<div>
<h2 className="text-lg font-bold mb-2">Existing Key</h2>
<div className="flex mb-1">
<p className="text-sm italic my-auto">sk- ...{data?.api_key}</p>
<button
className="ml-1 my-auto hover:bg-gray-700 rounded-full p-1"
onClick={async () => {
await fetch(OPENAI_API_KEY_URL, {
method: "DELETE",
});
window.location.reload();
}}
>
<TrashIcon />
</button>
</div>
</div>
);
};
const Page = () => {
return (
<div>
<div className="border-solid border-gray-600 border-b pb-2 mb-4 flex">
<KeyIcon size="32" />
<h1 className="text-3xl font-bold pl-2">OpenAI Keys</h1>
</div>
<ExistingKeys />
<h2 className="text-lg font-bold mb-2">Update Key</h2>
<p className="text-sm mb-2">
Specify an OpenAI API key and click the &quot;Submit&quot; button.
</p>
<div className="border rounded-md border-gray-700 p-3">
<ApiKeyForm
handleResponse={(response) => {
if (response.ok) {
mutate(OPENAI_API_KEY_URL);
}
}}
/>
</div>
</div>
);
};
export default Page;

View File

@ -6,6 +6,7 @@ import {
GlobeIcon,
GoogleDriveIcon,
SlackIcon,
KeyIcon,
} from "@/components/icons/icons";
import { DISABLE_AUTH } from "@/lib/constants";
import { getCurrentUserSS } from "@/lib/userSS";
@ -89,6 +90,20 @@ export default async function AdminLayout({
},
],
},
{
name: "Keys",
items: [
{
name: (
<div className="flex">
<KeyIcon size="16" />
<div className="ml-1">OpenAI</div>
</div>
),
link: "/admin/keys/openai",
},
],
},
]}
/>
<div className="px-12 min-h-screen bg-gray-900 text-gray-100 w-full">

View File

@ -19,7 +19,7 @@ export default async function RootLayout({
}) {
return (
<html lang="en">
<body className={`${inter.variable} font-sans bg-gray-900`}>
<body className={`${inter.variable} font-sans bg-gray-900 text-gray-100`}>
{children}
</body>
</html>

View File

@ -4,6 +4,7 @@ import { getCurrentUserSS } from "@/lib/userSS";
import { redirect } from "next/navigation";
import { DISABLE_AUTH } from "@/lib/constants";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { ApiKeyModal } from "@/components/openai/ApiKeyModal";
export default async function Home() {
let user = null;
@ -13,12 +14,14 @@ export default async function Home() {
return redirect("/auth/login");
}
}
return (
<>
<Header user={user} />
<div className="m-3">
<HealthCheckBanner />
</div>
<ApiKeyModal />
<div className="px-24 pt-10 flex flex-col items-center min-h-screen bg-gray-900 text-gray-100">
<div className="max-w-[800px] w-full">
<SearchSection />

View File

@ -3,16 +3,21 @@ import { ErrorMessage, Field } from "formik";
interface TextFormFieldProps {
name: string;
label: string;
type?: string;
}
export const TextFormField = ({ name, label }: TextFormFieldProps) => {
export const TextFormField = ({
name,
label,
type = "text",
}: TextFormFieldProps) => {
return (
<div className="mb-4">
<label htmlFor={name} className="block mb-1">
{label}
</label>
<Field
type="text"
type={type}
name={name}
id={name}
className="border bg-slate-700 text-gray-200 border-gray-300 rounded w-full py-2 px-3"

View File

@ -7,6 +7,8 @@ import {
GithubLogo,
GoogleDriveLogo,
Notebook,
Key,
Trash,
} from "@phosphor-icons/react";
interface IconProps {
@ -23,6 +25,20 @@ export const NotebookIcon = ({
return <Notebook size={size} className={className} />;
};
export const KeyIcon = ({
size = "16",
className = defaultTailwindCSS,
}: IconProps) => {
return <Key size={size} className={className} />;
};
export const TrashIcon = ({
size = "16",
className = defaultTailwindCSS,
}: IconProps) => {
return <Trash size={size} className={className} />;
};
export const GlobeIcon = ({
size = "16",
className = defaultTailwindCSS,

View File

@ -0,0 +1,86 @@
import { Form, Formik } from "formik";
import { Popup } from "../admin/connectors/Popup";
import { useState } from "react";
import { TextFormField } from "../admin/connectors/Field";
import { OPENAI_API_KEY_URL } from "./constants";
import { LoadingAnimation } from "../Loading";
interface Props {
handleResponse?: (response: Response) => void;
}
export const ApiKeyForm = ({ handleResponse }: Props) => {
const [popup, setPopup] = useState<{
message: string;
type: "success" | "error";
} | null>(null);
return (
<div>
{popup && <Popup message={popup.message} type={popup.type} />}
<Formik
initialValues={{ apiKey: "" }}
onSubmit={async ({ apiKey }, formikHelpers) => {
const response = await fetch(OPENAI_API_KEY_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ api_key: apiKey }),
});
if (handleResponse) {
handleResponse(response);
}
if (response.ok) {
setPopup({
message: "Updated API key!",
type: "success",
});
formikHelpers.resetForm();
} else {
const body = await response.json();
if (body.detail) {
setPopup({ message: body.detail, type: "error" });
} else {
setPopup({
message:
"Unable to set API key. Check if the provided key is valid.",
type: "error",
});
}
setTimeout(() => {
setPopup(null);
}, 3000);
}
}}
>
{({ isSubmitting }) =>
isSubmitting ? (
<LoadingAnimation text="Validating API key" />
) : (
<Form>
<TextFormField
name="apiKey"
type="password"
label="OpenAI API Key:"
/>
<div className="flex">
<button
type="submit"
disabled={isSubmitting}
className={
"bg-slate-500 hover:bg-slate-700 text-white " +
"font-bold py-2 px-4 rounded focus:outline-none " +
"focus:shadow-outline w-full mx-auto"
}
>
Submit
</button>
</div>
</Form>
)
}
</Formik>
</div>
);
};

View File

@ -0,0 +1,55 @@
"use client";
import { useState, useEffect } from "react";
import { ApiKeyForm } from "./ApiKeyForm";
export const ApiKeyModal = () => {
const [isOpen, setIsOpen] = useState(false);
useEffect(() => {
fetch("/api/admin/openai-api-key/validate", {
method: "HEAD",
}).then((res) => {
// show popup if either the API key is not set or the API key is invalid
if (!res.ok && (res.status === 404 || res.status === 400)) {
setIsOpen(true);
}
});
}, []);
return (
<div>
{isOpen && (
<div
className="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-50"
onClick={() => setIsOpen(false)}
>
<div
className="bg-gray-800 p-6 rounded border border-gray-700 shadow-lg relative w-1/2 text-sm"
onClick={(event) => event.stopPropagation()}
>
<p className="mb-2.5 font-bold">
Can&apos;t find a valid registered OpenAI API key. Please provide
one to be able to ask questions! Or if you&apos;d rather just look
around for now,{" "}
<strong
onClick={() => setIsOpen(false)}
className="text-blue-300 cursor-pointer"
>
skip this step
</strong>
.
</p>
<ApiKeyForm
handleResponse={(response) => {
if (response.ok) {
setIsOpen(false);
}
}}
/>
</div>
</div>
)}
</div>
);
};

View File

@ -0,0 +1 @@
export const OPENAI_API_KEY_URL = "/api/admin/openai-api-key";