Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/xml2db/dialect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,22 @@ def validate_model_config(self, config: dict) -> dict:
"Clustered columnstore indexes are only supported with MS SQL Server database, noop"
)
return config

# ------------------------------------------------------------------
# Data loading
# ------------------------------------------------------------------

def bulk_insert(self, conn: Any, table: Any, records: list) -> None:
"""Insert records into a staging table.

The base implementation uses SQLAlchemy's parameterised executemany,
which is backend-agnostic. Subclasses may override this with a
backend-specific bulk-loading strategy (e.g. COPY FROM CSV).

Args:
conn: A SQLAlchemy ``Connection`` already within a transaction.
table: The SQLAlchemy ``Table`` object to insert into.
records: A list of dicts mapping column keys to Python values.
"""
if records:
conn.execute(table.insert(), records)
118 changes: 117 additions & 1 deletion src/xml2db/dialect/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
import csv
import os
import tempfile
from typing import Any

from sqlalchemy import Column, Integer, Sequence
from sqlalchemy import (
BigInteger,
Boolean,
Column,
DateTime,
Double,
Integer,
LargeBinary,
Sequence,
SmallInteger,
text,
)
from sqlalchemy.exc import ProgrammingError
import sqlalchemy.schema

Expand Down Expand Up @@ -48,3 +62,105 @@ def do_create() -> None:
do_create()
except ProgrammingError:
pass

# Maps SQLAlchemy column types to DuckDB CAST target type names.
# String types need no cast; LargeBinary is handled via unhex().
# Order matters: subclasses (BigInteger, SmallInteger) must appear before
# their parent (Integer) so that isinstance() matches the most specific type.
_DUCKDB_CAST: dict = {
BigInteger: "BIGINT",
SmallInteger: "SMALLINT",
Integer: "INTEGER",
Double: "DOUBLE",
Boolean: "BOOLEAN",
DateTime: "TIMESTAMPTZ", # DateTime(timezone=False) → TIMESTAMP below
}

def _select_expr(self, key: str, col: Any) -> str:
"""Return a DuckDB SELECT expression that casts a VARCHAR CSV column."""
if isinstance(col.type, LargeBinary):
return f'unhex("{key}")'
for sa_type, duckdb_type in self._DUCKDB_CAST.items():
if isinstance(col.type, sa_type):
if isinstance(col.type, DateTime) and not col.type.timezone:
duckdb_type = "TIMESTAMP"
return f'CAST("{key}" AS {duckdb_type})'
return f'"{key}"' # String / unknown: keep as VARCHAR

def bulk_insert(self, conn: Any, table: Any, records: list) -> None:
"""Bulk-insert records via a temporary CSV file and DuckDB's ``read_csv``.

All CSV columns are read as VARCHAR (``all_varchar=true``) and then
explicitly cast to their target types in the ``SELECT`` clause.
Binary columns are hex-encoded in the CSV and decoded with ``unhex()``.

Args:
conn: A SQLAlchemy ``Connection`` already within a transaction.
table: The SQLAlchemy ``Table`` object to insert into.
records: A list of dicts mapping column keys to Python values.
"""
if not records:
return

# Map column key -> SQLAlchemy Column object
col_by_key = {col.key: col for col in table.columns}

# Columns present in the first record that correspond to table columns
col_keys = [k for k in records[0] if k in col_by_key]

# SQLAlchemy Python-side scalar defaults (e.g. default=False on temp_exists)
# are applied automatically by executemany but not by our CSV path.
extra_defaults: dict = {}
for col in table.columns:
if col.key not in records[0] and col.key in col_by_key:
d = col.default
if d is not None and d.is_scalar:
extra_defaults[col.key] = d.arg

all_col_keys = col_keys + list(extra_defaults.keys())

