Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 60 additions & 16 deletions airflow_dbt_python/utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,30 +177,34 @@ class BaseConfig:
static: bool = False
upgrade: bool = False

require_model_names_without_spaces: bool = False
require_ref_searches_node_package_before_root: bool = False
exclude_resource_types: list[str] = dataclasses.field(
default_factory=list, repr=False
)

# legacy behaviors - https://github.com/dbt-labs/dbt-core/blob/main/docs/guides/behavior-change-flags.md
require_batched_execution_for_custom_microbatch_strategy: bool = False
require_event_names_in_deprecations: bool = False
require_explicit_package_overrides_for_builtin_materializations: bool = True
require_resource_names_without_spaces: bool = True
source_freshness_run_project_hooks: bool = True
skip_nodes_if_on_run_start_fails: bool = False
state_modified_compare_more_unrendered_values: bool = False
state_modified_compare_vars: bool = False
require_yaml_configuration_for_mf_time_spines: bool = False
require_nested_cumulative_type_params: bool = False
validate_macro_args: bool = False
require_all_warnings_handled_by_warn_error: bool = False
require_generic_test_arguments_property: bool = False
# Behavior change flags
# See: https://docs.getdbt.com/reference/global-configs/behavior-changes#behavior-change-flags
require_all_warnings_handled_by_warn_error: Optional[bool] = None
require_batched_execution_for_custom_microbatch_strategy: Optional[bool] = None
require_explicit_package_overrides_for_builtin_materializations: Optional[bool] = (
None
)
require_generic_test_arguments_property: Optional[bool] = None
require_nested_cumulative_type_params: Optional[bool] = None
require_ref_searches_node_package_before_root: Optional[bool] = None
require_resource_names_without_spaces: Optional[bool] = None
require_unique_project_resource_names: Optional[bool] = None
require_valid_schema_from_generate_schema_name: Optional[bool] = None
require_yaml_configuration_for_mf_time_spines: Optional[bool] = None
restrict_direct_pg_catalog_access: Optional[bool] = None
skip_nodes_if_on_run_start_fails: Optional[bool] = None
source_freshness_run_project_hooks: Optional[bool] = None
state_modified_compare_more_unrendered_values: Optional[bool] = None
validate_macro_args: Optional[bool] = None

def __post_init__(self):
"""Post initialization actions for a dbt configuration."""
self.vars = parse_yaml_args(self.vars)
self.set_flags_from_dbt_project_file()
self.set_mutually_exclusive_attributes()

def set_mutually_exclusive_attributes(self):
Expand Down Expand Up @@ -257,6 +261,46 @@ def set_mutually_exclusive_attributes(self):
else:
setattr(self, attr, not negative_value)

def set_flags_from_dbt_project_file(self):
"""Attempt to load configured flags from a project configuration file.

Dbt allows flags to be set in the configuration file. Since we create a project
here, we must attempt to load them when they are set.

Important to keep in mind dbt's precedence rules and not override anything
passed as an argument or set in an environment variable.
"""
if not self.project_dir:
return

dbt_project_path = Path(self.project_dir) / "dbt_project.yml"
if dbt_project_path.exists() is False:
dbt_project_path = Path(self.project_dir) / "dbt_project.yaml"
if dbt_project_path.exists() is False:
return

try:
with open(dbt_project_path) as dbt_project_yaml:
yaml = dbt_project_yaml.read()
contents = yaml_helper.load_yaml_text(yaml) or {}
except Exception:
return

if "flags" not in contents:
return

for flag_name, flag_value in contents["flags"].items():
current_value = getattr(self, flag_name, None)
env_value = os.getenv(f"DBT_{flag_name.upper()}", None)

if current_value is not None or env_value is not None:
# According to dbt config precedence rules, a value passed as argument
# or in the environment wins over values set in dbt project config.
# https://docs.getdbt.com/reference/global-configs/project-flags#config-precedence
continue

setattr(self, flag_name, flag_value)

def __getattribute__(self, item: str):
"""Dbt 1.5+ uses uppercase attributes, let's handle this."""
if item.isupper():
Expand Down
96 changes: 96 additions & 0 deletions tests/operators/test_dbt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,99 @@ def test_dbt_test_uses_correct_argument_according_to_version():
assert getattr(op, "models", None) is None
assert op.selector == "a-selector"
assert getattr(op, "selector_name", None) is None


