diff --git a/pyproject.toml b/pyproject.toml
index ca93b3c..01ebe99 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -84,7 +84,7 @@ warn_unused_ignores = true
[tool.pylint.main]
extension-pkg-allow-list = ["orjson"]
-ignore = ["tests", "dist", "build"]
+ignore = ["dist", "build"]
fail-under = 9.0
jobs = 0
limit-inference-results = 100
@@ -118,6 +118,13 @@ disable = [
"W0231", # super-init-not-called
]
+[tool.pylint."tests/*"]
+disable = [
+ "W0212", # protected-access
+ "W0612", # unused-variable
+ "W0621", # redefined-outer-name --> pytest fixtures
+]
+
[tool.ruff]
line-length = 88
indent-width = 4
diff --git a/requirements.txt b/requirements.txt
index 49652d6..d9b881c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
+attrs
google-benchmark
numpy
orjson
diff --git a/src/BenchMatcha/complexity.py b/src/BenchMatcha/complexity.py
index 95e8f5a..9b277a4 100644
--- a/src/BenchMatcha/complexity.py
+++ b/src/BenchMatcha/complexity.py
@@ -71,15 +71,16 @@ def __repr__(self) -> str:
# Define common complexity functions with all coefficients and intercept
Equation = (
- Callable[[np.ndarray, float, float], np.ndarray]
+ Callable[[np.ndarray, float], np.ndarray]
+ | Callable[[np.ndarray, float, float], np.ndarray]
| Callable[[np.ndarray, float, float, float], np.ndarray]
| Callable[[np.ndarray, float, float, float, float], np.ndarray]
)
-def constant(n: np.ndarray, a: float, b: float) -> np.ndarray:
+def constant(n: np.ndarray, a: float) -> np.ndarray:
"""Constant O(1) equation."""
- return a * np.ones_like(n) + b
+ return a * np.ones_like(n)
def logn(n: np.ndarray, a: float, b: float) -> np.ndarray:
diff --git a/src/BenchMatcha/config.py b/src/BenchMatcha/config.py
index b3fd28d..9cddf3c 100644
--- a/src/BenchMatcha/config.py
+++ b/src/BenchMatcha/config.py
@@ -30,17 +30,33 @@
"""Default runner configuration."""
import logging
+from typing import Any, Iterable
import toml # type: ignore[import-untyped]
-
-from . import plotting
+from attrs import asdict, define, field
log: logging.Logger = logging.getLogger(__name__)
-class Config:
- """default configuration.
+def traverse(d: dict, keys: Iterable[str]) -> dict:
+ """Safely traverse dictionary keys."""
+ for k in keys:
+ d = d.get(k, {})
+
+ return d
+
+
+@define
+class _ConfigBase:
+ def tojson(self) -> dict[str, Any]:
+ """Convert instance into json object."""
+ return asdict(self, recurse=True)
+
+
+@define
+class ConfigBase(_ConfigBase):
+ """default configuration settings.
Attributes:
color (str): plot marker color.
@@ -50,10 +66,13 @@ class Config:
"""
- color: str = plotting.Prism[3]
- line_color: str = plotting.Prism[4]
- font: str = "Space Grotesk Light, Courier New, monospace"
- x_axis: int = 13
+ color: str = field(converter=str, default="#0f8554")
+ line_color: str = field(converter=str, default="#73af48")
+ font: str = field(
+ converter=str,
+ default="Space Grotesk Light, Courier New, monospace",
+ )
+ x_axis: int = field(converter=int, default=13)
class ConfigUpdater:
@@ -61,38 +80,52 @@ class ConfigUpdater:
Args:
path (str): path to valid configuration file.
- config (Config): configuration class to update.
+ config (_ConfigBase): configuration class to update.
"""
path: str
- config: type[Config]
+ config: _ConfigBase
+ _tool_key: str = "BenchMatcha"
- def __init__(self, path: str, config: type[Config] = Config) -> None:
+ def __init__(self, path: str, config: _ConfigBase) -> None:
self.path = path
self.config = config
- def load(self) -> dict:
+ def load(self) -> dict[str, Any]:
"""Load toml data from path."""
return toml.load(self.path)
- def _update(self, data: dict) -> None:
- for key, value in data.get("tool", {}).get("BenchMatcha", {}).items():
+ def _update(self, data: dict[str, Any]) -> None:
+ for key, value in data.items():
+ # Support recursive configuration
+ if isinstance(value, dict):
+ previous = self.config
+ current = getattr(self.config, key)
+ if isinstance(current, _ConfigBase):
+ self.config = current
+ self._update(value)
+ self.config = previous
+ continue
+
if not hasattr(self.config, key):
log.info("Unsupported tool key: %s", key)
continue
-
setattr(self.config, key, value)
def update(self) -> None:
"""Parse toml path and update default configuration."""
- data: dict = self.load()
- self._update(data)
+ data: dict[str, Any] = self.load()
+ self._update(traverse(data, ("tool", self._tool_key)))
-def update_config_from_pyproject(path: str) -> None:
+def update_config_from_pyproject(path: str, config: _ConfigBase) -> None:
"""Update default config from pyproject toml file.
+ Args:
+ path (str): Filepath to pyproject.toml config.
+ config (ConfigBase): configuration class object.
+
Example:
.. code-block: toml
@@ -104,5 +137,5 @@ def update_config_from_pyproject(path: str) -> None:
x_axis=5
"""
- cu = ConfigUpdater(path)
+ cu = ConfigUpdater(path, config)
cu.update()
diff --git a/src/BenchMatcha/plotting.py b/src/BenchMatcha/plotting.py
index 042e908..01ca7b7 100644
--- a/src/BenchMatcha/plotting.py
+++ b/src/BenchMatcha/plotting.py
@@ -36,6 +36,8 @@
from plotly.express import colors # type: ignore[import-untyped]
from plotly.io import to_json as _to_json # type: ignore[import-untyped]
+from .config import ConfigBase
+from .structure import BenchmarkArray
from .utils import _simple_stats, power_of_2
@@ -245,3 +247,71 @@ def draw_complexity_line(
),
opacity=0.7,
)
+
+
+def plot_benchmark_array(benchmark: BenchmarkArray, config: ConfigBase) -> go.Figure:
+ """Plot benchmark array.
+
+ Args:
+ benchmark (BenchmarkArray): benchmark array data.
+ config (ConfigBase): configuration settings.
+
+ Returns:
+ (go.Figure) returns plotly figure.
+
+ """
+ fig = go.Figure()
+ fig.add_trace(
+ create_scatter_trace(
+ benchmark.size,
+ benchmark.cpu_time,
+ "CPU Time",
+ config.color,
+ )
+ )
+
+ fig.add_trace(
+ draw_complexity_line(
+ benchmark.size,
+ benchmark.complexity.cpu_coefficient,
+ benchmark.complexity.big_o,
+ f"CPU Time Fit ({benchmark.complexity.big_o})",
+ config.line_color,
+ )
+ )
+
+ fig.add_annotation(
+ **create_annotation_text(
+ benchmark.complexity.big_o,
+ benchmark.complexity.rms,
+ )
+ )
+
+ vals, labels = construct_log2_axis(benchmark.size)
+ if (p := len(vals) // config.x_axis) > 0:
+ vals = vals[:: p + 1]
+ labels = labels[:: p + 1]
+
+ fig.update_layout(
+ title=f"Benchmark Results
{benchmark.function}",
+ xaxis=dict(
+ type="log",
+ tickvals=vals,
+ ticktext=labels,
+ tickmode="array",
+ title="Input Size (n)",
+ ),
+ yaxis=dict(
+ title=f"Time ({benchmark.unit})",
+ type="log",
+ dtick=1,
+ exponentformat="power",
+ ),
+ legend_title="Timing",
+ font=dict(
+ family=config.font,
+ size=12,
+ ),
+ )
+
+ return fig
diff --git a/src/BenchMatcha/runner.py b/src/BenchMatcha/runner.py
index ac15962..ad734be 100644
--- a/src/BenchMatcha/runner.py
+++ b/src/BenchMatcha/runner.py
@@ -43,97 +43,17 @@
from . import plotting
# from .complexity import analyze_complexity
-from .config import Config, update_config_from_pyproject
+from .config import ConfigBase, update_config_from_pyproject
from .errors import ParsingError
from .handlers import HandleText
-from .sifter import collect_benchmarks, load_benchmark
-from .structure import BenchmarkArray, BenchmarkContext, parse_version
+from .plotting import plot_benchmark_array
+from .sifter import manage_registration
+from .structure import BenchmarkContext, parse_version
log: logging.Logger = logging.getLogger(__name__)
-def manage_registration(path: str) -> None:
- """Manage import, depending on whether path is a directory or file."""
- abspath: str = os.path.abspath(path)
- log.debug("Loading path: %s", abspath)
- if not os.path.exists(abspath):
- raise FileNotFoundError("Invalid filepath")
-
- if os.path.isdir(abspath):
- collect_benchmarks(abspath)
-
- elif os.path.isfile(abspath) and abspath.endswith(".py"):
- load_benchmark(abspath, os.path.abspath(os.path.dirname(abspath)))
-
- else:
- log.warning(
- "Unsupported path provided. While the path does exist, it is neither a"
- " python file nor a directory: %s",
- abspath,
- )
- raise TypeError(f"Unsupported path type: {abspath}")
-
-
-def plot_benchmark_array(benchmark: BenchmarkArray) -> go.Figure:
- """Plot benchmark array."""
- fig = go.Figure()
- fig.add_trace(
- plotting.create_scatter_trace(
- benchmark.size,
- benchmark.cpu_time,
- "CPU Time",
- Config.color,
- )
- )
-
- fig.add_trace(
- plotting.draw_complexity_line(
- benchmark.size,
- benchmark.complexity.cpu_coefficient,
- benchmark.complexity.big_o,
- f"CPU Time Fit ({benchmark.complexity.big_o})",
- Config.line_color,
- )
- )
-
- fig.add_annotation(
- **plotting.create_annotation_text(
- benchmark.complexity.big_o,
- benchmark.complexity.rms,
- )
- )
-
- vals, labels = plotting.construct_log2_axis(benchmark.size)
- if (p := len(vals) // Config.x_axis) > 0:
- vals = vals[:: p + 1]
- labels = labels[:: p + 1]
-
- fig.update_layout(
- title=f"Benchmark Results
{benchmark.function}",
- xaxis=dict(
- type="log",
- tickvals=vals,
- ticktext=labels,
- tickmode="array",
- title="Input Size (n)",
- ),
- yaxis=dict(
- title=f"Time ({benchmark.unit})",
- type="log",
- dtick=1,
- exponentformat="power",
- ),
- legend_title="Timing",
- font=dict(
- family=Config.font,
- size=12,
- ),
- )
-
- return fig
-
-
# TODO: Consider defining CLI Exit Status in an Enum
def _run() -> BenchmarkContext:
if "--benchmark_format=json" not in sys.argv:
@@ -153,7 +73,8 @@ def _run() -> BenchmarkContext:
text: str = stdout.read()
error: str = stderr.read()
- stdout.close(), stderr.close() # pylint: disable=W0106
+ stdout.close()
+ stderr.close()
# Pass stderr from google_benchmark
if len(error):
@@ -170,10 +91,10 @@ def _run() -> BenchmarkContext:
return context
-def save(context: BenchmarkContext, cache_dir: str) -> None:
+def save(context: BenchmarkContext, cache_dir: str, config: ConfigBase) -> None:
"""Save benchmark data."""
for j in context.benchmarks:
- figure: go.Figure = plot_benchmark_array(j)
+ figure: go.Figure = plot_benchmark_array(j, config)
plotting.to_html(figure, os.path.join(cache_dir, "out.html"), "a")
# TODO: Save data to database. Serialize to json in the interim.
@@ -193,7 +114,7 @@ def save(context: BenchmarkContext, cache_dir: str) -> None:
f.write(serialized)
-def run(cache_dir: str) -> None:
+def run(cache_dir: str, config: ConfigBase) -> None:
"""BenchMatcha Runner."""
context: BenchmarkContext = _run()
@@ -202,10 +123,29 @@ def run(cache_dir: str) -> None:
# for bench in context.benchmarks:
# analyze_complexity(bench.size, bench.real_time)
- save(context, cache_dir)
+ save(context, cache_dir, config)
+
+
+def prepare_benchmark_sys_args(known: argparse.Namespace, unknown: list[str]) -> None:
+ """Handle google benchmark system arg preparation."""
+ # Google_benchmark CLI arguments, For Reference:
+ # https://github.com/google/benchmark/blob/main/src/benchmark.cc#L751-L812
+ problems: list[str] = []
+ for k in filter(
+ lambda x: isinstance(x, str) and "--benchmark_format=" in x,
+ unknown,
+ ):
+ if "json" not in k:
+ log.warning("Benchmark Format must be json: `%s`", k)
+ problems.append(k)
+ for p in problems:
+ unknown.remove(p)
+
+ # Prune / Reset for google_benchmark
+ sys.argv = [sys.argv[0], *unknown, *known.others]
-def get_args() -> argparse.Namespace:
+def get_args() -> tuple[argparse.Namespace, list[str]]:
"""Get BenchMatcha command line arguments and reset to support google_benchmark."""
args = argparse.ArgumentParser("benchmatcha", conflict_handler="error")
args.add_argument(
@@ -282,29 +222,15 @@ def get_args() -> argparse.Namespace:
# )
known, unknown = args.parse_known_args()
- # NOTE: Only validate `benchmark_format` argument from google_benchmark cli, since
- # we require json format to correctly work downstream. All other argument
- # validations should be handled by google_benchmark cli parsing directly.
- problems: list[str] = []
- for k in filter(
- lambda x: isinstance(x, str) and "--benchmark_format=" in x,
- unknown,
- ):
- if "json" not in k:
- log.warning("Benchmark Format must be json: `%s`", k)
- problems.append(k)
- for p in problems:
- unknown.remove(p)
-
- # Prune / Reset for google_benchmark
- sys.argv = [sys.argv[0], *unknown, *known.others]
-
- return known
+ return known, unknown
def main() -> None:
"""Primary CLI Entry Point."""
- args: argparse.Namespace = get_args()
+ args: argparse.Namespace
+ unknowns: list[str]
+ args, unknowns = get_args()
+ default_config = ConfigBase()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
@@ -312,20 +238,22 @@ def main() -> None:
if os.path.exists(args.config):
log.debug("Updating default configuration from file: %s", args.config)
- update_config_from_pyproject(args.config)
+ update_config_from_pyproject(args.config, default_config)
+ else:
+ log.debug("Configuration file not found: %s", args.config)
# NOTE: Configuration Args should overwrite values set in config file
if args.color is not None:
log.debug("Overriding color from arg: %s", args.color)
- Config.color = args.color
+ default_config.color = args.color
if args.line_color is not None:
log.debug("Overriding line_color from arg: %s", args.line_color)
- Config.line_color = args.line_color
+ default_config.line_color = args.line_color
if args.x_axis is not None:
log.debug("Overriding x_axis from arg: %s", args.x_axis)
- Config.x_axis = args.x_axis
+ default_config.x_axis = int(args.x_axis)
# Create cache directory if it does not exist
if not os.path.exists(cache := args.cache):
@@ -336,4 +264,5 @@ def main() -> None:
for path in args.path:
manage_registration(path)
- run(cache)
+ prepare_benchmark_sys_args(args, unknowns)
+ run(cache, default_config)
diff --git a/src/BenchMatcha/sifter.py b/src/BenchMatcha/sifter.py
index 9415ef8..938c767 100644
--- a/src/BenchMatcha/sifter.py
+++ b/src/BenchMatcha/sifter.py
@@ -30,6 +30,7 @@
"""Discovery of benchmark tests to register."""
import glob
+import logging
import os
from collections.abc import Iterator
from pathlib import Path
@@ -38,6 +39,9 @@
from _pytest.pathlib import import_path
+log: logging.Logger = logging.getLogger(__name__)
+
+
def scandir(filepath: str) -> Iterator[os.DirEntry[str]]:
"""Simple wrapper around os.scandir to use more simply as an iterator."""
with os.scandir(os.path.abspath(filepath)) as scanner:
@@ -87,3 +91,25 @@ def collect_benchmarks(root: str) -> None:
root = os.path.abspath(root)
for j in collect(root):
load_benchmark(j, root=root)
+
+
+def manage_registration(path: str) -> None:
+ """Manage import, depending on whether path is a directory or file."""
+ abspath: str = os.path.abspath(path)
+ log.debug("Loading path: %s", abspath)
+ if not os.path.exists(abspath):
+ raise FileNotFoundError(f"Invalid filepath: {abspath}")
+
+ if os.path.isdir(abspath):
+ collect_benchmarks(abspath)
+
+ elif os.path.isfile(abspath) and abspath.endswith(".py"):
+ load_benchmark(abspath, os.path.abspath(os.path.dirname(abspath)))
+
+ else:
+ log.warning(
+ "Unsupported path provided. While the path does exist, it is neither a"
+ " python file nor a directory: %s",
+ abspath,
+ )
+ raise TypeError(f"Unsupported path type: {abspath}")
diff --git a/src/BenchMatcha/structure.py b/src/BenchMatcha/structure.py
index 3f881a9..92cb6bd 100644
--- a/src/BenchMatcha/structure.py
+++ b/src/BenchMatcha/structure.py
@@ -31,6 +31,9 @@
from __future__ import annotations
+import os
+import subprocess
+import sys
from collections import defaultdict
from dataclasses import dataclass
from datetime import UTC, datetime
@@ -64,6 +67,28 @@ def get_size(record: dict[str, str]) -> int:
return int(_get_function(record)[1])
+def get_python_version() -> str:
+ """Get current python version."""
+ version: tuple[str, str, str] = (
+ str(sys.version_info.major),
+ str(sys.version_info.minor),
+ str(sys.version_info.micro),
+ )
+
+ return ".".join(version)
+
+
+def _get_commit_hash(path: str) -> str:
+ return (
+ subprocess.check_output(
+ ["git", "describe", "HEAD", "--always"],
+ cwd=path,
+ )
+ .strip()
+ .decode()
+ )
+
+
@dataclass
class Cache:
"""System cache information.
@@ -295,8 +320,6 @@ def convert_to_arrays(
# TODO: Consider how we would capture custom data (Counters).
-# TODO: Capture additional information not included in google_benchmark output
-# (e.g. git sha)
@dataclass
class BenchmarkContext:
"""Google benchmark context."""
@@ -315,6 +338,8 @@ class BenchmarkContext:
json_schema_version: int
benchmarks: list[BenchmarkArray]
aslr_enabled: bool
+ python_version: str
+ git_sha: str
@classmethod
def from_json(cls, record: dict[str, Any]) -> Self:
@@ -338,6 +363,8 @@ def from_json(cls, record: dict[str, Any]) -> Self:
date=date,
benchmarks=benchmarks,
aslr_enabled=aslr,
+ python_version=get_python_version(),
+ git_sha=_get_commit_hash(os.getcwd()),
)
def to_json(self) -> dict:
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index 368b789..8ada40b 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -27,6 +27,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Common pytest fixtures for integration test suite."""
+
import os
import subprocess
import tempfile
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 558ab48..d4ef94c 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -27,6 +27,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Common pytest fixtures for unit test suite."""
+
import tempfile
from collections.abc import Iterator
@@ -35,6 +37,7 @@
@pytest.fixture
def mock_context() -> str:
+ """Sample output from google_benchmark context."""
return """
"context": {
"date": "2025-07-13T12:09:31-07:00",
@@ -73,6 +76,7 @@ def mock_context() -> str:
@pytest.fixture
def mock_bench() -> str:
+ """Sample output from google_benchmark benchmarks."""
return """
"benchmarks": [
{
@@ -136,12 +140,13 @@ def mock_bench() -> str:
@pytest.fixture
def mock_data(mock_context: str, mock_bench: str) -> str:
- """mock benchmark data."""
+ """Mock google_benchmark json data."""
return "{" + f"{mock_context}{mock_bench}" + "}"
@pytest.fixture
def mock_file(mock_data: str) -> Iterator[tempfile._TemporaryFileWrapper]:
+ """Mock temporary file wrapper around mock google_benchmark json data."""
with tempfile.NamedTemporaryFile("w+") as f:
f.write(mock_data)
f.seek(0)
diff --git a/tests/unit/test_complexity.py b/tests/unit/test_complexity.py
index da35df9..e8066de 100644
--- a/tests/unit/test_complexity.py
+++ b/tests/unit/test_complexity.py
@@ -40,6 +40,7 @@
@pytest.fixture
def fit_result() -> comp.FitResult:
+ """Sample Fit Result."""
return comp.FitResult(
bigo="N",
params=np.asarray([2.73]),
@@ -50,6 +51,7 @@ def fit_result() -> comp.FitResult:
@pytest.fixture
def coords() -> tuple[np.ndarray, np.ndarray]:
+ """Mock coordinates."""
x = np.arange(3)
y = np.arange(9).reshape((3, 3))
@@ -58,6 +60,8 @@ def coords() -> tuple[np.ndarray, np.ndarray]:
@pytest.fixture
def shunt_fit() -> Iterator[None]:
+ """Temporarily add a complexity function that fails, and removing on teardown."""
+
def _failed_fit(x, a, b):
raise RuntimeError("Intentionally Fail.")
diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py
index d7fda90..9b9ee63 100644
--- a/tests/unit/test_config.py
+++ b/tests/unit/test_config.py
@@ -29,99 +29,200 @@
"""Test config module."""
-from contextlib import contextmanager
from io import StringIO
from unittest.mock import MagicMock, patch
import pytest
+import toml
+from attrs import define, field
from BenchMatcha import config
@pytest.fixture
-def toml_str() -> str:
- return """
-[tool.BenchMatcha]
-color="#FFF"
-line_color="#333"
-font="Courier"
-upsupported_key="test"
-"""
-
-
-@pytest.fixture
-def toml_data() -> dict:
+def toml_inner() -> dict:
return {
- "tool": {
- "BenchMatcha": {
- "color": "#FFF",
- "line_color": "#333",
- "font": "Courier",
- "upsupported_key": "test",
- }
- }
+ "color": "#FFF",
+ "line_color": "#333",
+ "font": "Courier",
+ "x_axis": 5,
+ "unsupported_key": "test",
}
-@contextmanager
-def reset(c: type[config.Config]):
- """Simple contextmanager to reset config to original default values."""
- default: dict = {key: getattr(c, key) for key in c.__annotations__.keys()}
- yield
+@pytest.fixture
+def toml_data(toml_inner: dict) -> dict:
+ """Mock toml config as a dictionary object."""
+ return {"tool": {"BenchMatcha": toml_inner}}
+
- for k, v in default.items():
- setattr(c, k, v)
- assert getattr(c, k) == default[k], f"Expected Value to be reset: {k}"
+@pytest.fixture
+def toml_str(toml_data: dict) -> str:
+ """Mock toml config text."""
+ return toml.dumps(toml_data)
-def test_config_load(toml_str: str, toml_data: dict) -> None:
+@pytest.fixture
+def configuration() -> config.ConfigBase:
+ """Create configuration and reset instance singleton."""
+ return config.ConfigBase()
+
+
+@pytest.mark.parametrize(
+ ["obj", "keys", "expected"],
+ [
+ ({}, ("a", "b", "c"), {}),
+ ({"a": {"b": {"c": "d"}}}, ("a", "b", "c"), "d"),
+ ({"a": {"b": {"d": "d"}}}, ("a", "b", "c"), {}),
+ ],
+)
+def test_traverse(obj: dict, keys: tuple[str], expected: str) -> None:
+ """Test dictionary traversal works as anticipated."""
+ result = config.traverse(obj, keys)
+ assert result == expected
+
+
+def test_config_load(
+ toml_str: str,
+ toml_data: dict,
+ configuration: config.ConfigBase,
+) -> None:
"""Confirm toml data loads correctly."""
- with reset(config.Config), StringIO(toml_str) as stream:
- instance = config.ConfigUpdater(stream)
+ with StringIO(toml_str) as stream:
+ instance = config.ConfigUpdater(stream, configuration)
result = instance.load()
assert result == toml_data, "Expected same object"
-def _assert_config_is_updated() -> None:
- assert config.Config.color == "#FFF", "Expected color to be updated."
- assert config.Config.line_color == "#333", "Expected line color to be updated."
- assert config.Config.font == "Courier", "Expected font to be updated."
- assert not hasattr(config.Config, "upsupported_key"), (
+def _assert_config_is_updated(conf: config.ConfigBase) -> None:
+ assert conf.color == "#FFF", "Expected color to be updated."
+ assert conf.line_color == "#333", "Expected line color to be updated."
+ assert conf.font == "Courier", "Expected font to be updated."
+ assert not hasattr(conf, "unsupported_key"), (
"Expected unsupported key to be bypassed."
)
-def test_config_private_update(toml_data: dict) -> None:
+def test_config_private_update(
+ toml_inner: dict,
+ configuration: config.ConfigBase,
+) -> None:
"""Confirm config data is updated correctly."""
- with reset(config.Config):
- instance = config.ConfigUpdater("")
- instance._update(toml_data)
- _assert_config_is_updated()
+ instance = config.ConfigUpdater("", configuration)
+ instance._update(toml_inner)
+ _assert_config_is_updated(configuration)
-def test_config_update(toml_str: str) -> None:
+def test_config_update(
+ toml_str: str,
+ configuration: config.ConfigBase,
+) -> None:
"""Confirm toml data loads correctly."""
- with reset(config.Config), StringIO(toml_str) as stream:
- instance = config.ConfigUpdater(stream)
+ with StringIO(toml_str) as stream:
+ instance = config.ConfigUpdater(stream, configuration)
+ instance.update()
+
+ _assert_config_is_updated(configuration)
+
+
+def test_config_json(
+ toml_str: str,
+ toml_inner: dict,
+ configuration: config.ConfigBase,
+) -> None:
+ """Confirm configuration returns as json object correctly."""
+ with StringIO(toml_str) as stream:
+ instance = config.ConfigUpdater(stream, configuration)
instance.update()
- _assert_config_is_updated()
+
+ _assert_config_is_updated(configuration)
+ toml_inner.pop("unsupported_key", None)
+ assert configuration.tojson() == toml_inner, "Unexpected json object data."
@patch.object(config.ConfigUpdater, "load")
-def test_config_update_mock(mock: MagicMock, toml_data: dict) -> None:
+def test_config_update_mock(
+ mock: MagicMock,
+ toml_data: dict,
+ configuration: config.ConfigBase,
+) -> None:
"""Confirm config is updated by mocking load method of ConfigUpdater."""
mock.return_value = toml_data
- with reset(config.Config):
- instance = config.ConfigUpdater("")
- instance.update()
- _assert_config_is_updated()
+ instance = config.ConfigUpdater("", configuration)
+ instance.update()
+ _assert_config_is_updated(configuration)
@patch.object(config.ConfigUpdater, "load")
-def test_config_update_function(mock: MagicMock, toml_data: dict) -> None:
+def test_config_update_function(
+ mock: MagicMock,
+ toml_data: dict,
+ configuration: config.ConfigBase,
+) -> None:
"""Confirm config is updated from available function api."""
mock.return_value = toml_data
- with reset(config.Config):
- config.update_config_from_pyproject("")
- _assert_config_is_updated()
+ config.update_config_from_pyproject("", configuration)
+ _assert_config_is_updated(configuration)
+
+
+def test_config_base_setters(configuration: config.ConfigBase) -> None:
+ """Confirm class setattr engages designated converters."""
+ configuration.x_axis = "5"
+ assert isinstance(configuration.x_axis, int), "Expected value to be coerced to int."
+ assert configuration.x_axis == 5, "Expected value to be updated to 5."
+
+ setattr(configuration, "x_axis", "7") # noqa: B010
+ assert configuration.x_axis == 7, "Expected value to be updated to 7."
+
+
+@define
+class A(config._ConfigBase):
+ a: str = field(default="a")
+ b: int = field(default=1)
+
+
+@define
+class B(config._ConfigBase):
+ a: A = field(factory=A)
+ b: str = field(default="b")
+
+
+@pytest.fixture
+def toml_data_b() -> str:
+ return """
+[tool.BenchMatcha]
+b = "loaded b"
+
+[tool.BenchMatcha.a]
+a = "loaded a"
+b = 5
+"""
+
+
+@patch.object(config.ConfigUpdater, "load")
+def test_recursive_update_function(
+ mock: MagicMock,
+ toml_data_b: str,
+):
+ """Confirm nested recursive configurations are appropriately updated."""
+ # sanity checks
+ con = B()
+ assert con.a.a == "a"
+ assert con.a.b == 1
+ assert con.b == "b"
+
+ # setup
+ data: dict = toml.loads(toml_data_b)
+ mock.return_value = data
+
+ instance = config.ConfigUpdater("", con)
+ instance.update()
+
+ # checks
+ assert con.a.a == "loaded a"
+ assert con.a.b == 5
+ assert con.b == "loaded b"
+
+ # sanity
+ assert instance.config is con, "Expected config to be same on instance."
diff --git a/tests/unit/test_errors.py b/tests/unit/test_errors.py
index 8d8b546..07c1e45 100644
--- a/tests/unit/test_errors.py
+++ b/tests/unit/test_errors.py
@@ -27,6 +27,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Unit test errors module."""
+
import builtins
from collections.abc import Callable, Iterator
from json import JSONDecodeError
@@ -38,6 +40,7 @@
@pytest.fixture
def register_error() -> Iterator[Callable[[type[Exception]], type[Exception]]]:
+ """Register custom error dynamically with teardown."""
e: type[Exception] | None = None
def inner(exc: type[Exception]) -> type[Exception]:
diff --git a/tests/unit/test_handlers.py b/tests/unit/test_handlers.py
index 6e68831..202a936 100644
--- a/tests/unit/test_handlers.py
+++ b/tests/unit/test_handlers.py
@@ -50,7 +50,7 @@ def test_load(mock_data: str) -> None:
[
(handlers.HandlePath, lambda x: x),
(handlers.HandleBytes, lambda x: x.encode()),
- (handlers.HandleIO, lambda x: StringIO(x)),
+ (handlers.HandleIO, StringIO),
(handlers.HandleIO, lambda x: BytesIO(x.encode())),
],
)
@@ -59,7 +59,7 @@ def test_handlers(
transformer: Callable[[str], Any],
mock_data: str,
) -> None:
- """"""
+ """Test handlers correctly parse mock json data."""
mock = transformer(mock_data)
result = handler(mock).handle()
assert isinstance(result, dict), "Expected a dictionary loaded."
diff --git a/tests/unit/test_plotting.py b/tests/unit/test_plotting.py
index 6e520a4..dd8c361 100644
--- a/tests/unit/test_plotting.py
+++ b/tests/unit/test_plotting.py
@@ -36,6 +36,32 @@
import pytest
from BenchMatcha import plotting
+from BenchMatcha.config import ConfigBase
+from BenchMatcha.structure import BenchmarkArray, ComplexityInfo
+
+
+@pytest.fixture
+def bench_arr():
+ """Sample BenchmarkArray instance object."""
+ name: str = "test"
+ c = ComplexityInfo(
+ function=name,
+ big_o="lgN",
+ real_coefficient=1.0,
+ cpu_coefficient=2.0,
+ )
+
+ b = BenchmarkArray(
+ function=name,
+ unit="s",
+ size=np.asarray([2, 4, 8]),
+ iterations=np.asarray([[100, 50, 25], [100, 50, 25], [100, 50, 25]]),
+ real_time=np.asarray([[5.2, 5.6, 5.3], [10.3, 10.2, 10.0], [15.3, 15.1, 15.2]]),
+ cpu_time=np.asarray([[5.2, 5.2, 5.0], [10.1, 10.0, 9.9], [15.2, 14.9, 15.0]]),
+ complexity=c,
+ )
+
+ return b
def test_serialization_to_html():
@@ -144,3 +170,9 @@ def test_draw_complexity_line() -> None:
x = np.arange(20)
result = plotting.draw_complexity_line(x, 1.2, "N", "test", "red")
assert isinstance(result, go.Scatter)
+
+
+def test_plot_benchmark_array(bench_arr: BenchmarkArray):
+ """Confirm an array is constructed."""
+ result = plotting.plot_benchmark_array(bench_arr, ConfigBase())
+ assert isinstance(result, go.Figure), "Expected a figure object."
diff --git a/tests/unit/test_sifter.py b/tests/unit/test_sifter.py
index f1ba107..2580adb 100644
--- a/tests/unit/test_sifter.py
+++ b/tests/unit/test_sifter.py
@@ -47,6 +47,7 @@ def __init__(self, path: str, directory: bool = False):
self.directory = directory
def is_dir(self, **kwargs) -> bool:
+ """Confirm if path is a directory."""
return self.directory
def __str__(self):
diff --git a/tests/unit/test_structure.py b/tests/unit/test_structure.py
index ffbe0c0..4be7191 100644
--- a/tests/unit/test_structure.py
+++ b/tests/unit/test_structure.py
@@ -40,6 +40,7 @@
@pytest.fixture
def shunt_version() -> Iterator[Callable[[tuple[int, ...]], tuple[int, ...]]]:
+ """Temporarily modify supported versions, with teardown to reinstate original."""
previous = structure.SUPPORTED_VERSIONS
def modify(value: tuple[int, ...]) -> tuple[int, ...]:
@@ -52,6 +53,28 @@ def modify(value: tuple[int, ...]) -> tuple[int, ...]:
assert structure.SUPPORTED_VERSIONS == previous
+@pytest.fixture
+def complexity_info() -> dict:
+ """Expected complexity info data."""
+ return {
+ "function": "function",
+ "big_o": "N",
+ "cpu_coefficient": 1.2549524739108346e01,
+ "real_coefficient": 1.2555517635286144e01,
+ "rms": 5.4107447739157266e-02,
+ }
+
+
+@pytest.fixture
+def cache_data() -> list[tuple]:
+ """Expected cache data."""
+ return [
+ ("Data", 1, 65536, 0),
+ ("Instruction", 1, 131072, 0),
+ ("Unified", 2, 4194304, 1),
+ ]
+
+
@pytest.mark.parametrize(
["value", "expected"],
[
@@ -66,7 +89,99 @@ def test_parse_datetime(value: str, expected: datetime) -> None:
assert result == expected, "Unexpected date time."
-def test_parse_json_data(mock_data: str) -> None:
+def _check_cache(cache: structure.Cache, expected: tuple) -> None:
+ assert isinstance(cache, structure.Cache), f"Expected Cache type: {type(cache)}."
+ assert cache.type == expected[0], f"Unexpected type: {cache.type}"
+ assert cache.level == expected[1], f"Unexpected level: {cache.level}"
+ assert cache.size == expected[2], f"Unexpected size: {cache.size}"
+ assert cache.num_sharing == expected[3], f"Unexpected sharing: {cache.num_sharing}"
+
+
+def test_convert_cache_data() -> None:
+ """Test we correctly convert cache json data into a dataclass."""
+ cache = {
+ "type": "test",
+ "level": 17,
+ "size": 45,
+ "num_sharing": 50,
+ "unknown_key": 3,
+ }
+ result = structure.Cache.from_json(cache)
+ _check_cache(result, ("test", 17, 45, 50))
+
+
+def test_parse_cache_json_data(mock_data: str, cache_data: list[tuple]) -> None:
+ """Test we correctly parse cache json data into a dataclass."""
+ data: dict = load(mock_data)
+ caches: list[dict] = data["context"]["caches"]
+ for cache, e in zip(caches, cache_data, strict=False):
+ result = structure.Cache.from_json(cache)
+ _check_cache(result, e)
+
+
+def _check_complexity_info(complexity: structure.ComplexityInfo, expected: dict):
+ assert isinstance(complexity, structure.ComplexityInfo), (
+ "Expected value to be a ComplexityInfo object."
+ )
+ assert complexity.function == expected["function"], (
+ f"Incorrect Function Name: {complexity.function}"
+ )
+ assert complexity.big_o == expected["big_o"], (
+ f"Incorrect Function Name: {complexity.function}"
+ )
+ assert complexity.cpu_coefficient == expected["cpu_coefficient"], (
+ f"Unexpected CPU Coefficient: {complexity.cpu_coefficient}"
+ )
+ assert complexity.real_coefficient == expected["real_coefficient"], (
+ f"Unexpected Real Coefficient: {complexity.real_coefficient}"
+ )
+ assert complexity.rms == expected["rms"], f"Unexpected RMS: {complexity.rms}"
+
+
+def test_convert_benchmark_json_to_complexity(complexity_info: dict) -> None:
+ """Test conversion of benchmark data correctly collects complexity info."""
+ obj: list[dict] = [
+ {
+ "name": "function/repeats:3_BigO",
+ "family_index": 0,
+ "per_family_instance_index": 0,
+ "run_name": "function/repeats:3",
+ "run_type": "aggregate",
+ "repetitions": 3,
+ "threads": 1,
+ "aggregate_name": "BigO",
+ "aggregate_unit": "time",
+ "cpu_coefficient": 1.2549524739108346e01,
+ "real_coefficient": 1.2555517635286144e01,
+ "big_o": "N",
+ "time_unit": "ns",
+ },
+ {
+ "name": "function/repeats:3_RMS",
+ "family_index": 0,
+ "per_family_instance_index": 0,
+ "run_name": "function/repeats:3",
+ "run_type": "aggregate",
+ "repetitions": 3,
+ "threads": 1,
+ "aggregate_name": "RMS",
+ "aggregate_unit": "percentage",
+ "rms": 5.4107447739157266e-02,
+ },
+ ]
+
+ result = structure.get_complexity_info(obj)
+ assert isinstance(result, dict), "Expected dictionary object returned."
+ assert "function" in result, "Expected function name key to be present."
+ value = result["function"]
+ _check_complexity_info(value, complexity_info)
+
+
+def test_parse_full_json_data(
+ mock_data: str,
+ complexity_info: dict,
+ cache_data: list[tuple],
+) -> None:
"""Test we correctly parse json data into a dataclass."""
data = load(mock_data)
result = structure.BenchmarkContext.from_json(data)
@@ -79,15 +194,23 @@ def test_parse_json_data(mock_data: str) -> None:
assert result.cpu_scaling_enabled is False
assert len(result.caches) == 3
+ for c, e in zip(result.caches, cache_data, strict=False):
+ _check_cache(c, e)
+
assert result.load_avg == [4.69092, 4.60693, 4.47949]
assert result.library_version == "1.9.4"
assert result.library_build_type == "release"
assert result.json_schema_version == 1
+ assert result.aslr_enabled is False
+ assert isinstance(result.python_version, str), "Expected a string"
+ assert isinstance(result.git_sha, str), "Expected a string"
+ assert isinstance(result.benchmarks, list), "Expected benchmarks to be a list."
assert len(result.benchmarks) == 1
+ _check_complexity_info(result.benchmarks[0].complexity, complexity_info)
-def test_convert_BenchmarkContext_to_json(mock_data: str) -> None:
+def test_convert_benchmark_context_to_json(mock_data: str) -> None:
"""Test we convert dataclass into dictionary json like objects."""
data = load(mock_data)
obj = structure.BenchmarkContext.from_json(data)
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py
index 4ef1d61..b141bb1 100644
--- a/tests/unit/test_utils.py
+++ b/tests/unit/test_utils.py
@@ -27,6 +27,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Unit test utils module."""
+
import numpy as np
import pytest
@@ -38,20 +40,20 @@
(1, 1),
]
for j in range(2, 24):
- value = 1 << j
- _power2.append((value - 1, value))
- _power2.append((value, value))
- _power2.append((value + 1, 1 << (j + 1)))
+ _v = 1 << j
+ _power2.append((_v - 1, _v))
+ _power2.append((_v, _v))
+ _power2.append((_v + 1, 1 << (j + 1)))
@pytest.mark.parametrize(["value", "expected"], _power2)
-def test_power_of_2(value: int, expected: int):
+def test_power_of_2(value: int, expected: int) -> None:
"""Test returns next power of two."""
- result = utils.power_of_2(value)
+ result: int = utils.power_of_2(value)
assert result == expected, f"Unexpected result: {result}"
-def test_simple_stats():
+def test_simple_stats() -> None:
"""Test mean and std."""
x = np.asarray([[1, 2, 3], [2, 3, 1], [3, 1, 2]])
result = utils._simple_stats(x)
@@ -72,7 +74,7 @@ def test_simple_stats():
("oLambda", "f(N)"),
],
)
-def test_bigo_enum_get(value: str, expected: str):
+def test_bigo_enum_get(value: str, expected: str) -> None:
"""Test conversion of big o notation identifier get classmethod."""
result: str = utils.BigO.get(value)
assert result == expected, "Unexpected result."
@@ -90,7 +92,7 @@ def test_bigo_enum_get(value: str, expected: str):
("f(N)", "oLambda"),
],
)
-def test_bigo_enum_back(value: str, expected: str):
+def test_bigo_enum_back(value: str, expected: str) -> None:
"""Test conversion of big o notation identifier back classmethod."""
result: str = utils.BigO.back(value)
assert result == expected, "Unexpected result."