Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions TAP/configparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,8 @@ def __init__(self, path, **kwargs):
self.connectInfo['port'] = self.dbport
self.connectInfo['socket'] = self.socket
self.connectInfo['dbschema'] = self.dbschema
self.connectInfo['accesstbl'] = self.accesstbl
self.connectInfo['usertbl'] = self.usertbl

if self.debug:
logging.debug('')
Expand Down
21 changes: 18 additions & 3 deletions TAP/datadictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,24 @@ def __init__(self, conn, table, connectInfo, **kwargs):

if self.ddtbl == None:

sql = "select * from " + self.connectInfo["tap_schema"] + "." + self.connectInfo["columns_table"] + " where lower(table_name) = " + \
"'" + self.dbtable + "'"
# Detect placeholder style from connection type
conn_type = self.conn.__class__.__module__
if 'cx_Oracle' in conn_type or 'oracledb' in conn_type:
placeholder = ':1'
elif 'psycopg2' in conn_type:
placeholder = '%s'
elif 'mysql' in conn_type:
placeholder = '%s'
else:
placeholder = '?'

sql = "select * from " + self.connectInfo["tap_schema"] + "." + self.connectInfo["columns_table"] + " where lower(table_name) = " \
+ placeholder

if self.debug:
logging.debug('')
logging.debug(f'TAP_SCHEMA sql = {sql:s}')
logging.debug(f' param = {self.dbtable:s}')

else:

Expand All @@ -187,7 +199,10 @@ def __init__(self, conn, table, connectInfo, **kwargs):
logging.debug(f'Internal DD table sql = {sql:s}')

try:
cursor.execute(sql)
if self.ddtbl is None:
cursor.execute(sql, (self.dbtable,))
else:
cursor.execute(sql)

except Exception as e:

Expand Down
39 changes: 39 additions & 0 deletions TAP/propfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from TAP.writeresult import writeResult
from TAP.datadictionary import dataDictionary
from TAP.tablenames import TableNames
from TAP.tablevalidator import TableValidator


class propFilter:
Expand Down Expand Up @@ -558,6 +559,44 @@ def __init__(self, **kwargs):
logging.debug('')
logging.debug(f'dbtable = [{self.dbtable:s}]')

#
# Defense-in-depth: validate tables against TAP_SCHEMA.
#
# Exclude server-configured internal tables (access control
# tables used by propfilter) from validation. These are NOT
# in TAP_SCHEMA.tables so they would be rejected, but
# propfilter needs them for proprietary-data filtering.
# Users still cannot query them directly: any query that
# enters through the non-propfilter path (tapQuery) will
# reject them because they are absent from TAP_SCHEMA.
#

internal_tables = set()
if self.accesstbl:
internal_tables.add(self.accesstbl.lower())
if self.usertbl:
internal_tables.add(self.usertbl.lower())

user_tables = [t for t in tables
if t.lower() not in internal_tables]

if user_tables:
try:
validator = TableValidator(self.conn,
connectInfo=self.connectInfo,
debug=self.debug)
validator.validate(user_tables)

except Exception as e:

if self.debug:
logging.debug('')
logging.debug(
f'Table validation exception: {str(e):s}')

self.msg = str(e)
raise Exception(self.msg)

#
# Parse query: to extract query pieces for propfilter
#
Expand Down
3 changes: 2 additions & 1 deletion TAP/tablenames.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def extract_from_part(self, parsed):
for x in self.extract_from_part(item):
yield x
elif item.ttype is Keyword and item.value.upper() in \
['ORDER', 'GROUP', 'BY', 'HAVING', 'GROUP BY']:
['ORDER', 'ORDER BY', 'GROUP', 'GROUP BY',
'BY', 'HAVING', 'LIMIT', 'OFFSET']:
from_seen = False
StopIteration
else:
Expand Down
99 changes: 99 additions & 0 deletions TAP/tablevalidator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2020, Caltech IPAC.
# This code is released with a BSD 3-clause license. License information is at
# https://github.com/Caltech-IPAC/nexsciTAP/blob/master/LICENSE


import logging


