Skip to content

Commit b0e0dee

Browse files
Added support to configure DB path, added result to CSV report and minor timestamp addition to logs
1 parent ae62f8e commit b0e0dee

File tree

2 files changed

+47
-35
lines changed

2 files changed

+47
-35
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
*.db
2-
.venv/
2+
*.csv
3+
.mypy_cache/
4+
.venv/
5+
.python-version

main.py

+43-34
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from tqdm import tqdm
1717
from unstract.api_deployments.client import APIDeploymentsClient
1818

19-
DB_NAME = "file_processing.db"
20-
global_arguments = None
2119
logger = logging.getLogger(__name__)
2220

2321

@@ -29,6 +27,7 @@ class Arguments:
2927
api_timeout: int = 10
3028
poll_interval: int = 5
3129
input_folder_path: str = ""
30+
db_path: str = ""
3231
parallel_call_count: int = 5
3332
retry_failed: bool = False
3433
retry_pending: bool = False
@@ -42,8 +41,8 @@ class Arguments:
4241

4342

4443
# Initialize SQLite DB
45-
def init_db():
46-
conn = sqlite3.connect(DB_NAME)
44+
def init_db(args: Arguments):
45+
conn = sqlite3.connect(args.db_path)
4746
c = conn.cursor()
4847

4948
# Create the table if it doesn't exist
@@ -89,7 +88,7 @@ def init_db():
8988

