5
5
import sqlite3
6
6
import sys
7
7
import time
8
+ import textwrap
8
9
from dataclasses import dataclass
9
10
from datetime import datetime
10
11
from functools import partial
14
15
from tqdm import tqdm
15
16
from unstract .api_deployments .client import APIDeploymentsClient
16
17
17
- DB_NAME = "file_processing .db"
18
+ DB_NAME = "/home/praveen/Documents/db/demo .db"
18
19
global_arguments = None
19
20
logger = logging .getLogger (__name__ )
20
21
@@ -51,6 +52,10 @@ def init_db():
51
52
time_taken REAL,
52
53
status_code INTEGER,
53
54
status_api_endpoint TEXT,
55
+ total_embedding_cost TEXT,
56
+ total_embedding_tokens INTEGER DEFAULT 0,
57
+ total_llm_cost TEXT,
58
+ total_llm_tokens INTEGER DEFAULT 0,
54
59
updated_at TEXT,
55
60
created_at TEXT
56
61
)"""
@@ -97,6 +102,39 @@ def update_db(
97
102
status_code ,
98
103
status_api_endpoint ,
99
104
):
105
+
106
+ total_embedding_cost = 0.0
107
+ total_embedding_tokens = 0
108
+ total_llm_cost = 0.0
109
+ total_llm_tokens = 0
110
+
111
+ if result is not None :
112
+ # Extract 'extraction_result' from the result
113
+ extraction_result = result .get ("extraction_result" , [])
114
+
115
+ if extraction_result :
116
+ extraction_data = extraction_result [0 ].get ("result" , "" )
117
+
118
+ # If extraction_data is a string, attempt to parse it as JSON
119
+ if isinstance (extraction_data , str ):
120
+ try :
121
+ extraction_data = json .loads (extraction_data ) if extraction_data else {}
122
+ except json .JSONDecodeError :
123
+ extraction_data = {}
124
+
125
+ # Now we can safely access metadata, embedding_llm, and extraction_llm under the 'result'
126
+ metadata = extraction_data .get ("metadata" , {})
127
+ embedding_llm = metadata .get ("embedding" , [])
128
+ extraction_llm = metadata .get ("extraction_llm" , [])
129
+
130
+ # Calculate total cost from `cost_in_dollars` in both LLM arrays, converting to float as needed
131
+ total_embedding_cost += sum (float (item .get ("cost_in_dollars" , "0" )) for item in embedding_llm )
132
+ total_llm_cost += sum (float (item .get ("cost_in_dollars" , "0" )) for item in extraction_llm )
133
+
134
+ # Calculate total tokens using `embedding_tokens` for embedding and `total_tokens` for extraction
135
+ total_embedding_tokens += sum (item .get ("embedding_tokens" , 0 ) for item in embedding_llm )
136
+ total_llm_tokens += sum (item .get ("total_tokens" , 0 ) for item in extraction_llm )
137
+
100
138
conn = sqlite3 .connect (DB_NAME )
101
139
conn .set_trace_callback (
102
140
lambda x : (
@@ -109,16 +147,20 @@ def update_db(
109
147
now = datetime .now ().isoformat ()
110
148
c .execute (
111
149
"""
112
- INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, updated_at, created_at)
113
- VALUES (?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
114
- """ ,
150
+ INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, updated_at, created_at)
151
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
152
+ """ ,
115
153
(
116
154
file_name ,
117
155
execution_status ,
118
156
json .dumps (result ),
119
157
time_taken ,
120
158
status_code ,
121
159
status_api_endpoint ,
160
+ total_embedding_cost ,
161
+ total_embedding_tokens ,
162
+ total_llm_cost ,
163
+ total_llm_tokens ,
122
164
now ,
123
165
file_name ,
124
166
now ,
@@ -130,7 +172,7 @@ def update_db(
130
172
131
173
# Print final summary with count of each status and average time using a single SQL query
132
174
def print_summary ():
133
- conn = sqlite3 .connect ("file_processing.db" )
175
+ conn = sqlite3 .connect (DB_NAME )
134
176
c = conn .cursor ()
135
177
136
178
# Fetch count and average time for each status
@@ -153,13 +195,13 @@ def print_summary():
153
195
154
196
155
197
def print_report ():
156
- conn = sqlite3 .connect ("file_processing.db" )
198
+ conn = sqlite3 .connect (DB_NAME )
157
199
c = conn .cursor ()
158
200
159
- # Fetch count and average time for each status
201
+ # Fetch required fields, including total_cost and total_tokens
160
202
c .execute (
161
203
"""
162
- SELECT file_name, execution_status, time_taken
204
+ SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens
163
205
FROM file_status
164
206
"""
165
207
)
@@ -170,8 +212,18 @@ def print_report():
170
212
print ("\n Detailed Report:" )
171
213
if report_data :
172
214
# Tabulate the data with column headers
173
- headers = ["File Name" , "Execution Status" , "Time Elapsed (seconds)" ]
174
- print (tabulate (report_data , headers = headers , tablefmt = "pretty" ))
215
+ headers = ["File Name" , "Execution Status" , "Time Elapsed (seconds)" , "Total Embedding Cost" , "Total Embedding Tokens" , "Total LLM Cost" , "Total LLM Tokens" ]
216
+
217
+ # Wrap text in each column to a specific width (e.g., 30 characters for file names and 20 for others)
218
+ formatted_data = []
219
+ for row in report_data :
220
+ formatted_row = [
221
+ textwrap .fill (str (cell ), width = 30 ) if isinstance (cell , str ) else f"{ cell :.8f} " if isinstance (cell , float ) else cell
222
+ for cell in row
223
+ ]
224
+ formatted_data .append (formatted_row )
225
+
226
+ print (tabulate (formatted_data , headers = headers , tablefmt = "pretty" ))
175
227
else :
176
228
print ("No records found in the database." )
177
229
0 commit comments