16
16
from tqdm import tqdm
17
17
from unstract .api_deployments .client import APIDeploymentsClient
18
18
19
- DB_NAME = "file_processing.db"
20
- global_arguments = None
21
19
logger = logging .getLogger (__name__ )
22
20
23
21
@@ -29,6 +27,7 @@ class Arguments:
29
27
api_timeout : int = 10
30
28
poll_interval : int = 5
31
29
input_folder_path : str = ""
30
+ db_path : str = ""
32
31
parallel_call_count : int = 5
33
32
retry_failed : bool = False
34
33
retry_pending : bool = False
@@ -42,8 +41,8 @@ class Arguments:
42
41
43
42
44
43
# Initialize SQLite DB
45
- def init_db ():
46
- conn = sqlite3 .connect (DB_NAME )
44
+ def init_db (args : Arguments ):
45
+ conn = sqlite3 .connect (args . db_path )
47
46
c = conn .cursor ()
48
47
49
48
# Create the table if it doesn't exist
@@ -89,7 +88,7 @@ def init_db():
89
88
90
89
# Check if the file is already processed
91
90
def skip_file_processing (file_name , args : Arguments ):
92
- conn = sqlite3 .connect (DB_NAME )
91
+ conn = sqlite3 .connect (args . db_path )
93
92
c = conn .cursor ()
94
93
c .execute (
95
94
"SELECT execution_status FROM file_status WHERE file_name = ?" , (file_name ,)
@@ -124,6 +123,7 @@ def update_db(
124
123
time_taken ,
125
124
status_code ,
126
125
status_api_endpoint ,
126
+ args : Arguments
127
127
):
128
128
129
129
total_embedding_cost = None
@@ -138,7 +138,7 @@ def update_db(
138
138
if execution_status == "ERROR" :
139
139
error_message = extract_error_message (result )
140
140
141
- conn = sqlite3 .connect (DB_NAME )
141
+ conn = sqlite3 .connect (args . db_path )
142
142
conn .set_trace_callback (
143
143
lambda x : (
144
144
logger .debug (f"[{ file_name } ] Executing statement: { x } " )
@@ -232,8 +232,8 @@ def extract_error_message(result):
232
232
return result .get ("error" , "No error message found" )
233
233
234
234
# Print final summary with count of each status and average time using a single SQL query
235
- def print_summary ():
236
- conn = sqlite3 .connect (DB_NAME )
235
+ def print_summary (args : Arguments ):
236
+ conn = sqlite3 .connect (args . db_path )
237
237
c = conn .cursor ()
238
238
239
239
# Fetch count and average time for each status
@@ -255,8 +255,8 @@ def print_summary():
255
255
print (f"Status '{ status } ': { count } " )
256
256
257
257
258
- def print_report ():
259
- conn = sqlite3 .connect (DB_NAME )
258
+ def print_report (args : Arguments ):
259
+ conn = sqlite3 .connect (args . db_path )
260
260
c = conn .cursor ()
261
261
262
262
# Fetch required fields, including total_cost and total_tokens
@@ -318,36 +318,36 @@ def print_report():
318
318
319
319
print ("\n Note: For more detailed error messages, use the CSV report argument." )
320
320
321
- def export_report_to_csv (output_path ):
322
- conn = sqlite3 .connect (DB_NAME )
321
+ def export_report_to_csv (args : Arguments ):
322
+ conn = sqlite3 .connect (args . db_path )
323
323
c = conn .cursor ()
324
324
325
325
c .execute (
326
326
"""
327
- SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
327
+ SELECT file_name, execution_status, result, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
328
328
FROM file_status
329
329
"""
330
330
)
331
331
report_data = c .fetchall ()
332
332
conn .close ()
333
333
334
334
if not report_data :
335
- print ("No data available to export." )
335
+ print ("No data available to export as CSV ." )
336
336
return
337
337
338
338
# Define the headers
339
339
headers = [
340
- "File Name" , "Execution Status" , "Time Elapsed (seconds)" ,
340
+ "File Name" , "Execution Status" , "Result" , " Time Elapsed (seconds)" ,
341
341
"Total Embedding Cost" , "Total Embedding Tokens" ,
342
342
"Total LLM Cost" , "Total LLM Tokens" , "Error Message"
343
343
]
344
344
345
345
try :
346
- with open (output_path , 'w' , newline = '' ) as csvfile :
346
+ with open (args . csv_report , 'w' , newline = '' ) as csvfile :
347
347
writer = csv .writer (csvfile )
348
348
writer .writerow (headers ) # Write headers
349
349
writer .writerows (report_data ) # Write data rows
350
- print (f"CSV successfully exported to { output_path } " )
350
+ print (f"CSV successfully exported to ' { args . csv_report } ' " )
351
351
except Exception as e :
352
352
print (f"Error exporting to CSV: { e } " )
353
353
@@ -357,7 +357,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
357
357
status_endpoint = None
358
358
359
359
# If retry_pending is True, check if the status API endpoint is available
360
- conn = sqlite3 .connect (DB_NAME )
360
+ conn = sqlite3 .connect (args . db_path )
361
361
c = conn .cursor ()
362
362
c .execute (
363
363
"SELECT status_api_endpoint FROM file_status WHERE file_name = ? AND execution_status NOT IN ('COMPLETED', 'ERROR')" ,
@@ -382,7 +382,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
382
382
383
383
# Fresh API call to process the file
384
384
execution_status = "STARTING"
385
- update_db (file_path , execution_status , None , None , None , None )
385
+ update_db (file_path , execution_status , None , None , None , None , args = args )
386
386
response = client .structure_file (file_paths = [file_path ])
387
387
logger .debug (f"[{ file_path } ] Response of initial API call: { response } " )
388
388
status_endpoint = response .get (
@@ -397,6 +397,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
397
397
None ,
398
398
status_code ,
399
399
status_endpoint ,
400
+ args = args
400
401
)
401
402
return status_endpoint , execution_status , response
402
403
@@ -436,7 +437,7 @@ def process_file(
436
437
execution_status = response .get ("execution_status" )
437
438
status_code = response .get ("status_code" ) # Default to 200 if not provided
438
439
update_db (
439
- file_path , execution_status , None , None , status_code , status_endpoint
440
+ file_path , execution_status , None , None , status_code , status_endpoint , args = args
440
441
)
441
442
442
443
result = response
@@ -456,7 +457,7 @@ def process_file(
456
457
end_time = time .time ()
457
458
time_taken = round (end_time - start_time , 2 )
458
459
update_db (
459
- file_path , execution_status , result , time_taken , status_code , status_endpoint
460
+ file_path , execution_status , result , time_taken , status_code , status_endpoint , args = args
460
461
)
461
462
logger .info (f"[{ file_path } ]: Processing completed: { execution_status } " )
462
463
@@ -550,6 +551,19 @@ def main():
550
551
default = 5 ,
551
552
help = "Number of calls to be made in parallel." ,
552
553
)
554
+ parser .add_argument (
555
+ "--db_path" ,
556
+ dest = "db_path" ,
557
+ type = str ,
558
+ default = "file_processing.db" ,
559
+ help = "Path where the SQlite DB file is stored, defaults to './file_processing.db'" ,
560
+ )
561
+ parser .add_argument (
562
+ '--csv_report' ,
563
+ dest = "csv_report" ,
564
+ type = str ,
565
+ help = 'Path to export the detailed report as a CSV file' ,
566
+ )
553
567
parser .add_argument (
554
568
"--retry_failed" ,
555
569
dest = "retry_failed" ,
@@ -588,50 +602,45 @@ def main():
588
602
action = "store_true" ,
589
603
help = "Print a detailed report of all file processed." ,
590
604
)
591
-
592
605
parser .add_argument (
593
606
"--exclude_metadata" ,
594
607
dest = "include_metadata" ,
595
608
action = "store_false" ,
596
609
help = "Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file." ,
597
610
)
598
-
599
611
parser .add_argument (
600
612
"--no_verify" ,
601
613
dest = "verify" ,
602
614
action = "store_false" ,
603
615
help = "Disable SSL certificate verification." ,
604
616
)
605
617
606
- parser .add_argument (
607
- '--csv_report' ,
608
- dest = "csv_report" ,
609
- type = str ,
610
- help = 'Path to export the detailed report as a CSV file' ,
611
- )
612
-
613
618
args = Arguments (** vars (parser .parse_args ()))
614
619
615
620
ch = logging .StreamHandler (sys .stdout )
616
621
ch .setLevel (args .log_level )
622
+ formatter = logging .Formatter (
623
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
624
+ )
625
+ ch .setFormatter (formatter )
617
626
logging .basicConfig (level = args .log_level , handlers = [ch ])
618
627
619
628
logger .warning (f"Running with params: { args } " )
620
629
621
- init_db () # Initialize DB
630
+ init_db (args = args ) # Initialize DB
622
631
623
632
load_folder (args = args )
624
633
625
- print_summary () # Print summary at the end
634
+ print_summary (args = args ) # Print summary at the end
626
635
if args .print_report :
627
- print_report ()
636
+ print_report (args = args )
628
637
logger .warning (
629
638
"Elapsed time calculation of a file which was resumed"
630
639
" from pending state will not be correct"
631
640
)
632
641
633
642
if args .csv_report :
634
- export_report_to_csv (args . csv_report )
643
+ export_report_to_csv (args = args )
635
644
636
645
637
646
if __name__ == "__main__" :
0 commit comments