new model structure concept

This commit is contained in:
pablodanswer
2024-09-13 20:38:03 -07:00
parent 7b91beb3b2
commit 074342165b
8 changed files with 48 additions and 23 deletions

View File

@@ -181,8 +181,6 @@ def update_current_search_settings(
logger.warning("No current search settings found to update")
return
print("current settings", current_settings.__dict__)
print("search settings", search_settings.__dict__)
# Whenever we update the current search settings, we should ensure that the local reranking model is warmed up.
if (
search_settings.rerank_provider_type is None

View File

@@ -16,7 +16,6 @@ from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
from shared_configs.enums import RerankerProvider
from shared_configs.utils import obfuscate_api_key
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
@@ -53,14 +52,8 @@ class InferenceSettings(RerankingDetails):
class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
@classmethod
def from_db_model(
cls, search_settings: SearchSettings
) -> "SearchSettingsCreationRequest":
inference_settings = InferenceSettings.from_db_model(search_settings)
indexing_setting = IndexingSetting.from_db_model(search_settings)
return cls(**inference_settings.dict(), **indexing_setting.dict())
api_key_set: bool
rerank_api_key_set: bool
class SavedSearchSettings(InferenceSettings, IndexingSetting):
@@ -88,13 +81,18 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
class SearchSettingsSnapshot(SavedSearchSettings):
rerank_api_key: None = None
api_key: None = None
rerank_api_key_set: bool
api_key_set: bool
@classmethod
def from_saved_settings(
cls, settings: SavedSearchSettings
) -> "SearchSettingsSnapshot":
data = settings.dict(exclude={"rerank_api_key", "api_key"})
data["rerank_api_key"] = obfuscate_api_key(settings.rerank_api_key)
data["api_key"] = obfuscate_api_key(settings.api_key)
data["rerank_api_key_set"] = bool(settings.rerank_api_key)
data["api_key_set"] = bool(settings.api_key)
return cls(**data)

View File

@@ -4,7 +4,6 @@ from pydantic import BaseModel
from pydantic import Field
from danswer.llm.llm_provider_options import fetch_models_for_provider
from shared_configs.utils import obfuscate_api_key
if TYPE_CHECKING:
from danswer.db.models import LLMProvider as LLMProviderModel
@@ -82,7 +81,7 @@ class FullLLMProvider(LLMProvider):
@classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
return cls(
api_key=obfuscate_api_key(llm_provider_model.api_key),
api_key=llm_provider_model.api_key,
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=llm_provider_model.provider,
@@ -104,10 +103,20 @@ class FullLLMProvider(LLMProvider):
class FullLLMProviderSnapshot(FullLLMProvider):
api_key: None = None
api_key_set: bool
@classmethod
def from_full_llm_provider(
cls, settings: FullLLMProvider
) -> "FullLLMProviderSnapshot":
data = settings.dict(exclude={"api_key"})
data["api_key"] = obfuscate_api_key(settings.api_key)
data["api_key_set"] = bool(settings.api_key)
return cls(**data)
@classmethod
def from_model(
cls, llm_provider_model: "LLMProviderModel"
) -> "FullLLMProviderSnapshot":
full_provider = FullLLMProvider.from_model(llm_provider_model)
return cls.from_full_llm_provider(full_provider)

View File

@@ -74,7 +74,9 @@ def set_new_search_settings(
search_values["index_name"] = index_name
new_search_settings_request = SavedSearchSettings(**search_values)
else:
new_search_settings_request = SavedSearchSettings(**search_settings_new.dict())
new_search_settings_request = SavedSearchSettings(
**search_settings_new.dict().exclude("api_key_set", "rerank_api_key_set")
)
secondary_search_settings = get_secondary_search_settings(db_session)

View File

@@ -24,7 +24,7 @@ export interface EmbeddingDetails {
import { EmbeddingIcon } from "@/components/icons/icons";
import Link from "next/link";
import { SavedSearchSettings } from "../../embeddings/interfaces";
import { SearchSettingsSnapshot } from "../../embeddings/interfaces";
import UpgradingPage from "./UpgradingPage";
import { useContext } from "react";
import { SettingsContext } from "@/components/settings/SettingsProvider";
@@ -42,7 +42,7 @@ function Main() {
);
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
useSWR<SavedSearchSettings | null>(
useSWR<SearchSettingsSnapshot | null>(
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds

View File

@@ -8,6 +8,10 @@ export interface RerankingDetails {
rerank_api_key: string | null;
rerank_api_url: string | null;
}
export interface RerankingDetailsSnapshot
extends Omit<RerankingDetails, "rerank_api_key"> {
rerank_api_key_set: boolean;
}
export enum RerankerProvider {
COHERE = "cohere",
@@ -34,6 +38,14 @@ export interface SavedSearchSettings
provider_type: EmbeddingProvider | null;
}
export interface SearchSettingsSnapshot
extends Omit<RerankingDetails, "rerank_api_key">,
Omit<AdvancedSearchConfiguration, "api_url"> {
provider_type: EmbeddingProvider | null;
rerank_api_key_set: boolean;
api_key_set: boolean;
}
export interface RerankingModel {
rerank_provider_type: RerankerProvider | null;
modelName?: string;

View File

@@ -20,6 +20,7 @@ import {
AdvancedSearchConfiguration,
RerankingDetails,
SavedSearchSettings,
SearchSettingsSnapshot,
} from "../interfaces";
import RerankingDetailsForm from "../RerankingFormPage";
import { useEmbeddingFormContext } from "@/components/context/EmbeddingContext";
@@ -98,7 +99,7 @@ export default function EmbeddingForm() {
>(currentEmbeddingModel!);
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
useSWR<SavedSearchSettings | null>(
useSWR<SearchSettingsSnapshot | null>(
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
@@ -122,7 +123,7 @@ export default function EmbeddingForm() {
});
setRerankingDetails({
rerank_api_key: searchSettings.rerank_api_key,
rerank_api_key: null,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
rerank_api_url: searchSettings.rerank_api_url,
@@ -132,7 +133,7 @@ export default function EmbeddingForm() {
const originalRerankingDetails: RerankingDetails = searchSettings
? {
rerank_api_key: searchSettings.rerank_api_key,
rerank_api_key: null,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
rerank_api_url: searchSettings.rerank_api_url,
@@ -173,7 +174,7 @@ export default function EmbeddingForm() {
const response = await updateSearchSettings(values);
if (response.ok) {
setPopup({
message: "Updated search settings succesffuly",
message: "Updated search settings successfully",
type: "success",
});
mutate("/api/search-settings/get-current-search-settings");

View File

@@ -52,6 +52,11 @@ export interface EmbeddingModelDescriptor {
index_name: string | null;
}
export interface EmbeddingModelSnapshot
extends Omit<CloudEmbeddingModel, "api_key"> {
api_key_set: boolean;
}
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
pricePerMillion: number;
enabled?: boolean;