updated tool runner + displays

This commit is contained in:
pablodanswer 2024-09-17 08:46:37 -07:00
parent 2beffdaa6e
commit 3170430673
13 changed files with 68 additions and 51 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 26 KiB

After

Width:  |  Height:  |  Size: 27 KiB

View File

@ -709,7 +709,10 @@ def stream_chat_message_objects(
tool_call = ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_arguments={
k: v if not isinstance(v, bytes) else v.decode("utf-8")
for k, v in tool_result.tool_args.items()
},
tool_result=tool_result.tool_result,
)
@ -814,11 +817,12 @@ def stream_chat_message_objects(
)
elif packet.id == GRAPHING_RESPONSE_ID:
graph_generation = cast(GraphingResponse, packet.response)
yield graph_generation
yield GraphGenerationDisplay(
file_id=graph_generation.extra_graph_display.file_id,
line_graph=graph_generation.extra_graph_display.line_graph,
)
# yield GraphGenerationDisplay(
# file_id=graph_generation.extra_graph_display.file_id,
# line_graph=graph_generation.extra_graph_display.line_graph,
# )
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(

View File

@ -1,19 +1,22 @@
import base64
import json
from datetime import datetime
from typing import Any
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
class DateTimeAndBytesEncoder(json.JSONEncoder):
"""Custom JSON encoder that converts datetime objects to ISO format strings and bytes to base64."""
def default(self, obj: Any) -> Any:
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, bytes):
return base64.b64encode(obj).decode("utf-8")
return super().default(obj)
def get_json_line(
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeAndBytesEncoder
) -> str:
"""
Convert a dictionary to a JSON string with datetime handling, and add a newline.

View File

@ -1,4 +1,3 @@
import base64
import json
import os
import re
@ -29,7 +28,7 @@ from danswer.prompts.chat_prompts import (
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
from danswer.tools.graphing.models import GraphingError
from danswer.tools.graphing.models import GraphingResponse
from danswer.tools.graphing.models import GraphingResult
from danswer.tools.graphing.models import GraphType
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
@ -276,6 +275,7 @@ class GraphingTool(Tool):
print(kwargs)
file_content = kwargs["filename"]
file_content = file_content.decode("utf-8")
csv_file = StringIO(file_content)
df = pd.read_csv(csv_file)
@ -320,13 +320,13 @@ class GraphingTool(Tool):
ax = fig.gca() # Get the current Axes
plot_data = None
plot_type = None
plot_type: GraphType | None = None
if self.is_line_plot(ax):
plot_data = self.extract_line_plot_data(ax)
plot_type = "line"
plot_type = GraphType.LINE_GRAPH
elif self.is_bar_plot(ax):
plot_data = self.extract_bar_plot_data(ax)
plot_type = "bar"
plot_type = GraphType.BAR_CHART
if plot_data:
plot_data_file = os.path.join(self.output_dir, "plot_data.json")
@ -350,23 +350,19 @@ class GraphingTool(Tool):
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight") # type: ignore
img_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
with open("aaa garp.png", "wb") as f:
f.write(buf.getvalue())
graph_result = GraphingResult(image=img_base64, plot_data=plot_data)
print("da plot type iza")
print(plot_type)
print("\n\n\n")
print(code)
response = GraphingResponse(
graph_result=graph_result,
extra_graph_display={
"file_id": file_id,
"line_graph": plot_type == "line",
},
yield ToolResponse(
id=GRAPHING_RESPONSE_ID,
response=GraphingResponse(
file_id=str(file_id),
graph_type=plot_type.value
if plot_type
else None, # Use .value to get the string
plot_data=plot_data, # Pass the dictionary directly, not as a JSON string
),
)
yield ToolResponse(id=GRAPHING_RESPONSE_ID, response=response)
except Exception as e:
error_msg = f"Error generating graph: {str(e)}\n{traceback.format_exc()}"

View File

@ -1,4 +1,5 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
@ -8,21 +9,15 @@ class GraphGenerationDisplay(BaseModel):
line_graph: bool
class GraphType(Enum):
class GraphType(str, Enum):
BAR_CHART = "bar_chart"
LINE_GRAPH = "line_graph"
# SCATTER_PLOT = "scatter_plot"
# PIE_CHART = "pie_chart"
# HISTOGRAM = "histogram"
class GraphingResponse(BaseModel):
revised_query: str | None = None
file_id: str
graph_tye: GraphType
# graph_display: GraphGenerationDisplay | None
plot_data: dict[str, Any] | None
graph_type: GraphType
class GraphingError(BaseModel):

View File

@ -1,3 +1,4 @@
import base64
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
@ -27,6 +28,7 @@ class ToolRunner:
if self._tool_responses is not None:
print("prev")
print(self._tool_responses)
yield from self._tool_responses
return
@ -35,6 +37,11 @@ class ToolRunner:
print(self.tool.name)
for tool_response in self.tool.run(llm=self._llm, **self.args):
if isinstance(tool_response.response, bytes):
tool_response.response = base64.b64encode(
tool_response.response
).decode("utf-8")
print("tool response")
yield tool_response
tool_responses.append(tool_response)

View File

@ -1 +1 @@
{"data": [{"x": 0.0, "y": 91.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 1.0, "y": 46.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 2.0, "y": 26.41338, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 3.0, "y": 1.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 4.0, "y": 23.5, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 5.0, "y": 46.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 6.0, "y": 68.5, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 7.0, "y": 91.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}], "title": "Summary Statistics of Id", "xlabel": "Statistic", "ylabel": "Value", "xticks": [0, 1, 2, 3, 4, 5, 6, 7], "xticklabels": ["count", "mean", "std", "min", "25%", "50%", "75%", "max"]}
{"data": [{"x": 0.0, "y": 91.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 1.0, "y": 46.0, "width": 0.7999999999999999, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 2.0, "y": 26.41338, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 3.0, "y": 1.0, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 4.0, "y": 23.5, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 5.0, "y": 46.0, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 6.0, "y": 68.5, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 7.0, "y": 91.0, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}], "title": "Summary Statistics of Id Column", "xlabel": "Statistic", "ylabel": "Value", "xticks": [0, 1, 2, 3, 4, 5, 6, 7], "xticklabels": ["count", "mean", "std", "min", "25%", "50%", "75%", "max"]}

View File

@ -154,7 +154,6 @@ export default function AddConnector({
initialValues={createConnectorInitialValues(connector)}
validationSchema={createConnectorValidationSchema(connector)}
onSubmit={async (values) => {
console.log(" Iam submiing the connector");
const {
name,
groups,

View File

@ -1143,6 +1143,7 @@ export function ChatPage({
continue;
}
console.log(packet);
if (!initialFetchDetails) {
if (!Object.hasOwn(packet, "user_message_id")) {
console.error(

View File

@ -77,13 +77,6 @@ const TOOLS_WITH_CUSTOM_HANDLING = [
INTERNET_SEARCH_TOOL_NAME,
IMAGE_GENERATION_TOOL_NAME,
];
import plotDataJson from "./linechart.json";
import barChartDataJson from "./barchart_data.json";
import polarChartDataJson from "./polar_plot_data.json";
import { JSONUpload } from "./JSONUpload";
import { ImageDisplay } from "@/components/chat_display/graphs/ImageDisplay";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
function FileDisplay({
files,
@ -164,15 +157,18 @@ function FileDisplay({
</>
);
}
export interface graph {
file_id: string;
line: boolean;
enum GraphType {
BAR_CHART = "bar_chart",
LINE_GRAPH = "line_graph",
}
export interface GraphChunk {
file_id: string;
line_graph: boolean;
plot_data: Record<string, any> | null;
graph_type: GraphType | null;
}
export const AIMessage = ({
hasChildAI,
hasParentAI,
@ -208,7 +204,7 @@ export const AIMessage = ({
shared?: boolean;
hasChildAI?: boolean;
hasParentAI?: boolean;
graphs?: graph[];
graphs?: GraphChunk[];
isActive?: boolean;
continueGenerating?: () => void;
otherMessagesCanSwitchTo?: number[];
@ -236,6 +232,7 @@ export const AIMessage = ({
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
setPopup?: (popupSpec: PopupSpec | null) => void;
}) => {
console.log(toolCall);
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
const toolCallGenerating = toolCall && !toolCall.tool_result;
@ -439,7 +436,7 @@ export const AIMessage = ({
</div>
)}
{graphs.map((graph, ind) => {
return graph.line ? (
return graph.graph_type === GraphType.LINE_GRAPH ? (
<ModalChartWrapper
key={ind}
chartType="line"
@ -460,6 +457,17 @@ export const AIMessage = ({
{content || files ? (
<>
{toolCall?.tool_name == "create_graph" && (
<ModalChartWrapper
key={0}
chartType="line"
fileId={toolCall?.tool_result?.file_id}
>
<LineChartDisplay
fileId={toolCall?.tool_result?.file_id}
/>
</ModalChartWrapper>
)}
<FileDisplay files={files || []} />
{typeof content === "string" ? (

View File

@ -124,7 +124,6 @@ export const CsvSection = ({
setFadeIn(false);
}
}, [isLoading]);
console.log("rerendering");
const downloadFile = () => {
if (!fileId) return;

View File

@ -41,6 +41,7 @@ export function BarChartDisplay({ fileId }: { fileId: string }) {
if (!barPlotData) {
return <div>Loading...</div>;
}
console.log("IN THE FUNCTION");
// Transform data to match Recharts expected format
const transformedData = barPlotData.data.map((point, index) => ({

View File

@ -190,6 +190,8 @@ export function LineChartDisplay({ fileId }: { fileId: string }) {
throw new Error("Failed to fetch plot data");
}
const plotDataJson: PlotData = await response.json();
console.log("plot data");
console.log(plotDataJson);
const transformedData: ChartDataPoint[] = plotDataJson.data[0].x.map(
(x, index) => ({
@ -212,6 +214,8 @@ export function LineChartDisplay({ fileId }: { fileId: string }) {
console.error("Error fetching plot data:", error);
}
};
console.log("chartData");
console.log(chartData);
return (
<div className="w-full h-full">