Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 14 additions & 136 deletions src/memos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,159 +91,37 @@ def wrapper(*args, **kwargs):
self._start = time.perf_counter()
try:
return func(*args, **kwargs)
except Exception as e:
self.duration_ms = int((time.perf_counter() - self._start) * 1000)
self._emit(self.duration_ms, type(e))
raise
finally:
self.duration_ms = int((time.perf_counter() - self._start) * 1000)
self._emit(self.duration_ms)

return wrapper

# -- public API --------------------------------------------------------

def set(self, **fields: Any):
"""Add / overwrite fields after execution (e.g. counts only known after the block runs)."""
self._fields.update(fields)
# -- public helpers ----------------------------------------------------

@staticmethod
def emit_now(biz: str, stage: str, **fields: Any):
"""Fire a one-shot structured log without timing (e.g. summary rollups)."""
parts = [f"biz={biz}", f"stage={stage}"]
for k, v in fields.items():
parts.append(f"{k}={v}")
logger.info("[STAGE] " + " ".join(parts))
def set(self, **kwargs: Any) -> None:
"""Attach extra fields mid-block (context-manager usage)."""
self._fields.update(kwargs)

# -- internals ---------------------------------------------------------
# -- internal ----------------------------------------------------------

def _emit(self, duration_ms: int, exc_type=None):
parts: list[str] = []
parts = ["[STAGE]"]
if self._biz:
parts.append(f"biz={self._biz}")
if self._stage:
parts.append(f"stage={self._stage}")
parts.append(f"duration_ms={duration_ms}")

if self._static_extra:
self._fields.update(self._static_extra)

for k, v in self._static_extra.items():
parts.append(f"{k}={v}")
for k, v in self._fields.items():
parts.append(f"{k}={v}")

parts.append(f"duration_ms={duration_ms}")
if exc_type is not None:
parts.append(f"error={exc_type.__name__}")

msg = "[STAGE] " + " ".join(parts)
msg = " ".join(parts)
getattr(logger, self._level, logger.info)(msg)


def timed_with_status(
func=None,
*,
log_prefix="",
log_args=None,
log_extra_args=None,
fallback=None,
):
"""
Parameters:
- log: enable timing logs (default True)
- log_prefix: prefix; falls back to function name
- log_args: names to include in logs (str or list/tuple of str), values are taken from kwargs by name.
- log_extra_args:
- can be a dict: fixed contextual fields that are always attached to logs;
- or a callable: like `fn(*args, **kwargs) -> dict`, used to dynamically generate contextual fields at runtime.
"""

if isinstance(log_args, str):
effective_log_args = [log_args]
else:
effective_log_args = list(log_args) if log_args else []

def decorator(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
start = time.perf_counter()
exc_type = None
exc_message = None
result = None
success_flag = False

try:
result = fn(*args, **kwargs)
success_flag = True
return result
except Exception as e:
exc_type = type(e)
stack_info = "".join(traceback.format_stack()[:-1])
exc_message = f"{stack_info}{traceback.format_exc()}"
success_flag = False

if fallback is not None and callable(fallback):
result = fallback(e, *args, **kwargs)
return result
finally:
elapsed_ms = (time.perf_counter() - start) * 1000.0

ctx_parts = []
# 1) Collect parameters from kwargs by name
for key in effective_log_args:
val = kwargs.get(key)
ctx_parts.append(f"{key}={val}")

# 2) Support log_extra_args as dict or callable, so we can dynamically
# extract values from self or other runtime context
extra_items = {}
try:
if callable(log_extra_args):
extra_items = log_extra_args(*args, **kwargs) or {}
elif isinstance(log_extra_args, dict):
extra_items = log_extra_args
except Exception as e:
logger.warning(f"[TIMER_WITH_STATUS] log_extra_args callback error: {e!r}")

if extra_items:
ctx_parts.extend(f"{key}={val}" for key, val in extra_items.items())

ctx_str = f" [{', '.join(ctx_parts)}]" if ctx_parts else ""

status = "SUCCESS" if success_flag else "FAILED"
status_info = f", status: {status}"
if not success_flag and exc_type is not None:
status_info += (
f", error_type: {exc_type.__name__}, error_message: {exc_message}"
)

msg = (
f"[TIMER_WITH_STATUS] {log_prefix or fn.__name__} "
f"took {elapsed_ms:.0f} ms{status_info}, args: {ctx_str}"
)

logger.info(msg)

return wrapper

if func is None:
return decorator
return decorator(func)


def timed(func=None, *, log=True, log_prefix=""):
def decorator(fn):
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = fn(*args, **kwargs)
elapsed_ms = (time.perf_counter() - start) * 1000.0

if log is not True:
return result

# 100ms threshold
if elapsed_ms >= 100.0:
logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms")

return result

return wrapper

# Handle both @timed and @timed(log=True) cases
if func is None:
return decorator
return decorator(func)
Loading