GENERIC_TESTS_WITH_ARGUMENTS = """
version: 2

models:
- name: model_2
columns:
- name: field1
tests:
- not_null
- accepted_values:
arguments:
values: ['123', '456']
"""


@pytest.fixture(scope="function")
def generic_tests_files_with_arguments(
model_files, dbt_project_dir, generic_tests_files
):
"""Create a dbt generic test YAML file using arguments.

We also take care to preserve the existing ``schema.yml`` file and restore it
afterwards.
"""
schema_file = generic_tests_files[0]
schema_file_old = schema_file.rename(schema_file.with_suffix(".old"))
d = dbt_project_dir / "models"
d.mkdir(exist_ok=True, parents=True)

schema = d / "schema.yml"
schema.write_text(GENERIC_TESTS_WITH_ARGUMENTS)

yield [schema]

schema.unlink()
schema_file_old.rename(schema_file)


@pytest.fixture(scope="function")
def dbt_project_file_with_arguments_flag(
dbt_project_file, dbt_project_dir, logs_dir, request
):
"""Create a test dbt_project.yml file with flag to require arguments.

We also take care to preserve the existing ``dbt_project_file`` and restore it
afterwards.
"""
dbt_project_file_old = dbt_project_file.rename(dbt_project_file.with_suffix(".old"))

p = dbt_project_dir / "dbt_project.yml"
contents = """
name: test
profile: default
config-version: 2
version: 1.0.0
flags:
require_generic_test_arguments_property: true
"""
p.write_text(contents)

yield p

p.unlink()
dbt_project_file_old.rename(dbt_project_file)


def test_dbt_test_generic_tests_with_arguments(
profiles_file,
dbt_project_file_with_arguments_flag,
generic_tests_files_with_arguments,
hook,
):
"""Test a dbt test operator for a generic test with arguments.

Since the require_generic_test_arguments_property flag is enabled, this should pass.
"""
hook.run_dbt_task(
"run",
project_dir=dbt_project_file_with_arguments_flag.parent,
profiles_dir=profiles_file.parent,
)

op = DbtTestOperator(
task_id="dbt_task",
project_dir=dbt_project_file_with_arguments_flag.parent,
profiles_dir=profiles_file.parent,
generic=True,
)
results = op.execute({})

assert results["args"]["generic"] is True
assert len(results["results"]) == 2
for test_result in results["results"]:
assert test_result["status"] == TestStatus.Pass
55 changes: 55 additions & 0 deletions tests/utils/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,58 @@ def test_base_config_create_dbt_project_and_profile_with_no_profile(
target = profile.to_target_dict()
assert target["name"] == conn_id
assert target["type"] == "postgres"


@pytest.fixture(scope="function")
def dbt_project_file_with_flags(dbt_project_dir, logs_dir, request):
"""Create a test dbt_project.yml file."""
p = dbt_project_dir / "dbt_project.yml"
contents = """
name: test
profile: default
config-version: 2
version: 1.0.0
flags:
fail_fast: false
require_generic_test_arguments_property: true
"""
p.write_text(contents)

return p


def test_base_config_sets_flag_from_dbt_project_file(dbt_project_file_with_flags):
"""Test the configuration reads flags from dbt project file."""
config = BaseConfig(
project_dir=dbt_project_file_with_flags.parent,
)
assert config.fail_fast is False
assert config.require_generic_test_arguments_property is True


def test_base_config_does_not_override_when_value_passed(dbt_project_file_with_flags):
"""Test flags in dbt project file do not override any values passed."""
config = BaseConfig(
project_dir=dbt_project_file_with_flags.parent,
fail_fast=True,
require_generic_test_arguments_property=False,
)
assert config.fail_fast is True
assert config.require_generic_test_arguments_property is False


def test_base_config_does_not_override_when_value_in_environment(
dbt_project_file_with_flags,
):
"""Test flags in dbt project file do not override when environment values set."""
env = {
"DBT_FAIL_FAST": "1",
"DBT_REQUIRE_GENERIC_TEST_ARGUMENTS_PROPERTY": "0",
}

with patch.dict(os.environ, env):
config = BaseConfig(
project_dir=dbt_project_file_with_flags.parent,
)
assert config.fail_fast is None
assert config.require_generic_test_arguments_property is None
Loading