diff --git a/pyproject.toml b/pyproject.toml index ca93b3c..01ebe99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ warn_unused_ignores = true [tool.pylint.main] extension-pkg-allow-list = ["orjson"] -ignore = ["tests", "dist", "build"] +ignore = ["dist", "build"] fail-under = 9.0 jobs = 0 limit-inference-results = 100 @@ -118,6 +118,13 @@ disable = [ "W0231", # super-init-not-called ] +[tool.pylint."tests/*"] +disable = [ + "W0212", # protected-access + "W0612", # unused-variable + "W0621", # redefined-outer-name --> pytest fixtures +] + [tool.ruff] line-length = 88 indent-width = 4 diff --git a/requirements.txt b/requirements.txt index 49652d6..d9b881c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +attrs google-benchmark numpy orjson diff --git a/src/BenchMatcha/complexity.py b/src/BenchMatcha/complexity.py index 95e8f5a..9b277a4 100644 --- a/src/BenchMatcha/complexity.py +++ b/src/BenchMatcha/complexity.py @@ -71,15 +71,16 @@ def __repr__(self) -> str: # Define common complexity functions with all coefficients and intercept Equation = ( - Callable[[np.ndarray, float, float], np.ndarray] + Callable[[np.ndarray, float], np.ndarray] + | Callable[[np.ndarray, float, float], np.ndarray] | Callable[[np.ndarray, float, float, float], np.ndarray] | Callable[[np.ndarray, float, float, float, float], np.ndarray] ) -def constant(n: np.ndarray, a: float, b: float) -> np.ndarray: +def constant(n: np.ndarray, a: float) -> np.ndarray: """Constant O(1) equation.""" - return a * np.ones_like(n) + b + return a * np.ones_like(n) def logn(n: np.ndarray, a: float, b: float) -> np.ndarray: diff --git a/src/BenchMatcha/config.py b/src/BenchMatcha/config.py index b3fd28d..9cddf3c 100644 --- a/src/BenchMatcha/config.py +++ b/src/BenchMatcha/config.py @@ -30,17 +30,33 @@ """Default runner configuration.""" import logging +from typing import Any, Iterable import toml # type: ignore[import-untyped] - -from . import plotting +from attrs import asdict, define, field log: logging.Logger = logging.getLogger(__name__) -class Config: - """default configuration. +def traverse(d: dict, keys: Iterable[str]) -> dict: + """Safely traverse dictionary keys.""" + for k in keys: + d = d.get(k, {}) + + return d + + +@define +class _ConfigBase: + def tojson(self) -> dict[str, Any]: + """Convert instance into json object.""" + return asdict(self, recurse=True) + + +@define +class ConfigBase(_ConfigBase): + """default configuration settings. Attributes: color (str): plot marker color. @@ -50,10 +66,13 @@ class Config: """ - color: str = plotting.Prism[3] - line_color: str = plotting.Prism[4] - font: str = "Space Grotesk Light, Courier New, monospace" - x_axis: int = 13 + color: str = field(converter=str, default="#0f8554") + line_color: str = field(converter=str, default="#73af48") + font: str = field( + converter=str, + default="Space Grotesk Light, Courier New, monospace", + ) + x_axis: int = field(converter=int, default=13) class ConfigUpdater: @@ -61,38 +80,52 @@ class ConfigUpdater: Args: path (str): path to valid configuration file. - config (Config): configuration class to update. + config (_ConfigBase): configuration class to update. """ path: str - config: type[Config] + config: _ConfigBase + _tool_key: str = "BenchMatcha" - def __init__(self, path: str, config: type[Config] = Config) -> None: + def __init__(self, path: str, config: _ConfigBase) -> None: self.path = path self.config = config - def load(self) -> dict: + def load(self) -> dict[str, Any]: """Load toml data from path.""" return toml.load(self.path) - def _update(self, data: dict) -> None: - for key, value in data.get("tool", {}).get("BenchMatcha", {}).items(): + def _update(self, data: dict[str, Any]) -> None: + for key, value in data.items(): + # Support recursive configuration + if isinstance(value, dict): + previous = self.config + current = getattr(self.config, key) + if isinstance(current, _ConfigBase): + self.config = current + self._update(value) + self.config = previous + continue + if not hasattr(self.config, key): log.info("Unsupported tool key: %s", key) continue - setattr(self.config, key, value) def update(self) -> None: """Parse toml path and update default configuration.""" - data: dict = self.load() - self._update(data) + data: dict[str, Any] = self.load() + self._update(traverse(data, ("tool", self._tool_key))) -def update_config_from_pyproject(path: str) -> None: +def update_config_from_pyproject(path: str, config: _ConfigBase) -> None: """Update default config from pyproject toml file. + Args: + path (str): Filepath to pyproject.toml config. + config (ConfigBase): configuration class object. + Example: .. code-block: toml @@ -104,5 +137,5 @@ def update_config_from_pyproject(path: str) -> None: x_axis=5 """ - cu = ConfigUpdater(path) + cu = ConfigUpdater(path, config) cu.update() diff --git a/src/BenchMatcha/plotting.py b/src/BenchMatcha/plotting.py index 042e908..01ca7b7 100644 --- a/src/BenchMatcha/plotting.py +++ b/src/BenchMatcha/plotting.py @@ -36,6 +36,8 @@ from plotly.express import colors # type: ignore[import-untyped] from plotly.io import to_json as _to_json # type: ignore[import-untyped] +from .config import ConfigBase +from .structure import BenchmarkArray from .utils import _simple_stats, power_of_2 @@ -245,3 +247,71 @@ def draw_complexity_line( ), opacity=0.7, ) + + +def plot_benchmark_array(benchmark: BenchmarkArray, config: ConfigBase) -> go.Figure: + """Plot benchmark array. + + Args: + benchmark (BenchmarkArray): benchmark array data. + config (ConfigBase): configuration settings. + + Returns: + (go.Figure) returns plotly figure. + + """ + fig = go.Figure() + fig.add_trace( + create_scatter_trace( + benchmark.size, + benchmark.cpu_time, + "CPU Time", + config.color, + ) + ) + + fig.add_trace( + draw_complexity_line( + benchmark.size, + benchmark.complexity.cpu_coefficient, + benchmark.complexity.big_o, + f"CPU Time Fit ({benchmark.complexity.big_o})", + config.line_color, + ) + ) + + fig.add_annotation( + **create_annotation_text( + benchmark.complexity.big_o, + benchmark.complexity.rms, + ) + ) + + vals, labels = construct_log2_axis(benchmark.size) + if (p := len(vals) // config.x_axis) > 0: + vals = vals[:: p + 1] + labels = labels[:: p + 1] + + fig.update_layout( + title=f"Benchmark Results
{benchmark.function}", + xaxis=dict( + type="log", + tickvals=vals, + ticktext=labels, + tickmode="array", + title="Input Size (n)", + ), + yaxis=dict( + title=f"Time ({benchmark.unit})", + type="log", + dtick=1, + exponentformat="power", + ), + legend_title="Timing", + font=dict( + family=config.font, + size=12, + ), + ) + + return fig diff --git a/src/BenchMatcha/runner.py b/src/BenchMatcha/runner.py index ac15962..ad734be 100644 --- a/src/BenchMatcha/runner.py +++ b/src/BenchMatcha/runner.py @@ -43,97 +43,17 @@ from . import plotting # from .complexity import analyze_complexity -from .config import Config, update_config_from_pyproject +from .config import ConfigBase, update_config_from_pyproject from .errors import ParsingError from .handlers import HandleText -from .sifter import collect_benchmarks, load_benchmark -from .structure import BenchmarkArray, BenchmarkContext, parse_version +from .plotting import plot_benchmark_array +from .sifter import manage_registration +from .structure import BenchmarkContext, parse_version log: logging.Logger = logging.getLogger(__name__) -def manage_registration(path: str) -> None: - """Manage import, depending on whether path is a directory or file.""" - abspath: str = os.path.abspath(path) - log.debug("Loading path: %s", abspath) - if not os.path.exists(abspath): - raise FileNotFoundError("Invalid filepath") - - if os.path.isdir(abspath): - collect_benchmarks(abspath) - - elif os.path.isfile(abspath) and abspath.endswith(".py"): - load_benchmark(abspath, os.path.abspath(os.path.dirname(abspath))) - - else: - log.warning( - "Unsupported path provided. While the path does exist, it is neither a" - " python file nor a directory: %s", - abspath, - ) - raise TypeError(f"Unsupported path type: {abspath}") - - -def plot_benchmark_array(benchmark: BenchmarkArray) -> go.Figure: - """Plot benchmark array.""" - fig = go.Figure() - fig.add_trace( - plotting.create_scatter_trace( - benchmark.size, - benchmark.cpu_time, - "CPU Time", - Config.color, - ) - ) - - fig.add_trace( - plotting.draw_complexity_line( - benchmark.size, - benchmark.complexity.cpu_coefficient, - benchmark.complexity.big_o, - f"CPU Time Fit ({benchmark.complexity.big_o})", - Config.line_color, - ) - ) - - fig.add_annotation( - **plotting.create_annotation_text( - benchmark.complexity.big_o, - benchmark.complexity.rms, - ) - ) - - vals, labels = plotting.construct_log2_axis(benchmark.size) - if (p := len(vals) // Config.x_axis) > 0: - vals = vals[:: p + 1] - labels = labels[:: p + 1] - - fig.update_layout( - title=f"Benchmark Results
{benchmark.function}", - xaxis=dict( - type="log", - tickvals=vals, - ticktext=labels, - tickmode="array", - title="Input Size (n)", - ), - yaxis=dict( - title=f"Time ({benchmark.unit})", - type="log", - dtick=1, - exponentformat="power", - ), - legend_title="Timing", - font=dict( - family=Config.font, - size=12, - ), - ) - - return fig - - # TODO: Consider defining CLI Exit Status in an Enum def _run() -> BenchmarkContext: if "--benchmark_format=json" not in sys.argv: @@ -153,7 +73,8 @@ def _run() -> BenchmarkContext: text: str = stdout.read() error: str = stderr.read() - stdout.close(), stderr.close() # pylint: disable=W0106 + stdout.close() + stderr.close() # Pass stderr from google_benchmark if len(error): @@ -170,10 +91,10 @@ def _run() -> BenchmarkContext: return context -def save(context: BenchmarkContext, cache_dir: str) -> None: +def save(context: BenchmarkContext, cache_dir: str, config: ConfigBase) -> None: """Save benchmark data.""" for j in context.benchmarks: - figure: go.Figure = plot_benchmark_array(j) + figure: go.Figure = plot_benchmark_array(j, config) plotting.to_html(figure, os.path.join(cache_dir, "out.html"), "a") # TODO: Save data to database. Serialize to json in the interim. @@ -193,7 +114,7 @@ def save(context: BenchmarkContext, cache_dir: str) -> None: f.write(serialized) -def run(cache_dir: str) -> None: +def run(cache_dir: str, config: ConfigBase) -> None: """BenchMatcha Runner.""" context: BenchmarkContext = _run() @@ -202,10 +123,29 @@ def run(cache_dir: str) -> None: # for bench in context.benchmarks: # analyze_complexity(bench.size, bench.real_time) - save(context, cache_dir) + save(context, cache_dir, config) + + +def prepare_benchmark_sys_args(known: argparse.Namespace, unknown: list[str]) -> None: + """Handle google benchmark system arg preparation.""" + # Google_benchmark CLI arguments, For Reference: + # https://github.com/google/benchmark/blob/main/src/benchmark.cc#L751-L812 + problems: list[str] = [] + for k in filter( + lambda x: isinstance(x, str) and "--benchmark_format=" in x, + unknown, + ): + if "json" not in k: + log.warning("Benchmark Format must be json: `%s`", k) + problems.append(k) + for p in problems: + unknown.remove(p) + + # Prune / Reset for google_benchmark + sys.argv = [sys.argv[0], *unknown, *known.others] -def get_args() -> argparse.Namespace: +def get_args() -> tuple[argparse.Namespace, list[str]]: """Get BenchMatcha command line arguments and reset to support google_benchmark.""" args = argparse.ArgumentParser("benchmatcha", conflict_handler="error") args.add_argument( @@ -282,29 +222,15 @@ def get_args() -> argparse.Namespace: # ) known, unknown = args.parse_known_args() - # NOTE: Only validate `benchmark_format` argument from google_benchmark cli, since - # we require json format to correctly work downstream. All other argument - # validations should be handled by google_benchmark cli parsing directly. - problems: list[str] = [] - for k in filter( - lambda x: isinstance(x, str) and "--benchmark_format=" in x, - unknown, - ): - if "json" not in k: - log.warning("Benchmark Format must be json: `%s`", k) - problems.append(k) - for p in problems: - unknown.remove(p) - - # Prune / Reset for google_benchmark - sys.argv = [sys.argv[0], *unknown, *known.others] - - return known + return known, unknown def main() -> None: """Primary CLI Entry Point.""" - args: argparse.Namespace = get_args() + args: argparse.Namespace + unknowns: list[str] + args, unknowns = get_args() + default_config = ConfigBase() if args.verbose: logging.basicConfig(level=logging.DEBUG) @@ -312,20 +238,22 @@ def main() -> None: if os.path.exists(args.config): log.debug("Updating default configuration from file: %s", args.config) - update_config_from_pyproject(args.config) + update_config_from_pyproject(args.config, default_config) + else: + log.debug("Configuration file not found: %s", args.config) # NOTE: Configuration Args should overwrite values set in config file if args.color is not None: log.debug("Overriding color from arg: %s", args.color) - Config.color = args.color + default_config.color = args.color if args.line_color is not None: log.debug("Overriding line_color from arg: %s", args.line_color) - Config.line_color = args.line_color + default_config.line_color = args.line_color if args.x_axis is not None: log.debug("Overriding x_axis from arg: %s", args.x_axis) - Config.x_axis = args.x_axis + default_config.x_axis = int(args.x_axis) # Create cache directory if it does not exist if not os.path.exists(cache := args.cache): @@ -336,4 +264,5 @@ def main() -> None: for path in args.path: manage_registration(path) - run(cache) + prepare_benchmark_sys_args(args, unknowns) + run(cache, default_config) diff --git a/src/BenchMatcha/sifter.py b/src/BenchMatcha/sifter.py index 9415ef8..938c767 100644 --- a/src/BenchMatcha/sifter.py +++ b/src/BenchMatcha/sifter.py @@ -30,6 +30,7 @@ """Discovery of benchmark tests to register.""" import glob +import logging import os from collections.abc import Iterator from pathlib import Path @@ -38,6 +39,9 @@ from _pytest.pathlib import import_path +log: logging.Logger = logging.getLogger(__name__) + + def scandir(filepath: str) -> Iterator[os.DirEntry[str]]: """Simple wrapper around os.scandir to use more simply as an iterator.""" with os.scandir(os.path.abspath(filepath)) as scanner: @@ -87,3 +91,25 @@ def collect_benchmarks(root: str) -> None: root = os.path.abspath(root) for j in collect(root): load_benchmark(j, root=root) + + +def manage_registration(path: str) -> None: + """Manage import, depending on whether path is a directory or file.""" + abspath: str = os.path.abspath(path) + log.debug("Loading path: %s", abspath) + if not os.path.exists(abspath): + raise FileNotFoundError(f"Invalid filepath: {abspath}") + + if os.path.isdir(abspath): + collect_benchmarks(abspath) + + elif os.path.isfile(abspath) and abspath.endswith(".py"): + load_benchmark(abspath, os.path.abspath(os.path.dirname(abspath))) + + else: + log.warning( + "Unsupported path provided. While the path does exist, it is neither a" + " python file nor a directory: %s", + abspath, + ) + raise TypeError(f"Unsupported path type: {abspath}") diff --git a/src/BenchMatcha/structure.py b/src/BenchMatcha/structure.py index 3f881a9..92cb6bd 100644 --- a/src/BenchMatcha/structure.py +++ b/src/BenchMatcha/structure.py @@ -31,6 +31,9 @@ from __future__ import annotations +import os +import subprocess +import sys from collections import defaultdict from dataclasses import dataclass from datetime import UTC, datetime @@ -64,6 +67,28 @@ def get_size(record: dict[str, str]) -> int: return int(_get_function(record)[1]) +def get_python_version() -> str: + """Get current python version.""" + version: tuple[str, str, str] = ( + str(sys.version_info.major), + str(sys.version_info.minor), + str(sys.version_info.micro), + ) + + return ".".join(version) + + +def _get_commit_hash(path: str) -> str: + return ( + subprocess.check_output( + ["git", "describe", "HEAD", "--always"], + cwd=path, + ) + .strip() + .decode() + ) + + @dataclass class Cache: """System cache information. @@ -295,8 +320,6 @@ def convert_to_arrays( # TODO: Consider how we would capture custom data (Counters). -# TODO: Capture additional information not included in google_benchmark output -# (e.g. git sha) @dataclass class BenchmarkContext: """Google benchmark context.""" @@ -315,6 +338,8 @@ class BenchmarkContext: json_schema_version: int benchmarks: list[BenchmarkArray] aslr_enabled: bool + python_version: str + git_sha: str @classmethod def from_json(cls, record: dict[str, Any]) -> Self: @@ -338,6 +363,8 @@ def from_json(cls, record: dict[str, Any]) -> Self: date=date, benchmarks=benchmarks, aslr_enabled=aslr, + python_version=get_python_version(), + git_sha=_get_commit_hash(os.getcwd()), ) def to_json(self) -> dict: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 368b789..8ada40b 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -27,6 +27,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Common pytest fixtures for integration test suite.""" + import os import subprocess import tempfile diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 558ab48..d4ef94c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -27,6 +27,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Common pytest fixtures for unit test suite.""" + import tempfile from collections.abc import Iterator @@ -35,6 +37,7 @@ @pytest.fixture def mock_context() -> str: + """Sample output from google_benchmark context.""" return """ "context": { "date": "2025-07-13T12:09:31-07:00", @@ -73,6 +76,7 @@ def mock_context() -> str: @pytest.fixture def mock_bench() -> str: + """Sample output from google_benchmark benchmarks.""" return """ "benchmarks": [ { @@ -136,12 +140,13 @@ def mock_bench() -> str: @pytest.fixture def mock_data(mock_context: str, mock_bench: str) -> str: - """mock benchmark data.""" + """Mock google_benchmark json data.""" return "{" + f"{mock_context}{mock_bench}" + "}" @pytest.fixture def mock_file(mock_data: str) -> Iterator[tempfile._TemporaryFileWrapper]: + """Mock temporary file wrapper around mock google_benchmark json data.""" with tempfile.NamedTemporaryFile("w+") as f: f.write(mock_data) f.seek(0) diff --git a/tests/unit/test_complexity.py b/tests/unit/test_complexity.py index da35df9..e8066de 100644 --- a/tests/unit/test_complexity.py +++ b/tests/unit/test_complexity.py @@ -40,6 +40,7 @@ @pytest.fixture def fit_result() -> comp.FitResult: + """Sample Fit Result.""" return comp.FitResult( bigo="N", params=np.asarray([2.73]), @@ -50,6 +51,7 @@ def fit_result() -> comp.FitResult: @pytest.fixture def coords() -> tuple[np.ndarray, np.ndarray]: + """Mock coordinates.""" x = np.arange(3) y = np.arange(9).reshape((3, 3)) @@ -58,6 +60,8 @@ def coords() -> tuple[np.ndarray, np.ndarray]: @pytest.fixture def shunt_fit() -> Iterator[None]: + """Temporarily add a complexity function that fails, and removing on teardown.""" + def _failed_fit(x, a, b): raise RuntimeError("Intentionally Fail.") diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index d7fda90..9b9ee63 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -29,99 +29,200 @@ """Test config module.""" -from contextlib import contextmanager from io import StringIO from unittest.mock import MagicMock, patch import pytest +import toml +from attrs import define, field from BenchMatcha import config @pytest.fixture -def toml_str() -> str: - return """ -[tool.BenchMatcha] -color="#FFF" -line_color="#333" -font="Courier" -upsupported_key="test" -""" - - -@pytest.fixture -def toml_data() -> dict: +def toml_inner() -> dict: return { - "tool": { - "BenchMatcha": { - "color": "#FFF", - "line_color": "#333", - "font": "Courier", - "upsupported_key": "test", - } - } + "color": "#FFF", + "line_color": "#333", + "font": "Courier", + "x_axis": 5, + "unsupported_key": "test", } -@contextmanager -def reset(c: type[config.Config]): - """Simple contextmanager to reset config to original default values.""" - default: dict = {key: getattr(c, key) for key in c.__annotations__.keys()} - yield +@pytest.fixture +def toml_data(toml_inner: dict) -> dict: + """Mock toml config as a dictionary object.""" + return {"tool": {"BenchMatcha": toml_inner}} + - for k, v in default.items(): - setattr(c, k, v) - assert getattr(c, k) == default[k], f"Expected Value to be reset: {k}" +@pytest.fixture +def toml_str(toml_data: dict) -> str: + """Mock toml config text.""" + return toml.dumps(toml_data) -def test_config_load(toml_str: str, toml_data: dict) -> None: +@pytest.fixture +def configuration() -> config.ConfigBase: + """Create configuration and reset instance singleton.""" + return config.ConfigBase() + + +@pytest.mark.parametrize( + ["obj", "keys", "expected"], + [ + ({}, ("a", "b", "c"), {}), + ({"a": {"b": {"c": "d"}}}, ("a", "b", "c"), "d"), + ({"a": {"b": {"d": "d"}}}, ("a", "b", "c"), {}), + ], +) +def test_traverse(obj: dict, keys: tuple[str], expected: str) -> None: + """Test dictionary traversal works as anticipated.""" + result = config.traverse(obj, keys) + assert result == expected + + +def test_config_load( + toml_str: str, + toml_data: dict, + configuration: config.ConfigBase, +) -> None: """Confirm toml data loads correctly.""" - with reset(config.Config), StringIO(toml_str) as stream: - instance = config.ConfigUpdater(stream) + with StringIO(toml_str) as stream: + instance = config.ConfigUpdater(stream, configuration) result = instance.load() assert result == toml_data, "Expected same object" -def _assert_config_is_updated() -> None: - assert config.Config.color == "#FFF", "Expected color to be updated." - assert config.Config.line_color == "#333", "Expected line color to be updated." - assert config.Config.font == "Courier", "Expected font to be updated." - assert not hasattr(config.Config, "upsupported_key"), ( +def _assert_config_is_updated(conf: config.ConfigBase) -> None: + assert conf.color == "#FFF", "Expected color to be updated." + assert conf.line_color == "#333", "Expected line color to be updated." + assert conf.font == "Courier", "Expected font to be updated." + assert not hasattr(conf, "unsupported_key"), ( "Expected unsupported key to be bypassed." ) -def test_config_private_update(toml_data: dict) -> None: +def test_config_private_update( + toml_inner: dict, + configuration: config.ConfigBase, +) -> None: """Confirm config data is updated correctly.""" - with reset(config.Config): - instance = config.ConfigUpdater("") - instance._update(toml_data) - _assert_config_is_updated() + instance = config.ConfigUpdater("", configuration) + instance._update(toml_inner) + _assert_config_is_updated(configuration) -def test_config_update(toml_str: str) -> None: +def test_config_update( + toml_str: str, + configuration: config.ConfigBase, +) -> None: """Confirm toml data loads correctly.""" - with reset(config.Config), StringIO(toml_str) as stream: - instance = config.ConfigUpdater(stream) + with StringIO(toml_str) as stream: + instance = config.ConfigUpdater(stream, configuration) + instance.update() + + _assert_config_is_updated(configuration) + + +def test_config_json( + toml_str: str, + toml_inner: dict, + configuration: config.ConfigBase, +) -> None: + """Confirm configuration returns as json object correctly.""" + with StringIO(toml_str) as stream: + instance = config.ConfigUpdater(stream, configuration) instance.update() - _assert_config_is_updated() + + _assert_config_is_updated(configuration) + toml_inner.pop("unsupported_key", None) + assert configuration.tojson() == toml_inner, "Unexpected json object data." @patch.object(config.ConfigUpdater, "load") -def test_config_update_mock(mock: MagicMock, toml_data: dict) -> None: +def test_config_update_mock( + mock: MagicMock, + toml_data: dict, + configuration: config.ConfigBase, +) -> None: """Confirm config is updated by mocking load method of ConfigUpdater.""" mock.return_value = toml_data - with reset(config.Config): - instance = config.ConfigUpdater("") - instance.update() - _assert_config_is_updated() + instance = config.ConfigUpdater("", configuration) + instance.update() + _assert_config_is_updated(configuration) @patch.object(config.ConfigUpdater, "load") -def test_config_update_function(mock: MagicMock, toml_data: dict) -> None: +def test_config_update_function( + mock: MagicMock, + toml_data: dict, + configuration: config.ConfigBase, +) -> None: """Confirm config is updated from available function api.""" mock.return_value = toml_data - with reset(config.Config): - config.update_config_from_pyproject("") - _assert_config_is_updated() + config.update_config_from_pyproject("", configuration) + _assert_config_is_updated(configuration) + + +def test_config_base_setters(configuration: config.ConfigBase) -> None: + """Confirm class setattr engages designated converters.""" + configuration.x_axis = "5" + assert isinstance(configuration.x_axis, int), "Expected value to be coerced to int." + assert configuration.x_axis == 5, "Expected value to be updated to 5." + + setattr(configuration, "x_axis", "7") # noqa: B010 + assert configuration.x_axis == 7, "Expected value to be updated to 7." + + +@define +class A(config._ConfigBase): + a: str = field(default="a") + b: int = field(default=1) + + +@define +class B(config._ConfigBase): + a: A = field(factory=A) + b: str = field(default="b") + + +@pytest.fixture +def toml_data_b() -> str: + return """ +[tool.BenchMatcha] +b = "loaded b" + +[tool.BenchMatcha.a] +a = "loaded a" +b = 5 +""" + + +@patch.object(config.ConfigUpdater, "load") +def test_recursive_update_function( + mock: MagicMock, + toml_data_b: str, +): + """Confirm nested recursive configurations are appropriately updated.""" + # sanity checks + con = B() + assert con.a.a == "a" + assert con.a.b == 1 + assert con.b == "b" + + # setup + data: dict = toml.loads(toml_data_b) + mock.return_value = data + + instance = config.ConfigUpdater("", con) + instance.update() + + # checks + assert con.a.a == "loaded a" + assert con.a.b == 5 + assert con.b == "loaded b" + + # sanity + assert instance.config is con, "Expected config to be same on instance." diff --git a/tests/unit/test_errors.py b/tests/unit/test_errors.py index 8d8b546..07c1e45 100644 --- a/tests/unit/test_errors.py +++ b/tests/unit/test_errors.py @@ -27,6 +27,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Unit test errors module.""" + import builtins from collections.abc import Callable, Iterator from json import JSONDecodeError @@ -38,6 +40,7 @@ @pytest.fixture def register_error() -> Iterator[Callable[[type[Exception]], type[Exception]]]: + """Register custom error dynamically with teardown.""" e: type[Exception] | None = None def inner(exc: type[Exception]) -> type[Exception]: diff --git a/tests/unit/test_handlers.py b/tests/unit/test_handlers.py index 6e68831..202a936 100644 --- a/tests/unit/test_handlers.py +++ b/tests/unit/test_handlers.py @@ -50,7 +50,7 @@ def test_load(mock_data: str) -> None: [ (handlers.HandlePath, lambda x: x), (handlers.HandleBytes, lambda x: x.encode()), - (handlers.HandleIO, lambda x: StringIO(x)), + (handlers.HandleIO, StringIO), (handlers.HandleIO, lambda x: BytesIO(x.encode())), ], ) @@ -59,7 +59,7 @@ def test_handlers( transformer: Callable[[str], Any], mock_data: str, ) -> None: - """""" + """Test handlers correctly parse mock json data.""" mock = transformer(mock_data) result = handler(mock).handle() assert isinstance(result, dict), "Expected a dictionary loaded." diff --git a/tests/unit/test_plotting.py b/tests/unit/test_plotting.py index 6e520a4..dd8c361 100644 --- a/tests/unit/test_plotting.py +++ b/tests/unit/test_plotting.py @@ -36,6 +36,32 @@ import pytest from BenchMatcha import plotting +from BenchMatcha.config import ConfigBase +from BenchMatcha.structure import BenchmarkArray, ComplexityInfo + + +@pytest.fixture +def bench_arr(): + """Sample BenchmarkArray instance object.""" + name: str = "test" + c = ComplexityInfo( + function=name, + big_o="lgN", + real_coefficient=1.0, + cpu_coefficient=2.0, + ) + + b = BenchmarkArray( + function=name, + unit="s", + size=np.asarray([2, 4, 8]), + iterations=np.asarray([[100, 50, 25], [100, 50, 25], [100, 50, 25]]), + real_time=np.asarray([[5.2, 5.6, 5.3], [10.3, 10.2, 10.0], [15.3, 15.1, 15.2]]), + cpu_time=np.asarray([[5.2, 5.2, 5.0], [10.1, 10.0, 9.9], [15.2, 14.9, 15.0]]), + complexity=c, + ) + + return b def test_serialization_to_html(): @@ -144,3 +170,9 @@ def test_draw_complexity_line() -> None: x = np.arange(20) result = plotting.draw_complexity_line(x, 1.2, "N", "test", "red") assert isinstance(result, go.Scatter) + + +def test_plot_benchmark_array(bench_arr: BenchmarkArray): + """Confirm an array is constructed.""" + result = plotting.plot_benchmark_array(bench_arr, ConfigBase()) + assert isinstance(result, go.Figure), "Expected a figure object." diff --git a/tests/unit/test_sifter.py b/tests/unit/test_sifter.py index f1ba107..2580adb 100644 --- a/tests/unit/test_sifter.py +++ b/tests/unit/test_sifter.py @@ -47,6 +47,7 @@ def __init__(self, path: str, directory: bool = False): self.directory = directory def is_dir(self, **kwargs) -> bool: + """Confirm if path is a directory.""" return self.directory def __str__(self): diff --git a/tests/unit/test_structure.py b/tests/unit/test_structure.py index ffbe0c0..4be7191 100644 --- a/tests/unit/test_structure.py +++ b/tests/unit/test_structure.py @@ -40,6 +40,7 @@ @pytest.fixture def shunt_version() -> Iterator[Callable[[tuple[int, ...]], tuple[int, ...]]]: + """Temporarily modify supported versions, with teardown to reinstate original.""" previous = structure.SUPPORTED_VERSIONS def modify(value: tuple[int, ...]) -> tuple[int, ...]: @@ -52,6 +53,28 @@ def modify(value: tuple[int, ...]) -> tuple[int, ...]: assert structure.SUPPORTED_VERSIONS == previous +@pytest.fixture +def complexity_info() -> dict: + """Expected complexity info data.""" + return { + "function": "function", + "big_o": "N", + "cpu_coefficient": 1.2549524739108346e01, + "real_coefficient": 1.2555517635286144e01, + "rms": 5.4107447739157266e-02, + } + + +@pytest.fixture +def cache_data() -> list[tuple]: + """Expected cache data.""" + return [ + ("Data", 1, 65536, 0), + ("Instruction", 1, 131072, 0), + ("Unified", 2, 4194304, 1), + ] + + @pytest.mark.parametrize( ["value", "expected"], [ @@ -66,7 +89,99 @@ def test_parse_datetime(value: str, expected: datetime) -> None: assert result == expected, "Unexpected date time." -def test_parse_json_data(mock_data: str) -> None: +def _check_cache(cache: structure.Cache, expected: tuple) -> None: + assert isinstance(cache, structure.Cache), f"Expected Cache type: {type(cache)}." + assert cache.type == expected[0], f"Unexpected type: {cache.type}" + assert cache.level == expected[1], f"Unexpected level: {cache.level}" + assert cache.size == expected[2], f"Unexpected size: {cache.size}" + assert cache.num_sharing == expected[3], f"Unexpected sharing: {cache.num_sharing}" + + +def test_convert_cache_data() -> None: + """Test we correctly convert cache json data into a dataclass.""" + cache = { + "type": "test", + "level": 17, + "size": 45, + "num_sharing": 50, + "unknown_key": 3, + } + result = structure.Cache.from_json(cache) + _check_cache(result, ("test", 17, 45, 50)) + + +def test_parse_cache_json_data(mock_data: str, cache_data: list[tuple]) -> None: + """Test we correctly parse cache json data into a dataclass.""" + data: dict = load(mock_data) + caches: list[dict] = data["context"]["caches"] + for cache, e in zip(caches, cache_data, strict=False): + result = structure.Cache.from_json(cache) + _check_cache(result, e) + + +def _check_complexity_info(complexity: structure.ComplexityInfo, expected: dict): + assert isinstance(complexity, structure.ComplexityInfo), ( + "Expected value to be a ComplexityInfo object." + ) + assert complexity.function == expected["function"], ( + f"Incorrect Function Name: {complexity.function}" + ) + assert complexity.big_o == expected["big_o"], ( + f"Incorrect Function Name: {complexity.function}" + ) + assert complexity.cpu_coefficient == expected["cpu_coefficient"], ( + f"Unexpected CPU Coefficient: {complexity.cpu_coefficient}" + ) + assert complexity.real_coefficient == expected["real_coefficient"], ( + f"Unexpected Real Coefficient: {complexity.real_coefficient}" + ) + assert complexity.rms == expected["rms"], f"Unexpected RMS: {complexity.rms}" + + +def test_convert_benchmark_json_to_complexity(complexity_info: dict) -> None: + """Test conversion of benchmark data correctly collects complexity info.""" + obj: list[dict] = [ + { + "name": "function/repeats:3_BigO", + "family_index": 0, + "per_family_instance_index": 0, + "run_name": "function/repeats:3", + "run_type": "aggregate", + "repetitions": 3, + "threads": 1, + "aggregate_name": "BigO", + "aggregate_unit": "time", + "cpu_coefficient": 1.2549524739108346e01, + "real_coefficient": 1.2555517635286144e01, + "big_o": "N", + "time_unit": "ns", + }, + { + "name": "function/repeats:3_RMS", + "family_index": 0, + "per_family_instance_index": 0, + "run_name": "function/repeats:3", + "run_type": "aggregate", + "repetitions": 3, + "threads": 1, + "aggregate_name": "RMS", + "aggregate_unit": "percentage", + "rms": 5.4107447739157266e-02, + }, + ] + + result = structure.get_complexity_info(obj) + assert isinstance(result, dict), "Expected dictionary object returned." + assert "function" in result, "Expected function name key to be present." + value = result["function"] + _check_complexity_info(value, complexity_info) + + +def test_parse_full_json_data( + mock_data: str, + complexity_info: dict, + cache_data: list[tuple], +) -> None: """Test we correctly parse json data into a dataclass.""" data = load(mock_data) result = structure.BenchmarkContext.from_json(data) @@ -79,15 +194,23 @@ def test_parse_json_data(mock_data: str) -> None: assert result.cpu_scaling_enabled is False assert len(result.caches) == 3 + for c, e in zip(result.caches, cache_data, strict=False): + _check_cache(c, e) + assert result.load_avg == [4.69092, 4.60693, 4.47949] assert result.library_version == "1.9.4" assert result.library_build_type == "release" assert result.json_schema_version == 1 + assert result.aslr_enabled is False + assert isinstance(result.python_version, str), "Expected a string" + assert isinstance(result.git_sha, str), "Expected a string" + assert isinstance(result.benchmarks, list), "Expected benchmarks to be a list." assert len(result.benchmarks) == 1 + _check_complexity_info(result.benchmarks[0].complexity, complexity_info) -def test_convert_BenchmarkContext_to_json(mock_data: str) -> None: +def test_convert_benchmark_context_to_json(mock_data: str) -> None: """Test we convert dataclass into dictionary json like objects.""" data = load(mock_data) obj = structure.BenchmarkContext.from_json(data) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 4ef1d61..b141bb1 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -27,6 +27,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Unit test utils module.""" + import numpy as np import pytest @@ -38,20 +40,20 @@ (1, 1), ] for j in range(2, 24): - value = 1 << j - _power2.append((value - 1, value)) - _power2.append((value, value)) - _power2.append((value + 1, 1 << (j + 1))) + _v = 1 << j + _power2.append((_v - 1, _v)) + _power2.append((_v, _v)) + _power2.append((_v + 1, 1 << (j + 1))) @pytest.mark.parametrize(["value", "expected"], _power2) -def test_power_of_2(value: int, expected: int): +def test_power_of_2(value: int, expected: int) -> None: """Test returns next power of two.""" - result = utils.power_of_2(value) + result: int = utils.power_of_2(value) assert result == expected, f"Unexpected result: {result}" -def test_simple_stats(): +def test_simple_stats() -> None: """Test mean and std.""" x = np.asarray([[1, 2, 3], [2, 3, 1], [3, 1, 2]]) result = utils._simple_stats(x) @@ -72,7 +74,7 @@ def test_simple_stats(): ("oLambda", "f(N)"), ], ) -def test_bigo_enum_get(value: str, expected: str): +def test_bigo_enum_get(value: str, expected: str) -> None: """Test conversion of big o notation identifier get classmethod.""" result: str = utils.BigO.get(value) assert result == expected, "Unexpected result." @@ -90,7 +92,7 @@ def test_bigo_enum_get(value: str, expected: str): ("f(N)", "oLambda"), ], ) -def test_bigo_enum_back(value: str, expected: str): +def test_bigo_enum_back(value: str, expected: str) -> None: """Test conversion of big o notation identifier back classmethod.""" result: str = utils.BigO.back(value) assert result == expected, "Unexpected result."