mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 11:28:09 +02:00
Merge branch 'main' into add-teams-connector
This commit is contained in:
commit
a9834853ef
@ -14,16 +14,16 @@ jobs:
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v2
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
@ -38,5 +38,7 @@ jobs:
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
trivyignores: ./backend/.trivyignore
|
||||
|
@ -14,16 +14,16 @@ jobs:
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v2
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
|
@ -5,38 +5,114 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=danswer/danswer-web-server:${{ github.ref_name }}
|
||||
type=raw,value=danswer/danswer-web-server:latest
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- build
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Web Image Docker Build and Push
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-web-server:${{ github.ref_name }}
|
||||
danswer/danswer-web-server:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-web-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
2
.github/workflows/pr-python-checks.yml
vendored
2
.github/workflows/pr-python-checks.yml
vendored
@ -20,10 +20,12 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Run MyPy
|
||||
run: |
|
||||
|
135
.vscode/launch.template.jsonc
vendored
135
.vscode/launch.template.jsonc
vendored
@ -11,62 +11,6 @@
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "API Server",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"DISABLE_AUTH": "True",
|
||||
"TYPESENSE_API_KEY": "typesense_api_key",
|
||||
"DYNAMIC_CONFIG_DIR_PATH": "./dynamic_config_storage"
|
||||
},
|
||||
"args": [
|
||||
"danswer.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Indexer",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONPATH": ".",
|
||||
"TYPESENSE_API_KEY": "typesense_api_key",
|
||||
"DYNAMIC_CONFIG_DIR_PATH": "./dynamic_config_storage"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Temp File Deletion",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/file_deletion.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONPATH": "${workspaceFolder}/backend"
|
||||
}
|
||||
},
|
||||
// For the listner to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot Listener",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/listeners/slack_listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
@ -77,6 +21,85 @@
|
||||
"run", "dev"
|
||||
],
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"model_server.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"LOG_ALL_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"danswer.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Indexing",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"ENABLE_MINI_CHUNK": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
||||
{
|
||||
"name": "Background Jobs",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--no-indexing"
|
||||
]
|
||||
},
|
||||
// For the listner to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "danswer/danswerbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
@ -72,15 +72,20 @@ For convenience here's a command for it:
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
_For Windows activate via:_
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
If using PowerShell, the command slightly differs:
|
||||
```powershell
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Install the required python dependencies:
|
||||
```bash
|
||||
pip install -r danswer/backend/requirements/default.txt
|
||||
pip install -r danswer/backend/requirements/dev.txt
|
||||
pip install -r danswer/backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
@ -108,26 +113,24 @@ docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational
|
||||
(index refers to Vespa and relational_db refers to Postgres)
|
||||
|
||||
#### Running Danswer
|
||||
|
||||
Setup a folder to store config. Navigate to `danswer/backend` and run:
|
||||
```bash
|
||||
mkdir dynamic_config_storage
|
||||
```
|
||||
|
||||
To start the frontend, navigate to `danswer/web` and run:
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Package the Vespa schema. This will only need to be done when the Vespa schema is updated locally.
|
||||
|
||||
Navigate to `danswer/backend/danswer/document_index/vespa/app_config` and run:
|
||||
Next, start the model server which runs the local NLP models.
|
||||
Navigate to `danswer/backend` and run:
|
||||
```bash
|
||||
zip -r ../vespa-app.zip .
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
"
|
||||
```
|
||||
- Note: If you don't have the `zip` utility, you will need to install it prior to running the above
|
||||
|
||||
The first time running Danswer, you will also need to run the DB migrations for Postgres.
|
||||
The first time running Danswer, you will need to run the DB migrations for Postgres.
|
||||
After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `danswer/backend` and with the venv active, run:
|
||||
@ -145,17 +148,12 @@ python ./scripts/dev_run_background_jobs.py
|
||||
|
||||
To run the backend API server, navigate back to `danswer/backend` and run:
|
||||
```bash
|
||||
AUTH_TYPE=disabled \
|
||||
DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage \
|
||||
VESPA_DEPLOYMENT_ZIP=./danswer/document_index/vespa/vespa-app.zip \
|
||||
uvicorn danswer.main:app --reload --port 8080
|
||||
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
|
||||
```
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='disabled'
|
||||
$env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage'
|
||||
$env:VESPA_DEPLOYMENT_ZIP='./danswer/document_index/vespa/vespa-app.zip'
|
||||
uvicorn danswer.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
@ -174,20 +172,16 @@ pre-commit install
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Danswer is fully type-annotated, and we would like to keep it that way!
|
||||
Right now, there is no automated type checking at the moment (coming soon), but we ask you to manually run it before
|
||||
creating a pull requests with `python -m mypy .` from the `danswer/backend` directory.
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
|
||||
|
||||
|
||||
#### Web
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `danswer/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `danswer/web` directory.
|
||||
Like `mypy`, we have no automated formatting yet (coming soon), but we request that, for now,
|
||||
you run this manually before creating a pull request.
|
||||
Please double check that prettier passes before creating a pull request.
|
||||
|
||||
|
||||
### Release Process
|
||||
Danswer follows the semver versioning standard.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).
|
||||
|
||||
As pre-1.0 software, even patch releases may contain breaking or non-backwards-compatible changes.
|
||||
|
43
README.md
43
README.md
@ -5,7 +5,7 @@
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
<p align="center">Open Source Unified Search and Gen-AI Chat with your Docs.</p>
|
||||
<p align="center">Open Source Gen-AI Chat + Unified Search.</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
@ -22,16 +22,17 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Danswer](https://www.danswer.ai/)</strong> lets you ask questions in natural language questions and get back
|
||||
answers based on your team specific documents. Think ChatGPT if it had access to your team's unique
|
||||
knowledge. Connects to all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
<strong>[Danswer](https://www.danswer.ai/)</strong> is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring Personas (AI Assistants) and their Prompts.
|
||||
|
||||
Teams have used Danswer to:
|
||||
- Speedup customer support and escalation turnaround time.
|
||||
- Improve Engineering efficiency by making documentation and code changelogs easy to find.
|
||||
- Let sales team get fuller context and faster in preparation for calls.
|
||||
- Track customer requests and priorities for Product teams.
|
||||
- Help teams self-serve IT, Onboarding, HR, etc.
|
||||
Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
|
||||
supported?" or "Where's the pull request for feature Y?"
|
||||
|
||||
<h3>Usage</h3>
|
||||
|
||||
@ -57,19 +58,27 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
|
||||
|
||||
|
||||
## 💃 Main Features
|
||||
* Chat UI with the ability to select documents to chat with.
|
||||
* Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
* Connect Danswer with LLM of your choice (self-host for a fully airgapped solution).
|
||||
* Document Search + AI Answers for natural language queries.
|
||||
* Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
* Chat support (think ChatGPT but it has access to your private knowledge sources).
|
||||
* Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
* Slack integration to get answers and search results directly in Slack.
|
||||
|
||||
|
||||
## 🚧 Roadmap
|
||||
* Chat/Prompt sharing with specific teammates and user groups.
|
||||
* Multi-Model model support, chat with images, video etc.
|
||||
* Choosing between LLMs and parameters during chat session.
|
||||
* Tool calling and agent configurations options.
|
||||
* Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
|
||||
## Other Noteable Benefits of Danswer
|
||||
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
* User Authentication with document level access management.
|
||||
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
* Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
* Custom deep learning models + learn from user feedback.
|
||||
* Connect Danswer with LLM of your choice for a fully airgapped solution.
|
||||
* Easy deployment and ability to host Danswer anywhere of your choosing.
|
||||
|
||||
|
||||
@ -96,11 +105,5 @@ Efficiently pulls the latest changes from:
|
||||
* Websites
|
||||
* And more ...
|
||||
|
||||
## 🚧 Roadmap
|
||||
* Organizational understanding.
|
||||
* Ability to locate and suggest experts from your team.
|
||||
* Code Search
|
||||
* Structured Query Languages (SQL, Excel formulas, etc.)
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
2
backend/.gitignore
vendored
2
backend/.gitignore
vendored
@ -8,4 +8,4 @@ api_keys.py
|
||||
.env
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule
|
||||
celerybeat-schedule*
|
||||
|
46
backend/.trivyignore
Normal file
46
backend/.trivyignore
Normal file
@ -0,0 +1,46 @@
|
||||
# https://github.com/madler/zlib/issues/868
|
||||
# Pulled in with base Debian image, it's part of the contrib folder but unused
|
||||
# zlib1g is fine
|
||||
# Will be gone with Debian image upgrade
|
||||
# No impact in our settings
|
||||
CVE-2023-45853
|
||||
|
||||
# krb5 related, worst case is denial of service by resource exhaustion
|
||||
# Accept the risk
|
||||
CVE-2024-26458
|
||||
CVE-2024-26461
|
||||
CVE-2024-26462
|
||||
CVE-2024-26458
|
||||
CVE-2024-26461
|
||||
CVE-2024-26462
|
||||
CVE-2024-26458
|
||||
CVE-2024-26461
|
||||
CVE-2024-26462
|
||||
CVE-2024-26458
|
||||
CVE-2024-26461
|
||||
CVE-2024-26462
|
||||
|
||||
# Specific to Firefox which we do not use
|
||||
# No impact in our settings
|
||||
CVE-2024-0743
|
||||
|
||||
# bind9 related, worst case is denial of service by CPU resource exhaustion
|
||||
# Accept the risk
|
||||
CVE-2023-50387
|
||||
CVE-2023-50868
|
||||
CVE-2023-50387
|
||||
CVE-2023-50868
|
||||
|
||||
# libexpat1, XML parsing resource exhaustion
|
||||
# We don't parse any user provided XMLs
|
||||
# No impact in our settings
|
||||
CVE-2023-52425
|
||||
CVE-2024-28757
|
||||
|
||||
# sqlite, only used by NLTK library to grab word lemmatizer and stopwords
|
||||
# No impact in our settings
|
||||
CVE-2023-7104
|
||||
|
||||
# libharfbuzz0b, O(n^2) growth, worst case is denial of service
|
||||
# Accept the risk
|
||||
CVE-2023-25193
|
@ -1,5 +1,10 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
LABEL com.danswer.maintainer="founders@danswer.ai"
|
||||
LABEL com.danswer.description="This image is for the backend of Danswer. It is MIT Licensed and \
|
||||
free for all to use. You can find it at https://hub.docker.com/r/danswer/danswer-backend. For \
|
||||
more details, visit https://github.com/danswer-ai/danswer."
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
@ -12,7 +17,9 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# zip for Vespa step futher down
|
||||
# ca-certificates for HTTPS
|
||||
RUN apt-get update && \
|
||||
apt-get install -y cmake curl zip ca-certificates && \
|
||||
apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \
|
||||
libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@ -29,15 +36,25 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake libldap-2.5-0 libldap-2.5-0 && \
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \
|
||||
libldap-2.5-0 libldap-2.5-0 && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('wordnet', quiet=True); \
|
||||
nltk.download('punkt', quiet=True);"
|
||||
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./shared_models /app/shared_models
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
@ -1,5 +1,11 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
LABEL com.danswer.maintainer="founders@danswer.ai"
|
||||
LABEL com.danswer.description="This image is for the Danswer model server which runs all of the \
|
||||
AI models for Danswer. This container and all the code is MIT Licensed and free for all to use. \
|
||||
You can find it at https://hub.docker.com/r/danswer/danswer-model-server. For more details, \
|
||||
visit https://github.com/danswer-ai/danswer."
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
@ -11,25 +17,26 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
WORKDIR /app
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
AutoTokenizer.from_pretrained('danswer/intent-model'); \
|
||||
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
snapshot_download('danswer/intent-model'); \
|
||||
snapshot_download('intfloat/e5-base-v2'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
|
||||
|
||||
# Needed for model configs and defaults
|
||||
COPY ./danswer/configs /app/danswer/configs
|
||||
COPY ./danswer/dynamic_configs /app/danswer/dynamic_configs
|
||||
WORKDIR /app
|
||||
|
||||
# Utils used by model server
|
||||
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
|
||||
COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py
|
||||
COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py
|
||||
|
||||
# Place to fetch version information
|
||||
COPY ./danswer/__init__.py /app/danswer/__init__.py
|
||||
|
||||
# Shared implementations for running NLP models locally
|
||||
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
|
||||
|
||||
# Request/Response models
|
||||
COPY ./shared_models /app/shared_models
|
||||
# Shared between Danswer Backend and Model Server
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
|
||||
# Model Server main code
|
||||
COPY ./model_server /app/model_server
|
||||
|
31
backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py
Normal file
31
backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""Add starter prompts
|
||||
|
||||
Revision ID: 0a2b51deb0b8
|
||||
Revises: 5f4b8568a221
|
||||
Create Date: 2024-03-02 23:23:49.960309
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0a2b51deb0b8"
|
||||
down_revision = "5f4b8568a221"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"starter_messages",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "starter_messages")
|
113
backend/alembic/versions/0a98909f2757_enable_encrypted_fields.py
Normal file
113
backend/alembic/versions/0a98909f2757_enable_encrypted_fields.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Enable Encrypted Fields
|
||||
|
||||
Revision ID: 0a98909f2757
|
||||
Revises: 570282d33c49
|
||||
Create Date: 2024-05-05 19:30:34.317972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table
|
||||
from sqlalchemy.dialects import postgresql
|
||||
import json
|
||||
|
||||
from danswer.utils.encryption import encrypt_string_to_bytes
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0a98909f2757"
|
||||
down_revision = "570282d33c49"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
connection = op.get_bind()
|
||||
|
||||
op.alter_column("key_value_store", "value", nullable=True)
|
||||
op.add_column(
|
||||
"key_value_store",
|
||||
sa.Column(
|
||||
"encrypted_value",
|
||||
sa.LargeBinary,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Need a temporary column to translate the JSONB to binary
|
||||
op.add_column("credential", sa.Column("temp_column", sa.LargeBinary()))
|
||||
|
||||
creds_table = table(
|
||||
"credential",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"credential_json",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"temp_column",
|
||||
sa.LargeBinary(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
results = connection.execute(sa.select(creds_table))
|
||||
|
||||
# This uses the MIT encrypt which does not actually encrypt the credentials
|
||||
# In other words, this upgrade does not apply the encryption. Porting existing sensitive data
|
||||
# and key rotation currently is not supported and will come out in the future
|
||||
for row_id, creds, _ in results:
|
||||
creds_binary = encrypt_string_to_bytes(json.dumps(creds))
|
||||
connection.execute(
|
||||
creds_table.update()
|
||||
.where(creds_table.c.id == row_id)
|
||||
.values(temp_column=creds_binary)
|
||||
)
|
||||
|
||||
op.drop_column("credential", "credential_json")
|
||||
op.alter_column("credential", "temp_column", new_column_name="credential_json")
|
||||
|
||||
op.add_column("llm_provider", sa.Column("temp_column", sa.LargeBinary()))
|
||||
|
||||
llm_table = table(
|
||||
"llm_provider",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"api_key",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"temp_column",
|
||||
sa.LargeBinary(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
results = connection.execute(sa.select(llm_table))
|
||||
|
||||
for row_id, api_key, _ in results:
|
||||
llm_key = encrypt_string_to_bytes(api_key)
|
||||
connection.execute(
|
||||
llm_table.update()
|
||||
.where(llm_table.c.id == row_id)
|
||||
.values(temp_column=llm_key)
|
||||
)
|
||||
|
||||
op.drop_column("llm_provider", "api_key")
|
||||
op.alter_column("llm_provider", "temp_column", new_column_name="api_key")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Some information loss but this is ok. Should not allow decryption via downgrade.
|
||||
op.drop_column("credential", "credential_json")
|
||||
op.drop_column("llm_provider", "api_key")
|
||||
|
||||
op.add_column("llm_provider", sa.Column("api_key", sa.String()))
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column("credential_json", postgresql.JSONB(astext_type=sa.Text())),
|
||||
)
|
||||
|
||||
op.execute("DELETE FROM key_value_store WHERE value IS NULL")
|
||||
op.alter_column("key_value_store", "value", nullable=False)
|
||||
op.drop_column("key_value_store", "encrypted_value")
|
@ -13,8 +13,8 @@ from danswer.configs.constants import DocumentSource
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "15326fcec57e"
|
||||
down_revision = "77d07dffae64"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
29
backend/alembic/versions/173cae5bba26_port_config_store.py
Normal file
29
backend/alembic/versions/173cae5bba26_port_config_store.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Port Config Store
|
||||
|
||||
Revision ID: 173cae5bba26
|
||||
Revises: e50154680a5c
|
||||
Create Date: 2024-03-19 15:30:44.425436
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "173cae5bba26"
|
||||
down_revision = "e50154680a5c"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"key_value_store",
|
||||
sa.Column("key", sa.String(), nullable=False),
|
||||
sa.Column("value", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.PrimaryKeyConstraint("key"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("key_value_store")
|
@ -13,8 +13,8 @@ from alembic import op
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2666d766cb9b"
|
||||
down_revision = "6d387b3196c2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -13,8 +13,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "27c6ecc08586"
|
||||
down_revision = "2666d766cb9b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "30c1d5744104"
|
||||
down_revision = "7f99be1cb9f5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
45
backend/alembic/versions/3879338f8ba1_add_tool_table.py
Normal file
45
backend/alembic/versions/3879338f8ba1_add_tool_table.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Add tool table
|
||||
|
||||
Revision ID: 3879338f8ba1
|
||||
Revises: f1c6478c3fd8
|
||||
Create Date: 2024-05-11 16:11:23.718084
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3879338f8ba1"
|
||||
down_revision = "f1c6478c3fd8"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"tool",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("in_code_tool_id", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"persona__tool",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("tool_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["tool_id"],
|
||||
["tool.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "tool_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("persona__tool")
|
||||
op.drop_table("tool")
|
@ -0,0 +1,41 @@
|
||||
"""Add chat session sharing
|
||||
|
||||
Revision ID: 38eda64af7fe
|
||||
Revises: 776b3bbe9092
|
||||
Create Date: 2024-03-27 19:41:29.073594
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "38eda64af7fe"
|
||||
down_revision = "776b3bbe9092"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column(
|
||||
"shared_status",
|
||||
sa.Enum(
|
||||
"PUBLIC",
|
||||
"PRIVATE",
|
||||
name="chatsessionsharedstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute("UPDATE chat_session SET shared_status='PRIVATE'")
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"shared_status",
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "shared_status")
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b25685ff73c"
|
||||
down_revision = "e0a68a81d434"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ from alembic import op
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3c5e35aa9af0"
|
||||
down_revision = "27c6ecc08586"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -0,0 +1,49 @@
|
||||
"""Add tables for UI-based LLM configuration
|
||||
|
||||
Revision ID: 401c1ac29467
|
||||
Revises: 703313b75876
|
||||
Create Date: 2024-04-13 18:07:29.153817
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "401c1ac29467"
|
||||
down_revision = "703313b75876"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column("api_base", sa.String(), nullable=True),
|
||||
sa.Column("api_version", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"custom_config",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("default_model_name", sa.String(), nullable=False),
|
||||
sa.Column("fast_default_model_name", sa.String(), nullable=True),
|
||||
sa.Column("is_default_provider", sa.Boolean(), unique=True, nullable=True),
|
||||
sa.Column("model_names", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("llm_model_provider_override", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "llm_model_provider_override")
|
||||
|
||||
op.drop_table("llm_provider")
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "465f78d9b7f9"
|
||||
down_revision = "3c5e35aa9af0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ from sqlalchemy import String
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "46625e4745d4"
|
||||
down_revision = "9d97fecfab7f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
28
backend/alembic/versions/4738e4b3bae1_pg_file_store.py
Normal file
28
backend/alembic/versions/4738e4b3bae1_pg_file_store.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""PG File Store
|
||||
|
||||
Revision ID: 4738e4b3bae1
|
||||
Revises: e91df4e935ef
|
||||
Create Date: 2024-03-20 18:53:32.461518
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4738e4b3bae1"
|
||||
down_revision = "e91df4e935ef"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"file_store",
|
||||
sa.Column("file_name", sa.String(), nullable=False),
|
||||
sa.Column("lobj_oid", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("file_name"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("file_store")
|
@ -11,9 +11,9 @@ from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "47433d30de82"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
down_revision: None = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
23
backend/alembic/versions/475fcefe8826_add_name_to_api_key.py
Normal file
23
backend/alembic/versions/475fcefe8826_add_name_to_api_key.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Add name to api_key
|
||||
|
||||
Revision ID: 475fcefe8826
|
||||
Revises: ecab2b3f1a3b
|
||||
Create Date: 2024-04-11 11:05:18.414438
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "475fcefe8826"
|
||||
down_revision = "ecab2b3f1a3b"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("api_key", sa.Column("name", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("api_key", "name")
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "50b683a8295c"
|
||||
down_revision = "7da0ae5ad583"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -0,0 +1,27 @@
|
||||
"""Track Danswerbot Explicitly
|
||||
|
||||
Revision ID: 570282d33c49
|
||||
Revises: 7547d982db8f
|
||||
Create Date: 2024-05-04 17:49:28.568109
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "570282d33c49"
|
||||
down_revision = "7547d982db8f"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("danswerbot_flow", sa.Boolean(), nullable=True)
|
||||
)
|
||||
op.execute("UPDATE chat_session SET danswerbot_flow = one_shot")
|
||||
op.alter_column("chat_session", "danswerbot_flow", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "danswerbot_flow")
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "57b53544726e"
|
||||
down_revision = "800f48024ae9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -13,8 +13,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5809c0787398"
|
||||
down_revision = "d929f0c1c6af"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5e84129c8be3"
|
||||
down_revision = "e6a4bbc13fe4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -0,0 +1,27 @@
|
||||
"""add removed documents to index_attempt
|
||||
|
||||
Revision ID: 5f4b8568a221
|
||||
Revises: dbaa756c2ccf
|
||||
Create Date: 2024-02-16 15:02:03.319907
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5f4b8568a221"
|
||||
down_revision = "8987770549c0"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("docs_removed_from_index", sa.Integer()),
|
||||
)
|
||||
op.execute("UPDATE index_attempt SET docs_removed_from_index = 0")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("index_attempt", "docs_removed_from_index")
|
@ -0,0 +1,45 @@
|
||||
"""Add user-configured names to LLMProvider
|
||||
|
||||
Revision ID: 643a84a42a33
|
||||
Revises: 0a98909f2757
|
||||
Create Date: 2024-05-07 14:54:55.493100
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "643a84a42a33"
|
||||
down_revision = "0a98909f2757"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("llm_provider", sa.Column("provider", sa.String(), nullable=True))
|
||||
# move "name" -> "provider" to match the new schema
|
||||
op.execute("UPDATE llm_provider SET provider = name")
|
||||
# pretty up display name
|
||||
op.execute("UPDATE llm_provider SET name = 'OpenAI' WHERE name = 'openai'")
|
||||
op.execute("UPDATE llm_provider SET name = 'Anthropic' WHERE name = 'anthropic'")
|
||||
op.execute("UPDATE llm_provider SET name = 'Azure OpenAI' WHERE name = 'azure'")
|
||||
op.execute("UPDATE llm_provider SET name = 'AWS Bedrock' WHERE name = 'bedrock'")
|
||||
|
||||
# update personas to use the new provider names
|
||||
op.execute(
|
||||
"UPDATE persona SET llm_model_provider_override = 'OpenAI' WHERE llm_model_provider_override = 'openai'"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE persona SET llm_model_provider_override = 'Anthropic' WHERE llm_model_provider_override = 'anthropic'"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE persona SET llm_model_provider_override = 'Azure OpenAI' WHERE llm_model_provider_override = 'azure'"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE persona SET llm_model_provider_override = 'AWS Bedrock' WHERE llm_model_provider_override = 'bedrock'"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("UPDATE llm_provider SET name = provider")
|
||||
op.drop_column("llm_provider", "provider")
|
@ -13,8 +13,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6d387b3196c2"
|
||||
down_revision = "47433d30de82"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -0,0 +1,83 @@
|
||||
"""Add TokenRateLimit Tables
|
||||
|
||||
Revision ID: 703313b75876
|
||||
Revises: fad14119fb92
|
||||
Create Date: 2024-04-15 01:36:02.952809
|
||||
|
||||
"""
|
||||
import json
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "703313b75876"
|
||||
down_revision = "fad14119fb92"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"token_rate_limit",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("enabled", sa.Boolean(), nullable=False),
|
||||
sa.Column("token_budget", sa.Integer(), nullable=False),
|
||||
sa.Column("period_hours", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"scope",
|
||||
sa.String(length=10),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"token_rate_limit__user_group",
|
||||
sa.Column("rate_limit_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["rate_limit_id"],
|
||||
["token_rate_limit.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
settings_json = cast(
|
||||
str, get_dynamic_config_store().load("token_budget_settings")
|
||||
)
|
||||
settings = json.loads(settings_json)
|
||||
|
||||
is_enabled = settings.get("enable_token_budget", False)
|
||||
token_budget = settings.get("token_budget", -1)
|
||||
period_hours = settings.get("period_hours", -1)
|
||||
|
||||
if is_enabled and token_budget > 0 and period_hours > 0:
|
||||
op.execute(
|
||||
f"INSERT INTO token_rate_limit \
|
||||
(enabled, token_budget, period_hours, scope) VALUES \
|
||||
({is_enabled}, {token_budget}, {period_hours}, 'GLOBAL')"
|
||||
)
|
||||
|
||||
# Delete the dynamic config
|
||||
get_dynamic_config_store().delete("token_budget_settings")
|
||||
|
||||
except Exception:
|
||||
# Ignore if the dynamic config is not found
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("token_rate_limit__user_group")
|
||||
op.drop_table("token_rate_limit")
|
@ -0,0 +1,68 @@
|
||||
"""More Descriptive Filestore
|
||||
|
||||
Revision ID: 70f00c45c0f2
|
||||
Revises: 3879338f8ba1
|
||||
Create Date: 2024-05-17 17:51:41.926893
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "70f00c45c0f2"
|
||||
down_revision = "3879338f8ba1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("file_store", sa.Column("display_name", sa.String(), nullable=True))
|
||||
op.add_column(
|
||||
"file_store",
|
||||
sa.Column(
|
||||
"file_origin",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="connector", # Default to connector
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"file_store",
|
||||
sa.Column(
|
||||
"file_type", sa.String(), nullable=False, server_default="text/plain"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"file_store",
|
||||
sa.Column(
|
||||
"file_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE file_store
|
||||
SET file_origin = CASE
|
||||
WHEN file_name LIKE 'chat__%' THEN 'chat_upload'
|
||||
ELSE 'connector'
|
||||
END,
|
||||
file_name = CASE
|
||||
WHEN file_name LIKE 'chat__%' THEN SUBSTR(file_name, 7)
|
||||
ELSE file_name
|
||||
END,
|
||||
file_type = CASE
|
||||
WHEN file_name LIKE 'chat__%' THEN 'image/png'
|
||||
ELSE 'text/plain'
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("file_store", "file_metadata")
|
||||
op.drop_column("file_store", "file_type")
|
||||
op.drop_column("file_store", "file_origin")
|
||||
op.drop_column("file_store", "display_name")
|
@ -0,0 +1,81 @@
|
||||
"""Permission Auto Sync Framework
|
||||
|
||||
Revision ID: 72bdc9929a46
|
||||
Revises: 475fcefe8826
|
||||
Create Date: 2024-04-14 21:15:28.659634
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "72bdc9929a46"
|
||||
down_revision = "475fcefe8826"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"email_to_external_user_cache",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_user_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"external_permission",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("external_permission_group", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"permission_sync_run",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("update_type", sa.String(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("permission_sync_run")
|
||||
op.drop_table("external_permission")
|
||||
op.drop_table("email_to_external_user_cache")
|
51
backend/alembic/versions/7547d982db8f_chat_folders.py
Normal file
51
backend/alembic/versions/7547d982db8f_chat_folders.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Chat Folders
|
||||
|
||||
Revision ID: 7547d982db8f
|
||||
Revises: ef7da92f7213
|
||||
Create Date: 2024-05-02 15:18:56.573347
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7547d982db8f"
|
||||
down_revision = "ef7da92f7213"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chat_folder",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.add_column("chat_session", sa.Column("folder_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"chat_session_chat_folder_fk",
|
||||
"chat_session",
|
||||
"chat_folder",
|
||||
["folder_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"chat_session_chat_folder_fk", "chat_session", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("chat_session", "folder_id")
|
||||
op.drop_table("chat_folder")
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "767f1c2a00eb"
|
||||
down_revision = "dba7f71618f5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "76b60d407dfb"
|
||||
down_revision = "b156fa702355"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -0,0 +1,71 @@
|
||||
"""Remove Remaining Enums
|
||||
|
||||
Revision ID: 776b3bbe9092
|
||||
Revises: 4738e4b3bae1
|
||||
Create Date: 2024-03-22 21:34:27.629444
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
down_revision = "4738e4b3bae1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"search_type",
|
||||
type_=sa.String,
|
||||
existing_type=sa.Enum(SearchType, native_enum=False),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"recency_bias",
|
||||
type_=sa.String,
|
||||
existing_type=sa.Enum(RecencyBiasSetting, native_enum=False),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Because the indexmodelstatus enum does not have a mapping to a string type
|
||||
# we need this workaround instead of directly changing the type
|
||||
op.add_column("embedding_model", sa.Column("temp_status", sa.String))
|
||||
op.execute("UPDATE embedding_model SET temp_status = status::text")
|
||||
op.drop_column("embedding_model", "status")
|
||||
op.alter_column("embedding_model", "temp_status", new_column_name="status")
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS searchtype")
|
||||
op.execute("DROP TYPE IF EXISTS recencybiassetting")
|
||||
op.execute("DROP TYPE IF EXISTS indexmodelstatus")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"search_type",
|
||||
type_=sa.Enum(SearchType, native_enum=False),
|
||||
existing_type=sa.String(length=50),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"recency_bias",
|
||||
type_=sa.Enum(RecencyBiasSetting, native_enum=False),
|
||||
existing_type=sa.String(length=50),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"embedding_model",
|
||||
"status",
|
||||
type_=sa.Enum(IndexModelStatus, native_enum=False),
|
||||
existing_type=sa.String(length=50),
|
||||
existing_nullable=False,
|
||||
)
|
@ -12,8 +12,8 @@ from sqlalchemy import String
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "77d07dffae64"
|
||||
down_revision = "d61e513bef0a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "78dbe7e38469"
|
||||
down_revision = "7ccea01261f6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "79acd316403a"
|
||||
down_revision = "904e5138fffb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7ccea01261f6"
|
||||
down_revision = "a570b80a5f20"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7da0ae5ad583"
|
||||
down_revision = "e86866a9c78a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7da543f5672f"
|
||||
down_revision = "febe9eaa0644"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7f726bad5367"
|
||||
down_revision = "79acd316403a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ from alembic import op
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7f99be1cb9f5"
|
||||
down_revision = "78dbe7e38469"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ from sqlalchemy.schema import Sequence, CreateSequence
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "800f48024ae9"
|
||||
down_revision = "767f1c2a00eb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "80696cf850ae"
|
||||
down_revision = "15326fcec57e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "891cd83c87a8"
|
||||
down_revision = "76b60d407dfb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8987770549c0"
|
||||
down_revision = "ec3ec2eabf7b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8aabb57f3b49"
|
||||
down_revision = "5e84129c8be3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8e26726b7683"
|
||||
down_revision = "5809c0787398"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "904451035c9b"
|
||||
down_revision = "3b25685ff73c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "904e5138fffb"
|
||||
down_revision = "891cd83c87a8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -0,0 +1,36 @@
|
||||
"""Remove DocumentSource from Tag
|
||||
|
||||
Revision ID: 91fd3b470d1a
|
||||
Revises: 173cae5bba26
|
||||
Create Date: 2024-03-21 12:05:23.956734
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91fd3b470d1a"
|
||||
down_revision = "173cae5bba26"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"tag",
|
||||
"source",
|
||||
type_=sa.String(length=50),
|
||||
existing_type=sa.Enum(DocumentSource, native_enum=False),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"tag",
|
||||
"source",
|
||||
type_=sa.Enum(DocumentSource, native_enum=False),
|
||||
existing_type=sa.String(length=50),
|
||||
existing_nullable=False,
|
||||
)
|
@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9d97fecfab7f"
|
||||
down_revision = "ffc707a226b4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -0,0 +1,27 @@
|
||||
"""Add chosen_assistants to User table
|
||||
|
||||
Revision ID: a3bfd0d64902
|
||||
Revises: ec85f2b3c544
|
||||
Create Date: 2024-05-26 17:22:24.834741
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3bfd0d64902"
|
||||
down_revision = "ec85f2b3c544"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "chosen_assistants")
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a570b80a5f20"
|
||||
down_revision = "904451035c9b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ae62505e3acc"
|
||||
down_revision = "7da543f5672f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b082fec533f0"
|
||||
down_revision = "df0c7ad8a076"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -15,8 +15,8 @@ from danswer.configs.constants import DocumentSource
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b156fa702355"
|
||||
down_revision = "baf71f781b9e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
searchtype_enum = ENUM(
|
||||
|
@ -0,0 +1,28 @@
|
||||
"""fix-file-type-migration
|
||||
|
||||
Revision ID: b85f02ec1308
|
||||
Revises: a3bfd0d64902
|
||||
Create Date: 2024-05-31 18:09:26.658164
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b85f02ec1308"
|
||||
down_revision = "a3bfd0d64902"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE file_store
|
||||
SET file_origin = UPPER(file_origin)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Let's not break anything on purpose :)
|
||||
pass
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "baf71f781b9e"
|
||||
down_revision = "50b683a8295c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d5645c915d0e"
|
||||
down_revision = "8e26726b7683"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d61e513bef0a"
|
||||
down_revision = "46625e4745d4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d7111c1238cd"
|
||||
down_revision = "465f78d9b7f9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -13,8 +13,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d929f0c1c6af"
|
||||
down_revision = "8aabb57f3b49"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dba7f71618f5"
|
||||
down_revision = "d5645c915d0e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -9,18 +9,18 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.db.embedding_model import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
)
|
||||
from danswer.db.models import IndexModelStatus
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dbaa756c2ccf"
|
||||
down_revision = "7f726bad5367"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
@ -40,6 +40,9 @@ def upgrade() -> None:
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# since all index attempts must be associated with an embedding model,
|
||||
# need to put something in here to avoid nulls. On server startup,
|
||||
# this value will be overriden
|
||||
EmbeddingModel = table(
|
||||
"embedding_model",
|
||||
column("id", Integer),
|
||||
@ -53,20 +56,44 @@ def upgrade() -> None:
|
||||
"status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)
|
||||
),
|
||||
)
|
||||
# insert an embedding model row that corresponds to the embedding model
|
||||
# the user selected via env variables before this change. This is needed since
|
||||
# all index_attempts must be associated with an embedding model, so without this
|
||||
# we will run into violations of non-null contraints
|
||||
old_embedding_model = get_old_default_embedding_model()
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
{
|
||||
"model_name": DOCUMENT_ENCODER_MODEL,
|
||||
"model_dim": DOC_EMBEDDING_DIM,
|
||||
"normalize": NORMALIZE_EMBEDDINGS,
|
||||
"query_prefix": ASYM_QUERY_PREFIX,
|
||||
"passage_prefix": ASYM_PASSAGE_PREFIX,
|
||||
"index_name": "danswer_chunk",
|
||||
"status": IndexModelStatus.PRESENT,
|
||||
"model_name": old_embedding_model.model_name,
|
||||
"model_dim": old_embedding_model.model_dim,
|
||||
"normalize": old_embedding_model.normalize,
|
||||
"query_prefix": old_embedding_model.query_prefix,
|
||||
"passage_prefix": old_embedding_model.passage_prefix,
|
||||
"index_name": old_embedding_model.index_name,
|
||||
"status": old_embedding_model.status,
|
||||
}
|
||||
],
|
||||
)
|
||||
# if the user has not overridden the default embedding model via env variables,
|
||||
# insert the new default model into the database to auto-upgrade them
|
||||
if not user_has_overridden_embedding_model():
|
||||
new_embedding_model = get_new_default_embedding_model(is_present=False)
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
{
|
||||
"model_name": new_embedding_model.model_name,
|
||||
"model_dim": new_embedding_model.model_dim,
|
||||
"normalize": new_embedding_model.normalize,
|
||||
"query_prefix": new_embedding_model.query_prefix,
|
||||
"passage_prefix": new_embedding_model.passage_prefix,
|
||||
"index_name": new_embedding_model.index_name,
|
||||
"status": IndexModelStatus.FUTURE,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("embedding_model_id", sa.Integer(), nullable=True),
|
||||
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "df0c7ad8a076"
|
||||
down_revision = "d7111c1238cd"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e0a68a81d434"
|
||||
down_revision = "ae62505e3acc"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
38
backend/alembic/versions/e50154680a5c_no_source_enum.py
Normal file
38
backend/alembic/versions/e50154680a5c_no_source_enum.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""No Source Enum
|
||||
|
||||
Revision ID: e50154680a5c
|
||||
Revises: fcd135795f21
|
||||
Create Date: 2024-03-14 18:06:08.523106
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e50154680a5c"
|
||||
down_revision = "fcd135795f21"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"search_doc",
|
||||
"source_type",
|
||||
type_=sa.String(length=50),
|
||||
existing_type=sa.Enum(DocumentSource, native_enum=False),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.execute("DROP TYPE IF EXISTS documentsource")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"search_doc",
|
||||
"source_type",
|
||||
type_=sa.Enum(DocumentSource, native_enum=False),
|
||||
existing_type=sa.String(length=50),
|
||||
existing_nullable=False,
|
||||
)
|
@ -11,8 +11,8 @@ from alembic import op
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e6a4bbc13fe4"
|
||||
down_revision = "b082fec533f0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e86866a9c78a"
|
||||
down_revision = "80696cf850ae"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -0,0 +1,118 @@
|
||||
"""Private Personas DocumentSets
|
||||
|
||||
Revision ID: e91df4e935ef
|
||||
Revises: 91fd3b470d1a
|
||||
Create Date: 2024-03-17 11:47:24.675881
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e91df4e935ef"
|
||||
down_revision = "91fd3b470d1a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"document_set__user",
|
||||
sa.Column("document_set_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_set_id"],
|
||||
["document_set.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("document_set_id", "user_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"persona__user",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "user_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"document_set__user_group",
|
||||
sa.Column("document_set_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_group_id",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_set_id"],
|
||||
["document_set.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("document_set_id", "user_group_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"persona__user_group",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_group_id",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "user_group_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"document_set",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=True),
|
||||
)
|
||||
# fill in is_public for existing rows
|
||||
op.execute("UPDATE document_set SET is_public = true WHERE is_public IS NULL")
|
||||
op.alter_column("document_set", "is_public", nullable=False)
|
||||
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=True),
|
||||
)
|
||||
# fill in is_public for existing rows
|
||||
op.execute("UPDATE persona SET is_public = true WHERE is_public IS NULL")
|
||||
op.alter_column("persona", "is_public", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "is_public")
|
||||
|
||||
op.drop_column("document_set", "is_public")
|
||||
|
||||
op.drop_table("persona__user")
|
||||
op.drop_table("document_set__user")
|
||||
op.drop_table("persona__user_group")
|
||||
op.drop_table("document_set__user_group")
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ec3ec2eabf7b"
|
||||
down_revision = "dbaa756c2ccf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -0,0 +1,31 @@
|
||||
"""Remove Last Attempt Status from CC Pair
|
||||
|
||||
Revision ID: ec85f2b3c544
|
||||
Revises: 3879338f8ba1
|
||||
Create Date: 2024-05-23 21:39:46.126010
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ec85f2b3c544"
|
||||
down_revision = "70f00c45c0f2"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_attempt_status")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_attempt_status",
|
||||
sa.VARCHAR(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
@ -0,0 +1,40 @@
|
||||
"""Add overrides to the chat session
|
||||
|
||||
Revision ID: ecab2b3f1a3b
|
||||
Revises: 38eda64af7fe
|
||||
Create Date: 2024-04-01 19:08:21.359102
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ecab2b3f1a3b"
|
||||
down_revision = "38eda64af7fe"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column(
|
||||
"llm_override",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column(
|
||||
"prompt_override",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "prompt_override")
|
||||
op.drop_column("chat_session", "llm_override")
|
@ -0,0 +1,27 @@
|
||||
"""Add files to ChatMessage
|
||||
|
||||
Revision ID: ef7da92f7213
|
||||
Revises: 401c1ac29467
|
||||
Create Date: 2024-04-28 16:59:33.199153
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ef7da92f7213"
|
||||
down_revision = "401c1ac29467"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("files", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "files")
|
@ -0,0 +1,25 @@
|
||||
"""Add pre-defined feedback
|
||||
|
||||
Revision ID: f1c6478c3fd8
|
||||
Revises: 643a84a42a33
|
||||
Create Date: 2024-05-09 18:11:49.210667
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "f1c6478c3fd8"
|
||||
down_revision = "643a84a42a33"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_feedback",
|
||||
sa.Column("predefined_feedback", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_feedback", "predefined_feedback")
|
@ -0,0 +1,39 @@
|
||||
"""Delete Tags with wrong Enum
|
||||
|
||||
Revision ID: fad14119fb92
|
||||
Revises: 72bdc9929a46
|
||||
Create Date: 2024-04-25 17:05:09.695703
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fad14119fb92"
|
||||
down_revision = "72bdc9929a46"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Some documents may lose their tags but this is the only way as the enum
|
||||
# mapping may have changed since tag switched to string (it will be reindexed anyway)
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM document__tag
|
||||
WHERE tag_id IN (
|
||||
SELECT id FROM tag
|
||||
WHERE source ~ '^[0-9]+$'
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM tag
|
||||
WHERE source ~ '^[0-9]+$'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
@ -0,0 +1,39 @@
|
||||
"""Add slack bot display type
|
||||
|
||||
Revision ID: fcd135795f21
|
||||
Revises: 0a2b51deb0b8
|
||||
Create Date: 2024-03-04 17:03:27.116284
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fcd135795f21"
|
||||
down_revision = "0a2b51deb0b8"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"slack_bot_config",
|
||||
sa.Column(
|
||||
"response_type",
|
||||
sa.Enum(
|
||||
"QUOTES",
|
||||
"CITATIONS",
|
||||
name="slackbotresponsetype",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE slack_bot_config SET response_type = 'QUOTES' WHERE response_type IS NULL"
|
||||
)
|
||||
op.alter_column("slack_bot_config", "response_type", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("slack_bot_config", "response_type")
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "febe9eaa0644"
|
||||
down_revision = "57b53544726e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ffc707a226b4"
|
||||
down_revision = "30c1d5744104"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
40
backend/danswer/auth/noauth_user.py
Normal file
40
backend/danswer/auth/noauth_user.py
Normal file
@ -0,0 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.dynamic_configs.store import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.store import DynamicConfigStore
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.server.manage.models import UserPreferences
|
||||
|
||||
|
||||
NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
|
||||
|
||||
|
||||
def set_no_auth_user_preferences(
|
||||
store: DynamicConfigStore, preferences: UserPreferences
|
||||
) -> None:
|
||||
store.store(NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
|
||||
|
||||
|
||||
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
|
||||
try:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except ConfigNotFoundError:
|
||||
return UserPreferences(chosen_assistants=None)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
|
||||
return UserInfo(
|
||||
id="__no_auth_user__",
|
||||
email="anonymous@danswer.ai",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
)
|
@ -23,8 +23,8 @@ from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from fastapi_users.openapi import OpenAPIResponseType
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserCreate
|
||||
@ -33,15 +33,18 @@ from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SMTP_PASS
|
||||
from danswer.configs.app_configs import SMTP_PORT
|
||||
from danswer.configs.app_configs import SMTP_SERVER
|
||||
from danswer.configs.app_configs import SMTP_USER
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
@ -69,6 +72,20 @@ def verify_auth_setting() -> None:
|
||||
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
if email and email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN):
|
||||
name = email.split("@")[0]
|
||||
if name == DANSWER_API_KEY_PREFIX + UNNAMED_KEY_PLACEHOLDER:
|
||||
return "Unnamed API Key"
|
||||
|
||||
if space_less:
|
||||
return name
|
||||
|
||||
return name.replace("API_KEY__", "API Key: ")
|
||||
|
||||
return email or ""
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
@ -133,8 +150,8 @@ def send_user_verification_email(
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = SECRET
|
||||
verification_token_secret = SECRET
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
|
||||
async def create(
|
||||
self,
|
||||
@ -213,7 +230,10 @@ async def get_user_manager(
|
||||
yield UserManager(user_db)
|
||||
|
||||
|
||||
cookie_transport = CookieTransport(cookie_max_age=SESSION_EXPIRE_TIME_SECONDS)
|
||||
cookie_transport = CookieTransport(
|
||||
cookie_max_age=SESSION_EXPIRE_TIME_SECONDS,
|
||||
cookie_secure=WEB_DOMAIN.startswith("https"),
|
||||
)
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
@ -276,13 +296,32 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
# take care of that in `double_check_user` ourself. This is needed, since
|
||||
# we want the /me endpoint to still return a user even if they are not
|
||||
# yet verified, so that the frontend knows they exist
|
||||
optional_valid_user = fastapi_users.current_user(active=True, optional=True)
|
||||
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
async def optional_user_(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> User | None:
|
||||
"""NOTE: `request` and `db_session` are not used here, but are included
|
||||
for the EE version of this function."""
|
||||
return user
|
||||
|
||||
|
||||
async def optional_user(
|
||||
request: Request,
|
||||
user: User | None = Depends(optional_fastapi_current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> User | None:
|
||||
versioned_fetch_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "optional_user_"
|
||||
)
|
||||
return await versioned_fetch_user(request, user, db_session)
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
@ -304,15 +343,9 @@ async def double_check_user(
|
||||
|
||||
|
||||
async def current_user(
|
||||
request: Request,
|
||||
user: User | None = Depends(optional_valid_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
double_check_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "double_check_user"
|
||||
)
|
||||
user = await double_check_user(request, user, db_session)
|
||||
return user
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
|
||||
|
@ -1,6 +1,4 @@
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery # type: ignore
|
||||
@ -10,16 +8,14 @@ from danswer.background.connector_deletion import delete_connector_credential_pa
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.connectors.file.utils import file_age_in_hours
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.document_set import fetch_documents_for_document_set
|
||||
from danswer.db.document_set import fetch_documents_for_document_set_paginated
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import build_connection_string
|
||||
@ -31,7 +27,6 @@ from danswer.db.tasks import get_latest_task
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -42,7 +37,7 @@ celery_backend_url = f"db+{connection_string}"
|
||||
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
|
||||
|
||||
|
||||
_SYNC_BATCH_SIZE = 1000
|
||||
_SYNC_BATCH_SIZE = 100
|
||||
|
||||
|
||||
#####
|
||||
@ -67,15 +62,18 @@ def cleanup_connector_credential_pair_task(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if not cc_pair or not check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair
|
||||
):
|
||||
if not cc_pair:
|
||||
raise ValueError(
|
||||
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
|
||||
"This is likely because there is an ongoing / planned indexing attempt OR the "
|
||||
"connector is not disabled."
|
||||
f"Cannot run deletion attempt - connector_credential_pair with Connector ID: "
|
||||
f"{connector_id} and Credential ID: {credential_id} does not exist."
|
||||
)
|
||||
|
||||
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair, db_session=db_session
|
||||
)
|
||||
if deletion_attempt_disallowed_reason:
|
||||
raise ValueError(deletion_attempt_disallowed_reason)
|
||||
|
||||
try:
|
||||
# The bulk of the work is in here, updates Postgres and Vespa
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
@ -98,15 +96,13 @@ def sync_document_set_task(document_set_id: int) -> None:
|
||||
"""For document sets marked as not up to date, sync the state from postgres
|
||||
into the datastore. Also handles deletions."""
|
||||
|
||||
def _sync_document_batch(document_ids: list[str]) -> None:
|
||||
def _sync_document_batch(document_ids: list[str], db_session: Session) -> None:
|
||||
logger.debug(f"Syncing document sets for: {document_ids}")
|
||||
# begin a transaction, release lock at the end
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquires a lock on the documents so that no other process can modify them
|
||||
prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
with prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
):
|
||||
# get current state of document sets for these documents
|
||||
document_set_map = {
|
||||
document_id: document_sets
|
||||
@ -131,17 +127,21 @@ def sync_document_set_task(document_set_id: int) -> None:
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
documents_to_update = fetch_documents_for_document_set(
|
||||
document_set_id=document_set_id,
|
||||
db_session=db_session,
|
||||
current_only=False,
|
||||
)
|
||||
for document_batch in batch_generator(
|
||||
documents_to_update, _SYNC_BATCH_SIZE
|
||||
):
|
||||
_sync_document_batch(
|
||||
document_ids=[document.id for document in document_batch]
|
||||
cursor = None
|
||||
while True:
|
||||
document_batch, cursor = fetch_documents_for_document_set_paginated(
|
||||
document_set_id=document_set_id,
|
||||
db_session=db_session,
|
||||
current_only=False,
|
||||
last_document_id=cursor,
|
||||
limit=_SYNC_BATCH_SIZE,
|
||||
)
|
||||
_sync_document_batch(
|
||||
document_ids=[document.id for document in document_batch],
|
||||
db_session=db_session,
|
||||
)
|
||||
if cursor is None:
|
||||
break
|
||||
|
||||
# if there are no connectors, then delete the document set. Otherwise, just
|
||||
# mark it as successfully synced.
|
||||
@ -182,7 +182,7 @@ def check_for_document_sets_sync_task() -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
db_session=db_session, include_outdated=True
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
if not document_set.is_up_to_date:
|
||||
@ -203,21 +203,6 @@ def check_for_document_sets_sync_task() -> None:
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT)
|
||||
def clean_old_temp_files_task(
|
||||
age_threshold_in_hours: float | int = 24 * 7, # 1 week,
|
||||
base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH,
|
||||
) -> None:
|
||||
"""Files added via the File connector need to be deleted after ingestion
|
||||
Currently handled async of the indexing job"""
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
for file in os.listdir(base_path):
|
||||
full_file_path = Path(base_path) / file
|
||||
if file_age_in_hours(full_file_path) > age_threshold_in_hours:
|
||||
logger.info(f"Cleaning up uploaded file: {full_file_path}")
|
||||
os.remove(full_file_path)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
@ -226,8 +211,4 @@ celery_app.conf.beat_schedule = {
|
||||
"task": "check_for_document_sets_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
"clean-old-temp-files": {
|
||||
"task": "clean_old_temp_files_task",
|
||||
"schedule": timedelta(minutes=30),
|
||||
},
|
||||
}
|
||||
|
@ -19,8 +19,8 @@ from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair
|
||||
from danswer.db.document import delete_documents_complete
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
@ -47,60 +47,65 @@ def _delete_connector_credential_pair_batch(
|
||||
credential_id: int,
|
||||
document_index: DocumentIndex,
|
||||
) -> None:
|
||||
"""
|
||||
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
|
||||
it gets permanently deleted.
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquire lock for all documents in this batch so that indexing can't
|
||||
# override the deletion
|
||||
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
|
||||
|
||||
document_connector_cnts = get_document_connector_cnts(
|
||||
with prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
|
||||
# figure out which docs need to be completely deleted
|
||||
document_ids_to_delete = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt == 1
|
||||
]
|
||||
logger.debug(f"Deleting documents: {document_ids_to_delete}")
|
||||
|
||||
document_index.delete(doc_ids=document_ids_to_delete)
|
||||
|
||||
delete_documents_complete(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_delete,
|
||||
)
|
||||
|
||||
# figure out which docs need to be updated
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
]
|
||||
access_for_documents = get_access_for_documents(
|
||||
document_ids=document_ids_to_update,
|
||||
db_session=db_session,
|
||||
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=access,
|
||||
):
|
||||
document_connector_cnts = get_document_connector_cnts(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
|
||||
document_index.update(update_requests=update_requests)
|
||||
# figure out which docs need to be completely deleted
|
||||
document_ids_to_delete = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt == 1
|
||||
]
|
||||
logger.debug(f"Deleting documents: {document_ids_to_delete}")
|
||||
|
||||
delete_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
document_index.delete(doc_ids=document_ids_to_delete)
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_delete,
|
||||
)
|
||||
|
||||
# figure out which docs need to be updated
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
]
|
||||
access_for_documents = get_access_for_documents(
|
||||
document_ids=document_ids_to_update,
|
||||
db_session=db_session,
|
||||
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=access,
|
||||
)
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cleanup_synced_entities(
|
||||
|
@ -6,16 +6,16 @@ NOTE: cannot use Celery directly due to
|
||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Process
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from torch import multiprocessing
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
JobStatusType = (
|
||||
Literal["error"]
|
||||
| Literal["finished"]
|
||||
@ -25,12 +25,28 @@ JobStatusType = (
|
||||
)
|
||||
|
||||
|
||||
def _initializer(
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Ensure the parent proc's database connections are not touched
|
||||
in the new connection pool
|
||||
|
||||
Based on the recommended approach in the SQLAlchemy docs found:
|
||||
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
get_sqlalchemy_engine().dispose(close=False)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleJob:
|
||||
"""Drop in replacement for `dask.distributed.Future`"""
|
||||
|
||||
id: int
|
||||
process: multiprocessing.Process | None = None
|
||||
process: Optional["Process"] = None
|
||||
|
||||
def cancel(self) -> bool:
|
||||
return self.release()
|
||||
@ -95,7 +111,7 @@ class SimpleJobClient:
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
process = multiprocessing.Process(target=func, args=args, daemon=True)
|
||||
process = Process(target=_initializer(func=func, args=args), daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
|
@ -4,10 +4,13 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import torch
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.connector_deletion import (
|
||||
_delete_connector_credential_pair_batch,
|
||||
)
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from danswer.configs.app_configs import DISABLE_DOCUMENT_CLEANUP
|
||||
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@ -19,10 +22,11 @@ from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
@ -42,8 +46,14 @@ def _get_document_generator(
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""NOTE: `start_time` and `end_time` are only used for poll connectors"""
|
||||
) -> tuple[GenerateDocumentsOutput, bool]:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
|
||||
Returns an interator of document batches and whether the returned documents
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
task = attempt.connector.input_type
|
||||
|
||||
try:
|
||||
@ -65,7 +75,7 @@ def _get_document_generator(
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
|
||||
is_listing_complete = True
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
@ -78,12 +88,13 @@ def _get_document_generator(
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time.timestamp(), end=end_time.timestamp()
|
||||
)
|
||||
is_listing_complete = False
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator
|
||||
return doc_batch_generator, is_listing_complete
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
@ -104,16 +115,6 @@ def _run_indexing(
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
|
||||
# Mark as started
|
||||
mark_attempt_in_progress(index_attempt, db_session)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
)
|
||||
|
||||
# Indexing is only done into one index at a time
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=index_name, secondary_index_name=None
|
||||
@ -131,6 +132,7 @@ def _run_indexing(
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (db_embedding_model.status == IndexModelStatus.FUTURE),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector
|
||||
@ -158,19 +160,20 @@ def _run_indexing(
|
||||
source_type=db_connector.source,
|
||||
)
|
||||
):
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
doc_batch_generator = _get_document_generator(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
)
|
||||
|
||||
try:
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
doc_batch_generator, is_listing_complete = _get_document_generator(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
for doc_batch in doc_batch_generator:
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
@ -186,6 +189,7 @@ def _run_indexing(
|
||||
|
||||
db_session.refresh(index_attempt)
|
||||
if index_attempt.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError("Index Attempt was canceled")
|
||||
|
||||
logger.debug(
|
||||
@ -202,6 +206,7 @@ def _run_indexing(
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
@ -216,6 +221,40 @@ def _run_indexing(
|
||||
index_attempt=index_attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
if is_listing_complete and not DISABLE_DOCUMENT_CLEANUP:
|
||||
# clean up all documents from the index that have not been returned from the connector
|
||||
all_indexed_document_ids = {
|
||||
d.id
|
||||
for d in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
)
|
||||
}
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids
|
||||
)
|
||||
logger.debug(
|
||||
f"Cleaning up {len(doc_ids_to_remove)} documents that are not contained in the newest connector state"
|
||||
)
|
||||
|
||||
# delete docs from cc-pair and receive the number of completely deleted docs in return
|
||||
_delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=len(doc_ids_to_remove),
|
||||
)
|
||||
|
||||
run_end_dt = window_end
|
||||
@ -224,7 +263,6 @@ def _run_indexing(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
@ -255,7 +293,6 @@ def _run_indexing(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
raise e
|
||||
@ -270,7 +307,6 @@ def _run_indexing(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.SUCCESS,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
@ -282,7 +318,35 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
|
||||
def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
||||
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
|
||||
# make sure that the index attempt can't change in between checking the
|
||||
# status and marking it as in_progress. This setting will be discarded
|
||||
# after the next commit:
|
||||
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
|
||||
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
|
||||
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
if attempt is None:
|
||||
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
||||
|
||||
if attempt.status != IndexingStatus.NOT_STARTED:
|
||||
raise RuntimeError(
|
||||
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
|
||||
f"Current status is '{attempt.status}'."
|
||||
)
|
||||
|
||||
# only commit once, to make sure this all happens in a single transaction
|
||||
mark_attempt_in_progress__no_commit(attempt)
|
||||
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
|
||||
db_session.commit()
|
||||
|
||||
return attempt
|
||||
|
||||
|
||||
def run_indexing_entrypoint(index_attempt_id: int) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
@ -291,17 +355,10 @@ def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
|
||||
logger.info(f"Setting task to use {num_threads} threads")
|
||||
torch.set_num_threads(num_threads)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
if attempt is None:
|
||||
raise RuntimeError(
|
||||
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
|
||||
)
|
||||
# make sure that it is valid to run this indexing attempt + mark it
|
||||
# as in progress
|
||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
logger.info(
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
@ -309,10 +366,7 @@ def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(
|
||||
db_session=db_session,
|
||||
index_attempt=attempt,
|
||||
)
|
||||
_run_indexing(db_session, attempt)
|
||||
|
||||
logger.info(
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
|
@ -3,7 +3,6 @@ import time
|
||||
from datetime import datetime
|
||||
|
||||
import dask
|
||||
import torch
|
||||
from dask.distributed import Client
|
||||
from dask.distributed import Future
|
||||
from distributed import LocalCluster
|
||||
@ -15,21 +14,13 @@ from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import LOG_LEVEL
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.embedding_model import update_embedding_model_status
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
@ -41,7 +32,12 @@ from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -54,22 +50,19 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
)
|
||||
|
||||
|
||||
"""Util funcs"""
|
||||
|
||||
|
||||
def _get_num_threads() -> int:
|
||||
"""Get # of "threads" to use for ML models in an indexing job. By default uses
|
||||
the torch implementation, which returns the # of physical cores on the machine.
|
||||
"""
|
||||
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
|
||||
|
||||
|
||||
def _should_create_new_indexing(
|
||||
connector: Connector,
|
||||
last_index: IndexAttempt | None,
|
||||
model: EmbeddingModel,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if model.status == IndexModelStatus.FUTURE and not last_index:
|
||||
if connector.id == 0: # Ingestion API
|
||||
@ -124,17 +117,6 @@ def _mark_run_failed(
|
||||
db_session=db_session,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
if (
|
||||
index_attempt.connector_id is not None
|
||||
and index_attempt.credential_id is not None
|
||||
and index_attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
):
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector_id,
|
||||
credential_id=index_attempt.credential_id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
)
|
||||
|
||||
|
||||
"""Main funcs"""
|
||||
@ -185,7 +167,11 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector, last_attempt, model, db_session
|
||||
connector=connector,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
@ -193,16 +179,6 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
|
||||
# CC-Pair will have the status that it should for the primary index
|
||||
# Will be re-sync-ed once the indices are swapped
|
||||
if model.status == IndexModelStatus.PRESENT:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
attempt_status=IndexingStatus.NOT_STARTED,
|
||||
)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
@ -254,6 +230,9 @@ def cleanup_indexing_jobs(
|
||||
)
|
||||
for index_attempt in in_progress_indexing_attempts:
|
||||
if index_attempt.id in existing_jobs:
|
||||
# If index attempt is canceled, stop the run
|
||||
if index_attempt.status == IndexingStatus.FAILED:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
# check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# on the fact that the `time_updated` field is constantly updated every
|
||||
@ -328,12 +307,10 @@ def kickoff_indexing_jobs(
|
||||
|
||||
if use_secondary_index:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
||||
run_indexing_entrypoint, attempt.id, pure=False
|
||||
)
|
||||
else:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
||||
)
|
||||
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
|
||||
|
||||
if run:
|
||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||
@ -348,49 +325,22 @@ def kickoff_indexing_jobs(
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def check_index_swap(db_session: Session) -> None:
|
||||
"""Get count of cc-pairs and count of index_attempts for the new model grouped by
|
||||
connector + credential, if it's the same, then assume new index is done building.
|
||||
This does not take into consideration if the attempt failed or not"""
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = len(all_cc_pairs) - 1
|
||||
embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
if not embedding_model:
|
||||
return
|
||||
|
||||
unique_cc_indexings = count_unique_cc_pairs_with_index_attempts(
|
||||
embedding_model_id=embedding_model.id, db_session=db_session
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
if unique_cc_indexings > cc_pair_count:
|
||||
raise RuntimeError("More unique indexings than cc pairs, should not occur")
|
||||
|
||||
if cc_pair_count == unique_cc_indexings:
|
||||
# Swap indices
|
||||
now_old_embedding_model = get_current_db_embedding_model(db_session)
|
||||
update_embedding_model_status(
|
||||
embedding_model=now_old_embedding_model,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_embedding_model_status(
|
||||
embedding_model=embedding_model,
|
||||
new_status=IndexModelStatus.PRESENT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Expire jobs for the now past index/embedding model
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
# Recount aggregates
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
@ -417,12 +367,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
client_secondary = SimpleJobClient(n_workers=num_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
with Session(engine) as db_session:
|
||||
# Previous version did not always clean up cc-pairs well leaving some connectors undeleteable
|
||||
# This ensures that bad states get cleaned up
|
||||
mark_all_in_progress_cc_pairs_failed(db_session)
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
@ -454,12 +398,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
# needed for CUDA to work with multiprocessing
|
||||
# NOTE: needs to be done on application startup
|
||||
# before any other torch code has been run
|
||||
if not DASK_JOB_CLIENT_ENABLED:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
||||
|
@ -1,168 +1,39 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
|
||||
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||
from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
|
||||
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
|
||||
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
||||
from danswer.prompts.token_counts import (
|
||||
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
|
||||
)
|
||||
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
# Maps connector enum string to a more natural language representation for the LLM
|
||||
# If not on the list, uses the original but slightly cleaned up, see below
|
||||
CONNECTOR_NAME_MAP = {
|
||||
"web": "Website",
|
||||
"requesttracker": "Request Tracker",
|
||||
"github": "GitHub",
|
||||
"file": "File Upload",
|
||||
}
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def clean_up_source(source_str: str) -> str:
|
||||
if source_str in CONNECTOR_NAME_MAP:
|
||||
return CONNECTOR_NAME_MAP[source_str]
|
||||
return source_str.replace("_", " ").title()
|
||||
|
||||
|
||||
def build_doc_context_str(
|
||||
semantic_identifier: str,
|
||||
source_type: DocumentSource,
|
||||
content: str,
|
||||
metadata_dict: dict[str, str | list[str]],
|
||||
updated_at: datetime | None,
|
||||
ind: int,
|
||||
include_metadata: bool = True,
|
||||
) -> str:
|
||||
context_str = ""
|
||||
if include_metadata:
|
||||
context_str += f"DOCUMENT {ind}: {semantic_identifier}\n"
|
||||
context_str += f"Source: {clean_up_source(source_type)}\n"
|
||||
|
||||
for k, v in metadata_dict.items():
|
||||
if isinstance(v, list):
|
||||
v_str = ", ".join(v)
|
||||
context_str += f"{k.capitalize()}: {v_str}\n"
|
||||
else:
|
||||
context_str += f"{k.capitalize()}: {v}\n"
|
||||
|
||||
if updated_at:
|
||||
update_str = updated_at.strftime("%B %d, %Y %H:%M")
|
||||
context_str += f"Updated: {update_str}\n"
|
||||
context_str += f"{CODE_BLOCK_PAT.format(content.strip())}\n\n\n"
|
||||
return context_str
|
||||
|
||||
|
||||
def build_complete_context_str(
|
||||
context_docs: list[LlmDoc | InferenceChunk],
|
||||
include_metadata: bool = True,
|
||||
) -> str:
|
||||
context_str = ""
|
||||
for ind, doc in enumerate(context_docs, start=1):
|
||||
context_str += build_doc_context_str(
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
source_type=doc.source_type,
|
||||
content=doc.content,
|
||||
metadata_dict=doc.metadata,
|
||||
updated_at=doc.updated_at,
|
||||
ind=ind,
|
||||
include_metadata=include_metadata,
|
||||
)
|
||||
|
||||
return context_str.strip()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def build_chat_system_message(
|
||||
prompt: Prompt,
|
||||
context_exists: bool,
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
citation_line: str = REQUIRE_CITATION_STATEMENT,
|
||||
no_citation_line: str = NO_CITATION_STATEMENT,
|
||||
) -> tuple[SystemMessage | None, int]:
|
||||
system_prompt = prompt.system_prompt.strip()
|
||||
if prompt.include_citations:
|
||||
if context_exists:
|
||||
system_prompt += citation_line
|
||||
else:
|
||||
system_prompt += no_citation_line
|
||||
if prompt.datetime_aware:
|
||||
if system_prompt:
|
||||
system_prompt += (
|
||||
f"\n\nAdditional Information:\n\t- {get_current_llm_day_time()}."
|
||||
)
|
||||
else:
|
||||
system_prompt = get_current_llm_day_time()
|
||||
|
||||
if not system_prompt:
|
||||
return None, 0
|
||||
|
||||
token_count = len(llm_tokenizer_encode_func(system_prompt))
|
||||
system_msg = SystemMessage(content=system_prompt)
|
||||
|
||||
return system_msg, token_count
|
||||
|
||||
|
||||
def build_task_prompt_reminders(
|
||||
prompt: Prompt,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
citation_str: str = CITATION_REMINDER,
|
||||
language_hint_str: str = LANGUAGE_HINT,
|
||||
) -> str:
|
||||
base_task = prompt.task_prompt
|
||||
citation_or_nothing = citation_str if prompt.include_citations else ""
|
||||
language_hint_or_nothing = language_hint_str.lstrip() if use_language_hint else ""
|
||||
return base_task + citation_or_nothing + language_hint_or_nothing
|
||||
|
||||
|
||||
def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc:
|
||||
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inf_chunk.document_id,
|
||||
content=inf_chunk.content,
|
||||
# This one is using the combined content of all the chunks of the section
|
||||
# In default settings, this is the same as just the content of base chunk
|
||||
content=inf_chunk.combined_content,
|
||||
blurb=inf_chunk.blurb,
|
||||
semantic_identifier=inf_chunk.semantic_identifier,
|
||||
source_type=inf_chunk.source_type,
|
||||
metadata=inf_chunk.metadata,
|
||||
updated_at=inf_chunk.updated_at,
|
||||
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
|
||||
source_links=inf_chunk.source_links,
|
||||
)
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: list[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> dict[str, int]:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
@ -174,157 +45,6 @@ def map_document_id_order(
|
||||
return order_mapping
|
||||
|
||||
|
||||
def build_chat_user_message(
|
||||
chat_message: ChatMessage,
|
||||
prompt: Prompt,
|
||||
context_docs: list[LlmDoc],
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
all_doc_useful: bool,
|
||||
user_prompt_template: str = CHAT_USER_PROMPT,
|
||||
context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT,
|
||||
ignore_str: str = DEFAULT_IGNORE_STATEMENT,
|
||||
) -> tuple[HumanMessage, int]:
|
||||
user_query = chat_message.message
|
||||
|
||||
if not context_docs:
|
||||
# Simpler prompt for cases where there is no context
|
||||
user_prompt = (
|
||||
context_free_template.format(
|
||||
task_prompt=prompt.task_prompt, user_query=user_query
|
||||
)
|
||||
if prompt.task_prompt
|
||||
else user_query
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
token_count = len(llm_tokenizer_encode_func(user_prompt))
|
||||
user_msg = HumanMessage(content=user_prompt)
|
||||
return user_msg, token_count
|
||||
|
||||
context_docs_str = build_complete_context_str(
|
||||
cast(list[LlmDoc | InferenceChunk], context_docs)
|
||||
)
|
||||
optional_ignore = "" if all_doc_useful else ignore_str
|
||||
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
|
||||
|
||||
user_prompt = user_prompt_template.format(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=user_query,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
token_count = len(llm_tokenizer_encode_func(user_prompt))
|
||||
user_msg = HumanMessage(content=user_prompt)
|
||||
|
||||
return user_msg, token_count
|
||||
|
||||
|
||||
def _get_usable_chunks(
|
||||
chunks: list[InferenceChunk], token_limit: int
|
||||
) -> list[InferenceChunk]:
|
||||
total_token_count = 0
|
||||
usable_chunks = []
|
||||
for chunk in chunks:
|
||||
chunk_token_count = check_number_of_tokens(chunk.content)
|
||||
if total_token_count + chunk_token_count > token_limit:
|
||||
break
|
||||
|
||||
total_token_count += chunk_token_count
|
||||
usable_chunks.append(chunk)
|
||||
|
||||
# try and return at least one chunk if possible. This chunk will
|
||||
# get truncated later on in the pipeline. This would only occur if
|
||||
# the first chunk is larger than the token limit (usually due to character
|
||||
# count -> token count mismatches caused by special characters / non-ascii
|
||||
# languages)
|
||||
if not usable_chunks and chunks:
|
||||
usable_chunks = [chunks[0]]
|
||||
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def get_usable_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
token_limit: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
offset_into_chunks = 0
|
||||
usable_chunks: list[InferenceChunk] = []
|
||||
for _ in range(min(offset + 1, 1)): # go through this process at least once
|
||||
if offset_into_chunks >= len(chunks) and offset_into_chunks > 0:
|
||||
raise ValueError(
|
||||
"Chunks offset too large, should not retry this many times"
|
||||
)
|
||||
|
||||
usable_chunks = _get_usable_chunks(
|
||||
chunks=chunks[offset_into_chunks:], token_limit=token_limit
|
||||
)
|
||||
offset_into_chunks += len(usable_chunks)
|
||||
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def get_chunks_for_qa(
|
||||
chunks: list[InferenceChunk],
|
||||
llm_chunk_selection: list[bool],
|
||||
token_limit: int | None,
|
||||
batch_offset: int = 0,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Gives back indices of chunks to pass into the LLM for Q&A.
|
||||
|
||||
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
|
||||
by the LLM in a separate flow (this can be turned off)
|
||||
|
||||
Note, the batch_offset calculation has to count the batches from the beginning each time as
|
||||
there's no way to know which chunks were included in the prior batches without recounting atm,
|
||||
this is somewhat slow as it requires tokenizing all the chunks again
|
||||
"""
|
||||
batch_index = 0
|
||||
latest_batch_indices: list[int] = []
|
||||
token_count = 0
|
||||
|
||||
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
|
||||
for selection_target in [True, False]:
|
||||
for ind, chunk in enumerate(chunks):
|
||||
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
|
||||
IGNORE_FOR_QA
|
||||
):
|
||||
continue
|
||||
|
||||
# We calculate it live in case the user uses a different LLM + tokenizer
|
||||
chunk_token = check_number_of_tokens(chunk.content)
|
||||
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
|
||||
token_count += chunk_token + 50
|
||||
|
||||
# Always use at least 1 chunk
|
||||
if (
|
||||
token_limit is None
|
||||
or token_count <= token_limit
|
||||
or not latest_batch_indices
|
||||
):
|
||||
latest_batch_indices.append(ind)
|
||||
current_chunk_unused = False
|
||||
else:
|
||||
current_chunk_unused = True
|
||||
|
||||
if token_limit is not None and token_count >= token_limit:
|
||||
if batch_index < batch_offset:
|
||||
batch_index += 1
|
||||
if current_chunk_unused:
|
||||
latest_batch_indices = [ind]
|
||||
token_count = chunk_token
|
||||
else:
|
||||
latest_batch_indices = []
|
||||
token_count = 0
|
||||
else:
|
||||
return latest_batch_indices
|
||||
|
||||
return latest_batch_indices
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
@ -340,7 +60,7 @@ def create_chat_chain(
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
if not all_chat_messages:
|
||||
raise ValueError("No messages in Chat Session")
|
||||
raise RuntimeError("No messages in Chat Session")
|
||||
|
||||
root_message = all_chat_messages[0]
|
||||
if root_message.parent_message is not None:
|
||||
@ -370,7 +90,7 @@ def create_chat_chain(
|
||||
|
||||
|
||||
def combine_message_chain(
|
||||
messages: list[ChatMessage],
|
||||
messages: list[ChatMessage] | list[PreviousMessage],
|
||||
token_limit: int,
|
||||
msg_limit: int | None = None,
|
||||
) -> str:
|
||||
@ -381,7 +101,7 @@ def combine_message_chain(
|
||||
if msg_limit is not None:
|
||||
messages = messages[-msg_limit:]
|
||||
|
||||
for message in reversed(messages):
|
||||
for message in cast(list[ChatMessage] | list[PreviousMessage], reversed(messages)):
|
||||
message_token_count = message.token_count
|
||||
|
||||
if total_token_count + message_token_count > token_limit:
|
||||
@ -394,218 +114,58 @@ def combine_message_chain(
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
_PER_MESSAGE_TOKEN_BUFFER = 7
|
||||
def reorganize_citations(
|
||||
answer: str, citations: list[CitationInfo]
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
"""For a complete, citation-aware response, we want to reorganize the citations so that
|
||||
they are in the order of the documents that were used in the response. This just looks nicer / avoids
|
||||
confusion ("Why is there [7] when only 2 documents are cited?")."""
|
||||
|
||||
# Regular expression to find all instances of [[x]](LINK)
|
||||
pattern = r"\[\[(.*?)\]\]\((.*?)\)"
|
||||
|
||||
def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
|
||||
"""From the back, find the index of the last element to include
|
||||
before the list exceeds the maximum"""
|
||||
running_sum = 0
|
||||
all_citation_matches = re.findall(pattern, answer)
|
||||
|
||||
last_ind = 0
|
||||
for i in range(len(lst) - 1, -1, -1):
|
||||
running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER
|
||||
if running_sum > max_prompt_tokens:
|
||||
last_ind = i + 1
|
||||
break
|
||||
if last_ind >= len(lst):
|
||||
raise ValueError("Last message alone is too large!")
|
||||
return last_ind
|
||||
|
||||
|
||||
def drop_messages_history_overflow(
|
||||
system_msg: BaseMessage | None,
|
||||
system_token_count: int,
|
||||
history_msgs: list[BaseMessage],
|
||||
history_token_counts: list[int],
|
||||
final_msg: BaseMessage,
|
||||
final_msg_token_count: int,
|
||||
max_allowed_tokens: int,
|
||||
) -> list[BaseMessage]:
|
||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||
prompt template must be included"""
|
||||
if len(history_msgs) != len(history_token_counts):
|
||||
# This should never happen
|
||||
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
||||
|
||||
prompt: list[BaseMessage] = []
|
||||
|
||||
# Start dropping from the history if necessary
|
||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||
ind_prev_msg_start = find_last_index(
|
||||
all_tokens, max_prompt_tokens=max_allowed_tokens
|
||||
)
|
||||
|
||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||
prompt.append(system_msg)
|
||||
|
||||
prompt.extend(history_msgs[ind_prev_msg_start:])
|
||||
|
||||
prompt.append(final_msg)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def in_code_block(llm_text: str) -> bool:
|
||||
count = llm_text.count(TRIPLE_BACKTICK)
|
||||
return count % 2 != 0
|
||||
|
||||
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||
llm_out = ""
|
||||
max_citation_num = len(context_docs)
|
||||
curr_segment = ""
|
||||
prepend_bracket = False
|
||||
cited_inds = set()
|
||||
hold = ""
|
||||
for raw_token in tokens:
|
||||
if stop_stream:
|
||||
next_hold = hold + raw_token
|
||||
|
||||
if stop_stream in next_hold:
|
||||
break
|
||||
|
||||
if next_hold == stop_stream[: len(next_hold)]:
|
||||
hold = next_hold
|
||||
new_citation_info: dict[int, CitationInfo] = {}
|
||||
for citation_match in all_citation_matches:
|
||||
try:
|
||||
citation_num = int(citation_match[0])
|
||||
if citation_num in new_citation_info:
|
||||
continue
|
||||
|
||||
token = next_hold
|
||||
hold = ""
|
||||
else:
|
||||
token = raw_token
|
||||
matching_citation = next(
|
||||
iter([c for c in citations if c.citation_num == int(citation_num)]),
|
||||
None,
|
||||
)
|
||||
if matching_citation is None:
|
||||
continue
|
||||
|
||||
# Special case of [1][ where ][ is a single token
|
||||
# This is where the model attempts to do consecutive citations like [1][2]
|
||||
if prepend_bracket:
|
||||
curr_segment += "[" + curr_segment
|
||||
prepend_bracket = False
|
||||
new_citation_info[citation_num] = CitationInfo(
|
||||
citation_num=len(new_citation_info) + 1,
|
||||
document_id=matching_citation.document_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
curr_segment += token
|
||||
llm_out += token
|
||||
# Function to replace citations with their new number
|
||||
def slack_link_format(match: re.Match) -> str:
|
||||
link_text = match.group(1)
|
||||
try:
|
||||
citation_num = int(link_text)
|
||||
if citation_num in new_citation_info:
|
||||
link_text = new_citation_info[citation_num].citation_num
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
link_url = match.group(2)
|
||||
return f"[[{link_text}]]({link_url})"
|
||||
|
||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
citation_found = re.search(citation_pattern, curr_segment)
|
||||
# Substitute all matches in the input text
|
||||
new_answer = re.sub(pattern, slack_link_format, answer)
|
||||
|
||||
if citation_found and not in_code_block(llm_out):
|
||||
numerical_value = int(citation_found.group(1))
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[
|
||||
numerical_value - 1
|
||||
] # remove 1 index offset
|
||||
# if any citations weren't parsable, just add them back to be safe
|
||||
for citation in citations:
|
||||
if citation.citation_num not in new_citation_info:
|
||||
new_citation_info[citation.citation_num] = citation
|
||||
|
||||
link = context_llm_doc.link
|
||||
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
|
||||
|
||||
# Use the citation number for the document's rank in
|
||||
# the search (or selected docs) results
|
||||
curr_segment = re.sub(
|
||||
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
|
||||
)
|
||||
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
if link:
|
||||
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||
|
||||
# In case there's another open bracket like [1][, don't want to match this
|
||||
possible_citation_found = None
|
||||
|
||||
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||
# citation that needs to be replaced with a link
|
||||
if possible_citation_found:
|
||||
continue
|
||||
|
||||
# Special case with back to back citations [1][2]
|
||||
if curr_segment and curr_segment[-1] == "[":
|
||||
curr_segment = curr_segment[:-1]
|
||||
prepend_bracket = True
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
curr_segment = ""
|
||||
|
||||
if curr_segment:
|
||||
if prepend_bracket:
|
||||
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
||||
else:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
|
||||
|
||||
def get_prompt_tokens(prompt: Prompt) -> int:
|
||||
return (
|
||||
check_number_of_tokens(prompt.system_prompt)
|
||||
+ check_number_of_tokens(prompt.task_prompt)
|
||||
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
|
||||
+ CITATION_STATEMENT_TOKEN_CNT
|
||||
+ CITATION_REMINDER_TOKEN_CNT
|
||||
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
|
||||
)
|
||||
|
||||
|
||||
# buffer just to be safe so that we don't overflow the token limit due to
|
||||
# a small miscalculation
|
||||
_MISC_BUFFER = 40
|
||||
|
||||
|
||||
def compute_max_document_tokens(
|
||||
persona: Persona,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
"""Estimates the number of tokens available for context documents. Formula is roughly:
|
||||
|
||||
(
|
||||
model_context_window - reserved_output_tokens - prompt_tokens
|
||||
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
|
||||
)
|
||||
|
||||
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
|
||||
if we're trying to determine if the user should be able to select another document) then we just set an
|
||||
arbitrary "upper bound".
|
||||
"""
|
||||
llm_name = GEN_AI_MODEL_VERSION
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
# if we can't find a number of tokens, just assume some common default
|
||||
max_input_tokens = (
|
||||
max_llm_token_override
|
||||
if max_llm_token_override
|
||||
else get_max_input_tokens(model_name=llm_name)
|
||||
)
|
||||
if persona.prompts:
|
||||
# TODO this may not always be the first prompt
|
||||
prompt_tokens = get_prompt_tokens(persona.prompts[0])
|
||||
else:
|
||||
raise RuntimeError("Persona has no prompts - this should never happen")
|
||||
user_input_tokens = (
|
||||
check_number_of_tokens(actual_user_input)
|
||||
if actual_user_input is not None
|
||||
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
)
|
||||
|
||||
return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
|
||||
|
||||
|
||||
def compute_max_llm_input_tokens(persona: Persona) -> int:
|
||||
"""Maximum tokens allows in the input to the LLM (of any type)."""
|
||||
llm_name = GEN_AI_MODEL_VERSION
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
input_tokens = get_max_input_tokens(model_name=llm_name)
|
||||
return input_tokens - _MISC_BUFFER
|
||||
return new_answer, list(new_citation_info.values())
|
||||
|
@ -13,7 +13,7 @@ from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
||||
@ -24,7 +24,7 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
user_id=None,
|
||||
user=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
description=prompt["description"].strip(),
|
||||
@ -34,7 +34,6 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
shared=True,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
@ -67,9 +66,7 @@ def load_personas_from_yaml(
|
||||
prompts: list[PromptDBModel | None] | None = None
|
||||
else:
|
||||
prompts = [
|
||||
get_prompt_by_name(
|
||||
prompt_name, user_id=None, shared=True, db_session=db_session
|
||||
)
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
@ -78,22 +75,26 @@ def load_personas_from_yaml(
|
||||
if not prompts:
|
||||
prompts = None
|
||||
|
||||
p_id = persona.get("id")
|
||||
upsert_persona(
|
||||
user_id=None,
|
||||
persona_id=persona.get("id"),
|
||||
user=None,
|
||||
# Negative to not conflict with existing personas
|
||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
num_chunks=persona.get("num_chunks")
|
||||
if persona.get("num_chunks") is not None
|
||||
else default_chunks,
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompts=cast(list[PromptDBModel] | None, prompts),
|
||||
document_sets=doc_sets,
|
||||
default_persona=True,
|
||||
shared=True,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
@ -5,10 +5,10 @@ from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@ -16,11 +16,13 @@ class LlmDoc(BaseModel):
|
||||
|
||||
document_id: str
|
||||
content: str
|
||||
blurb: str
|
||||
semantic_identifier: str
|
||||
source_type: DocumentSource
|
||||
metadata: dict[str, str | list[str]]
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
@ -100,9 +102,21 @@ class QAResponse(SearchResponse, DanswerAnswer):
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[
|
||||
DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError
|
||||
]
|
||||
class ImageGenerationDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| ImageGenerationDisplay
|
||||
| StreamingError
|
||||
)
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
|
||||
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
|
@ -5,9 +5,9 @@ personas:
|
||||
# this is for DanswerBot to use when tagged in a non-configured channel
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Default"
|
||||
name: "Danswer"
|
||||
description: >
|
||||
Default Danswer Question Answering functionality.
|
||||
Assistant with access to documents from your Connected Sources.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
prompts:
|
||||
- "Answer-Question"
|
||||
@ -39,22 +39,23 @@ personas:
|
||||
document_sets: []
|
||||
|
||||
|
||||
- name: "Summarize"
|
||||
- id: 1
|
||||
name: "GPT"
|
||||
description: >
|
||||
A less creative assistant which summarizes relevant documents but does not try to
|
||||
extrapolate any answers for you.
|
||||
Assistant with no access to documents. Chat with just the Language Model.
|
||||
prompts:
|
||||
- "Summarize"
|
||||
num_chunks: 10
|
||||
- "OnlyLLM"
|
||||
num_chunks: 0
|
||||
llm_relevance_filter: true
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
|
||||
|
||||
- name: "Paraphrase"
|
||||
- id: 2
|
||||
name: "Paraphrase"
|
||||
description: >
|
||||
The least creative default assistant that only provides quotes from the documents.
|
||||
Assistant that is heavily constrained and only provides exact quotes from Connected Sources.
|
||||
prompts:
|
||||
- "Paraphrase"
|
||||
num_chunks: 10
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user