diff --git a/TAP/configparam.py b/TAP/configparam.py index 58518b9..5f533b7 100644 --- a/TAP/configparam.py +++ b/TAP/configparam.py @@ -649,6 +649,8 @@ def __init__(self, path, **kwargs): self.connectInfo['port'] = self.dbport self.connectInfo['socket'] = self.socket self.connectInfo['dbschema'] = self.dbschema + self.connectInfo['accesstbl'] = self.accesstbl + self.connectInfo['usertbl'] = self.usertbl if self.debug: logging.debug('') diff --git a/TAP/datadictionary.py b/TAP/datadictionary.py index a427af9..3f556d7 100644 --- a/TAP/datadictionary.py +++ b/TAP/datadictionary.py @@ -171,12 +171,24 @@ def __init__(self, conn, table, connectInfo, **kwargs): if self.ddtbl == None: - sql = "select * from " + self.connectInfo["tap_schema"] + "." + self.connectInfo["columns_table"] + " where lower(table_name) = " + \ - "'" + self.dbtable + "'" + # Detect placeholder style from connection type + conn_type = self.conn.__class__.__module__ + if 'cx_Oracle' in conn_type or 'oracledb' in conn_type: + placeholder = ':1' + elif 'psycopg2' in conn_type: + placeholder = '%s' + elif 'mysql' in conn_type: + placeholder = '%s' + else: + placeholder = '?' + + sql = "select * from " + self.connectInfo["tap_schema"] + "." + self.connectInfo["columns_table"] + " where lower(table_name) = " \ + + placeholder if self.debug: logging.debug('') logging.debug(f'TAP_SCHEMA sql = {sql:s}') + logging.debug(f' param = {self.dbtable:s}') else: @@ -187,7 +199,10 @@ def __init__(self, conn, table, connectInfo, **kwargs): logging.debug(f'Internal DD table sql = {sql:s}') try: - cursor.execute(sql) + if self.ddtbl is None: + cursor.execute(sql, (self.dbtable,)) + else: + cursor.execute(sql) except Exception as e: diff --git a/TAP/propfilter.py b/TAP/propfilter.py index 436e105..3f9a0fc 100644 --- a/TAP/propfilter.py +++ b/TAP/propfilter.py @@ -12,6 +12,7 @@ from TAP.writeresult import writeResult from TAP.datadictionary import dataDictionary from TAP.tablenames import TableNames +from TAP.tablevalidator import TableValidator class propFilter: @@ -558,6 +559,44 @@ def __init__(self, **kwargs): logging.debug('') logging.debug(f'dbtable = [{self.dbtable:s}]') + # + # Defense-in-depth: validate tables against TAP_SCHEMA. + # + # Exclude server-configured internal tables (access control + # tables used by propfilter) from validation. These are NOT + # in TAP_SCHEMA.tables so they would be rejected, but + # propfilter needs them for proprietary-data filtering. + # Users still cannot query them directly: any query that + # enters through the non-propfilter path (tapQuery) will + # reject them because they are absent from TAP_SCHEMA. + # + + internal_tables = set() + if self.accesstbl: + internal_tables.add(self.accesstbl.lower()) + if self.usertbl: + internal_tables.add(self.usertbl.lower()) + + user_tables = [t for t in tables + if t.lower() not in internal_tables] + + if user_tables: + try: + validator = TableValidator(self.conn, + connectInfo=self.connectInfo, + debug=self.debug) + validator.validate(user_tables) + + except Exception as e: + + if self.debug: + logging.debug('') + logging.debug( + f'Table validation exception: {str(e):s}') + + self.msg = str(e) + raise Exception(self.msg) + # # Parse query: to extract query pieces for propfilter # diff --git a/TAP/tablenames.py b/TAP/tablenames.py index 6a9f9ca..6576d0b 100644 --- a/TAP/tablenames.py +++ b/TAP/tablenames.py @@ -45,7 +45,8 @@ def extract_from_part(self, parsed): for x in self.extract_from_part(item): yield x elif item.ttype is Keyword and item.value.upper() in \ - ['ORDER', 'GROUP', 'BY', 'HAVING', 'GROUP BY']: + ['ORDER', 'ORDER BY', 'GROUP', 'GROUP BY', + 'BY', 'HAVING', 'LIMIT', 'OFFSET']: from_seen = False StopIteration else: diff --git a/TAP/tablevalidator.py b/TAP/tablevalidator.py new file mode 100644 index 0000000..5cee308 --- /dev/null +++ b/TAP/tablevalidator.py @@ -0,0 +1,99 @@ +# Copyright (c) 2020, Caltech IPAC. +# This code is released with a BSD 3-clause license. License information is at +# https://github.com/Caltech-IPAC/nexsciTAP/blob/master/LICENSE + + +import logging + + +class TableValidator: + """ + Validates that table names in an ADQL query are registered in + the TAP_SCHEMA tables table, preventing access to unauthorized + database objects. + """ + + def __init__(self, conn, connectInfo=None, debug=0): + + self.conn = conn + self.debug = debug + + self.tap_schema = 'TAP_SCHEMA' + self.tables_table = 'tables' + + if connectInfo is not None: + if 'tap_schema' in connectInfo: + self.tap_schema = connectInfo['tap_schema'] + if 'tables_table' in connectInfo: + self.tables_table = connectInfo['tables_table'] + + self.allowed_tables = set() + self.allowed_bare = set() + self.allowed_schemas = set() + + self._load_allowed_tables() + + def _load_allowed_tables(self): + + cursor = self.conn.cursor() + + tables_ref = self.tap_schema + '.' + self.tables_table + + cursor.execute('SELECT table_name FROM ' + tables_ref) + rows = cursor.fetchall() + cursor.close() + + for row in rows: + full_name = row[0].strip().lower() + self.allowed_tables.add(full_name) + + if '.' in full_name: + schema, bare = full_name.split('.', 1) + self.allowed_bare.add(bare) + self.allowed_schemas.add(schema) + else: + self.allowed_bare.add(full_name) + + if self.debug: + logging.debug('') + logging.debug( + f'TableValidator: loaded {len(self.allowed_tables)} ' + f'allowed tables: {self.allowed_tables}') + + def validate(self, table_names): + + if not table_names: + raise Exception('No table names to validate.') + + for tname in table_names: + tname_lower = tname.strip().lower() + + # Exact match against full table names (e.g. "tap_schema.columns") + if tname_lower in self.allowed_tables: + continue + + if '.' in tname_lower: + schema, bare = tname_lower.split('.', 1) + + # Schema-qualified query table: only match if the schema + # is one we know about AND the bare name is allowed. + # This prevents "information_schema.tables" from matching + # just because "tables" is a bare name in TAP_SCHEMA. + if schema in self.allowed_schemas and \ + bare in self.allowed_bare: + continue + else: + # Unqualified query table: bare-name match is fine + # (e.g. query says "columns", whitelist has + # "tap_schema.columns") + if tname_lower in self.allowed_bare: + continue + + raise Exception( + f'Table \'{tname}\' is not available for querying. ' + f'Use TAP_SCHEMA.tables to see available tables.') + + if self.debug: + logging.debug('') + logging.debug( + f'TableValidator: all tables validated: {table_names}') diff --git a/TAP/tapquery.py b/TAP/tapquery.py index 2b1dd49..f825d34 100644 --- a/TAP/tapquery.py +++ b/TAP/tapquery.py @@ -14,6 +14,7 @@ from TAP.datadictionary import dataDictionary from TAP.writeresult import writeResult from TAP.tablenames import TableNames +from TAP.tablevalidator import TableValidator from TAP.configparam import configParam from ADQL.adql import ADQL @@ -563,6 +564,45 @@ def __init__(self, **kwargs): if self.debug: logging.debug(f'dbtable= [{self.dbtable:s}]') + # + # Defense-in-depth: validate tables against TAP_SCHEMA. + # + # Exclude server-configured internal tables (access control + # tables used by propfilter) from validation. These are NOT + # in TAP_SCHEMA.tables so they would be rejected, but the + # system needs them for proprietary-data filtering. + # + + if self.tap_schema.lower() != 'none': + + internal_tables = set() + accesstbl = self.connectInfo.get('accesstbl', '') + usertbl = self.connectInfo.get('usertbl', '') + if accesstbl: + internal_tables.add(accesstbl.lower()) + if usertbl: + internal_tables.add(usertbl.lower()) + + user_tables = [t for t in tables + if t.lower() not in internal_tables] + + if user_tables: + try: + validator = TableValidator( + self.conn, + connectInfo=self.connectInfo, + debug=self.debug) + validator.validate(user_tables) + + except Exception as e: + + if self.debug: + logging.debug('') + logging.debug( + f'Table validation exception: {str(e):s}') + + self.msg = str(e) + raise Exception(self.msg) # # Retrieve dd table diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tablevalidator.py b/tests/test_tablevalidator.py new file mode 100644 index 0000000..e5d0f0a --- /dev/null +++ b/tests/test_tablevalidator.py @@ -0,0 +1,121 @@ +import os +import sqlite3 +import tempfile +import unittest + +from TAP.tablevalidator import TableValidator + + +def _make_db(table_names, schema_name='TAP_SCHEMA', tables_table='tables'): + """Create SQLite DBs with TAP_SCHEMA attached, mimicking real setup. + + Returns (conn, tap_schema_path) — caller is responsible for cleanup. + """ + fd, tap_schema_path = tempfile.mkstemp(suffix='.db') + os.close(fd) + + schema_conn = sqlite3.connect(tap_schema_path) + schema_conn.execute('CREATE TABLE ' + tables_table + ' (table_name TEXT)') + for name in table_names: + schema_conn.execute( + 'INSERT INTO ' + tables_table + ' VALUES (?)', (name,)) + schema_conn.commit() + schema_conn.close() + + conn = sqlite3.connect(':memory:') + conn.execute( + 'ATTACH DATABASE ? AS ' + schema_name, (tap_schema_path,)) + + return conn, tap_schema_path + + +class TestTableValidator(unittest.TestCase): + + def setUp(self): + self.conn, self._tap_schema_path = _make_db([ + 'ps', + 'pscomppars', + 'stellarhosts', + 'TAP_SCHEMA.tables', + 'TAP_SCHEMA.columns', + 'TAP_SCHEMA.schemas', + 'cumulative', + ]) + self.connectInfo = { + 'tap_schema': 'TAP_SCHEMA', + 'tables_table': 'tables', + } + + def tearDown(self): + self.conn.close() + if os.path.exists(self._tap_schema_path): + os.unlink(self._tap_schema_path) + + def test_exact_match(self): + v = TableValidator(self.conn, connectInfo=self.connectInfo) + v.validate(['ps']) # should not raise + + def test_default_connectInfo(self): + """Works with no connectInfo (uses TAP_SCHEMA.tables default).""" + v = TableValidator(self.conn) + v.validate(['ps']) + + def test_case_insensitive(self): + v = TableValidator(self.conn, connectInfo=self.connectInfo) + v.validate(['PS']) + v.validate(['Ps']) + v.validate(['TAP_SCHEMA.Tables']) + + def test_schema_prefix_in_whitelist_bare_in_query(self): + """TAP_SCHEMA.columns is whitelisted; query says just 'columns'.""" + v = TableValidator(self.conn, connectInfo=self.connectInfo) + v.validate(['columns']) + + def test_bare_in_whitelist_unknown_schema_in_query(self): + """'ps' is whitelisted bare; 'public.ps' has unknown schema — rejected.""" + v = TableValidator(self.conn, connectInfo=self.connectInfo) + with self.assertRaises(Exception): + v.validate(['public.ps']) + + def test_known_schema_bare_table_in_query(self): + """'ps' is whitelisted bare; 'tap_schema.ps' uses known schema — allowed.""" + v = TableValidator(self.conn, connectInfo=self.connectInfo) + v.validate(['tap_schema.ps']) + + def test_disallowed_table_raises(self): + v = TableValidator(self.conn, connectInfo=self.connectInfo) + with self.assertRaises(Exception) as ctx: + v.validate(['pg_catalog.pg_tables']) + self.assertIn('not available', str(ctx.exception)) + + def test_multi_table_all_valid(self): + v = TableValidator(self.conn, connectInfo=self.connectInfo) + v.validate(['ps', 'pscomppars', 'stellarhosts']) + + def test_multi_table_one_invalid(self): + v = TableValidator(self.conn, connectInfo=self.connectInfo) + with self.assertRaises(Exception) as ctx: + v.validate(['ps', 'information_schema.tables', 'stellarhosts']) + self.assertIn('information_schema.tables', str(ctx.exception)) + + def test_empty_table_list_raises(self): + v = TableValidator(self.conn, connectInfo=self.connectInfo) + with self.assertRaises(Exception): + v.validate([]) + + def test_system_catalog_blocked(self): + v = TableValidator(self.conn, connectInfo=self.connectInfo) + for bad_table in [ + 'ALL_TABLES', + 'DBA_USERS', + 'V$SESSION', + 'information_schema.tables', + 'pg_catalog.pg_class', + 'EXOFOP.FILES', + ]: + with self.assertRaises(Exception, msg=f'{bad_table} should be blocked'): + v.validate([bad_table]) + + +if __name__ == '__main__': + unittest.main()