diff --git a/basalt/basalt_facade.py b/basalt/basalt_facade.py index 925773d..5d392f4 100644 --- a/basalt/basalt_facade.py +++ b/basalt/basalt_facade.py @@ -1,5 +1,7 @@ +from typing import Optional + from .utils.api import Api -from .utils.protocols import IPromptSDK, IBasaltSDK, LogLevel, IDatasetSDK +from .utils.protocols import ICache, IPromptSDK, IBasaltSDK, LogLevel, IDatasetSDK from .sdk.promptsdk import PromptSDK from .sdk.monitorsdk import MonitorSDK from .sdk.datasetsdk import DatasetSDK @@ -18,15 +20,24 @@ class BasaltFacade(IBasaltSDK): The Basalt client. """ - def __init__(self, api_key: str, log_level: LogLevel = 'all'): + def __init__( + self, + api_key: str, + log_level: LogLevel = "all", + cache: Optional[ICache] = None, + ): """ Initializes the Basalt client with the given API key and log level. Args: api_key (str): The API key for authenticating with the Basalt SDK. log_level (str, optional): The log level for the logger. Defaults to 'all'. (all, warn, error, debug, none) + cache (ICache, optional): The cache to use for the SDK. Defaults to None, which means a MemoryCache will be used. """ - cache = MemoryCache() + + if cache is None: + cache = MemoryCache() + logger = Logger(log_level=log_level) networker = Networker() diff --git a/basalt/utils/memcache.py b/basalt/utils/memcache.py index 4aa7a47..b4d84bb 100644 --- a/basalt/utils/memcache.py +++ b/basalt/utils/memcache.py @@ -1,7 +1,9 @@ import time from typing import Any, Dict, Hashable -class MemoryCache: +from .protocols import ICache + +class MemoryCache(ICache): """ MemoryCache is a simple in-memory cache that stores values for a given key. It implements the ICache protocol.