Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions gmft/detectors/tatr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from gmft.core._dataclasses import with_config
from gmft.core.ml import _resolve_device
from gmft.detectors.base import BaseDetector, CroppedTable, RotatedCroppedTable
from gmft.base import Rect

from gmft.impl.tatr.config import TATRDetectorConfig
from gmft.pdf_bindings.base import BasePage
Expand Down Expand Up @@ -54,19 +55,20 @@ def __init__(self, config: TATRDetectorConfig = None, default_implementation=Tru
self.config = config

def extract(
self, page: BasePage, config_overrides: TATRDetectorConfig = None
self, page: BasePage, config_overrides: TATRDetectorConfig = None, rect: Rect = None
) -> list[CroppedTable]:
"""
Detect tables in a page.

:param page: BasePage
:param config_overrides: override the config for this call only
:param config_overrides: Optional config overrides for this extraction
:param rect: Optional Rect to constrain detection within given dimensions
:return: list of CroppedTable objects
"""
config = with_config(self.config, config_overrides)

img = page.get_image(
72
72, rect=rect
) # use standard dpi = 72, which means we don't need any scaling
encoding = self.image_processor(img, return_tensors="pt").to(
_resolve_device(self.config.torch_device)
Expand Down