diff --git a/CHANGELOG.md b/CHANGELOG.md index a806ef4..f9c4124 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +## v0.4.2 + +Bugfixes before the upcoming release. +- Better imports and lazy loading +- Default device is now 'auto', which resolves to cuda/cpu depending on availability +- Rich text now available as `AutoPageFormatter` + - Fixed bug with permuted coordinates (e0c6dc52) +- CroppedTable now directly has `angle` property +- CI tests, Python 3.9 support +- More type hints +- Light restructuring (non-breaking) +- Internal data structure tweaks + - (`fctn_results` → `predictions.tatr`) + - (`effective_*` → `predictions.effective`) + ## v0.4.0 Features: 3 new table structure recognition options! diff --git a/gmft/__init__.py b/gmft/__init__.py index 3d8dbed..4276ba8 100644 --- a/gmft/__init__.py +++ b/gmft/__init__.py @@ -1,140 +1,133 @@ """ Currently, contains aliases for key classes and functions. -Unfortunately, although at one point the ability to import classes from the top level module (ie. `from gmft import AutoTableFormatter`) was encouraged, -it is now discouraged and may be removed in future versions. The reason being: importing through the top level module -loads the entire library, even when you're using only a small part of it. +Importing from the top-level module previously resulted in long load times. +However, v0.5 introduces lazy loading, which greatly improves the situation. -Instead, `gmft.auto` is now encouraged. For example, `from gmft.auto import AutoTableFormatter`. +Now, classes may either be imported from their original locations, +`gmft.auto`, or from here, where they will be lazy loaded. """ +# small classes are fine, but discouraged. from gmft.base import Rect +from gmft.core.legacy.mirror import DeprecationMirrorMeta from gmft.pdf_bindings.base import BasePDFDocument, BasePage from gmft.detectors.base import CroppedTable, RotatedCroppedTable from gmft.formatters.base import FormattedTable -from gmft.auto import ( - TATRDetector as TATRTableDetectorOrig, - TableDetectorConfig as TableDetectorConfigOrig, - TableDetector as TableDetectorOrig, - TATRFormatConfig as TATRFormatConfigOrig, - TATRFormattedTable as TATRFormattedTableOrig, - TATRFormatter as TATRTableFormatterOrig, - AutoTableFormatter as AutoTableFormatterOrig, - AutoFormatConfig as AutoFormatConfigOrig, - AutoTableDetector as AutoTableDetectorOrig, +# config-only classes specific to TATR are still discouraged. + +# these auto classes are lazy-loaded +from gmft.core.auto_lazy import ( + AutoTableFormatter, + AutoFormatConfig, + AutoTableDetector, ) -has_warned = False +# We need to support these imports for compatibility: +# TATRTableDetector +# TableDetectorConfig +# TableDetector +# TATRFormatConfig +# TATRFormattedTable +# TATRTableFormatter +# AutoTableFormatter +# AutoFormatConfig +# AutoTableDetector -def _deprecation_warning(name): - global has_warned - if has_warned: - return - import warnings +# These bulky TATR-specific detectors are discouraged, but still available for compatibility. +class TATRTableDetector(metaclass=DeprecationMirrorMeta): + """ + This import is deprecated. - msg = f"(Deprecation) While once encouraged, \ -importing {name} and other classes from the top level module is now deprecated and will break in v0.5.0. \ -Please import from gmft.auto instead." - warnings.warn(msg, DeprecationWarning, stacklevel=2) - print(msg) - has_warned = True + Please use: + - gmft.AutoTableDetector + - gmft.detectors.tatr.TATRDetector + """ + @classmethod + def get_mirrored_class(cls): + from gmft.detectors.tatr import TATRDetector as OrigCls -# These small classes are fine, but still discouraged. -# Rect -# BasePDFDocument -# BasePage -# CroppedTable -# RotatedCroppedTable + return OrigCls -class TATRTableDetector(TATRTableDetectorOrig): - """ - Deprecated. Please import from gmft.auto instead. +class TableDetectorConfig(metaclass=DeprecationMirrorMeta): """ + This import is deprecated. - def __init__(self, *args, **kwargs): - _deprecation_warning("TATRTableDetector") - super().__init__(*args, **kwargs) - - -class TableDetectorConfig(TableDetectorConfigOrig): - """ - Deprecated. Please import from gmft.auto instead. + Please use: + - Reformat API (v0.5) + - gmft.detectors.tatr.TATRDetectorConfig """ - def __init__(self, *args, **kwargs): - _deprecation_warning("TableDetectorConfig") - super().__init__(*args, **kwargs) + @classmethod + def get_mirrored_class(cls): + from gmft.impl.tatr.config import TATRDetectorConfig as OrigCls + return OrigCls -class TableDetector(TableDetectorOrig): - """ - Deprecated. Please import from gmft.auto instead. - """ - - def __init__(self, *args, **kwargs): - _deprecation_warning("TableDetector") - super().__init__(*args, **kwargs) - -class TATRFormatConfig(TATRFormatConfigOrig): - """ - Deprecated. Please import from gmft.auto instead. +class TableDetector(metaclass=DeprecationMirrorMeta): """ + This import is deprecated. - def __init__(self, *args, **kwargs): - _deprecation_warning("TATRFormatConfig") - super().__init__(*args, **kwargs) - - -class TATRFormattedTable(TATRFormattedTableOrig): - """ - Deprecated. Please import from gmft.auto instead. + Please use: + - gmft.AutoTableDetector + - gmft.detectors.tatr.TATRDetector """ - def __init__(self, *args, **kwargs): - _deprecation_warning("TATRFormattedTable") - super().__init__(*args, **kwargs) - - -class TATRTableFormatter(TATRTableFormatterOrig): - """ - Deprecated. Please import from gmft.auto instead. - """ + @classmethod + def get_mirrored_class(cls): + from gmft.auto import TATRDetector as OrigCls - def __init__(self, *args, **kwargs): - _deprecation_warning("TATRTableFormatter") - super().__init__(*args, **kwargs) + return OrigCls -class AutoTableFormatter(AutoTableFormatterOrig): +class TATRFormatConfig(metaclass=DeprecationMirrorMeta): """ - Deprecated. Please import from gmft.auto instead. + This import is deprecated. + + Please use: + - Reformat API (v0.5) + - gmft.formatters.tatr.TATRFormatConfig """ - def __init__(self, *args, **kwargs): - _deprecation_warning("AutoTableFormatter") - super().__init__(*args, **kwargs) + @classmethod + def get_mirrored_class(cls): + from gmft.impl.tatr.config import TATRFormatConfig as OrigCls + + return OrigCls -class AutoFormatConfig(AutoFormatConfigOrig): +class TATRFormattedTable(metaclass=DeprecationMirrorMeta): """ - Deprecated. Please import from gmft.auto instead. + This import is deprecated. + + Please use: + - Reformat API (v0.5) + - gmft.formatters.tatr.TATRFormattedTable """ - def __init__(self, *args, **kwargs): - _deprecation_warning("AutoFormatConfig") - super().__init__(*args, **kwargs) + @classmethod + def get_mirrored_class(cls): + from gmft.formatters.tatr import TATRFormattedTable as OrigCls + return OrigCls -class AutoTableDetector(AutoTableDetectorOrig): + +class TATRTableFormatter(metaclass=DeprecationMirrorMeta): """ - Deprecated. Please import from gmft.auto instead. + This import is deprecated. + + Please use: + - gmft.auto.AutoTableFormatter + - gmft.formatters.tatr.TATRFormatter """ - def __init__(self, *args, **kwargs): - _deprecation_warning("AutoTableDetector") - super().__init__(*args, **kwargs) + @classmethod + def get_mirrored_class(cls): + from gmft.formatters.tatr import TATRFormatter as OrigCls + + return OrigCls diff --git a/gmft/algorithm/structure.py b/gmft/algorithm/structure.py index 3c5e0db..255098a 100644 --- a/gmft/algorithm/structure.py +++ b/gmft/algorithm/structure.py @@ -767,7 +767,7 @@ def extract_to_df(table: TATRFormattedTable, config: TATRFormatConfig = None): outliers = {} # store table-wide information about outliers or pecularities - results = table.predictions["tatr"] + results = table.predictions.tatr # 1. collate identified boxes boxes = [] @@ -889,8 +889,9 @@ def extract_to_df(table: TATRFormattedTable, config: TATRFormatConfig = None): if not known_means: # no text was detected outliers["no text"] = True - table.predictions["effective"] = _empty_effective_predictions() - table.predictions["indices"] = _empty_indices_predictions() + table.predictions.effective = _empty_effective_predictions() + table.predictions.indices = _empty_indices_predictions() + table.predictions.status = "ready" table._df = pd.DataFrame() table.outliers = outliers return table._df @@ -930,7 +931,7 @@ def extract_to_df(table: TATRFormattedTable, config: TATRFormatConfig = None): ) # nms takes care of deduplication - table.predictions["effective"] = { + table.predictions.effective = { "rows": sorted_rows, "columns": sorted_columns, "headers": sorted_headers, @@ -1071,7 +1072,8 @@ def extract_to_df(table: TATRFormattedTable, config: TATRFormatConfig = None): ] indices_preds["_projecting"] = [i for i, x in enumerate(is_projecting) if x] - table.predictions["indices"] = indices_preds + table.predictions.indices = indices_preds + table.predictions.status = "ready" # if projecting_indices: # insert at end diff --git a/gmft/auto.py b/gmft/auto.py index ec4ccbf..5c1b6cb 100644 --- a/gmft/auto.py +++ b/gmft/auto.py @@ -22,41 +22,8 @@ TATRTableFormatter = TATRFormatter # TATRFormatConfig = TATRFormatConfig - -class AutoTableFormatter: - """ - The recommended :class:`~gmft.formatters.base.BaseFormatter`. Currently points to :class:`~gmft.formatters.tatr.TATRFormatter`. - Uses a TableTransformerForObjectDetection for small/medium tables, and a custom algorithm for large tables. - - Using :meth:`extract`, a :class:`~gmft.formatters.base.FormattedTable` is produced, which can be exported to csv, df, etc. - """ - - def __new__(cls, *args, **kwargs): - from gmft.formatters.tatr import TATRFormatter - - return TATRFormatter(*args, **kwargs) - - -class AutoFormatConfig: - """ - Configuration for the recommended :class:`~gmft.formatters.base.BaseFormatter`. Currently points to :class:`~gmft.formatters.tatr.TATRFormatConfig`. - """ - - def __new__(cls, *args, **kwargs): - from gmft.impl.tatr.config import TATRFormatConfig - - return TATRFormatConfig(*args, **kwargs) - - -class AutoTableDetector: - """ - The recommended :class:`~gmft.detectors.base.BaseDetector`. Currently points to :class:`~gmft.detectors.tatr.TATRDetector`. - Uses TableTransformerForObjectDetection for small/medium tables, and a custom algorithm for large tables. - - Using :meth:`~gmft.detectors.base.BaseDetector.extract` produces a :class:`~gmft.formatters.base.FormattedTable`, which can be exported to csv, df, etc. - """ - - def __new__(cls, *args, **kwargs): - from gmft.detectors.tatr import TATRDetector - - return TATRDetector(*args, **kwargs) +from gmft.core.auto_lazy import ( + AutoTableFormatter, + AutoFormatConfig, + AutoTableDetector, +) diff --git a/gmft/core/auto_lazy.py b/gmft/core/auto_lazy.py new file mode 100644 index 0000000..8b7f2ff --- /dev/null +++ b/gmft/core/auto_lazy.py @@ -0,0 +1,37 @@ +class AutoTableFormatter: + """ + The recommended :class:`~gmft.formatters.base.BaseFormatter`. Currently points to :class:`~gmft.formatters.tatr.TATRFormatter`. + Uses a TableTransformerForObjectDetection for small/medium tables, and a custom algorithm for large tables. + + Using :meth:`extract`, a :class:`~gmft.formatters.base.FormattedTable` is produced, which can be exported to csv, df, etc. + """ + + def __new__(cls, *args, **kwargs): + from gmft.formatters.tatr import TATRFormatter + + return TATRFormatter(*args, **kwargs) + + +class AutoFormatConfig: + """ + Configuration for the recommended :class:`~gmft.formatters.base.BaseFormatter`. Currently points to :class:`~gmft.formatters.tatr.TATRFormatConfig`. + """ + + def __new__(cls, *args, **kwargs): + from gmft.impl.tatr.config import TATRFormatConfig + + return TATRFormatConfig(*args, **kwargs) + + +class AutoTableDetector: + """ + The recommended :class:`~gmft.detectors.base.BaseDetector`. Currently points to :class:`~gmft.detectors.tatr.TATRDetector`. + Uses TableTransformerForObjectDetection for small/medium tables, and a custom algorithm for large tables. + + Using :meth:`~gmft.detectors.base.BaseDetector.extract` produces a :class:`~gmft.formatters.base.FormattedTable`, which can be exported to csv, df, etc. + """ + + def __new__(cls, *args, **kwargs): + from gmft.detectors.tatr import TATRDetector + + return TATRDetector(*args, **kwargs) diff --git a/gmft/core/io/serial/dicts.py b/gmft/core/io/serial/dicts.py index 1969f1f..90b86e5 100644 --- a/gmft/core/io/serial/dicts.py +++ b/gmft/core/io/serial/dicts.py @@ -3,6 +3,7 @@ from gmft.core.ml.prediction import ( IndicesPredictions, RawBboxPredictions, + _empty_effective_predictions, _empty_indices_predictions, ) from gmft.detectors.base import CroppedTable @@ -57,3 +58,11 @@ def _extract_indices(d: dict) -> IndicesPredictions: } return _empty_indices_predictions() + + +def _extract_effective(d: dict) -> IndicesPredictions: + # version gmft>=0.5 format + if "predictions.effective" in d: + return d["predictions.effective"] + + return _empty_effective_predictions() diff --git a/gmft/core/legacy/fctn_results.py b/gmft/core/legacy/fctn_results.py index 97ccb30..71ed5b1 100644 --- a/gmft/core/legacy/fctn_results.py +++ b/gmft/core/legacy/fctn_results.py @@ -14,91 +14,91 @@ class LegacyFctnResults: predictions: TablePredictions @property - @deprecated("Use self.predictions['tatr']") + @deprecated("Use self.predictions.tatr") def fctn_results(self) -> RawBboxPredictions: - return self.predictions["tatr"] + return self.predictions.tatr @fctn_results.setter - @deprecated("Use self.predictions['tatr']") + @deprecated("Use self.predictions.tatr") def fctn_results(self, value: RawBboxPredictions): - self.predictions["tatr"] = value + self.predictions.tatr = value @property - @deprecated("Use self.predictions['effective']") + @deprecated("Use self.predictions.effective") def effective_rows(self): - return self.predictions["effective"]["rows"] + return self.predictions.effective["rows"] @effective_rows.setter - @deprecated("Use self.predictions['effective']") + @deprecated("Use self.predictions.effective") def effective_rows(self, value): - self.predictions["effective"]["rows"] = value + self.predictions.effective["rows"] = value @property - @deprecated("Use self.predictions['effective']") + @deprecated("Use self.predictions.effective") def effective_columns(self): - return self.predictions["effective"]["columns"] + return self.predictions.effective["columns"] @effective_columns.setter - @deprecated("Use self.predictions['effective']") + @deprecated("Use self.predictions.effective") def effective_columns(self, value): - self.predictions["effective"]["columns"] = value + self.predictions.effective["columns"] = value @property - @deprecated("Use self.predictions['effective']") + @deprecated("Use self.predictions.effective") def effective_headers(self): - return self.predictions["effective"]["headers"] + return self.predictions.effective["headers"] @effective_headers.setter - @deprecated("Use self.predictions['effective']") + @deprecated("Use self.predictions.effective") def effective_headers(self, value): - self.predictions["effective"]["headers"] = value + self.predictions.effective["headers"] = value @property - @deprecated("Use self.predictions['effective']") + @deprecated("Use self.predictions.effective") def effective_projecting(self): - return self.predictions["effective"]["projecting"] + return self.predictions.effective["projecting"] @effective_projecting.setter - @deprecated("Use self.predictions['effective']") + @deprecated("Use self.predictions.effective") def effective_projecting(self, value): - self.predictions["effective"]["projecting"] = value + self.predictions.effective["projecting"] = value @property - @deprecated("Use self.predictions['effective']") + @deprecated("Use self.predictions.effective") def effective_spanning(self): - return self.predictions["effective"]["spanning"] + return self.predictions.effective["spanning"] @effective_spanning.setter - @deprecated("Use self.predictions['effective']") + @deprecated("Use self.predictions.effective") def effective_spanning(self, value): - self.predictions["effective"]["spanning"] = value + self.predictions.effective["spanning"] = value @property - @deprecated("Use self.predictions['indices']['top_header']") + @deprecated("Use self.predictions.indices['top_header']") def _top_header_indices(self): - return self.predictions["indices"].get("top_header") + return self.predictions.indices.get("top_header") @_top_header_indices.setter - @deprecated("Use self.predictions['indices']['_top_header']") + @deprecated("Use self.predictions.indices['_top_header']") def _top_header_indices(self, value): - self.predictions["indices"]["_top_header"] = value + self.predictions.indices["_top_header"] = value @property - @deprecated("Use self.predictions['indices']['_projecting']") + @deprecated("Use self.predictions.indices['_projecting']") def _projecting_indices(self): - return self.predictions["indices"].get("_projecting") + return self.predictions.indices.get("_projecting") @_projecting_indices.setter - @deprecated("Use self.predictions['indices']['_projecting']") + @deprecated("Use self.predictions.indices['_projecting']") def _projecting_indices(self, value): - self.predictions["indices"]["_projecting"] = value + self.predictions.indices["_projecting"] = value @property - @deprecated("Use self.predictions['indices']['_hier_left']") + @deprecated("Use self.predictions.indices['_hier_left']") def _hier_left_indices(self): - return self.predictions["indices"].get("_hier_left") + return self.predictions.indices.get("_hier_left") @_hier_left_indices.setter - @deprecated("Use self.predictions['indices']['hier_left']") + @deprecated("Use self.predictions.indices['hier_left']") def _hier_left_indices(self, value): - self.predictions["indices"]["hier_left"] = value + self.predictions.indices["hier_left"] = value diff --git a/gmft/core/legacy/mirror.py b/gmft/core/legacy/mirror.py new file mode 100644 index 0000000..b740d38 --- /dev/null +++ b/gmft/core/legacy/mirror.py @@ -0,0 +1,47 @@ +has_warned = False + + +def _deprecation_warning(name): + global has_warned + if has_warned: + return + import warnings + + msg = ( + f"(Deprecation) Importing {name} et al. from the top level module is deprecated. \ +Refer to the import guide, or import from gmft.auto." # TODO: add documentation link + ) + warnings.warn(msg, DeprecationWarning, stacklevel=2) + print(msg) + has_warned = True + + +class DeprecationMirrorMeta(type): + """ + A metaclass that wraps a class to issue a deprecation warning when instantiated. + + It mirrors a class (which needs to provided as a classmethod `get_mirrored_class`). + + Though the power of magic, the class and wrapped class will be nearly equivalent. + `isinstance()` is modified so that the original class and wrapped class are interchangeable. + """ + + def __init__(cls, name, bases, dct): + # Call the classmethod to get the real class + # cls._orig_cls = cls.get_mirrored_class() + super().__init__(name, bases, dct) + + def __call__(cls, *args, **kwargs): + # Issue warning once per instantiation + _deprecation_warning(cls.__name__) + instance = cls.get_mirrored_class()(*args, **kwargs) + # instance.__class__ = cls # Make it look like the wrapper + return instance + + def __instancecheck__(cls, instance): + """ + Allow isinstance checks to work with the original class OR the wrapped class. + """ + return isinstance(instance, cls.get_mirrored_class()) or isinstance( + instance, cls + ) diff --git a/gmft/core/ml/prediction/__init__.py b/gmft/core/ml/prediction/__init__.py index 08b070f..d998664 100644 --- a/gmft/core/ml/prediction/__init__.py +++ b/gmft/core/ml/prediction/__init__.py @@ -1,4 +1,5 @@ -from typing import Optional, Tuple, TypedDict, List, Union +from dataclasses import dataclass +from typing import Literal, Optional, Tuple, TypedDict, List, Union from typing_extensions import NotRequired @@ -47,7 +48,8 @@ class IndicesPredictions(TypedDict): _hier_left: NotRequired[List[int]] -class TablePredictions(TypedDict): +@dataclass +class TablePredictions: """Type definition for the complete predictions dictionary.""" tatr: RawBboxPredictions @@ -55,6 +57,8 @@ class TablePredictions(TypedDict): effective: EffectivePredictions indices: IndicesPredictions + status: Literal["unready", "ready"] = "unready" + def _empty_effective_predictions(): return { diff --git a/gmft/detectors/tatr.py b/gmft/detectors/tatr.py index de86cf5..83d4187 100644 --- a/gmft/detectors/tatr.py +++ b/gmft/detectors/tatr.py @@ -5,43 +5,10 @@ from gmft.core.ml import _resolve_device from gmft.detectors.base import BaseDetector, CroppedTable, RotatedCroppedTable +from gmft.impl.tatr.config import TATRDetectorConfig from gmft.pdf_bindings.base import BasePage -@dataclass -class TATRDetectorConfig: - """ - Configuration for the :class:`.TATRDetector` class. - - Specific to the TableTransformerForObjectDetection model. (Do not subclass this.) - """ - - image_processor_path: str = "microsoft/table-transformer-detection" - detector_path: str = "microsoft/table-transformer-detection" - no_timm: bool = True # huggingface revision - warn_uninitialized_weights: bool = False - torch_device: str = "cuda" if torch.cuda.is_available() else "cpu" - - detector_base_threshold: float = 0.9 - """Minimum confidence score required for a table""" - - @property - def confidence_score_threshold(self): - raise DeprecationWarning( - "Use detector_base_threshold instead. Will break in v0.6.0." - ) - - @confidence_score_threshold.setter - def confidence_score_threshold(self, value): - raise DeprecationWarning( - "Use detector_base_threshold instead. Will break in v0.6.0." - ) - - def __post_init__(self): - # use cuda if available - pass - - class TATRDetector(BaseDetector[TATRDetectorConfig]): """ Uses TableTransformerForObjectDetection for small/medium tables, and a custom algorithm for large tables. diff --git a/gmft/formatters/ditr.py b/gmft/formatters/ditr.py index e2df8a6..bbf13eb 100644 --- a/gmft/formatters/ditr.py +++ b/gmft/formatters/ditr.py @@ -12,10 +12,15 @@ _ioa, get_good_between_dividers, ) -from gmft.core.io.serial.dicts import _extract_fctn_results, _extract_indices +from gmft.core.io.serial.dicts import ( + _extract_effective, + _extract_fctn_results, + _extract_indices, +) from gmft.core.legacy.fctn_results import LegacyFctnResults from gmft.core.ml import _resolve_device from gmft.core.ml.prediction import ( + TablePredictions, _empty_effective_predictions, _empty_indices_predictions, ) @@ -70,11 +75,12 @@ def __init__( super(DITRFormattedTable, self).__init__( cropped_table, None, irvl_results, config=config ) - self.predictions = { - "tatr": fctn_results, - "effective": _empty_effective_predictions(), - "indices": _empty_indices_predictions(), - } + self.predictions = TablePredictions( + tatr=fctn_results, + effective=_empty_effective_predictions(), + indices=_empty_indices_predictions(), + status="unready", + ) if config is None: config = DITRFormatConfig() @@ -121,13 +127,13 @@ def visualize(self, **kwargs): for y0, y1 in self.irvl_results["row_dividers"]: bboxes.append([0, y0, tbl_width, y1]) labels.append(2) - for x0, y0, x1, y1 in self.predictions["effective"]["headers"]: + for x0, y0, x1, y1 in self.predictions.effective["headers"]: bboxes.append([x0, y0, x1, y1]) labels.append(3) - for x0, y0, x1, y1 in self.predictions["effective"]["headers"]: + for x0, y0, x1, y1 in self.predictions.effective["headers"]: bboxes.append([x0, y0, x1, y1]) labels.append(4) - for x0, y0, x1, y1 in self.predictions["effective"]["headers"]: + for x0, y0, x1, y1 in self.predictions.effective["headers"]: bboxes.append([x0, y0, x1, y1]) labels.append(5) return plot_shaded_boxes(img, labels=labels, boxes=bboxes, **kwargs) @@ -141,14 +147,15 @@ def to_dict(self): else: parent = CroppedTable.to_dict(self) optional = {} - if self.predictions["indices"]: - optional["predictions.indices"] = self.predictions["indices"] + if self.predictions.status == "ready": + optional["predictions.effective"] = self.predictions.effective + optional["predictions.indices"] = self.predictions.indices return { **parent, **{ "config": non_defaults_only(self.config), "outliers": self.outliers, - "fctn_results": self.predictions["tatr"], + "fctn_results": self.predictions.tatr, }, **optional, } @@ -174,7 +181,10 @@ def from_dict(d: dict, page: BasePage): ) table.recompute() table.outliers = d.get("outliers", None) - table.predictions["indices"] = _extract_indices(d) + table.predictions.indices = _extract_indices(d) + table.predictions.effective = _extract_effective(d) + if "predictions.effective" in d: + table.predictions.status = "ready" return table @@ -434,7 +444,7 @@ def ditr_extract_to_df(table: DITRFormattedTable, config: DITRFormatConfig = Non outliers = {} # store table-wide information about outliers or pecularities - results = table.predictions["tatr"] + results = table.predictions.tatr row_divider_boxes, col_divider_boxes, top_headers, projected, spanning_cells = ( proportion_fctn_results(results, config) ) @@ -462,7 +472,7 @@ def ditr_extract_to_df(table: DITRFormattedTable, config: DITRFormatConfig = Non "row_dividers": row_divider_intervals, "col_dividers": col_divider_intervals, } - table.predictions["effective"] = { + table.predictions.effective = { "rows": [], "columns": [], "headers": top_headers, @@ -599,7 +609,8 @@ def ditr_extract_to_df(table: DITRFormattedTable, config: DITRFormatConfig = Non # automatically makes is_projecting and header_indices mutually exclusive indices_preds["_projecting"] = [i for i, x in enumerate(is_projecting) if x] - table.predictions["indices"] = indices_preds + table.predictions.indices = indices_preds + table.predictions.status = "ready" # b. drop the former header rows always table._df.drop(index=header_indices, inplace=True) diff --git a/gmft/formatters/tatr.py b/gmft/formatters/tatr.py index c56207c..bf17430 100644 --- a/gmft/formatters/tatr.py +++ b/gmft/formatters/tatr.py @@ -7,6 +7,7 @@ from gmft.core.ml import _resolve_device from gmft.core.ml.prediction import ( BboxPrediction, + TablePredictions, _empty_effective_predictions, _empty_indices_predictions, ) @@ -58,11 +59,11 @@ def __init__( config: TATRFormatConfig = None, ): super(TATRFormattedTable, self).__init__(cropped_table) - self.predictions = { - "tatr": fctn_results, - "effective": _empty_effective_predictions(), - "indices": _empty_indices_predictions(), - } + self.predictions = TablePredictions( + tatr=fctn_results, + effective=_empty_effective_predictions(), + indices=_empty_indices_predictions(), + ) if config is None: config = TATRFormatConfig() @@ -123,7 +124,7 @@ def visualize( self._df = self.df() vis: List[BboxPrediction] = [ item - for sublist in self.predictions["effective"].values() + for sublist in self.predictions.effective.values() for item in sublist ] boxes = [x["bbox"] for x in vis] @@ -136,13 +137,12 @@ def visualize( else: # transform functionalized coordinates into image coordinates boxes = [ - (x * scale_by for x in bbox) - for bbox in self.predictions["tatr"]["boxes"] + (x * scale_by for x in bbox) for bbox in self.predictions.tatr["boxes"] ] _to_visualize = { - "scores": self.predictions["tatr"]["scores"], - "labels": self.predictions["tatr"]["labels"], + "scores": self.predictions.tatr["scores"], + "labels": self.predictions.tatr["labels"], "boxes": boxes, } @@ -171,14 +171,14 @@ def to_dict(self): else: parent = CroppedTable.to_dict(self) optional = {} - if self.predictions["indices"]: - optional["predictions.indices"] = self.predictions["indices"] + if self.predictions.indices: + optional["predictions.indices"] = self.predictions.indices return { **parent, **{ "config": non_defaults_only(self.config), "outliers": self.outliers, - "fctn_results": self.predictions["tatr"], + "fctn_results": self.predictions.tatr, }, **optional, } @@ -201,7 +201,7 @@ def from_dict(d: dict, page: BasePage): config=config, ) table.outliers = d.get("outliers", None) - table.predictions["indices"] = _extract_indices(d) + table.predictions.indices = _extract_indices(d) return table diff --git a/gmft/impl/tatr/config.py b/gmft/impl/tatr/config.py index 3f806a9..f4b4c06 100644 --- a/gmft/impl/tatr/config.py +++ b/gmft/impl/tatr/config.py @@ -8,6 +8,40 @@ from gmft.core.legacy.removed_config import LegacyRemovedConfig +@dataclass +class TATRDetectorConfig: + """ + Configuration for the :class:`.TATRDetector` class. + + Specific to the TableTransformerForObjectDetection model. (Do not subclass this.) + """ + + image_processor_path: str = "microsoft/table-transformer-detection" + detector_path: str = "microsoft/table-transformer-detection" + no_timm: bool = True # huggingface revision + warn_uninitialized_weights: bool = False + torch_device: str = "cuda" if torch.cuda.is_available() else "cpu" + + detector_base_threshold: float = 0.9 + """Minimum confidence score required for a table""" + + @property + def confidence_score_threshold(self): + raise DeprecationWarning( + "Use detector_base_threshold instead. Will break in v0.6.0." + ) + + @confidence_score_threshold.setter + def confidence_score_threshold(self, value): + raise DeprecationWarning( + "Use detector_base_threshold instead. Will break in v0.6.0." + ) + + def __post_init__(self): + # use cuda if available + pass + + @dataclass class TATRFormatConfig(LegacyRemovedConfig): """ diff --git a/pyproject.toml b/pyproject.toml index 98f5aee..1f7d0a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "gmft" -version = "0.4.1" +version = "0.4.2" description = "Lightweight, performant, deep table extraction" authors = [ { name = "conjunct" }, diff --git a/test/compat/test_imports.py b/test/compat/test_imports.py index d8ad29c..e9d5f5a 100644 --- a/test/compat/test_imports.py +++ b/test/compat/test_imports.py @@ -51,14 +51,33 @@ def test_aliases(): assert isinstance(ct, CroppedTableOrig) - from gmft import TATRFormatConfig - config = TATRFormatConfig(large_table_threshold=2) +def test_idem_isinstance(): + # test that isinstance(self, cls) is true + # which is tricky with lazy-loaded classes + from gmft import TATRFormatConfig as LazyConfig + + from gmft.impl.tatr.config import TATRFormatConfig as OrigConfig + + config = LazyConfig(large_table_threshold=2) assert config.large_table_threshold == 2 - assert isinstance(config, TATRFormatConfig) + # check is instanceof self + assert isinstance(config, LazyConfig) + + # check equivalency of OrigConfig and TATRFormatConfig + assert isinstance(config, OrigConfig) + + config_orig = OrigConfig(large_table_threshold=3) + assert config_orig.large_table_threshold == 3 + assert isinstance(config_orig, OrigConfig) + assert isinstance(config_orig, LazyConfig) # FAILS + + +def test_common_alias(): # import from "common" as an alias for "base" + from gmft import Rect from gmft.common import Rect as CommonRect from gmft.formatters.common import BaseFormatter diff --git a/test/formatters/ditr/test_df.py b/test/formatters/ditr/test_df.py index dcb65cb..42ca65d 100644 --- a/test/formatters/ditr/test_df.py +++ b/test/formatters/ditr/test_df.py @@ -95,7 +95,7 @@ def test_bulk_pdf5_t0(self, pdf5_tables): pass # this one just doesn't work very well # TODO make it work based on minima # try_jth_table(pdf5_tables, 5, 0) - # assert pdf5_tables[0].predictions["indices"]["_projecting"] == [15, 18, 22, 29] + # assert pdf5_tables[0].predictions.indices["_projecting"] == [15, 18, 22, 29] def test_bulk_pdf5_t1(self, ditr_tables, ditr_csvs, docs_bulk): try_table("pdf5_t1", ditr_tables, ditr_csvs, docs_bulk[5 - 1]) diff --git a/test/formatters/histogram/test_df.py b/test/formatters/histogram/test_df.py index 19b4d02..9b999c1 100644 --- a/test/formatters/histogram/test_df.py +++ b/test/formatters/histogram/test_df.py @@ -112,7 +112,7 @@ def test_bulk_pdf5_t0(self, pdf5_tables): pass # this one just doesn't work very well # TODO make it work based on minima # try_jth_table(pdf5_tables, 5, 0) - # assert pdf5_tables[0].predictions["indices"]["_projecting"] == [15, 18, 22, 29] + # assert pdf5_tables[0].predictions.indices["_projecting"] == [15, 18, 22, 29] def test_bulk_pdf5_t1(self, pdf5_tables, tatr_csvs): try_jth_table(pdf5_tables, tatr_csvs, 5, 1) diff --git a/test/formatters/tatr/test_df.py b/test/formatters/tatr/test_df.py index dfba1b5..ab7a9b7 100644 --- a/test/formatters/tatr/test_df.py +++ b/test/formatters/tatr/test_df.py @@ -92,11 +92,11 @@ def test_bulk_pdf2_t0(self, pdf2_tables, tatr_csvs): def test_bulk_pdf2_t1(self, pdf2_tables, tatr_csvs): try_jth_table(pdf2_tables, tatr_csvs, 2, 1) # hint: subtract 2 from the line no to get the proj. index (assume 1 header) - assert pdf2_tables[1].predictions["indices"]["_projecting"] == [9, 12, 16] + assert pdf2_tables[1].predictions.indices["_projecting"] == [9, 12, 16] def test_bulk_pdf2_t2(self, pdf2_tables, tatr_csvs): try_jth_table(pdf2_tables, tatr_csvs, 2, 2) - assert pdf2_tables[2].predictions["indices"]["_projecting"] == [0, 5] + assert pdf2_tables[2].predictions.indices["_projecting"] == [0, 5] def test_bulk_pdf2_t3(self, pdf2_tables, tatr_csvs): try_jth_table(pdf2_tables, tatr_csvs, 2, 3) @@ -112,7 +112,7 @@ def test_bulk_pdf3_t1(self, pdf3_tables, tatr_csvs): def test_bulk_pdf3_t2(self, pdf3_tables, tatr_csvs): try_jth_table(pdf3_tables, tatr_csvs, 3, 2) - assert pdf3_tables[2].predictions["indices"]["_projecting"] == [0, 8] + assert pdf3_tables[2].predictions.indices["_projecting"] == [0, 8] def test_bulk_pdf3_t3(self, pdf3_tables, tatr_csvs): try_jth_table(pdf3_tables, tatr_csvs, 3, 3) @@ -124,17 +124,17 @@ def test_bulk_pdf4_t0(self, pdf4_tables, tatr_csvs): def test_bulk_pdf4_t1(self, pdf4_tables, tatr_csvs): try_jth_table(pdf4_tables, tatr_csvs, 4, 1) - assert pdf4_tables[1].predictions["indices"]["_projecting"] == [0, 14] + assert pdf4_tables[1].predictions.indices["_projecting"] == [0, 14] class TestPdf5: def test_bulk_pdf5_t0(self, pdf5_tables, tatr_csvs): try_jth_table(pdf5_tables, tatr_csvs, 5, 0) - assert pdf5_tables[0].predictions["indices"]["_projecting"] == [15, 18, 22, 29] + assert pdf5_tables[0].predictions.indices["_projecting"] == [15, 18, 22, 29] def test_bulk_pdf5_t1(self, pdf5_tables, tatr_csvs): try_jth_table(pdf5_tables, tatr_csvs, 5, 1) - assert pdf5_tables[1].predictions["indices"]["_projecting"] == [13, 16, 22, 26] + assert pdf5_tables[1].predictions.indices["_projecting"] == [13, 16, 22, 26] class TestPdf6: diff --git a/test/formatters/tatr/test_spanning.py b/test/formatters/tatr/test_spanning.py index 2d413d8..716fc67 100644 --- a/test/formatters/tatr/test_spanning.py +++ b/test/formatters/tatr/test_spanning.py @@ -178,7 +178,7 @@ def test_pdf2_t2(self, pdf2_tables): try_jth_table(pdf2_tables, 2, 2, expected, config=config2) - assert pdf2_tables[2].predictions["indices"]["_projecting"] == [0, 5] + assert pdf2_tables[2].predictions.indices["_projecting"] == [0, 5] # pdf4 t1 is arguably HierTop, but the ground truth is not yet clear diff --git a/test/formatters/tatr/test_visualize.py b/test/formatters/tatr/test_visualize.py index 4b5346f..34331f8 100644 --- a/test/formatters/tatr/test_visualize.py +++ b/test/formatters/tatr/test_visualize.py @@ -15,7 +15,7 @@ # self._img_margin = (0, 0, 0, 0) # self.angle = 0 # self._df = None -# self.predictions = {} +# self.predictions = TablePredictions(TODO) # self.image_shape = (100, 100, 3) # def image(self, dpi=None, padding=None, margin=None):