class TableValidator:
"""
Validates that table names in an ADQL query are registered in
the TAP_SCHEMA tables table, preventing access to unauthorized
database objects.
"""

def __init__(self, conn, connectInfo=None, debug=0):

self.conn = conn
self.debug = debug

self.tap_schema = 'TAP_SCHEMA'
self.tables_table = 'tables'

if connectInfo is not None:
if 'tap_schema' in connectInfo:
self.tap_schema = connectInfo['tap_schema']
if 'tables_table' in connectInfo:
self.tables_table = connectInfo['tables_table']

self.allowed_tables = set()
self.allowed_bare = set()
self.allowed_schemas = set()

self._load_allowed_tables()

def _load_allowed_tables(self):

cursor = self.conn.cursor()

tables_ref = self.tap_schema + '.' + self.tables_table

cursor.execute('SELECT table_name FROM ' + tables_ref)
rows = cursor.fetchall()
cursor.close()

for row in rows:
full_name = row[0].strip().lower()
self.allowed_tables.add(full_name)

if '.' in full_name:
schema, bare = full_name.split('.', 1)
self.allowed_bare.add(bare)
self.allowed_schemas.add(schema)
else:
self.allowed_bare.add(full_name)

if self.debug:
logging.debug('')
logging.debug(
f'TableValidator: loaded {len(self.allowed_tables)} '
f'allowed tables: {self.allowed_tables}')

def validate(self, table_names):

if not table_names:
raise Exception('No table names to validate.')

for tname in table_names:
tname_lower = tname.strip().lower()

# Exact match against full table names (e.g. "tap_schema.columns")
if tname_lower in self.allowed_tables:
continue

if '.' in tname_lower:
schema, bare = tname_lower.split('.', 1)

# Schema-qualified query table: only match if the schema
# is one we know about AND the bare name is allowed.
# This prevents "information_schema.tables" from matching
# just because "tables" is a bare name in TAP_SCHEMA.
if schema in self.allowed_schemas and \
bare in self.allowed_bare:
continue
else:
# Unqualified query table: bare-name match is fine
# (e.g. query says "columns", whitelist has
# "tap_schema.columns")
if tname_lower in self.allowed_bare:
continue

raise Exception(
f'Table \'{tname}\' is not available for querying. '
f'Use TAP_SCHEMA.tables to see available tables.')

if self.debug:
logging.debug('')
logging.debug(
f'TableValidator: all tables validated: {table_names}')
40 changes: 40 additions & 0 deletions TAP/tapquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from TAP.datadictionary import dataDictionary
from TAP.writeresult import writeResult
from TAP.tablenames import TableNames
from TAP.tablevalidator import TableValidator
from TAP.configparam import configParam

from ADQL.adql import ADQL
Expand Down Expand Up @@ -563,6 +564,45 @@ def __init__(self, **kwargs):
if self.debug:
logging.debug(f'dbtable= [{self.dbtable:s}]')

#
# Defense-in-depth: validate tables against TAP_SCHEMA.
#
# Exclude server-configured internal tables (access control
# tables used by propfilter) from validation. These are NOT
# in TAP_SCHEMA.tables so they would be rejected, but the
# system needs them for proprietary-data filtering.
#

if self.tap_schema.lower() != 'none':

internal_tables = set()
accesstbl = self.connectInfo.get('accesstbl', '')
usertbl = self.connectInfo.get('usertbl', '')
if accesstbl:
internal_tables.add(accesstbl.lower())
if usertbl:
internal_tables.add(usertbl.lower())

user_tables = [t for t in tables
if t.lower() not in internal_tables]

if user_tables:
try:
validator = TableValidator(
self.conn,
connectInfo=self.connectInfo,
debug=self.debug)
validator.validate(user_tables)

except Exception as e:

if self.debug:
logging.debug('')
logging.debug(
f'Table validation exception: {str(e):s}')

self.msg = str(e)
raise Exception(self.msg)

#
# Retrieve dd table
Expand Down
Empty file added tests/__init__.py
Empty file.
121 changes: 121 additions & 0 deletions tests/test_tablevalidator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
import sqlite3
import tempfile
import unittest

from TAP.tablevalidator import TableValidator


