diff --git a/src/blacki/agent.py b/src/blacki/agent.py index 03eaea7..1b138c4 100644 --- a/src/blacki/agent.py +++ b/src/blacki/agent.py @@ -199,12 +199,15 @@ def create_app(agent: LlmAgent | None = None) -> App: if agent is None: agent = create_agent() + from blacki.declarative_db.plugin import DeclarativeDbPlugin + return App( name="blacki", root_agent=agent, plugins=[ TelegramModelOverridePlugin(name="telegram_model_override"), GlobalInstructionPlugin(return_global_instruction), + DeclarativeDbPlugin(name="declarative_db"), LoggingPlugin(), ], events_compaction_config=None, diff --git a/src/blacki/container.py b/src/blacki/container.py index 2c72070..71e4b06 100644 --- a/src/blacki/container.py +++ b/src/blacki/container.py @@ -27,6 +27,7 @@ import aiosqlite from blacki.calories.storage import SqliteCalorieStorage + from blacki.declarative_db.storage import SqliteDeclarativeDbStorage from blacki.reminders.storage import SqliteReminderStorage from blacki.utils.preferences import SqlitePreferencesStorage from blacki.workouts.storage import SqliteWorkoutStorage @@ -134,6 +135,9 @@ class AppContainer: _preferences_storage: SqlitePreferencesStorage | None = field( default=None, init=False, repr=False ) + _declarative_db_storage: SqliteDeclarativeDbStorage | None = field( + default=None, init=False, repr=False + ) @classmethod async def create(cls, sqlite_path: str | Path) -> Self: @@ -174,6 +178,10 @@ async def _close_storages(self) -> None: await self._preferences_storage.close() self._preferences_storage = None + if self._declarative_db_storage is not None: + await self._declarative_db_storage.close() + self._declarative_db_storage = None + async def initialize_all_storages(self) -> None: """Initialize all storage instances. @@ -184,6 +192,7 @@ async def initialize_all_storages(self) -> None: await self.calorie_storage.initialize() await self.workout_storage.initialize() await self.preferences_storage.initialize() + await self.declarative_db_storage.initialize() @property def lock(self) -> asyncio.Lock: @@ -225,3 +234,14 @@ def preferences_storage(self) -> SqlitePreferencesStorage: self._preferences_storage = SqlitePreferencesStorage(self.conn, self._lock) return self._preferences_storage + + @property + def declarative_db_storage(self) -> SqliteDeclarativeDbStorage: + """Get or create the declarative database storage instance.""" + if self._declarative_db_storage is None: + from blacki.declarative_db.storage import SqliteDeclarativeDbStorage + + self._declarative_db_storage = SqliteDeclarativeDbStorage( + self.conn, self._lock + ) + return self._declarative_db_storage diff --git a/src/blacki/declarative_db/plugin.py b/src/blacki/declarative_db/plugin.py new file mode 100644 index 0000000..e14eaca --- /dev/null +++ b/src/blacki/declarative_db/plugin.py @@ -0,0 +1,58 @@ +"""ADK plugin for injecting declarative database instructions.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from google.adk.plugins.base_plugin import BasePlugin + +from blacki.declarative_db.storage import get_declarative_db_storage + +if TYPE_CHECKING: + from google.adk.agents.callback_context import CallbackContext + from google.adk.models.llm_request import LlmRequest + +logger = logging.getLogger(__name__) + + +class DeclarativeDbPlugin(BasePlugin): + """ADK plugin to dynamically load declarative database schemas. + + Loads both schemas and custom instruction overrides. + This plugin queries the storage layer using the active session user's ID + and appends user-defined tables, query templates, and instruction overrides + to the system instruction context on every LLM call. + """ + + def __init__(self, name: str = "declarative_db") -> None: + super().__init__(name=name) + + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> None: + """Callback executed before the LLM is invoked.""" + if not callback_context.session: + return + + user_id = callback_context.session.state.get( + "user_id" + ) or callback_context.session.state.get("telegram_chat_id") + if not user_id: + logger.debug( + "before_model_callback: No user_id or telegram_chat_id in session state" + ) + return + + try: + storage = get_declarative_db_storage() + schema_xml = await storage.get_schema_instructions_xml(str(user_id)) + if schema_xml: + logger.info( + "Injecting custom database instructions for user %s (%d chars)", + user_id, + len(schema_xml), + ) + llm_request.append_instructions([schema_xml]) + except Exception: + logger.exception("Failed to inject user database schemas and instructions") diff --git a/src/blacki/declarative_db/storage.py b/src/blacki/declarative_db/storage.py new file mode 100644 index 0000000..6efea21 --- /dev/null +++ b/src/blacki/declarative_db/storage.py @@ -0,0 +1,783 @@ +"""SQLite storage implementation for declarative tables and query templates.""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import logging +from typing import TYPE_CHECKING, Any + +from blacki.declarative_db.validation import validate_column_type, validate_identifier +from blacki.storage.base import SqlStorage +from blacki.utils.timezone import now_utc + +if TYPE_CHECKING: + import aiosqlite + +logger = logging.getLogger(__name__) + + +class SqliteDeclarativeDbStorage(SqlStorage): + """Storage implementation for custom schemas, query templates, and overrides.""" + + def __init__(self, conn: aiosqlite.Connection, lock: asyncio.Lock) -> None: + super().__init__(conn, lock) + + async def _create_tables(self) -> None: + # Create metadata table for tracking tables + await self._conn.execute(""" + CREATE TABLE IF NOT EXISTS custom_tables ( + user_id TEXT NOT NULL, + table_name TEXT NOT NULL, + physical_name TEXT NOT NULL, + description TEXT, + created_at TEXT NOT NULL, + PRIMARY KEY (user_id, table_name) + ) + """) + + # Create metadata table for column definitions + await self._conn.execute(""" + CREATE TABLE IF NOT EXISTS custom_table_columns ( + user_id TEXT NOT NULL, + table_name TEXT NOT NULL, + column_name TEXT NOT NULL, + column_type TEXT NOT NULL, + is_primary_key INTEGER NOT NULL DEFAULT 0, + is_not_null INTEGER NOT NULL DEFAULT 0, + default_value TEXT, + PRIMARY KEY (user_id, table_name, column_name), + FOREIGN KEY (user_id, table_name) + REFERENCES custom_tables(user_id, table_name) ON DELETE CASCADE + ) + """) + + # Create metadata table for query templates + await self._conn.execute(""" + CREATE TABLE IF NOT EXISTS saved_query_templates ( + user_id TEXT NOT NULL, + template_name TEXT NOT NULL, + table_name TEXT NOT NULL, + query_type TEXT NOT NULL, + description TEXT, + select_columns TEXT, -- JSON list of column names + filter_columns TEXT, -- JSON list of column names + order_by_column TEXT, + order_by_direction TEXT, -- ASC, DESC + limit_val INTEGER, + PRIMARY KEY (user_id, template_name), + FOREIGN KEY (user_id, table_name) + REFERENCES custom_tables(user_id, table_name) ON DELETE CASCADE + ) + """) + + # Create metadata table for user-scoped instructions + await self._conn.execute(""" + CREATE TABLE IF NOT EXISTS custom_instruction_overrides ( + user_id TEXT NOT NULL, + key TEXT NOT NULL, + instructions TEXT NOT NULL, + updated_at TEXT NOT NULL, + PRIMARY KEY (user_id, key) + ) + """) + + def _get_physical_name(self, user_id: str, table_name: str) -> str: + """Derive a safe physical table name scoped by hashed user_id.""" + user_hash = hashlib.md5( + user_id.encode("utf-8"), usedforsecurity=False + ).hexdigest()[:16] + return f"usr_{user_hash}_{table_name}" + + async def create_custom_table( + self, + user_id: str, + table_name: str, + columns: list[dict[str, Any]], + description: str | None = None, + ) -> None: + """Create a safe custom physical table and store its metadata. + + Args: + user_id: The scoping user ID. + table_name: Logical table name. + columns: List of columns containing name, type, primary_key, + not_null, and default keys. + description: Optional table description. + """ + validate_identifier(table_name) + physical_name = self._get_physical_name(user_id, table_name) + + if not columns: + raise ValueError("Table must define at least one column") + + # Validate columns and prepare dynamic DDL components + validated_cols: list[dict[str, Any]] = [] + ddl_parts: list[str] = [] + + for col in columns: + col_name = col.get("name", "").strip() + col_type = col.get("type", "").strip().upper() + is_pk = int(bool(col.get("primary_key"))) + is_nn = int(bool(col.get("not_null"))) + default_val = col.get("default") + + validate_identifier(col_name) + validate_column_type(col_type) + + default_str = None + if default_val is not None: + if isinstance(default_val, bool): + default_str = "1" if default_val else "0" + part_def = f" DEFAULT {default_str}" + elif isinstance(default_val, (int, float)): + default_str = str(default_val) + part_def = f" DEFAULT {default_str}" + elif isinstance(default_val, str): + escaped_default = default_val.replace("'", "''") + part_def = f" DEFAULT '{escaped_default}'" + default_str = default_val + elif isinstance(default_val, (list, dict)): + default_str = json.dumps(default_val) + escaped_default = default_str.replace("'", "''") + part_def = f" DEFAULT '{escaped_default}'" + else: + default_str = str(default_val) + escaped_default = default_str.replace("'", "''") + part_def = f" DEFAULT '{escaped_default}'" + else: + part_def = "" + + validated_cols.append( + { + "name": col_name, + "type": col_type, + "is_pk": is_pk, + "is_nn": is_nn, + "default": default_str, + } + ) + + # Form column DDL segment + part = f'"{col_name}" {col_type}' + if is_pk: + part += " PRIMARY KEY" + if is_nn: + part += " NOT NULL" + part += part_def + + ddl_parts.append(part) + + # Build physical CREATE TABLE string + create_ddl = ( + f'CREATE TABLE IF NOT EXISTS "{physical_name}" (\n ' + + ",\n ".join(ddl_parts) + + "\n)" + ) + + now = now_utc().isoformat(timespec="seconds") + + async with self._lock: + # Check if table already exists + table_exists_row = await self._fetch_one( + "SELECT 1 FROM custom_tables WHERE user_id = ? AND table_name = ?", + (user_id, table_name), + ) + table_exists = table_exists_row is not None + + if not table_exists: + # Check table counts guardrail + existing_count_row = await self._fetch_one( + "SELECT COUNT(*) as count FROM custom_tables WHERE user_id = ?", + (user_id,), + ) + existing_count = ( + existing_count_row["count"] if existing_count_row else 0 + ) + if existing_count >= 5: + raise ValueError("Limit of 5 custom tables per user reached") + + if len(columns) > 15: + raise ValueError("Limit of 15 columns per custom table reached") + + # Initialize physical table and save metadata within one serial transaction + await self._conn.execute("BEGIN") + try: + if table_exists: + # Drop old table to prevent column/metadata desync + await self._conn.execute(f'DROP TABLE IF EXISTS "{physical_name}"') + await self._conn.execute( + "DELETE FROM custom_tables " + "WHERE user_id = ? AND table_name = ?", + (user_id, table_name), + ) + + # Physically create table + await self._conn.execute(create_ddl) + + # Save table metadata + await self._conn.execute( + """ + INSERT INTO custom_tables ( + user_id, table_name, physical_name, description, created_at + ) VALUES (?, ?, ?, ?, ?) + ON CONFLICT (user_id, table_name) DO UPDATE SET + description = excluded.description, + physical_name = excluded.physical_name + """, + (user_id, table_name, physical_name, description, now), + ) + + # Delete any old metadata columns + await self._conn.execute( + "DELETE FROM custom_table_columns " + "WHERE user_id = ? AND table_name = ?", + (user_id, table_name), + ) + + # Save column metadata + for c in validated_cols: + await self._conn.execute( + """ + INSERT INTO custom_table_columns ( + user_id, table_name, column_name, column_type, + is_primary_key, is_not_null, default_value + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + user_id, + table_name, + c["name"], + c["type"], + c["is_pk"], + c["is_nn"], + c["default"], + ), + ) + + await self._conn.execute("COMMIT") + logger.info( + "Successfully created table %s with %d columns", + table_name, + len(validated_cols), + ) + except Exception as e: + await self._conn.execute("ROLLBACK") + logger.exception("Failed to create custom table %s", table_name) + raise e + + async def delete_custom_table(self, user_id: str, table_name: str) -> bool: + """Physically drop a table and delete its metadata records.""" + validate_identifier(table_name) + physical_name = self._get_physical_name(user_id, table_name) + + async with self._lock: + # Check if table metadata exists first + row = await self._fetch_one( + "SELECT table_name FROM custom_tables " + "WHERE user_id = ? AND table_name = ?", + (user_id, table_name), + ) + if not row: + return False + + await self._conn.execute("BEGIN") + try: + # Drop physical table + await self._conn.execute(f'DROP TABLE IF EXISTS "{physical_name}"') + + # Clear metadata (ON DELETE CASCADE propagates to columns & templates) + await self._conn.execute( + "DELETE FROM custom_tables WHERE user_id = ? AND table_name = ?", + (user_id, table_name), + ) + + await self._conn.execute("COMMIT") + logger.info("Successfully dropped table %s", table_name) + return True + except Exception as e: + await self._conn.execute("ROLLBACK") + logger.exception("Failed to drop custom table %s", table_name) + raise e + + async def create_query_template( + self, + user_id: str, + template_name: str, + table_name: str, + query_type: str, + select_columns: list[str] | None = None, + filter_columns: list[str] | None = None, + order_by_column: str | None = None, + order_by_direction: str | None = None, + limit_val: int | None = None, + description: str | None = None, + ) -> None: + """Create or save a query template against a logical custom table.""" + validate_identifier(template_name) + validate_identifier(table_name) + query_type = query_type.strip().upper() + if query_type not in {"SELECT", "INSERT", "UPDATE", "DELETE"}: + raise ValueError("Query type must be SELECT, INSERT, UPDATE, or DELETE") + + # Validate identifiers inside lists + if select_columns: + for s in select_columns: + validate_identifier(s) + if filter_columns: + for f in filter_columns: + validate_identifier(f) + if order_by_column: + validate_identifier(order_by_column) + if order_by_direction: + order_by_dir_upper = order_by_direction.strip().upper() + if order_by_dir_upper not in {"ASC", "DESC"}: + raise ValueError("Sorting direction must be ASC or DESC") + order_by_direction = order_by_dir_upper + + async with self._lock: + # Check if template already exists + template_exists_row = await self._fetch_one( + "SELECT 1 FROM saved_query_templates " + "WHERE user_id = ? AND template_name = ?", + (user_id, template_name), + ) + template_exists = template_exists_row is not None + + if not template_exists: + # Check templates limit (10 per user) + existing_count_row = await self._fetch_one( + "SELECT COUNT(*) as count FROM saved_query_templates " + "WHERE user_id = ?", + (user_id,), + ) + existing_count = ( + existing_count_row["count"] if existing_count_row else 0 + ) + if existing_count >= 10: + raise ValueError( + "Limit of 10 saved query templates per user reached" + ) + + # Check if custom table exists + table_row = await self._fetch_one( + "SELECT table_name FROM custom_tables " + "WHERE user_id = ? AND table_name = ?", + (user_id, table_name), + ) + if not table_row: + raise ValueError(f"Table '{table_name}' does not exist for this user") + + # Validate column existence against metadata + allowed_cols_rows = await self._fetch_all( + "SELECT column_name FROM custom_table_columns " + "WHERE user_id = ? AND table_name = ?", + (user_id, table_name), + ) + allowed_cols = {r["column_name"] for r in allowed_cols_rows} + + for col_list, label in [ + (select_columns, "select_columns"), + (filter_columns, "filter_columns"), + ]: + if col_list: + for c in col_list: + if c not in allowed_cols: + raise ValueError( + f"Column '{c}' in {label} does not " + f"exist in table '{table_name}'" + ) + + if order_by_column and order_by_column not in allowed_cols: + raise ValueError( + f"Column '{order_by_column}' in order_by_column " + f"does not exist in table '{table_name}'" + ) + + # Insert metadata into saved templates + await self._conn.execute( + """ + INSERT INTO saved_query_templates ( + user_id, template_name, table_name, query_type, description, + select_columns, filter_columns, order_by_column, + order_by_direction, limit_val + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (user_id, template_name) DO UPDATE SET + table_name = excluded.table_name, + query_type = excluded.query_type, + description = excluded.description, + select_columns = excluded.select_columns, + filter_columns = excluded.filter_columns, + order_by_column = excluded.order_by_column, + order_by_direction = excluded.order_by_direction, + limit_val = excluded.limit_val + """, + ( + user_id, + template_name, + table_name, + query_type, + description, + json.dumps(select_columns) if select_columns is not None else None, + json.dumps(filter_columns) if filter_columns is not None else None, + order_by_column, + order_by_direction, + limit_val, + ), + ) + logger.info( + "Saved query template %s against table %s", template_name, table_name + ) + + async def execute_query_template( + self, + user_id: str, + template_name: str, + parameters: dict[str, Any], + ) -> list[dict[str, Any]] | int: + """Securely build and run a saved template query with validation.""" + validate_identifier(template_name) + + # 1. Fetch template from DB + template = await self._fetch_one( + "SELECT * FROM saved_query_templates " + "WHERE user_id = ? AND template_name = ?", + (user_id, template_name), + ) + if not template: + raise ValueError(f"Template '{template_name}' not found") + + table_name = template["table_name"] + query_type = template["query_type"] + + # Fetch physical table name + table_meta = await self._fetch_one( + "SELECT physical_name FROM custom_tables " + "WHERE user_id = ? AND table_name = ?", + (user_id, table_name), + ) + if not table_meta: + raise ValueError(f"Underlying table '{table_name}' does not exist") + physical_name = table_meta["physical_name"] + + # Fetch valid columns + valid_cols_rows = await self._fetch_all( + "SELECT column_name FROM custom_table_columns " + "WHERE user_id = ? AND table_name = ?", + (user_id, table_name), + ) + valid_cols = {r["column_name"] for r in valid_cols_rows} + + # 2. Strict Parameter Validation + # Verify that all keys supplied in parameters are valid matching column names + for key in parameters: + validate_identifier(key) + if key not in valid_cols: + raise ValueError( + f"Parameter key '{key}' is not a valid " + f"column for table '{table_name}'" + ) + + # Parse column lists from JSON + select_cols = ( + json.loads(template["select_columns"]) if template["select_columns"] else [] + ) + filter_cols = ( + json.loads(template["filter_columns"]) if template["filter_columns"] else [] + ) + + # 3. Secure Query Assembly + sql = "" + bindings: tuple[Any, ...] = () + + if query_type == "SELECT": + # Form SELECT segment + cols_str = ", ".join(f'"{c}"' for c in select_cols) if select_cols else "*" + + # Validated & quoted identifiers: SQL safe. + sql = f'SELECT {cols_str} FROM "{physical_name}"' # noqa: S608 + + # Form WHERE segment + if filter_cols: + where_segments = [] + where_bindings = [] + for c in filter_cols: + if c in parameters: + where_segments.append(f'"{c}" = ?') + where_bindings.append(parameters[c]) + else: + raise ValueError(f"Missing required filter parameter '{c}'") + sql += " WHERE " + " AND ".join(where_segments) + bindings = tuple(where_bindings) + + # Form ORDER BY segment + if template["order_by_column"]: + ob_col = template["order_by_column"] + ob_dir = template["order_by_direction"] or "ASC" + sql += f' ORDER BY "{ob_col}" {ob_dir}' + + # Form LIMIT segment + if template["limit_val"] is not None: + sql += f" LIMIT {int(template['limit_val'])}" + + # Run SELECT + return await self._fetch_all(sql, bindings) + + elif query_type == "INSERT": + # Form INSERT parameters from direct parameters matching columns + insert_keys = [k for k in parameters if k in valid_cols] + if not insert_keys: + raise ValueError("Must provide at least one valid parameter to insert") + + cols_str = ", ".join(f'"{k}"' for k in insert_keys) + placeholders = ", ".join("?" for _ in insert_keys) + sql = f'INSERT INTO "{physical_name}" ({cols_str}) VALUES ({placeholders})' # noqa: S608 + bindings = tuple(parameters[k] for k in insert_keys) + + # Run write insert + return await self._execute(sql, bindings) + + elif query_type == "UPDATE": + # Form UPDATE columns from params not in filter_columns + update_keys = [ + k for k in parameters if k in valid_cols and k not in filter_cols + ] + if not update_keys: + raise ValueError("No columns provided to update") + + set_segments = [f'"{k}" = ?' for k in update_keys] + set_bindings = [parameters[k] for k in update_keys] + + sql = f'UPDATE "{physical_name}" SET ' + ", ".join(set_segments) # noqa: S608 + + # Form WHERE segment + if filter_cols: + where_segments = [] + where_bindings = [] + for c in filter_cols: + if c in parameters: + where_segments.append(f'"{c}" = ?') + where_bindings.append(parameters[c]) + else: + raise ValueError( + f"Missing required update filter parameter '{c}'" + ) + sql += " WHERE " + " AND ".join(where_segments) + bindings = tuple(set_bindings + where_bindings) + else: + bindings = tuple(set_bindings) + + # Run write update + async with self._lock, self._conn.execute(sql, bindings) as cursor: + return cursor.rowcount + + elif query_type == "DELETE": + sql = f'DELETE FROM "{physical_name}"' # noqa: S608 + if filter_cols: + where_segments = [] + where_bindings = [] + for c in filter_cols: + if c in parameters: + where_segments.append(f'"{c}" = ?') + where_bindings.append(parameters[c]) + else: + raise ValueError( + f"Missing required delete filter parameter '{c}'" + ) + sql += " WHERE " + " AND ".join(where_segments) + bindings = tuple(where_bindings) + + # Run write delete + async with self._lock, self._conn.execute(sql, bindings) as cursor: + return cursor.rowcount + + else: + raise ValueError(f"Unsupported query type: {query_type}") + + async def set_custom_instruction_override( + self, user_id: str, instructions: str + ) -> None: + """Update or insert a custom instruction override.""" + now = now_utc().isoformat(timespec="seconds") + async with self._lock: + await self._conn.execute( + """ + INSERT INTO custom_instruction_overrides ( + user_id, key, instructions, updated_at + ) VALUES (?, 'custom_instructions', ?, ?) + ON CONFLICT (user_id, key) DO UPDATE SET + instructions = excluded.instructions, + updated_at = excluded.updated_at + """, + (user_id, instructions, now), + ) + logger.info("Updated custom instructions for user %s", user_id) + + async def get_custom_instruction_override(self, user_id: str) -> str | None: + """Retrieve custom instruction override for user.""" + row = await self._fetch_one( + "SELECT instructions FROM custom_instruction_overrides " + "WHERE user_id = ? AND key = 'custom_instructions'", + (user_id,), + ) + return row["instructions"] if row else None + + async def delete_custom_instruction_override(self, user_id: str) -> bool: + """Clear custom instruction override for user.""" + async with self._lock: + cursor = await self._conn.execute( + "DELETE FROM custom_instruction_overrides " + "WHERE user_id = ? AND key = 'custom_instructions'", + (user_id,), + ) + return cursor.rowcount > 0 + + async def list_custom_tables_and_templates(self, user_id: str) -> dict[str, Any]: + """List tables, columns, and saved templates for a given user.""" + tables = await self._fetch_all( + "SELECT table_name, description, created_at FROM custom_tables " + "WHERE user_id = ? ORDER BY table_name ASC", + (user_id,), + ) + columns = await self._fetch_all( + "SELECT table_name, column_name, column_type, is_primary_key, " + "is_not_null, default_value FROM custom_table_columns " + "WHERE user_id = ? ORDER BY table_name ASC, column_name ASC", + (user_id,), + ) + templates = await self._fetch_all( + "SELECT template_name, table_name, query_type, description, " + "select_columns, filter_columns, order_by_column, " + "order_by_direction, limit_val FROM saved_query_templates " + "WHERE user_id = ? ORDER BY template_name ASC", + (user_id,), + ) + + # Structure response + structured_tables: dict[str, Any] = {} + for t in tables: + t_name = t["table_name"] + structured_tables[t_name] = { + "description": t["description"], + "created_at": t["created_at"], + "columns": [], + "templates": [], + } + + for col in columns: + t_name = col["table_name"] + if t_name in structured_tables: + structured_tables[t_name]["columns"].append( + { + "name": col["column_name"], + "type": col["column_type"], + "primary_key": bool(col["is_primary_key"]), + "not_null": bool(col["is_not_null"]), + "default": col["default_value"], + } + ) + + for tmpl in templates: + t_name = tmpl["table_name"] + if t_name in structured_tables: + structured_tables[t_name]["templates"].append( + { + "name": tmpl["template_name"], + "type": tmpl["query_type"], + "description": tmpl["description"], + "select_columns": json.loads(tmpl["select_columns"]) + if tmpl["select_columns"] + else None, + "filter_columns": json.loads(tmpl["filter_columns"]) + if tmpl["filter_columns"] + else None, + "order_by": ( + f"{tmpl['order_by_column']} {tmpl['order_by_direction']}" + if tmpl["order_by_column"] + else None + ), + "limit": tmpl["limit_val"], + } + ) + + return structured_tables + + async def get_schema_instructions_xml(self, user_id: str) -> str: + """Compile schemas and overrides into instructions XML.""" + schema_data = await self.list_custom_tables_and_templates(user_id) + override = await self.get_custom_instruction_override(user_id) + + blocks: list[str] = [] + + if override: + blocks.append( + f"\n{override}\n" + ) + + if schema_data: + schema_lines = [ + "You have access to the following custom user-defined database tables:" + ] + for t_name, t_val in schema_data.items(): + schema_lines.append(f"\nTable: {t_name}") + if t_val["description"]: + schema_lines.append(f" Description: {t_val['description']}") + schema_lines.append(" Columns:") + for col in t_val["columns"]: + pk_label = " PRIMARY KEY" if col["primary_key"] else "" + nn_label = " NOT NULL" if col["not_null"] else "" + def_label = ( + f" DEFAULT {col['default']}" + if col["default"] is not None + else "" + ) + col_info = ( + f" - {col['name']} ({col['type']})" + f"{pk_label}{nn_label}{def_label}" + ) + schema_lines.append(col_info) + + if t_val["templates"]: + schema_lines.append(" Saved Query Templates:") + for tmpl in t_val["templates"]: + desc = ( + f" ({tmpl['description']})" if tmpl["description"] else "" + ) + schema_lines.append(f" - Template: {tmpl['name']}{desc}") + schema_lines.append(f" Operation: {tmpl['type']}") + if tmpl["select_columns"]: + cols_joined = ", ".join(tmpl["select_columns"]) + schema_lines.append(f" Returns Columns: {cols_joined}") + if tmpl["filter_columns"]: + params_joined = ", ".join(tmpl["filter_columns"]) + schema_lines.append( + f" Required Parameters: {params_joined}" + ) + if tmpl["order_by"]: + schema_lines.append(f" Sorted By: {tmpl['order_by']}") + if tmpl["limit"]: + schema_lines.append(f" Limit: {tmpl['limit']}") + + schema_str = "\n".join(schema_lines) + blocks.append( + f"\n{schema_str}\n" + ) + + return "\n\n".join(blocks) if blocks else "" + + +_storage: SqliteDeclarativeDbStorage | None = None + + +def get_declarative_db_storage() -> SqliteDeclarativeDbStorage: + """Return the process-wide singleton SqliteDeclarativeDbStorage instance.""" + from blacki.container import get_container + + container = get_container() + storage = container.declarative_db_storage + if not storage.is_initialized: + raise RuntimeError( + "Declarative DB storage not initialized. Call storage.initialize() first." + ) + return storage diff --git a/src/blacki/declarative_db/tools.py b/src/blacki/declarative_db/tools.py new file mode 100644 index 0000000..3fd478f --- /dev/null +++ b/src/blacki/declarative_db/tools.py @@ -0,0 +1,351 @@ +"""Agent-facing ADK tools for declarative databases.""" + +from __future__ import annotations + +import logging +from typing import Any + +from google.adk.tools import ToolContext + +from blacki.declarative_db.storage import get_declarative_db_storage + +logger = logging.getLogger(__name__) + + +async def create_custom_table( + table_name: str, + columns: list[dict[str, Any]], + tool_context: ToolContext, + description: str | None = None, +) -> dict[str, Any]: + """Create a new custom physical table and register its metadata. + + The physical SQLite table is securely created with the specified columns and + types, and is scoped and isolated per user. + + Args: + table_name: The name of the table to create (lowercase, alphanumeric). + columns: A list of dicts specifying column definitions. Each column dict + can have: + - "name" (str): Column name (alphanumeric, lowercase). + - "type" (str): Column type. Restricted strictly to: + "TEXT", "INTEGER", "REAL", "BLOB". + - "primary_key" (bool): Whether column is primary key. + - "not_null" (bool): Whether column cannot be null. + - "default" (any): Default column value. + tool_context: ADK tool context. + description: A short optional description of the table's purpose. + + Returns: + A status dictionary confirming creation or detailing validation errors. + """ + user_id = getattr(tool_context, "user_id", None) or tool_context.state.get( + "user_id" + ) + if not user_id: + return { + "status": "error", + "message": "User not identified in tool context.", + } + + try: + storage = get_declarative_db_storage() + await storage.create_custom_table( + user_id=str(user_id), + table_name=table_name, + columns=columns, + description=description, + ) + return { + "status": "success", + "message": ( + f"Table '{table_name}' was successfully created with " + f"{len(columns)} columns." + ), + } + except Exception as e: + logger.exception("Failed to create custom table %s", table_name) + return { + "status": "error", + "message": f"Failed to create table: {e}", + } + + +async def delete_custom_table( + table_name: str, + tool_context: ToolContext, +) -> dict[str, Any]: + """Physically drop a custom table and delete all its query templates. + + Args: + table_name: The logical name of the table to drop. + tool_context: ADK tool context. + + Returns: + A status dictionary. + """ + user_id = getattr(tool_context, "user_id", None) or tool_context.state.get( + "user_id" + ) + if not user_id: + return { + "status": "error", + "message": "User not identified in tool context.", + } + + try: + storage = get_declarative_db_storage() + deleted = await storage.delete_custom_table(str(user_id), table_name) + if deleted: + return { + "status": "success", + "message": ( + f"Table '{table_name}' and all its templates " + "were successfully deleted." + ), + } + return { + "status": "error", + "message": f"Table '{table_name}' was not found.", + } + except Exception as e: + logger.exception("Failed to delete custom table %s", table_name) + return { + "status": "error", + "message": f"Failed to delete table: {e}", + } + + +async def create_query_template( + template_name: str, + table_name: str, + query_type: str, + tool_context: ToolContext, + select_columns: list[str] | None = None, + filter_columns: list[str] | None = None, + order_by_column: str | None = None, + order_by_direction: str | None = None, + limit_val: int | None = None, + description: str | None = None, +) -> dict[str, Any]: + """Save a query template against a custom user table. + + Rather than compiling raw string statements, templates define structure. + The storage manager compiles them securely into parameterized queries + during runtime execution. + + Args: + template_name: The name of this query template (unique per user). + table_name: The logical table this query executes against. + query_type: The action of the query. Must be one of: + "SELECT", "INSERT", "UPDATE", "DELETE". + tool_context: ADK tool context. + select_columns: List of columns to return (for SELECT). If empty, returns all. + filter_columns: List of columns to bind parameters to inside the WHERE clause + (e.g., ["id"] maps to WHERE "id" = ?). All must be + passed in parameters during execution. + order_by_column: Name of column to sort on (for SELECT). + order_by_direction: Sorting direction (for SELECT). Must be "ASC" or "DESC". + limit_val: Optional integer pagination constraint (for SELECT). + description: A brief optional description of what the query template does. + + Returns: + A status dictionary. + """ + user_id = getattr(tool_context, "user_id", None) or tool_context.state.get( + "user_id" + ) + if not user_id: + return { + "status": "error", + "message": "User not identified in tool context.", + } + + try: + storage = get_declarative_db_storage() + await storage.create_query_template( + user_id=str(user_id), + template_name=template_name, + table_name=table_name, + query_type=query_type, + select_columns=select_columns, + filter_columns=filter_columns, + order_by_column=order_by_column, + order_by_direction=order_by_direction, + limit_val=limit_val, + description=description, + ) + return { + "status": "success", + "message": ( + f"Query template '{template_name}' was successfully " + f"registered for table '{table_name}'." + ), + } + except Exception as e: + logger.exception("Failed to create query template %s", template_name) + return { + "status": "error", + "message": f"Failed to register query template: {e}", + } + + +async def execute_query_template( + template_name: str, + parameters: dict[str, Any], + tool_context: ToolContext, +) -> dict[str, Any]: + """Execute a saved query template securely with inputs. + + Args: + template_name: Name of the template to run. + parameters: A dictionary of key-value bindings. Every key must map + exactly to a valid column defined for the underlying table. + tool_context: ADK tool context. + + Returns: + A result dictionary containing status and output records (for SELECT) or + affected rows count / last inserted ID (for write queries). + """ + user_id = getattr(tool_context, "user_id", None) or tool_context.state.get( + "user_id" + ) + if not user_id: + return { + "status": "error", + "message": "User not identified in tool context.", + } + + try: + storage = get_declarative_db_storage() + res = await storage.execute_query_template( + user_id=str(user_id), + template_name=template_name, + parameters=parameters, + ) + return { + "status": "success", + "result": res, + } + except Exception as e: + logger.exception("Failed to execute query template %s", template_name) + return { + "status": "error", + "message": f"Execution failed: {e}", + } + + +async def list_custom_tables_and_templates( + tool_context: ToolContext, +) -> dict[str, Any]: + """List all custom tables, columns, and registered templates for the active user. + + Args: + tool_context: ADK tool context. + + Returns: + A dictionary mapping table names to their schema structures and query templates. + """ + user_id = getattr(tool_context, "user_id", None) or tool_context.state.get( + "user_id" + ) + if not user_id: + return { + "status": "error", + "message": "User not identified in tool context.", + } + + try: + storage = get_declarative_db_storage() + schemas = await storage.list_custom_tables_and_templates(str(user_id)) + return { + "status": "success", + "tables": schemas, + } + except Exception as e: + logger.exception("Failed to list schemas") + return { + "status": "error", + "message": f"Failed to fetch custom schemas: {e}", + } + + +async def set_custom_instruction_override( + instructions: str, + tool_context: ToolContext, +) -> dict[str, Any]: + """Persist a custom instruction override for the user. + + This instructions block is injected into system instruction prompts to guide + the model's persona, handling, or workflows dynamically, safely isolated from + the system prompt codebase. + + Args: + instructions: Raw text instructions to guide the model. + tool_context: ADK tool context. + + Returns: + A status dictionary. + """ + user_id = getattr(tool_context, "user_id", None) or tool_context.state.get( + "user_id" + ) + if not user_id: + return { + "status": "error", + "message": "User not identified in tool context.", + } + + try: + storage = get_declarative_db_storage() + await storage.set_custom_instruction_override(str(user_id), instructions) + return { + "status": "success", + "message": "Custom instructions successfully saved.", + } + except Exception as e: + logger.exception("Failed to set instruction override") + return { + "status": "error", + "message": f"Failed to save instructions: {e}", + } + + +async def delete_custom_instruction_override( + tool_context: ToolContext, +) -> dict[str, Any]: + """Clear custom instruction override for the user. + + Args: + tool_context: ADK tool context. + + Returns: + A status dictionary. + """ + user_id = getattr(tool_context, "user_id", None) or tool_context.state.get( + "user_id" + ) + if not user_id: + return { + "status": "error", + "message": "User not identified in tool context.", + } + + try: + storage = get_declarative_db_storage() + deleted = await storage.delete_custom_instruction_override(str(user_id)) + if deleted: + return { + "status": "success", + "message": "Custom instructions successfully deleted.", + } + return { + "status": "success", + "message": "No custom instructions existed to delete.", + } + except Exception as e: + logger.exception("Failed to delete instruction override") + return { + "status": "error", + "message": f"Failed to clear instructions: {e}", + } diff --git a/src/blacki/declarative_db/validation.py b/src/blacki/declarative_db/validation.py new file mode 100644 index 0000000..8b81831 --- /dev/null +++ b/src/blacki/declarative_db/validation.py @@ -0,0 +1,96 @@ +"""Safety and validation layer for declarative database identifiers and types.""" + +from __future__ import annotations + +import re + +# Strict regex matching letters, numbers, and underscores, +# starting with a letter or underscore. +IDENTIFIER_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + +# Safe SQLite types allowlist +ALLOWED_TYPES = {"TEXT", "INTEGER", "REAL", "BLOB"} + +# Standard SQLite reserved keywords blocklist to avoid syntax overlaps and confusion +RESERVED_KEYWORDS = { + "SELECT", + "DROP", + "ALTER", + "DELETE", + "INSERT", + "PRAGMA", + "CREATE", + "TABLE", + "INDEX", + "UPDATE", + "WHERE", + "AND", + "OR", + "JOIN", + "FROM", + "LIMIT", + "ORDER", + "BY", + "IN", + "IS", + "NULL", + "NOT", + "PRIMARY", + "KEY", + "FOREIGN", + "REFERENCES", + "DEFAULT", + "CHECK", + "COLLATE", + "UNIQUE", + "EXISTS", + "INTO", + "VALUES", + "SET", + "ON", +} + + +def validate_identifier(name: str) -> None: + """Validate a table, column, or template name. + + Args: + name: Name to validate. + + Raises: + ValueError: If name fails format, length, or keyword validation. + """ + if not name: + raise ValueError("Identifier cannot be empty") + + if len(name) > 64: + raise ValueError(f"Identifier '{name}' exceeds maximum length of 64 characters") + + if not IDENTIFIER_REGEX.match(name): + raise ValueError( + f"Identifier '{name}' is invalid. " + "Must start with a letter or underscore and " + "contain only alphanumeric characters and underscores." + ) + + if name.upper() in RESERVED_KEYWORDS: + raise ValueError( + f"Identifier '{name}' is a reserved SQL keyword and cannot be used" + ) + + +def validate_column_type(col_type: str) -> None: + """Validate that the column type is in the strict allowlist. + + Args: + col_type: Type string to validate. + + Raises: + ValueError: If the type is not allowed. + """ + cleaned_type = col_type.strip().upper() + if cleaned_type not in ALLOWED_TYPES: + raise ValueError( + f"Type '{col_type}' is not allowed. " + f"Must be one of: {', '.join(sorted(ALLOWED_TYPES))}" + ) diff --git a/src/blacki/registry.py b/src/blacki/registry.py index c2efa51..890d026 100644 --- a/src/blacki/registry.py +++ b/src/blacki/registry.py @@ -54,6 +54,7 @@ def build_tools(config: ToolConfig) -> list[Any]: tools.extend(_build_reminder_tools()) tools.extend(_build_calorie_tools()) tools.extend(_build_workout_tools()) + tools.extend(_build_declarative_db_tools()) logger.info("Database-backed tools enabled") if config.sandbox_enabled: @@ -233,6 +234,33 @@ def _build_memory_tools() -> list[Any]: return [] +def _build_declarative_db_tools() -> list[Any]: + """Build declarative database tools.""" + try: + from blacki.declarative_db.tools import ( + create_custom_table, + create_query_template, + delete_custom_instruction_override, + delete_custom_table, + execute_query_template, + list_custom_tables_and_templates, + set_custom_instruction_override, + ) + + return [ + create_custom_table, + delete_custom_table, + create_query_template, + execute_query_template, + list_custom_tables_and_templates, + set_custom_instruction_override, + delete_custom_instruction_override, + ] + except ImportError as e: # pragma: no cover + logger.warning("Failed to load Declarative DB tools: %s", e) + return [] + + def build_tool_config_from_env() -> ToolConfig: """Build tool configuration from environment variables. diff --git a/tests/declarative_db/test_declarative_db.py b/tests/declarative_db/test_declarative_db.py new file mode 100644 index 0000000..60741f7 --- /dev/null +++ b/tests/declarative_db/test_declarative_db.py @@ -0,0 +1,1199 @@ +# mypy: disable-error-code="no-untyped-def" +"""Unit and integration tests for the declarative database tool template system.""" + +from __future__ import annotations + +import asyncio +import json +from collections.abc import AsyncIterator, Generator +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import aiosqlite +import pytest + +from blacki.container import set_container_from_connection +from blacki.declarative_db.plugin import DeclarativeDbPlugin +from blacki.declarative_db.storage import ( + SqliteDeclarativeDbStorage, +) +from blacki.declarative_db.tools import ( + create_custom_table, + create_query_template, + delete_custom_instruction_override, + delete_custom_table, + execute_query_template, + list_custom_tables_and_templates, + set_custom_instruction_override, +) +from blacki.declarative_db.validation import validate_column_type, validate_identifier + +# ============================================================================== +# 1. Validation Layer Tests +# ============================================================================== + + +class TestValidation: + """Test safe identifier and type validation constraints.""" + + def test_valid_identifiers(self) -> None: + """Should accept valid identifiers.""" + validate_identifier("users") + validate_identifier("user_todos") + validate_identifier("_private_table") + validate_identifier("column1") + + def test_invalid_identifiers_pattern(self) -> None: + """Should reject names with special characters or starting with numbers.""" + with pytest.raises(ValueError, match="is invalid"): + validate_identifier("user-todos") + with pytest.raises(ValueError, match="is invalid"): + validate_identifier("1todos") + with pytest.raises(ValueError, match="is invalid"): + validate_identifier("user todos") + + def test_invalid_identifiers_keywords(self) -> None: + """Should reject case-insensitive SQL reserved keywords.""" + with pytest.raises(ValueError, match="reserved SQL keyword"): + validate_identifier("SELECT") + with pytest.raises(ValueError, match="reserved SQL keyword"): + validate_identifier("Drop") + with pytest.raises(ValueError, match="reserved SQL keyword"): + validate_identifier("table") + + def test_invalid_identifiers_length(self) -> None: + """Should reject names exceeding 64 characters.""" + long_name = "a" * 65 + with pytest.raises(ValueError, match="exceeds maximum length"): + validate_identifier(long_name) + + def test_invalid_identifiers_empty(self) -> None: + """Should reject empty names.""" + with pytest.raises(ValueError, match="cannot be empty"): + validate_identifier("") + + def test_valid_column_types(self) -> None: + """Should accept allowlisted column types.""" + validate_column_type("TEXT") + validate_column_type("integer") + validate_column_type("REAL") + validate_column_type("blob") + + def test_invalid_column_types(self) -> None: + """Should reject non-allowlisted column types.""" + with pytest.raises(ValueError, match="is not allowed"): + validate_column_type("VARCHAR(50)") + with pytest.raises(ValueError, match="is not allowed"): + validate_column_type("DATETIME") + with pytest.raises(ValueError, match="is not allowed"): + validate_column_type("SERIAL") + + +# ============================================================================== +# 2. Storage Integration Tests +# ============================================================================== + + +@pytest.fixture +async def sqlite_conn(tmp_path: Path) -> AsyncIterator[aiosqlite.Connection]: + """Provide a real, isolated SQLite connection with schema created.""" + db_path = tmp_path / "test_declarative.db" + from blacki.storage.sqlite import create_connection + + conn = await create_connection(db_path) + yield conn + await conn.close() + + +@pytest.fixture +async def storage(sqlite_conn: aiosqlite.Connection) -> SqliteDeclarativeDbStorage: + """Provide initialized SqliteDeclarativeDbStorage instance.""" + lock = asyncio.Lock() + storage = SqliteDeclarativeDbStorage(sqlite_conn, lock) + await storage.initialize() + return storage + + +class TestStorageIntegration: + """Integration tests for SqliteDeclarativeDbStorage.""" + + @pytest.mark.anyio + async def test_create_and_delete_custom_table( + self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection + ) -> None: + """Should create a physical custom table and metadata, and drop it properly.""" + user_id = "user_123" + table_name = "items" + columns: list[dict[str, Any]] = [ + {"name": "id", "type": "INTEGER", "primary_key": True}, + {"name": "title", "type": "TEXT", "not_null": True}, + {"name": "price", "type": "REAL", "default": 0.0}, + ] + + # 1. Create table + await storage.create_custom_table( + user_id=user_id, + table_name=table_name, + columns=columns, + description="Item catalog", + ) + + # Verify physical table structure via PRAGMA + physical_name = storage._get_physical_name(user_id, table_name) + async with sqlite_conn.execute( + f"PRAGMA table_info('{physical_name}')" + ) as cursor: + rows = list(await cursor.fetchall()) + assert len(rows) == 3 + assert rows[0]["name"] == "id" + assert rows[0]["type"] == "INTEGER" + assert rows[0]["pk"] == 1 + assert rows[1]["name"] == "title" + assert rows[1]["notnull"] == 1 + assert rows[2]["name"] == "price" + assert rows[2]["dflt_value"] == "0.0" + + # Verify metadata + metadata = await storage.list_custom_tables_and_templates(user_id) + assert table_name in metadata + assert metadata[table_name]["description"] == "Item catalog" + assert len(metadata[table_name]["columns"]) == 3 + + # 2. Delete table + deleted = await storage.delete_custom_table(user_id, table_name) + assert deleted is True + + # Verify metadata is cleared + metadata = await storage.list_custom_tables_and_templates(user_id) + assert table_name not in metadata + + # Verify physical table is dropped + drop_query = ( + "SELECT count(*) FROM sqlite_master " # noqa: S608 + f"WHERE type='table' AND name='{physical_name}'" # noqa: S608 + ) + async with sqlite_conn.execute(drop_query) as cursor: + row = await cursor.fetchone() + assert row is not None + assert row[0] == 0 + + @pytest.mark.anyio + async def test_table_limits_guardrails( + self, storage: SqliteDeclarativeDbStorage + ) -> None: + """Should block creating more than 5 tables or tables with > 15 columns.""" + user_id = "user_guard" + + # Table with too many columns + too_many_cols: list[dict[str, Any]] = [ + {"name": f"col_{i}", "type": "TEXT"} for i in range(16) + ] + with pytest.raises(ValueError, match="Limit of 15 columns"): + await storage.create_custom_table(user_id, "big_table", too_many_cols) + + # Register 5 tables + cols: list[dict[str, Any]] = [{"name": "id", "type": "INTEGER"}] + for i in range(5): + await storage.create_custom_table(user_id, f"table_{i}", cols) + + # Attempting the 6th table should fail + with pytest.raises(ValueError, match="Limit of 5 custom tables"): + await storage.create_custom_table(user_id, "table_6", cols) + + @pytest.mark.anyio + async def test_query_templates_crud_and_validation( + self, storage: SqliteDeclarativeDbStorage + ) -> None: + """Should manage and validate templates, enforcing a limit of 10.""" + user_id = "user_templates" + table_name = "tasks" + columns: list[dict[str, Any]] = [ + {"name": "id", "type": "INTEGER", "primary_key": True}, + {"name": "task_name", "type": "TEXT"}, + {"name": "status", "type": "TEXT"}, + ] + + await storage.create_custom_table(user_id, table_name, columns) + + # 1. Create a SELECT template + await storage.create_query_template( + user_id=user_id, + template_name="get_by_status", + table_name=table_name, + query_type="SELECT", + select_columns=["id", "task_name"], + filter_columns=["status"], + order_by_column="id", + order_by_direction="DESC", + limit_val=5, + description="Fetch tasks by status", + ) + + # Validate duplicate/overwrite + await storage.create_query_template( + user_id=user_id, + template_name="get_by_status", + table_name=table_name, + query_type="SELECT", + select_columns=["id", "task_name"], + filter_columns=["status"], + description="Fetch tasks by status (updated)", + ) + + # 2. Reject template with invalid query_type + with pytest.raises(ValueError, match="Query type must be"): + await storage.create_query_template( + user_id, "invalid_tmpl", table_name, "DROP" + ) + + # 3. Reject template with invalid sorting direction + with pytest.raises(ValueError, match="Sorting direction must be"): + await storage.create_query_template( + user_id, + "invalid_tmpl", + table_name, + "SELECT", + order_by_direction="OTHER", + ) + + # 4. Reject template referencing non-existent table + with pytest.raises(ValueError, match="does not exist"): + await storage.create_query_template( + user_id, "tmpl", "non_existent_table", "SELECT" + ) + + # 5. Reject template referencing non-existent column + with pytest.raises(ValueError, match="does not exist in table"): + await storage.create_query_template( + user_id, + "tmpl", + table_name, + "SELECT", + select_columns=["non_existent_col"], + ) + + # 6. Template limits guardrail (max 10) + for i in range(1, 10): # Already have 1, add 9 more + await storage.create_query_template( + user_id, f"tmpl_{i}", table_name, "INSERT" + ) + + with pytest.raises(ValueError, match="Limit of 10 saved query templates"): + await storage.create_query_template( + user_id, "tmpl_11", table_name, "INSERT" + ) + + @pytest.mark.anyio + async def test_execute_queries_and_sql_injection_protection( + self, storage: SqliteDeclarativeDbStorage + ) -> None: + """Should run parameterized templates and catch malicious key injections.""" + user_id = "user_exec" + table_name = "notes" + columns: list[dict[str, Any]] = [ + {"name": "id", "type": "INTEGER", "primary_key": True}, + {"name": "content", "type": "TEXT"}, + {"name": "category", "type": "TEXT"}, + ] + + await storage.create_custom_table(user_id, table_name, columns) + + # Create INSERT template + await storage.create_query_template( + user_id=user_id, + template_name="add_note", + table_name=table_name, + query_type="INSERT", + ) + + # Create SELECT template + await storage.create_query_template( + user_id=user_id, + template_name="get_notes", + table_name=table_name, + query_type="SELECT", + select_columns=["id", "content"], + filter_columns=["category"], + ) + + # Create UPDATE template + await storage.create_query_template( + user_id=user_id, + template_name="update_note", + table_name=table_name, + query_type="UPDATE", + filter_columns=["id"], + ) + + # Create DELETE template + await storage.create_query_template( + user_id=user_id, + template_name="delete_notes", + table_name=table_name, + query_type="DELETE", + filter_columns=["category"], + ) + + # 1. Execute INSERT + row_id1 = await storage.execute_query_template( + user_id, "add_note", {"content": "Buy milk", "category": "shopping"} + ) + assert row_id1 == 1 + + row_id2 = await storage.execute_query_template( + user_id, "add_note", {"content": "Work gym", "category": "fitness"} + ) + assert row_id2 == 2 + + # 2. Execute SELECT + res = await storage.execute_query_template( + user_id, "get_notes", {"category": "shopping"} + ) + assert isinstance(res, list) + assert len(res) == 1 + assert res[0]["content"] == "Buy milk" + assert "category" not in res[0] # select_columns only returns id and content + + # 3. Execute UPDATE + affected_rows = await storage.execute_query_template( + user_id, "update_note", {"id": 1, "content": "Buy whole milk"} + ) + assert affected_rows == 1 + + # Verify update worked + res = await storage.execute_query_template( + user_id, "get_notes", {"category": "shopping"} + ) + assert isinstance(res, list) + assert res[0]["content"] == "Buy whole milk" + + # 4. Execute DELETE + deleted_count = await storage.execute_query_template( + user_id, "delete_notes", {"category": "shopping"} + ) + assert deleted_count == 1 + + # Verify no items left in shopping category + res = await storage.execute_query_template( + user_id, "get_notes", {"category": "shopping"} + ) + assert isinstance(res, list) + assert len(res) == 0 + + # 5. Security Parameter Key Injection Guard check + # Attempt to pass a malicious SQL string as a parameter key. + mal_key = "category OR 1=1; DROP TABLE notes; --" + with pytest.raises(ValueError, match="is invalid|is not a valid column"): + await storage.execute_query_template( + user_id, "get_notes", {mal_key: "shopping"} + ) + + @pytest.mark.anyio + async def test_custom_instructions_overrides( + self, storage: SqliteDeclarativeDbStorage + ) -> None: + """Should save, get, and delete custom user-scoped prompt overrides.""" + user_id = "user_override" + + # Initially empty + val = await storage.get_custom_instruction_override(user_id) + assert val is None + + # Set instructions + await storage.set_custom_instruction_override(user_id, "Tone: Sarcastic") + val = await storage.get_custom_instruction_override(user_id) + assert val == "Tone: Sarcastic" + + # Clear instructions + deleted = await storage.delete_custom_instruction_override(user_id) + assert deleted is True + val = await storage.get_custom_instruction_override(user_id) + assert val is None + + @pytest.mark.anyio + async def test_get_schema_instructions_xml( + self, storage: SqliteDeclarativeDbStorage + ) -> None: + """Should compile schemas and overrides into instructions XML.""" + user_id = "user_xml" + assert await storage.get_schema_instructions_xml(user_id) == "" + + # Set instructions override + await storage.set_custom_instruction_override(user_id, "Be extra kind.") + + # Create table & template + await storage.create_custom_table( + user_id, "logs", [{"name": "id", "type": "INTEGER", "primary_key": True}] + ) + await storage.create_query_template(user_id, "add_log", "logs", "INSERT") + + xml = await storage.get_schema_instructions_xml(user_id) + assert "" in xml + assert "Be extra kind." in xml + assert "" in xml + assert "Table: logs" in xml + assert "id (INTEGER) PRIMARY KEY" in xml + assert "Template: add_log" in xml + + @pytest.mark.anyio + async def test_create_custom_table_with_no_columns( + self, storage: SqliteDeclarativeDbStorage + ) -> None: + """Should raise ValueError when creating a table with no columns.""" + with pytest.raises(ValueError, match="Table must define at least one column"): + await storage.create_custom_table( + user_id="user_test", + table_name="no_cols", + columns=[], + ) + + @pytest.mark.anyio + async def test_create_custom_table_already_exists( + self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection + ) -> None: + """Verify overwriting an existing table works and bypasses limit.""" + user_id = "user_overwrite_test" + cols = [{"name": "id", "type": "INTEGER", "primary_key": True}] + + # Create 5 tables (limit is 5) + for i in range(5): + await storage.create_custom_table(user_id, f"tbl_{i}", cols) + + # Overwrite tbl_0 (bypasses limit) + await storage.create_custom_table( + user_id, + "tbl_0", + [ + {"name": "id", "type": "INTEGER", "primary_key": True}, + {"name": "new_col", "type": "TEXT"}, + ], + description="updated", + ) + + metadata = await storage.list_custom_tables_and_templates(user_id) + assert len(metadata["tbl_0"]["columns"]) == 2 + assert metadata["tbl_0"]["description"] == "updated" + + @pytest.mark.anyio + async def test_column_default_value_types( + self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection + ) -> None: + """Test different default value types in DDL construction.""" + user_id = "user_defaults" + table_name = "defaults_test" + + class Dummy: + def __str__(self) -> str: + return "dummy'value" + + columns: list[dict[str, Any]] = [ + {"name": "id", "type": "INTEGER", "primary_key": True}, + {"name": "bool_t", "type": "INTEGER", "default": True}, + {"name": "bool_f", "type": "INTEGER", "default": False}, + {"name": "complex_list", "type": "TEXT", "default": [1, 2, "three's"]}, + {"name": "complex_dict", "type": "TEXT", "default": {"a'b": 1}}, + {"name": "str_val", "type": "TEXT", "default": "hello'world"}, + {"name": "fallback", "type": "TEXT", "default": 42.5}, + {"name": "object_fallback", "type": "TEXT", "default": Dummy()}, + ] + + await storage.create_custom_table(user_id, table_name, columns) + + # Insert a row with defaults + physical_name = storage._get_physical_name(user_id, table_name) + await sqlite_conn.execute(f'INSERT INTO "{physical_name}" (id) VALUES (1)') # noqa: S608 + await sqlite_conn.commit() + + # Retrieve row to verify values + async with sqlite_conn.execute( + f'SELECT * FROM "{physical_name}" WHERE id = 1' # noqa: S608 + ) as cursor: + row = await cursor.fetchone() + assert row is not None + assert row["bool_t"] == 1 + assert row["bool_f"] == 0 + assert json.loads(row["complex_list"]) == [1, 2, "three's"] + assert json.loads(row["complex_dict"]) == {"a'b": 1} + assert row["str_val"] == "hello'world" + assert float(row["fallback"]) == 42.5 + assert row["object_fallback"] == "dummy'value" + + @pytest.mark.anyio + async def test_delete_custom_table_not_exists( + self, storage: SqliteDeclarativeDbStorage + ) -> None: + """Should return False if dropping a non-existent table.""" + res = await storage.delete_custom_table("user_id", "non_existent") + assert res is False + + @pytest.mark.anyio + async def test_storage_exceptions_rollback( + self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection + ) -> None: + """Verify storage exceptions during create/delete rollback transaction.""" + user_id = "user_exceptions" + cols = [{"name": "id", "type": "INTEGER", "primary_key": True}] + + # 1. Create table failure + original_execute = sqlite_conn.execute + + def mock_execute(sql: str, *args, **kwargs): + if "INSERT INTO custom_table_columns" in sql: + raise Exception("forced column insert failure") + return original_execute(sql, *args, **kwargs) + + with ( + patch.object(sqlite_conn, "execute", side_effect=mock_execute), + pytest.raises(Exception, match="forced column insert failure"), + ): + await storage.create_custom_table(user_id, "fail_tbl", cols) + + # Check that table does not exist in metadata + metadata = await storage.list_custom_tables_and_templates(user_id) + assert "fail_tbl" not in metadata + + # 2. Delete table failure + await storage.create_custom_table(user_id, "success_tbl", cols) + + def mock_execute_delete(sql: str, *args, **kwargs): + if "DELETE FROM custom_tables" in sql: + raise Exception("forced delete metadata failure") + return original_execute(sql, *args, **kwargs) + + with ( + patch.object(sqlite_conn, "execute", side_effect=mock_execute_delete), + pytest.raises(Exception, match="forced delete metadata failure"), + ): + await storage.delete_custom_table(user_id, "success_tbl") + + # Table metadata should still exist + metadata = await storage.list_custom_tables_and_templates(user_id) + assert "success_tbl" in metadata + + @pytest.mark.anyio + async def test_create_query_template_overwrites( + self, storage: SqliteDeclarativeDbStorage + ) -> None: + """Verify limit check is bypassed when template already exists (overwrites).""" + user_id = "user_tmpl_limit" + cols = [{"name": "id", "type": "INTEGER", "primary_key": True}] + await storage.create_custom_table(user_id, "tbl", cols) + + # Create 10 templates (the limit) + for i in range(10): + await storage.create_query_template(user_id, f"tmpl_{i}", "tbl", "SELECT") + + # Try creating an 11th new template -> should fail + with pytest.raises(ValueError, match="Limit of 10 saved query templates"): + await storage.create_query_template(user_id, "tmpl_10", "tbl", "SELECT") + + # Try overwriting an existing template (e.g. tmpl_0) -> should succeed + await storage.create_query_template( + user_id, "tmpl_0", "tbl", "SELECT", description="updated template" + ) + + @pytest.mark.anyio + async def test_create_query_template_invalid_sort_col( + self, storage: SqliteDeclarativeDbStorage + ) -> None: + """Should raise ValueError if sorting column does not exist in table.""" + user_id = "user_test" + cols = [{"name": "id", "type": "INTEGER", "primary_key": True}] + await storage.create_custom_table(user_id, "tbl", cols) + + with pytest.raises(ValueError, match="does not exist in table"): + await storage.create_query_template( + user_id=user_id, + template_name="tmpl", + table_name="tbl", + query_type="SELECT", + order_by_column="non_existent", + ) + + @pytest.mark.anyio + async def test_execute_query_template_errors_and_edge_cases( + self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection + ) -> None: + """Verify errors and edge cases when executing query templates.""" + await sqlite_conn.execute("PRAGMA foreign_keys = OFF") + user_id = "user_exec_edge" + table_name = "tbl" + cols: list[dict[str, Any]] = [ + {"name": "id", "type": "INTEGER", "primary_key": True}, + {"name": "name", "type": "TEXT"}, + {"name": "age", "type": "INTEGER"}, + ] + await storage.create_custom_table(user_id, table_name, cols) + + # 1. Execute non-existent template + with pytest.raises(ValueError, match="Template 'non_existent' not found"): + await storage.execute_query_template(user_id, "non_existent", {}) + + # 2. Execute where underlying table does not exist + await storage.create_query_template( + user_id, "get_val", table_name, "SELECT", select_columns=["name"] + ) + # Directly delete table from custom_tables metadata table. + # This orphans the template (bypasses default CASCADE). + await sqlite_conn.execute( + "DELETE FROM custom_tables WHERE user_id = ? AND table_name = ?", + (user_id, table_name), + ) + await sqlite_conn.commit() + + with pytest.raises(ValueError, match="Underlying table 'tbl' does not exist"): + await storage.execute_query_template(user_id, "get_val", {}) + + # Re-create table for subsequent tests + await storage.create_custom_table(user_id, table_name, cols) + + # 3. Parameter key not in columns + with pytest.raises(ValueError, match="is not a valid column"): + await storage.execute_query_template(user_id, "get_val", {"invalid_col": 1}) + + # 4. SELECT missing required filter parameter + await storage.create_query_template( + user_id, "get_with_filter", table_name, "SELECT", filter_columns=["age"] + ) + with pytest.raises(ValueError, match="Missing required filter parameter 'age'"): + await storage.execute_query_template(user_id, "get_with_filter", {}) + + # 5. INSERT no insert keys in parameters + await storage.create_query_template(user_id, "add_val", table_name, "INSERT") + with pytest.raises( + ValueError, match="Must provide at least one valid parameter to insert" + ): + # Empty parameters + await storage.execute_query_template(user_id, "add_val", {}) + + # 6. UPDATE no update keys in parameters + await storage.create_query_template( + user_id, "update_val", table_name, "UPDATE", filter_columns=["id"] + ) + with pytest.raises(ValueError, match="No columns provided to update"): + # Passing only filter column, nothing to UPDATE/SET + await storage.execute_query_template(user_id, "update_val", {"id": 1}) + + # 7. UPDATE missing required filter parameter + with pytest.raises( + ValueError, match="Missing required update filter parameter 'id'" + ): + await storage.execute_query_template(user_id, "update_val", {"name": "Bob"}) + + # 8. UPDATE with no filters (should succeed without WHERE) + await storage.create_query_template(user_id, "update_all", table_name, "UPDATE") + # Insert a row first + await storage.execute_query_template( + user_id, "add_val", {"id": 1, "name": "Alice"} + ) + # Update name globally without filter + affected = await storage.execute_query_template( + user_id, "update_all", {"name": "Bob"} + ) + assert affected == 1 + + # 9. DELETE missing required filter parameter + await storage.create_query_template( + user_id, "delete_val", table_name, "DELETE", filter_columns=["id"] + ) + with pytest.raises( + ValueError, match="Missing required delete filter parameter 'id'" + ): + await storage.execute_query_template(user_id, "delete_val", {}) + + # 10. Unsupported query type + await storage.create_query_template(user_id, "bad_type", table_name, "SELECT") + await sqlite_conn.execute( + "UPDATE saved_query_templates SET query_type = 'INVALID' " + "WHERE user_id = ? AND template_name = ?", + (user_id, "bad_type"), + ) + await sqlite_conn.commit() + with pytest.raises(ValueError, match="Unsupported query type: INVALID"): + await storage.execute_query_template(user_id, "bad_type", {}) + + # 11. SELECT order by direction fallback (None -> ASC) and limit + await storage.create_query_template( + user_id=user_id, + template_name="get_ordered", + table_name=table_name, + query_type="SELECT", + select_columns=["id", "name"], + order_by_column="name", + limit_val=2, + ) + res_ordered = await storage.execute_query_template(user_id, "get_ordered", {}) + assert isinstance(res_ordered, list) + + # 12. DELETE with no filters (should succeed) + await storage.create_query_template( + user_id=user_id, + template_name="delete_all", + table_name=table_name, + query_type="DELETE", + ) + deleted_cnt = await storage.execute_query_template(user_id, "delete_all", {}) + assert isinstance(deleted_cnt, int) + assert deleted_cnt >= 0 + + @pytest.mark.anyio + async def test_list_schemas_orphans( + self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection + ) -> None: + """Verify list schemas handles columns/templates for non-existent tables.""" + await sqlite_conn.execute("PRAGMA foreign_keys = OFF") + user_id = "user_orphans" + # Manually insert column metadata referencing a non-existent table + await sqlite_conn.execute( + "INSERT INTO custom_table_columns " + "(user_id, table_name, column_name, column_type) " + "VALUES (?, 'non_existent_table', 'ghost_col', 'TEXT')", + (user_id,), + ) + # Manually insert template metadata referencing a non-existent table + await sqlite_conn.execute( + "INSERT INTO saved_query_templates " + "(user_id, template_name, table_name, query_type) " + "VALUES (?, 'ghost_tmpl', 'non_existent_table', 'SELECT')", + (user_id,), + ) + await sqlite_conn.commit() + + # Call listing + res = await storage.list_custom_tables_and_templates(user_id) + # The non_existent_table should NOT be in res because it's not in custom_tables + assert "non_existent_table" not in res + + @pytest.mark.anyio + async def test_get_schema_instructions_xml_variations( + self, storage: SqliteDeclarativeDbStorage + ) -> None: + """Verify xml instructions formatting with different metadata states.""" + user_id = "user_xml_vars" + + # 1. Table with no description and no templates + await storage.create_custom_table( + user_id=user_id, + table_name="simple_table", + columns=[{"name": "id", "type": "INTEGER", "primary_key": True}], + ) + + xml1 = await storage.get_schema_instructions_xml(user_id) + assert "" in xml1 + assert "Table: simple_table" in xml1 + assert "Description:" not in xml1 + assert "Saved Query Templates:" not in xml1 + + # 2. Template with no select_columns, filter_columns, order_by, limit + await storage.create_query_template( + user_id=user_id, + template_name="get_simple", + table_name="simple_table", + query_type="SELECT", + ) + + xml2 = await storage.get_schema_instructions_xml(user_id) + assert "Saved Query Templates:" in xml2 + assert "Returns Columns:" not in xml2 + assert "Required Parameters:" not in xml2 + assert "Sorted By:" not in xml2 + assert "Limit:" not in xml2 + + # 3. Table WITH description and template with optional params + await storage.create_custom_table( + user_id=user_id, + table_name="detailed_table", + columns=[ + {"name": "id", "type": "INTEGER", "primary_key": True}, + {"name": "col_a", "type": "TEXT"}, + ], + description="Detailed table description", + ) + await storage.create_query_template( + user_id=user_id, + template_name="get_detailed", + table_name="detailed_table", + query_type="SELECT", + select_columns=["id"], + filter_columns=["col_a"], + order_by_column="id", + order_by_direction="ASC", + limit_val=10, + description="Fetch detailed", + ) + + xml3 = await storage.get_schema_instructions_xml(user_id) + assert "Description: Detailed table description" in xml3 + assert "Returns Columns: id" in xml3 + assert "Required Parameters: col_a" in xml3 + assert "Sorted By: id ASC" in xml3 + assert "Limit: 10" in xml3 + + def test_get_declarative_db_storage_uninitialized(self) -> None: + """Should raise RuntimeError when storage is not initialized.""" + from blacki.declarative_db.storage import get_declarative_db_storage + + mock_container = MagicMock() + mock_container.declarative_db_storage.is_initialized = False + + with ( + patch("blacki.container.get_container", return_value=mock_container), + pytest.raises(RuntimeError, match="Declarative DB storage not initialized"), + ): + get_declarative_db_storage() + + +# ============================================================================== +# 3. ADK Plugin Layer Tests +# ============================================================================== + + +class TestDeclarativeDbPlugin: + """Test dynamic context injection plugin.""" + + @pytest.mark.anyio + async def test_plugin_appends_instructions( + self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection + ) -> None: + """Should compile active schema XML and safely append to instructions.""" + # Setup container for process-wide lazy instantiation + container = set_container_from_connection(sqlite_conn) + container._declarative_db_storage = storage + + plugin = DeclarativeDbPlugin() + + # Build mock callback_context and llm_request + mock_session = MagicMock() + mock_session.state = {"user_id": "user_plugin_test"} + callback_context = MagicMock() + callback_context.session = mock_session + + llm_request = MagicMock() + llm_request.append_instructions = MagicMock() + + # Add data to storage to compile + await storage.set_custom_instruction_override( + "user_plugin_test", "Always speak in Spanish." + ) + + # Run callback + await plugin.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + + # Assert compiled schema was appended to LLM instructions + llm_request.append_instructions.assert_called_once() + args = llm_request.append_instructions.call_args[0][0] + assert len(args) == 1 + assert "Spanish" in args[0] + assert "" in args[0] + + @pytest.mark.anyio + async def test_plugin_no_session(self) -> None: + """Plugin should return early if callback_context.session is None.""" + plugin = DeclarativeDbPlugin() + callback_context = MagicMock() + callback_context.session = None + llm_request = MagicMock() + + await plugin.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + llm_request.append_instructions.assert_not_called() + + @pytest.mark.anyio + async def test_plugin_no_user_id_in_session_state(self) -> None: + """Plugin should return early if session has no user_id or telegram_chat_id.""" + plugin = DeclarativeDbPlugin() + mock_session = MagicMock() + mock_session.state = {} # Empty state + callback_context = MagicMock() + callback_context.session = mock_session + llm_request = MagicMock() + + await plugin.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + llm_request.append_instructions.assert_not_called() + + @pytest.mark.anyio + async def test_plugin_empty_schema_xml(self) -> None: + """Plugin should not call append_instructions if schema XML is empty.""" + plugin = DeclarativeDbPlugin() + mock_session = MagicMock() + mock_session.state = {"user_id": "test_user_empty"} + callback_context = MagicMock() + callback_context.session = mock_session + llm_request = MagicMock() + + mock_storage = MagicMock() + mock_storage.get_schema_instructions_xml = AsyncMock(return_value="") + + with patch( + "blacki.declarative_db.plugin.get_declarative_db_storage", + return_value=mock_storage, + ): + await plugin.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + llm_request.append_instructions.assert_not_called() + + @pytest.mark.anyio + async def test_plugin_storage_exception_logged(self) -> None: + """Plugin should catch and log storage exceptions without crashing.""" + plugin = DeclarativeDbPlugin() + mock_session = MagicMock() + mock_session.state = {"telegram_chat_id": "chat_123"} + callback_context = MagicMock() + callback_context.session = mock_session + llm_request = MagicMock() + + mock_storage = MagicMock() + mock_storage.get_schema_instructions_xml = AsyncMock( + side_effect=Exception("forced db error") + ) + + with patch( + "blacki.declarative_db.plugin.get_declarative_db_storage", + return_value=mock_storage, + ): + # This should not raise an error + await plugin.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + llm_request.append_instructions.assert_not_called() + + +# ============================================================================== +# 4. Agent Tools Layer Tests +# ============================================================================== + + +class TestAgentTools: + """Test ADK tools proxy functions.""" + + @pytest.fixture(autouse=True) + def setup_mock_container(self) -> Generator[None, None, None]: + """Mock out get_declarative_db_storage singleton lookup.""" + self.mock_storage = MagicMock() + patcher = MagicMock() + patcher.return_value = self.mock_storage + + # Patch get_declarative_db_storage in the tools module + self.patcher = patch( + "blacki.declarative_db.tools.get_declarative_db_storage", patcher + ) + self.patcher.start() + yield + self.patcher.stop() + + @pytest.mark.anyio + async def test_tools_extract_user_id_and_proxy_calls(self) -> None: + """Agent tools should pull user_id and call equivalent storage methods.""" + tool_context = MagicMock() + tool_context.user_id = None + tool_context.state = {"user_id": "tool_user"} + + # 1. create_custom_table tool + self.mock_storage.create_custom_table = AsyncMock() + res = await create_custom_table( + "events", + [{"name": "id", "type": "INTEGER"}], + tool_context, + "History of events", + ) + assert res["status"] == "success" + self.mock_storage.create_custom_table.assert_called_once_with( + user_id="tool_user", + table_name="events", + columns=[{"name": "id", "type": "INTEGER"}], + description="History of events", + ) + + # 2. delete_custom_table tool + self.mock_storage.delete_custom_table = AsyncMock(return_value=True) + res = await delete_custom_table("events", tool_context) + assert res["status"] == "success" + self.mock_storage.delete_custom_table.assert_called_once_with( + "tool_user", "events" + ) + + # 3. create_query_template tool + self.mock_storage.create_query_template = AsyncMock() + res = await create_query_template("add_event", "events", "INSERT", tool_context) + assert res["status"] == "success" + self.mock_storage.create_query_template.assert_called_once_with( + user_id="tool_user", + template_name="add_event", + table_name="events", + query_type="INSERT", + select_columns=None, + filter_columns=None, + order_by_column=None, + order_by_direction=None, + limit_val=None, + description=None, + ) + + # 4. execute_query_template tool + self.mock_storage.execute_query_template = AsyncMock(return_value=[{"id": 1}]) + res = await execute_query_template("get_events", {"id": 1}, tool_context) + assert res["status"] == "success" + assert res["result"] == [{"id": 1}] + self.mock_storage.execute_query_template.assert_called_once_with( + user_id="tool_user", + template_name="get_events", + parameters={"id": 1}, + ) + + # 5. list_custom_tables_and_templates tool + self.mock_storage.list_custom_tables_and_templates = AsyncMock(return_value={}) + res = await list_custom_tables_and_templates(tool_context) + assert res["status"] == "success" + assert res["tables"] == {} + self.mock_storage.list_custom_tables_and_templates.assert_called_once_with( + "tool_user" + ) + + # 6. set_custom_instruction_override tool + self.mock_storage.set_custom_instruction_override = AsyncMock() + res = await set_custom_instruction_override("Fly high", tool_context) + assert res["status"] == "success" + self.mock_storage.set_custom_instruction_override.assert_called_once_with( + "tool_user", "Fly high" + ) + + # 7. delete_custom_instruction_override tool + self.mock_storage.delete_custom_instruction_override = AsyncMock( + return_value=True + ) + res = await delete_custom_instruction_override(tool_context) + assert res["status"] == "success" + self.mock_storage.delete_custom_instruction_override.assert_called_once_with( + "tool_user" + ) + + @pytest.mark.anyio + async def test_tools_missing_user_id_errors(self) -> None: + """Verify tools return error when user_id is missing.""" + bad_context = MagicMock() + bad_context.user_id = None + bad_context.state = {} + + # 1. create_custom_table + res = await create_custom_table("tbl", [], bad_context) + assert res["status"] == "error" + assert "user not identified" in res["message"].lower() + + # 2. delete_custom_table + res = await delete_custom_table("tbl", bad_context) + assert res["status"] == "error" + assert "user not identified" in res["message"].lower() + + # 3. create_query_template + res = await create_query_template("tmpl", "tbl", "SELECT", bad_context) + assert res["status"] == "error" + assert "user not identified" in res["message"].lower() + + # 4. execute_query_template + res = await execute_query_template("tmpl", {}, bad_context) + assert res["status"] == "error" + assert "user not identified" in res["message"].lower() + + # 5. list_custom_tables_and_templates + res = await list_custom_tables_and_templates(bad_context) + assert res["status"] == "error" + assert "user not identified" in res["message"].lower() + + # 6. set_custom_instruction_override + res = await set_custom_instruction_override("instr", bad_context) + assert res["status"] == "error" + assert "user not identified" in res["message"].lower() + + # 7. delete_custom_instruction_override + res = await delete_custom_instruction_override(bad_context) + assert res["status"] == "error" + assert "user not identified" in res["message"].lower() + + @pytest.mark.anyio + async def test_tools_exception_handling_and_errors(self) -> None: + """Verify that tools handle and log exceptions properly.""" + tool_context = MagicMock() + tool_context.user_id = "test_user" + tool_context.state = {} + + # 1. create_custom_table raising exception + self.mock_storage.create_custom_table = AsyncMock( + side_effect=Exception("create error") + ) + res = await create_custom_table("tbl", [], tool_context) + assert res["status"] == "error" + assert "failed to create table: create error" in res["message"].lower() + + # 2. delete_custom_table raising exception + self.mock_storage.delete_custom_table = AsyncMock( + side_effect=Exception("delete error") + ) + res = await delete_custom_table("tbl", tool_context) + assert res["status"] == "error" + assert "failed to delete table: delete error" in res["message"].lower() + + # 3. delete_custom_table not found + self.mock_storage.delete_custom_table = AsyncMock(return_value=False) + res = await delete_custom_table("tbl", tool_context) + assert res["status"] == "error" + assert "not found" in res["message"].lower() + + # 4. create_query_template raising exception + self.mock_storage.create_query_template = AsyncMock( + side_effect=Exception("tmpl error") + ) + res = await create_query_template("tmpl", "tbl", "SELECT", tool_context) + assert res["status"] == "error" + assert "failed to register query template: tmpl error" in res["message"].lower() + + # 5. execute_query_template raising exception + self.mock_storage.execute_query_template = AsyncMock( + side_effect=Exception("exec error") + ) + res = await execute_query_template("tmpl", {}, tool_context) + assert res["status"] == "error" + assert "execution failed: exec error" in res["message"].lower() + + # 6. list_custom_tables_and_templates raising exception + self.mock_storage.list_custom_tables_and_templates = AsyncMock( + side_effect=Exception("list error") + ) + res = await list_custom_tables_and_templates(tool_context) + assert res["status"] == "error" + assert "failed to fetch custom schemas: list error" in res["message"].lower() + + # 7. set_custom_instruction_override raising exception + self.mock_storage.set_custom_instruction_override = AsyncMock( + side_effect=Exception("set override error") + ) + res = await set_custom_instruction_override("instr", tool_context) + assert res["status"] == "error" + assert ( + "failed to save instructions: set override error" in res["message"].lower() + ) + + # 8. delete_custom_instruction_override raising exception + self.mock_storage.delete_custom_instruction_override = AsyncMock( + side_effect=Exception("delete override error") + ) + res = await delete_custom_instruction_override(tool_context) + assert res["status"] == "error" + assert ( + "failed to clear instructions: delete override error" + in res["message"].lower() + ) + + # 9. delete_custom_instruction_override not found (returns False) + self.mock_storage.delete_custom_instruction_override = AsyncMock( + return_value=False + ) + res = await delete_custom_instruction_override(tool_context) + assert res["status"] == "success" + assert "no custom instructions existed to delete" in res["message"].lower() diff --git a/tests/test_container.py b/tests/test_container.py index 25a4775..c4a678c 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -195,6 +195,15 @@ async def test_preferences_storage_property(self, conn, lock) -> None: assert storage is not None assert container._preferences_storage is storage + @pytest.mark.asyncio + async def test_declarative_db_storage_property(self, conn, lock) -> None: + """Should lazily instantiate declarative DB storage.""" + container = AppContainer(conn=conn, _lock=lock) + + storage = container.declarative_db_storage + assert storage is not None + assert container._declarative_db_storage is storage + @pytest.mark.asyncio async def test_close_closes_connection_and_storages(self, conn, lock) -> None: """Should close connection and all storage instances.""" @@ -203,9 +212,13 @@ async def test_close_closes_connection_and_storages(self, conn, lock) -> None: reminder = container.reminder_storage reminder.close = AsyncMock() + declarative_db = container.declarative_db_storage + declarative_db.close = AsyncMock() + await container.close() reminder.close.assert_called_once() + declarative_db.close.assert_called_once() @pytest.mark.asyncio async def test_close_storages_resets_references(self, conn, lock) -> None: @@ -216,6 +229,7 @@ async def test_close_storages_resets_references(self, conn, lock) -> None: _ = container.calorie_storage _ = container.workout_storage _ = container.preferences_storage + _ = container.declarative_db_storage await container._close_storages() @@ -223,6 +237,7 @@ async def test_close_storages_resets_references(self, conn, lock) -> None: assert container._calorie_storage is None assert container._workout_storage is None assert container._preferences_storage is None + assert container._declarative_db_storage is None @pytest.mark.asyncio async def test_close_storages_partial(self, conn, lock) -> None: @@ -232,6 +247,7 @@ async def test_close_storages_partial(self, conn, lock) -> None: _ = container.calorie_storage _ = container.workout_storage _ = container.preferences_storage + _ = container.declarative_db_storage await container._close_storages() @@ -239,6 +255,9 @@ async def test_close_storages_partial(self, conn, lock) -> None: assert container._calorie_storage is None assert container._workout_storage is None assert container._preferences_storage is None + assert container._declarative_db_storage is None + assert container._workout_storage is None + assert container._preferences_storage is None @pytest.mark.asyncio async def test_initialize_all_storages(self, conn, lock) -> None: diff --git a/tests/test_registry.py b/tests/test_registry.py index 1b6d5a5..de4bf49 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -264,3 +264,15 @@ def test_returns_empty_for_nonexistent_dir(self) -> None: tools = _build_skill_tools(Path("/nonexistent/skills")) assert tools == [] + + +class TestBuildDeclarativeDbTools: + """Tests for _build_declarative_db_tools.""" + + def test_returns_tools_when_available(self) -> None: + """Should return declarative database tools when available.""" + from blacki.registry import _build_declarative_db_tools + + tools = _build_declarative_db_tools() + + assert len(tools) == 7