Skip to content

Commit ccd3cb2

Browse files
committed
Add: total cost and embedding in detailed report
1 parent d05e5c4 commit ccd3cb2

File tree

1 file changed

+62
-10
lines changed

1 file changed

+62
-10
lines changed

main.py

+62-10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sqlite3
66
import sys
77
import time
8+
import textwrap
89
from dataclasses import dataclass
910
from datetime import datetime
1011
from functools import partial
@@ -14,7 +15,7 @@
1415
from tqdm import tqdm
1516
from unstract.api_deployments.client import APIDeploymentsClient
1617

17-
DB_NAME = "file_processing.db"
18+
DB_NAME = "/home/praveen/Documents/db/demo.db"
1819
global_arguments = None
1920
logger = logging.getLogger(__name__)
2021

@@ -51,6 +52,10 @@ def init_db():
5152
time_taken REAL,
5253
status_code INTEGER,
5354
status_api_endpoint TEXT,
55+
total_embedding_cost TEXT,
56+
total_embedding_tokens INTEGER DEFAULT 0,
57+
total_llm_cost TEXT,
58+
total_llm_tokens INTEGER DEFAULT 0,
5459
updated_at TEXT,
5560
created_at TEXT
5661
)"""
@@ -97,6 +102,39 @@ def update_db(
97102
status_code,
98103
status_api_endpoint,
99104
):
105+
106+
total_embedding_cost = 0.0
107+
total_embedding_tokens = 0
108+
total_llm_cost = 0.0
109+
total_llm_tokens = 0
110+
111+
if result is not None:
112+
# Extract 'extraction_result' from the result
113+
extraction_result = result.get("extraction_result", [])
114+
115+
if extraction_result:
116+
extraction_data = extraction_result[0].get("result", "")
117+
118+
# If extraction_data is a string, attempt to parse it as JSON
119+
if isinstance(extraction_data, str):
120+
try:
121+
extraction_data = json.loads(extraction_data) if extraction_data else {}
122+
except json.JSONDecodeError:
123+
extraction_data = {}
124+
125+
# Now we can safely access metadata, embedding_llm, and extraction_llm under the 'result'
126+
metadata = extraction_data.get("metadata", {})
127+
embedding_llm = metadata.get("embedding", [])
128+
extraction_llm = metadata.get("extraction_llm", [])
129+
130+
# Calculate total cost from `cost_in_dollars` in both LLM arrays, converting to float as needed
131+
total_embedding_cost += sum(float(item.get("cost_in_dollars", "0")) for item in embedding_llm)
132+
total_llm_cost += sum(float(item.get("cost_in_dollars", "0")) for item in extraction_llm)
133+
134+
# Calculate total tokens using `embedding_tokens` for embedding and `total_tokens` for extraction
135+
total_embedding_tokens += sum(item.get("embedding_tokens", 0) for item in embedding_llm)
136+
total_llm_tokens += sum(item.get("total_tokens", 0) for item in extraction_llm)
137+
100138
conn = sqlite3.connect(DB_NAME)
101139
conn.set_trace_callback(
102140
lambda x: (
@@ -109,16 +147,20 @@ def update_db(
109147
now = datetime.now().isoformat()
110148
c.execute(
111149
"""
112-
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, updated_at, created_at)
113-
VALUES (?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
114-
""",
150+
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, updated_at, created_at)
151+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
152+
""",
115153
(
116154
file_name,
117155
execution_status,
118156
json.dumps(result),
119157
time_taken,
120158
status_code,
121159
status_api_endpoint,
160+
total_embedding_cost,
161+
total_embedding_tokens,
162+
total_llm_cost,
163+
total_llm_tokens,
122164
now,
123165
file_name,
124166
now,
@@ -130,7 +172,7 @@ def update_db(
130172

131173
# Print final summary with count of each status and average time using a single SQL query
132174
def print_summary():
133-
conn = sqlite3.connect("file_processing.db")
175+
conn = sqlite3.connect(DB_NAME)
134176
c = conn.cursor()
135177

136178
# Fetch count and average time for each status
@@ -153,13 +195,13 @@ def print_summary():
153195

154196

155197
def print_report():
156-
conn = sqlite3.connect("file_processing.db")
198+
conn = sqlite3.connect(DB_NAME)
157199
c = conn.cursor()
158200

159-
# Fetch count and average time for each status
201+
# Fetch required fields, including total_cost and total_tokens
160202
c.execute(
161203
"""
162-
SELECT file_name, execution_status, time_taken
204+
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens
163205
FROM file_status
164206
"""
165207
)
@@ -170,8 +212,18 @@ def print_report():
170212
print("\nDetailed Report:")
171213
if report_data:
172214
# Tabulate the data with column headers
173-
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)"]
174-
print(tabulate(report_data, headers=headers, tablefmt="pretty"))
215+
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)", "Total Embedding Cost", "Total Embedding Tokens", "Total LLM Cost", "Total LLM Tokens"]
216+
217+
# Wrap text in each column to a specific width (e.g., 30 characters for file names and 20 for others)
218+
formatted_data = []
219+
for row in report_data:
220+
formatted_row = [
221+
textwrap.fill(str(cell), width=30) if isinstance(cell, str) else f"{cell:.8f}" if isinstance(cell, float) else cell
222+
for cell in row
223+
]
224+
formatted_data.append(formatted_row)
225+
226+
print(tabulate(formatted_data, headers=headers, tablefmt="pretty"))
175227
else:
176228
print("No records found in the database.")
177229

0 commit comments

Comments
 (0)