def _make_db(table_names, schema_name='TAP_SCHEMA', tables_table='tables'):
"""Create SQLite DBs with TAP_SCHEMA attached, mimicking real setup.

Returns (conn, tap_schema_path) — caller is responsible for cleanup.
"""
fd, tap_schema_path = tempfile.mkstemp(suffix='.db')
os.close(fd)

schema_conn = sqlite3.connect(tap_schema_path)
schema_conn.execute('CREATE TABLE ' + tables_table + ' (table_name TEXT)')
for name in table_names:
schema_conn.execute(
'INSERT INTO ' + tables_table + ' VALUES (?)', (name,))
schema_conn.commit()
schema_conn.close()

conn = sqlite3.connect(':memory:')
conn.execute(
'ATTACH DATABASE ? AS ' + schema_name, (tap_schema_path,))

return conn, tap_schema_path


class TestTableValidator(unittest.TestCase):

def setUp(self):
self.conn, self._tap_schema_path = _make_db([
'ps',
'pscomppars',
'stellarhosts',
'TAP_SCHEMA.tables',
'TAP_SCHEMA.columns',
'TAP_SCHEMA.schemas',
'cumulative',
])
self.connectInfo = {
'tap_schema': 'TAP_SCHEMA',
'tables_table': 'tables',
}

def tearDown(self):
self.conn.close()
if os.path.exists(self._tap_schema_path):
os.unlink(self._tap_schema_path)

def test_exact_match(self):
v = TableValidator(self.conn, connectInfo=self.connectInfo)
v.validate(['ps']) # should not raise

def test_default_connectInfo(self):
"""Works with no connectInfo (uses TAP_SCHEMA.tables default)."""
v = TableValidator(self.conn)
v.validate(['ps'])

def test_case_insensitive(self):
v = TableValidator(self.conn, connectInfo=self.connectInfo)
v.validate(['PS'])
v.validate(['Ps'])
v.validate(['TAP_SCHEMA.Tables'])

def test_schema_prefix_in_whitelist_bare_in_query(self):
"""TAP_SCHEMA.columns is whitelisted; query says just 'columns'."""
v = TableValidator(self.conn, connectInfo=self.connectInfo)
v.validate(['columns'])

def test_bare_in_whitelist_unknown_schema_in_query(self):
"""'ps' is whitelisted bare; 'public.ps' has unknown schema — rejected."""
v = TableValidator(self.conn, connectInfo=self.connectInfo)
with self.assertRaises(Exception):
v.validate(['public.ps'])

def test_known_schema_bare_table_in_query(self):
"""'ps' is whitelisted bare; 'tap_schema.ps' uses known schema — allowed."""
v = TableValidator(self.conn, connectInfo=self.connectInfo)
v.validate(['tap_schema.ps'])

def test_disallowed_table_raises(self):
v = TableValidator(self.conn, connectInfo=self.connectInfo)
with self.assertRaises(Exception) as ctx:
v.validate(['pg_catalog.pg_tables'])
self.assertIn('not available', str(ctx.exception))

def test_multi_table_all_valid(self):
v = TableValidator(self.conn, connectInfo=self.connectInfo)
v.validate(['ps', 'pscomppars', 'stellarhosts'])

def test_multi_table_one_invalid(self):
v = TableValidator(self.conn, connectInfo=self.connectInfo)
with self.assertRaises(Exception) as ctx:
v.validate(['ps', 'information_schema.tables', 'stellarhosts'])
self.assertIn('information_schema.tables', str(ctx.exception))

def test_empty_table_list_raises(self):
v = TableValidator(self.conn, connectInfo=self.connectInfo)
with self.assertRaises(Exception):
v.validate([])

def test_system_catalog_blocked(self):
v = TableValidator(self.conn, connectInfo=self.connectInfo)
for bad_table in [
'ALL_TABLES',
'DBA_USERS',
'V$SESSION',
'information_schema.tables',
'pg_catalog.pg_class',
'EXOFOP.FILES',
]:
with self.assertRaises(Exception, msg=f'{bad_table} should be blocked'):
v.validate([bad_table])


if __name__ == '__main__':
unittest.main()