diff --git a/.gitignore b/.gitignore index 14f17fe9..2ac56777 100644 --- a/.gitignore +++ b/.gitignore @@ -176,7 +176,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ # Abstra # Abstra is an AI-powered process automation framework. diff --git a/docker-compose.yml b/docker-compose.yml index af1a9d1b..8ffb8b1e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,6 +15,7 @@ services: RABBITMQ_DEFAULT_PASS: ${AI_RABBITMQ_DEV_CELERY_PASS} volumes: - rabbitmq_data:/var/lib/rabbitmq + - ./rabbitmq/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf:ro # Redis with AOF persistence redis: @@ -99,7 +100,7 @@ services: network_mode: "host" volumes: - ./nginx.graphai-https.conf:/etc/nginx/conf.d/default.conf:ro - - ./.certs:/etc/nginx/certs:ro + - /home/dockerhost/graphcert.cede-apps.ch:/etc/nginx/certs:ro restart: unless-stopped networks: diff --git a/graphai/api/image/router.py b/graphai/api/image/router.py index 71e2ebb2..b384ee18 100644 --- a/graphai/api/image/router.py +++ b/graphai/api/image/router.py @@ -129,12 +129,24 @@ async def extract_text(data: ExtractTextRequest): method = data.method force = data.force no_cache = data.no_cache - api_token = data.google_api_token - openai_token = data.openai_api_token - gemini_token = data.gemini_api_token + google_api_token = data.google_api_token + openai_api_token = data.openai_api_token + gemini_api_token = data.gemini_api_token + rcp_api_token = data.rcp_api_token model_type = data.model_type enable_tikz = data.enable_tikz - task_id = ocr_job(token, force, no_cache, method, api_token, openai_token, gemini_token, model_type, enable_tikz) + task_id = ocr_job( + token, + force=force, + no_cache=no_cache, + method=method, + google_api_token=google_api_token, + openai_api_token=openai_api_token, + gemini_api_token=gemini_api_token, + rcp_api_token=rcp_api_token, + model_type=model_type, + enable_tikz=enable_tikz, + ) return {'task_id': task_id} diff --git a/graphai/api/image/schemas.py b/graphai/api/image/schemas.py index c15664dd..a1ae6e3d 100644 --- a/graphai/api/image/schemas.py +++ b/graphai/api/image/schemas.py @@ -166,10 +166,9 @@ class ExtractTextRequest(BaseModel): description="The token that identifies the requested file" ) - method: Literal['google', 'tesseract', 'openai', 'gemini'] = Field( + method: Literal['tesseract', 'google', 'openai', 'gemini', 'rcp'] = Field( title="Method", - description="OCR method. Available methods are 'google' (default), 'openai', 'gemini'," - "and 'tesseract' (not recommended)", + description="OCR method. Available methods are 'tesseract' (not recommended), 'google' (default), 'openai', 'gemini' and 'rcp'", default="google" ) @@ -187,36 +186,37 @@ class ExtractTextRequest(BaseModel): google_api_token: Union[str, None] = Field( title="Google API token", - description="Token that authenticates the user on the Google OCR API." - "Without a valid token, Google OCR will fail. Not required for Tesseract, OpenAI, or Gemini.", + description="Token that authenticates the user on the Google OCR API. Without a valid token, Google OCR will fail. Only required for method 'google'.", default=None ) openai_api_token: Union[str, None] = Field( title="OpenAI API token", - description="Token that authenticates the user on the OpenAI API." - "Without a valid token, OpenAI OCR will fail. Not required for Tesseract, Google, or Gemini.", + description="Token that authenticates the user on the OpenAI API. Without a valid token, OpenAI OCR will fail. Only required for method 'openai'.", default=None ) gemini_api_token: Union[str, None] = Field( title="Gemini API token", - description="Token that authenticates the user on the Gemini API." - "Without a valid token, Gemini OCR will fail. Not required for Tesseract, Google, or OpenAI.", + description="Token that authenticates the user on the Gemini API. Without a valid token, Gemini OCR will fail. Only required for method 'gemini'.", + default=None + ) + + rcp_api_token: Union[str, None] = Field( + title="RCP API token", + description="Token that authenticates the user on the RCP platform. Without a valid token, RCP OCR will fail. Only required for method 'rcp'.", default=None ) model_type: Union[str, None] = Field( title="Model type", - description="For OpenAI and Gemini options, allows the user to specify the model that they want to use. " - "Do not specify this option unless you know exactly what you are doing.", + description="For LLM-based options, allows the user to specify the model that they want to use. Do not specify this option unless you know exactly what you are doing.", default=None ) enable_tikz: bool = Field( title="Enable TikZ", - description="For PDF OCR, if True, attempts to extract any figures as valid TikZ. If False, " - "replaces the figures with an alt text describing them instead.", + description="For PDF OCR, if True, attempts to extract any figures as valid TikZ. If False, replaces the figures with an alt text describing them instead.", default=False ) diff --git a/graphai/celery/image/jobs.py b/graphai/celery/image/jobs.py index 56243a01..dfecfad6 100644 --- a/graphai/celery/image/jobs.py +++ b/graphai/celery/image/jobs.py @@ -11,6 +11,7 @@ extract_slide_text_task, extract_slide_text_callback_task, convert_pdf_to_pages_task, + fanout_pdf_ocr_task, extract_multi_image_text_task, collect_multi_image_ocr_task ) @@ -24,39 +25,77 @@ def retrieve_image_from_url_job(url, force=False, no_cache=False): + ################## + # Cache lookup + ################## if not force: - direct_lookup_task_id = direct_lookup_generic_job(cache_lookup_retrieve_image_from_url_task, url, - False, DEFAULT_TIMEOUT) + direct_lookup_task_id = direct_lookup_generic_job( + cache_lookup_retrieve_image_from_url_task, + url, + False, + DEFAULT_TIMEOUT + ) if direct_lookup_task_id is not None: return direct_lookup_task_id + ################## + # Retrieve image + ################## # First retrieve the file, and then do the database callback - task_list = [retrieve_image_from_url_task.s(url, None), - retrieve_image_from_url_callback_task.s(url)] - if not no_cache: - task_list += get_slide_fingerprint_chain_list(None, None, ignore_fp_results=True) - else: + task_list = [ + retrieve_image_from_url_task.s(url, None), + retrieve_image_from_url_callback_task.s(url), + ] + + if no_cache: task_list += [add_token_status_to_single_image_results_callback_task.s()] + else: + task_list += get_slide_fingerprint_chain_list( + token=None, + origin_token=None, + ignore_fp_results=True, + ) + task = chain(task_list) task = task.apply_async(priority=2) return task.id -def upload_image_from_file_job(contents, file_extension, origin, origin_info, force=False, no_cache=False): +def upload_image_from_file_job( + contents, + file_extension, + origin, + origin_info, + force=False, + no_cache=False, +): effective_url = create_origin_token_using_info(origin, origin_info) + if not force: - direct_lookup_task_id = direct_lookup_generic_job(cache_lookup_retrieve_image_from_url_task, effective_url, - False, DEFAULT_TIMEOUT) + direct_lookup_task_id = direct_lookup_generic_job( + cache_lookup_retrieve_image_from_url_task, + effective_url, + False, + DEFAULT_TIMEOUT, + ) + if direct_lookup_task_id is not None: return direct_lookup_task_id + task_list = [ upload_image_from_file_task.s(contents, file_extension), - retrieve_image_from_url_callback_task.s(effective_url) + retrieve_image_from_url_callback_task.s(effective_url), ] - if not no_cache: - task_list += get_slide_fingerprint_chain_list(None, None, ignore_fp_results=True) - else: + + if no_cache: task_list += [add_token_status_to_single_image_results_callback_task.s()] + else: + task_list += get_slide_fingerprint_chain_list( + token=None, + origin_token=None, + ignore_fp_results=True + ) + task = chain(task_list) task = task.apply_async(priority=2) return task.id @@ -85,45 +124,72 @@ def fingerprint_job(token, force): return task.id -def ocr_job(token, force=False, no_cache=False, method='google', - api_token=None, openai_token=None, gemini_token=None, - model_type=None, enable_tikz=True): +def ocr_job( + token, + force=False, + no_cache=False, + method='google', + google_api_token=None, + openai_api_token=None, + gemini_api_token=None, + rcp_api_token=None, + model_type=None, + enable_tikz=False, +): ################## # OCR cache lookup ################## if not force and not no_cache: - direct_lookup_task_id = direct_lookup_generic_job(cache_lookup_extract_slide_text_task, token, - False, DEFAULT_TIMEOUT, method) + direct_lookup_task_id = direct_lookup_generic_job( + cache_lookup_extract_slide_text_task, + token, + False, + DEFAULT_TIMEOUT, + method, + ) if direct_lookup_task_id is not None: return direct_lookup_task_id ##################### # OCR computation job ##################### - if not is_pdf(token): + if is_pdf(token): task_list = [ - extract_slide_text_task.s(token, method, - api_token, openai_token, gemini_token, model_type, enable_tikz) + convert_pdf_to_pages_task.s(token), + fanout_pdf_ocr_task.s( + method, + google_api_token, + openai_api_token, + gemini_api_token, + rcp_api_token, + model_type, + enable_tikz, + ), ] else: - n_parallel = 8 task_list = [ - convert_pdf_to_pages_task.s(token), - group( - extract_multi_image_text_task.s(i, - n_parallel, - method, - api_token, - openai_token, - gemini_token, - model_type, - enable_tikz) - for i in range(n_parallel) - ), - collect_multi_image_ocr_task.s() + extract_slide_text_task.s( + token, + method, + google_api_token, + openai_api_token, + gemini_api_token, + rcp_api_token, + model_type, + enable_tikz, + ) ] + + ################## + # OCR cache write + ################## if not no_cache: task_list.append(extract_slide_text_callback_task.s(token, force)) + + ################## + # Run task list + ################## task = chain(task_list) task = task.apply_async(priority=2) + return task.id diff --git a/graphai/celery/image/tasks.py b/graphai/celery/image/tasks.py index 1c4e64cf..7a12350b 100644 --- a/graphai/celery/image/tasks.py +++ b/graphai/celery/image/tasks.py @@ -1,4 +1,4 @@ -from celery import shared_task +from celery import shared_task, group, chord from graphai.core.image.image import ( cache_lookup_retrieve_image_from_url, @@ -21,98 +21,224 @@ file_management_config = VideoConfig() -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='caching.cache_lookup_retrieve_image', ignore_result=False, - file_manager=file_management_config) +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="caching.cache_lookup_retrieve_image", + ignore_result=False, + file_manager=file_management_config, +) def cache_lookup_retrieve_image_from_url_task(self, url): return cache_lookup_retrieve_image_from_url(url, self.file_manager) -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='image.retrieve_image', ignore_result=False, - file_manager=file_management_config) +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="image.retrieve_image", + ignore_result=False, + file_manager=file_management_config, +) def retrieve_image_from_url_task(self, url, force_token=None): return retrieve_image_file_from_url(url, self.file_manager, force_token) -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='image.upload_image', ignore_result=False, - file_manager=file_management_config) +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="image.upload_image", + ignore_result=False, + file_manager=file_management_config, +) def upload_image_from_file_task(self, contents, file_extension): return upload_image_from_file(contents, file_extension, self.file_manager) -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='image.retrieve_image_callback', ignore_result=False, - file_manager=file_management_config) +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="image.retrieve_image_callback", + ignore_result=False, + file_manager=file_management_config, +) def retrieve_image_from_url_callback_task(self, results, url): return retrieve_image_file_from_url_callback(results, url) -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='caching.cache_lookup_fingerprint_slide', ignore_result=False) +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="caching.cache_lookup_fingerprint_slide", + ignore_result=False, +) def cache_lookup_slide_fingerprint_task(self, token): - return fingerprint_cache_lookup_with_most_similar(token, SlideDBCachingManager(), None) - - -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='caching.cache_lookup_extract_slide_text', ignore_result=False) -def cache_lookup_extract_slide_text_task(self, token, method='tesseract'): + return fingerprint_cache_lookup_with_most_similar( + token, SlideDBCachingManager(), None + ) + + +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="caching.cache_lookup_extract_slide_text", + ignore_result=False, +) +def cache_lookup_extract_slide_text_task(self, token, method="tesseract"): return cache_lookup_extract_slide_text(token, method) -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='image.extract_slide_text', ignore_result=False, - file_manager=file_management_config) -def extract_slide_text_task(self, token, method='google', api_token=None, openai_token=None, gemini_token=None, - model_type=None, enable_tikz=True): - return extract_slide_text(token, - self.file_manager, - method, - api_token, - openai_token, - gemini_token, - model_type, - enable_tikz) - - -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='image.pdf_to_pages', ignore_result=False, - file_manager=file_management_config) +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="image.extract_slide_text", + ignore_result=False, + file_manager=file_management_config, +) +def extract_slide_text_task( + self, + token, + method="google", + google_api_token=None, + openai_api_token=None, + gemini_api_token=None, + rcp_api_token=None, + model_type=None, + enable_tikz=False, +): + return extract_slide_text( + token, + self.file_manager, + method, + google_api_token, + openai_api_token, + gemini_api_token, + rcp_api_token, + model_type, + enable_tikz, + ) + + +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="image.pdf_to_pages", + ignore_result=False, + file_manager=file_management_config, +) def convert_pdf_to_pages_task(self, token): + print(f'Starting {convert_pdf_to_pages_task} task for token {token}') return break_pdf_into_images(token, self.file_manager) -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='image.extract_multi_image_text', ignore_result=False) -def extract_multi_image_text_task(self, - page_and_filename_list, - i, - n, - method='google', - api_token=None, - openai_token=None, - gemini_token=None, - model_type=None, - enable_tikz=True): - return extract_multi_image_text(page_and_filename_list, - i, - n, - method, - api_token, - openai_token, - gemini_token, - model_type, - enable_tikz) - - -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='image.extract_multi_image_text_callback', ignore_result=False) +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="image.fanout_pdf_ocr_task", + ignore_result=False, +) +def fanout_pdf_ocr_task( + self, + pages, + method, + google_api_token=None, + openai_api_token=None, + gemini_api_token=None, + rcp_api_token=None, + model_type=None, + enable_tikz=False, +): + # Build one OCR task per page + header = group( + extract_multi_image_text_task.s( + page, + method, + google_api_token, + openai_api_token, + gemini_api_token, + rcp_api_token, + model_type, + enable_tikz, + ) + for page in pages + ) + + # When all pages are OCR'd, collect results + callback = collect_multi_image_ocr_task.s() + + # Replace this task with the chord so the outer chain waits properly + raise self.replace(chord(header, callback)) + + +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="image.extract_multi_image_text", + ignore_result=False, +) +def extract_multi_image_text_task( + self, + page_and_filename, + method="google", + google_api_token=None, + openai_api_token=None, + gemini_api_token=None, + rcp_api_token=None, + model_type=None, + enable_tikz=False, +): + print(f'Starting {extract_multi_image_text_task} task for page_and_filename {page_and_filename}') + return extract_multi_image_text( + page_and_filename, + method, + google_api_token, + openai_api_token, + gemini_api_token, + rcp_api_token, + model_type, + enable_tikz, + ) + + +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="image.extract_multi_image_text_callback", + ignore_result=False, +) def collect_multi_image_ocr_task(self, results): + print(f'Starting {collect_multi_image_ocr_task} task for results {results}') return collect_multi_image_ocr(results) -@shared_task(bind=True, autoretry_for=(Exception,), retry_backoff=True, retry_kwargs={"max_retries": 2}, - name='image.extract_slide_text_callback', ignore_result=False) +@shared_task( + bind=True, + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 2}, + name="image.extract_slide_text_callback", + ignore_result=False, +) def extract_slide_text_callback_task(self, results, token, force=False): return extract_slide_text_callback(results, token, force) diff --git a/graphai/core/image/image.py b/graphai/core/image/image.py index cc7cb6df..80b6bee6 100644 --- a/graphai/core/image/image.py +++ b/graphai/core/image/image.py @@ -15,7 +15,8 @@ get_ocr_colnames, GoogleOCRModel, OpenAIOCRModel, - GeminiOCRModel + GeminiOCRModel, + RCPOCRModel, ) import pymupdf from graphai.core.common.common_utils import ( @@ -199,146 +200,124 @@ def break_pdf_into_images(token, file_manager): return output_filenames -def perform_ocr(file_path, - method='google', - api_token=None, - openai_token=None, - gemini_token=None, - model_type=None, - enable_tikz=True): - ocr_colnames = get_ocr_colnames(method) +def perform_ocr( + file_path, + method="google", + google_api_token=None, + openai_api_token=None, + gemini_api_token=None, + rcp_api_token=None, + model_type=None, + enable_tikz=False, +): + text = None if method == 'tesseract': - res = perform_tesseract_ocr(file_path, language='enfr') - if res is None: - results = None - language = None - else: - language = detect_text_language(res) - results = [ - { - 'method': ocr_colnames[0], - 'text': res - } - ] + text = perform_tesseract_ocr(file_path, language='enfr') + + elif method == 'google' and google_api_token: + ocr_model = GoogleOCRModel(google_api_token) + ocr_model.establish_connection() + text1, text2 = ocr_model.perform_ocr(file_path) + + # Since DTD usually performs better, method #1 is our point of reference for langdetect + text = text1 + else: - if method == 'google': - # Google OCR - if api_token is None: - results = None - language = None - else: - ocr_model = GoogleOCRModel(api_token) - ocr_model.establish_connection() - res1, res2 = ocr_model.perform_ocr(file_path) - - if res1 is None: - results = None - language = None - else: - # Since DTD usually performs better, method #1 is our point of reference for langdetect - language = detect_text_language(res1) - res_list = [res1] - results = [ - { - 'method': ocr_colnames[i], - 'text': res_list[i] - } - for i in range(len(res_list)) - ] - else: - if method == 'openai': - # OpenAI OCR - if openai_token is None: - ocr_model = None - else: - ocr_model = OpenAIOCRModel(openai_token) - else: - # Gemini OCR - if gemini_token is None: - ocr_model = None - else: - ocr_model = GeminiOCRModel(gemini_token) - if ocr_model is not None: - ocr_model.establish_connection() - res = ocr_model.perform_ocr(file_path, model_type=model_type, enable_tikz=enable_tikz) - if res is None: - results = None - language = None - else: - language = detect_text_language(res) - results = [ - { - 'method': ocr_colnames[0], - 'text': res - } - ] - else: - results = None - language = None + ocr_model = None + if method == 'openai' and openai_api_token: + ocr_model = OpenAIOCRModel(openai_api_token) + elif method == 'gemini' and gemini_api_token: + ocr_model = GeminiOCRModel(gemini_api_token) + elif method == 'rcp' and rcp_api_token: + ocr_model = RCPOCRModel(rcp_api_token) + + if ocr_model: + ocr_model.establish_connection() + text = ocr_model.perform_ocr(file_path, model_type=model_type, enable_tikz=enable_tikz) + + if not text: + text = '' return { - 'results': results, - 'language': language, + 'results': [{'method': get_ocr_colnames(method)[0], 'text': text}], + 'language': detect_text_language(text), } -def extract_slide_text(token, - file_manager, - method='google', - api_token=None, - openai_token=None, - gemini_token=None, - model_type=None, - enable_tikz=True): +def extract_slide_text( + token, + file_manager, + method="google", + google_api_token=None, + openai_api_token=None, + gemini_api_token=None, + rcp_api_token=None, + model_type=None, + enable_tikz=False, +): + # Return no results if not a token if not is_token(token): return { 'results': None, 'language': None, 'fresh': False } + + # Perform OCR file_path = file_manager.generate_filepath(token) - res = perform_ocr(file_path, method, api_token, openai_token, gemini_token, model_type, enable_tikz) - res['fresh'] = res['results'] is not None + res = perform_ocr( + file_path, + method, + google_api_token, + openai_api_token, + gemini_api_token, + rcp_api_token, + model_type, + enable_tikz, + ) + res["fresh"] = res["results"] is not None return res -def extract_multi_image_text(page_and_filename_list, - i, - n, - method='google', - api_token=None, - openai_token=None, - gemini_token=None, - model_type=None, - enable_tikz=True): - n_pages = len(page_and_filename_list) - start_index = int(i / n * n_pages) - end_index = int((i + 1) / n * n_pages) - pages_to_handle = page_and_filename_list[start_index: end_index] - results = list() - for page in pages_to_handle: - results.append( - perform_ocr( - page['filename'], method, api_token, openai_token, gemini_token, model_type, enable_tikz - ) - ) +def extract_multi_image_text( + page_and_filename, + method="google", + google_api_token=None, + openai_api_token=None, + gemini_api_token=None, + rcp_api_token=None, + model_type=None, + enable_tikz=False, +): + # Perform OCR on page + result = perform_ocr( + page_and_filename["filename"], + method, + google_api_token, + openai_api_token, + gemini_api_token, + rcp_api_token, + model_type, + enable_tikz, + ) + + print(f"Performed OCR on page {page_and_filename['page']}. Result: {result}") + + # Build result and return it return { - 'results': [ - { - 'page': pages_to_handle[i]['page'], - 'content': results[i]['results'][0]['text'] - } - for i in range(len(results)) - ], - 'language': get_most_common_element([result['language'] for result in results]), - 'method': get_most_common_element([result['results'][0]['method'] for result in results]) + 'result': { + 'page': page_and_filename['page'], + 'content': result['results'][0]['text'] + }, + 'language': result['language'], + 'method': result['results'][0]['method'], } def collect_multi_image_ocr(results): - all_results = list(chain.from_iterable(result['results'] for result in results)) + all_results = [result['result'] for result in results] language = get_most_common_element([result['language'] for result in results]) method = get_most_common_element([result['method'] for result in results]) return { diff --git a/graphai/core/image/ocr.py b/graphai/core/image/ocr.py index 82dbe103..e2f2d960 100644 --- a/graphai/core/image/ocr.py +++ b/graphai/core/image/ocr.py @@ -20,7 +20,7 @@ import base64 -def get_ocr_prompt(enable_tikz=True): +def get_ocr_prompt(enable_tikz=False): if enable_tikz: figure_prompt_section = """ Figures are to be extracted as valid TikZ within LaTeX, inside \\begin{tikzpicture} and @@ -58,6 +58,22 @@ def get_ocr_prompt(enable_tikz=True): return ocr_prompt +def get_ocr_messages(image_path): + # Convert image to data uri + img_b64_str = ImgToBase64Converter(image_path).get_base64() + img_type = f'image/{image_path.split(".")[-1]}' + img_uri = f"data:{img_type};base64,{img_b64_str}" + + messages = [ + {"role": "user", "content": [ + {"type": "text", "text": get_ocr_prompt()}, + {"type": "image_url", "image_url": {"url": img_uri}}, + ]} + ] + + return messages + + def is_valid_latex(text): try: s = LatexNodes2Text().latex_to_text(text, tolerant_parsing=False) @@ -95,10 +111,7 @@ def __init__(self, api_key, model_class, model_name): self.model_params = None if self.api_key is None: - print( - f"No {model_name} API key was provided. " - f"{model_name} API endpoints cannot be used as there is no default API key." - ) + print(f"No {model_name} API key was provided. {model_name} API endpoints cannot be used as there is no default API key.") self.model = None self.load_lock = Lock() @@ -305,6 +318,48 @@ def perform_ocr(self, input_filename_with_path, model_type=None, **kwargs): return cleanup_json(response.text) +class RCPOCRModel(AbstractOCRModel): + def __init__(self, api_key): + super().__init__(api_key, OpenAI, "RCP") + self.model_params = dict( + base_url='https://inference.rcp.epfl.ch/v1', + api_key=self.api_key, + ) + + def perform_ocr(self, input_filename_with_path, model_type=None, **kwargs): + model_loaded = self.establish_connection() + + if not model_loaded: + return None + + if model_type is None: + model_type = "Qwen/Qwen3-VL-235B-A22B-Thinking-fp8" + + messages = get_ocr_messages(input_filename_with_path) + + try: + print(f'Performing OCR on RCP for file {input_filename_with_path}') + response = self.model.chat.completions.create(model=model_type, messages=messages, response_format={"type": "json_object"}) + print(f'Got {response}') + content = response.choices[0].message.content.strip() + + # Strip thinking tokens + thinking_tag = '' + if thinking_tag in response: + content = content.split(thinking_tag)[-1].strip() + + # Try to parse json and extract text, otherwise keep as is + try: + content = json.loads(content)['text'] + except Exception: + pass + + return content + except Exception as e: + print(e) + return None + + def get_ocr_colnames(method): if method == 'tesseract': return ['ocr_tesseract_results'] @@ -312,8 +367,12 @@ def get_ocr_colnames(method): return ['ocr_google_1_results', 'ocr_google_2_results'] elif method == 'openai': return ['ocr_openai_results'] - else: + elif method == 'gemini': return ['ocr_gemini_results'] + elif method == 'rcp': + return ['ocr_rcp_results'] + else: + raise ValueError(f'Unexpected method {method}') def perform_tesseract_ocr_on_pdf(pdf_path, language=None, in_pages=True): diff --git a/nginx.graphai-https.conf b/nginx.graphai-https.conf index c6503984..cd7d231c 100644 --- a/nginx.graphai-https.conf +++ b/nginx.graphai-https.conf @@ -5,64 +5,70 @@ # Redirect HTTP -> HTTPS server { listen 80; - listen [::]:80; - server_name graphai.epfl.ch _; + #listen [::]:80; + server_name graphai.graphcert.cede-apps.ch; return 301 https://$host$request_uri; } server { listen 443 ssl; - listen [::]:443 ssl; + #listen [::]:443 ssl; http2 on; - server_name graphai.epfl.ch _; + server_name graphai.graphcert.cede-apps.ch; - # --- TLS certs (your own certs) --- - # http.crt should be the server certificate (ideally fullchain). - # If your http.crt is NOT a full chain, see the note below about fullchain.pem. - ssl_certificate /etc/nginx/certs/graphai-fullchain.pem; - ssl_certificate_key /etc/nginx/certs/graphai-http.key; + add_header Content-Security-Policy "frame-ancestors 'self' graphai.graphcert.cede-apps.ch;"; - # (Optional but recommended) Let nginx send the full chain to clients - # If you have a separate intermediate/CA bundle, build a fullchain (see notes). - ssl_trusted_certificate /etc/nginx/certs/ca.crt; + ssl_certificate /etc/nginx/certs/fullchain.pem; + ssl_certificate_key /etc/nginx/certs/privkey.pem; - # --- TLS hardening --- - ssl_protocols TLSv1.2 TLSv1.3; - ssl_prefer_server_ciphers off; + ssl_protocols TLSv1 TLSv1.1 TLSv1.2; + ssl_prefer_server_ciphers on; + # ssl_dhparam /etc/ssl/certs/dhparam.pem; + ssl_ciphers 'ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:DHE-DSS-AES128-GCM-SHA256:kEDH+AESGCM:ECDHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA:ECDHE-ECDSA-AES256-SHA:DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-DSS-AES128-SHA256:DHE-RSA-AES256-SHA256:DHE-DSS-AES256-SHA:DHE-RSA-AES256-SHA:AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:AES:CAMELLIA:DES-CBC3-SHA:!aNULL:!eNULL:!EXPORT:!DES:!RC4:!MD5:!PSK:!aECDH:!EDH-DSS-DES-CBC3-SHA:!EDH-RSA-DES-CBC3-SHA:!KRB5-DES-CBC3-SHA'; + ssl_session_timeout 1d; + ssl_session_cache shared:SSL:50m; + ssl_stapling off; + ssl_stapling_verify on; + add_header Strict-Transport-Security max-age=15768000; - ssl_session_cache shared:SSL:10m; - ssl_session_timeout 10m; + # Increase if you upload large files + client_max_body_size 200m; - # Reasonable security headers (safe defaults) - add_header X-Content-Type-Options nosniff always; - add_header X-Frame-Options SAMEORIGIN always; - add_header Referrer-Policy strict-origin-when-cross-origin always; + # # Proxy to GraphAI API + # location / { + # proxy_pass http://127.0.0.1:28800; - # If you know you'll only ever serve via HTTPS: - add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always; + # proxy_http_version 1.1; + # proxy_set_header Host $host; + # proxy_set_header X-Real-IP $remote_addr; + # proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + # proxy_set_header X-Forwarded-Proto $scheme; - # Increase if you upload large files - client_max_body_size 200m; + # # WebSocket support (harmless even if unused) + # proxy_set_header Upgrade $http_upgrade; + # proxy_set_header Connection "upgrade"; + + # # Long requests (RAG, audio/video, etc.) + # proxy_connect_timeout 60s; + # proxy_send_timeout 600s; + # proxy_read_timeout 600s; + # } - # Proxy to GraphAI API + # Managing literal requests to "xxx.graphcert.cede-apps.ch" location / { proxy_pass http://127.0.0.1:28800; - - proxy_http_version 1.1; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header Host $host; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; - - # WebSocket support (harmless even if unused) + proxy_http_version 1.1; + # websocket headers proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection "upgrade"; - - # Long requests (RAG, audio/video, etc.) - proxy_connect_timeout 60s; - proxy_send_timeout 600s; - proxy_read_timeout 600s; - } + # proxy_set_header Connection $connection_upgrade; + # timeouts + proxy_read_timeout 3600; # If the proxied server does not transmit anything within this time, the connection is closed. + proxy_send_timeout 3600; # If the proxied server does not receive anything within this time, the connection is closed. + } } diff --git a/pyproject.toml b/pyproject.toml index f800d77c..90390c6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ "Operating System :: OS Independent" ] dependencies = [ - "loguru" + "loguru", "numpy", "scipy", "pandas", diff --git a/rabbitmq/rabbitmq.conf b/rabbitmq/rabbitmq.conf new file mode 100644 index 00000000..15677dda --- /dev/null +++ b/rabbitmq/rabbitmq.conf @@ -0,0 +1,2 @@ +# Set connection timeout to 6 hours +consumer_timeout = 21600000