Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

# Abstra
# Abstra is an AI-powered process automation framework.
Expand Down
24 changes: 9 additions & 15 deletions graphai/celery/image/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
extract_slide_text_task,
extract_slide_text_callback_task,
convert_pdf_to_pages_task,
fanout_pdf_ocr_task,
extract_multi_image_text_task,
collect_multi_image_ocr_task
)
Expand Down Expand Up @@ -153,24 +154,17 @@ def ocr_job(
# OCR computation job
#####################
if is_pdf(token):
n_parallel = 8
task_list = [
convert_pdf_to_pages_task.s(token),
group(
extract_multi_image_text_task.s(
i,
n_parallel,
method,
google_api_token,
openai_api_token,
gemini_api_token,
rcp_api_token,
model_type,
enable_tikz,
)
for i in range(n_parallel)
fanout_pdf_ocr_task.s(
method,
google_api_token,
openai_api_token,
gemini_api_token,
rcp_api_token,
model_type,
enable_tikz,
),
collect_multi_image_ocr_task.s()
]
else:
task_list = [
Expand Down
53 changes: 45 additions & 8 deletions graphai/celery/image/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from celery import shared_task
from celery import shared_task, group, chord

from graphai.core.image.image import (
cache_lookup_retrieve_image_from_url,
Expand Down Expand Up @@ -146,6 +146,47 @@ def convert_pdf_to_pages_task(self, token):
return break_pdf_into_images(token, self.file_manager)


@shared_task(
bind=True,
autoretry_for=(Exception,),
retry_backoff=True,
retry_kwargs={"max_retries": 2},
name="image.fanout_pdf_ocr_task",
ignore_result=False,
)
def fanout_pdf_ocr_task(
self,
pages,
method,
google_api_token=None,
openai_api_token=None,
gemini_api_token=None,
rcp_api_token=None,
model_type=None,
enable_tikz=False,
):
# Build one OCR task per page
header = group(
extract_multi_image_text_task.s(
page,
method,
google_api_token,
openai_api_token,
gemini_api_token,
rcp_api_token,
model_type,
enable_tikz,
)
for page in pages
)

# When all pages are OCR'd, collect results
callback = collect_multi_image_ocr_task.s()

# Replace this task with the chord so the outer chain waits properly
raise self.replace(chord(header, callback))


@shared_task(
bind=True,
autoretry_for=(Exception,),
Expand All @@ -156,9 +197,7 @@ def convert_pdf_to_pages_task(self, token):
)
def extract_multi_image_text_task(
self,
page_and_filename_list,
i,
n,
page_and_filename,
method="google",
google_api_token=None,
openai_api_token=None,
Expand All @@ -167,11 +206,9 @@ def extract_multi_image_text_task(
model_type=None,
enable_tikz=False,
):
print(f'Starting {extract_multi_image_text_task} task for page_and_filename_list {page_and_filename_list}, i {i} and n {n}')
print(f'Starting {extract_multi_image_text_task} task for page_and_filename {page_and_filename}')
return extract_multi_image_text(
page_and_filename_list,
i,
n,
page_and_filename,
method,
google_api_token,
openai_api_token,
Expand Down
91 changes: 32 additions & 59 deletions graphai/core/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,30 +210,19 @@ def perform_ocr(
model_type=None,
enable_tikz=False,
):
ocr_colnames = get_ocr_colnames(method)

results = None
language = None
text = None

if method == 'tesseract':
res = perform_tesseract_ocr(file_path, language='enfr')
text = perform_tesseract_ocr(file_path, language='enfr')

if res:
language = detect_text_language(res)
results = [{'method': ocr_colnames[0], 'text': res}]
elif method == 'google' and google_api_token:
ocr_model = GoogleOCRModel(google_api_token)
ocr_model.establish_connection()
res1, res2 = ocr_model.perform_ocr(file_path)
text1, text2 = ocr_model.perform_ocr(file_path)

# Since DTD usually performs better, method #1 is our point of reference for langdetect
text = text1

if res1:
# Since DTD usually performs better, method #1 is our point of reference for langdetect
language = detect_text_language(res1)
res_list = [res1]
results = [
{'method': ocr_colnames[i], 'text': res_list[i]}
for i in range(len(res_list))
]
else:
ocr_model = None
if method == 'openai' and openai_api_token:
Expand All @@ -245,17 +234,14 @@ def perform_ocr(

if ocr_model:
ocr_model.establish_connection()
res = ocr_model.perform_ocr(
file_path, model_type=model_type, enable_tikz=enable_tikz
)
text = ocr_model.perform_ocr(file_path, model_type=model_type, enable_tikz=enable_tikz)

if res:
language = detect_text_language(res)
results = [{'method': ocr_colnames[0], 'text': res}]
if not text:
text = ''

return {
'results': results,
'language': language,
'results': [{'method': get_ocr_colnames(method)[0], 'text': text}],
'language': detect_text_language(text),
}


Expand Down Expand Up @@ -296,9 +282,7 @@ def extract_slide_text(


def extract_multi_image_text(
page_and_filename_list,
i,
n,
page_and_filename,
method="google",
google_api_token=None,
openai_api_token=None,
Expand All @@ -307,44 +291,33 @@ def extract_multi_image_text(
model_type=None,
enable_tikz=False,
):
# Extract subset of pages to process
n_pages = len(page_and_filename_list)
start_index = int(i / n * n_pages)
end_index = int((i + 1) / n * n_pages)
pages_to_handle = page_and_filename_list[start_index: end_index]

# Perform OCR on subset of pages
results = list()
for page in pages_to_handle:
results.append(
perform_ocr(
page["filename"],
method,
google_api_token,
openai_api_token,
gemini_api_token,
rcp_api_token,
model_type,
enable_tikz,
)
)
# Perform OCR on page
result = perform_ocr(
page_and_filename["filename"],
method,
google_api_token,
openai_api_token,
gemini_api_token,
rcp_api_token,
model_type,
enable_tikz,
)

print(f"Performed OCR on page {page_and_filename['page']}. Result: {result}")

# Build result and return it
return {
'results': [
{
'page': pages_to_handle[i]['page'],
'content': results[i]['results'][0]['text']
}
for i in range(len(results))
],
'language': get_most_common_element([result['language'] for result in results]),
'method': get_most_common_element([result['results'][0]['method'] for result in results])
'result': {
'page': page_and_filename['page'],
'content': result['results'][0]['text']
},
'language': result['language'],
'method': result['results'][0]['method'],
}


def collect_multi_image_ocr(results):
all_results = list(chain.from_iterable(result['results'] for result in results))
all_results = [result['result'] for result in results]
language = get_most_common_element([result['language'] for result in results])
method = get_most_common_element([result['method'] for result in results])
return {
Expand Down
1 change: 0 additions & 1 deletion graphai/core/image/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ def perform_ocr(self, input_filename_with_path, model_type=None, **kwargs):
response = self.model.chat.completions.create(model=model_type, messages=messages, response_format={"type": "json_object"})
print(f'Got {response}')
content = response.choices[0].message.content.strip()
print(f'Got {content}')

# Strip thinking tokens
thinking_tag = '</think>'
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers = [
"Operating System :: OS Independent"
]
dependencies = [
"loguru"
"loguru",
"numpy",
"scipy",
"pandas",
Expand Down
Loading