Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Tests

on:
push:
branches: [ main, dev ]
pull_request:
branches: [ main, dev ]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"

- name: Cache uv dependencies
uses: actions/cache@v3
with:
path: |
.venv
.uv/cache
key: ${{ runner.os }}-uv-${{ hashFiles('**/uv.lock') }}
restore-keys: |
${{ runner.os }}-uv-

- name: Install dependencies
run: |
uv sync --group dev

- name: Run tests
run: |
uv run pytest test/
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__pycache__
legacy
./legacy
dist
.vscode
.pytest_cache
Expand All @@ -14,3 +14,6 @@ support_arena/*
data/test/outputs
experiments
TODO.md
.coverage
coverage.xml

Binary file added data/test/references/img/pdf2_t2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
43 changes: 23 additions & 20 deletions gmft/algorithm/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from gmft.base import Rect
from typing import TYPE_CHECKING

from gmft.core.ml.prediction import (
_empty_effective_predictions,
_empty_indices_predictions,
)

if TYPE_CHECKING:
from gmft.impl.tatr.config import TATRFormatConfig
from gmft.formatters.tatr import TATRFormattedTable
Expand Down Expand Up @@ -772,7 +777,7 @@ def extract_to_df(table: TATRFormattedTable, config: TATRFormatConfig = None):

outliers = {} # store table-wide information about outliers or pecularities

results = table.fctn_results
results = table.predictions["tatr"]

# 1. collate identified boxes
boxes = []
Expand Down Expand Up @@ -894,14 +899,8 @@ def extract_to_df(table: TATRFormattedTable, config: TATRFormatConfig = None):
if not known_means:
# no text was detected
outliers["no text"] = True
table.effective_rows = []
table.effective_columns = []
table.effective_headers = []
table.effective_projecting = []
table.effective_spanning = []
table._top_header_indices = []
table._projecting_indices = []
table._hier_left_indices = []
table.predictions["effective"] = _empty_effective_predictions()
table.predictions["indices"] = _empty_indices_predictions()
table._df = pd.DataFrame()
table.outliers = outliers
return table._df
Expand Down Expand Up @@ -941,12 +940,13 @@ def extract_to_df(table: TATRFormattedTable, config: TATRFormatConfig = None):
)

# nms takes care of deduplication

table.effective_rows = sorted_rows
table.effective_columns = sorted_columns
table.effective_headers = sorted_headers
table.effective_projecting = sorted_projecting
table.effective_spanning = spanning_cells
table.predictions["effective"] = {
"rows": sorted_rows,
"columns": sorted_columns,
"headers": sorted_headers,
"projecting": sorted_projecting,
"spanning": spanning_cells,
}

# 4b. check for catastrophic overlap
total_column_area = 0
Expand Down Expand Up @@ -1004,6 +1004,7 @@ def extract_to_df(table: TATRFormattedTable, config: TATRFormatConfig = None):
)

# semantic spanning fill
indices_preds = {}
if config.semantic_spanning_cells:
sorted_headers_bboxes = [x["bbox"] for x in sorted_headers]
sorted_row_bboxes = [x["bbox"] for x in sorted_rows]
Expand Down Expand Up @@ -1037,15 +1038,15 @@ def extract_to_df(table: TATRFormattedTable, config: TATRFormatConfig = None):
header_indices=header_indices,
config=config,
)
table._hier_left_indices = hier_left_idxs
indices_preds["_hier_left"] = hier_left_idxs
else:
table._hier_left_indices = [] # for the user
indices_preds["_hier_left"] = [] # for the user

# technically these indices will be off by the number of header rows ;-;
if config.enable_multi_header:
table._top_header_indices = header_indices
indices_preds["_top_header"] = header_indices
else:
table._top_header_indices = [0] if header_indices else []
indices_preds["_top_header"] = [0] if header_indices else []

# extract out the headers
header_rows = table_array[header_indices]
Expand Down Expand Up @@ -1078,7 +1079,9 @@ def extract_to_df(table: TATRFormattedTable, config: TATRFormatConfig = None):
is_projecting = [
x for i, x in enumerate(is_projecting) if i not in header_indices
]
table._projecting_indices = [i for i, x in enumerate(is_projecting) if x]
indices_preds["_projecting"] = [i for i, x in enumerate(is_projecting) if x]

table.predictions["indices"] = indices_preds

# if projecting_indices:
# insert at end
Expand Down
7 changes: 7 additions & 0 deletions gmft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,29 @@
TATRTableFormatter = TATRFormatter
# TATRFormatConfig = TATRFormatConfig


class AutoTableFormatter:
"""
The recommended :class:`~gmft.formatters.base.BaseFormatter`. Currently points to :class:`~gmft.formatters.tatr.TATRFormatter`.
Uses a TableTransformerForObjectDetection for small/medium tables, and a custom algorithm for large tables.

Using :meth:`extract`, a :class:`~gmft.formatters.base.FormattedTable` is produced, which can be exported to csv, df, etc.
"""

def __new__(cls, *args, **kwargs):
from gmft.formatters.tatr import TATRFormatter

return TATRFormatter(*args, **kwargs)


class AutoFormatConfig:
"""
Configuration for the recommended :class:`~gmft.formatters.base.BaseFormatter`. Currently points to :class:`~gmft.formatters.tatr.TATRFormatConfig`.
"""

def __new__(cls, *args, **kwargs):
from gmft.impl.tatr.config import TATRFormatConfig

return TATRFormatConfig(*args, **kwargs)


Expand All @@ -50,6 +55,8 @@ class AutoTableDetector:

Using :meth:`~gmft.detectors.base.BaseDetector.extract` produces a :class:`~gmft.formatters.base.FormattedTable`, which can be exported to csv, df, etc.
"""

def __new__(cls, *args, **kwargs):
from gmft.detectors.tatr import TATRDetector

return TATRDetector(*args, **kwargs)
2 changes: 1 addition & 1 deletion gmft/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TypeVar, Union
from gmft.core.exceptions import DocumentClosedException
from gmft.core.exception import DocumentClosedException


class Rect:
Expand Down
75 changes: 0 additions & 75 deletions gmft/core/_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,78 +56,3 @@ def non_defaults_only(config: object) -> dict:
if default_value != current_value:
result[f.name] = current_value
return result


import warnings

string_types = (type(b""), type(""))


def removed_property(reason):
"""
Custom decorator for marking class properties as removed.
Automatically raises a DeprecationWarning when the property is accessed or set.

See https://stackoverflow.com/questions/2536307/decorators-in-the-python-standard-lib-deprecated-specifically
"""
if isinstance(reason, string_types):
# The @deprecated is used with a 'reason'.
#
# .. code-block:: python
#
# @deprecated("please, use another function")
# def old_function(x, y):
# pass

def decorator(func1):
if inspect.isclass(func1):
fmt1 = "Call to deprecated class {name} ({reason})."
else:
fmt1 = "Call to deprecated function {name} ({reason})."

@functools.wraps(func1)
def new_func1(*args, **kwargs):
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
fmt1.format(name=func1.__name__, reason=reason),
category=DeprecationWarning,
stacklevel=2,
)
warnings.simplefilter("default", DeprecationWarning)
return func1(*args, **kwargs)

return new_func1

return decorator

elif inspect.isclass(reason) or inspect.isfunction(reason):
# The @deprecated is used without any 'reason'.
#
# .. code-block:: python
#
# @deprecated
# def old_function(x, y):
# pass

func2 = reason

if inspect.isclass(func2):
fmt2 = "Call to deprecated class {name}."
else:
fmt2 = "Call to deprecated function {name}."

@functools.wraps(func2)
def new_func2(*args, **kwargs):
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
fmt2.format(name=func2.__name__),
category=DeprecationWarning,
stacklevel=2,
)
warnings.simplefilter("default", DeprecationWarning)
return func2(*args, **kwargs)

return new_func2

else:
raise TypeError(repr(type(reason)))
File renamed without changes.
File renamed without changes.
59 changes: 59 additions & 0 deletions gmft/core/io/serial/dicts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import copy
from typing import Optional
from gmft.core.ml.prediction import (
IndicesPredictions,
RawBboxPredictions,
_empty_indices_predictions,
)
from gmft.detectors.base import CroppedTable
from gmft.formatters.base import _normalize_bbox
from gmft.impl.tatr.config import TATRFormatConfig
from gmft.pdf_bindings.base import BasePage


def _extract_fctn_results(d: dict) -> RawBboxPredictions:
"""
Extract prediction["tatr"], formerly known as fctn_results
"""
if "fctn_results" not in d:
raise ValueError(
"fctn_results not found in dict -- dict may be a CroppedTable but not a TATRFormattedTable."
)

results = d["fctn_results"] # fix shallow copy issue
if (
"fctn_scale_factor" in d
or "scale_factor" in d
or "fctn_padding" in d
or "padding" in d
):
# deprecated: this is for backwards compatibility
scale_factor = d.get("fctn_scale_factor", d.get("scale_factor", 1))
padding = d.get("fctn_padding", d.get("padding", (0, 0)))
padding = tuple(padding)

# normalize results here
for i, bbox in enumerate(results["boxes"]):
results["boxes"][i] = _normalize_bbox(
bbox, used_scale_factor=scale_factor, used_padding=padding
)
return results


def _extract_indices(d: dict) -> IndicesPredictions:
# version gmft>=0.5 format
if "predictions.indices" in d:
return d["predictions.indices"]

# version gmft<0.5 format
if any(
x in d
for x in ["_hier_left_indices", "_top_header_indices", "_projecting_indices"]
):
return {
"_projecting": d.get("_projecting_indices"),
"_hier_left": d.get("_hier_left_indices"),
"_top_header": d.get("_top_header_indices"),
}

return _empty_indices_predictions()
Loading
Loading