9089
# Check if the file is already processed
9190
def skip_file_processing(file_name, args: Arguments):
92-
conn = sqlite3.connect(DB_NAME)
91+
conn = sqlite3.connect(args.db_path)
9392
c = conn.cursor()
9493
c.execute(
9594
"SELECT execution_status FROM file_status WHERE file_name = ?", (file_name,)
@@ -124,6 +123,7 @@ def update_db(
124123
time_taken,
125124
status_code,
126125
status_api_endpoint,
126+
args: Arguments
127127
):
128128

129129
total_embedding_cost = None
@@ -138,7 +138,7 @@ def update_db(
138138
if execution_status == "ERROR":
139139
error_message = extract_error_message(result)
140140

141-
conn = sqlite3.connect(DB_NAME)
141+
conn = sqlite3.connect(args.db_path)
142142
conn.set_trace_callback(
143143
lambda x: (
144144
logger.debug(f"[{file_name}] Executing statement: {x}")
@@ -232,8 +232,8 @@ def extract_error_message(result):
232232
return result.get("error", "No error message found")
233233

234234
# Print final summary with count of each status and average time using a single SQL query
235-
def print_summary():
236-
conn = sqlite3.connect(DB_NAME)
235+
def print_summary(args: Arguments):
236+
conn = sqlite3.connect(args.db_path)
237237
c = conn.cursor()
238238

239239
# Fetch count and average time for each status
@@ -255,8 +255,8 @@ def print_summary():
255255
print(f"Status '{status}': {count}")
256256

257257

258-
def print_report():
259-
conn = sqlite3.connect(DB_NAME)
258+
def print_report(args: Arguments):
259+
conn = sqlite3.connect(args.db_path)
260260
c = conn.cursor()
261261

262262
# Fetch required fields, including total_cost and total_tokens
@@ -318,36 +318,36 @@ def print_report():
318318

319319
print("\nNote: For more detailed error messages, use the CSV report argument.")
320320

321-
def export_report_to_csv(output_path):
322-
conn = sqlite3.connect(DB_NAME)
321+
def export_report_to_csv(args: Arguments):
322+
conn = sqlite3.connect(args.db_path)
323323
c = conn.cursor()
324324

325325
c.execute(
326326
"""
327-
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
327+
SELECT file_name, execution_status, result, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
328328
FROM file_status
329329
"""
330330
)
331331
report_data = c.fetchall()
332332
conn.close()
333333

334334
if not report_data:
335-
print("No data available to export.")
335+
print("No data available to export as CSV.")
336336
return
337337

338338
# Define the headers
339339
headers = [
340-
"File Name", "Execution Status", "Time Elapsed (seconds)",
340+
"File Name", "Execution Status", "Result", "Time Elapsed (seconds)",
341341
"Total Embedding Cost", "Total Embedding Tokens",
342342
"Total LLM Cost", "Total LLM Tokens", "Error Message"
343343
]
344344

345345
try:
346-
with open(output_path, 'w', newline='') as csvfile:
346+
with open(args.csv_report, 'w', newline='') as csvfile:
347347
writer = csv.writer(csvfile)
348348
writer.writerow(headers) # Write headers
349349
writer.writerows(report_data) # Write data rows
350-
print(f"CSV successfully exported to {output_path}")
350+
print(f"CSV successfully exported to '{args.csv_report}'")
351351
except Exception as e:
352352
print(f"Error exporting to CSV: {e}")
353353

@@ -357,7 +357,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
357357
status_endpoint = None
358358

359359
# If retry_pending is True, check if the status API endpoint is available
360-
conn = sqlite3.connect(DB_NAME)
360+
conn = sqlite3.connect(args.db_path)
361361
c = conn.cursor()
362362
c.execute(
363363
"SELECT status_api_endpoint FROM file_status WHERE file_name = ? AND execution_status NOT IN ('COMPLETED', 'ERROR')",
@@ -382,7 +382,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
382382

383383
# Fresh API call to process the file
384384
execution_status = "STARTING"
385-
update_db(file_path, execution_status, None, None, None, None)
385+
update_db(file_path, execution_status, None, None, None, None, args=args)
386386
response = client.structure_file(file_paths=[file_path])
387387
logger.debug(f"[{file_path}] Response of initial API call: {response}")
388388
status_endpoint = response.get(
@@ -397,6 +397,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
397397
None,
398398
status_code,
399399
status_endpoint,
400+
args=args
400401
)
401402
return status_endpoint, execution_status, response
402403

@@ -436,7 +437,7 @@ def process_file(
436437
execution_status = response.get("execution_status")
437438
status_code = response.get("status_code") # Default to 200 if not provided
438439
update_db(
439-
file_path, execution_status, None, None, status_code, status_endpoint
440+
file_path, execution_status, None, None, status_code, status_endpoint, args=args
440441
)
441442

442443
result = response
@@ -456,7 +457,7 @@ def process_file(
456457
end_time = time.time()
457458
time_taken = round(end_time - start_time, 2)
458459
update_db(
459-
file_path, execution_status, result, time_taken, status_code, status_endpoint
460+
file_path, execution_status, result, time_taken, status_code, status_endpoint, args=args
460461
)
461462
logger.info(f"[{file_path}]: Processing completed: {execution_status}")
462463

@@ -550,6 +551,19 @@ def main():
550551
default=5,
551552
help="Number of calls to be made in parallel.",
552553
)
554+
parser.add_argument(
555+
"--db_path",
556+
dest="db_path",
557+
type=str,
558+
default="file_processing.db",
559+
help="Path where the SQlite DB file is stored, defaults to './file_processing.db'",
560+
)
561+
parser.add_argument(
562+
'--csv_report',
563+
dest="csv_report",
564+
type=str,
565+
help='Path to export the detailed report as a CSV file',
566+
)
553567
parser.add_argument(
554568
"--retry_failed",
555569
dest="retry_failed",
@@ -588,50 +602,45 @@ def main():
588602
action="store_true",
589603
help="Print a detailed report of all file processed.",
590604
)
591-
592605
parser.add_argument(
593606
"--exclude_metadata",
594607
dest="include_metadata",
595608
action="store_false",
596609
help="Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file.",
597610
)
598-
599611
parser.add_argument(
600612
"--no_verify",
601613
dest="verify",
602614
action="store_false",
603615
help="Disable SSL certificate verification.",
604616
)
605617

606-
parser.add_argument(
607-
'--csv_report',
608-
dest="csv_report",
609-
type=str,
610-
help='Path to export the detailed report as a CSV file',
611-
)
612-
613618
args = Arguments(**vars(parser.parse_args()))
614619

615620
ch = logging.StreamHandler(sys.stdout)
616621
ch.setLevel(args.log_level)
622+
formatter = logging.Formatter(
623+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
624+
)
625+
ch.setFormatter(formatter)
617626
logging.basicConfig(level=args.log_level, handlers=[ch])
618627

619628
logger.warning(f"Running with params: {args}")
620629

621-
init_db() # Initialize DB
630+
init_db(args=args) # Initialize DB
622631

623632
load_folder(args=args)
624633

625-
print_summary() # Print summary at the end
634+
print_summary(args=args) # Print summary at the end
626635
if args.print_report:
627-
print_report()
636+
print_report(args=args)
628637
logger.warning(
629638
"Elapsed time calculation of a file which was resumed"
630639
" from pending state will not be correct"
631640
)
632641

633642
if args.csv_report:
634-
export_report_to_csv(args.csv_report)
643+
export_report_to_csv(args=args)
635644

636645

637646
if __name__ == "__main__":

0 commit comments

Comments
 (0)