Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ warn_unused_ignores = true

[tool.pylint.main]
extension-pkg-allow-list = ["orjson"]
ignore = ["tests", "dist", "build"]
ignore = ["dist", "build"]
fail-under = 9.0
jobs = 0
limit-inference-results = 100
Expand Down Expand Up @@ -118,6 +118,13 @@ disable = [
"W0231", # super-init-not-called
]

[tool.pylint."tests/*"]
disable = [
"W0212", # protected-access
"W0612", # unused-variable
"W0621", # redefined-outer-name --> pytest fixtures
]

[tool.ruff]
line-length = 88
indent-width = 4
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
attrs
google-benchmark
numpy
orjson
Expand Down
7 changes: 4 additions & 3 deletions src/BenchMatcha/complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,16 @@ def __repr__(self) -> str:

# Define common complexity functions with all coefficients and intercept
Equation = (
Callable[[np.ndarray, float, float], np.ndarray]
Callable[[np.ndarray, float], np.ndarray]
| Callable[[np.ndarray, float, float], np.ndarray]
| Callable[[np.ndarray, float, float, float], np.ndarray]
| Callable[[np.ndarray, float, float, float, float], np.ndarray]
)


def constant(n: np.ndarray, a: float, b: float) -> np.ndarray:
def constant(n: np.ndarray, a: float) -> np.ndarray:
"""Constant O(1) equation."""
return a * np.ones_like(n) + b
return a * np.ones_like(n)


def logn(n: np.ndarray, a: float, b: float) -> np.ndarray:
Expand Down
71 changes: 52 additions & 19 deletions src/BenchMatcha/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,33 @@
"""Default runner configuration."""

import logging
from typing import Any, Iterable

import toml # type: ignore[import-untyped]

from . import plotting
from attrs import asdict, define, field


log: logging.Logger = logging.getLogger(__name__)


class Config:
"""default configuration.
def traverse(d: dict, keys: Iterable[str]) -> dict:
"""Safely traverse dictionary keys."""
for k in keys:
d = d.get(k, {})

return d


@define
class _ConfigBase:
def tojson(self) -> dict[str, Any]:
"""Convert instance into json object."""
return asdict(self, recurse=True)


@define
class ConfigBase(_ConfigBase):
"""default configuration settings.

Attributes:
color (str): plot marker color.
Expand All @@ -50,49 +66,66 @@ class Config:

"""

color: str = plotting.Prism[3]
line_color: str = plotting.Prism[4]
font: str = "Space Grotesk Light, Courier New, monospace"
x_axis: int = 13
color: str = field(converter=str, default="#0f8554")
line_color: str = field(converter=str, default="#73af48")
font: str = field(
converter=str,
default="Space Grotesk Light, Courier New, monospace",
)
x_axis: int = field(converter=int, default=13)


class ConfigUpdater:
"""Configuration updater through pyproject config file.

Args:
path (str): path to valid configuration file.
config (Config): configuration class to update.
config (_ConfigBase): configuration class to update.

"""

path: str
config: type[Config]
config: _ConfigBase
_tool_key: str = "BenchMatcha"

def __init__(self, path: str, config: type[Config] = Config) -> None:
def __init__(self, path: str, config: _ConfigBase) -> None:
self.path = path
self.config = config

def load(self) -> dict:
def load(self) -> dict[str, Any]:
"""Load toml data from path."""
return toml.load(self.path)

def _update(self, data: dict) -> None:
for key, value in data.get("tool", {}).get("BenchMatcha", {}).items():
def _update(self, data: dict[str, Any]) -> None:
for key, value in data.items():
# Support recursive configuration
if isinstance(value, dict):
previous = self.config
current = getattr(self.config, key)
if isinstance(current, _ConfigBase):
self.config = current
self._update(value)
self.config = previous
continue

if not hasattr(self.config, key):
log.info("Unsupported tool key: %s", key)
continue

setattr(self.config, key, value)

def update(self) -> None:
"""Parse toml path and update default configuration."""
data: dict = self.load()
self._update(data)
data: dict[str, Any] = self.load()
self._update(traverse(data, ("tool", self._tool_key)))


def update_config_from_pyproject(path: str) -> None:
def update_config_from_pyproject(path: str, config: _ConfigBase) -> None:
"""Update default config from pyproject toml file.

Args:
path (str): Filepath to pyproject.toml config.
config (ConfigBase): configuration class object.

Example:

.. code-block: toml
Expand All @@ -104,5 +137,5 @@ def update_config_from_pyproject(path: str) -> None:
x_axis=5

"""
cu = ConfigUpdater(path)
cu = ConfigUpdater(path, config)
cu.update()
70 changes: 70 additions & 0 deletions src/BenchMatcha/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from plotly.express import colors # type: ignore[import-untyped]
from plotly.io import to_json as _to_json # type: ignore[import-untyped]

from .config import ConfigBase
from .structure import BenchmarkArray
from .utils import _simple_stats, power_of_2


Expand Down Expand Up @@ -245,3 +247,71 @@ def draw_complexity_line(
),
opacity=0.7,
)


def plot_benchmark_array(benchmark: BenchmarkArray, config: ConfigBase) -> go.Figure:
"""Plot benchmark array.

Args:
benchmark (BenchmarkArray): benchmark array data.
config (ConfigBase): configuration settings.

Returns:
(go.Figure) returns plotly figure.

"""
fig = go.Figure()
fig.add_trace(
create_scatter_trace(
benchmark.size,
benchmark.cpu_time,
"CPU Time",
config.color,
)
)

fig.add_trace(
draw_complexity_line(
benchmark.size,
benchmark.complexity.cpu_coefficient,
benchmark.complexity.big_o,
f"CPU Time Fit ({benchmark.complexity.big_o})",
config.line_color,
)
)

fig.add_annotation(
**create_annotation_text(
benchmark.complexity.big_o,
benchmark.complexity.rms,
)
)

vals, labels = construct_log2_axis(benchmark.size)
if (p := len(vals) // config.x_axis) > 0:
vals = vals[:: p + 1]
labels = labels[:: p + 1]

fig.update_layout(
title=f"Benchmark Results<br><i>{benchmark.function}</i>",
xaxis=dict(
type="log",
tickvals=vals,
ticktext=labels,
tickmode="array",
title="Input Size (n)",
),
yaxis=dict(
title=f"Time ({benchmark.unit})",
type="log",
dtick=1,
exponentformat="power",
),
legend_title="Timing",
font=dict(
family=config.font,
size=12,
),
)

return fig
Loading
Loading