diff --git a/src/blacki/agent.py b/src/blacki/agent.py
index 03eaea7..1b138c4 100644
--- a/src/blacki/agent.py
+++ b/src/blacki/agent.py
@@ -199,12 +199,15 @@ def create_app(agent: LlmAgent | None = None) -> App:
if agent is None:
agent = create_agent()
+ from blacki.declarative_db.plugin import DeclarativeDbPlugin
+
return App(
name="blacki",
root_agent=agent,
plugins=[
TelegramModelOverridePlugin(name="telegram_model_override"),
GlobalInstructionPlugin(return_global_instruction),
+ DeclarativeDbPlugin(name="declarative_db"),
LoggingPlugin(),
],
events_compaction_config=None,
diff --git a/src/blacki/container.py b/src/blacki/container.py
index 2c72070..71e4b06 100644
--- a/src/blacki/container.py
+++ b/src/blacki/container.py
@@ -27,6 +27,7 @@
import aiosqlite
from blacki.calories.storage import SqliteCalorieStorage
+ from blacki.declarative_db.storage import SqliteDeclarativeDbStorage
from blacki.reminders.storage import SqliteReminderStorage
from blacki.utils.preferences import SqlitePreferencesStorage
from blacki.workouts.storage import SqliteWorkoutStorage
@@ -134,6 +135,9 @@ class AppContainer:
_preferences_storage: SqlitePreferencesStorage | None = field(
default=None, init=False, repr=False
)
+ _declarative_db_storage: SqliteDeclarativeDbStorage | None = field(
+ default=None, init=False, repr=False
+ )
@classmethod
async def create(cls, sqlite_path: str | Path) -> Self:
@@ -174,6 +178,10 @@ async def _close_storages(self) -> None:
await self._preferences_storage.close()
self._preferences_storage = None
+ if self._declarative_db_storage is not None:
+ await self._declarative_db_storage.close()
+ self._declarative_db_storage = None
+
async def initialize_all_storages(self) -> None:
"""Initialize all storage instances.
@@ -184,6 +192,7 @@ async def initialize_all_storages(self) -> None:
await self.calorie_storage.initialize()
await self.workout_storage.initialize()
await self.preferences_storage.initialize()
+ await self.declarative_db_storage.initialize()
@property
def lock(self) -> asyncio.Lock:
@@ -225,3 +234,14 @@ def preferences_storage(self) -> SqlitePreferencesStorage:
self._preferences_storage = SqlitePreferencesStorage(self.conn, self._lock)
return self._preferences_storage
+
+ @property
+ def declarative_db_storage(self) -> SqliteDeclarativeDbStorage:
+ """Get or create the declarative database storage instance."""
+ if self._declarative_db_storage is None:
+ from blacki.declarative_db.storage import SqliteDeclarativeDbStorage
+
+ self._declarative_db_storage = SqliteDeclarativeDbStorage(
+ self.conn, self._lock
+ )
+ return self._declarative_db_storage
diff --git a/src/blacki/declarative_db/plugin.py b/src/blacki/declarative_db/plugin.py
new file mode 100644
index 0000000..e14eaca
--- /dev/null
+++ b/src/blacki/declarative_db/plugin.py
@@ -0,0 +1,58 @@
+"""ADK plugin for injecting declarative database instructions."""
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING
+
+from google.adk.plugins.base_plugin import BasePlugin
+
+from blacki.declarative_db.storage import get_declarative_db_storage
+
+if TYPE_CHECKING:
+ from google.adk.agents.callback_context import CallbackContext
+ from google.adk.models.llm_request import LlmRequest
+
+logger = logging.getLogger(__name__)
+
+
+class DeclarativeDbPlugin(BasePlugin):
+ """ADK plugin to dynamically load declarative database schemas.
+
+ Loads both schemas and custom instruction overrides.
+ This plugin queries the storage layer using the active session user's ID
+ and appends user-defined tables, query templates, and instruction overrides
+ to the system instruction context on every LLM call.
+ """
+
+ def __init__(self, name: str = "declarative_db") -> None:
+ super().__init__(name=name)
+
+ async def before_model_callback(
+ self, *, callback_context: CallbackContext, llm_request: LlmRequest
+ ) -> None:
+ """Callback executed before the LLM is invoked."""
+ if not callback_context.session:
+ return
+
+ user_id = callback_context.session.state.get(
+ "user_id"
+ ) or callback_context.session.state.get("telegram_chat_id")
+ if not user_id:
+ logger.debug(
+ "before_model_callback: No user_id or telegram_chat_id in session state"
+ )
+ return
+
+ try:
+ storage = get_declarative_db_storage()
+ schema_xml = await storage.get_schema_instructions_xml(str(user_id))
+ if schema_xml:
+ logger.info(
+ "Injecting custom database instructions for user %s (%d chars)",
+ user_id,
+ len(schema_xml),
+ )
+ llm_request.append_instructions([schema_xml])
+ except Exception:
+ logger.exception("Failed to inject user database schemas and instructions")
diff --git a/src/blacki/declarative_db/storage.py b/src/blacki/declarative_db/storage.py
new file mode 100644
index 0000000..6efea21
--- /dev/null
+++ b/src/blacki/declarative_db/storage.py
@@ -0,0 +1,783 @@
+"""SQLite storage implementation for declarative tables and query templates."""
+
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import json
+import logging
+from typing import TYPE_CHECKING, Any
+
+from blacki.declarative_db.validation import validate_column_type, validate_identifier
+from blacki.storage.base import SqlStorage
+from blacki.utils.timezone import now_utc
+
+if TYPE_CHECKING:
+ import aiosqlite
+
+logger = logging.getLogger(__name__)
+
+
+class SqliteDeclarativeDbStorage(SqlStorage):
+ """Storage implementation for custom schemas, query templates, and overrides."""
+
+ def __init__(self, conn: aiosqlite.Connection, lock: asyncio.Lock) -> None:
+ super().__init__(conn, lock)
+
+ async def _create_tables(self) -> None:
+ # Create metadata table for tracking tables
+ await self._conn.execute("""
+ CREATE TABLE IF NOT EXISTS custom_tables (
+ user_id TEXT NOT NULL,
+ table_name TEXT NOT NULL,
+ physical_name TEXT NOT NULL,
+ description TEXT,
+ created_at TEXT NOT NULL,
+ PRIMARY KEY (user_id, table_name)
+ )
+ """)
+
+ # Create metadata table for column definitions
+ await self._conn.execute("""
+ CREATE TABLE IF NOT EXISTS custom_table_columns (
+ user_id TEXT NOT NULL,
+ table_name TEXT NOT NULL,
+ column_name TEXT NOT NULL,
+ column_type TEXT NOT NULL,
+ is_primary_key INTEGER NOT NULL DEFAULT 0,
+ is_not_null INTEGER NOT NULL DEFAULT 0,
+ default_value TEXT,
+ PRIMARY KEY (user_id, table_name, column_name),
+ FOREIGN KEY (user_id, table_name)
+ REFERENCES custom_tables(user_id, table_name) ON DELETE CASCADE
+ )
+ """)
+
+ # Create metadata table for query templates
+ await self._conn.execute("""
+ CREATE TABLE IF NOT EXISTS saved_query_templates (
+ user_id TEXT NOT NULL,
+ template_name TEXT NOT NULL,
+ table_name TEXT NOT NULL,
+ query_type TEXT NOT NULL,
+ description TEXT,
+ select_columns TEXT, -- JSON list of column names
+ filter_columns TEXT, -- JSON list of column names
+ order_by_column TEXT,
+ order_by_direction TEXT, -- ASC, DESC
+ limit_val INTEGER,
+ PRIMARY KEY (user_id, template_name),
+ FOREIGN KEY (user_id, table_name)
+ REFERENCES custom_tables(user_id, table_name) ON DELETE CASCADE
+ )
+ """)
+
+ # Create metadata table for user-scoped instructions
+ await self._conn.execute("""
+ CREATE TABLE IF NOT EXISTS custom_instruction_overrides (
+ user_id TEXT NOT NULL,
+ key TEXT NOT NULL,
+ instructions TEXT NOT NULL,
+ updated_at TEXT NOT NULL,
+ PRIMARY KEY (user_id, key)
+ )
+ """)
+
+ def _get_physical_name(self, user_id: str, table_name: str) -> str:
+ """Derive a safe physical table name scoped by hashed user_id."""
+ user_hash = hashlib.md5(
+ user_id.encode("utf-8"), usedforsecurity=False
+ ).hexdigest()[:16]
+ return f"usr_{user_hash}_{table_name}"
+
+ async def create_custom_table(
+ self,
+ user_id: str,
+ table_name: str,
+ columns: list[dict[str, Any]],
+ description: str | None = None,
+ ) -> None:
+ """Create a safe custom physical table and store its metadata.
+
+ Args:
+ user_id: The scoping user ID.
+ table_name: Logical table name.
+ columns: List of columns containing name, type, primary_key,
+ not_null, and default keys.
+ description: Optional table description.
+ """
+ validate_identifier(table_name)
+ physical_name = self._get_physical_name(user_id, table_name)
+
+ if not columns:
+ raise ValueError("Table must define at least one column")
+
+ # Validate columns and prepare dynamic DDL components
+ validated_cols: list[dict[str, Any]] = []
+ ddl_parts: list[str] = []
+
+ for col in columns:
+ col_name = col.get("name", "").strip()
+ col_type = col.get("type", "").strip().upper()
+ is_pk = int(bool(col.get("primary_key")))
+ is_nn = int(bool(col.get("not_null")))
+ default_val = col.get("default")
+
+ validate_identifier(col_name)
+ validate_column_type(col_type)
+
+ default_str = None
+ if default_val is not None:
+ if isinstance(default_val, bool):
+ default_str = "1" if default_val else "0"
+ part_def = f" DEFAULT {default_str}"
+ elif isinstance(default_val, (int, float)):
+ default_str = str(default_val)
+ part_def = f" DEFAULT {default_str}"
+ elif isinstance(default_val, str):
+ escaped_default = default_val.replace("'", "''")
+ part_def = f" DEFAULT '{escaped_default}'"
+ default_str = default_val
+ elif isinstance(default_val, (list, dict)):
+ default_str = json.dumps(default_val)
+ escaped_default = default_str.replace("'", "''")
+ part_def = f" DEFAULT '{escaped_default}'"
+ else:
+ default_str = str(default_val)
+ escaped_default = default_str.replace("'", "''")
+ part_def = f" DEFAULT '{escaped_default}'"
+ else:
+ part_def = ""
+
+ validated_cols.append(
+ {
+ "name": col_name,
+ "type": col_type,
+ "is_pk": is_pk,
+ "is_nn": is_nn,
+ "default": default_str,
+ }
+ )
+
+ # Form column DDL segment
+ part = f'"{col_name}" {col_type}'
+ if is_pk:
+ part += " PRIMARY KEY"
+ if is_nn:
+ part += " NOT NULL"
+ part += part_def
+
+ ddl_parts.append(part)
+
+ # Build physical CREATE TABLE string
+ create_ddl = (
+ f'CREATE TABLE IF NOT EXISTS "{physical_name}" (\n '
+ + ",\n ".join(ddl_parts)
+ + "\n)"
+ )
+
+ now = now_utc().isoformat(timespec="seconds")
+
+ async with self._lock:
+ # Check if table already exists
+ table_exists_row = await self._fetch_one(
+ "SELECT 1 FROM custom_tables WHERE user_id = ? AND table_name = ?",
+ (user_id, table_name),
+ )
+ table_exists = table_exists_row is not None
+
+ if not table_exists:
+ # Check table counts guardrail
+ existing_count_row = await self._fetch_one(
+ "SELECT COUNT(*) as count FROM custom_tables WHERE user_id = ?",
+ (user_id,),
+ )
+ existing_count = (
+ existing_count_row["count"] if existing_count_row else 0
+ )
+ if existing_count >= 5:
+ raise ValueError("Limit of 5 custom tables per user reached")
+
+ if len(columns) > 15:
+ raise ValueError("Limit of 15 columns per custom table reached")
+
+ # Initialize physical table and save metadata within one serial transaction
+ await self._conn.execute("BEGIN")
+ try:
+ if table_exists:
+ # Drop old table to prevent column/metadata desync
+ await self._conn.execute(f'DROP TABLE IF EXISTS "{physical_name}"')
+ await self._conn.execute(
+ "DELETE FROM custom_tables "
+ "WHERE user_id = ? AND table_name = ?",
+ (user_id, table_name),
+ )
+
+ # Physically create table
+ await self._conn.execute(create_ddl)
+
+ # Save table metadata
+ await self._conn.execute(
+ """
+ INSERT INTO custom_tables (
+ user_id, table_name, physical_name, description, created_at
+ ) VALUES (?, ?, ?, ?, ?)
+ ON CONFLICT (user_id, table_name) DO UPDATE SET
+ description = excluded.description,
+ physical_name = excluded.physical_name
+ """,
+ (user_id, table_name, physical_name, description, now),
+ )
+
+ # Delete any old metadata columns
+ await self._conn.execute(
+ "DELETE FROM custom_table_columns "
+ "WHERE user_id = ? AND table_name = ?",
+ (user_id, table_name),
+ )
+
+ # Save column metadata
+ for c in validated_cols:
+ await self._conn.execute(
+ """
+ INSERT INTO custom_table_columns (
+ user_id, table_name, column_name, column_type,
+ is_primary_key, is_not_null, default_value
+ ) VALUES (?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ user_id,
+ table_name,
+ c["name"],
+ c["type"],
+ c["is_pk"],
+ c["is_nn"],
+ c["default"],
+ ),
+ )
+
+ await self._conn.execute("COMMIT")
+ logger.info(
+ "Successfully created table %s with %d columns",
+ table_name,
+ len(validated_cols),
+ )
+ except Exception as e:
+ await self._conn.execute("ROLLBACK")
+ logger.exception("Failed to create custom table %s", table_name)
+ raise e
+
+ async def delete_custom_table(self, user_id: str, table_name: str) -> bool:
+ """Physically drop a table and delete its metadata records."""
+ validate_identifier(table_name)
+ physical_name = self._get_physical_name(user_id, table_name)
+
+ async with self._lock:
+ # Check if table metadata exists first
+ row = await self._fetch_one(
+ "SELECT table_name FROM custom_tables "
+ "WHERE user_id = ? AND table_name = ?",
+ (user_id, table_name),
+ )
+ if not row:
+ return False
+
+ await self._conn.execute("BEGIN")
+ try:
+ # Drop physical table
+ await self._conn.execute(f'DROP TABLE IF EXISTS "{physical_name}"')
+
+ # Clear metadata (ON DELETE CASCADE propagates to columns & templates)
+ await self._conn.execute(
+ "DELETE FROM custom_tables WHERE user_id = ? AND table_name = ?",
+ (user_id, table_name),
+ )
+
+ await self._conn.execute("COMMIT")
+ logger.info("Successfully dropped table %s", table_name)
+ return True
+ except Exception as e:
+ await self._conn.execute("ROLLBACK")
+ logger.exception("Failed to drop custom table %s", table_name)
+ raise e
+
+ async def create_query_template(
+ self,
+ user_id: str,
+ template_name: str,
+ table_name: str,
+ query_type: str,
+ select_columns: list[str] | None = None,
+ filter_columns: list[str] | None = None,
+ order_by_column: str | None = None,
+ order_by_direction: str | None = None,
+ limit_val: int | None = None,
+ description: str | None = None,
+ ) -> None:
+ """Create or save a query template against a logical custom table."""
+ validate_identifier(template_name)
+ validate_identifier(table_name)
+ query_type = query_type.strip().upper()
+ if query_type not in {"SELECT", "INSERT", "UPDATE", "DELETE"}:
+ raise ValueError("Query type must be SELECT, INSERT, UPDATE, or DELETE")
+
+ # Validate identifiers inside lists
+ if select_columns:
+ for s in select_columns:
+ validate_identifier(s)
+ if filter_columns:
+ for f in filter_columns:
+ validate_identifier(f)
+ if order_by_column:
+ validate_identifier(order_by_column)
+ if order_by_direction:
+ order_by_dir_upper = order_by_direction.strip().upper()
+ if order_by_dir_upper not in {"ASC", "DESC"}:
+ raise ValueError("Sorting direction must be ASC or DESC")
+ order_by_direction = order_by_dir_upper
+
+ async with self._lock:
+ # Check if template already exists
+ template_exists_row = await self._fetch_one(
+ "SELECT 1 FROM saved_query_templates "
+ "WHERE user_id = ? AND template_name = ?",
+ (user_id, template_name),
+ )
+ template_exists = template_exists_row is not None
+
+ if not template_exists:
+ # Check templates limit (10 per user)
+ existing_count_row = await self._fetch_one(
+ "SELECT COUNT(*) as count FROM saved_query_templates "
+ "WHERE user_id = ?",
+ (user_id,),
+ )
+ existing_count = (
+ existing_count_row["count"] if existing_count_row else 0
+ )
+ if existing_count >= 10:
+ raise ValueError(
+ "Limit of 10 saved query templates per user reached"
+ )
+
+ # Check if custom table exists
+ table_row = await self._fetch_one(
+ "SELECT table_name FROM custom_tables "
+ "WHERE user_id = ? AND table_name = ?",
+ (user_id, table_name),
+ )
+ if not table_row:
+ raise ValueError(f"Table '{table_name}' does not exist for this user")
+
+ # Validate column existence against metadata
+ allowed_cols_rows = await self._fetch_all(
+ "SELECT column_name FROM custom_table_columns "
+ "WHERE user_id = ? AND table_name = ?",
+ (user_id, table_name),
+ )
+ allowed_cols = {r["column_name"] for r in allowed_cols_rows}
+
+ for col_list, label in [
+ (select_columns, "select_columns"),
+ (filter_columns, "filter_columns"),
+ ]:
+ if col_list:
+ for c in col_list:
+ if c not in allowed_cols:
+ raise ValueError(
+ f"Column '{c}' in {label} does not "
+ f"exist in table '{table_name}'"
+ )
+
+ if order_by_column and order_by_column not in allowed_cols:
+ raise ValueError(
+ f"Column '{order_by_column}' in order_by_column "
+ f"does not exist in table '{table_name}'"
+ )
+
+ # Insert metadata into saved templates
+ await self._conn.execute(
+ """
+ INSERT INTO saved_query_templates (
+ user_id, template_name, table_name, query_type, description,
+ select_columns, filter_columns, order_by_column,
+ order_by_direction, limit_val
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ ON CONFLICT (user_id, template_name) DO UPDATE SET
+ table_name = excluded.table_name,
+ query_type = excluded.query_type,
+ description = excluded.description,
+ select_columns = excluded.select_columns,
+ filter_columns = excluded.filter_columns,
+ order_by_column = excluded.order_by_column,
+ order_by_direction = excluded.order_by_direction,
+ limit_val = excluded.limit_val
+ """,
+ (
+ user_id,
+ template_name,
+ table_name,
+ query_type,
+ description,
+ json.dumps(select_columns) if select_columns is not None else None,
+ json.dumps(filter_columns) if filter_columns is not None else None,
+ order_by_column,
+ order_by_direction,
+ limit_val,
+ ),
+ )
+ logger.info(
+ "Saved query template %s against table %s", template_name, table_name
+ )
+
+ async def execute_query_template(
+ self,
+ user_id: str,
+ template_name: str,
+ parameters: dict[str, Any],
+ ) -> list[dict[str, Any]] | int:
+ """Securely build and run a saved template query with validation."""
+ validate_identifier(template_name)
+
+ # 1. Fetch template from DB
+ template = await self._fetch_one(
+ "SELECT * FROM saved_query_templates "
+ "WHERE user_id = ? AND template_name = ?",
+ (user_id, template_name),
+ )
+ if not template:
+ raise ValueError(f"Template '{template_name}' not found")
+
+ table_name = template["table_name"]
+ query_type = template["query_type"]
+
+ # Fetch physical table name
+ table_meta = await self._fetch_one(
+ "SELECT physical_name FROM custom_tables "
+ "WHERE user_id = ? AND table_name = ?",
+ (user_id, table_name),
+ )
+ if not table_meta:
+ raise ValueError(f"Underlying table '{table_name}' does not exist")
+ physical_name = table_meta["physical_name"]
+
+ # Fetch valid columns
+ valid_cols_rows = await self._fetch_all(
+ "SELECT column_name FROM custom_table_columns "
+ "WHERE user_id = ? AND table_name = ?",
+ (user_id, table_name),
+ )
+ valid_cols = {r["column_name"] for r in valid_cols_rows}
+
+ # 2. Strict Parameter Validation
+ # Verify that all keys supplied in parameters are valid matching column names
+ for key in parameters:
+ validate_identifier(key)
+ if key not in valid_cols:
+ raise ValueError(
+ f"Parameter key '{key}' is not a valid "
+ f"column for table '{table_name}'"
+ )
+
+ # Parse column lists from JSON
+ select_cols = (
+ json.loads(template["select_columns"]) if template["select_columns"] else []
+ )
+ filter_cols = (
+ json.loads(template["filter_columns"]) if template["filter_columns"] else []
+ )
+
+ # 3. Secure Query Assembly
+ sql = ""
+ bindings: tuple[Any, ...] = ()
+
+ if query_type == "SELECT":
+ # Form SELECT segment
+ cols_str = ", ".join(f'"{c}"' for c in select_cols) if select_cols else "*"
+
+ # Validated & quoted identifiers: SQL safe.
+ sql = f'SELECT {cols_str} FROM "{physical_name}"' # noqa: S608
+
+ # Form WHERE segment
+ if filter_cols:
+ where_segments = []
+ where_bindings = []
+ for c in filter_cols:
+ if c in parameters:
+ where_segments.append(f'"{c}" = ?')
+ where_bindings.append(parameters[c])
+ else:
+ raise ValueError(f"Missing required filter parameter '{c}'")
+ sql += " WHERE " + " AND ".join(where_segments)
+ bindings = tuple(where_bindings)
+
+ # Form ORDER BY segment
+ if template["order_by_column"]:
+ ob_col = template["order_by_column"]
+ ob_dir = template["order_by_direction"] or "ASC"
+ sql += f' ORDER BY "{ob_col}" {ob_dir}'
+
+ # Form LIMIT segment
+ if template["limit_val"] is not None:
+ sql += f" LIMIT {int(template['limit_val'])}"
+
+ # Run SELECT
+ return await self._fetch_all(sql, bindings)
+
+ elif query_type == "INSERT":
+ # Form INSERT parameters from direct parameters matching columns
+ insert_keys = [k for k in parameters if k in valid_cols]
+ if not insert_keys:
+ raise ValueError("Must provide at least one valid parameter to insert")
+
+ cols_str = ", ".join(f'"{k}"' for k in insert_keys)
+ placeholders = ", ".join("?" for _ in insert_keys)
+ sql = f'INSERT INTO "{physical_name}" ({cols_str}) VALUES ({placeholders})' # noqa: S608
+ bindings = tuple(parameters[k] for k in insert_keys)
+
+ # Run write insert
+ return await self._execute(sql, bindings)
+
+ elif query_type == "UPDATE":
+ # Form UPDATE columns from params not in filter_columns
+ update_keys = [
+ k for k in parameters if k in valid_cols and k not in filter_cols
+ ]
+ if not update_keys:
+ raise ValueError("No columns provided to update")
+
+ set_segments = [f'"{k}" = ?' for k in update_keys]
+ set_bindings = [parameters[k] for k in update_keys]
+
+ sql = f'UPDATE "{physical_name}" SET ' + ", ".join(set_segments) # noqa: S608
+
+ # Form WHERE segment
+ if filter_cols:
+ where_segments = []
+ where_bindings = []
+ for c in filter_cols:
+ if c in parameters:
+ where_segments.append(f'"{c}" = ?')
+ where_bindings.append(parameters[c])
+ else:
+ raise ValueError(
+ f"Missing required update filter parameter '{c}'"
+ )
+ sql += " WHERE " + " AND ".join(where_segments)
+ bindings = tuple(set_bindings + where_bindings)
+ else:
+ bindings = tuple(set_bindings)
+
+ # Run write update
+ async with self._lock, self._conn.execute(sql, bindings) as cursor:
+ return cursor.rowcount
+
+ elif query_type == "DELETE":
+ sql = f'DELETE FROM "{physical_name}"' # noqa: S608
+ if filter_cols:
+ where_segments = []
+ where_bindings = []
+ for c in filter_cols:
+ if c in parameters:
+ where_segments.append(f'"{c}" = ?')
+ where_bindings.append(parameters[c])
+ else:
+ raise ValueError(
+ f"Missing required delete filter parameter '{c}'"
+ )
+ sql += " WHERE " + " AND ".join(where_segments)
+ bindings = tuple(where_bindings)
+
+ # Run write delete
+ async with self._lock, self._conn.execute(sql, bindings) as cursor:
+ return cursor.rowcount
+
+ else:
+ raise ValueError(f"Unsupported query type: {query_type}")
+
+ async def set_custom_instruction_override(
+ self, user_id: str, instructions: str
+ ) -> None:
+ """Update or insert a custom instruction override."""
+ now = now_utc().isoformat(timespec="seconds")
+ async with self._lock:
+ await self._conn.execute(
+ """
+ INSERT INTO custom_instruction_overrides (
+ user_id, key, instructions, updated_at
+ ) VALUES (?, 'custom_instructions', ?, ?)
+ ON CONFLICT (user_id, key) DO UPDATE SET
+ instructions = excluded.instructions,
+ updated_at = excluded.updated_at
+ """,
+ (user_id, instructions, now),
+ )
+ logger.info("Updated custom instructions for user %s", user_id)
+
+ async def get_custom_instruction_override(self, user_id: str) -> str | None:
+ """Retrieve custom instruction override for user."""
+ row = await self._fetch_one(
+ "SELECT instructions FROM custom_instruction_overrides "
+ "WHERE user_id = ? AND key = 'custom_instructions'",
+ (user_id,),
+ )
+ return row["instructions"] if row else None
+
+ async def delete_custom_instruction_override(self, user_id: str) -> bool:
+ """Clear custom instruction override for user."""
+ async with self._lock:
+ cursor = await self._conn.execute(
+ "DELETE FROM custom_instruction_overrides "
+ "WHERE user_id = ? AND key = 'custom_instructions'",
+ (user_id,),
+ )
+ return cursor.rowcount > 0
+
+ async def list_custom_tables_and_templates(self, user_id: str) -> dict[str, Any]:
+ """List tables, columns, and saved templates for a given user."""
+ tables = await self._fetch_all(
+ "SELECT table_name, description, created_at FROM custom_tables "
+ "WHERE user_id = ? ORDER BY table_name ASC",
+ (user_id,),
+ )
+ columns = await self._fetch_all(
+ "SELECT table_name, column_name, column_type, is_primary_key, "
+ "is_not_null, default_value FROM custom_table_columns "
+ "WHERE user_id = ? ORDER BY table_name ASC, column_name ASC",
+ (user_id,),
+ )
+ templates = await self._fetch_all(
+ "SELECT template_name, table_name, query_type, description, "
+ "select_columns, filter_columns, order_by_column, "
+ "order_by_direction, limit_val FROM saved_query_templates "
+ "WHERE user_id = ? ORDER BY template_name ASC",
+ (user_id,),
+ )
+
+ # Structure response
+ structured_tables: dict[str, Any] = {}
+ for t in tables:
+ t_name = t["table_name"]
+ structured_tables[t_name] = {
+ "description": t["description"],
+ "created_at": t["created_at"],
+ "columns": [],
+ "templates": [],
+ }
+
+ for col in columns:
+ t_name = col["table_name"]
+ if t_name in structured_tables:
+ structured_tables[t_name]["columns"].append(
+ {
+ "name": col["column_name"],
+ "type": col["column_type"],
+ "primary_key": bool(col["is_primary_key"]),
+ "not_null": bool(col["is_not_null"]),
+ "default": col["default_value"],
+ }
+ )
+
+ for tmpl in templates:
+ t_name = tmpl["table_name"]
+ if t_name in structured_tables:
+ structured_tables[t_name]["templates"].append(
+ {
+ "name": tmpl["template_name"],
+ "type": tmpl["query_type"],
+ "description": tmpl["description"],
+ "select_columns": json.loads(tmpl["select_columns"])
+ if tmpl["select_columns"]
+ else None,
+ "filter_columns": json.loads(tmpl["filter_columns"])
+ if tmpl["filter_columns"]
+ else None,
+ "order_by": (
+ f"{tmpl['order_by_column']} {tmpl['order_by_direction']}"
+ if tmpl["order_by_column"]
+ else None
+ ),
+ "limit": tmpl["limit_val"],
+ }
+ )
+
+ return structured_tables
+
+ async def get_schema_instructions_xml(self, user_id: str) -> str:
+ """Compile schemas and overrides into instructions XML."""
+ schema_data = await self.list_custom_tables_and_templates(user_id)
+ override = await self.get_custom_instruction_override(user_id)
+
+ blocks: list[str] = []
+
+ if override:
+ blocks.append(
+ f"\n{override}\n"
+ )
+
+ if schema_data:
+ schema_lines = [
+ "You have access to the following custom user-defined database tables:"
+ ]
+ for t_name, t_val in schema_data.items():
+ schema_lines.append(f"\nTable: {t_name}")
+ if t_val["description"]:
+ schema_lines.append(f" Description: {t_val['description']}")
+ schema_lines.append(" Columns:")
+ for col in t_val["columns"]:
+ pk_label = " PRIMARY KEY" if col["primary_key"] else ""
+ nn_label = " NOT NULL" if col["not_null"] else ""
+ def_label = (
+ f" DEFAULT {col['default']}"
+ if col["default"] is not None
+ else ""
+ )
+ col_info = (
+ f" - {col['name']} ({col['type']})"
+ f"{pk_label}{nn_label}{def_label}"
+ )
+ schema_lines.append(col_info)
+
+ if t_val["templates"]:
+ schema_lines.append(" Saved Query Templates:")
+ for tmpl in t_val["templates"]:
+ desc = (
+ f" ({tmpl['description']})" if tmpl["description"] else ""
+ )
+ schema_lines.append(f" - Template: {tmpl['name']}{desc}")
+ schema_lines.append(f" Operation: {tmpl['type']}")
+ if tmpl["select_columns"]:
+ cols_joined = ", ".join(tmpl["select_columns"])
+ schema_lines.append(f" Returns Columns: {cols_joined}")
+ if tmpl["filter_columns"]:
+ params_joined = ", ".join(tmpl["filter_columns"])
+ schema_lines.append(
+ f" Required Parameters: {params_joined}"
+ )
+ if tmpl["order_by"]:
+ schema_lines.append(f" Sorted By: {tmpl['order_by']}")
+ if tmpl["limit"]:
+ schema_lines.append(f" Limit: {tmpl['limit']}")
+
+ schema_str = "\n".join(schema_lines)
+ blocks.append(
+ f"\n{schema_str}\n"
+ )
+
+ return "\n\n".join(blocks) if blocks else ""
+
+
+_storage: SqliteDeclarativeDbStorage | None = None
+
+
+def get_declarative_db_storage() -> SqliteDeclarativeDbStorage:
+ """Return the process-wide singleton SqliteDeclarativeDbStorage instance."""
+ from blacki.container import get_container
+
+ container = get_container()
+ storage = container.declarative_db_storage
+ if not storage.is_initialized:
+ raise RuntimeError(
+ "Declarative DB storage not initialized. Call storage.initialize() first."
+ )
+ return storage
diff --git a/src/blacki/declarative_db/tools.py b/src/blacki/declarative_db/tools.py
new file mode 100644
index 0000000..3fd478f
--- /dev/null
+++ b/src/blacki/declarative_db/tools.py
@@ -0,0 +1,351 @@
+"""Agent-facing ADK tools for declarative databases."""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+
+from google.adk.tools import ToolContext
+
+from blacki.declarative_db.storage import get_declarative_db_storage
+
+logger = logging.getLogger(__name__)
+
+
+async def create_custom_table(
+ table_name: str,
+ columns: list[dict[str, Any]],
+ tool_context: ToolContext,
+ description: str | None = None,
+) -> dict[str, Any]:
+ """Create a new custom physical table and register its metadata.
+
+ The physical SQLite table is securely created with the specified columns and
+ types, and is scoped and isolated per user.
+
+ Args:
+ table_name: The name of the table to create (lowercase, alphanumeric).
+ columns: A list of dicts specifying column definitions. Each column dict
+ can have:
+ - "name" (str): Column name (alphanumeric, lowercase).
+ - "type" (str): Column type. Restricted strictly to:
+ "TEXT", "INTEGER", "REAL", "BLOB".
+ - "primary_key" (bool): Whether column is primary key.
+ - "not_null" (bool): Whether column cannot be null.
+ - "default" (any): Default column value.
+ tool_context: ADK tool context.
+ description: A short optional description of the table's purpose.
+
+ Returns:
+ A status dictionary confirming creation or detailing validation errors.
+ """
+ user_id = getattr(tool_context, "user_id", None) or tool_context.state.get(
+ "user_id"
+ )
+ if not user_id:
+ return {
+ "status": "error",
+ "message": "User not identified in tool context.",
+ }
+
+ try:
+ storage = get_declarative_db_storage()
+ await storage.create_custom_table(
+ user_id=str(user_id),
+ table_name=table_name,
+ columns=columns,
+ description=description,
+ )
+ return {
+ "status": "success",
+ "message": (
+ f"Table '{table_name}' was successfully created with "
+ f"{len(columns)} columns."
+ ),
+ }
+ except Exception as e:
+ logger.exception("Failed to create custom table %s", table_name)
+ return {
+ "status": "error",
+ "message": f"Failed to create table: {e}",
+ }
+
+
+async def delete_custom_table(
+ table_name: str,
+ tool_context: ToolContext,
+) -> dict[str, Any]:
+ """Physically drop a custom table and delete all its query templates.
+
+ Args:
+ table_name: The logical name of the table to drop.
+ tool_context: ADK tool context.
+
+ Returns:
+ A status dictionary.
+ """
+ user_id = getattr(tool_context, "user_id", None) or tool_context.state.get(
+ "user_id"
+ )
+ if not user_id:
+ return {
+ "status": "error",
+ "message": "User not identified in tool context.",
+ }
+
+ try:
+ storage = get_declarative_db_storage()
+ deleted = await storage.delete_custom_table(str(user_id), table_name)
+ if deleted:
+ return {
+ "status": "success",
+ "message": (
+ f"Table '{table_name}' and all its templates "
+ "were successfully deleted."
+ ),
+ }
+ return {
+ "status": "error",
+ "message": f"Table '{table_name}' was not found.",
+ }
+ except Exception as e:
+ logger.exception("Failed to delete custom table %s", table_name)
+ return {
+ "status": "error",
+ "message": f"Failed to delete table: {e}",
+ }
+
+
+async def create_query_template(
+ template_name: str,
+ table_name: str,
+ query_type: str,
+ tool_context: ToolContext,
+ select_columns: list[str] | None = None,
+ filter_columns: list[str] | None = None,
+ order_by_column: str | None = None,
+ order_by_direction: str | None = None,
+ limit_val: int | None = None,
+ description: str | None = None,
+) -> dict[str, Any]:
+ """Save a query template against a custom user table.
+
+ Rather than compiling raw string statements, templates define structure.
+ The storage manager compiles them securely into parameterized queries
+ during runtime execution.
+
+ Args:
+ template_name: The name of this query template (unique per user).
+ table_name: The logical table this query executes against.
+ query_type: The action of the query. Must be one of:
+ "SELECT", "INSERT", "UPDATE", "DELETE".
+ tool_context: ADK tool context.
+ select_columns: List of columns to return (for SELECT). If empty, returns all.
+ filter_columns: List of columns to bind parameters to inside the WHERE clause
+ (e.g., ["id"] maps to WHERE "id" = ?). All must be
+ passed in parameters during execution.
+ order_by_column: Name of column to sort on (for SELECT).
+ order_by_direction: Sorting direction (for SELECT). Must be "ASC" or "DESC".
+ limit_val: Optional integer pagination constraint (for SELECT).
+ description: A brief optional description of what the query template does.
+
+ Returns:
+ A status dictionary.
+ """
+ user_id = getattr(tool_context, "user_id", None) or tool_context.state.get(
+ "user_id"
+ )
+ if not user_id:
+ return {
+ "status": "error",
+ "message": "User not identified in tool context.",
+ }
+
+ try:
+ storage = get_declarative_db_storage()
+ await storage.create_query_template(
+ user_id=str(user_id),
+ template_name=template_name,
+ table_name=table_name,
+ query_type=query_type,
+ select_columns=select_columns,
+ filter_columns=filter_columns,
+ order_by_column=order_by_column,
+ order_by_direction=order_by_direction,
+ limit_val=limit_val,
+ description=description,
+ )
+ return {
+ "status": "success",
+ "message": (
+ f"Query template '{template_name}' was successfully "
+ f"registered for table '{table_name}'."
+ ),
+ }
+ except Exception as e:
+ logger.exception("Failed to create query template %s", template_name)
+ return {
+ "status": "error",
+ "message": f"Failed to register query template: {e}",
+ }
+
+
+async def execute_query_template(
+ template_name: str,
+ parameters: dict[str, Any],
+ tool_context: ToolContext,
+) -> dict[str, Any]:
+ """Execute a saved query template securely with inputs.
+
+ Args:
+ template_name: Name of the template to run.
+ parameters: A dictionary of key-value bindings. Every key must map
+ exactly to a valid column defined for the underlying table.
+ tool_context: ADK tool context.
+
+ Returns:
+ A result dictionary containing status and output records (for SELECT) or
+ affected rows count / last inserted ID (for write queries).
+ """
+ user_id = getattr(tool_context, "user_id", None) or tool_context.state.get(
+ "user_id"
+ )
+ if not user_id:
+ return {
+ "status": "error",
+ "message": "User not identified in tool context.",
+ }
+
+ try:
+ storage = get_declarative_db_storage()
+ res = await storage.execute_query_template(
+ user_id=str(user_id),
+ template_name=template_name,
+ parameters=parameters,
+ )
+ return {
+ "status": "success",
+ "result": res,
+ }
+ except Exception as e:
+ logger.exception("Failed to execute query template %s", template_name)
+ return {
+ "status": "error",
+ "message": f"Execution failed: {e}",
+ }
+
+
+async def list_custom_tables_and_templates(
+ tool_context: ToolContext,
+) -> dict[str, Any]:
+ """List all custom tables, columns, and registered templates for the active user.
+
+ Args:
+ tool_context: ADK tool context.
+
+ Returns:
+ A dictionary mapping table names to their schema structures and query templates.
+ """
+ user_id = getattr(tool_context, "user_id", None) or tool_context.state.get(
+ "user_id"
+ )
+ if not user_id:
+ return {
+ "status": "error",
+ "message": "User not identified in tool context.",
+ }
+
+ try:
+ storage = get_declarative_db_storage()
+ schemas = await storage.list_custom_tables_and_templates(str(user_id))
+ return {
+ "status": "success",
+ "tables": schemas,
+ }
+ except Exception as e:
+ logger.exception("Failed to list schemas")
+ return {
+ "status": "error",
+ "message": f"Failed to fetch custom schemas: {e}",
+ }
+
+
+async def set_custom_instruction_override(
+ instructions: str,
+ tool_context: ToolContext,
+) -> dict[str, Any]:
+ """Persist a custom instruction override for the user.
+
+ This instructions block is injected into system instruction prompts to guide
+ the model's persona, handling, or workflows dynamically, safely isolated from
+ the system prompt codebase.
+
+ Args:
+ instructions: Raw text instructions to guide the model.
+ tool_context: ADK tool context.
+
+ Returns:
+ A status dictionary.
+ """
+ user_id = getattr(tool_context, "user_id", None) or tool_context.state.get(
+ "user_id"
+ )
+ if not user_id:
+ return {
+ "status": "error",
+ "message": "User not identified in tool context.",
+ }
+
+ try:
+ storage = get_declarative_db_storage()
+ await storage.set_custom_instruction_override(str(user_id), instructions)
+ return {
+ "status": "success",
+ "message": "Custom instructions successfully saved.",
+ }
+ except Exception as e:
+ logger.exception("Failed to set instruction override")
+ return {
+ "status": "error",
+ "message": f"Failed to save instructions: {e}",
+ }
+
+
+async def delete_custom_instruction_override(
+ tool_context: ToolContext,
+) -> dict[str, Any]:
+ """Clear custom instruction override for the user.
+
+ Args:
+ tool_context: ADK tool context.
+
+ Returns:
+ A status dictionary.
+ """
+ user_id = getattr(tool_context, "user_id", None) or tool_context.state.get(
+ "user_id"
+ )
+ if not user_id:
+ return {
+ "status": "error",
+ "message": "User not identified in tool context.",
+ }
+
+ try:
+ storage = get_declarative_db_storage()
+ deleted = await storage.delete_custom_instruction_override(str(user_id))
+ if deleted:
+ return {
+ "status": "success",
+ "message": "Custom instructions successfully deleted.",
+ }
+ return {
+ "status": "success",
+ "message": "No custom instructions existed to delete.",
+ }
+ except Exception as e:
+ logger.exception("Failed to delete instruction override")
+ return {
+ "status": "error",
+ "message": f"Failed to clear instructions: {e}",
+ }
diff --git a/src/blacki/declarative_db/validation.py b/src/blacki/declarative_db/validation.py
new file mode 100644
index 0000000..8b81831
--- /dev/null
+++ b/src/blacki/declarative_db/validation.py
@@ -0,0 +1,96 @@
+"""Safety and validation layer for declarative database identifiers and types."""
+
+from __future__ import annotations
+
+import re
+
+# Strict regex matching letters, numbers, and underscores,
+# starting with a letter or underscore.
+IDENTIFIER_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
+
+# Safe SQLite types allowlist
+ALLOWED_TYPES = {"TEXT", "INTEGER", "REAL", "BLOB"}
+
+# Standard SQLite reserved keywords blocklist to avoid syntax overlaps and confusion
+RESERVED_KEYWORDS = {
+ "SELECT",
+ "DROP",
+ "ALTER",
+ "DELETE",
+ "INSERT",
+ "PRAGMA",
+ "CREATE",
+ "TABLE",
+ "INDEX",
+ "UPDATE",
+ "WHERE",
+ "AND",
+ "OR",
+ "JOIN",
+ "FROM",
+ "LIMIT",
+ "ORDER",
+ "BY",
+ "IN",
+ "IS",
+ "NULL",
+ "NOT",
+ "PRIMARY",
+ "KEY",
+ "FOREIGN",
+ "REFERENCES",
+ "DEFAULT",
+ "CHECK",
+ "COLLATE",
+ "UNIQUE",
+ "EXISTS",
+ "INTO",
+ "VALUES",
+ "SET",
+ "ON",
+}
+
+
+def validate_identifier(name: str) -> None:
+ """Validate a table, column, or template name.
+
+ Args:
+ name: Name to validate.
+
+ Raises:
+ ValueError: If name fails format, length, or keyword validation.
+ """
+ if not name:
+ raise ValueError("Identifier cannot be empty")
+
+ if len(name) > 64:
+ raise ValueError(f"Identifier '{name}' exceeds maximum length of 64 characters")
+
+ if not IDENTIFIER_REGEX.match(name):
+ raise ValueError(
+ f"Identifier '{name}' is invalid. "
+ "Must start with a letter or underscore and "
+ "contain only alphanumeric characters and underscores."
+ )
+
+ if name.upper() in RESERVED_KEYWORDS:
+ raise ValueError(
+ f"Identifier '{name}' is a reserved SQL keyword and cannot be used"
+ )
+
+
+def validate_column_type(col_type: str) -> None:
+ """Validate that the column type is in the strict allowlist.
+
+ Args:
+ col_type: Type string to validate.
+
+ Raises:
+ ValueError: If the type is not allowed.
+ """
+ cleaned_type = col_type.strip().upper()
+ if cleaned_type not in ALLOWED_TYPES:
+ raise ValueError(
+ f"Type '{col_type}' is not allowed. "
+ f"Must be one of: {', '.join(sorted(ALLOWED_TYPES))}"
+ )
diff --git a/src/blacki/registry.py b/src/blacki/registry.py
index c2efa51..890d026 100644
--- a/src/blacki/registry.py
+++ b/src/blacki/registry.py
@@ -54,6 +54,7 @@ def build_tools(config: ToolConfig) -> list[Any]:
tools.extend(_build_reminder_tools())
tools.extend(_build_calorie_tools())
tools.extend(_build_workout_tools())
+ tools.extend(_build_declarative_db_tools())
logger.info("Database-backed tools enabled")
if config.sandbox_enabled:
@@ -233,6 +234,33 @@ def _build_memory_tools() -> list[Any]:
return []
+def _build_declarative_db_tools() -> list[Any]:
+ """Build declarative database tools."""
+ try:
+ from blacki.declarative_db.tools import (
+ create_custom_table,
+ create_query_template,
+ delete_custom_instruction_override,
+ delete_custom_table,
+ execute_query_template,
+ list_custom_tables_and_templates,
+ set_custom_instruction_override,
+ )
+
+ return [
+ create_custom_table,
+ delete_custom_table,
+ create_query_template,
+ execute_query_template,
+ list_custom_tables_and_templates,
+ set_custom_instruction_override,
+ delete_custom_instruction_override,
+ ]
+ except ImportError as e: # pragma: no cover
+ logger.warning("Failed to load Declarative DB tools: %s", e)
+ return []
+
+
def build_tool_config_from_env() -> ToolConfig:
"""Build tool configuration from environment variables.
diff --git a/tests/declarative_db/test_declarative_db.py b/tests/declarative_db/test_declarative_db.py
new file mode 100644
index 0000000..60741f7
--- /dev/null
+++ b/tests/declarative_db/test_declarative_db.py
@@ -0,0 +1,1199 @@
+# mypy: disable-error-code="no-untyped-def"
+"""Unit and integration tests for the declarative database tool template system."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+from collections.abc import AsyncIterator, Generator
+from pathlib import Path
+from typing import Any
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import aiosqlite
+import pytest
+
+from blacki.container import set_container_from_connection
+from blacki.declarative_db.plugin import DeclarativeDbPlugin
+from blacki.declarative_db.storage import (
+ SqliteDeclarativeDbStorage,
+)
+from blacki.declarative_db.tools import (
+ create_custom_table,
+ create_query_template,
+ delete_custom_instruction_override,
+ delete_custom_table,
+ execute_query_template,
+ list_custom_tables_and_templates,
+ set_custom_instruction_override,
+)
+from blacki.declarative_db.validation import validate_column_type, validate_identifier
+
+# ==============================================================================
+# 1. Validation Layer Tests
+# ==============================================================================
+
+
+class TestValidation:
+ """Test safe identifier and type validation constraints."""
+
+ def test_valid_identifiers(self) -> None:
+ """Should accept valid identifiers."""
+ validate_identifier("users")
+ validate_identifier("user_todos")
+ validate_identifier("_private_table")
+ validate_identifier("column1")
+
+ def test_invalid_identifiers_pattern(self) -> None:
+ """Should reject names with special characters or starting with numbers."""
+ with pytest.raises(ValueError, match="is invalid"):
+ validate_identifier("user-todos")
+ with pytest.raises(ValueError, match="is invalid"):
+ validate_identifier("1todos")
+ with pytest.raises(ValueError, match="is invalid"):
+ validate_identifier("user todos")
+
+ def test_invalid_identifiers_keywords(self) -> None:
+ """Should reject case-insensitive SQL reserved keywords."""
+ with pytest.raises(ValueError, match="reserved SQL keyword"):
+ validate_identifier("SELECT")
+ with pytest.raises(ValueError, match="reserved SQL keyword"):
+ validate_identifier("Drop")
+ with pytest.raises(ValueError, match="reserved SQL keyword"):
+ validate_identifier("table")
+
+ def test_invalid_identifiers_length(self) -> None:
+ """Should reject names exceeding 64 characters."""
+ long_name = "a" * 65
+ with pytest.raises(ValueError, match="exceeds maximum length"):
+ validate_identifier(long_name)
+
+ def test_invalid_identifiers_empty(self) -> None:
+ """Should reject empty names."""
+ with pytest.raises(ValueError, match="cannot be empty"):
+ validate_identifier("")
+
+ def test_valid_column_types(self) -> None:
+ """Should accept allowlisted column types."""
+ validate_column_type("TEXT")
+ validate_column_type("integer")
+ validate_column_type("REAL")
+ validate_column_type("blob")
+
+ def test_invalid_column_types(self) -> None:
+ """Should reject non-allowlisted column types."""
+ with pytest.raises(ValueError, match="is not allowed"):
+ validate_column_type("VARCHAR(50)")
+ with pytest.raises(ValueError, match="is not allowed"):
+ validate_column_type("DATETIME")
+ with pytest.raises(ValueError, match="is not allowed"):
+ validate_column_type("SERIAL")
+
+
+# ==============================================================================
+# 2. Storage Integration Tests
+# ==============================================================================
+
+
+@pytest.fixture
+async def sqlite_conn(tmp_path: Path) -> AsyncIterator[aiosqlite.Connection]:
+ """Provide a real, isolated SQLite connection with schema created."""
+ db_path = tmp_path / "test_declarative.db"
+ from blacki.storage.sqlite import create_connection
+
+ conn = await create_connection(db_path)
+ yield conn
+ await conn.close()
+
+
+@pytest.fixture
+async def storage(sqlite_conn: aiosqlite.Connection) -> SqliteDeclarativeDbStorage:
+ """Provide initialized SqliteDeclarativeDbStorage instance."""
+ lock = asyncio.Lock()
+ storage = SqliteDeclarativeDbStorage(sqlite_conn, lock)
+ await storage.initialize()
+ return storage
+
+
+class TestStorageIntegration:
+ """Integration tests for SqliteDeclarativeDbStorage."""
+
+ @pytest.mark.anyio
+ async def test_create_and_delete_custom_table(
+ self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection
+ ) -> None:
+ """Should create a physical custom table and metadata, and drop it properly."""
+ user_id = "user_123"
+ table_name = "items"
+ columns: list[dict[str, Any]] = [
+ {"name": "id", "type": "INTEGER", "primary_key": True},
+ {"name": "title", "type": "TEXT", "not_null": True},
+ {"name": "price", "type": "REAL", "default": 0.0},
+ ]
+
+ # 1. Create table
+ await storage.create_custom_table(
+ user_id=user_id,
+ table_name=table_name,
+ columns=columns,
+ description="Item catalog",
+ )
+
+ # Verify physical table structure via PRAGMA
+ physical_name = storage._get_physical_name(user_id, table_name)
+ async with sqlite_conn.execute(
+ f"PRAGMA table_info('{physical_name}')"
+ ) as cursor:
+ rows = list(await cursor.fetchall())
+ assert len(rows) == 3
+ assert rows[0]["name"] == "id"
+ assert rows[0]["type"] == "INTEGER"
+ assert rows[0]["pk"] == 1
+ assert rows[1]["name"] == "title"
+ assert rows[1]["notnull"] == 1
+ assert rows[2]["name"] == "price"
+ assert rows[2]["dflt_value"] == "0.0"
+
+ # Verify metadata
+ metadata = await storage.list_custom_tables_and_templates(user_id)
+ assert table_name in metadata
+ assert metadata[table_name]["description"] == "Item catalog"
+ assert len(metadata[table_name]["columns"]) == 3
+
+ # 2. Delete table
+ deleted = await storage.delete_custom_table(user_id, table_name)
+ assert deleted is True
+
+ # Verify metadata is cleared
+ metadata = await storage.list_custom_tables_and_templates(user_id)
+ assert table_name not in metadata
+
+ # Verify physical table is dropped
+ drop_query = (
+ "SELECT count(*) FROM sqlite_master " # noqa: S608
+ f"WHERE type='table' AND name='{physical_name}'" # noqa: S608
+ )
+ async with sqlite_conn.execute(drop_query) as cursor:
+ row = await cursor.fetchone()
+ assert row is not None
+ assert row[0] == 0
+
+ @pytest.mark.anyio
+ async def test_table_limits_guardrails(
+ self, storage: SqliteDeclarativeDbStorage
+ ) -> None:
+ """Should block creating more than 5 tables or tables with > 15 columns."""
+ user_id = "user_guard"
+
+ # Table with too many columns
+ too_many_cols: list[dict[str, Any]] = [
+ {"name": f"col_{i}", "type": "TEXT"} for i in range(16)
+ ]
+ with pytest.raises(ValueError, match="Limit of 15 columns"):
+ await storage.create_custom_table(user_id, "big_table", too_many_cols)
+
+ # Register 5 tables
+ cols: list[dict[str, Any]] = [{"name": "id", "type": "INTEGER"}]
+ for i in range(5):
+ await storage.create_custom_table(user_id, f"table_{i}", cols)
+
+ # Attempting the 6th table should fail
+ with pytest.raises(ValueError, match="Limit of 5 custom tables"):
+ await storage.create_custom_table(user_id, "table_6", cols)
+
+ @pytest.mark.anyio
+ async def test_query_templates_crud_and_validation(
+ self, storage: SqliteDeclarativeDbStorage
+ ) -> None:
+ """Should manage and validate templates, enforcing a limit of 10."""
+ user_id = "user_templates"
+ table_name = "tasks"
+ columns: list[dict[str, Any]] = [
+ {"name": "id", "type": "INTEGER", "primary_key": True},
+ {"name": "task_name", "type": "TEXT"},
+ {"name": "status", "type": "TEXT"},
+ ]
+
+ await storage.create_custom_table(user_id, table_name, columns)
+
+ # 1. Create a SELECT template
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="get_by_status",
+ table_name=table_name,
+ query_type="SELECT",
+ select_columns=["id", "task_name"],
+ filter_columns=["status"],
+ order_by_column="id",
+ order_by_direction="DESC",
+ limit_val=5,
+ description="Fetch tasks by status",
+ )
+
+ # Validate duplicate/overwrite
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="get_by_status",
+ table_name=table_name,
+ query_type="SELECT",
+ select_columns=["id", "task_name"],
+ filter_columns=["status"],
+ description="Fetch tasks by status (updated)",
+ )
+
+ # 2. Reject template with invalid query_type
+ with pytest.raises(ValueError, match="Query type must be"):
+ await storage.create_query_template(
+ user_id, "invalid_tmpl", table_name, "DROP"
+ )
+
+ # 3. Reject template with invalid sorting direction
+ with pytest.raises(ValueError, match="Sorting direction must be"):
+ await storage.create_query_template(
+ user_id,
+ "invalid_tmpl",
+ table_name,
+ "SELECT",
+ order_by_direction="OTHER",
+ )
+
+ # 4. Reject template referencing non-existent table
+ with pytest.raises(ValueError, match="does not exist"):
+ await storage.create_query_template(
+ user_id, "tmpl", "non_existent_table", "SELECT"
+ )
+
+ # 5. Reject template referencing non-existent column
+ with pytest.raises(ValueError, match="does not exist in table"):
+ await storage.create_query_template(
+ user_id,
+ "tmpl",
+ table_name,
+ "SELECT",
+ select_columns=["non_existent_col"],
+ )
+
+ # 6. Template limits guardrail (max 10)
+ for i in range(1, 10): # Already have 1, add 9 more
+ await storage.create_query_template(
+ user_id, f"tmpl_{i}", table_name, "INSERT"
+ )
+
+ with pytest.raises(ValueError, match="Limit of 10 saved query templates"):
+ await storage.create_query_template(
+ user_id, "tmpl_11", table_name, "INSERT"
+ )
+
+ @pytest.mark.anyio
+ async def test_execute_queries_and_sql_injection_protection(
+ self, storage: SqliteDeclarativeDbStorage
+ ) -> None:
+ """Should run parameterized templates and catch malicious key injections."""
+ user_id = "user_exec"
+ table_name = "notes"
+ columns: list[dict[str, Any]] = [
+ {"name": "id", "type": "INTEGER", "primary_key": True},
+ {"name": "content", "type": "TEXT"},
+ {"name": "category", "type": "TEXT"},
+ ]
+
+ await storage.create_custom_table(user_id, table_name, columns)
+
+ # Create INSERT template
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="add_note",
+ table_name=table_name,
+ query_type="INSERT",
+ )
+
+ # Create SELECT template
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="get_notes",
+ table_name=table_name,
+ query_type="SELECT",
+ select_columns=["id", "content"],
+ filter_columns=["category"],
+ )
+
+ # Create UPDATE template
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="update_note",
+ table_name=table_name,
+ query_type="UPDATE",
+ filter_columns=["id"],
+ )
+
+ # Create DELETE template
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="delete_notes",
+ table_name=table_name,
+ query_type="DELETE",
+ filter_columns=["category"],
+ )
+
+ # 1. Execute INSERT
+ row_id1 = await storage.execute_query_template(
+ user_id, "add_note", {"content": "Buy milk", "category": "shopping"}
+ )
+ assert row_id1 == 1
+
+ row_id2 = await storage.execute_query_template(
+ user_id, "add_note", {"content": "Work gym", "category": "fitness"}
+ )
+ assert row_id2 == 2
+
+ # 2. Execute SELECT
+ res = await storage.execute_query_template(
+ user_id, "get_notes", {"category": "shopping"}
+ )
+ assert isinstance(res, list)
+ assert len(res) == 1
+ assert res[0]["content"] == "Buy milk"
+ assert "category" not in res[0] # select_columns only returns id and content
+
+ # 3. Execute UPDATE
+ affected_rows = await storage.execute_query_template(
+ user_id, "update_note", {"id": 1, "content": "Buy whole milk"}
+ )
+ assert affected_rows == 1
+
+ # Verify update worked
+ res = await storage.execute_query_template(
+ user_id, "get_notes", {"category": "shopping"}
+ )
+ assert isinstance(res, list)
+ assert res[0]["content"] == "Buy whole milk"
+
+ # 4. Execute DELETE
+ deleted_count = await storage.execute_query_template(
+ user_id, "delete_notes", {"category": "shopping"}
+ )
+ assert deleted_count == 1
+
+ # Verify no items left in shopping category
+ res = await storage.execute_query_template(
+ user_id, "get_notes", {"category": "shopping"}
+ )
+ assert isinstance(res, list)
+ assert len(res) == 0
+
+ # 5. Security Parameter Key Injection Guard check
+ # Attempt to pass a malicious SQL string as a parameter key.
+ mal_key = "category OR 1=1; DROP TABLE notes; --"
+ with pytest.raises(ValueError, match="is invalid|is not a valid column"):
+ await storage.execute_query_template(
+ user_id, "get_notes", {mal_key: "shopping"}
+ )
+
+ @pytest.mark.anyio
+ async def test_custom_instructions_overrides(
+ self, storage: SqliteDeclarativeDbStorage
+ ) -> None:
+ """Should save, get, and delete custom user-scoped prompt overrides."""
+ user_id = "user_override"
+
+ # Initially empty
+ val = await storage.get_custom_instruction_override(user_id)
+ assert val is None
+
+ # Set instructions
+ await storage.set_custom_instruction_override(user_id, "Tone: Sarcastic")
+ val = await storage.get_custom_instruction_override(user_id)
+ assert val == "Tone: Sarcastic"
+
+ # Clear instructions
+ deleted = await storage.delete_custom_instruction_override(user_id)
+ assert deleted is True
+ val = await storage.get_custom_instruction_override(user_id)
+ assert val is None
+
+ @pytest.mark.anyio
+ async def test_get_schema_instructions_xml(
+ self, storage: SqliteDeclarativeDbStorage
+ ) -> None:
+ """Should compile schemas and overrides into instructions XML."""
+ user_id = "user_xml"
+ assert await storage.get_schema_instructions_xml(user_id) == ""
+
+ # Set instructions override
+ await storage.set_custom_instruction_override(user_id, "Be extra kind.")
+
+ # Create table & template
+ await storage.create_custom_table(
+ user_id, "logs", [{"name": "id", "type": "INTEGER", "primary_key": True}]
+ )
+ await storage.create_query_template(user_id, "add_log", "logs", "INSERT")
+
+ xml = await storage.get_schema_instructions_xml(user_id)
+ assert "" in xml
+ assert "Be extra kind." in xml
+ assert "" in xml
+ assert "Table: logs" in xml
+ assert "id (INTEGER) PRIMARY KEY" in xml
+ assert "Template: add_log" in xml
+
+ @pytest.mark.anyio
+ async def test_create_custom_table_with_no_columns(
+ self, storage: SqliteDeclarativeDbStorage
+ ) -> None:
+ """Should raise ValueError when creating a table with no columns."""
+ with pytest.raises(ValueError, match="Table must define at least one column"):
+ await storage.create_custom_table(
+ user_id="user_test",
+ table_name="no_cols",
+ columns=[],
+ )
+
+ @pytest.mark.anyio
+ async def test_create_custom_table_already_exists(
+ self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection
+ ) -> None:
+ """Verify overwriting an existing table works and bypasses limit."""
+ user_id = "user_overwrite_test"
+ cols = [{"name": "id", "type": "INTEGER", "primary_key": True}]
+
+ # Create 5 tables (limit is 5)
+ for i in range(5):
+ await storage.create_custom_table(user_id, f"tbl_{i}", cols)
+
+ # Overwrite tbl_0 (bypasses limit)
+ await storage.create_custom_table(
+ user_id,
+ "tbl_0",
+ [
+ {"name": "id", "type": "INTEGER", "primary_key": True},
+ {"name": "new_col", "type": "TEXT"},
+ ],
+ description="updated",
+ )
+
+ metadata = await storage.list_custom_tables_and_templates(user_id)
+ assert len(metadata["tbl_0"]["columns"]) == 2
+ assert metadata["tbl_0"]["description"] == "updated"
+
+ @pytest.mark.anyio
+ async def test_column_default_value_types(
+ self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection
+ ) -> None:
+ """Test different default value types in DDL construction."""
+ user_id = "user_defaults"
+ table_name = "defaults_test"
+
+ class Dummy:
+ def __str__(self) -> str:
+ return "dummy'value"
+
+ columns: list[dict[str, Any]] = [
+ {"name": "id", "type": "INTEGER", "primary_key": True},
+ {"name": "bool_t", "type": "INTEGER", "default": True},
+ {"name": "bool_f", "type": "INTEGER", "default": False},
+ {"name": "complex_list", "type": "TEXT", "default": [1, 2, "three's"]},
+ {"name": "complex_dict", "type": "TEXT", "default": {"a'b": 1}},
+ {"name": "str_val", "type": "TEXT", "default": "hello'world"},
+ {"name": "fallback", "type": "TEXT", "default": 42.5},
+ {"name": "object_fallback", "type": "TEXT", "default": Dummy()},
+ ]
+
+ await storage.create_custom_table(user_id, table_name, columns)
+
+ # Insert a row with defaults
+ physical_name = storage._get_physical_name(user_id, table_name)
+ await sqlite_conn.execute(f'INSERT INTO "{physical_name}" (id) VALUES (1)') # noqa: S608
+ await sqlite_conn.commit()
+
+ # Retrieve row to verify values
+ async with sqlite_conn.execute(
+ f'SELECT * FROM "{physical_name}" WHERE id = 1' # noqa: S608
+ ) as cursor:
+ row = await cursor.fetchone()
+ assert row is not None
+ assert row["bool_t"] == 1
+ assert row["bool_f"] == 0
+ assert json.loads(row["complex_list"]) == [1, 2, "three's"]
+ assert json.loads(row["complex_dict"]) == {"a'b": 1}
+ assert row["str_val"] == "hello'world"
+ assert float(row["fallback"]) == 42.5
+ assert row["object_fallback"] == "dummy'value"
+
+ @pytest.mark.anyio
+ async def test_delete_custom_table_not_exists(
+ self, storage: SqliteDeclarativeDbStorage
+ ) -> None:
+ """Should return False if dropping a non-existent table."""
+ res = await storage.delete_custom_table("user_id", "non_existent")
+ assert res is False
+
+ @pytest.mark.anyio
+ async def test_storage_exceptions_rollback(
+ self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection
+ ) -> None:
+ """Verify storage exceptions during create/delete rollback transaction."""
+ user_id = "user_exceptions"
+ cols = [{"name": "id", "type": "INTEGER", "primary_key": True}]
+
+ # 1. Create table failure
+ original_execute = sqlite_conn.execute
+
+ def mock_execute(sql: str, *args, **kwargs):
+ if "INSERT INTO custom_table_columns" in sql:
+ raise Exception("forced column insert failure")
+ return original_execute(sql, *args, **kwargs)
+
+ with (
+ patch.object(sqlite_conn, "execute", side_effect=mock_execute),
+ pytest.raises(Exception, match="forced column insert failure"),
+ ):
+ await storage.create_custom_table(user_id, "fail_tbl", cols)
+
+ # Check that table does not exist in metadata
+ metadata = await storage.list_custom_tables_and_templates(user_id)
+ assert "fail_tbl" not in metadata
+
+ # 2. Delete table failure
+ await storage.create_custom_table(user_id, "success_tbl", cols)
+
+ def mock_execute_delete(sql: str, *args, **kwargs):
+ if "DELETE FROM custom_tables" in sql:
+ raise Exception("forced delete metadata failure")
+ return original_execute(sql, *args, **kwargs)
+
+ with (
+ patch.object(sqlite_conn, "execute", side_effect=mock_execute_delete),
+ pytest.raises(Exception, match="forced delete metadata failure"),
+ ):
+ await storage.delete_custom_table(user_id, "success_tbl")
+
+ # Table metadata should still exist
+ metadata = await storage.list_custom_tables_and_templates(user_id)
+ assert "success_tbl" in metadata
+
+ @pytest.mark.anyio
+ async def test_create_query_template_overwrites(
+ self, storage: SqliteDeclarativeDbStorage
+ ) -> None:
+ """Verify limit check is bypassed when template already exists (overwrites)."""
+ user_id = "user_tmpl_limit"
+ cols = [{"name": "id", "type": "INTEGER", "primary_key": True}]
+ await storage.create_custom_table(user_id, "tbl", cols)
+
+ # Create 10 templates (the limit)
+ for i in range(10):
+ await storage.create_query_template(user_id, f"tmpl_{i}", "tbl", "SELECT")
+
+ # Try creating an 11th new template -> should fail
+ with pytest.raises(ValueError, match="Limit of 10 saved query templates"):
+ await storage.create_query_template(user_id, "tmpl_10", "tbl", "SELECT")
+
+ # Try overwriting an existing template (e.g. tmpl_0) -> should succeed
+ await storage.create_query_template(
+ user_id, "tmpl_0", "tbl", "SELECT", description="updated template"
+ )
+
+ @pytest.mark.anyio
+ async def test_create_query_template_invalid_sort_col(
+ self, storage: SqliteDeclarativeDbStorage
+ ) -> None:
+ """Should raise ValueError if sorting column does not exist in table."""
+ user_id = "user_test"
+ cols = [{"name": "id", "type": "INTEGER", "primary_key": True}]
+ await storage.create_custom_table(user_id, "tbl", cols)
+
+ with pytest.raises(ValueError, match="does not exist in table"):
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="tmpl",
+ table_name="tbl",
+ query_type="SELECT",
+ order_by_column="non_existent",
+ )
+
+ @pytest.mark.anyio
+ async def test_execute_query_template_errors_and_edge_cases(
+ self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection
+ ) -> None:
+ """Verify errors and edge cases when executing query templates."""
+ await sqlite_conn.execute("PRAGMA foreign_keys = OFF")
+ user_id = "user_exec_edge"
+ table_name = "tbl"
+ cols: list[dict[str, Any]] = [
+ {"name": "id", "type": "INTEGER", "primary_key": True},
+ {"name": "name", "type": "TEXT"},
+ {"name": "age", "type": "INTEGER"},
+ ]
+ await storage.create_custom_table(user_id, table_name, cols)
+
+ # 1. Execute non-existent template
+ with pytest.raises(ValueError, match="Template 'non_existent' not found"):
+ await storage.execute_query_template(user_id, "non_existent", {})
+
+ # 2. Execute where underlying table does not exist
+ await storage.create_query_template(
+ user_id, "get_val", table_name, "SELECT", select_columns=["name"]
+ )
+ # Directly delete table from custom_tables metadata table.
+ # This orphans the template (bypasses default CASCADE).
+ await sqlite_conn.execute(
+ "DELETE FROM custom_tables WHERE user_id = ? AND table_name = ?",
+ (user_id, table_name),
+ )
+ await sqlite_conn.commit()
+
+ with pytest.raises(ValueError, match="Underlying table 'tbl' does not exist"):
+ await storage.execute_query_template(user_id, "get_val", {})
+
+ # Re-create table for subsequent tests
+ await storage.create_custom_table(user_id, table_name, cols)
+
+ # 3. Parameter key not in columns
+ with pytest.raises(ValueError, match="is not a valid column"):
+ await storage.execute_query_template(user_id, "get_val", {"invalid_col": 1})
+
+ # 4. SELECT missing required filter parameter
+ await storage.create_query_template(
+ user_id, "get_with_filter", table_name, "SELECT", filter_columns=["age"]
+ )
+ with pytest.raises(ValueError, match="Missing required filter parameter 'age'"):
+ await storage.execute_query_template(user_id, "get_with_filter", {})
+
+ # 5. INSERT no insert keys in parameters
+ await storage.create_query_template(user_id, "add_val", table_name, "INSERT")
+ with pytest.raises(
+ ValueError, match="Must provide at least one valid parameter to insert"
+ ):
+ # Empty parameters
+ await storage.execute_query_template(user_id, "add_val", {})
+
+ # 6. UPDATE no update keys in parameters
+ await storage.create_query_template(
+ user_id, "update_val", table_name, "UPDATE", filter_columns=["id"]
+ )
+ with pytest.raises(ValueError, match="No columns provided to update"):
+ # Passing only filter column, nothing to UPDATE/SET
+ await storage.execute_query_template(user_id, "update_val", {"id": 1})
+
+ # 7. UPDATE missing required filter parameter
+ with pytest.raises(
+ ValueError, match="Missing required update filter parameter 'id'"
+ ):
+ await storage.execute_query_template(user_id, "update_val", {"name": "Bob"})
+
+ # 8. UPDATE with no filters (should succeed without WHERE)
+ await storage.create_query_template(user_id, "update_all", table_name, "UPDATE")
+ # Insert a row first
+ await storage.execute_query_template(
+ user_id, "add_val", {"id": 1, "name": "Alice"}
+ )
+ # Update name globally without filter
+ affected = await storage.execute_query_template(
+ user_id, "update_all", {"name": "Bob"}
+ )
+ assert affected == 1
+
+ # 9. DELETE missing required filter parameter
+ await storage.create_query_template(
+ user_id, "delete_val", table_name, "DELETE", filter_columns=["id"]
+ )
+ with pytest.raises(
+ ValueError, match="Missing required delete filter parameter 'id'"
+ ):
+ await storage.execute_query_template(user_id, "delete_val", {})
+
+ # 10. Unsupported query type
+ await storage.create_query_template(user_id, "bad_type", table_name, "SELECT")
+ await sqlite_conn.execute(
+ "UPDATE saved_query_templates SET query_type = 'INVALID' "
+ "WHERE user_id = ? AND template_name = ?",
+ (user_id, "bad_type"),
+ )
+ await sqlite_conn.commit()
+ with pytest.raises(ValueError, match="Unsupported query type: INVALID"):
+ await storage.execute_query_template(user_id, "bad_type", {})
+
+ # 11. SELECT order by direction fallback (None -> ASC) and limit
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="get_ordered",
+ table_name=table_name,
+ query_type="SELECT",
+ select_columns=["id", "name"],
+ order_by_column="name",
+ limit_val=2,
+ )
+ res_ordered = await storage.execute_query_template(user_id, "get_ordered", {})
+ assert isinstance(res_ordered, list)
+
+ # 12. DELETE with no filters (should succeed)
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="delete_all",
+ table_name=table_name,
+ query_type="DELETE",
+ )
+ deleted_cnt = await storage.execute_query_template(user_id, "delete_all", {})
+ assert isinstance(deleted_cnt, int)
+ assert deleted_cnt >= 0
+
+ @pytest.mark.anyio
+ async def test_list_schemas_orphans(
+ self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection
+ ) -> None:
+ """Verify list schemas handles columns/templates for non-existent tables."""
+ await sqlite_conn.execute("PRAGMA foreign_keys = OFF")
+ user_id = "user_orphans"
+ # Manually insert column metadata referencing a non-existent table
+ await sqlite_conn.execute(
+ "INSERT INTO custom_table_columns "
+ "(user_id, table_name, column_name, column_type) "
+ "VALUES (?, 'non_existent_table', 'ghost_col', 'TEXT')",
+ (user_id,),
+ )
+ # Manually insert template metadata referencing a non-existent table
+ await sqlite_conn.execute(
+ "INSERT INTO saved_query_templates "
+ "(user_id, template_name, table_name, query_type) "
+ "VALUES (?, 'ghost_tmpl', 'non_existent_table', 'SELECT')",
+ (user_id,),
+ )
+ await sqlite_conn.commit()
+
+ # Call listing
+ res = await storage.list_custom_tables_and_templates(user_id)
+ # The non_existent_table should NOT be in res because it's not in custom_tables
+ assert "non_existent_table" not in res
+
+ @pytest.mark.anyio
+ async def test_get_schema_instructions_xml_variations(
+ self, storage: SqliteDeclarativeDbStorage
+ ) -> None:
+ """Verify xml instructions formatting with different metadata states."""
+ user_id = "user_xml_vars"
+
+ # 1. Table with no description and no templates
+ await storage.create_custom_table(
+ user_id=user_id,
+ table_name="simple_table",
+ columns=[{"name": "id", "type": "INTEGER", "primary_key": True}],
+ )
+
+ xml1 = await storage.get_schema_instructions_xml(user_id)
+ assert "" in xml1
+ assert "Table: simple_table" in xml1
+ assert "Description:" not in xml1
+ assert "Saved Query Templates:" not in xml1
+
+ # 2. Template with no select_columns, filter_columns, order_by, limit
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="get_simple",
+ table_name="simple_table",
+ query_type="SELECT",
+ )
+
+ xml2 = await storage.get_schema_instructions_xml(user_id)
+ assert "Saved Query Templates:" in xml2
+ assert "Returns Columns:" not in xml2
+ assert "Required Parameters:" not in xml2
+ assert "Sorted By:" not in xml2
+ assert "Limit:" not in xml2
+
+ # 3. Table WITH description and template with optional params
+ await storage.create_custom_table(
+ user_id=user_id,
+ table_name="detailed_table",
+ columns=[
+ {"name": "id", "type": "INTEGER", "primary_key": True},
+ {"name": "col_a", "type": "TEXT"},
+ ],
+ description="Detailed table description",
+ )
+ await storage.create_query_template(
+ user_id=user_id,
+ template_name="get_detailed",
+ table_name="detailed_table",
+ query_type="SELECT",
+ select_columns=["id"],
+ filter_columns=["col_a"],
+ order_by_column="id",
+ order_by_direction="ASC",
+ limit_val=10,
+ description="Fetch detailed",
+ )
+
+ xml3 = await storage.get_schema_instructions_xml(user_id)
+ assert "Description: Detailed table description" in xml3
+ assert "Returns Columns: id" in xml3
+ assert "Required Parameters: col_a" in xml3
+ assert "Sorted By: id ASC" in xml3
+ assert "Limit: 10" in xml3
+
+ def test_get_declarative_db_storage_uninitialized(self) -> None:
+ """Should raise RuntimeError when storage is not initialized."""
+ from blacki.declarative_db.storage import get_declarative_db_storage
+
+ mock_container = MagicMock()
+ mock_container.declarative_db_storage.is_initialized = False
+
+ with (
+ patch("blacki.container.get_container", return_value=mock_container),
+ pytest.raises(RuntimeError, match="Declarative DB storage not initialized"),
+ ):
+ get_declarative_db_storage()
+
+
+# ==============================================================================
+# 3. ADK Plugin Layer Tests
+# ==============================================================================
+
+
+class TestDeclarativeDbPlugin:
+ """Test dynamic context injection plugin."""
+
+ @pytest.mark.anyio
+ async def test_plugin_appends_instructions(
+ self, storage: SqliteDeclarativeDbStorage, sqlite_conn: aiosqlite.Connection
+ ) -> None:
+ """Should compile active schema XML and safely append to instructions."""
+ # Setup container for process-wide lazy instantiation
+ container = set_container_from_connection(sqlite_conn)
+ container._declarative_db_storage = storage
+
+ plugin = DeclarativeDbPlugin()
+
+ # Build mock callback_context and llm_request
+ mock_session = MagicMock()
+ mock_session.state = {"user_id": "user_plugin_test"}
+ callback_context = MagicMock()
+ callback_context.session = mock_session
+
+ llm_request = MagicMock()
+ llm_request.append_instructions = MagicMock()
+
+ # Add data to storage to compile
+ await storage.set_custom_instruction_override(
+ "user_plugin_test", "Always speak in Spanish."
+ )
+
+ # Run callback
+ await plugin.before_model_callback(
+ callback_context=callback_context, llm_request=llm_request
+ )
+
+ # Assert compiled schema was appended to LLM instructions
+ llm_request.append_instructions.assert_called_once()
+ args = llm_request.append_instructions.call_args[0][0]
+ assert len(args) == 1
+ assert "Spanish" in args[0]
+ assert "" in args[0]
+
+ @pytest.mark.anyio
+ async def test_plugin_no_session(self) -> None:
+ """Plugin should return early if callback_context.session is None."""
+ plugin = DeclarativeDbPlugin()
+ callback_context = MagicMock()
+ callback_context.session = None
+ llm_request = MagicMock()
+
+ await plugin.before_model_callback(
+ callback_context=callback_context, llm_request=llm_request
+ )
+ llm_request.append_instructions.assert_not_called()
+
+ @pytest.mark.anyio
+ async def test_plugin_no_user_id_in_session_state(self) -> None:
+ """Plugin should return early if session has no user_id or telegram_chat_id."""
+ plugin = DeclarativeDbPlugin()
+ mock_session = MagicMock()
+ mock_session.state = {} # Empty state
+ callback_context = MagicMock()
+ callback_context.session = mock_session
+ llm_request = MagicMock()
+
+ await plugin.before_model_callback(
+ callback_context=callback_context, llm_request=llm_request
+ )
+ llm_request.append_instructions.assert_not_called()
+
+ @pytest.mark.anyio
+ async def test_plugin_empty_schema_xml(self) -> None:
+ """Plugin should not call append_instructions if schema XML is empty."""
+ plugin = DeclarativeDbPlugin()
+ mock_session = MagicMock()
+ mock_session.state = {"user_id": "test_user_empty"}
+ callback_context = MagicMock()
+ callback_context.session = mock_session
+ llm_request = MagicMock()
+
+ mock_storage = MagicMock()
+ mock_storage.get_schema_instructions_xml = AsyncMock(return_value="")
+
+ with patch(
+ "blacki.declarative_db.plugin.get_declarative_db_storage",
+ return_value=mock_storage,
+ ):
+ await plugin.before_model_callback(
+ callback_context=callback_context, llm_request=llm_request
+ )
+ llm_request.append_instructions.assert_not_called()
+
+ @pytest.mark.anyio
+ async def test_plugin_storage_exception_logged(self) -> None:
+ """Plugin should catch and log storage exceptions without crashing."""
+ plugin = DeclarativeDbPlugin()
+ mock_session = MagicMock()
+ mock_session.state = {"telegram_chat_id": "chat_123"}
+ callback_context = MagicMock()
+ callback_context.session = mock_session
+ llm_request = MagicMock()
+
+ mock_storage = MagicMock()
+ mock_storage.get_schema_instructions_xml = AsyncMock(
+ side_effect=Exception("forced db error")
+ )
+
+ with patch(
+ "blacki.declarative_db.plugin.get_declarative_db_storage",
+ return_value=mock_storage,
+ ):
+ # This should not raise an error
+ await plugin.before_model_callback(
+ callback_context=callback_context, llm_request=llm_request
+ )
+ llm_request.append_instructions.assert_not_called()
+
+
+# ==============================================================================
+# 4. Agent Tools Layer Tests
+# ==============================================================================
+
+
+class TestAgentTools:
+ """Test ADK tools proxy functions."""
+
+ @pytest.fixture(autouse=True)
+ def setup_mock_container(self) -> Generator[None, None, None]:
+ """Mock out get_declarative_db_storage singleton lookup."""
+ self.mock_storage = MagicMock()
+ patcher = MagicMock()
+ patcher.return_value = self.mock_storage
+
+ # Patch get_declarative_db_storage in the tools module
+ self.patcher = patch(
+ "blacki.declarative_db.tools.get_declarative_db_storage", patcher
+ )
+ self.patcher.start()
+ yield
+ self.patcher.stop()
+
+ @pytest.mark.anyio
+ async def test_tools_extract_user_id_and_proxy_calls(self) -> None:
+ """Agent tools should pull user_id and call equivalent storage methods."""
+ tool_context = MagicMock()
+ tool_context.user_id = None
+ tool_context.state = {"user_id": "tool_user"}
+
+ # 1. create_custom_table tool
+ self.mock_storage.create_custom_table = AsyncMock()
+ res = await create_custom_table(
+ "events",
+ [{"name": "id", "type": "INTEGER"}],
+ tool_context,
+ "History of events",
+ )
+ assert res["status"] == "success"
+ self.mock_storage.create_custom_table.assert_called_once_with(
+ user_id="tool_user",
+ table_name="events",
+ columns=[{"name": "id", "type": "INTEGER"}],
+ description="History of events",
+ )
+
+ # 2. delete_custom_table tool
+ self.mock_storage.delete_custom_table = AsyncMock(return_value=True)
+ res = await delete_custom_table("events", tool_context)
+ assert res["status"] == "success"
+ self.mock_storage.delete_custom_table.assert_called_once_with(
+ "tool_user", "events"
+ )
+
+ # 3. create_query_template tool
+ self.mock_storage.create_query_template = AsyncMock()
+ res = await create_query_template("add_event", "events", "INSERT", tool_context)
+ assert res["status"] == "success"
+ self.mock_storage.create_query_template.assert_called_once_with(
+ user_id="tool_user",
+ template_name="add_event",
+ table_name="events",
+ query_type="INSERT",
+ select_columns=None,
+ filter_columns=None,
+ order_by_column=None,
+ order_by_direction=None,
+ limit_val=None,
+ description=None,
+ )
+
+ # 4. execute_query_template tool
+ self.mock_storage.execute_query_template = AsyncMock(return_value=[{"id": 1}])
+ res = await execute_query_template("get_events", {"id": 1}, tool_context)
+ assert res["status"] == "success"
+ assert res["result"] == [{"id": 1}]
+ self.mock_storage.execute_query_template.assert_called_once_with(
+ user_id="tool_user",
+ template_name="get_events",
+ parameters={"id": 1},
+ )
+
+ # 5. list_custom_tables_and_templates tool
+ self.mock_storage.list_custom_tables_and_templates = AsyncMock(return_value={})
+ res = await list_custom_tables_and_templates(tool_context)
+ assert res["status"] == "success"
+ assert res["tables"] == {}
+ self.mock_storage.list_custom_tables_and_templates.assert_called_once_with(
+ "tool_user"
+ )
+
+ # 6. set_custom_instruction_override tool
+ self.mock_storage.set_custom_instruction_override = AsyncMock()
+ res = await set_custom_instruction_override("Fly high", tool_context)
+ assert res["status"] == "success"
+ self.mock_storage.set_custom_instruction_override.assert_called_once_with(
+ "tool_user", "Fly high"
+ )
+
+ # 7. delete_custom_instruction_override tool
+ self.mock_storage.delete_custom_instruction_override = AsyncMock(
+ return_value=True
+ )
+ res = await delete_custom_instruction_override(tool_context)
+ assert res["status"] == "success"
+ self.mock_storage.delete_custom_instruction_override.assert_called_once_with(
+ "tool_user"
+ )
+
+ @pytest.mark.anyio
+ async def test_tools_missing_user_id_errors(self) -> None:
+ """Verify tools return error when user_id is missing."""
+ bad_context = MagicMock()
+ bad_context.user_id = None
+ bad_context.state = {}
+
+ # 1. create_custom_table
+ res = await create_custom_table("tbl", [], bad_context)
+ assert res["status"] == "error"
+ assert "user not identified" in res["message"].lower()
+
+ # 2. delete_custom_table
+ res = await delete_custom_table("tbl", bad_context)
+ assert res["status"] == "error"
+ assert "user not identified" in res["message"].lower()
+
+ # 3. create_query_template
+ res = await create_query_template("tmpl", "tbl", "SELECT", bad_context)
+ assert res["status"] == "error"
+ assert "user not identified" in res["message"].lower()
+
+ # 4. execute_query_template
+ res = await execute_query_template("tmpl", {}, bad_context)
+ assert res["status"] == "error"
+ assert "user not identified" in res["message"].lower()
+
+ # 5. list_custom_tables_and_templates
+ res = await list_custom_tables_and_templates(bad_context)
+ assert res["status"] == "error"
+ assert "user not identified" in res["message"].lower()
+
+ # 6. set_custom_instruction_override
+ res = await set_custom_instruction_override("instr", bad_context)
+ assert res["status"] == "error"
+ assert "user not identified" in res["message"].lower()
+
+ # 7. delete_custom_instruction_override
+ res = await delete_custom_instruction_override(bad_context)
+ assert res["status"] == "error"
+ assert "user not identified" in res["message"].lower()
+
+ @pytest.mark.anyio
+ async def test_tools_exception_handling_and_errors(self) -> None:
+ """Verify that tools handle and log exceptions properly."""
+ tool_context = MagicMock()
+ tool_context.user_id = "test_user"
+ tool_context.state = {}
+
+ # 1. create_custom_table raising exception
+ self.mock_storage.create_custom_table = AsyncMock(
+ side_effect=Exception("create error")
+ )
+ res = await create_custom_table("tbl", [], tool_context)
+ assert res["status"] == "error"
+ assert "failed to create table: create error" in res["message"].lower()
+
+ # 2. delete_custom_table raising exception
+ self.mock_storage.delete_custom_table = AsyncMock(
+ side_effect=Exception("delete error")
+ )
+ res = await delete_custom_table("tbl", tool_context)
+ assert res["status"] == "error"
+ assert "failed to delete table: delete error" in res["message"].lower()
+
+ # 3. delete_custom_table not found
+ self.mock_storage.delete_custom_table = AsyncMock(return_value=False)
+ res = await delete_custom_table("tbl", tool_context)
+ assert res["status"] == "error"
+ assert "not found" in res["message"].lower()
+
+ # 4. create_query_template raising exception
+ self.mock_storage.create_query_template = AsyncMock(
+ side_effect=Exception("tmpl error")
+ )
+ res = await create_query_template("tmpl", "tbl", "SELECT", tool_context)
+ assert res["status"] == "error"
+ assert "failed to register query template: tmpl error" in res["message"].lower()
+
+ # 5. execute_query_template raising exception
+ self.mock_storage.execute_query_template = AsyncMock(
+ side_effect=Exception("exec error")
+ )
+ res = await execute_query_template("tmpl", {}, tool_context)
+ assert res["status"] == "error"
+ assert "execution failed: exec error" in res["message"].lower()
+
+ # 6. list_custom_tables_and_templates raising exception
+ self.mock_storage.list_custom_tables_and_templates = AsyncMock(
+ side_effect=Exception("list error")
+ )
+ res = await list_custom_tables_and_templates(tool_context)
+ assert res["status"] == "error"
+ assert "failed to fetch custom schemas: list error" in res["message"].lower()
+
+ # 7. set_custom_instruction_override raising exception
+ self.mock_storage.set_custom_instruction_override = AsyncMock(
+ side_effect=Exception("set override error")
+ )
+ res = await set_custom_instruction_override("instr", tool_context)
+ assert res["status"] == "error"
+ assert (
+ "failed to save instructions: set override error" in res["message"].lower()
+ )
+
+ # 8. delete_custom_instruction_override raising exception
+ self.mock_storage.delete_custom_instruction_override = AsyncMock(
+ side_effect=Exception("delete override error")
+ )
+ res = await delete_custom_instruction_override(tool_context)
+ assert res["status"] == "error"
+ assert (
+ "failed to clear instructions: delete override error"
+ in res["message"].lower()
+ )
+
+ # 9. delete_custom_instruction_override not found (returns False)
+ self.mock_storage.delete_custom_instruction_override = AsyncMock(
+ return_value=False
+ )
+ res = await delete_custom_instruction_override(tool_context)
+ assert res["status"] == "success"
+ assert "no custom instructions existed to delete" in res["message"].lower()
diff --git a/tests/test_container.py b/tests/test_container.py
index 25a4775..c4a678c 100644
--- a/tests/test_container.py
+++ b/tests/test_container.py
@@ -195,6 +195,15 @@ async def test_preferences_storage_property(self, conn, lock) -> None:
assert storage is not None
assert container._preferences_storage is storage
+ @pytest.mark.asyncio
+ async def test_declarative_db_storage_property(self, conn, lock) -> None:
+ """Should lazily instantiate declarative DB storage."""
+ container = AppContainer(conn=conn, _lock=lock)
+
+ storage = container.declarative_db_storage
+ assert storage is not None
+ assert container._declarative_db_storage is storage
+
@pytest.mark.asyncio
async def test_close_closes_connection_and_storages(self, conn, lock) -> None:
"""Should close connection and all storage instances."""
@@ -203,9 +212,13 @@ async def test_close_closes_connection_and_storages(self, conn, lock) -> None:
reminder = container.reminder_storage
reminder.close = AsyncMock()
+ declarative_db = container.declarative_db_storage
+ declarative_db.close = AsyncMock()
+
await container.close()
reminder.close.assert_called_once()
+ declarative_db.close.assert_called_once()
@pytest.mark.asyncio
async def test_close_storages_resets_references(self, conn, lock) -> None:
@@ -216,6 +229,7 @@ async def test_close_storages_resets_references(self, conn, lock) -> None:
_ = container.calorie_storage
_ = container.workout_storage
_ = container.preferences_storage
+ _ = container.declarative_db_storage
await container._close_storages()
@@ -223,6 +237,7 @@ async def test_close_storages_resets_references(self, conn, lock) -> None:
assert container._calorie_storage is None
assert container._workout_storage is None
assert container._preferences_storage is None
+ assert container._declarative_db_storage is None
@pytest.mark.asyncio
async def test_close_storages_partial(self, conn, lock) -> None:
@@ -232,6 +247,7 @@ async def test_close_storages_partial(self, conn, lock) -> None:
_ = container.calorie_storage
_ = container.workout_storage
_ = container.preferences_storage
+ _ = container.declarative_db_storage
await container._close_storages()
@@ -239,6 +255,9 @@ async def test_close_storages_partial(self, conn, lock) -> None:
assert container._calorie_storage is None
assert container._workout_storage is None
assert container._preferences_storage is None
+ assert container._declarative_db_storage is None
+ assert container._workout_storage is None
+ assert container._preferences_storage is None
@pytest.mark.asyncio
async def test_initialize_all_storages(self, conn, lock) -> None:
diff --git a/tests/test_registry.py b/tests/test_registry.py
index 1b6d5a5..de4bf49 100644
--- a/tests/test_registry.py
+++ b/tests/test_registry.py
@@ -264,3 +264,15 @@ def test_returns_empty_for_nonexistent_dir(self) -> None:
tools = _build_skill_tools(Path("/nonexistent/skills"))
assert tools == []
+
+
+class TestBuildDeclarativeDbTools:
+ """Tests for _build_declarative_db_tools."""
+
+ def test_returns_tools_when_available(self) -> None:
+ """Should return declarative database tools when available."""
+ from blacki.registry import _build_declarative_db_tools
+
+ tools = _build_declarative_db_tools()
+
+ assert len(tools) == 7