diff --git a/nac_test/pyats_core/constants.py b/nac_test/pyats_core/constants.py index 6522ed3c..c9a4f052 100644 --- a/nac_test/pyats_core/constants.py +++ b/nac_test/pyats_core/constants.py @@ -37,6 +37,10 @@ # PyATS-specific file paths AUTH_CACHE_DIR: str = os.path.join(tempfile.gettempdir(), "nac-test-auth-cache") +# PyATS config files written to output directory during test execution +PYATS_PLUGIN_CONFIG_FILENAME: str = ".pyats_plugin.yaml" +PYATS_CONFIG_FILENAME: str = ".pyats.conf" + # pushed to pyats device connection settings to speed up disconnects (default is 10s/1s) PYATS_POST_DISCONNECT_WAIT_SECONDS: int = 0 PYATS_GRACEFUL_DISCONNECT_WAIT_SECONDS: int = 0 diff --git a/nac_test/pyats_core/execution/subprocess_runner.py b/nac_test/pyats_core/execution/subprocess_runner.py index 003092f0..8ee23fb4 100644 --- a/nac_test/pyats_core/execution/subprocess_runner.py +++ b/nac_test/pyats_core/execution/subprocess_runner.py @@ -8,7 +8,6 @@ import logging import os import sysconfig -import tempfile import textwrap import time from collections.abc import Callable @@ -19,13 +18,33 @@ from nac_test.pyats_core.constants import ( PIPE_DRAIN_DELAY_SECONDS, PIPE_DRAIN_TIMEOUT_SECONDS, + PYATS_CONFIG_FILENAME, PYATS_OUTPUT_BUFFER_LIMIT, + PYATS_PLUGIN_CONFIG_FILENAME, ) from nac_test.utils.logging import DEFAULT_LOGLEVEL, LogLevel logger = logging.getLogger(__name__) +# disable EnvironmentDebugPlugin to prevent sensitive environment vars +# from being logged by PyATS +PLUGIN_CONFIG = textwrap.dedent("""\ + plugins: + ProgressReporterPlugin: + enabled: True + module: nac_test.pyats_core.progress.plugin + order: 1.0 + EnvironmentDebugPlugin: + enabled: False + """) + +PYATS_CONFIG = textwrap.dedent("""\ + [report] + git_info = false + """) + + class SubprocessRunner: """Executes PyATS jobs as subprocesses and handles their output.""" @@ -33,7 +52,6 @@ def __init__( self, output_dir: Path, output_handler: Callable[[str], None], - plugin_config_path: Path | None = None, loglevel: LogLevel = DEFAULT_LOGLEVEL, ): """Initialize the subprocess runner. @@ -41,12 +59,10 @@ def __init__( Args: output_dir: Directory for test output output_handler: Function to process each line of stdout - plugin_config_path: Path to the PyATS plugin configuration file loglevel: Logging level to pass to PyATS CLI """ self.output_dir = output_dir self.output_handler = output_handler - self.plugin_config_path = plugin_config_path self.loglevel = loglevel # Ensure pyats is in the same environment as nac-test @@ -57,11 +73,61 @@ def __init__( ) self.pyats_executable = str(pyats_path) + self._plugin_config_file: Path | None = None + self._pyats_config_file: Path | None = None + self._create_config_files() + + def _create_config_files(self) -> None: + """Create config files for PyATS execution in the output directory. + + Raises: + RuntimeError: If file creation fails + """ + plugin_config_file = self.output_dir / PYATS_PLUGIN_CONFIG_FILENAME + pyats_config_file = self.output_dir / PYATS_CONFIG_FILENAME + + try: + plugin_config_file.write_text(PLUGIN_CONFIG) + pyats_config_file.write_text(PYATS_CONFIG) + except OSError as e: + # Clean up any successfully written files before raising + plugin_config_file.unlink(missing_ok=True) + pyats_config_file.unlink(missing_ok=True) + raise RuntimeError(f"Failed to create PyATS config files: {e}") from e + + self._plugin_config_file = plugin_config_file + self._pyats_config_file = pyats_config_file + logger.debug(f"Created plugin_config {self._plugin_config_file}") + logger.debug(f"Created pyats_config {self._pyats_config_file}") + + def cleanup(self) -> None: + """Remove config files created during initialization. + + Called explicitly by the orchestrator after normal execution. Also called + opportunistically from __del__ for unexpected exits (best-effort only). + """ + for config_file in [self._plugin_config_file, self._pyats_config_file]: + if config_file is not None: + config_file.unlink(missing_ok=True) + logger.debug(f"Cleaned up config file: {config_file}") + + def __del__(self) -> None: + """Opportunistic cleanup on garbage collection. + + Not guaranteed: CPython-specific, not called on SIGKILL or interpreter shutdown. + Handles unexpected exits without complicating call sites in the orchestrator. + A more robust cleanup mechanism will be implemented as part of #677 (which + primarily targets a different file, but config file cleanup will be included) — + until then, leaked config files are acceptable (contents not sensitive, files small). + """ + try: + self.cleanup() + except Exception: + pass # Best-effort: never raise from __del__ + def _build_command( self, job_file_path: Path, - plugin_config_file: str, - pyats_config_file: str, archive_name: str, testbed_file_path: Path | None = None, ) -> list[str]: @@ -69,8 +135,6 @@ def _build_command( Args: job_file_path: Path to the job file - plugin_config_file: Path to the plugin configuration file - pyats_config_file: Path to the PyATS configuration file archive_name: Name for the archive file testbed_file_path: Optional path to the testbed file (for D2D tests) @@ -87,12 +151,20 @@ def _build_command( if testbed_file_path is not None: cmd.extend(["--testbed-file", str(testbed_file_path)]) + # Unreachable in practice: _create_config_files() always sets these in __init__, + # or raises before __init__ completes. Guard exists for mypy type narrowing only. + if ( + self._plugin_config_file is None or self._pyats_config_file is None + ): # pragma: no cover + raise RuntimeError( + "Config files not initialized — this is a bug in SubprocessRunner." + ) cmd.extend( [ "--configuration", - plugin_config_file, + str(self._plugin_config_file), "--pyats-configuration", - pyats_config_file, + str(self._pyats_config_file), "--archive-dir", str(self.output_dir), "--archive-name", @@ -131,51 +203,12 @@ async def execute_job( Returns: Path to the archive file if successful, None otherwise """ - # Create plugin configuration for progress reporting - plugin_config_file = None - pyats_config_file = None - try: - plugin_config = textwrap.dedent(""" - plugins: - ProgressReporterPlugin: - enabled: True - module: nac_test.pyats_core.progress.plugin - order: 1.0 - EnvironmentDebugPlugin: - enabled: False - """) - - with tempfile.NamedTemporaryFile( - mode="w", suffix="_plugin_config.yaml", delete=False - ) as f: - f.write(plugin_config) - plugin_config_file = f.name - logger.debug( - f"Created plugin_config {plugin_config_file} with content\n{plugin_config}" - ) - - # Create PyATS configuration to disable git_info collection - # This prevents fork() crashes on macOS with Python 3.12+ caused by - # CoreFoundation lock corruption in get_git_info() - pyats_config = "[report]\ngit_info = false\n" - with tempfile.NamedTemporaryFile( - mode="w", suffix="_pyats_config.conf", delete=False - ) as f: - f.write(pyats_config) - pyats_config_file = f.name - logger.debug(f"Created pyats_config {pyats_config_file}") - - except Exception as e: - logger.warning(f"Failed to create config files: {e}") - # If we can't create config files, we should probably fail - return None - - # Generate archive name with timestamp job_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] archive_name = f"nac_test_job_{job_timestamp}.zip" cmd = self._build_command( - job_file_path, plugin_config_file, pyats_config_file, archive_name + job_file_path, + archive_name, ) logger.info(f"Executing command: {' '.join(cmd)}") @@ -241,49 +274,12 @@ async def execute_job_with_testbed( Returns: Path to the archive file if successful, None otherwise """ - # Create plugin configuration for progress reporting - plugin_config_file = None - pyats_config_file = None - try: - plugin_config = textwrap.dedent(""" - plugins: - ProgressReporterPlugin: - enabled: True - module: nac_test.pyats_core.progress.plugin - order: 1.0 - EnvironmentDebugPlugin: - enabled: False - """) - - with tempfile.NamedTemporaryFile( - mode="w", suffix="_plugin_config.yaml", delete=False - ) as f: - f.write(plugin_config) - plugin_config_file = f.name - - # Create PyATS configuration to disable git_info collection - # This prevents fork() crashes on macOS with Python 3.12+ caused by - # CoreFoundation lock corruption in get_git_info() - pyats_config = "[report]\ngit_info = false\n" - with tempfile.NamedTemporaryFile( - mode="w", suffix="_pyats_config.conf", delete=False - ) as f: - f.write(pyats_config) - pyats_config_file = f.name - - except Exception as e: - logger.warning(f"Failed to create config files: {e}") - # If we can't create config files, we should probably fail - return None - # Get device ID from environment for archive naming hostname = env.get("HOSTNAME", "unknown") archive_name = f"pyats_archive_device_{hostname}" cmd = self._build_command( job_file_path, - plugin_config_file, - pyats_config_file, archive_name, testbed_file_path=testbed_file_path, ) diff --git a/nac_test/pyats_core/orchestrator.py b/nac_test/pyats_core/orchestrator.py index 62a5f7c4..047dbb03 100644 --- a/nac_test/pyats_core/orchestrator.py +++ b/nac_test/pyats_core/orchestrator.py @@ -53,7 +53,6 @@ from nac_test.utils.logging import DEFAULT_LOGLEVEL, LogLevel from nac_test.utils.system_resources import SystemResourceCalculator from nac_test.utils.terminal import terminal -from nac_test.utils.yaml import dump_to_stream logger = logging.getLogger(__name__) @@ -168,51 +167,6 @@ def _calculate_workers(self) -> int: return cpu_workers - def _build_reporter_config(self) -> dict[str, Any]: - """Build the configuration for PyATS reporters. - - This centralizes the reporter setup to use an asynchronous QueueHandler - which puts all incoming reporting messages into a queue and lets a - separate thread handle the slow disk I/O. This makes the ReportServer - non-blocking and prevents client timeouts under heavy load. - - Returns: - A dictionary representing the reporter configuration. - """ - return { - "reporter": { - "server": { - "handlers": { - "fh": { - "class": "pyats.reporter.handlers.FileHandler", - }, - "qh": { - "class": "pyats.reporter.handlers.QueueHandler", - "handlers": ["fh"], - }, - } - }, - "root": { - "handlers": ["qh"], - }, - } - } - - def _generate_plugin_config(self, temp_dir: Path) -> Path: - """Generate the PyATS plugin configuration file. - - Args: - temp_dir: The temporary directory to write the file in. - - Returns: - The path to the generated configuration file. - """ - reporter_config = self._build_reporter_config() - config_path = temp_dir / "plugin_config.yaml" - with open(config_path, "w") as f: - dump_to_stream(reporter_config, f) - return config_path - def _populate_test_status_from_archive(self, archive_path: Path) -> None: """Populate test_status from archive results.json when progress events are missing. @@ -667,59 +621,57 @@ async def _run_tests_async(self) -> PyATSResults: loglevel=self.loglevel, ) # Archives should be stored at base level, not in pyats_results subdirectory - self.subprocess_runner = SubprocessRunner( - self.base_output_dir, - output_handler=self.output_processor.process_line, - loglevel=self.loglevel, - ) - # Generate the plugin config and pass it to the runner - with tempfile.TemporaryDirectory() as temp_dir: - plugin_config_path = self._generate_plugin_config(Path(temp_dir)) - if self.subprocess_runner is not None: - self.subprocess_runner.plugin_config_path = plugin_config_path - - # Execute tests based on their type - tasks = [] - - if api_tests: - tasks.append(self._execute_api_tests_standard(api_tests)) - - if d2d_tests: - # Get device inventory for D2D tests - devices = self.device_inventory_discovery.get_device_inventory( - d2d_tests - ) + try: + self.subprocess_runner = SubprocessRunner( + self.base_output_dir, + output_handler=self.output_processor.process_line, + loglevel=self.loglevel, + ) + except RuntimeError as e: + # pyats entrypoint not found or config file creation failed. + error_msg = str(e) + api_result = TestResults.from_error(error_msg) if api_tests else None + d2d_result = TestResults.from_error(error_msg) if d2d_tests else None + return PyATSResults(api=api_result, d2d=d2d_result) - # Display any skipped devices - skipped = self.device_inventory_discovery.skipped_devices - if skipped: - print() # Blank line before warnings - for skip_info in skipped: - device_id = skip_info.get("device_id", "") - reason = skip_info.get("reason", "Unknown error") - print( - terminal.warning( - f"WARNING - Skipping device {device_id}: {reason}" - ) - ) - print() # Blank line after warnings + # Execute tests based on their type + tasks = [] - if devices: - tasks.append( - self._execute_ssh_tests_device_centric(d2d_tests, devices) - ) - else: + if api_tests: + tasks.append(self._execute_api_tests_standard(api_tests)) + + if d2d_tests: + # Get device inventory for D2D tests + devices = self.device_inventory_discovery.get_device_inventory(d2d_tests) + + # Display any skipped devices + skipped = self.device_inventory_discovery.skipped_devices + if skipped: + print() # Blank line before warnings + for skip_info in skipped: + device_id = skip_info.get("device_id", "") + reason = skip_info.get("reason", "Unknown error") print( terminal.warning( - "No devices found in inventory. D2D tests will be skipped." + f"WARNING - Skipping device {device_id}: {reason}" ) ) + print() # Blank line after warnings - # Run all test types in parallel - if tasks: - await asyncio.gather(*tasks) + if devices: + tasks.append(self._execute_ssh_tests_device_centric(d2d_tests, devices)) else: - print("No tests to execute after categorization") + print( + terminal.warning( + "No devices found in inventory. D2D tests will be skipped." + ) + ) + + # Run all test types in parallel + if tasks: + await asyncio.gather(*tasks) + else: + print("No tests to execute after categorization") # Split test_status into api_test_status and d2d_test_status based on test type. # OutputProcessor correctly parses results for ALL tests into test_status. @@ -769,6 +721,16 @@ async def _run_tests_async(self) -> PyATSResults: return pyats_results + def _cleanup_subprocess_runner(self, keep_artifacts: bool) -> None: + """Explicitly clean up config files created by SubprocessRunner. + + Called after report generation completes (success or failure paths). + Unexpected exit paths are handled opportunistically via SubprocessRunner.__del__ + until a more robust mechanism is implemented as part of #677. + """ + if self.subprocess_runner and not keep_artifacts: + self.subprocess_runner.cleanup() + async def _generate_html_reports_async( self, ) -> PyATSResults: @@ -807,6 +769,11 @@ async def _generate_html_reports_async( ) result = await generator.generate_reports_from_archives(archive_paths) + # Determine whether to keep artifacts (debug mode or explicit env var) + keep_artifacts: bool = bool( + DEBUG_MODE or os.environ.get("NAC_TEST_PYATS_KEEP_REPORT_DATA") + ) + if result["status"] in ["success", "partial"]: # Log report generation timing (procedural info) duration_str = format_duration(result["duration"]) @@ -835,7 +802,7 @@ async def _generate_html_reports_async( # Clean up archives after successful extraction and report generation # (unless in debug mode or user wants to keep data) - if not (DEBUG_MODE or os.environ.get("NAC_TEST_PYATS_KEEP_REPORT_DATA")): + if not keep_artifacts: for archive_path in archive_paths: try: archive_path.unlink() @@ -863,6 +830,9 @@ async def _generate_html_reports_async( except Exception as e: logger.debug(f"Could not remove directory {type_dir}: {e}") + # Clean up PyATS config files created by SubprocessRunner + self._cleanup_subprocess_runner(keep_artifacts) + # Extract and return test statistics if result.get("pyats_stats"): return self._extract_pyats_stats(result["pyats_stats"]) @@ -873,4 +843,6 @@ async def _generate_html_reports_async( print(f"\n{terminal.error('Failed to generate reports')}") if result.get("error"): print(f"Error: {result['error']}") + # Clean up PyATS config files; report generation failed so no stats to return + self._cleanup_subprocess_runner(keep_artifacts) return PyATSResults() diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 09ea8772..fdcb31a5 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -44,7 +44,7 @@ # Sentinel value for credential exposure detection (#689) # All test passwords use this value so we can detect if credentials leak into artifacts -TEST_CREDENTIAL_SENTINEL = "CRED_SENTINEL_MUST_NOT_APPEAR_IN_ARTIFACTS" +TEST_CREDENTIAL_SENTINEL: str = "CRED_SENTINEL_MUST_NOT_APPEAR_IN_ARTIFACTS" @dataclass diff --git a/tests/pyats_core/execution/test_subprocess_runner_integration.py b/tests/pyats_core/execution/test_subprocess_runner_integration.py index b8db71dc..57569510 100644 --- a/tests/pyats_core/execution/test_subprocess_runner_integration.py +++ b/tests/pyats_core/execution/test_subprocess_runner_integration.py @@ -6,9 +6,8 @@ Tests the subprocess execution logic: 1. Config file content verification (git_info = false for macOS fork() safety) 2. Command construction includes all required PyATS flags -3. Error handling when config file creation fails -4. Return code interpretation (0 = success, 1 = test failures, >1 = error) -5. Output processing and progress event parsing +3. Return code interpretation (0 = success, 1 = test failures, >1 = error) +4. Output processing and progress event parsing See tests/unit/pyats_core/execution/test_subprocess_runner.py for additional unit tests covering subprocess crash handling, @@ -16,13 +15,16 @@ """ import asyncio +import re from pathlib import Path from typing import Any from unittest.mock import AsyncMock, patch import pytest -from nac_test.pyats_core.execution.subprocess_runner import SubprocessRunner +from nac_test.pyats_core.execution.subprocess_runner import ( + SubprocessRunner, +) def _make_mock_process( @@ -123,12 +125,9 @@ def test_execute_job_writes_git_info_false_to_pyats_config( config_idx = cmd.index("--pyats-configuration") config_path = Path(cmd[config_idx + 1]) - try: - content = config_path.read_text() - assert "[report]" in content - assert "git_info = false" in content - finally: - config_path.unlink(missing_ok=True) + content = config_path.read_text() + assert re.search(r"\[report\]", content) + assert re.search(r"git_info\s*=\s*false", content) def test_execute_job_with_testbed_writes_git_info_false_to_pyats_config( self, runner: SubprocessRunner @@ -148,12 +147,9 @@ def test_execute_job_with_testbed_writes_git_info_false_to_pyats_config( config_idx = cmd.index("--pyats-configuration") config_path = Path(cmd[config_idx + 1]) - try: - content = config_path.read_text() - assert "[report]" in content - assert "git_info = false" in content - finally: - config_path.unlink(missing_ok=True) + content = config_path.read_text() + assert re.search(r"\[report\]", content) + assert re.search(r"git_info\s*=\s*false", content) def test_execute_job_writes_plugin_config_with_progress_reporter( self, runner: SubprocessRunner @@ -166,12 +162,9 @@ def test_execute_job_writes_plugin_config_with_progress_reporter( config_idx = cmd.index("--configuration") config_path = Path(cmd[config_idx + 1]) - try: - content = config_path.read_text() - assert "ProgressReporterPlugin" in content - assert "enabled: True" in content - finally: - config_path.unlink(missing_ok=True) + content = config_path.read_text() + assert re.search(r"ProgressReporterPlugin:\s+enabled:\s+True", content) + assert re.search(r"EnvironmentDebugPlugin:\s+enabled:\s+False", content) class TestCommandConstruction: @@ -220,71 +213,6 @@ def test_execute_job_with_testbed_includes_testbed_and_config_flags( assert "--pyats-configuration" in cmd -class TestConfigCreationFailure: - """Tests error handling when config file creation fails.""" - - def test_execute_job_returns_none_on_config_failure(self, tmp_path: Path) -> None: - """Verify execute_job returns None and does NOT launch subprocess when config fails. - - If we can't create the config files, we must not proceed with execution - because PyATS would use default settings that cause fork() crashes on macOS. - """ - runner = SubprocessRunner( - output_dir=tmp_path, - output_handler=lambda line: None, - ) - - with ( - patch( - "tempfile.NamedTemporaryFile", - side_effect=OSError("disk full"), - ), - patch( - "asyncio.create_subprocess_exec", - new_callable=AsyncMock, - ) as mock_exec, - ): - result = asyncio.run( - runner.execute_job( - job_file_path=Path("/fake/job.py"), - env={}, - ) - ) - - assert result is None, "execute_job must return None when config creation fails" - mock_exec.assert_not_called() - - def test_execute_job_with_testbed_returns_none_on_config_failure( - self, tmp_path: Path - ) -> None: - """Verify execute_job_with_testbed returns None when config creation fails.""" - runner = SubprocessRunner( - output_dir=tmp_path, - output_handler=lambda line: None, - ) - - with ( - patch( - "tempfile.NamedTemporaryFile", - side_effect=OSError("disk full"), - ), - patch( - "asyncio.create_subprocess_exec", - new_callable=AsyncMock, - ) as mock_exec, - ): - result = asyncio.run( - runner.execute_job_with_testbed( - job_file_path=Path("/fake/job.py"), - testbed_file_path=Path("/fake/testbed.yaml"), - env={"HOSTNAME": "test-device"}, - ) - ) - - assert result is None - mock_exec.assert_not_called() - - class TestReturnCodeHandling: """Tests subprocess return code interpretation.""" diff --git a/tests/pyats_core/test_orchestrator_config_error.py b/tests/pyats_core/test_orchestrator_config_error.py new file mode 100644 index 00000000..493e681b --- /dev/null +++ b/tests/pyats_core/test_orchestrator_config_error.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: MPL-2.0 +# Copyright (c) 2025 Daniel Schmidt + +"""Tests for PyATSOrchestrator handling of SubprocessRunner init failures.""" + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from nac_test.pyats_core.orchestrator import PyATSOrchestrator + +from .conftest import PyATSTestDirs + + +class TestOrchestratorSubprocessRunnerInitError: + """Tests for RuntimeError handling in PyATSOrchestrator when SubprocessRunner cannot initialize.""" + + @pytest.mark.parametrize( + ("has_api", "has_d2d"), + [ + (True, False), + (False, True), + (True, True), + ], + ) + def test_subprocess_runner_init_error_returns_from_error_results( + self, + aci_controller_env: None, + pyats_test_dirs: PyATSTestDirs, + has_api: bool, + has_d2d: bool, + ) -> None: + """OSError in SubprocessRunner._create_config_files raises RuntimeError, returns from_error results.""" + api_tests = [Path("/fake/tests/api/test_one.py")] if has_api else [] + d2d_tests = [Path("/fake/tests/d2d/test_two.py")] if has_d2d else [] + + orchestrator = PyATSOrchestrator( + data_paths=[pyats_test_dirs.output_dir.parent / "data"], + test_dir=pyats_test_dirs.test_dir, + output_dir=pyats_test_dirs.output_dir, + merged_data_filename="merged.yaml", + ) + + with ( + patch.object( + orchestrator.test_discovery, "discover_pyats_tests" + ) as mock_discover, + patch.object( + orchestrator.test_discovery, "categorize_tests_by_type" + ) as mock_categorize, + patch.object(orchestrator, "validate_environment"), + patch.object(Path, "write_text", side_effect=OSError("disk full")), + ): + mock_discover.return_value = (api_tests + d2d_tests, []) + mock_categorize.return_value = (api_tests, d2d_tests) + result = orchestrator.run_tests() + + if has_api: + assert result.api is not None + assert result.api.has_error is True + assert result.api.reason is not None + assert "disk full" in result.api.reason + else: + assert result.api is None + + if has_d2d: + assert result.d2d is not None + assert result.d2d.has_error is True + assert result.d2d.reason is not None + assert "disk full" in result.d2d.reason + else: + assert result.d2d is None diff --git a/tests/unit/pyats_core/execution/test_subprocess_runner.py b/tests/unit/pyats_core/execution/test_subprocess_runner.py index 4393226c..4f2ae967 100644 --- a/tests/unit/pyats_core/execution/test_subprocess_runner.py +++ b/tests/unit/pyats_core/execution/test_subprocess_runner.py @@ -19,11 +19,18 @@ import asyncio import logging from pathlib import Path +from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest -from nac_test.pyats_core.execution.subprocess_runner import SubprocessRunner +from nac_test.pyats_core.constants import ( + PYATS_CONFIG_FILENAME, + PYATS_PLUGIN_CONFIG_FILENAME, +) +from nac_test.pyats_core.execution.subprocess_runner import ( + SubprocessRunner, +) from nac_test.utils.logging import LogLevel @@ -289,10 +296,171 @@ def test_build_command_loglevel_to_cli_flags( runner = SubprocessRunner(temp_output_dir, mock_output_handler, loglevel=loglevel) cmd = runner._build_command( job_file_path=Path("/tmp/job.py"), - plugin_config_file="/tmp/plugin.yaml", - pyats_config_file="/tmp/pyats.conf", archive_name="test_archive.zip", ) assert cmd.count("--verbose") == expected_verbose_count assert cmd.count("--quiet") == expected_quiet_count + + +def test_build_command_includes_config_files( + temp_output_dir: Path, + mock_output_handler: Mock, +) -> None: + """Test that _build_command includes plugin and pyats config files in the command.""" + runner = SubprocessRunner(temp_output_dir, mock_output_handler) + cmd = runner._build_command( + job_file_path=Path("/tmp/job.py"), + archive_name="test_archive.zip", + ) + + # Verify --configuration flag with plugin config file + assert "--configuration" in cmd + config_idx = cmd.index("--configuration") + assert cmd[config_idx + 1] == str(runner._plugin_config_file) + + # Verify --pyats-configuration flag with pyats config file + assert "--pyats-configuration" in cmd + pyats_config_idx = cmd.index("--pyats-configuration") + assert cmd[pyats_config_idx + 1] == str(runner._pyats_config_file) + + +def test_init_creates_config_files_in_output_dir( + temp_output_dir: Path, mock_output_handler: Mock +) -> None: + """Test that __init__ creates config files in the output directory.""" + runner = SubprocessRunner(temp_output_dir, mock_output_handler) + + assert runner._plugin_config_file is not None + assert runner._pyats_config_file is not None + + assert runner._plugin_config_file.exists() + assert runner._pyats_config_file.exists() + assert runner._plugin_config_file.parent == temp_output_dir + assert runner._pyats_config_file.parent == temp_output_dir + assert runner._plugin_config_file.name == PYATS_PLUGIN_CONFIG_FILENAME + assert runner._pyats_config_file.name == PYATS_CONFIG_FILENAME + + +def test_init_raises_runtime_error_on_write_failure( + temp_output_dir: Path, mock_output_handler: Mock +) -> None: + """Test that RuntimeError is raised when config file write fails.""" + with patch.object(Path, "write_text", side_effect=OSError("disk full")): + with pytest.raises(RuntimeError, match="disk full"): + SubprocessRunner(temp_output_dir, mock_output_handler) + + +def test_write_failure_leaves_attributes_none_and_cleanup_is_safe( + temp_output_dir: Path, +) -> None: + """Test that write failure leaves attributes None and cleanup handles it safely.""" + runner = object.__new__(SubprocessRunner) + runner.output_dir = temp_output_dir + runner._plugin_config_file = None + runner._pyats_config_file = None + + with patch.object(Path, "write_text", side_effect=OSError("disk full")): + with pytest.raises(RuntimeError): + runner._create_config_files() + + assert runner._plugin_config_file is None + assert runner._pyats_config_file is None + + with patch.object(Path, "unlink") as mock_unlink: + runner.cleanup() + mock_unlink.assert_not_called() + + +def test_partial_write_failure_cleans_up_first_file( + temp_output_dir: Path, +) -> None: + """Test that partial write failure cleans up successfully written files. + + If the first config file is written successfully but the second fails, + the first file should be cleaned up to avoid orphaned files. + """ + runner = object.__new__(SubprocessRunner) + runner.output_dir = temp_output_dir + runner._plugin_config_file = None + runner._pyats_config_file = None + + # Track which file is being written + write_count = {"count": 0} + original_write_text = Path.write_text + + def write_text_fail_on_second( + self: Path, content: str, *args: Any, **kwargs: Any + ) -> None: + write_count["count"] += 1 + if write_count["count"] == 1: + # First write succeeds + original_write_text(self, content, *args, **kwargs) + else: + # Second write fails + raise OSError("disk full on second file") + + with patch.object(Path, "write_text", write_text_fail_on_second): + with pytest.raises(RuntimeError, match="disk full"): + runner._create_config_files() + + # Attributes should still be None (not set until both writes succeed) + assert runner._plugin_config_file is None + assert runner._pyats_config_file is None + + # The first file that was successfully written should have been cleaned up + plugin_config_path = temp_output_dir / PYATS_PLUGIN_CONFIG_FILENAME + assert not plugin_config_path.exists(), ( + "First config file should be cleaned up after second write fails" + ) + + +# --- Cleanup tests --- + + +def test_cleanup_removes_config_files( + temp_output_dir: Path, mock_output_handler: Mock +) -> None: + """Test that cleanup() removes config files.""" + runner = SubprocessRunner(temp_output_dir, mock_output_handler) + + assert runner._plugin_config_file is not None + assert runner._pyats_config_file is not None + assert runner._plugin_config_file.exists() + assert runner._pyats_config_file.exists() + + runner.cleanup() + + assert not runner._plugin_config_file.exists() + assert not runner._pyats_config_file.exists() + + +def test_cleanup_is_idempotent( + temp_output_dir: Path, mock_output_handler: Mock +) -> None: + """Test that cleanup() can be called multiple times safely.""" + runner = SubprocessRunner(temp_output_dir, mock_output_handler) + assert runner._plugin_config_file is not None + assert runner._pyats_config_file is not None + + runner.cleanup() + assert not runner._plugin_config_file.exists() + assert not runner._pyats_config_file.exists() + + runner.cleanup() # Second call must not raise + assert not runner._plugin_config_file.exists() + assert not runner._pyats_config_file.exists() + + +def test_del_calls_cleanup(temp_output_dir: Path, mock_output_handler: Mock) -> None: + """Test that __del__ triggers opportunistic cleanup of config files.""" + runner = SubprocessRunner(temp_output_dir, mock_output_handler) + assert runner._plugin_config_file is not None + assert runner._pyats_config_file is not None + assert runner._plugin_config_file.exists() + assert runner._pyats_config_file.exists() + + runner.__del__() + + assert not runner._plugin_config_file.exists() + assert not runner._pyats_config_file.exists()