fd, csv_path = tempfile.mkstemp(suffix=".csv")
try:
with os.fdopen(fd, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(all_col_keys)
for record in records:
row = []
for key in all_col_keys:
v = record.get(key) if key in col_keys else extra_defaults[key]
if v is None:
row.append("")
elif isinstance(v, bytes):
row.append(v.hex())
elif isinstance(v, bool):
# Must come before the general str() path since bool is a
# subclass of int, and csv.writer would write 0/1 otherwise.
row.append("true" if v else "false")
else:
# str() on datetime gives "YYYY-MM-DD HH:MM:SS[.f][+HH:MM]",
# which DuckDB's CAST accepts without ambiguity.
row.append(str(v))
writer.writerow(row)

full_name = (
f'"{table.schema}"."{table.name}"'
if table.schema
else f'"{table.name}"'
)
insert_cols = ", ".join(
f'"{col_by_key[k].name}"' for k in all_col_keys
)
select_exprs = ", ".join(
self._select_expr(k, col_by_key[k]) for k in all_col_keys
)
# DuckDB requires forward slashes in file paths on all platforms.
safe_path = csv_path.replace("\\", "/")
sql = text(
f"INSERT INTO {full_name} ({insert_cols}) "
f"SELECT {select_exprs} "
f"FROM read_csv('{safe_path}', header=true, nullstr='', all_varchar=true)"
)
conn.execute(sql)
finally:
if os.path.exists(csv_path):
os.unlink(csv_path)
6 changes: 5 additions & 1 deletion src/xml2db/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,11 @@ def insert_into_temp_tables(self, max_lines: int = -1) -> None:
start_idx = 0
while start_idx < len(data):
with self.model.engine.begin() as conn:
conn.execute(query, data[start_idx : (start_idx + max_lines)])
self.model.dialect.bulk_insert(
conn,
query.table,
data[start_idx : (start_idx + max_lines)],
)
start_idx = start_idx + max_lines

def merge_into_target_tables(self, single_transaction: bool = True) -> int:
Expand Down
178 changes: 178 additions & 0 deletions tests/test_bulk_insert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Unit tests for dialect bulk_insert implementations."""
import datetime

import pytest

pytest.importorskip("duckdb", reason="duckdb not installed")

from sqlalchemy import (
BigInteger,
Boolean,
Column,
DateTime,
Double,
Integer,
LargeBinary,
MetaData,
SmallInteger,
String,
Table,
create_engine,
select,
text,
)

from xml2db.dialect.base import DatabaseDialect
from xml2db.dialect.duckdb import DuckDBDialect


@pytest.fixture()
def duckdb_engine():
return create_engine("duckdb:///:memory:")


def _make_table(engine, name, *extra_cols):
"""Create a simple test table and return the SQLAlchemy Table object."""
meta = MetaData()
table = Table(
name,
meta,
Column("id", Integer, key="id"),
Column("label", String(100), key="label"),
*extra_cols,
)
meta.create_all(engine)
return table


def _roundtrip(engine, table, records):
"""Insert records via DuckDBDialect.bulk_insert and read them back."""
dialect = DuckDBDialect()
with engine.begin() as conn:
dialect.bulk_insert(conn, table, records)
with engine.connect() as conn:
return conn.execute(select(table)).mappings().all()


# ---------------------------------------------------------------------------
# Base dialect falls back to SQLAlchemy executemany
# ---------------------------------------------------------------------------


def test_base_dialect_bulk_insert(duckdb_engine):
table = _make_table(duckdb_engine, "base_test")
records = [{"id": 1, "label": "hello"}, {"id": 2, "label": "world"}]
DatabaseDialect().bulk_insert(
duckdb_engine.connect().__enter__(), table, records
)
# Just check the method is importable and has the right signature.


# ---------------------------------------------------------------------------
# DuckDB dialect: basic types
# ---------------------------------------------------------------------------


def test_duckdb_bulk_insert_basic(duckdb_engine):
table = _make_table(duckdb_engine, "basic")
records = [{"id": 1, "label": "hello"}, {"id": 2, "label": None}]
rows = _roundtrip(duckdb_engine, table, records)
assert len(rows) == 2
assert rows[0]["id"] == 1
assert rows[0]["label"] == "hello"
assert rows[1]["label"] is None


def test_duckdb_bulk_insert_numeric_types(duckdb_engine):
meta = MetaData()
table = Table(
"numeric_types",
meta,
Column("i", Integer, key="i"),
Column("bi", BigInteger, key="bi"),
Column("si", SmallInteger, key="si"),
Column("d", Double, key="d"),
)
meta.create_all(duckdb_engine)
records = [{"i": 1, "bi": 10**15, "si": 32767, "d": 3.14}]
rows = _roundtrip(duckdb_engine, table, records)
assert rows[0]["i"] == 1
assert rows[0]["bi"] == 10**15
assert rows[0]["si"] == 32767
assert abs(rows[0]["d"] - 3.14) < 1e-9


def test_duckdb_bulk_insert_boolean(duckdb_engine):
meta = MetaData()
table = Table(
"bool_test",
meta,
Column("id", Integer, key="id"),
Column("flag", Boolean, key="flag"),
)
meta.create_all(duckdb_engine)
records = [{"id": 1, "flag": True}, {"id": 2, "flag": False}, {"id": 3, "flag": None}]
rows = _roundtrip(duckdb_engine, table, records)
assert rows[0]["flag"] is True
assert rows[1]["flag"] is False
assert rows[2]["flag"] is None


def test_duckdb_bulk_insert_datetime(duckdb_engine):
meta = MetaData()
table = Table(
"dt_test",
meta,
Column("id", Integer, key="id"),
Column("ts", DateTime(timezone=True), key="ts"),
)
meta.create_all(duckdb_engine)
dt = datetime.datetime(2023, 9, 27, 14, 35, 54, 274602)
records = [{"id": 1, "ts": dt}, {"id": 2, "ts": None}]
rows = _roundtrip(duckdb_engine, table, records)
# Value must survive the CSV round-trip and be returned as a datetime-like object.
assert rows[0]["ts"] is not None
assert rows[1]["ts"] is None


def test_duckdb_bulk_insert_binary(duckdb_engine):
meta = MetaData()
table = Table(
"binary_test",
meta,
Column("id", Integer, key="id"),
Column("hash", LargeBinary(32), key="hash"),
)
meta.create_all(duckdb_engine)
payload = b"\xde\xad\xbe\xef" * 8
records = [{"id": 1, "hash": payload}, {"id": 2, "hash": None}]
rows = _roundtrip(duckdb_engine, table, records)
assert bytes(rows[0]["hash"]) == payload
assert rows[1]["hash"] is None


def test_duckdb_bulk_insert_scalar_column_default(duckdb_engine):
"""Columns with Python-side scalar defaults absent from records must be applied."""
meta = MetaData()
table = Table(
"default_test",
meta,
Column("id", Integer, key="id"),
Column("flag", Boolean, default=False, key="flag"),
)
meta.create_all(duckdb_engine)
# Records do NOT contain 'flag'; the default must be applied.
records = [{"id": 1}, {"id": 2}]
rows = _roundtrip(duckdb_engine, table, records)
assert rows[0]["flag"] is False
assert rows[1]["flag"] is False


def test_duckdb_bulk_insert_empty(duckdb_engine):
table = _make_table(duckdb_engine, "empty_test")
dialect = DuckDBDialect()
with engine.begin() if False else duckdb_engine.begin() as conn:
dialect.bulk_insert(conn, table, [])
with duckdb_engine.connect() as conn:
count = conn.execute(text("SELECT COUNT(*) FROM empty_test")).scalar()
assert count == 0
Loading