diff --git a/.gitignore b/.gitignore index 14f17fe9..2ac56777 100644 --- a/.gitignore +++ b/.gitignore @@ -176,7 +176,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ # Abstra # Abstra is an AI-powered process automation framework. diff --git a/graphai/celery/image/jobs.py b/graphai/celery/image/jobs.py index cd1c7b76..dfecfad6 100644 --- a/graphai/celery/image/jobs.py +++ b/graphai/celery/image/jobs.py @@ -11,6 +11,7 @@ extract_slide_text_task, extract_slide_text_callback_task, convert_pdf_to_pages_task, + fanout_pdf_ocr_task, extract_multi_image_text_task, collect_multi_image_ocr_task ) @@ -153,24 +154,17 @@ def ocr_job( # OCR computation job ##################### if is_pdf(token): - n_parallel = 8 task_list = [ convert_pdf_to_pages_task.s(token), - group( - extract_multi_image_text_task.s( - i, - n_parallel, - method, - google_api_token, - openai_api_token, - gemini_api_token, - rcp_api_token, - model_type, - enable_tikz, - ) - for i in range(n_parallel) + fanout_pdf_ocr_task.s( + method, + google_api_token, + openai_api_token, + gemini_api_token, + rcp_api_token, + model_type, + enable_tikz, ), - collect_multi_image_ocr_task.s() ] else: task_list = [ diff --git a/graphai/celery/image/tasks.py b/graphai/celery/image/tasks.py index f79e3913..7a12350b 100644 --- a/graphai/celery/image/tasks.py +++ b/graphai/celery/image/tasks.py @@ -1,4 +1,4 @@ -from celery import shared_task +from celery import shared_task, group, chord from graphai.core.image.image import ( cache_lookup_retrieve_image_from_url, @@ -146,6 +146,47 @@ def convert_pdf_to_pages_task(self, token): return break_pdf_into_images(token, self.file_manager) +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="image.fanout_pdf_ocr_task", + ignore_result=False, +) +def fanout_pdf_ocr_task( + self, + pages, + method, + google_api_token=None, + openai_api_token=None, + gemini_api_token=None, + rcp_api_token=None, + model_type=None, + enable_tikz=False, +): + # Build one OCR task per page + header = group( + extract_multi_image_text_task.s( + page, + method, + google_api_token, + openai_api_token, + gemini_api_token, + rcp_api_token, + model_type, + enable_tikz, + ) + for page in pages + ) + + # When all pages are OCR'd, collect results + callback = collect_multi_image_ocr_task.s() + + # Replace this task with the chord so the outer chain waits properly + raise self.replace(chord(header, callback)) + + @shared_task( bind=True, autoretry_for=(Exception,), @@ -156,9 +197,7 @@ def convert_pdf_to_pages_task(self, token): ) def extract_multi_image_text_task( self, - page_and_filename_list, - i, - n, + page_and_filename, method="google", google_api_token=None, openai_api_token=None, @@ -167,11 +206,9 @@ def extract_multi_image_text_task( model_type=None, enable_tikz=False, ): - print(f'Starting {extract_multi_image_text_task} task for page_and_filename_list {page_and_filename_list}, i {i} and n {n}') + print(f'Starting {extract_multi_image_text_task} task for page_and_filename {page_and_filename}') return extract_multi_image_text( - page_and_filename_list, - i, - n, + page_and_filename, method, google_api_token, openai_api_token, diff --git a/graphai/core/image/image.py b/graphai/core/image/image.py index ddac7abc..80b6bee6 100644 --- a/graphai/core/image/image.py +++ b/graphai/core/image/image.py @@ -210,30 +210,19 @@ def perform_ocr( model_type=None, enable_tikz=False, ): - ocr_colnames = get_ocr_colnames(method) - - results = None - language = None + text = None if method == 'tesseract': - res = perform_tesseract_ocr(file_path, language='enfr') + text = perform_tesseract_ocr(file_path, language='enfr') - if res: - language = detect_text_language(res) - results = [{'method': ocr_colnames[0], 'text': res}] elif method == 'google' and google_api_token: ocr_model = GoogleOCRModel(google_api_token) ocr_model.establish_connection() - res1, res2 = ocr_model.perform_ocr(file_path) + text1, text2 = ocr_model.perform_ocr(file_path) + + # Since DTD usually performs better, method #1 is our point of reference for langdetect + text = text1 - if res1: - # Since DTD usually performs better, method #1 is our point of reference for langdetect - language = detect_text_language(res1) - res_list = [res1] - results = [ - {'method': ocr_colnames[i], 'text': res_list[i]} - for i in range(len(res_list)) - ] else: ocr_model = None if method == 'openai' and openai_api_token: @@ -245,17 +234,14 @@ def perform_ocr( if ocr_model: ocr_model.establish_connection() - res = ocr_model.perform_ocr( - file_path, model_type=model_type, enable_tikz=enable_tikz - ) + text = ocr_model.perform_ocr(file_path, model_type=model_type, enable_tikz=enable_tikz) - if res: - language = detect_text_language(res) - results = [{'method': ocr_colnames[0], 'text': res}] + if not text: + text = '' return { - 'results': results, - 'language': language, + 'results': [{'method': get_ocr_colnames(method)[0], 'text': text}], + 'language': detect_text_language(text), } @@ -296,9 +282,7 @@ def extract_slide_text( def extract_multi_image_text( - page_and_filename_list, - i, - n, + page_and_filename, method="google", google_api_token=None, openai_api_token=None, @@ -307,44 +291,33 @@ def extract_multi_image_text( model_type=None, enable_tikz=False, ): - # Extract subset of pages to process - n_pages = len(page_and_filename_list) - start_index = int(i / n * n_pages) - end_index = int((i + 1) / n * n_pages) - pages_to_handle = page_and_filename_list[start_index: end_index] - - # Perform OCR on subset of pages - results = list() - for page in pages_to_handle: - results.append( - perform_ocr( - page["filename"], - method, - google_api_token, - openai_api_token, - gemini_api_token, - rcp_api_token, - model_type, - enable_tikz, - ) - ) + # Perform OCR on page + result = perform_ocr( + page_and_filename["filename"], + method, + google_api_token, + openai_api_token, + gemini_api_token, + rcp_api_token, + model_type, + enable_tikz, + ) + + print(f"Performed OCR on page {page_and_filename['page']}. Result: {result}") # Build result and return it return { - 'results': [ - { - 'page': pages_to_handle[i]['page'], - 'content': results[i]['results'][0]['text'] - } - for i in range(len(results)) - ], - 'language': get_most_common_element([result['language'] for result in results]), - 'method': get_most_common_element([result['results'][0]['method'] for result in results]) + 'result': { + 'page': page_and_filename['page'], + 'content': result['results'][0]['text'] + }, + 'language': result['language'], + 'method': result['results'][0]['method'], } def collect_multi_image_ocr(results): - all_results = list(chain.from_iterable(result['results'] for result in results)) + all_results = [result['result'] for result in results] language = get_most_common_element([result['language'] for result in results]) method = get_most_common_element([result['method'] for result in results]) return { diff --git a/graphai/core/image/ocr.py b/graphai/core/image/ocr.py index 9facde0f..e2f2d960 100644 --- a/graphai/core/image/ocr.py +++ b/graphai/core/image/ocr.py @@ -342,7 +342,6 @@ def perform_ocr(self, input_filename_with_path, model_type=None, **kwargs): response = self.model.chat.completions.create(model=model_type, messages=messages, response_format={"type": "json_object"}) print(f'Got {response}') content = response.choices[0].message.content.strip() - print(f'Got {content}') # Strip thinking tokens thinking_tag = '' diff --git a/pyproject.toml b/pyproject.toml index f800d77c..90390c6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ "Operating System :: OS Independent" ] dependencies = [ - "loguru" + "loguru", "numpy", "scipy", "pandas",