Skip to content

Commit 208029d

Browse files
committed
added --workers arg
1 parent f66d48d commit 208029d

1 file changed

Lines changed: 43 additions & 14 deletions

File tree

scripts/full_pipeline.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
import os
2626
import json
2727
import argparse
28+
from concurrent.futures import ProcessPoolExecutor, as_completed
2829
from pathlib import Path
29-
from typing import Dict
30+
from typing import Dict, Tuple
3031
import subprocess
3132
import sys
3233

@@ -49,6 +50,19 @@
4950
from src.preprocessing.pdf_text_extraction import extract_text_from_pdf_bytes
5051

5152

53+
def _extract_local_pdf(args: Tuple[Path, str]) -> Tuple[str, str, str | None]:
54+
"""Worker: read a local PDF and return (txt_name, label, text | None)."""
55+
pdf_path, label = args
56+
try:
57+
with open(pdf_path, "rb") as f:
58+
pdf_bytes = f.read()
59+
text = extract_text_from_pdf_bytes(pdf_bytes)
60+
return (f"{pdf_path.stem}.txt", label, text)
61+
except Exception as e:
62+
print(f"Error processing {pdf_path.name}: {e}")
63+
return (f"{pdf_path.stem}.txt", label, None)
64+
65+
5266
def run(cmd):
5367
print(f"$ {' '.join(cmd)}")
5468
r = subprocess.run(cmd)
@@ -97,7 +111,7 @@ def process_api_mode():
97111
print(f"Wrote {len(labels)} labeled text files.")
98112

99113

100-
def process_local_mode(data_path: Path):
114+
def process_local_mode(data_path: Path, workers: int = 1):
101115
"""Process PDFs from local directory."""
102116
if not data_path.exists():
103117
raise RuntimeError(f"Data path does not exist: {data_path}")
@@ -114,23 +128,31 @@ def process_local_mode(data_path: Path):
114128
out_dir.mkdir(parents=True, exist_ok=True)
115129
labels: Dict[str, str] = {}
116130

131+
# Build work items: (pdf_path, label)
132+
work_items = []
117133
for folder, label in [(useful_dir, "useful"), (not_useful_dir, "not-useful")]:
118134
pdf_files = list(folder.glob("*.pdf"))
119135
print(f"Found {len(pdf_files)} PDFs in local folder '{label}'")
120-
121136
for pdf_path in pdf_files:
122-
try:
123-
with open(pdf_path, "rb") as f:
124-
pdf_bytes = f.read()
125-
text = extract_text_from_pdf_bytes(pdf_bytes)
126-
stem = pdf_path.stem
127-
txt_name = f"{stem}.txt"
137+
work_items.append((pdf_path, label))
138+
139+
if workers > 1 and len(work_items) > 1:
140+
print(f"[INFO] Using {workers} worker processes for PDF extraction.")
141+
with ProcessPoolExecutor(max_workers=workers) as executor:
142+
futures = {executor.submit(_extract_local_pdf, item): item for item in work_items}
143+
for future in as_completed(futures):
144+
txt_name, label, text = future.result()
145+
if text is not None:
146+
(out_dir / txt_name).write_text(text, encoding="utf-8")
147+
labels[txt_name] = label
148+
print(f"Processed {txt_name}")
149+
else:
150+
for item in work_items:
151+
txt_name, label, text = _extract_local_pdf(item)
152+
if text is not None:
128153
(out_dir / txt_name).write_text(text, encoding="utf-8")
129154
labels[txt_name] = label
130-
print(f"Processed {pdf_path.name}")
131-
except Exception as e:
132-
print(f"Error processing {pdf_path.name}: {e}")
133-
continue
155+
print(f"Processed {txt_name}")
134156

135157
write_labels(labels, Path("data/labels.json"))
136158
print(f"Wrote {len(labels)} labeled text files.")
@@ -152,11 +174,18 @@ def main():
152174
group.add_argument("--api", action="store_true", help="Use API mode to download PDFs from Google Drive")
153175
group.add_argument("--local", type=Path, metavar="PATH", help="Use local mode with PDFs from specified directory (should contain 'useful' and 'not-useful' subfolders)")
154176

177+
parser.add_argument(
178+
"--workers",
179+
type=int,
180+
default=1,
181+
help="Number of parallel worker processes for PDF extraction (default: 1 = sequential).",
182+
)
183+
155184
args = parser.parse_args()
156185

157186
if args.local:
158187
print(f"Running in LOCAL mode with data path: {args.local}")
159-
process_local_mode(args.local)
188+
process_local_mode(args.local, workers=args.workers)
160189
else: # args.api
161190
print("Running in API mode (Google Drive)")
162191
process_api_mode()

0 commit comments

Comments
 (0)