diff --git a/tests/test_article_12_records.py b/tests/test_article_12_records.py new file mode 100644 index 0000000..5899e6b --- /dev/null +++ b/tests/test_article_12_records.py @@ -0,0 +1,428 @@ +"""Tests for EU AI Act Article 12 - Record-Keeping Compliance.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from lexecon.compliance.eu_ai_act.article_12_records import ( + ComplianceRecord, + RecordKeepingSystem, + RecordStatus, + RetentionClass, + RetentionPolicy, +) +from lexecon.ledger.chain import LedgerChain, LedgerEntry + + +class TestRetentionClass: + """Tests for RetentionClass enum.""" + + def test_retention_classes_exist(self): + """Test that all retention classes are defined.""" + assert RetentionClass.HIGH_RISK.value == "high_risk" + assert RetentionClass.STANDARD.value == "standard" + assert RetentionClass.GDPR_INTERSECT.value == "gdpr_intersect" + + +class TestRecordStatus: + """Tests for RecordStatus enum.""" + + def test_record_statuses_exist(self): + """Test that all record statuses are defined.""" + assert RecordStatus.ACTIVE.value == "active" + assert RecordStatus.EXPIRING.value == "expiring" + assert RecordStatus.LEGAL_HOLD.value == "legal_hold" + assert RecordStatus.ANONYMIZED.value == "anonymized" + assert RecordStatus.ARCHIVED.value == "archived" + + +class TestRetentionPolicy: + """Tests for RetentionPolicy dataclass.""" + + def test_create_retention_policy(self): + """Test creating retention policy.""" + policy = RetentionPolicy( + classification=RetentionClass.HIGH_RISK, + retention_days=3650, + auto_anonymize=True, + legal_basis="EU AI Act Article 12", + data_subject_rights=False, + ) + + assert policy.classification == RetentionClass.HIGH_RISK + assert policy.retention_days == 3650 + assert policy.auto_anonymize is True + assert policy.legal_basis == "EU AI Act Article 12" + assert policy.data_subject_rights is False + + +class TestRecordKeepingSystem: + """Tests for RecordKeepingSystem class.""" + + @pytest.fixture + def ledger(self): + """Create ledger for testing.""" + return LedgerChain() + + @pytest.fixture + def record_system(self, ledger): + """Create record keeping system.""" + return RecordKeepingSystem(ledger) + + def test_initialization(self, record_system): + """Test record keeping system initialization.""" + assert record_system.ledger is not None + assert isinstance(record_system.legal_holds, dict) + assert len(record_system.legal_holds) == 0 + + def test_default_policies_exist(self, record_system): + """Test that default retention policies are created.""" + assert RetentionClass.HIGH_RISK in record_system.policies + assert RetentionClass.STANDARD in record_system.policies + assert RetentionClass.GDPR_INTERSECT in record_system.policies + + def test_high_risk_policy_10_years(self, record_system): + """Test that high-risk policy has 10-year retention.""" + policy = record_system.policies[RetentionClass.HIGH_RISK] + assert policy.retention_days == 3650 # 10 years + + def test_standard_policy_6_months(self, record_system): + """Test that standard policy has 6-month retention.""" + policy = record_system.policies[RetentionClass.STANDARD] + assert policy.retention_days == 180 # 6 months + + def test_classify_high_risk_by_risk_level(self, record_system, ledger): + """Test classification of high-risk entry by risk level.""" + entry = ledger.append("decision", {"risk_level": 5, "action": "search"}) + + classification = record_system.classify_entry(entry) + assert classification == RetentionClass.HIGH_RISK + + def test_classify_high_risk_by_denial(self, record_system, ledger): + """Test classification of high-risk entry by denial.""" + entry = ledger.append("decision", {"decision": "deny", "risk_level": 2}) + + classification = record_system.classify_entry(entry) + assert classification == RetentionClass.HIGH_RISK + + def test_classify_high_risk_by_pii(self, record_system, ledger): + """Test classification of high-risk entry by PII.""" + entry = ledger.append("decision", {"data_classes": ["PII", "public"], "decision": "allow"}) + + classification = record_system.classify_entry(entry) + assert classification == RetentionClass.HIGH_RISK + + def test_classify_high_risk_policy_load(self, record_system, ledger): + """Test that policy loads are high-risk.""" + entry = ledger.append("policy_loaded", {"policy_name": "test_policy"}) + + classification = record_system.classify_entry(entry) + assert classification == RetentionClass.HIGH_RISK + + def test_classify_gdpr_by_personal_data(self, record_system, ledger): + """Test classification of GDPR intersect by personal data.""" + entry = ledger.append("decision", {"user_email": "test@example.com", "action": "read"}) + + classification = record_system.classify_entry(entry) + assert classification == RetentionClass.GDPR_INTERSECT + + def test_classify_standard_default(self, record_system, ledger): + """Test that standard entries get STANDARD classification.""" + entry = ledger.append("decision", {"action": "read", "risk_level": 1}) + + classification = record_system.classify_entry(entry) + assert classification == RetentionClass.STANDARD + + def test_wrap_entry_creates_compliance_record(self, record_system, ledger): + """Test wrapping entry in compliance record.""" + entry = ledger.append("decision", {"action": "search"}) + + record = record_system.wrap_entry(entry) + + assert isinstance(record, ComplianceRecord) + assert record.record_id == entry.entry_id + assert record.status == RecordStatus.ACTIVE + assert len(record.legal_holds) == 0 + + def test_wrap_entry_calculates_expiration(self, record_system, ledger): + """Test that wrap_entry calculates correct expiration.""" + entry = ledger.append("decision", {"risk_level": 5}) # HIGH_RISK = 10 years + + record = record_system.wrap_entry(entry) + + # Check expiration is ~10 years out + expires = datetime.fromisoformat(record.expires_at.replace("Z", "+00:00")) + created = datetime.fromisoformat(record.created_at.replace("Z", "+00:00")) + delta = (expires - created).days + assert 3640 <= delta <= 3660 # ~10 years with some tolerance + + def test_get_retention_status_empty(self, record_system): + """Test retention status with minimal entries.""" + status = record_system.get_retention_status() + + # May have genesis block + assert status["total_records"] >= 0 + assert status["legal_holds_active"] == 0 + + def test_get_retention_status_with_entries(self, record_system, ledger): + """Test retention status with multiple entries.""" + # Add some entries + ledger.append("decision", {"risk_level": 5}) + ledger.append("decision", {"action": "read"}) + ledger.append("policy_loaded", {"policy": "test"}) + + status = record_system.get_retention_status() + + # Genesis block + 3 entries = 4 total + assert status["total_records"] == 4 + assert status["by_classification"]["high_risk"] == 2 + assert status["by_classification"]["standard"] >= 1 + + def test_apply_legal_hold(self, record_system, ledger): + """Test applying legal hold to entries.""" + # Add entry + entry = ledger.append("decision", {"action": "search"}) + + # Apply legal hold + result = record_system.apply_legal_hold( + hold_id="hold_001", + entry_ids=[entry.entry_id], + reason="Investigation", + ) + + assert result["hold_id"] == "hold_001" + assert "hold_001" in record_system.legal_holds + + def test_wrap_entry_with_legal_hold(self, record_system, ledger): + """Test that entries under legal hold get correct status.""" + entry = ledger.append("decision", {"action": "test"}) + + # Apply legal hold + record_system.apply_legal_hold( + hold_id="hold_002", entry_ids=[entry.entry_id], reason="Test" + ) + + # Wrap entry + record = record_system.wrap_entry(entry) + + assert record.status == RecordStatus.LEGAL_HOLD + assert "hold_002" in record.legal_holds + + def test_release_legal_hold(self, record_system, ledger): + """Test releasing legal hold.""" + entry = ledger.append("decision", {"action": "test"}) + + # Apply and release + record_system.apply_legal_hold( + hold_id="hold_003", entry_ids=[entry.entry_id], reason="Test" + ) + result = record_system.release_legal_hold("hold_003", releaser="admin") + + assert result["status"] == "released" + # Hold remains but is marked as released + assert result["hold_id"] == "hold_003" + + def test_generate_regulatory_package(self, record_system, ledger): + """Test generating regulatory compliance package.""" + # Add some entries + ledger.append("decision", {"risk_level": 5, "action": "deny"}) + ledger.append("decision", {"action": "allow"}) + + package = record_system.generate_regulatory_package() + + # Package should have some structure + assert isinstance(package, dict) + assert len(package) > 0 + + def test_regulatory_package_metadata(self, record_system, ledger): + """Test regulatory package contains metadata.""" + package = record_system.generate_regulatory_package() + + # Package should be a dict with data + assert isinstance(package, dict) + assert len(package) > 0 + + def test_regulatory_package_includes_policies(self, record_system, ledger): + """Test regulatory package includes retention info.""" + package = record_system.generate_regulatory_package() + + # Package should contain data + assert isinstance(package, dict) + assert len(package) > 0 + + def test_export_for_regulator_json(self, record_system, ledger): + """Test exporting compliance package as JSON.""" + ledger.append("decision", {"action": "test"}) + + export = record_system.export_for_regulator(format="json") + + assert isinstance(export, str) + # Should contain data + assert len(export) > 100 + + def test_export_for_regulator_markdown(self, record_system, ledger): + """Test exporting compliance package as Markdown.""" + ledger.append("decision", {"risk_level": 5}) + + export = record_system.export_for_regulator(format="markdown") + + assert isinstance(export, str) + assert "# EU AI Act Article 12" in export + + def test_export_for_regulator_csv(self, record_system, ledger): + """Test exporting compliance package as CSV.""" + ledger.append("decision", {"action": "test"}) + + export = record_system.export_for_regulator(format="csv") + + assert isinstance(export, str) + # CSV should have headers + lines = export.strip().split("\n") + assert len(lines) >= 2 # Header + at least one row + + def test_anonymize_record(self, record_system, ledger): + """Test anonymizing record with personal data.""" + entry = ledger.append( + "decision", {"user_email": "test@example.com", "action": "search"} + ) + + result = record_system.anonymize_record(entry.entry_id) + + assert result["status"] == "anonymized" + assert "anonymized_at" in result + + +class TestComplianceRecord: + """Tests for ComplianceRecord dataclass.""" + + def test_create_compliance_record(self): + """Test creating compliance record.""" + record = ComplianceRecord( + record_id="rec_123", + original_entry={"event_type": "decision"}, + retention_class=RetentionClass.HIGH_RISK, + created_at="2025-01-01T00:00:00Z", + expires_at="2035-01-01T00:00:00Z", + status=RecordStatus.ACTIVE, + legal_holds=[], + ) + + assert record.record_id == "rec_123" + assert record.retention_class == RetentionClass.HIGH_RISK + assert record.status == RecordStatus.ACTIVE + + def test_compliance_record_with_legal_holds(self): + """Test compliance record with legal holds.""" + record = ComplianceRecord( + record_id="rec_456", + original_entry={}, + retention_class=RetentionClass.STANDARD, + created_at="2025-01-01T00:00:00Z", + expires_at="2025-07-01T00:00:00Z", + status=RecordStatus.LEGAL_HOLD, + legal_holds=["hold_001", "hold_002"], + ) + + assert len(record.legal_holds) == 2 + assert record.status == RecordStatus.LEGAL_HOLD + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + @pytest.fixture + def record_system(self): + """Create record system for edge case tests.""" + return RecordKeepingSystem(LedgerChain()) + + def test_classify_entry_with_empty_data(self, record_system): + """Test classifying entry with empty data.""" + ledger = record_system.ledger + entry = ledger.append("decision", {}) + + classification = record_system.classify_entry(entry) + assert classification == RetentionClass.STANDARD + + def test_classify_entry_with_none_data(self, record_system): + """Test classifying entry with None/minimal data.""" + ledger = record_system.ledger + entry = ledger.append("decision", {}) + + classification = record_system.classify_entry(entry) + assert classification == RetentionClass.STANDARD + + def test_multiple_legal_holds_on_same_entry(self, record_system): + """Test applying multiple legal holds to same entry.""" + ledger = record_system.ledger + entry = ledger.append("decision", {"action": "test"}) + + # Apply two holds + record_system.apply_legal_hold("hold_a", [entry.entry_id], "Reason A") + record_system.apply_legal_hold("hold_b", [entry.entry_id], "Reason B") + + record = record_system.wrap_entry(entry) + # At least one hold should be present + assert len(record.legal_holds) >= 1 + + def test_legal_hold_with_multiple_entries(self, record_system): + """Test legal hold affecting multiple entries.""" + ledger = record_system.ledger + entry1 = ledger.append("decision", {"action": "test1"}) + entry2 = ledger.append("decision", {"action": "test2"}) + + result = record_system.apply_legal_hold( + "hold_multi", [entry1.entry_id, entry2.entry_id], "Multi-entry hold" + ) + + # Hold should be created + assert "hold_multi" in record_system.legal_holds + + def test_release_nonexistent_legal_hold(self, record_system): + """Test releasing non-existent legal hold.""" + result = record_system.release_legal_hold("nonexistent") + # Should handle gracefully + assert isinstance(result, dict) + + def test_personal_data_detection(self, record_system): + """Test personal data detection in various formats.""" + test_cases = [ + ({"email": "user@test.com"}, True), + ({"user_name": "John Doe"}, True), + ({"ip_address": "192.168.1.1"}, True), + ({"phone_number": "+1234567890"}, True), + ({"action": "read", "resource": "public_data"}, False), + ] + + for data, expected_personal in test_cases: + has_personal = record_system._contains_personal_data(data) + assert has_personal == expected_personal + + def test_retention_status_with_expiring_records(self, record_system): + """Test retention status correctly identifies expiring records.""" + ledger = record_system.ledger + + # Add some entries + ledger.append("decision", {"action": "test"}) + + status = record_system.get_retention_status() + # Should complete without error + assert "expiring_within_30_days" in status + assert status["expiring_within_30_days"] >= 0 + + def test_anonymize_preserves_structure(self, record_system): + """Test that anonymization preserves basic functionality.""" + ledger = record_system.ledger + entry = ledger.append( + "decision", + { + "user_email": "sensitive@example.com", + "action": "search", + "metadata": {"ip_address": "10.0.0.1"}, + }, + ) + + result = record_system.anonymize_record(entry.entry_id) + + # Anonymization should complete successfully + assert result["status"] == "anonymized" + assert "anonymized_at" in result diff --git a/tests/test_article_14_oversight.py b/tests/test_article_14_oversight.py new file mode 100644 index 0000000..c84031c --- /dev/null +++ b/tests/test_article_14_oversight.py @@ -0,0 +1,681 @@ +"""Tests for EU AI Act Article 14 - Human Oversight Evidence System.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from lexecon.compliance.eu_ai_act.article_14_oversight import ( + EscalationPath, + HumanIntervention, + HumanOversightEvidence, + InterventionType, + OversightRole, +) +from lexecon.identity.signing import KeyManager + + +class TestInterventionType: + """Tests for InterventionType enum.""" + + def test_intervention_types_exist(self): + """Test that all intervention types are defined.""" + assert InterventionType.APPROVAL.value == "approval" + assert InterventionType.OVERRIDE.value == "override" + assert InterventionType.ESCALATION.value == "escalation" + assert InterventionType.EMERGENCY_STOP.value == "emergency_stop" + assert InterventionType.POLICY_EXCEPTION.value == "policy_exception" + assert InterventionType.MANUAL_REVIEW.value == "manual_review" + + +class TestOversightRole: + """Tests for OversightRole enum.""" + + def test_oversight_roles_exist(self): + """Test that all oversight roles are defined.""" + assert OversightRole.COMPLIANCE_OFFICER.value == "compliance_officer" + assert OversightRole.SECURITY_LEAD.value == "security_lead" + assert OversightRole.LEGAL_COUNSEL.value == "legal_counsel" + assert OversightRole.RISK_MANAGER.value == "risk_manager" + assert OversightRole.EXECUTIVE.value == "executive" + assert OversightRole.SOC_ANALYST.value == "soc_analyst" + + +class TestHumanOversightEvidence: + """Tests for HumanOversightEvidence class.""" + + @pytest.fixture + def oversight(self): + """Create oversight evidence system.""" + return HumanOversightEvidence() + + @pytest.fixture + def oversight_with_km(self): + """Create oversight system with specific key manager.""" + km = KeyManager.generate() + return HumanOversightEvidence(key_manager=km) + + def test_initialization(self, oversight): + """Test oversight system initialization.""" + assert oversight.key_manager is not None + assert isinstance(oversight.interventions, list) + assert len(oversight.interventions) == 0 + assert isinstance(oversight.escalation_paths, dict) + + def test_default_escalation_paths(self, oversight): + """Test that default escalation paths are created.""" + assert "high_risk" in oversight.escalation_paths + assert "financial" in oversight.escalation_paths + assert "legal" in oversight.escalation_paths + assert "operational" in oversight.escalation_paths + + def test_high_risk_escalation_path(self, oversight): + """Test high risk escalation path configuration.""" + path = oversight.escalation_paths["high_risk"] + + assert path.decision_class == "high_risk" + assert OversightRole.SOC_ANALYST in path.roles + assert OversightRole.SECURITY_LEAD in path.roles + assert OversightRole.EXECUTIVE in path.roles + assert path.max_response_time_minutes == 15 + assert path.requires_approval_from == OversightRole.SECURITY_LEAD + + def test_log_intervention_approval(self, oversight): + """Test logging an approval intervention.""" + ai_rec = {"action": "allow", "confidence": 0.95} + human_dec = {"action": "allow", "approved": True} + + intervention = oversight.log_intervention( + intervention_type=InterventionType.APPROVAL, + ai_recommendation=ai_rec, + human_decision=human_dec, + human_role=OversightRole.COMPLIANCE_OFFICER, + reason="AI recommendation aligns with policy", + ) + + assert intervention.intervention_type == InterventionType.APPROVAL + assert intervention.ai_recommendation == ai_rec + assert intervention.human_decision == human_dec + assert intervention.human_role == OversightRole.COMPLIANCE_OFFICER + assert intervention.ai_confidence == 0.95 + assert intervention.signature is not None + + def test_log_intervention_override(self, oversight): + """Test logging an override intervention.""" + ai_rec = {"action": "deny", "confidence": 0.85} + human_dec = {"action": "allow", "overridden": True} + + intervention = oversight.log_intervention( + intervention_type=InterventionType.OVERRIDE, + ai_recommendation=ai_rec, + human_decision=human_dec, + human_role=OversightRole.SECURITY_LEAD, + reason="Business context requires exception", + ) + + assert intervention.intervention_type == InterventionType.OVERRIDE + assert intervention.ai_recommendation["action"] != intervention.human_decision["action"] + + def test_log_intervention_with_context(self, oversight): + """Test logging intervention with request context.""" + context = {"user_id": "user123", "request_id": "req456"} + + intervention = oversight.log_intervention( + intervention_type=InterventionType.MANUAL_REVIEW, + ai_recommendation={"action": "allow"}, + human_decision={"action": "allow"}, + human_role=OversightRole.RISK_MANAGER, + reason="Required review completed", + request_context=context, + ) + + assert intervention.request_context == context + + def test_log_intervention_with_response_time(self, oversight): + """Test logging intervention with response time.""" + intervention = oversight.log_intervention( + intervention_type=InterventionType.APPROVAL, + ai_recommendation={"action": "allow"}, + human_decision={"action": "allow"}, + human_role=OversightRole.SOC_ANALYST, + reason="Approved", + response_time_ms=1500, + ) + + assert intervention.response_time_ms == 1500 + + def test_intervention_has_unique_id(self, oversight): + """Test that interventions get unique IDs.""" + int1 = oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Reason 1", + ) + + int2 = oversight.log_intervention( + InterventionType.OVERRIDE, + {"action": "deny"}, + {"action": "allow"}, + OversightRole.SECURITY_LEAD, + "Reason 2", + ) + + assert int1.intervention_id != int2.intervention_id + + def test_intervention_has_timestamp(self, oversight): + """Test that interventions have timestamps.""" + intervention = oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + ) + + assert intervention.timestamp is not None + # Should be ISO format with Z suffix + assert intervention.timestamp.endswith("Z") + + # Should be parseable + dt = datetime.fromisoformat(intervention.timestamp.replace("Z", "+00:00")) + assert isinstance(dt, datetime) + + def test_intervention_is_signed_by_default(self, oversight): + """Test that interventions are signed by default.""" + intervention = oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + ) + + assert intervention.signature is not None + assert len(intervention.signature) > 0 + + def test_intervention_can_be_unsigned(self, oversight): + """Test that interventions can be created without signatures.""" + intervention = oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + sign=False, + ) + + assert intervention.signature is None + + def test_verify_intervention_valid(self, oversight): + """Test verifying a valid intervention.""" + intervention = oversight.log_intervention( + InterventionType.OVERRIDE, + {"action": "deny"}, + {"action": "allow"}, + OversightRole.SECURITY_LEAD, + "Override required", + ) + + is_valid = oversight.verify_intervention(intervention) + assert is_valid is True + + def test_verify_intervention_invalid_signature(self, oversight): + """Test that tampered intervention fails verification.""" + intervention = oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + ) + + # Tamper with the intervention + intervention.reason = "TAMPERED" + + is_valid = oversight.verify_intervention(intervention) + assert is_valid is False + + def test_verify_intervention_no_signature(self, oversight): + """Test verifying intervention without signature.""" + intervention = oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + sign=False, + ) + + is_valid = oversight.verify_intervention(intervention) + assert is_valid is False + + def test_multiple_interventions_stored(self, oversight): + """Test that multiple interventions are stored.""" + for i in range(5): + oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow", "index": i}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + f"Approved {i}", + ) + + assert len(oversight.interventions) == 5 + + def test_generate_oversight_effectiveness_report(self, oversight): + """Test generating oversight effectiveness report.""" + # Log some interventions with decision fields for override detection + oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow", "confidence": 0.9, "decision": "allow"}, + {"action": "allow", "decision": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + response_time_ms=1000, + ) + + oversight.log_intervention( + InterventionType.OVERRIDE, + {"action": "deny", "confidence": 0.8, "decision": "deny"}, + {"action": "allow", "decision": "allow"}, + OversightRole.SECURITY_LEAD, + "Override needed", + response_time_ms=2000, + ) + + report = oversight.generate_oversight_effectiveness_report(time_period_days=30) + + assert "total_interventions" in report + assert "intervention_breakdown" in report + assert "effectiveness_metrics" in report + assert "response_time_metrics" in report + assert "compliance_assessment" in report + + def test_effectiveness_report_calculates_override_rate(self, oversight): + """Test that report calculates override rate correctly.""" + # 3 approvals, 2 overrides = 40% override rate + for _ in range(3): + oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow", "decision": "allow"}, + {"action": "allow", "decision": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + ) + + for _ in range(2): + oversight.log_intervention( + InterventionType.OVERRIDE, + {"action": "deny", "decision": "deny"}, + {"action": "allow", "decision": "allow"}, + OversightRole.SECURITY_LEAD, + "Override", + ) + + report = oversight.generate_oversight_effectiveness_report() + + assert report["total_interventions"] == 5 + assert report["effectiveness_metrics"]["override_rate_percent"] == 40.0 # 2/5 = 40% + + def test_effectiveness_report_average_response_time(self, oversight): + """Test that report calculates average response time.""" + oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + response_time_ms=1000, + ) + + oversight.log_intervention( + InterventionType.OVERRIDE, + {"action": "deny"}, + {"action": "allow"}, + OversightRole.SECURITY_LEAD, + "Override", + response_time_ms=3000, + ) + + report = oversight.generate_oversight_effectiveness_report() + + # Average of 1000 and 3000 is 2000 + assert report["response_time_metrics"]["average_ms"] == 2000.0 + + def test_effectiveness_report_intervention_breakdown(self, oversight): + """Test that report breaks down interventions by type.""" + oversight.log_intervention( + InterventionType.APPROVAL, {}, {}, OversightRole.COMPLIANCE_OFFICER, "A" + ) + oversight.log_intervention( + InterventionType.APPROVAL, {}, {}, OversightRole.COMPLIANCE_OFFICER, "B" + ) + oversight.log_intervention( + InterventionType.OVERRIDE, {}, {}, OversightRole.SECURITY_LEAD, "C" + ) + oversight.log_intervention( + InterventionType.EMERGENCY_STOP, {}, {}, OversightRole.EXECUTIVE, "D" + ) + + report = oversight.generate_oversight_effectiveness_report() + + breakdown = report["intervention_breakdown"] + assert breakdown["by_type"][InterventionType.APPROVAL.value] == 2 + assert breakdown["by_type"][InterventionType.OVERRIDE.value] == 1 + assert breakdown["by_type"][InterventionType.EMERGENCY_STOP.value] == 1 + + def test_get_escalation_path(self, oversight): + """Test getting escalation path for decision class.""" + path = oversight.get_escalation_path("high_risk") + + assert path is not None + assert path.decision_class == "high_risk" + + def test_get_escalation_path_nonexistent(self, oversight): + """Test getting non-existent escalation path.""" + path = oversight.get_escalation_path("nonexistent") + + assert path is None + + def test_export_evidence_package(self, oversight): + """Test exporting complete evidence package.""" + # Log some interventions + oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + ) + + package = oversight.export_evidence_package() + + assert "summary" in package + assert "interventions" in package + assert "effectiveness_report" in package + assert "compliance_attestation" in package + + def test_export_evidence_package_metadata(self, oversight): + """Test evidence package metadata.""" + package = oversight.export_evidence_package() + + assert "generated_at" in package + assert "period" in package + assert "summary" in package + assert package["summary"]["total_interventions"] >= 0 + + def test_export_evidence_package_verification_proofs(self, oversight): + """Test that evidence package includes verification evidence.""" + intervention = oversight.log_intervention( + InterventionType.OVERRIDE, + {"action": "deny"}, + {"action": "allow"}, + OversightRole.SECURITY_LEAD, + "Override", + ) + + package = oversight.export_evidence_package() + + # Verification info is in the effectiveness report + evidence_integrity = package["effectiveness_report"]["evidence_integrity"] + assert evidence_integrity["all_signed"] is True + assert evidence_integrity["verification_rate"] == 100.0 + + def test_export_markdown(self, oversight): + """Test exporting evidence package as markdown.""" + oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + ) + + package = oversight.export_evidence_package() + markdown = oversight.export_markdown(package) + + assert isinstance(markdown, str) + assert "# EU AI Act Article 14" in markdown + assert "Human Oversight Evidence" in markdown + + def test_export_markdown_includes_interventions(self, oversight): + """Test that markdown export includes intervention details.""" + oversight.log_intervention( + InterventionType.OVERRIDE, + {"action": "deny"}, + {"action": "allow"}, + OversightRole.SECURITY_LEAD, + "Override required for business context", + ) + + package = oversight.export_evidence_package() + markdown = oversight.export_markdown(package) + + assert "override:" in markdown + assert "security_lead:" in markdown + + def test_simulate_escalation(self, oversight): + """Test escalation simulation.""" + result = oversight.simulate_escalation( + decision_class="high_risk", + current_role=OversightRole.SOC_ANALYST, + ) + + assert "full_escalation_chain" in result + assert "current_role" in result + assert "can_approve" in result + + def test_simulate_escalation_chain(self, oversight): + """Test that escalation follows correct chain.""" + result = oversight.simulate_escalation( + decision_class="high_risk", + current_role=OversightRole.SOC_ANALYST, + ) + + # High risk escalates: SOC -> Security Lead -> Executive + chain = result["full_escalation_chain"] + assert OversightRole.SOC_ANALYST.value in chain + assert OversightRole.SECURITY_LEAD.value in chain + assert OversightRole.EXECUTIVE.value in chain + + def test_simulate_escalation_approval_authority(self, oversight): + """Test that escalation identifies who can approve.""" + # SOC analyst cannot approve high risk decisions + result1 = oversight.simulate_escalation( + decision_class="high_risk", + current_role=OversightRole.SOC_ANALYST, + ) + assert result1["can_approve"] is False + + # Security lead CAN approve high risk decisions + result2 = oversight.simulate_escalation( + decision_class="high_risk", + current_role=OversightRole.SECURITY_LEAD, + ) + assert result2["can_approve"] is True + + +class TestEscalationPath: + """Tests for EscalationPath dataclass.""" + + def test_create_escalation_path(self): + """Test creating escalation path.""" + path = EscalationPath( + decision_class="test_class", + roles=[OversightRole.COMPLIANCE_OFFICER, OversightRole.LEGAL_COUNSEL], + max_response_time_minutes=30, + requires_approval_from=OversightRole.LEGAL_COUNSEL, + ) + + assert path.decision_class == "test_class" + assert len(path.roles) == 2 + assert path.max_response_time_minutes == 30 + + def test_escalation_path_ordered_roles(self): + """Test that escalation path maintains role order.""" + roles = [ + OversightRole.SOC_ANALYST, + OversightRole.SECURITY_LEAD, + OversightRole.EXECUTIVE, + ] + + path = EscalationPath( + decision_class="ordered", + roles=roles, + max_response_time_minutes=15, + requires_approval_from=OversightRole.EXECUTIVE, + ) + + assert path.roles == roles + assert path.roles[0] == OversightRole.SOC_ANALYST + assert path.roles[-1] == OversightRole.EXECUTIVE + + +class TestHumanIntervention: + """Tests for HumanIntervention dataclass.""" + + def test_create_intervention(self): + """Test creating human intervention record.""" + intervention = HumanIntervention( + intervention_id="test_123", + timestamp="2025-01-01T00:00:00Z", + intervention_type=InterventionType.APPROVAL, + ai_recommendation={"action": "allow"}, + ai_confidence=0.95, + human_decision={"action": "allow"}, + human_role=OversightRole.COMPLIANCE_OFFICER, + request_context={"user": "test"}, + reason="Test approval", + ) + + assert intervention.intervention_id == "test_123" + assert intervention.intervention_type == InterventionType.APPROVAL + assert intervention.ai_confidence == 0.95 + + def test_intervention_optional_fields(self): + """Test intervention with optional fields.""" + intervention = HumanIntervention( + intervention_id="test", + timestamp="2025-01-01T00:00:00Z", + intervention_type=InterventionType.OVERRIDE, + ai_recommendation={}, + ai_confidence=0.8, + human_decision={}, + human_role=OversightRole.SECURITY_LEAD, + request_context={}, + reason="Test", + signature="test_signature", + response_time_ms=1500, + escalated_from="person_a", + escalated_to="person_b", + ) + + assert intervention.signature == "test_signature" + assert intervention.response_time_ms == 1500 + assert intervention.escalated_from == "person_a" + assert intervention.escalated_to == "person_b" + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_empty_ai_recommendation(self): + """Test handling empty AI recommendation.""" + oversight = HumanOversightEvidence() + + intervention = oversight.log_intervention( + InterventionType.MANUAL_REVIEW, + ai_recommendation={}, + human_decision={"action": "allow"}, + human_role=OversightRole.COMPLIANCE_OFFICER, + reason="Manual review", + ) + + assert intervention.ai_confidence == 0.0 # Default when no confidence + + def test_very_long_reason(self): + """Test intervention with very long reason.""" + oversight = HumanOversightEvidence() + + long_reason = "A" * 10000 + + intervention = oversight.log_intervention( + InterventionType.OVERRIDE, + {"action": "deny"}, + {"action": "allow"}, + OversightRole.SECURITY_LEAD, + long_reason, + ) + + assert intervention.reason == long_reason + + def test_effectiveness_report_no_interventions(self): + """Test effectiveness report with no interventions.""" + oversight = HumanOversightEvidence() + + report = oversight.generate_oversight_effectiveness_report() + + assert report["total_interventions"] == 0 + # When no interventions, report structure may be minimal + assert "intervention_breakdown" in report or "total_interventions" in report + + def test_effectiveness_report_no_response_times(self): + """Test effectiveness report when no response times recorded.""" + oversight = HumanOversightEvidence() + + oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved", + response_time_ms=None, + ) + + report = oversight.generate_oversight_effectiveness_report() + + # Should handle missing response times gracefully + assert "response_time_metrics" in report + + def test_unicode_in_reason(self): + """Test intervention with unicode characters.""" + oversight = HumanOversightEvidence() + + intervention = oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow"}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + "Approved 审批 ✓", + ) + + assert "审批" in intervention.reason + + def test_concurrent_interventions(self): + """Test logging multiple interventions rapidly.""" + oversight = HumanOversightEvidence() + + interventions = [] + for i in range(100): + intervention = oversight.log_intervention( + InterventionType.APPROVAL, + {"action": "allow", "index": i}, + {"action": "allow"}, + OversightRole.COMPLIANCE_OFFICER, + f"Approved {i}", + ) + interventions.append(intervention) + + # All should have unique IDs + ids = [i.intervention_id for i in interventions] + assert len(ids) == len(set(ids)) + + def test_export_with_no_data(self): + """Test exporting evidence package with no interventions.""" + oversight = HumanOversightEvidence() + + package = oversight.export_evidence_package() + + assert package["summary"]["total_interventions"] == 0 + assert len(package["interventions"]) == 0 diff --git a/tests/test_capability_tokens.py b/tests/test_capability_tokens.py new file mode 100644 index 0000000..efb65bf --- /dev/null +++ b/tests/test_capability_tokens.py @@ -0,0 +1,433 @@ +"""Tests for capability tokens.""" + +from datetime import datetime, timedelta + +import pytest + +from lexecon.capability.tokens import CapabilityToken, CapabilityTokenStore + + +class TestCapabilityToken: + """Tests for CapabilityToken class.""" + + def test_create_token(self): + """Test creating a capability token.""" + token = CapabilityToken.create( + action="search", tool="web_search", policy_version_hash="abc123" + ) + + assert token.token_id.startswith("tok_") + assert len(token.token_id) == 20 # "tok_" + 16 hex chars + assert token.scope["action"] == "search" + assert token.scope["tool"] == "web_search" + assert token.policy_version_hash == "abc123" + assert token.granted_at is not None + assert token.expiry > token.granted_at + + def test_create_token_with_custom_ttl(self): + """Test creating token with custom TTL.""" + token = CapabilityToken.create( + action="write", tool="database", policy_version_hash="xyz789", ttl_minutes=10 + ) + + # Token should expire in 10 minutes + expected_expiry = datetime.utcnow() + timedelta(minutes=10) + time_diff = abs((token.expiry - expected_expiry).total_seconds()) + assert time_diff < 2 # Allow 2 seconds tolerance + + def test_token_is_valid_when_not_expired(self): + """Test that non-expired token is valid.""" + token = CapabilityToken.create( + action="read", tool="file_system", policy_version_hash="hash1", ttl_minutes=5 + ) + + assert token.is_valid() is True + + def test_token_is_invalid_when_expired(self): + """Test that expired token is invalid.""" + # Create token that expires immediately + now = datetime.utcnow() + token = CapabilityToken( + token_id="tok_expired123", + scope={"action": "read", "tool": "file_system"}, + expiry=now - timedelta(minutes=1), # Already expired + policy_version_hash="hash1", + granted_at=now - timedelta(minutes=10), + ) + + assert token.is_valid() is False + + def test_token_is_authorized_for_matching_action_and_tool(self): + """Test authorization check with matching action and tool.""" + token = CapabilityToken.create( + action="search", tool="web_search", policy_version_hash="hash1" + ) + + assert token.is_authorized_for("search", "web_search") is True + + def test_token_not_authorized_for_different_action(self): + """Test authorization check fails with different action.""" + token = CapabilityToken.create( + action="search", tool="web_search", policy_version_hash="hash1" + ) + + assert token.is_authorized_for("write", "web_search") is False + + def test_token_not_authorized_for_different_tool(self): + """Test authorization check fails with different tool.""" + token = CapabilityToken.create( + action="search", tool="web_search", policy_version_hash="hash1" + ) + + assert token.is_authorized_for("search", "database") is False + + def test_expired_token_not_authorized(self): + """Test that expired token is not authorized even with correct scope.""" + now = datetime.utcnow() + token = CapabilityToken( + token_id="tok_expired456", + scope={"action": "read", "tool": "file_system"}, + expiry=now - timedelta(minutes=1), # Expired + policy_version_hash="hash1", + granted_at=now - timedelta(minutes=10), + ) + + assert token.is_authorized_for("read", "file_system") is False + + def test_token_serialization(self): + """Test token serialization to dict.""" + token = CapabilityToken.create( + action="delete", tool="admin_panel", policy_version_hash="hash123" + ) + token.signature = "test_signature" + + data = token.to_dict() + + assert data["token_id"] == token.token_id + assert data["scope"]["action"] == "delete" + assert data["scope"]["tool"] == "admin_panel" + assert data["policy_version_hash"] == "hash123" + assert data["signature"] == "test_signature" + assert "expiry" in data + assert "granted_at" in data + + def test_token_deserialization(self): + """Test token deserialization from dict.""" + now = datetime.utcnow() + expiry = now + timedelta(minutes=5) + + data = { + "token_id": "tok_test123456789", + "scope": {"action": "update", "tool": "config"}, + "expiry": expiry.isoformat(), + "policy_version_hash": "hash999", + "granted_at": now.isoformat(), + "signature": "sig_abc", + } + + token = CapabilityToken.from_dict(data) + + assert token.token_id == "tok_test123456789" + assert token.scope["action"] == "update" + assert token.scope["tool"] == "config" + assert token.policy_version_hash == "hash999" + assert token.signature == "sig_abc" + assert isinstance(token.expiry, datetime) + assert isinstance(token.granted_at, datetime) + + def test_token_serialization_roundtrip(self): + """Test that serialization and deserialization preserve token data.""" + original = CapabilityToken.create( + action="execute", tool="script_runner", policy_version_hash="hashABC" + ) + original.signature = "test_sig_123" + + # Serialize and deserialize + data = original.to_dict() + restored = CapabilityToken.from_dict(data) + + assert restored.token_id == original.token_id + assert restored.scope == original.scope + assert restored.policy_version_hash == original.policy_version_hash + assert restored.signature == original.signature + # Time comparison with tolerance + assert abs((restored.expiry - original.expiry).total_seconds()) < 1 + assert abs((restored.granted_at - original.granted_at).total_seconds()) < 1 + + def test_token_deserialization_without_signature(self): + """Test deserialization when signature is not present.""" + now = datetime.utcnow() + data = { + "token_id": "tok_nosig", + "scope": {"action": "read", "tool": "api"}, + "expiry": (now + timedelta(minutes=5)).isoformat(), + "policy_version_hash": "hash1", + "granted_at": now.isoformat(), + } + + token = CapabilityToken.from_dict(data) + assert token.signature is None + + def test_different_tokens_have_unique_ids(self): + """Test that multiple tokens get unique IDs.""" + tokens = [ + CapabilityToken.create("action", "tool", "hash") for _ in range(100) + ] + + token_ids = [t.token_id for t in tokens] + assert len(token_ids) == len(set(token_ids)) # All unique + + +class TestCapabilityTokenStore: + """Tests for CapabilityTokenStore class.""" + + def test_store_initialization(self): + """Test token store initialization.""" + store = CapabilityTokenStore() + assert len(store.tokens) == 0 + + def test_store_token(self): + """Test storing a token.""" + store = CapabilityTokenStore() + token = CapabilityToken.create("read", "api", "hash1") + + store.store(token) + + assert len(store.tokens) == 1 + assert token.token_id in store.tokens + + def test_get_existing_token(self): + """Test retrieving an existing token.""" + store = CapabilityTokenStore() + token = CapabilityToken.create("write", "database", "hash2") + store.store(token) + + retrieved = store.get(token.token_id) + + assert retrieved is not None + assert retrieved.token_id == token.token_id + assert retrieved.scope == token.scope + + def test_get_nonexistent_token(self): + """Test retrieving a token that doesn't exist.""" + store = CapabilityTokenStore() + + result = store.get("tok_nonexistent") + + assert result is None + + def test_verify_valid_token(self): + """Test verifying a valid token with correct scope.""" + store = CapabilityTokenStore() + token = CapabilityToken.create("search", "web_search", "hash3") + store.store(token) + + is_valid = store.verify(token.token_id, "search", "web_search") + + assert is_valid is True + + def test_verify_token_wrong_action(self): + """Test verification fails with wrong action.""" + store = CapabilityTokenStore() + token = CapabilityToken.create("read", "file", "hash4") + store.store(token) + + is_valid = store.verify(token.token_id, "write", "file") + + assert is_valid is False + + def test_verify_token_wrong_tool(self): + """Test verification fails with wrong tool.""" + store = CapabilityTokenStore() + token = CapabilityToken.create("execute", "script", "hash5") + store.store(token) + + is_valid = store.verify(token.token_id, "execute", "command") + + assert is_valid is False + + def test_verify_nonexistent_token(self): + """Test verification fails for non-existent token.""" + store = CapabilityTokenStore() + + is_valid = store.verify("tok_fake123", "any", "any") + + assert is_valid is False + + def test_verify_expired_token(self): + """Test verification fails for expired token.""" + store = CapabilityTokenStore() + now = datetime.utcnow() + expired_token = CapabilityToken( + token_id="tok_expired999", + scope={"action": "read", "tool": "api"}, + expiry=now - timedelta(minutes=1), + policy_version_hash="hash6", + granted_at=now - timedelta(minutes=10), + ) + store.store(expired_token) + + is_valid = store.verify(expired_token.token_id, "read", "api") + + assert is_valid is False + + def test_cleanup_expired_tokens(self): + """Test cleanup of expired tokens.""" + store = CapabilityTokenStore() + now = datetime.utcnow() + + # Create valid token + valid_token = CapabilityToken.create("read", "api", "hash7") + store.store(valid_token) + + # Create expired tokens + for i in range(3): + expired = CapabilityToken( + token_id=f"tok_exp{i}", + scope={"action": "write", "tool": "db"}, + expiry=now - timedelta(minutes=i + 1), + policy_version_hash="hash8", + granted_at=now - timedelta(minutes=10), + ) + store.store(expired) + + # Should have 4 tokens total + assert len(store.tokens) == 4 + + # Cleanup expired + removed_count = store.cleanup_expired() + + # Should remove 3 expired tokens + assert removed_count == 3 + assert len(store.tokens) == 1 + assert valid_token.token_id in store.tokens + + def test_cleanup_with_no_expired_tokens(self): + """Test cleanup when there are no expired tokens.""" + store = CapabilityTokenStore() + + # Create only valid tokens + for i in range(5): + token = CapabilityToken.create(f"action{i}", "tool", f"hash{i}") + store.store(token) + + removed_count = store.cleanup_expired() + + assert removed_count == 0 + assert len(store.tokens) == 5 + + def test_cleanup_empty_store(self): + """Test cleanup on empty store.""" + store = CapabilityTokenStore() + + removed_count = store.cleanup_expired() + + assert removed_count == 0 + + def test_store_multiple_tokens(self): + """Test storing multiple tokens.""" + store = CapabilityTokenStore() + tokens = [] + + for i in range(10): + token = CapabilityToken.create(f"action{i}", f"tool{i}", f"hash{i}") + tokens.append(token) + store.store(token) + + assert len(store.tokens) == 10 + + # Verify all can be retrieved + for token in tokens: + retrieved = store.get(token.token_id) + assert retrieved is not None + assert retrieved.token_id == token.token_id + + def test_store_overwrites_existing_token(self): + """Test that storing same token ID overwrites.""" + store = CapabilityTokenStore() + token1 = CapabilityToken.create("read", "api", "hash1") + store.store(token1) + + # Create new token with same ID + now = datetime.utcnow() + token2 = CapabilityToken( + token_id=token1.token_id, + scope={"action": "write", "tool": "db"}, + expiry=now + timedelta(minutes=10), + policy_version_hash="hash2", + granted_at=now, + ) + store.store(token2) + + # Should only have 1 token + assert len(store.tokens) == 1 + + # Should have the new token's scope + retrieved = store.get(token1.token_id) + assert retrieved.scope["action"] == "write" + assert retrieved.policy_version_hash == "hash2" + + +class TestTokenEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_token_with_zero_ttl(self): + """Test token creation with zero TTL.""" + token = CapabilityToken.create("action", "tool", "hash", ttl_minutes=0) + + # Should be expired immediately + assert token.is_valid() is False + + def test_token_with_negative_ttl(self): + """Test token creation with negative TTL.""" + token = CapabilityToken.create("action", "tool", "hash", ttl_minutes=-5) + + # Should be expired + assert token.is_valid() is False + + def test_token_scope_with_additional_fields(self): + """Test token scope can contain additional fields.""" + now = datetime.utcnow() + token = CapabilityToken( + token_id="tok_extra", + scope={ + "action": "read", + "tool": "api", + "resource": "/users/123", + "method": "GET", + }, + expiry=now + timedelta(minutes=5), + policy_version_hash="hash1", + granted_at=now, + ) + + # Basic authorization should still work + assert token.is_authorized_for("read", "api") is True + # Extra fields preserved + assert token.scope["resource"] == "/users/123" + assert token.scope["method"] == "GET" + + def test_token_with_empty_scope(self): + """Test token with minimal scope.""" + now = datetime.utcnow() + token = CapabilityToken( + token_id="tok_empty", + scope={}, + expiry=now + timedelta(minutes=5), + policy_version_hash="hash1", + granted_at=now, + ) + + # Should not authorize anything + assert token.is_authorized_for("action", "tool") is False + + def test_very_long_ttl(self): + """Test token with very long TTL.""" + token = CapabilityToken.create("action", "tool", "hash", ttl_minutes=525600) # 1 year + + assert token.is_valid() is True + + # Expiry should be about 1 year from now + expected = datetime.utcnow() + timedelta(days=365) + time_diff = abs((token.expiry - expected).total_seconds()) + assert time_diff < 60 # Within 1 minute tolerance diff --git a/tests/test_decision_service.py b/tests/test_decision_service.py new file mode 100644 index 0000000..11e0049 --- /dev/null +++ b/tests/test_decision_service.py @@ -0,0 +1,685 @@ +"""Tests for decision service.""" + +import uuid +from datetime import datetime + +import pytest + +from lexecon.decision.service import DecisionRequest, DecisionResponse, DecisionService +from lexecon.identity.signing import NodeIdentity +from lexecon.ledger.chain import LedgerChain +from lexecon.policy.engine import PolicyEngine, PolicyMode +from lexecon.policy.relations import PolicyRelation +from lexecon.policy.terms import PolicyTerm + + +class TestDecisionRequest: + """Tests for DecisionRequest class.""" + + def test_create_decision_request(self): + """Test creating a decision request.""" + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web_search", + user_intent="Research AI governance", + ) + + assert request.actor == "model" + assert request.proposed_action == "search" + assert request.tool == "web_search" + assert request.user_intent == "Research AI governance" + assert request.risk_level == 1 # Default + assert request.policy_mode == "strict" # Default + assert isinstance(request.request_id, str) + assert len(request.request_id) > 0 + + def test_request_with_custom_id(self): + """Test creating request with custom ID.""" + custom_id = "req_123456" + request = DecisionRequest( + request_id=custom_id, + actor="user", + proposed_action="read", + tool="file_system", + user_intent="Read file", + ) + + assert request.request_id == custom_id + + def test_request_generates_uuid_if_no_id(self): + """Test that request generates UUID if no ID provided.""" + request1 = DecisionRequest( + actor="model", proposed_action="action1", tool="tool1", user_intent="intent1" + ) + request2 = DecisionRequest( + actor="model", proposed_action="action2", tool="tool2", user_intent="intent2" + ) + + # Should be different UUIDs + assert request1.request_id != request2.request_id + + # Should be valid UUIDs + uuid.UUID(request1.request_id) + uuid.UUID(request2.request_id) + + def test_request_with_data_classes(self): + """Test request with data classes.""" + request = DecisionRequest( + actor="model", + proposed_action="process", + tool="analytics", + user_intent="Analyze data", + data_classes=["pii", "financial"], + ) + + assert request.data_classes == ["pii", "financial"] + + def test_request_with_high_risk_level(self): + """Test request with high risk level.""" + request = DecisionRequest( + actor="model", + proposed_action="delete", + tool="database", + user_intent="Clean up", + risk_level=5, + ) + + assert request.risk_level == 5 + + def test_request_with_context(self): + """Test request with additional context.""" + context = {"session_id": "abc123", "ip": "192.168.1.1"} + request = DecisionRequest( + actor="user", + proposed_action="login", + tool="auth", + user_intent="Login", + context=context, + ) + + assert request.context == context + + def test_request_serialization(self): + """Test request serialization to dict.""" + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web", + user_intent="Search", + data_classes=["public"], + risk_level=2, + context={"query": "test"}, + ) + + data = request.to_dict() + + assert data["actor"] == "model" + assert data["proposed_action"] == "search" + assert data["tool"] == "web" + assert data["user_intent"] == "Search" + assert data["data_classes"] == ["public"] + assert data["risk_level"] == 2 + assert data["context"]["query"] == "test" + assert "timestamp" in data + assert "request_id" in data + + def test_request_has_timestamp(self): + """Test that request includes timestamp.""" + request = DecisionRequest( + actor="model", proposed_action="action", tool="tool", user_intent="intent" + ) + + assert hasattr(request, "timestamp") + # Should be ISO format + datetime.fromisoformat(request.timestamp) + + +class TestDecisionResponse: + """Tests for DecisionResponse class.""" + + def test_create_decision_response(self): + """Test creating a decision response.""" + response = DecisionResponse( + request_id="req_123", + decision="permit", + reasoning="Action is allowed by policy", + policy_version_hash="abc123", + ) + + assert response.request_id == "req_123" + assert response.decision == "permit" + assert response.reasoning == "Action is allowed by policy" + assert response.policy_version_hash == "abc123" + + def test_response_with_capability_token(self): + """Test response including capability token.""" + token = { + "token_id": "tok_abc123", + "scope": {"action": "read", "tool": "api"}, + "expiry": "2025-01-02T12:00:00", + } + + response = DecisionResponse( + request_id="req_456", + decision="permit", + reasoning="Permitted", + policy_version_hash="hash1", + capability_token=token, + ) + + assert response.capability_token == token + + def test_response_with_ledger_entry(self): + """Test response including ledger entry hash.""" + response = DecisionResponse( + request_id="req_789", + decision="deny", + reasoning="Forbidden by policy", + policy_version_hash="hash2", + ledger_entry_hash="ledger_hash_abc", + ) + + assert response.ledger_entry_hash == "ledger_hash_abc" + + def test_response_with_signature(self): + """Test response including signature.""" + response = DecisionResponse( + request_id="req_sig", + decision="permit", + reasoning="OK", + policy_version_hash="hash3", + signature="sig_base64_encoded", + ) + + assert response.signature == "sig_base64_encoded" + + def test_decision_hash_generation(self): + """Test decision hash generation.""" + response = DecisionResponse( + request_id="req_hash", + decision="permit", + reasoning="Test", + policy_version_hash="hash4", + ) + + decision_hash = response.decision_hash + + assert isinstance(decision_hash, str) + assert len(decision_hash) == 64 # SHA256 hex + + def test_decision_hash_is_deterministic(self): + """Test that decision hash is deterministic.""" + timestamp = "2025-01-01T00:00:00" + response1 = DecisionResponse( + request_id="req_det", + decision="permit", + reasoning="Test", + policy_version_hash="hash5", + timestamp=timestamp, + ) + response2 = DecisionResponse( + request_id="req_det", + decision="permit", + reasoning="Different reasoning", # Hash doesn't include reasoning + policy_version_hash="hash5", + timestamp=timestamp, + ) + + # Same request_id, decision, policy hash, timestamp -> same hash + assert response1.decision_hash == response2.decision_hash + + def test_response_serialization(self): + """Test response serialization to dict.""" + response = DecisionResponse( + request_id="req_ser", + decision="deny", + reasoning="Not authorized", + policy_version_hash="hash6", + ledger_entry_hash="ledger123", + signature="sig123", + ) + + data = response.to_dict() + + assert data["request_id"] == "req_ser" + assert data["decision"] == "deny" + assert data["reasoning"] == "Not authorized" + assert data["reason"] == "Not authorized" # Backwards compatibility + assert data["allowed"] is False # deny -> allowed=False + assert data["policy_version_hash"] == "hash6" + assert data["ledger_entry_hash"] == "ledger123" + assert data["signature"] == "sig123" + assert "timestamp" in data + + def test_response_allowed_field_for_permit(self): + """Test that 'allowed' field is True for permit.""" + response = DecisionResponse( + request_id="req_permit", + decision="permit", + reasoning="OK", + policy_version_hash="hash", + ) + + data = response.to_dict() + assert data["allowed"] is True + + def test_response_allowed_field_for_deny(self): + """Test that 'allowed' field is False for deny.""" + response = DecisionResponse( + request_id="req_deny", decision="deny", reasoning="No", policy_version_hash="hash" + ) + + data = response.to_dict() + assert data["allowed"] is False + + def test_response_signature_defaults_to_empty(self): + """Test that signature defaults to empty string in serialization.""" + response = DecisionResponse( + request_id="req_nosig", decision="permit", reasoning="OK", policy_version_hash="hash" + ) + + data = response.to_dict() + assert data["signature"] == "" + + +class TestDecisionService: + """Tests for DecisionService class.""" + + @pytest.fixture + def policy_engine(self): + """Create a policy engine with basic rules.""" + engine = PolicyEngine(mode=PolicyMode.STRICT) + engine.add_term(PolicyTerm.create_actor("model", "AI Model")) + engine.add_term(PolicyTerm.create_action("search", "Search")) + engine.add_relation(PolicyRelation.permits("actor:model", "action:search")) + return engine + + @pytest.fixture + def service(self, policy_engine): + """Create basic decision service.""" + return DecisionService(policy_engine) + + @pytest.fixture + def full_service(self, policy_engine): + """Create decision service with ledger and identity.""" + ledger = LedgerChain() + identity = NodeIdentity("test-node") + return DecisionService(policy_engine, ledger, identity) + + def test_service_initialization(self, policy_engine): + """Test creating decision service.""" + service = DecisionService(policy_engine) + + assert service.policy_engine is not None + assert service.ledger is None + assert service.identity is None + + def test_service_with_ledger_and_identity(self, full_service): + """Test service with all components.""" + assert full_service.policy_engine is not None + assert full_service.ledger is not None + assert full_service.identity is not None + + def test_evaluate_request_permit(self, service): + """Test evaluating a request that should be permitted.""" + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web_search", + user_intent="Search for information", + ) + + response = service.evaluate_request(request) + + assert response.decision == "permit" + assert response.request_id == request.request_id + assert response.capability_token is not None + assert "token_id" in response.capability_token + + def test_evaluate_request_deny(self, service): + """Test evaluating a request that should be denied.""" + request = DecisionRequest( + actor="model", + proposed_action="delete", # Not permitted + tool="database", + user_intent="Delete records", + ) + + response = service.evaluate_request(request) + + assert response.decision == "deny" + assert response.capability_token is None + + def test_capability_token_generation(self, service): + """Test that capability token is generated for permitted actions.""" + request = DecisionRequest( + actor="model", proposed_action="search", tool="web_search", user_intent="Search" + ) + + response = service.evaluate_request(request) + + token = response.capability_token + assert token is not None + assert token["scope"]["action"] == "search" + assert token["scope"]["tool"] == "web_search" + assert "expiry" in token + assert "granted_at" in token + + def test_ledger_entry_creation(self, full_service): + """Test that ledger entry is created.""" + request = DecisionRequest( + actor="model", proposed_action="search", tool="web", user_intent="Test" + ) + + response = full_service.evaluate_request(request) + + assert response.ledger_entry_hash is not None + assert len(response.ledger_entry_hash) == 64 # SHA256 hex + + # Verify entry exists in ledger by hash + entry = full_service.ledger.get_entry(response.ledger_entry_hash) + assert entry is not None + assert entry.event_type == "decision" + + def test_signature_generation(self, full_service): + """Test that decision is signed.""" + request = DecisionRequest( + actor="model", proposed_action="search", tool="web", user_intent="Test" + ) + + response = full_service.evaluate_request(request) + + assert response.signature is not None + assert len(response.signature) > 0 + + # Verify it's valid base64 + import base64 + + decoded = base64.b64decode(response.signature) + assert len(decoded) > 0 + + def test_evaluate_convenience_method_simple(self, policy_engine): + """Test simple evaluate method.""" + service = DecisionService(policy_engine) + + # Simple policy evaluation + result = service.evaluate(actor="model", action="search") + + # Should return PolicyDecision object + assert hasattr(result, "allowed") + assert hasattr(result, "reason") + + def test_evaluate_convenience_method_full(self, full_service): + """Test evaluate method with full parameters.""" + # Full decision request + result = full_service.evaluate( + actor="model", + proposed_action="search", + tool="web_search", + user_intent="Testing", + risk_level=1, + ) + + # Should return DecisionResponse + assert hasattr(result, "decision") + assert hasattr(result, "capability_token") + assert hasattr(result, "ledger_entry_hash") + + def test_multiple_decisions_logged(self, full_service): + """Test that multiple decisions are logged to ledger.""" + initial_count = len(full_service.ledger.entries) + + for i in range(5): + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web", + user_intent=f"Search {i}", + ) + full_service.evaluate_request(request) + + final_count = len(full_service.ledger.entries) + assert final_count == initial_count + 5 + + def test_decision_with_data_classes(self, policy_engine): + """Test decision evaluation with data classes.""" + service = DecisionService(policy_engine) + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web", + user_intent="Search", + data_classes=["public"], + ) + + response = service.evaluate_request(request) + + # Should still work + assert response.decision in ["permit", "deny"] + + def test_decision_with_high_risk_level(self, service): + """Test decision with high risk level.""" + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web", + user_intent="Search", + risk_level=5, + ) + + response = service.evaluate_request(request) + + # Should still evaluate + assert response.decision in ["permit", "deny"] + + def test_policy_version_hash_in_response(self, service): + """Test that response includes policy version hash.""" + request = DecisionRequest( + actor="model", proposed_action="search", tool="web", user_intent="Test" + ) + + response = service.evaluate_request(request) + + assert response.policy_version_hash is not None + assert len(response.policy_version_hash) == 64 # SHA256 hex + + def test_reasoning_in_response(self, service): + """Test that response includes reasoning.""" + request = DecisionRequest( + actor="model", proposed_action="search", tool="web", user_intent="Test" + ) + + response = service.evaluate_request(request) + + assert response.reasoning is not None + assert len(response.reasoning) > 0 + + +class TestDecisionWorkflow: + """Integration tests for complete decision workflow.""" + + def test_complete_workflow(self): + """Test complete decision workflow from request to signed response.""" + # Setup + engine = PolicyEngine(mode=PolicyMode.STRICT) + engine.add_term(PolicyTerm.create_actor("user", "User")) + engine.add_term(PolicyTerm.create_action("read", "Read")) + engine.add_relation(PolicyRelation.permits("actor:user", "action:read")) + + ledger = LedgerChain() + identity = NodeIdentity("governance-node") + service = DecisionService(engine, ledger, identity) + + # Make decision + request = DecisionRequest( + actor="user", proposed_action="read", tool="file_system", user_intent="Read file" + ) + + response = service.evaluate_request(request) + + # Verify all components + assert response.decision == "permit" + assert response.capability_token is not None + assert response.ledger_entry_hash is not None + assert response.signature is not None + + # Verify ledger integrity + assert ledger.verify_integrity()["valid"] is True + + # Verify signature + is_valid = identity.verify_signature(response.decision_hash, response.signature) + assert is_valid is True + + def test_deny_workflow(self): + """Test workflow for denied decision.""" + # Strict mode - deny by default + engine = PolicyEngine(mode=PolicyMode.STRICT) + ledger = LedgerChain() + identity = NodeIdentity("test-node") + service = DecisionService(engine, ledger, identity) + + request = DecisionRequest( + actor="unknown", proposed_action="forbidden", tool="admin", user_intent="Test" + ) + + response = service.evaluate_request(request) + + # Should be denied + assert response.decision == "deny" + assert response.capability_token is None # No token for deny + assert response.ledger_entry_hash is not None # Still logged + assert response.signature is not None # Still signed + + def test_decision_audit_trail(self): + """Test that decisions create proper audit trail.""" + engine = PolicyEngine(mode=PolicyMode.STRICT) + engine.add_term(PolicyTerm.create_actor("bot", "Bot")) + engine.add_term(PolicyTerm.create_action("execute", "Execute")) + engine.add_relation(PolicyRelation.permits("actor:bot", "action:execute")) + + ledger = LedgerChain() + identity = NodeIdentity("audit-node") + service = DecisionService(engine, ledger, identity) + + # Make multiple decisions + requests = [ + DecisionRequest( + actor="bot", proposed_action="execute", tool=f"tool{i}", user_intent=f"Task {i}" + ) + for i in range(3) + ] + + responses = [service.evaluate_request(req) for req in requests] + + # Check audit report + report = ledger.generate_audit_report() + + assert report["total_entries"] >= 4 # Genesis + 3 decisions + assert "decision" in report["event_type_counts"] + assert report["event_type_counts"]["decision"] == 3 + + # Verify all entries + decision_entries = ledger.get_entries_by_type("decision") + assert len(decision_entries) == 3 + + def test_concurrent_decisions(self): + """Test handling multiple concurrent decisions.""" + engine = PolicyEngine(mode=PolicyMode.PERMISSIVE) + ledger = LedgerChain() + identity = NodeIdentity("concurrent-node") + service = DecisionService(engine, ledger, identity) + + # Simulate concurrent requests + requests = [ + DecisionRequest( + actor=f"actor{i}", proposed_action="action", tool="tool", user_intent="Intent" + ) + for i in range(10) + ] + + responses = [service.evaluate_request(req) for req in requests] + + # All should succeed + assert len(responses) == 10 + assert all(r.ledger_entry_hash is not None for r in responses) + + # All hashes should be unique + hashes = [r.ledger_entry_hash for r in responses] + assert len(hashes) == len(set(hashes)) + + # Ledger should still be valid + assert ledger.verify_integrity()["valid"] is True + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_empty_user_intent(self): + """Test request with empty user intent.""" + engine = PolicyEngine() + service = DecisionService(engine) + + request = DecisionRequest( + actor="model", proposed_action="action", tool="tool", user_intent="" + ) + + response = service.evaluate_request(request) + assert response is not None + + def test_very_long_user_intent(self): + """Test request with very long user intent.""" + engine = PolicyEngine() + service = DecisionService(engine) + + long_intent = "A" * 10000 + request = DecisionRequest( + actor="model", proposed_action="action", tool="tool", user_intent=long_intent + ) + + response = service.evaluate_request(request) + assert response is not None + + def test_special_characters_in_fields(self): + """Test request with special characters.""" + engine = PolicyEngine() + service = DecisionService(engine) + + request = DecisionRequest( + actor="model-v2.0", + proposed_action="read/write", + tool="tool_@#$", + user_intent="Test with 特殊文字 and émojis 🔒", + ) + + response = service.evaluate_request(request) + assert response is not None + + def test_decision_without_ledger(self): + """Test that service works without ledger.""" + engine = PolicyEngine() + service = DecisionService(engine) # No ledger + + request = DecisionRequest( + actor="model", proposed_action="action", tool="tool", user_intent="Test" + ) + + response = service.evaluate_request(request) + + assert response.ledger_entry_hash is None + assert response.decision in ["permit", "deny"] + + def test_decision_without_identity(self): + """Test that service works without identity.""" + engine = PolicyEngine() + ledger = LedgerChain() + service = DecisionService(engine, ledger) # No identity + + request = DecisionRequest( + actor="model", proposed_action="action", tool="tool", user_intent="Test" + ) + + response = service.evaluate_request(request) + + assert response.signature is None + assert response.ledger_entry_hash is not None # Ledger still works diff --git a/tests/test_health.py b/tests/test_health.py new file mode 100644 index 0000000..0b02954 --- /dev/null +++ b/tests/test_health.py @@ -0,0 +1,513 @@ +"""Tests for health check and observability functionality.""" + +import time + +import pytest + +from lexecon.observability.health import ( + HealthCheck, + HealthStatus, + check_identity, + check_ledger, + check_policy_engine, +) + + +class TestHealthStatus: + """Tests for HealthStatus enum.""" + + def test_health_status_values(self): + """Test that health status enum has expected values.""" + assert HealthStatus.HEALTHY == "healthy" + assert HealthStatus.DEGRADED == "degraded" + assert HealthStatus.UNHEALTHY == "unhealthy" + + def test_health_status_is_string(self): + """Test that health status values are strings.""" + assert isinstance(HealthStatus.HEALTHY, str) + assert isinstance(HealthStatus.DEGRADED, str) + assert isinstance(HealthStatus.UNHEALTHY, str) + + +class TestHealthCheck: + """Tests for HealthCheck class.""" + + def test_initialization(self): + """Test health check initialization.""" + hc = HealthCheck() + + assert hc.checks is not None + assert hc.start_time > 0 + assert isinstance(hc.checks, dict) + + def test_liveness_probe(self): + """Test liveness probe returns healthy status.""" + hc = HealthCheck() + + result = hc.liveness() + + assert result["status"] == HealthStatus.HEALTHY + assert "timestamp" in result + assert "uptime_seconds" in result + assert result["uptime_seconds"] >= 0 + + def test_liveness_uptime_increases(self): + """Test that uptime increases over time.""" + hc = HealthCheck() + + result1 = hc.liveness() + time.sleep(0.1) + result2 = hc.liveness() + + assert result2["uptime_seconds"] > result1["uptime_seconds"] + + def test_readiness_probe_no_checks(self): + """Test readiness probe with no health checks registered.""" + hc = HealthCheck() + hc.checks = {} # Clear default checks + + result = hc.readiness() + + assert result["status"] == HealthStatus.HEALTHY + assert "timestamp" in result + assert "checks" in result + assert len(result["checks"]) == 0 + + def test_readiness_probe_all_healthy(self): + """Test readiness probe when all checks are healthy.""" + hc = HealthCheck() + hc.checks = {} # Clear default checks + + # Add healthy checks + hc.add_check("check1", lambda: (HealthStatus.HEALTHY, {"detail": "ok"})) + hc.add_check("check2", lambda: (HealthStatus.HEALTHY, {"detail": "ok"})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.HEALTHY + assert len(result["checks"]) == 2 + assert all(c["status"] == HealthStatus.HEALTHY for c in result["checks"]) + + def test_readiness_probe_one_degraded(self): + """Test readiness probe with one degraded check.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("healthy", lambda: (HealthStatus.HEALTHY, {})) + hc.add_check("degraded", lambda: (HealthStatus.DEGRADED, {"reason": "slow"})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.DEGRADED + assert len(result["checks"]) == 2 + + def test_readiness_probe_one_unhealthy(self): + """Test readiness probe with one unhealthy check.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("healthy", lambda: (HealthStatus.HEALTHY, {})) + hc.add_check("unhealthy", lambda: (HealthStatus.UNHEALTHY, {"error": "failed"})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.UNHEALTHY + assert len(result["checks"]) == 2 + + # Find unhealthy check + unhealthy_check = next(c for c in result["checks"] if c["name"] == "unhealthy") + assert unhealthy_check["details"]["error"] == "failed" + + def test_readiness_unhealthy_takes_precedence(self): + """Test that unhealthy status takes precedence over degraded.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("degraded", lambda: (HealthStatus.DEGRADED, {})) + hc.add_check("unhealthy", lambda: (HealthStatus.UNHEALTHY, {})) + hc.add_check("healthy", lambda: (HealthStatus.HEALTHY, {})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.UNHEALTHY + + def test_readiness_handles_check_exception(self): + """Test readiness probe handles exceptions in health checks.""" + hc = HealthCheck() + hc.checks = {} + + def failing_check(): + raise RuntimeError("Check failed") + + hc.add_check("healthy", lambda: (HealthStatus.HEALTHY, {})) + hc.add_check("failing", failing_check) + + result = hc.readiness() + + # Overall status should be unhealthy + assert result["status"] == HealthStatus.UNHEALTHY + + # Failing check should show error + failing_result = next(c for c in result["checks"] if c["name"] == "failing") + assert failing_result["status"] == HealthStatus.UNHEALTHY + assert "error" in failing_result["details"] + assert "Check failed" in failing_result["details"]["error"] + + def test_add_check(self): + """Test adding custom health check.""" + hc = HealthCheck() + hc.checks = {} + + def custom_check(): + return HealthStatus.HEALTHY, {"custom": "data"} + + hc.add_check("custom", custom_check) + + assert "custom" in hc.checks + assert hc.checks["custom"] == custom_check + + def test_add_multiple_checks(self): + """Test adding multiple health checks.""" + hc = HealthCheck() + hc.checks = {} + + for i in range(5): + hc.add_check(f"check{i}", lambda: (HealthStatus.HEALTHY, {})) + + assert len(hc.checks) == 5 + + def test_startup_probe(self): + """Test startup probe.""" + hc = HealthCheck() + + result = hc.startup() + + assert result["status"] == HealthStatus.HEALTHY + assert "timestamp" in result + assert "message" in result + assert result["message"] == "Service initialized" + + def test_readiness_check_details(self): + """Test that readiness probe includes check details.""" + hc = HealthCheck() + hc.checks = {} + + details = {"version": "1.0", "connections": 5} + hc.add_check("detailed", lambda: (HealthStatus.HEALTHY, details)) + + result = hc.readiness() + + check_result = result["checks"][0] + assert check_result["name"] == "detailed" + assert check_result["details"] == details + + def test_multiple_unhealthy_checks(self): + """Test handling multiple unhealthy checks.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("fail1", lambda: (HealthStatus.UNHEALTHY, {"error": "error1"})) + hc.add_check("fail2", lambda: (HealthStatus.UNHEALTHY, {"error": "error2"})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.UNHEALTHY + assert len(result["checks"]) == 2 + assert all(c["status"] == HealthStatus.UNHEALTHY for c in result["checks"]) + + +class TestDefaultHealthChecks: + """Tests for default health check functions.""" + + def test_check_policy_engine(self): + """Test policy engine health check.""" + status, details = check_policy_engine() + + assert status == HealthStatus.HEALTHY + assert "policies_loaded" in details + assert isinstance(details["policies_loaded"], int) + + def test_check_ledger(self): + """Test ledger health check.""" + status, details = check_ledger() + + assert status == HealthStatus.HEALTHY + assert "entries" in details + assert "last_verified" in details + assert isinstance(details["entries"], int) + assert isinstance(details["last_verified"], float) + + def test_check_identity(self): + """Test identity health check.""" + status, details = check_identity() + + assert status == HealthStatus.HEALTHY + assert "key_loaded" in details + assert isinstance(details["key_loaded"], bool) + + +class TestHealthCheckIntegration: + """Integration tests for health check system.""" + + def test_default_health_check_instance(self): + """Test that default health check instance has checks registered.""" + from lexecon.observability.health import health_check + + # Should have default checks registered + assert len(health_check.checks) > 0 + assert "policy_engine" in health_check.checks + assert "ledger" in health_check.checks + assert "identity" in health_check.checks + + def test_full_readiness_check(self): + """Test full readiness check with default checks.""" + from lexecon.observability.health import health_check + + result = health_check.readiness() + + assert "status" in result + assert "checks" in result + assert len(result["checks"]) >= 3 # At least the 3 default checks + + def test_liveness_timestamp_is_recent(self): + """Test that liveness timestamp is recent.""" + hc = HealthCheck() + + result = hc.liveness() + + # Timestamp should be within last second + now = time.time() + assert abs(now - result["timestamp"]) < 1.0 + + def test_readiness_timestamp_is_recent(self): + """Test that readiness timestamp is recent.""" + hc = HealthCheck() + + result = hc.readiness() + + now = time.time() + assert abs(now - result["timestamp"]) < 1.0 + + +class TestHealthCheckEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_check_returning_none(self): + """Test handling check that returns None.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("none_check", lambda: None) + + # Should handle by catching exception in readiness check + result = hc.readiness() + # Should mark as unhealthy due to exception + assert result["status"] == HealthStatus.UNHEALTHY + + def test_check_returning_invalid_status(self): + """Test handling check with invalid status.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("invalid", lambda: ("invalid_status", {})) + + result = hc.readiness() + + # Should still run, but status might not be recognized + assert result is not None + + def test_check_with_empty_details(self): + """Test check with empty details.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("empty", lambda: (HealthStatus.HEALTHY, {})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.HEALTHY + check = result["checks"][0] + assert check["details"] == {} + + def test_check_with_complex_details(self): + """Test check with complex nested details.""" + hc = HealthCheck() + hc.checks = {} + + complex_details = { + "metrics": {"cpu": 45.2, "memory": 1024}, + "connections": [{"id": 1, "status": "active"}, {"id": 2, "status": "idle"}], + "metadata": {"version": "1.0", "uptime": 3600}, + } + + hc.add_check("complex", lambda: (HealthStatus.HEALTHY, complex_details)) + + result = hc.readiness() + + check = result["checks"][0] + assert check["details"] == complex_details + + def test_overwrite_check(self): + """Test that adding check with same name overwrites.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("check", lambda: (HealthStatus.HEALTHY, {"version": 1})) + hc.add_check("check", lambda: (HealthStatus.DEGRADED, {"version": 2})) + + result = hc.readiness() + + # Should only have 1 check + assert len(result["checks"]) == 1 + + # Should be the second one + assert result["checks"][0]["status"] == HealthStatus.DEGRADED + assert result["checks"][0]["details"]["version"] == 2 + + def test_check_with_very_long_execution_time(self): + """Test check that takes long to execute.""" + hc = HealthCheck() + hc.checks = {} + + def slow_check(): + time.sleep(0.2) + return HealthStatus.HEALTHY, {"slow": True} + + hc.add_check("slow", slow_check) + + start = time.time() + result = hc.readiness() + duration = time.time() - start + + # Should still complete + assert result["status"] == HealthStatus.HEALTHY + assert duration >= 0.2 + + def test_many_checks(self): + """Test with many health checks.""" + hc = HealthCheck() + hc.checks = {} + + # Add 100 checks + for i in range(100): + hc.add_check(f"check{i}", lambda: (HealthStatus.HEALTHY, {})) + + result = hc.readiness() + + assert len(result["checks"]) == 100 + assert result["status"] == HealthStatus.HEALTHY + + def test_mixed_status_priority(self): + """Test status priority with all three statuses.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("healthy1", lambda: (HealthStatus.HEALTHY, {})) + hc.add_check("healthy2", lambda: (HealthStatus.HEALTHY, {})) + hc.add_check("degraded1", lambda: (HealthStatus.DEGRADED, {})) + hc.add_check("unhealthy1", lambda: (HealthStatus.UNHEALTHY, {})) + + result = hc.readiness() + + # Unhealthy should take precedence + assert result["status"] == HealthStatus.UNHEALTHY + + def test_only_degraded_checks(self): + """Test with only degraded checks.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("deg1", lambda: (HealthStatus.DEGRADED, {})) + hc.add_check("deg2", lambda: (HealthStatus.DEGRADED, {})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.DEGRADED + + +class TestHealthCheckConcurrency: + """Tests for concurrent health check scenarios.""" + + def test_multiple_liveness_calls(self): + """Test multiple simultaneous liveness calls.""" + hc = HealthCheck() + + results = [hc.liveness() for _ in range(10)] + + # All should succeed + assert len(results) == 10 + assert all(r["status"] == HealthStatus.HEALTHY for r in results) + + def test_multiple_readiness_calls(self): + """Test multiple simultaneous readiness calls.""" + hc = HealthCheck() + hc.checks = {} + hc.add_check("test", lambda: (HealthStatus.HEALTHY, {})) + + results = [hc.readiness() for _ in range(10)] + + # All should succeed + assert len(results) == 10 + assert all(r["status"] == HealthStatus.HEALTHY for r in results) + + def test_concurrent_check_modifications(self): + """Test that checks can be added during operation.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("initial", lambda: (HealthStatus.HEALTHY, {})) + + result1 = hc.readiness() + assert len(result1["checks"]) == 1 + + hc.add_check("added", lambda: (HealthStatus.HEALTHY, {})) + + result2 = hc.readiness() + assert len(result2["checks"]) == 2 + + +class TestHealthCheckSerialization: + """Tests for health check result serialization.""" + + def test_liveness_result_is_dict(self): + """Test that liveness result is JSON-serializable.""" + hc = HealthCheck() + result = hc.liveness() + + import json + + # Should be serializable to JSON + json_str = json.dumps(result) + assert len(json_str) > 0 + + # Should be deserializable + parsed = json.loads(json_str) + assert parsed["status"] == HealthStatus.HEALTHY + + def test_readiness_result_is_dict(self): + """Test that readiness result is JSON-serializable.""" + hc = HealthCheck() + hc.checks = {} + hc.add_check("test", lambda: (HealthStatus.HEALTHY, {"count": 5})) + + result = hc.readiness() + + import json + + json_str = json.dumps(result) + parsed = json.loads(json_str) + + assert parsed["status"] == HealthStatus.HEALTHY + assert len(parsed["checks"]) == 1 + + def test_startup_result_is_dict(self): + """Test that startup result is JSON-serializable.""" + hc = HealthCheck() + result = hc.startup() + + import json + + json_str = json.dumps(result) + parsed = json.loads(json_str) + + assert parsed["status"] == HealthStatus.HEALTHY diff --git a/tests/test_identity.py b/tests/test_identity.py new file mode 100644 index 0000000..8af89d9 --- /dev/null +++ b/tests/test_identity.py @@ -0,0 +1,548 @@ +"""Tests for identity and signing functionality.""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from lexecon.identity.signing import KeyManager, NodeIdentity + + +class TestKeyManager: + """Tests for KeyManager class.""" + + def test_generate_key_pair(self): + """Test generating a new Ed25519 key pair.""" + km = KeyManager.generate() + + assert km.private_key is not None + assert km.public_key is not None + + def test_different_key_pairs_are_unique(self): + """Test that multiple generated key pairs are different.""" + km1 = KeyManager.generate() + km2 = KeyManager.generate() + + # Get fingerprints to compare + fp1 = km1.get_public_key_fingerprint() + fp2 = km2.get_public_key_fingerprint() + + assert fp1 != fp2 + + def test_save_keys_to_disk(self): + """Test saving keys to disk in PEM format.""" + km = KeyManager.generate() + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "test.key" + public_path = Path(tmpdir) / "test.pub" + + km.save_keys(private_path, public_path) + + # Check files exist + assert private_path.exists() + assert public_path.exists() + + # Check files have content + assert len(private_path.read_bytes()) > 0 + assert len(public_path.read_bytes()) > 0 + + # Check PEM format + private_content = private_path.read_text() + public_content = public_path.read_text() + + assert "BEGIN PRIVATE KEY" in private_content + assert "END PRIVATE KEY" in private_content + assert "BEGIN PUBLIC KEY" in public_content + assert "END PUBLIC KEY" in public_content + + def test_load_keys_from_disk(self): + """Test loading private key from disk.""" + km_original = KeyManager.generate() + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "test.key" + public_path = Path(tmpdir) / "test.pub" + + # Save keys + km_original.save_keys(private_path, public_path) + + # Load keys + km_loaded = KeyManager.load_keys(private_path) + + assert km_loaded.private_key is not None + assert km_loaded.public_key is not None + + # Fingerprints should match + assert ( + km_loaded.get_public_key_fingerprint() + == km_original.get_public_key_fingerprint() + ) + + def test_load_public_key_from_disk(self): + """Test loading public key separately.""" + km = KeyManager.generate() + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "test.key" + public_path = Path(tmpdir) / "test.pub" + + km.save_keys(private_path, public_path) + + # Load public key + public_key = KeyManager.load_public_key(public_path) + + assert public_key is not None + + def test_sign_data(self): + """Test signing data with private key.""" + km = KeyManager.generate() + data = {"message": "test", "value": 42} + + signature = km.sign(data) + + # Signature should be base64 string + assert isinstance(signature, str) + assert len(signature) > 0 + + # Should be valid base64 + import base64 + + decoded = base64.b64decode(signature) + assert len(decoded) > 0 + + def test_sign_creates_deterministic_signature(self): + """Test that signing same data twice gives same signature.""" + km = KeyManager.generate() + data = {"key": "value", "number": 123} + + sig1 = km.sign(data) + sig2 = km.sign(data) + + assert sig1 == sig2 + + def test_sign_different_data_gives_different_signatures(self): + """Test that different data produces different signatures.""" + km = KeyManager.generate() + data1 = {"message": "first"} + data2 = {"message": "second"} + + sig1 = km.sign(data1) + sig2 = km.sign(data2) + + assert sig1 != sig2 + + def test_verify_valid_signature(self): + """Test verifying a valid signature.""" + km = KeyManager.generate() + data = {"test": "data", "count": 5} + + signature = km.sign(data) + is_valid = KeyManager.verify(data, signature, km.public_key) + + assert is_valid is True + + def test_verify_invalid_signature(self): + """Test that invalid signature fails verification.""" + km = KeyManager.generate() + data = {"test": "data"} + + signature = km.sign(data) + + # Tamper with signature + import base64 + + sig_bytes = base64.b64decode(signature) + tampered = sig_bytes[:-1] + bytes([sig_bytes[-1] ^ 0xFF]) + tampered_sig = base64.b64encode(tampered).decode() + + is_valid = KeyManager.verify(data, tampered_sig, km.public_key) + + assert is_valid is False + + def test_verify_signature_with_wrong_key(self): + """Test that signature verification fails with wrong public key.""" + km1 = KeyManager.generate() + km2 = KeyManager.generate() + data = {"test": "data"} + + # Sign with km1 + signature = km1.sign(data) + + # Verify with km2's public key + is_valid = KeyManager.verify(data, signature, km2.public_key) + + assert is_valid is False + + def test_verify_signature_with_tampered_data(self): + """Test that verification fails when data is tampered.""" + km = KeyManager.generate() + data = {"value": 100} + + signature = km.sign(data) + + # Tamper with data + tampered_data = {"value": 999} + + is_valid = KeyManager.verify(tampered_data, signature, km.public_key) + + assert is_valid is False + + def test_sign_without_private_key_raises_error(self): + """Test that signing without private key raises error.""" + km = KeyManager() # No key + + with pytest.raises(ValueError, match="No private key"): + km.sign({"test": "data"}) + + def test_save_keys_without_private_key_raises_error(self): + """Test that saving without private key raises error.""" + km = KeyManager() # No key + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "test.key" + public_path = Path(tmpdir) / "test.pub" + + with pytest.raises(ValueError, match="No private key"): + km.save_keys(private_path, public_path) + + def test_get_public_key_fingerprint(self): + """Test getting public key fingerprint.""" + km = KeyManager.generate() + + fingerprint = km.get_public_key_fingerprint() + + # Should be 16 character hex string (first 16 chars of SHA256) + assert isinstance(fingerprint, str) + assert len(fingerprint) == 16 + # Should be valid hex + int(fingerprint, 16) + + def test_fingerprint_is_deterministic(self): + """Test that fingerprint is deterministic for same key.""" + km = KeyManager.generate() + + fp1 = km.get_public_key_fingerprint() + fp2 = km.get_public_key_fingerprint() + + assert fp1 == fp2 + + def test_get_fingerprint_without_public_key_raises_error(self): + """Test that getting fingerprint without key raises error.""" + km = KeyManager() + + with pytest.raises(ValueError, match="No public key"): + km.get_public_key_fingerprint() + + def test_sign_canonical_json(self): + """Test that signing uses canonical JSON representation.""" + km = KeyManager.generate() + + # These should produce the same signature due to canonical JSON + data1 = {"b": 2, "a": 1} + data2 = {"a": 1, "b": 2} + + sig1 = km.sign(data1) + sig2 = km.sign(data2) + + assert sig1 == sig2 + + def test_sign_handles_nested_data(self): + """Test signing complex nested data structures.""" + km = KeyManager.generate() + data = { + "user": {"name": "Alice", "id": 123}, + "permissions": ["read", "write"], + "metadata": {"created": "2025-01-01", "version": 2}, + } + + signature = km.sign(data) + is_valid = KeyManager.verify(data, signature, km.public_key) + + assert is_valid is True + + def test_key_persistence_roundtrip(self): + """Test full roundtrip: generate, save, load, verify.""" + km_original = KeyManager.generate() + data = {"test": "roundtrip"} + original_signature = km_original.sign(data) + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "key.pem" + public_path = Path(tmpdir) / "key.pub" + + # Save + km_original.save_keys(private_path, public_path) + + # Load + km_loaded = KeyManager.load_keys(private_path) + + # Should be able to sign with loaded key + loaded_signature = km_loaded.sign(data) + + # Signatures should match + assert loaded_signature == original_signature + + # Verification should work + assert KeyManager.verify(data, loaded_signature, km_loaded.public_key) is True + + +class TestNodeIdentity: + """Tests for NodeIdentity class.""" + + def test_create_node_identity(self): + """Test creating a node identity.""" + node = NodeIdentity("test-node-1") + + assert node.node_id == "test-node-1" + assert node.key_manager is not None + assert node.key_manager.private_key is not None + + def test_create_with_existing_key_manager(self): + """Test creating node with existing key manager.""" + km = KeyManager.generate() + node = NodeIdentity("test-node-2", key_manager=km) + + assert node.node_id == "test-node-2" + assert node.key_manager is km + + def test_get_node_id(self): + """Test getting node ID.""" + node = NodeIdentity("my-node") + + assert node.get_node_id() == "my-node" + + def test_sign_data(self): + """Test signing data through node identity.""" + node = NodeIdentity("signer-node") + data = {"message": "test", "timestamp": "2025-01-01"} + + signature = node.sign(data) + + assert isinstance(signature, str) + assert len(signature) > 0 + + def test_get_public_key_fingerprint(self): + """Test getting public key fingerprint.""" + node = NodeIdentity("fp-node") + + fingerprint = node.get_public_key_fingerprint() + + assert isinstance(fingerprint, str) + assert len(fingerprint) == 16 + + def test_verify_signature_with_string_data(self): + """Test verifying signature on string data (like hashes).""" + node = NodeIdentity("verify-node") + + # Simulate signing a hash string + hash_string = "abc123def456" + + # Sign it manually through key manager + import base64 + message = hash_string.encode() + signature_bytes = node.key_manager.private_key.sign(message) + signature = base64.b64encode(signature_bytes).decode() + + # Verify using node identity + is_valid = node.verify_signature(hash_string, signature) + + assert is_valid is True + + def test_verify_signature_fails_with_wrong_data(self): + """Test that verification fails with different data.""" + node = NodeIdentity("verify-node") + + original_data = "original_hash" + tampered_data = "tampered_hash" + + # Sign original + import base64 + message = original_data.encode() + signature_bytes = node.key_manager.private_key.sign(message) + signature = base64.b64encode(signature_bytes).decode() + + # Try to verify with tampered data + is_valid = node.verify_signature(tampered_data, signature) + + assert is_valid is False + + def test_verify_signature_fails_with_wrong_signature(self): + """Test that verification fails with invalid signature.""" + node = NodeIdentity("verify-node") + data = "test_data" + + # Create invalid signature + fake_signature = "aW52YWxpZF9zaWduYXR1cmU=" # base64 of "invalid_signature" + + is_valid = node.verify_signature(data, fake_signature) + + assert is_valid is False + + def test_verify_signature_without_public_key(self): + """Test verification fails without public key.""" + node = NodeIdentity("no-key-node") + # Remove public key + node.key_manager.public_key = None + + is_valid = node.verify_signature("data", "signature") + + assert is_valid is False + + def test_different_nodes_have_different_fingerprints(self): + """Test that different nodes have unique fingerprints.""" + node1 = NodeIdentity("node-1") + node2 = NodeIdentity("node-2") + + fp1 = node1.get_public_key_fingerprint() + fp2 = node2.get_public_key_fingerprint() + + assert fp1 != fp2 + + def test_node_can_verify_own_signature(self): + """Test that node can verify its own signatures.""" + node = NodeIdentity("self-verify-node") + data = {"action": "test", "value": 42} + + # Sign with node + signature = node.sign(data) + + # Convert dict to canonical JSON for verification + import json + canonical = json.dumps(data, sort_keys=True, separators=(",", ":")) + + # This should work with the node's verify_signature method + # but it expects string data, so we need to verify differently + # Let's use the key manager's verify method + is_valid = KeyManager.verify(data, signature, node.key_manager.public_key) + + assert is_valid is True + + +class TestCrossNodeVerification: + """Tests for cross-node signature verification.""" + + def test_node_cannot_verify_other_node_signature(self): + """Test that one node cannot forge another's signature.""" + node1 = NodeIdentity("node-1") + node2 = NodeIdentity("node-2") + + data = {"message": "test"} + + # Node 1 signs + signature = node1.sign(data) + + # Node 2 tries to verify with its own key + is_valid = KeyManager.verify(data, signature, node2.key_manager.public_key) + + assert is_valid is False + + def test_public_key_distribution(self): + """Test that public keys can be shared for verification.""" + node1 = NodeIdentity("alice") + node2 = NodeIdentity("bob") + + data = {"transfer": "100", "to": "bob"} + + # Alice signs + signature = node1.sign(data) + + # Bob can verify using Alice's public key + is_valid = KeyManager.verify(data, signature, node1.key_manager.public_key) + + assert is_valid is True + + def test_signature_persistence_across_nodes(self): + """Test signature verification works after key export/import.""" + # Node 1 creates and signs + node1 = NodeIdentity("node-1") + data = {"test": "data"} + signature = node1.sign(data) + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "node1.key" + public_path = Path(tmpdir) / "node1.pub" + + # Export keys + node1.key_manager.save_keys(private_path, public_path) + + # Load into new key manager (simulating different node) + loaded_km = KeyManager.load_keys(private_path) + + # Should be able to verify with loaded keys + is_valid = KeyManager.verify(data, signature, loaded_km.public_key) + assert is_valid is True + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_sign_empty_dict(self): + """Test signing empty dictionary.""" + km = KeyManager.generate() + data = {} + + signature = km.sign(data) + is_valid = KeyManager.verify(data, signature, km.public_key) + + assert is_valid is True + + def test_sign_large_data(self): + """Test signing large data structure.""" + km = KeyManager.generate() + data = {"items": [{"id": i, "value": f"item_{i}"} for i in range(1000)]} + + signature = km.sign(data) + is_valid = KeyManager.verify(data, signature, km.public_key) + + assert is_valid is True + + def test_sign_with_unicode(self): + """Test signing data with unicode characters.""" + km = KeyManager.generate() + data = {"message": "Hello 世界 🌍", "emoji": "🔐"} + + signature = km.sign(data) + is_valid = KeyManager.verify(data, signature, km.public_key) + + assert is_valid is True + + def test_node_id_with_special_characters(self): + """Test node identity with special characters in ID.""" + node = NodeIdentity("node-123_test.example.com") + + assert node.get_node_id() == "node-123_test.example.com" + + def test_verify_with_malformed_signature(self): + """Test verification with malformed base64 signature.""" + node = NodeIdentity("test-node") + + # Invalid base64 + invalid_sig = "not-valid-base64!!!" + + is_valid = node.verify_signature("data", invalid_sig) + + assert is_valid is False + + def test_load_nonexistent_key_file(self): + """Test loading from non-existent file.""" + with pytest.raises(FileNotFoundError): + KeyManager.load_keys(Path("/nonexistent/key.pem")) + + def test_save_to_readonly_directory(self): + """Test error handling when saving to readonly location.""" + import os + + # Skip if running as root (has permission to write everywhere) + if os.getuid() == 0: + pytest.skip("Running as root, cannot test readonly directory") + + km = KeyManager.generate() + + # Try to save to root (should fail on Unix systems) + readonly_path = Path("/readonly.key") + readonly_pub = Path("/readonly.pub") + + with pytest.raises((PermissionError, OSError)): + km.save_keys(readonly_path, readonly_pub) diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000..38966fe --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,578 @@ +"""Tests for structured logging functionality.""" + +import json +import logging +from io import StringIO + +import pytest + +from lexecon.observability.logging import ( + LoggerAdapter, + StructuredFormatter, + configure_logging, + get_logger, + request_id_var, + user_id_var, +) + + +class TestStructuredFormatter: + """Tests for StructuredFormatter class.""" + + def test_format_basic_log(self): + """Test formatting a basic log record.""" + formatter = StructuredFormatter() + record = logging.LogRecord( + name="test_logger", + level=logging.INFO, + pathname="test.py", + lineno=42, + msg="Test message", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + + # Should be valid JSON + data = json.loads(output) + assert data["message"] == "Test message" + assert data["level"] == "INFO" + assert data["logger"] == "test_logger" + assert data["line"] == 42 + + def test_format_includes_timestamp(self): + """Test that formatted log includes timestamp.""" + formatter = StructuredFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + data = json.loads(output) + + assert "timestamp" in data + # Should end with Z (UTC) + assert data["timestamp"].endswith("Z") + + def test_format_with_request_context(self): + """Test formatting includes request context from ContextVar.""" + formatter = StructuredFormatter() + + # Set context variables + request_id_var.set("req_123") + user_id_var.set("user_456") + + try: + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test with context", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + data = json.loads(output) + + assert data["request_id"] == "req_123" + assert data["user_id"] == "user_456" + finally: + # Clean up context + request_id_var.set(None) + user_id_var.set(None) + + def test_format_without_context(self): + """Test formatting when context variables are not set.""" + formatter = StructuredFormatter() + + # Ensure context is clear + request_id_var.set(None) + user_id_var.set(None) + + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="No context", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + data = json.loads(output) + + assert "request_id" not in data + assert "user_id" not in data + + def test_format_with_exception(self): + """Test formatting log with exception info.""" + formatter = StructuredFormatter() + + try: + raise ValueError("Test error") + except ValueError: + import sys + + exc_info = sys.exc_info() + + record = logging.LogRecord( + name="test", + level=logging.ERROR, + pathname="test.py", + lineno=1, + msg="Error occurred", + args=(), + exc_info=exc_info, + ) + + output = formatter.format(record) + data = json.loads(output) + + assert "exception" in data + assert data["exception"]["type"] == "ValueError" + assert data["exception"]["message"] == "Test error" + assert "traceback" in data["exception"] + + def test_format_error_without_exception(self): + """Test that errors without exc_info include stack trace.""" + formatter = StructuredFormatter() + + record = logging.LogRecord( + name="test", + level=logging.ERROR, + pathname="test.py", + lineno=1, + msg="Error without exception", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + data = json.loads(output) + + assert "stack_trace" in data + assert isinstance(data["stack_trace"], list) + + def test_format_warning_no_stack_trace(self): + """Test that warnings don't include stack trace.""" + formatter = StructuredFormatter() + + record = logging.LogRecord( + name="test", + level=logging.WARNING, + pathname="test.py", + lineno=1, + msg="Warning message", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + data = json.loads(output) + + assert "stack_trace" not in data + assert "exception" not in data + + def test_format_with_extra_fields(self): + """Test formatting with custom extra fields.""" + formatter = StructuredFormatter() + + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + record.extra_fields = {"custom_field": "custom_value", "count": 42} + + output = formatter.format(record) + data = json.loads(output) + + assert data["custom_field"] == "custom_value" + assert data["count"] == 42 + + def test_format_includes_module_and_function(self): + """Test that log includes module and function information.""" + formatter = StructuredFormatter() + + record = logging.LogRecord( + name="test_logger", + level=logging.INFO, + pathname="/path/to/module.py", + lineno=100, + msg="Test", + args=(), + exc_info=None, + func="test_function", + ) + + output = formatter.format(record) + data = json.loads(output) + + assert data["module"] == "module" + assert data["function"] == "test_function" + + def test_format_different_log_levels(self): + """Test formatting different log levels.""" + formatter = StructuredFormatter() + + levels = [ + (logging.DEBUG, "DEBUG"), + (logging.INFO, "INFO"), + (logging.WARNING, "WARNING"), + (logging.ERROR, "ERROR"), + (logging.CRITICAL, "CRITICAL"), + ] + + for level_num, level_name in levels: + record = logging.LogRecord( + name="test", + level=level_num, + pathname="test.py", + lineno=1, + msg=f"{level_name} message", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + data = json.loads(output) + + assert data["level"] == level_name + assert data["message"] == f"{level_name} message" + + +class TestConfigureLogging: + """Tests for configure_logging function.""" + + def test_configure_with_json_format(self): + """Test configuring logging with JSON format.""" + configure_logging(level="INFO", format="json", output="stdout") + + root = logging.getLogger() + assert root.level == logging.INFO + assert len(root.handlers) > 0 + + # Check that handler has StructuredFormatter + handler = root.handlers[0] + assert isinstance(handler.formatter, StructuredFormatter) + + def test_configure_with_text_format(self): + """Test configuring logging with text format.""" + configure_logging(level="DEBUG", format="text", output="stdout") + + root = logging.getLogger() + assert root.level == logging.DEBUG + + handler = root.handlers[0] + assert isinstance(handler.formatter, logging.Formatter) + assert not isinstance(handler.formatter, StructuredFormatter) + + def test_configure_different_log_levels(self): + """Test configuring different log levels.""" + levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + expected = [logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL] + + for level_str, level_expected in zip(levels, expected): + configure_logging(level=level_str, format="json", output="stdout") + root = logging.getLogger() + assert root.level == level_expected + + def test_configure_removes_existing_handlers(self): + """Test that configuration removes existing handlers.""" + root = logging.getLogger() + + # Add some handlers + handler1 = logging.StreamHandler() + handler2 = logging.StreamHandler() + root.addHandler(handler1) + root.addHandler(handler2) + + initial_count = len(root.handlers) + assert initial_count >= 2 + + # Reconfigure + configure_logging(level="INFO", format="json", output="stdout") + + # Should have only 1 handler now + assert len(root.handlers) == 1 + + +class TestGetLogger: + """Tests for get_logger function.""" + + def test_get_logger_returns_logger(self): + """Test that get_logger returns a logger instance.""" + logger = get_logger("test_module") + + assert isinstance(logger, logging.Logger) + assert logger.name == "test_module" + + def test_get_logger_same_name_same_instance(self): + """Test that getting logger with same name returns same instance.""" + logger1 = get_logger("shared_module") + logger2 = get_logger("shared_module") + + assert logger1 is logger2 + + def test_get_logger_different_names(self): + """Test that different names create different loggers.""" + logger1 = get_logger("module1") + logger2 = get_logger("module2") + + assert logger1 is not logger2 + assert logger1.name == "module1" + assert logger2.name == "module2" + + +class TestLoggerAdapter: + """Tests for LoggerAdapter class.""" + + def test_adapter_creation(self): + """Test creating logger adapter.""" + base_logger = logging.getLogger("test") + adapter = LoggerAdapter(base_logger, {}) + + assert adapter.logger is base_logger + + def test_adapter_adds_context_vars(self): + """Test that adapter adds context variables.""" + request_id_var.set("req_999") + user_id_var.set("user_888") + + try: + base_logger = logging.getLogger("test") + adapter = LoggerAdapter(base_logger, {}) + + msg, kwargs = adapter.process("Test message", {}) + + assert "extra" in kwargs + assert kwargs["extra"]["request_id"] == "req_999" + assert kwargs["extra"]["user_id"] == "user_888" + finally: + request_id_var.set(None) + user_id_var.set(None) + + def test_adapter_without_context_vars(self): + """Test adapter when context vars are not set.""" + request_id_var.set(None) + user_id_var.set(None) + + base_logger = logging.getLogger("test") + adapter = LoggerAdapter(base_logger, {}) + + msg, kwargs = adapter.process("Test message", {}) + + assert "extra" in kwargs + assert "request_id" not in kwargs["extra"] + assert "user_id" not in kwargs["extra"] + + def test_adapter_preserves_existing_extra(self): + """Test that adapter preserves existing extra fields.""" + request_id_var.set("req_777") + + try: + base_logger = logging.getLogger("test") + adapter = LoggerAdapter(base_logger, {}) + + msg, kwargs = adapter.process( + "Test message", {"extra": {"custom": "field"}} + ) + + assert kwargs["extra"]["custom"] == "field" + assert kwargs["extra"]["request_id"] == "req_777" + finally: + request_id_var.set(None) + + +class TestLoggingIntegration: + """Integration tests for logging system.""" + + def test_end_to_end_logging(self): + """Test complete logging workflow.""" + # Configure logging + stream = StringIO() + configure_logging(level="INFO", format="json", output="stdout") + + # Get logger and log message + logger = get_logger("integration_test") + + # Capture output + handler = logging.StreamHandler(stream) + handler.setFormatter(StructuredFormatter()) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + # Set context + request_id_var.set("req_integration") + + try: + logger.info("Integration test message", extra={"test_id": 123}) + + output = stream.getvalue() + data = json.loads(output) + + assert data["message"] == "Integration test message" + assert data["level"] == "INFO" + assert data["request_id"] == "req_integration" + assert data["logger"] == "integration_test" + finally: + request_id_var.set(None) + logger.removeHandler(handler) + + def test_logging_with_formatting(self): + """Test logging with message formatting.""" + stream = StringIO() + logger = logging.getLogger("format_test") + handler = logging.StreamHandler(stream) + handler.setFormatter(StructuredFormatter()) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + logger.info("User %s logged in from %s", "alice", "192.168.1.1") + + output = stream.getvalue() + data = json.loads(output) + + assert data["message"] == "User alice logged in from 192.168.1.1" + logger.removeHandler(handler) + + def test_concurrent_context_isolation(self): + """Test that context variables are properly isolated.""" + # This is a simplified test - in real concurrent scenarios, + # ContextVar provides proper isolation per async task/thread + + request_id_var.set("req_first") + assert request_id_var.get() == "req_first" + + request_id_var.set("req_second") + assert request_id_var.get() == "req_second" + + request_id_var.set(None) + assert request_id_var.get() is None + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_format_with_none_message(self): + """Test formatting log with None message.""" + formatter = StructuredFormatter() + + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg=None, + args=(), + exc_info=None, + ) + + output = formatter.format(record) + data = json.loads(output) + + assert data["message"] == "None" + + def test_format_with_unicode_message(self): + """Test formatting log with unicode characters.""" + formatter = StructuredFormatter() + + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Unicode test: 你好世界 🌍", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + data = json.loads(output) + + assert data["message"] == "Unicode test: 你好世界 🌍" + + def test_format_with_very_long_message(self): + """Test formatting very long log message.""" + formatter = StructuredFormatter() + + long_message = "A" * 10000 + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg=long_message, + args=(), + exc_info=None, + ) + + output = formatter.format(record) + data = json.loads(output) + + assert data["message"] == long_message + + def test_configure_with_invalid_level(self): + """Test configuring with invalid log level.""" + with pytest.raises(AttributeError): + configure_logging(level="INVALID", format="json", output="stdout") + + def test_context_var_cleanup(self): + """Test that context variables can be cleaned up.""" + request_id_var.set("test_req") + user_id_var.set("test_user") + + request_id_var.set(None) + user_id_var.set(None) + + assert request_id_var.get() is None + assert user_id_var.get() is None + + def test_multiple_formatters_different_loggers(self): + """Test using different formatters for different loggers.""" + logger1 = logging.getLogger("logger1") + logger2 = logging.getLogger("logger2") + + stream1 = StringIO() + stream2 = StringIO() + + handler1 = logging.StreamHandler(stream1) + handler1.setFormatter(StructuredFormatter()) + + handler2 = logging.StreamHandler(stream2) + handler2.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) + + logger1.addHandler(handler1) + logger2.addHandler(handler2) + + logger1.setLevel(logging.INFO) + logger2.setLevel(logging.INFO) + + logger1.info("JSON log") + logger2.info("Text log") + + # Logger1 should have JSON + output1 = stream1.getvalue() + data1 = json.loads(output1) + assert data1["message"] == "JSON log" + + # Logger2 should have text + output2 = stream2.getvalue() + assert "INFO: Text log" in output2 + + logger1.removeHandler(handler1) + logger2.removeHandler(handler2) diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..8721bf9 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,505 @@ +"""Tests for Prometheus metrics.""" + +import time + +import pytest +from prometheus_client import REGISTRY + +from lexecon.observability.metrics import ( + MetricsCollector, + active_policies, + active_tokens, + decisions_denied_total, + decisions_total, + http_request_duration_seconds, + http_requests_total, + ledger_entries_total, + metrics, + node_uptime_seconds, + policies_loaded_total, + record_decision, + record_policy_load, + tokens_issued_total, + tokens_verified_total, +) + + +class TestMetricsCollector: + """Tests for MetricsCollector class.""" + + def test_initialization(self): + """Test metrics collector initialization.""" + collector = MetricsCollector() + + assert collector.start_time > 0 + assert isinstance(collector.start_time, float) + + def test_record_request(self): + """Test recording HTTP request metrics.""" + collector = MetricsCollector() + + # Get initial value + initial_requests = http_requests_total.labels( + method="GET", endpoint="/health", status=200 + )._value.get() + + # Record request + collector.record_request("GET", "/health", 200, 0.123) + + # Verify counter increased + final_requests = http_requests_total.labels( + method="GET", endpoint="/health", status=200 + )._value.get() + + assert final_requests > initial_requests + + def test_record_request_duration(self): + """Test that request duration is recorded.""" + collector = MetricsCollector() + + # Record request with duration + collector.record_request("POST", "/decide", 200, 0.456) + + # Histogram should have recorded the observation + # We can't easily check the exact value, but we can verify it doesn't error + histogram = http_request_duration_seconds.labels(method="POST", endpoint="/decide") + assert histogram is not None + + def test_record_decision(self): + """Test recording decision metrics.""" + collector = MetricsCollector() + + # Get initial value + initial_decisions = decisions_total.labels( + allowed="True", actor="model", risk_level="1" + )._value.get() + + # Record decision + collector.record_decision(allowed=True, actor="model", risk_level=1, duration=0.1) + + # Verify counter increased + final_decisions = decisions_total.labels( + allowed="True", actor="model", risk_level="1" + )._value.get() + + assert final_decisions > initial_decisions + + def test_record_denial(self): + """Test recording decision denial.""" + collector = MetricsCollector() + + initial_denials = decisions_denied_total.labels( + reason_category="policy", actor="model" + )._value.get() + + collector.record_denial(reason_category="policy", actor="model") + + final_denials = decisions_denied_total.labels( + reason_category="policy", actor="model" + )._value.get() + + assert final_denials > initial_denials + + def test_record_policy_load(self): + """Test recording policy load.""" + collector = MetricsCollector() + + initial_loads = policies_loaded_total.labels(policy_name="test_policy")._value.get() + initial_active = active_policies._value.get() + + collector.record_policy_load("test_policy") + + final_loads = policies_loaded_total.labels(policy_name="test_policy")._value.get() + final_active = active_policies._value.get() + + assert final_loads > initial_loads + assert final_active > initial_active + + def test_record_ledger_entry(self): + """Test recording ledger entry.""" + collector = MetricsCollector() + + initial_entries = ledger_entries_total._value.get() + + collector.record_ledger_entry() + + final_entries = ledger_entries_total._value.get() + + assert final_entries > initial_entries + + def test_record_token_issuance(self): + """Test recording token issuance.""" + collector = MetricsCollector() + + initial_issued = tokens_issued_total.labels(scope="action:read")._value.get() + initial_active = active_tokens._value.get() + + collector.record_token_issuance("action:read") + + final_issued = tokens_issued_total.labels(scope="action:read")._value.get() + final_active = active_tokens._value.get() + + assert final_issued > initial_issued + assert final_active > initial_active + + def test_record_token_verification(self): + """Test recording token verification.""" + collector = MetricsCollector() + + initial_verified = tokens_verified_total.labels(valid="True")._value.get() + + collector.record_token_verification(valid=True) + + final_verified = tokens_verified_total.labels(valid="True")._value.get() + + assert final_verified > initial_verified + + def test_get_uptime(self): + """Test getting node uptime.""" + collector = MetricsCollector() + + uptime1 = collector.get_uptime() + assert uptime1 >= 0 + + time.sleep(0.1) + + uptime2 = collector.get_uptime() + assert uptime2 > uptime1 + + def test_export_metrics(self): + """Test exporting metrics in Prometheus format.""" + collector = MetricsCollector() + + output = collector.export_metrics() + + assert isinstance(output, bytes) + assert len(output) > 0 + + # Should contain Prometheus format markers + decoded = output.decode('utf-8') + assert "# HELP" in decoded or "# TYPE" in decoded + + +class TestGlobalMetricsInstance: + """Tests for global metrics instance.""" + + def test_global_metrics_exists(self): + """Test that global metrics instance exists.""" + assert metrics is not None + assert isinstance(metrics, MetricsCollector) + + def test_global_metrics_record_decision(self): + """Test using global metrics instance.""" + initial = decisions_total.labels( + allowed="False", actor="user", risk_level="3" + )._value.get() + + metrics.record_decision(allowed=False, actor="user", risk_level=3, duration=0.2) + + final = decisions_total.labels( + allowed="False", actor="user", risk_level="3" + )._value.get() + + assert final > initial + + +class TestConvenienceFunctions: + """Tests for convenience functions.""" + + def test_record_decision_function(self): + """Test record_decision convenience function.""" + initial = decisions_total.labels( + allowed="True", actor="bot", risk_level="2" + )._value.get() + + record_decision(allowed=True, actor="bot", risk_level=2, duration=0.15) + + final = decisions_total.labels( + allowed="True", actor="bot", risk_level="2" + )._value.get() + + assert final > initial + + def test_record_policy_load_function(self): + """Test record_policy_load convenience function.""" + initial = policies_loaded_total.labels(policy_name="conv_policy")._value.get() + + record_policy_load("conv_policy") + + final = policies_loaded_total.labels(policy_name="conv_policy")._value.get() + + assert final > initial + + +class TestMetricTypes: + """Tests for different metric types.""" + + def test_counter_increments(self): + """Test that counters only increment.""" + initial = http_requests_total.labels( + method="GET", endpoint="/test", status=200 + )._value.get() + + # Increment multiple times + for _ in range(5): + metrics.record_request("GET", "/test", 200, 0.1) + + final = http_requests_total.labels( + method="GET", endpoint="/test", status=200 + )._value.get() + + # Should have incremented by 5 + assert final >= initial + 5 + + def test_gauge_can_increase_and_decrease(self): + """Test that gauges can increase and decrease.""" + initial = active_policies._value.get() + + # Increase + active_policies.inc() + increased = active_policies._value.get() + assert increased > initial + + # Decrease + active_policies.dec() + decreased = active_policies._value.get() + assert decreased < increased + + def test_histogram_records_observations(self): + """Test that histograms record observations.""" + histogram = http_request_duration_seconds.labels(method="GET", endpoint="/metrics") + + # Record several observations + histogram.observe(0.1) + histogram.observe(0.2) + histogram.observe(0.3) + + # Histogram should have recorded these + # We can't easily check exact values, but verify no errors + assert histogram is not None + + +class TestMetricLabels: + """Tests for metric label handling.""" + + def test_different_labels_separate_metrics(self): + """Test that different labels create separate metric series.""" + initial_200 = http_requests_total.labels( + method="GET", endpoint="/api", status=200 + )._value.get() + + initial_404 = http_requests_total.labels( + method="GET", endpoint="/api", status=404 + )._value.get() + + # Record 200 + metrics.record_request("GET", "/api", 200, 0.1) + + # Only 200 should increment + final_200 = http_requests_total.labels( + method="GET", endpoint="/api", status=200 + )._value.get() + + final_404 = http_requests_total.labels( + method="GET", endpoint="/api", status=404 + )._value.get() + + assert final_200 > initial_200 + assert final_404 == initial_404 + + def test_label_combinations(self): + """Test various label combinations.""" + # Test different actors + metrics.record_decision(True, "actor1", 1, 0.1) + metrics.record_decision(True, "actor2", 1, 0.1) + + # Test different risk levels + metrics.record_decision(True, "model", 1, 0.1) + metrics.record_decision(True, "model", 5, 0.1) + + # Test different allowed values + metrics.record_decision(True, "user", 1, 0.1) + metrics.record_decision(False, "user", 1, 0.1) + + # All should be recorded separately + # If this completes without error, labels are working + + +class TestMetricsIntegration: + """Integration tests for metrics system.""" + + def test_complete_request_workflow(self): + """Test recording complete request workflow.""" + # Record incoming request + start_time = time.time() + metrics.record_request("POST", "/decide", 200, 0.2) + + # Record decision + metrics.record_decision(True, "model", 2, 0.15) + + # Record token issuance + metrics.record_token_issuance("action:search") + + # Record ledger entry + metrics.record_ledger_entry() + + # All should succeed without error + + def test_decision_workflow_metrics(self): + """Test metrics for decision workflow.""" + # Load policy + metrics.record_policy_load("workflow_policy") + + # Make decision (permitted) + metrics.record_decision(True, "workflow_actor", 1, 0.1) + + # Issue token + metrics.record_token_issuance("workflow:action") + + # Verify token later + metrics.record_token_verification(True) + + # Record to ledger + metrics.record_ledger_entry() + + # Workflow complete - verify no errors + + def test_denial_workflow_metrics(self): + """Test metrics for denial workflow.""" + # Make decision (denied) + metrics.record_decision(False, "denied_actor", 5, 0.1) + + # Record denial reason + metrics.record_denial("high_risk", "denied_actor") + + # Still record to ledger + metrics.record_ledger_entry() + + # Workflow complete + + +class TestPrometheusExport: + """Tests for Prometheus export functionality.""" + + def test_export_format(self): + """Test Prometheus export format.""" + output = metrics.export_metrics() + decoded = output.decode('utf-8') + + # Should contain metric definitions + assert "lexecon_" in decoded + + # Should contain HELP and TYPE comments + lines = decoded.split('\n') + help_lines = [l for l in lines if l.startswith('# HELP')] + type_lines = [l for l in lines if l.startswith('# TYPE')] + + assert len(help_lines) > 0 + assert len(type_lines) > 0 + + def test_export_includes_values(self): + """Test that export includes metric values.""" + # Record some metrics + metrics.record_decision(True, "export_test", 1, 0.1) + metrics.record_ledger_entry() + + output = metrics.export_metrics() + decoded = output.decode('utf-8') + + # Should contain metric values (numbers) + assert any(char.isdigit() for char in decoded) + + def test_export_is_valid_prometheus_format(self): + """Test that export is valid Prometheus format.""" + output = metrics.export_metrics() + decoded = output.decode('utf-8') + + lines = decoded.split('\n') + + # Each metric line should have format: metric_name{labels} value + metric_lines = [l for l in lines if l and not l.startswith('#')] + + for line in metric_lines[:10]: # Check first 10 + if '{' in line: + # Has labels + assert '}' in line + assert ' ' in line # Space before value + elif ' ' in line and line.strip(): + # No labels, just name and value + parts = line.split() + assert len(parts) >= 2 + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_record_zero_duration(self): + """Test recording request with zero duration.""" + metrics.record_request("GET", "/fast", 200, 0.0) + # Should not error + + def test_record_very_long_duration(self): + """Test recording very long request duration.""" + metrics.record_request("GET", "/slow", 200, 100.0) + # Should not error + + def test_record_high_risk_level(self): + """Test recording decision with high risk level.""" + metrics.record_decision(False, "risky", 10, 0.1) + # Should not error + + def test_record_special_characters_in_labels(self): + """Test labels with special characters.""" + # Prometheus should handle these + metrics.record_decision(True, "actor-with-dash", 1, 0.1) + metrics.record_policy_load("policy_with_underscore") + # Should not error + + def test_record_empty_scope(self): + """Test recording token with empty scope.""" + metrics.record_token_issuance("") + # Should not error + + def test_concurrent_metric_recording(self): + """Test recording metrics concurrently.""" + # Simulate concurrent requests + for i in range(100): + metrics.record_request("GET", f"/endpoint{i % 5}", 200, 0.01) + metrics.record_decision(True, "concurrent", 1, 0.01) + + # Should handle all without error + + def test_uptime_never_negative(self): + """Test that uptime is never negative.""" + collector = MetricsCollector() + + for _ in range(10): + uptime = collector.get_uptime() + assert uptime >= 0 + time.sleep(0.01) + + def test_export_with_no_metrics(self): + """Test exporting when no metrics recorded.""" + # Create fresh collector + collector = MetricsCollector() + + output = collector.export_metrics() + + # Should still produce valid output + assert isinstance(output, bytes) + assert len(output) > 0 + + def test_multiple_collectors(self): + """Test creating multiple collector instances.""" + collector1 = MetricsCollector() + collector2 = MetricsCollector() + + # Should have different start times if created at different times + time.sleep(0.01) + collector3 = MetricsCollector() + + uptime1 = collector1.get_uptime() + uptime3 = collector3.get_uptime() + + # collector1 should have slightly higher uptime + assert uptime1 >= uptime3 diff --git a/tests/test_tracing.py b/tests/test_tracing.py new file mode 100644 index 0000000..603226c --- /dev/null +++ b/tests/test_tracing.py @@ -0,0 +1,529 @@ +"""Tests for OpenTelemetry tracing functionality.""" + +import time + +import pytest + +from lexecon.observability.tracing import ( + TRACING_AVAILABLE, + TracingManager, + trace_function, + tracer, +) + + +class TestTracingAvailability: + """Tests for tracing availability detection.""" + + def test_tracing_available_is_boolean(self): + """Test that TRACING_AVAILABLE is a boolean.""" + assert isinstance(TRACING_AVAILABLE, bool) + + def test_global_tracer_exists(self): + """Test that global tracer instance exists.""" + assert tracer is not None + assert isinstance(tracer, TracingManager) + + +class TestTracingManager: + """Tests for TracingManager class.""" + + def test_initialization(self): + """Test tracing manager initialization.""" + manager = TracingManager() + + assert manager is not None + assert hasattr(manager, 'enabled') + assert hasattr(manager, 'tracer') + + def test_enabled_status(self): + """Test that enabled status matches availability.""" + manager = TracingManager() + + # If tracing is available, should be enabled + # If not available, should be disabled + assert manager.enabled == TRACING_AVAILABLE + + def test_tracer_attribute(self): + """Test tracer attribute state.""" + manager = TracingManager() + + if TRACING_AVAILABLE: + assert manager.tracer is not None + else: + assert manager.tracer is None + + @pytest.mark.skipif(not TRACING_AVAILABLE, reason="OpenTelemetry not installed") + def test_start_span_when_enabled(self): + """Test starting a span when tracing is enabled.""" + manager = TracingManager() + + span = manager.start_span("test_span", test_attribute="value") + + assert span is not None + # Span should be context manager + assert hasattr(span, '__enter__') + assert hasattr(span, '__exit__') + + def test_start_span_when_disabled(self): + """Test starting a span when tracing is disabled.""" + manager = TracingManager() + + if not TRACING_AVAILABLE: + span = manager.start_span("test_span") + assert span is None + + @pytest.mark.skipif(not TRACING_AVAILABLE, reason="OpenTelemetry not installed") + def test_span_with_attributes(self): + """Test creating span with multiple attributes.""" + manager = TracingManager() + + span = manager.start_span( + "attributed_span", + attr1="value1", + attr2=42, + attr3=True + ) + + assert span is not None + + @pytest.mark.skipif(not TRACING_AVAILABLE, reason="OpenTelemetry not installed") + def test_instrument_fastapi(self): + """Test FastAPI instrumentation.""" + from unittest.mock import Mock + + manager = TracingManager() + mock_app = Mock() + + # Should not raise error + manager.instrument_fastapi(mock_app) + + +class TestTraceFunctionDecorator: + """Tests for trace_function decorator.""" + + def test_decorator_without_name(self): + """Test decorator without explicit name.""" + @trace_function() + def test_func(): + return "result" + + result = test_func() + assert result == "result" + + def test_decorator_with_name(self): + """Test decorator with explicit span name.""" + @trace_function(name="custom_span_name") + def test_func(): + return "result" + + result = test_func() + assert result == "result" + + def test_decorator_preserves_function_behavior(self): + """Test that decorator doesn't change function behavior.""" + @trace_function() + def add(a, b): + return a + b + + assert add(2, 3) == 5 + assert add(10, 20) == 30 + + def test_decorator_with_arguments(self): + """Test decorating function with various arguments.""" + @trace_function() + def complex_func(x, y, *args, **kwargs): + return { + 'x': x, + 'y': y, + 'args': args, + 'kwargs': kwargs + } + + result = complex_func(1, 2, 3, 4, key="value") + + assert result['x'] == 1 + assert result['y'] == 2 + assert result['args'] == (3, 4) + assert result['kwargs'] == {'key': 'value'} + + def test_decorator_with_exception(self): + """Test decorator handles exceptions properly.""" + @trace_function() + def failing_func(): + raise ValueError("Test error") + + with pytest.raises(ValueError, match="Test error"): + failing_func() + + def test_decorator_exception_still_raises(self): + """Test that exceptions are still raised after tracing.""" + @trace_function() + def error_func(should_error): + if should_error: + raise RuntimeError("Intentional error") + return "success" + + # Should work without error + assert error_func(False) == "success" + + # Should raise error + with pytest.raises(RuntimeError): + error_func(True) + + def test_decorator_on_class_method(self): + """Test decorator on class methods.""" + class TestClass: + @trace_function() + def method(self, value): + return value * 2 + + obj = TestClass() + assert obj.method(5) == 10 + + def test_decorator_on_static_method(self): + """Test decorator on static methods.""" + class TestClass: + @staticmethod + @trace_function() + def static_method(value): + return value + 1 + + assert TestClass.static_method(5) == 6 + + def test_decorator_preserves_docstring(self): + """Test that decorator preserves function docstring.""" + @trace_function() + def documented_func(): + """This is a docstring.""" + return "result" + + assert documented_func.__doc__ == "This is a docstring." + + def test_decorator_preserves_function_name(self): + """Test that decorator preserves function name.""" + @trace_function() + def named_function(): + return "result" + + assert named_function.__name__ == "named_function" + + @pytest.mark.skipif(not TRACING_AVAILABLE, reason="OpenTelemetry not installed") + def test_decorator_records_duration(self): + """Test that decorator records execution duration.""" + @trace_function() + def slow_func(): + time.sleep(0.1) + return "done" + + start = time.time() + result = slow_func() + duration = time.time() - start + + assert result == "done" + assert duration >= 0.1 + + def test_decorator_when_tracing_disabled(self): + """Test decorator works even when tracing is disabled.""" + # This should work regardless of TRACING_AVAILABLE + @trace_function() + def normal_func(): + return "works" + + assert normal_func() == "works" + + +class TestTracingIntegration: + """Integration tests for tracing system.""" + + @pytest.mark.skipif(not TRACING_AVAILABLE, reason="OpenTelemetry not installed") + def test_nested_spans(self): + """Test creating nested spans.""" + manager = TracingManager() + + outer_span = manager.start_span("outer") + if outer_span: + inner_span = manager.start_span("inner") + assert inner_span is not None + + @pytest.mark.skipif(not TRACING_AVAILABLE, reason="OpenTelemetry not installed") + def test_multiple_sequential_spans(self): + """Test creating multiple sequential spans.""" + manager = TracingManager() + + span1 = manager.start_span("span1") + assert span1 is not None + + span2 = manager.start_span("span2") + assert span2 is not None + + span3 = manager.start_span("span3") + assert span3 is not None + + def test_decorated_functions_call_chain(self): + """Test call chain of decorated functions.""" + @trace_function() + def func_a(): + return func_b() + + @trace_function() + def func_b(): + return func_c() + + @trace_function() + def func_c(): + return "final_result" + + result = func_a() + assert result == "final_result" + + def test_decorator_with_multiple_returns(self): + """Test decorator on function with multiple return paths.""" + @trace_function() + def multi_return(value): + if value > 0: + return "positive" + elif value < 0: + return "negative" + else: + return "zero" + + assert multi_return(5) == "positive" + assert multi_return(-3) == "negative" + assert multi_return(0) == "zero" + + @pytest.mark.skipif(not TRACING_AVAILABLE, reason="OpenTelemetry not installed") + def test_span_context_manager(self): + """Test using span as context manager.""" + manager = TracingManager() + + span = manager.start_span("context_test") + + if span: + # Should be usable as context manager + with span: + pass # Span active here + + # Span should be ended after context + + +class TestTracingFallback: + """Tests for tracing fallback when disabled.""" + + def test_disabled_tracer_returns_none(self): + """Test that disabled tracer returns None for spans.""" + manager = TracingManager() + + if not manager.enabled: + span = manager.start_span("test") + assert span is None + + def test_disabled_tracing_no_errors(self): + """Test that disabled tracing doesn't cause errors.""" + @trace_function() + def test_func(): + return "result" + + # Should work fine even if tracing is disabled + result = test_func() + assert result == "result" + + def test_decorator_overhead_minimal_when_disabled(self): + """Test that decorator has minimal overhead when disabled.""" + @trace_function() + def fast_func(): + return 42 + + # Should execute quickly even if wrapped + start = time.time() + for _ in range(1000): + result = fast_func() + duration = time.time() - start + + assert result == 42 + # Should be very fast (less than 1 second for 1000 calls) + assert duration < 1.0 + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_decorator_on_generator(self): + """Test decorator on generator function.""" + @trace_function() + def gen_func(): + yield 1 + yield 2 + yield 3 + + result = list(gen_func()) + assert result == [1, 2, 3] + + def test_decorator_on_async_function(self): + """Test decorator on async function (should still work).""" + @trace_function() + async def async_func(): + return "async_result" + + # We can't easily await this in sync tests, but decorator should apply + assert hasattr(async_func, '__name__') + + def test_decorator_with_none_return(self): + """Test decorator on function returning None.""" + @trace_function() + def none_func(): + return None + + result = none_func() + assert result is None + + def test_decorator_with_no_return(self): + """Test decorator on function with no explicit return.""" + @trace_function() + def no_return_func(): + pass + + result = no_return_func() + assert result is None + + def test_span_name_with_special_characters(self): + """Test span names with special characters.""" + manager = TracingManager() + + # Should handle special characters gracefully + span = manager.start_span("span-with-dashes") + span = manager.start_span("span_with_underscores") + span = manager.start_span("span.with.dots") + + # Should not error + + def test_decorator_with_very_long_name(self): + """Test decorator with very long span name.""" + long_name = "a" * 1000 + + @trace_function(name=long_name) + def test_func(): + return "ok" + + result = test_func() + assert result == "ok" + + def test_multiple_decorators(self): + """Test function with multiple decorators.""" + @trace_function(name="outer") + @trace_function(name="inner") + def double_traced(): + return "traced" + + result = double_traced() + assert result == "traced" + + def test_decorator_with_recursive_function(self): + """Test decorator on recursive function.""" + @trace_function() + def factorial(n): + if n <= 1: + return 1 + return n * factorial(n - 1) + + result = factorial(5) + assert result == 120 + + @pytest.mark.skipif(not TRACING_AVAILABLE, reason="OpenTelemetry not installed") + def test_span_attributes_with_complex_types(self): + """Test span attributes with various types.""" + manager = TracingManager() + + # Different attribute types + span = manager.start_span( + "complex_attrs", + string_attr="value", + int_attr=42, + float_attr=3.14, + bool_attr=True + ) + + assert span is not None + + +class TestTracingPerformance: + """Tests for tracing performance characteristics.""" + + def test_decorator_minimal_overhead(self): + """Test that decorator has minimal overhead.""" + # Baseline function + def baseline(): + return 42 + + # Decorated function + @trace_function() + def traced(): + return 42 + + # Both should be fast + iterations = 100 + + start = time.time() + for _ in range(iterations): + baseline() + baseline_time = time.time() - start + + start = time.time() + for _ in range(iterations): + traced() + traced_time = time.time() - start + + # Traced should be relatively fast (overhead < 10x baseline) + # This is a loose bound since we don't know if tracing is enabled + assert traced_time < baseline_time * 10 + + def test_many_spans_no_memory_leak(self): + """Test creating many spans doesn't leak memory.""" + manager = TracingManager() + + # Create many spans + for i in range(1000): + span = manager.start_span(f"span_{i}") + + # Should complete without issues + + def test_concurrent_tracing(self): + """Test that concurrent tracing works.""" + @trace_function() + def concurrent_func(n): + time.sleep(0.001) + return n * 2 + + # Simulate concurrent calls + results = [concurrent_func(i) for i in range(10)] + + assert results == [i * 2 for i in range(10)] + + +class TestTracingConfiguration: + """Tests for tracing configuration and setup.""" + + def test_tracer_singleton(self): + """Test that module provides singleton tracer.""" + from lexecon.observability.tracing import tracer as tracer1 + from lexecon.observability.tracing import tracer as tracer2 + + assert tracer1 is tracer2 + + def test_tracing_manager_setup(self): + """Test that TracingManager sets up correctly.""" + manager = TracingManager() + + # Should have setup method called + assert hasattr(manager, '_setup_tracing') + + def test_graceful_degradation(self): + """Test graceful degradation when tracing unavailable.""" + # Should work even if OpenTelemetry is not installed + manager = TracingManager() + + # All operations should be safe + span = manager.start_span("test") + manager.instrument_fastapi(None) + + # No errors should occur