diff --git a/.github/workflows/docker-ci-cd.yml b/.github/workflows/docker-ci-cd.yml new file mode 100644 index 0000000..1fe0f1e --- /dev/null +++ b/.github/workflows/docker-ci-cd.yml @@ -0,0 +1,186 @@ +name: Docker CI/CD Pipeline + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + release: + types: [ created ] + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + build-and-test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest pytest-cov + + - name: Run tests + run: | + pytest tests/ -v --cov=astroml --cov-report=xml + + - name: Upload coverage + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + + build-docker-images: + runs-on: ubuntu-latest + needs: build-and-test + strategy: + matrix: + stage: [base, development, feature-store, ingestion, training-cpu, production] + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Container Registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha,prefix= + + - name: Build Docker image + uses: docker/build-push-action@v5 + with: + context: . + target: ${{ matrix.stage }} + push: true + tags: | + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ matrix.stage }}-${{ github.sha }} + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ matrix.stage }}-latest + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + + security-scan: + runs-on: ubuntu-latest + needs: build-docker-images + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@master + with: + image-ref: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:production-latest + format: 'sarif' + output: 'trivy-results.sarif' + + - name: Upload Trivy results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: 'trivy-results.sarif' + + deploy-kubernetes: + runs-on: ubuntu-latest + needs: [build-docker-images, security-scan] + if: github.ref == 'refs/heads/main' && github.event_name == 'push' + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up kubectl + uses: azure/setup-kubectl@v3 + with: + version: 'v1.28.0' + + - name: Configure kubectl + run: | + echo "${{ secrets.KUBE_CONFIG }}" | base64 -d > kubeconfig + export KUBECONFIG=kubeconfig + + - name: Install kustomize + run: | + curl -s "https://raw.githubusercontent.com/kubernetes-sigs/kustomize/master/hack/install_kustomize.sh" | bash + sudo mv kustomize /usr/local/bin/ + + - name: Deploy to Kubernetes + run: | + kustomize build k8s/ | kubectl apply -f - + + - name: Verify deployment + run: | + kubectl rollout status deployment/feature-store -n astroml + kubectl rollout status deployment/astroml-ingestion -n astroml + kubectl rollout status deployment/postgres -n astroml + kubectl rollout status deployment/redis -n astroml + + deploy-staging: + runs-on: ubuntu-latest + needs: [build-docker-images, security-scan] + if: github.ref == 'refs/heads/develop' && github.event_name == 'push' + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up kubectl + uses: azure/setup-kubectl@v3 + with: + version: 'v1.28.0' + + - name: Configure kubectl + run: | + echo "${{ secrets.KUBE_CONFIG_STAGING }}" | base64 -d > kubeconfig + export KUBECONFIG=kubeconfig + + - name: Install kustomize + run: | + curl -s "https://raw.githubusercontent.com/kubernetes-sigs/kustomize/master/hack/install_kustomize.sh" | bash + sudo mv kustomize /usr/local/bin/ + + - name: Deploy to Staging + run: | + kustomize build k8s/overlays/staging | kubectl apply -f - + + - name: Verify deployment + run: | + kubectl rollout status deployment/feature-store -n astroml-staging + kubectl rollout status deployment/astroml-ingestion -n astroml-staging + + notify: + runs-on: ubuntu-latest + needs: [deploy-kubernetes] + if: always() + steps: + - name: Send notification + uses: 8398a7/action-slack@v3 + with: + status: ${{ job.status }} + text: | + Deployment Status: ${{ job.status }} + Branch: ${{ github.ref }} + Commit: ${{ github.sha }} + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} diff --git a/docs/KUBERNETES_DEPLOYMENT.md b/docs/KUBERNETES_DEPLOYMENT.md new file mode 100644 index 0000000..3027842 --- /dev/null +++ b/docs/KUBERNETES_DEPLOYMENT.md @@ -0,0 +1,631 @@ +# Kubernetes Deployment Guide for AstroML + +This guide provides comprehensive instructions for deploying AstroML with Feature Store to Kubernetes clusters. + +## Overview + +The Kubernetes deployment provides: +- **Scalable deployment** with horizontal pod autoscaling +- **High availability** with multiple replicas +- **Monitoring** with Prometheus and Grafana +- **Logging** with Elasticsearch, Fluentd, and Kibana (EFK stack) +- **Ingress** for external access +- **CI/CD pipeline** with GitHub Actions + +## Prerequisites + +### System Requirements +- **Kubernetes cluster** v1.24+ (EKS, GKE, AKS, or minikube) +- **kubectl** v1.24+ configured for cluster access +- **kustomize** v4.0+ for configuration management +- **Helm** v3.0+ (optional, for additional packages) +- **Storage class** configured for persistent volumes +- **Ingress controller** installed (nginx, traefik, etc.) + +### Installation + +#### kubectl +```bash +# Install kubectl +curl -LO "https://dl.k8s.io/release/$(curl -L -s https://dl.k8s.io/release/stable.txt)/bin/linux/amd64/kubectl" +chmod +x kubectl +sudo mv kubectl /usr/local/bin/ + +# Verify installation +kubectl version --client +``` + +#### kustomize +```bash +# Install kustomize +curl -s "https://raw.githubusercontent.com/kubernetes-sigs/kustomize/master/hack/install_kustomize.sh" | bash +sudo mv kustomize /usr/local/bin/ + +# Verify installation +kustomize version +``` + +## Deployment Architecture + +### Components + +#### Core Infrastructure +- **PostgreSQL** - Primary database with persistent storage +- **Redis** - Caching and job queues +- **Feature Store** - Dedicated feature management service + +#### Application Services +- **Ingestion Service** - Data processing and backfill +- **Training Service** - ML model training +- **API Service** - REST API for feature access + +#### Monitoring Stack +- **Prometheus** - Metrics collection and storage +- **Grafana** - Visualization and dashboards + +#### Logging Stack +- **Elasticsearch** - Log storage and search +- **Fluentd** - Log collection and aggregation +- **Kibana** - Log visualization and analysis + +### Network Architecture + +``` +Internet + ↓ +Ingress Controller + ↓ +AstroML Services + ↓ +Feature Store, Ingestion, Training + ↓ +PostgreSQL, Redis +``` + +## Quick Start + +### 1. Clone Repository +```bash +git clone https://github.com/Menjay7/astroml.git +cd astroml +``` + +### 2. Configure Secrets +```bash +# Create secrets file +cat > k8s/secrets.yaml << EOF +apiVersion: v1 +kind: Secret +metadata: + name: postgres-secret + namespace: astroml +type: Opaque +stringData: + password: your-secure-password-here +--- +apiVersion: v1 +kind: Secret +metadata: + name: astroml-secret + namespace: astroml +type: Opaque +stringData: + database-url: "postgresql://astroml:your-password@postgres:5432/astroml" + redis-url: "redis://redis:6379/0" +EOF +``` + +### 3. Deploy Using Script +```bash +# Make script executable +chmod +x scripts/deploy-k8s.sh + +# Deploy all components +./scripts/deploy-k8s.sh deploy +``` + +### 4. Verify Deployment +```bash +# Check pod status +kubectl get pods -n astroml + +# Check services +kubectl get services -n astroml + +# Check ingress +kubectl get ingress -n astroml +``` + +### 5. Access Services +```bash +# Access Grafana +kubectl port-forward -n astroml svc/grafana 3000:3000 +# Open browser: http://localhost:3000 (admin/admin) + +# Access Kibana +kubectl port-forward -n astroml svc/kibana 5601:5601 +# Open browser: http://localhost:5601 +``` + +## Deployment Methods + +### Method 1: Using Deployment Script + +```bash +# Deploy all components +./scripts/deploy-k8s.sh deploy + +# Deploy using kustomize +./scripts/deploy-k8s.sh kustomize + +# Deploy monitoring only +./scripts/deploy-k8s.sh monitoring + +# Deploy logging only +./scripts/deploy-k8s.sh logging +``` + +### Method 2: Using kubectl Directly + +```bash +# Apply all configurations +kubectl apply -f k8s/ + +# Apply specific components +kubectl apply -f k8s/namespace.yaml +kubectl apply -f k8s/postgres-deployment.yaml +kubectl apply -f k8s/feature-store-deployment.yaml +``` + +### Method 3: Using Kustomize + +```bash +# Build and apply +kustomize build k8s/ | kubectl apply -f - + +# Build and preview +kustomize build k8s/ + +# Build to file +kustomize build k8s/ > deployment.yaml +kubectl apply -f deployment.yaml +``` + +## Configuration Management + +### Environment-Specific Configurations + +Create overlays for different environments: + +```bash +# Production overlay +k8s/overlays/production/ +├── kustomization.yaml +├── postgres-patch.yaml +└── feature-store-patch.yaml + +# Staging overlay +k8s/overlays/staging/ +├── kustomization.yaml +├── postgres-patch.yaml +└── feature-store-patch.yaml +``` + +### Example Production Overlay + +```yaml +# k8s/overlays/production/kustomization.yaml +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization + +namespace: astroml + +bases: + - ../../ + +patchesStrategicMerge: + - postgres-patch.yaml + - feature-store-patch.yaml + +images: + - name: astroml + newTag: v1.0.0 +``` + +### Example Patch + +```yaml +# k8s/overlays/production/postgres-patch.yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: postgres +spec: + replicas: 3 + resources: + requests: + memory: "2Gi" + cpu: "1000m" + limits: + memory: "4Gi" + cpu: "2000m" +``` + +## Scaling and High Availability + +### Horizontal Pod Autoscaling + +The Feature Store deployment includes HPA configuration: + +```yaml +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: feature-store-hpa +spec: + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: feature-store + minReplicas: 2 + maxReplicas: 10 + metrics: + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: 70 +``` + +### Manual Scaling + +```bash +# Scale deployment +kubectl scale deployment/feature-store -n astroml --replicas=5 + +# Scale using script +./scripts/deploy-k8s.sh scale feature-store 5 +``` + +### Resource Limits + +Configure resource limits based on workload: + +```yaml +resources: + requests: + memory: "512Mi" + cpu: "500m" + limits: + memory: "1Gi" + cpu: "1000m" +``` + +## Monitoring and Observability + +### Prometheus Metrics + +Access Prometheus metrics: + +```bash +# Port forward to Prometheus +kubectl port-forward -n astroml svc/prometheus 9090:9090 + +# Access in browser +# http://localhost:9090 +``` + +### Grafana Dashboards + +Access Grafana for visualization: + +```bash +# Port forward to Grafana +kubectl port-forward -n astroml svc/grafana 3000:3000 + +# Access in browser +# http://localhost:3000 +# Default credentials: admin/admin +``` + +### Log Analysis with Kibana + +Access Kibana for log analysis: + +```bash +# Port forward to Kibana +kubectl port-forward -n astroml svc/kibana 5601:5601 + +# Access in browser +# http://localhost:5601 +``` + +## Troubleshooting + +### Common Issues + +#### Pods Not Starting +```bash +# Check pod status +kubectl describe pod -n astroml + +# Check logs +kubectl logs -n astroml + +# Check events +kubectl get events -n astroml --sort-by='.lastTimestamp' +``` + +#### Service Not Accessible +```bash +# Check service endpoints +kubectl get endpoints -n astroml + +# Check service configuration +kubectl describe service -n astroml + +# Check network policies +kubectl get networkpolicies -n astroml +``` + +#### Storage Issues +```bash +# Check PVC status +kubectl get pvc -n astroml + +# Check storage class +kubectl get storageclass + +# Check PV status +kubectl get pv +``` + +### Debugging Commands + +```bash +# Get all resources +kubectl get all -n astroml + +# Get detailed information +kubectl describe deployment/feature-store -n astroml + +# Get logs from all pods +kubectl logs -l app=feature-store -n astroml --all-containers=true + +# Execute into pod +kubectl exec -it -n astroml -- /bin/bash + +# Check resource usage +kubectl top pods -n astroml +kubectl top nodes +``` + +## CI/CD Pipeline + +### GitHub Actions Workflow + +The project includes a comprehensive CI/CD pipeline: + +```yaml +# .github/workflows/docker-ci-cd.yml +- Build and test +- Build Docker images +- Security scanning +- Deploy to Kubernetes +- Notification +``` + +### Pipeline Stages + +1. **Build and Test** - Run tests and coverage +2. **Build Docker Images** - Build multi-stage images +3. **Security Scan** - Trivy vulnerability scanning +4. **Deploy to Kubernetes** - Automatic deployment +5. **Notification** - Slack notifications + +### Manual Deployment + +```bash +# Trigger deployment manually +gh workflow run docker-ci-cd.yml + +# Deploy specific branch +gh workflow run docker-ci-cd.yml -f branch=develop +``` + +## Security Considerations + +### Secrets Management + +Use Kubernetes secrets for sensitive data: + +```bash +# Create secret from file +kubectl create secret generic db-secret \ + --from-literal=password=your-password \ + -n astroml + +# Create secret from file +kubectl create secret generic tls-secret \ + --from-file=tls.crt=./cert.pem \ + --from-file=tls.key=./key.pem \ + -n astroml +``` + +### Network Policies + +Implement network policies for security: + +```yaml +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: feature-store-network-policy + namespace: astroml +spec: + podSelector: + matchLabels: + app: feature-store + policyTypes: + - Ingress + - Egress + ingress: + - from: + - podSelector: + matchLabels: + app: astroml-ingestion + ports: + - protocol: TCP + port: 8000 +``` + +### RBAC Configuration + +The deployment includes RBAC configuration: + +```yaml +apiVersion: v1 +kind: ServiceAccount +metadata: + name: astroml + namespace: astroml +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: astroml-role + namespace: astroml +rules: +- apiGroups: [""] + resources: ["configmaps", "secrets"] + verbs: ["get", "list"] +``` + +## Backup and Recovery + +### Database Backup + +```bash +# Backup PostgreSQL +kubectl exec -n astroml postgres-0 -- pg_dump -U astroml astroml > backup.sql + +# Restore PostgreSQL +kubectl exec -i -n astroml postgres-0 -- psql -U astroml astroml < backup.sql +``` + +### Volume Backup + +```bash +# Backup persistent volumes +kubectl get pvc -n astroml +# Use your cloud provider's backup solution +``` + +### Disaster Recovery + +```bash +# Restore from backup +kubectl apply -f k8s/ +kubectl exec -i -n astroml postgres-0 -- psql -U astroml astroml < backup.sql +``` + +## Performance Optimization + +### Resource Tuning + +Adjust resource limits based on usage: + +```bash +# Monitor resource usage +kubectl top pods -n astroml + +# Update resource limits +kubectl set resources deployment/feature-store \ + -n astroml \ + --limits=cpu=2000m,memory=2Gi \ + --requests=cpu=1000m,memory=1Gi +``` + +### Caching Configuration + +Optimize Redis caching: + +```yaml +env: +- name: FEATURE_STORE_CACHE_SIZE + value: "5000" +- name: FEATURE_STORE_CACHE_TTL + value: "7200" +``` + +### Database Optimization + +Configure PostgreSQL for performance: + +```yaml +env: +- name: POSTGRES_SHARED_BUFFERS + value: "256MB" +- name: POSTGRES_EFFECTIVE_CACHE_SIZE + value: "1GB" +``` + +## Maintenance + +### Rolling Updates + +```bash +# Update deployment +kubectl set image deployment/feature-store \ + feature-store=astroml:latest \ + -n astroml + +# Rollout status +kubectl rollout status deployment/feature-store -n astroml + +# Rollback if needed +kubectl rollout undo deployment/feature-store -n astroml +``` + +### Cleanup + +```bash +# Remove all components +./scripts/deploy-k8s.sh cleanup + +# Remove specific components +kubectl delete -f k8s/feature-store-deployment.yaml -n astroml + +# Remove namespace +kubectl delete namespace astroml +``` + +## Best Practices + +1. **Always use secrets** for sensitive data +2. **Implement resource limits** to prevent resource exhaustion +3. **Use liveness and readiness probes** for health checks +4. **Implement network policies** for security +5. **Monitor resource usage** regularly +6. **Backup data regularly** +7. **Test deployments in staging first** +8. **Use version tags** for images +9. **Implement proper RBAC** for access control +10. **Document custom configurations** + +## Support + +For issues and questions: +1. Check this documentation +2. Review logs and error messages +3. Search GitHub issues +4. Create new issue with details + +## Additional Resources + +- [Kubernetes Documentation](https://kubernetes.io/docs/) +- [Kustomize Documentation](https://kustomize.io/) +- [Prometheus Documentation](https://prometheus.io/docs/) +- [Grafana Documentation](https://grafana.com/docs/) +- [Elastic Stack Documentation](https://www.elastic.co/guide/) diff --git a/k8s/feature-store-deployment.yaml b/k8s/feature-store-deployment.yaml new file mode 100644 index 0000000..d3898a3 --- /dev/null +++ b/k8s/feature-store-deployment.yaml @@ -0,0 +1,169 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: feature-store-config + namespace: astroml +data: + FEATURE_STORE_PATH: "/app/feature_store" + FEATURE_STORE_CACHE_SIZE: "1000" + FEATURE_STORE_CACHE_TTL: "3600" + LOG_LEVEL: "INFO" + ASTROML_ENV: "production" +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: feature-store + namespace: astroml + labels: + app: feature-store + component: feature-store +spec: + replicas: 2 + selector: + matchLabels: + app: feature-store + template: + metadata: + labels: + app: feature-store + component: feature-store + spec: + serviceAccountName: astroml + containers: + - name: feature-store + image: astroml:latest + imagePullPolicy: IfNotPresent + env: + - name: DATABASE_URL + valueFrom: + configMapKeyRef: + name: astroml-config + key: DATABASE_URL + - name: REDIS_URL + valueFrom: + configMapKeyRef: + name: astroml-config + key: REDIS_URL + - name: FEATURE_STORE_PATH + valueFrom: + configMapKeyRef: + name: feature-store-config + key: FEATURE_STORE_PATH + - name: FEATURE_STORE_CACHE_SIZE + valueFrom: + configMapKeyRef: + name: feature-store-config + key: FEATURE_STORE_CACHE_SIZE + - name: FEATURE_STORE_CACHE_TTL + valueFrom: + configMapKeyRef: + name: feature-store-config + key: FEATURE_STORE_CACHE_TTL + - name: LOG_LEVEL + valueFrom: + configMapKeyRef: + name: feature-store-config + key: LOG_LEVEL + - name: ASTROML_ENV + valueFrom: + configMapKeyRef: + name: feature-store-config + key: ASTROML_ENV + ports: + - containerPort: 8000 + name: http + - containerPort: 8080 + name: metrics + command: ["python", "-c"] + args: + - | + from astroml.features import create_feature_store + store = create_feature_store('/app/feature_store') + print('Feature Store service ready') + import time + while True: + time.sleep(60) + resources: + requests: + memory: "512Mi" + cpu: "500m" + limits: + memory: "1Gi" + cpu: "1000m" + livenessProbe: + httpGet: + path: /health + port: 8000 + initialDelaySeconds: 30 + periodSeconds: 10 + readinessProbe: + httpGet: + path: /ready + port: 8000 + initialDelaySeconds: 10 + periodSeconds: 5 + volumeMounts: + - name: feature-store-storage + mountPath: /app/feature_store + volumes: + - name: feature-store-storage + persistentVolumeClaim: + claimName: feature-store-pvc +--- +apiVersion: v1 +kind: Service +metadata: + name: feature-store + namespace: astroml + labels: + app: feature-store +spec: + type: ClusterIP + ports: + - port: 8000 + targetPort: 8000 + name: http + - port: 8080 + targetPort: 8080 + name: metrics + selector: + app: feature-store +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: feature-store-pvc + namespace: astroml +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 10Gi +--- +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: feature-store-hpa + namespace: astroml +spec: + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: feature-store + minReplicas: 2 + maxReplicas: 10 + metrics: + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: 70 + - type: Resource + resource: + name: memory + target: + type: Utilization + averageUtilization: 80 diff --git a/k8s/ingress.yaml b/k8s/ingress.yaml new file mode 100644 index 0000000..6aff7d7 --- /dev/null +++ b/k8s/ingress.yaml @@ -0,0 +1,76 @@ +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: astroml-ingress + namespace: astroml + annotations: + nginx.ingress.kubernetes.io/rewrite-target: / + nginx.ingress.kubernetes.io/ssl-redirect: "true" + cert-manager.io/cluster-issuer: "letsencrypt-prod" + nginx.ingress.kubernetes.io/rate-limit: "100" + nginx.ingress.kubernetes.io/cors-allow-origin: "*" +spec: + ingressClassName: nginx + tls: + - hosts: + - astroml.example.com + secretName: astroml-tls + rules: + - host: astroml.example.com + http: + paths: + - path: /api + pathType: Prefix + backend: + service: + name: astroml-ingestion + port: + number: 8000 + - path: /feature-store + pathType: Prefix + backend: + service: + name: feature-store + port: + number: 8000 + - path: /training + pathType: Prefix + backend: + service: + name: astroml-training + port: + number: 6006 +--- +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: astroml-monitoring-ingress + namespace: astroml + annotations: + nginx.ingress.kubernetes.io/rewrite-target: / + nginx.ingress.kubernetes.io/ssl-redirect: "true" + cert-manager.io/cluster-issuer: "letsencrypt-prod" +spec: + ingressClassName: nginx + tls: + - hosts: + - monitoring.astroml.example.com + secretName: astroml-monitoring-tls + rules: + - host: monitoring.astroml.example.com + http: + paths: + - path: /grafana + pathType: Prefix + backend: + service: + name: grafana + port: + number: 3000 + - path: /prometheus + pathType: Prefix + backend: + service: + name: prometheus + port: + number: 9090 diff --git a/k8s/kustomization.yaml b/k8s/kustomization.yaml index 958bb04..d0123bd 100644 --- a/k8s/kustomization.yaml +++ b/k8s/kustomization.yaml @@ -7,7 +7,12 @@ resources: - namespace.yaml - postgres-deployment.yaml - redis-deployment.yaml + - feature-store-deployment.yaml - astroml-deployment.yaml + - services.yaml + - ingress.yaml + - monitoring.yaml + - logging.yaml - rbac.yaml commonLabels: diff --git a/k8s/logging.yaml b/k8s/logging.yaml new file mode 100644 index 0000000..810f2a0 --- /dev/null +++ b/k8s/logging.yaml @@ -0,0 +1,248 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: fluentd-config + namespace: astroml +data: + fluent.conf: | + + @type tail + path /var/log/containers/*.log + pos_file /var/log/fluentd-containers.log.pos + tag kubernetes.* + read_from_head true + + @type json + time_format %Y-%m-%dT%H:%M:%S.%NZ + + + + + @type kubernetes_metadata + + + + @type record_transformer + + hostname "#{Socket.gethostname}" + + + + + @type elasticsearch + host elasticsearch + port 9200 + logstash_format true + logstash_prefix astroml + logstash_dateformat %Y.%m.%d + include_tag_key true + tag_key @log_name + flush_interval 1s + +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: fluentd + namespace: astroml + labels: + app: fluentd +spec: + replicas: 1 + selector: + matchLabels: + app: fluentd + template: + metadata: + labels: + app: fluentd + spec: + serviceAccountName: fluentd + containers: + - name: fluentd + image: fluent/fluentd-kubernetes-daemonset:v1-debian-elasticsearch + env: + - name: FLUENT_ELASTICSEARCH_HOST + value: "elasticsearch" + - name: FLUENT_ELASTICSEARCH_PORT + value: "9200" + - name: FLUENT_ELASTICSEARCH_SCHEME + value: "http" + resources: + limits: + memory: 500Mi + requests: + cpu: 100m + memory: 200Mi + volumeMounts: + - name: varlog + mountPath: /var/log + - name: varlibdockercontainers + mountPath: /var/lib/docker/containers + readOnly: true + - name: fluentd-config + mountPath: /fluentd/etc + terminationGracePeriodSeconds: 30 + volumes: + - name: varlog + hostPath: + path: /var/log + - name: varlibdockercontainers + hostPath: + path: /var/lib/docker/containers + - name: fluentd-config + configMap: + name: fluentd-config +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: fluentd + namespace: astroml +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: fluentd +rules: +- apiGroups: [""] + resources: ["pods", "namespaces"] + verbs: ["get", "list", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: fluentd +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: fluentd +subjects: +- kind: ServiceAccount + name: fluentd + namespace: astroml +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: elasticsearch + namespace: astroml + labels: + app: elasticsearch +spec: + replicas: 1 + selector: + matchLabels: + app: elasticsearch + template: + metadata: + labels: + app: elasticsearch + spec: + containers: + - name: elasticsearch + image: docker.elastic.co/elasticsearch/elasticsearch:8.8.0 + ports: + - containerPort: 9200 + - containerPort: 9300 + env: + - name: discovery.type + value: single-node + - name: ES_JAVA_OPTS + value: "-Xms512m -Xmx512m" + - name: xpack.security.enabled + value: "false" + resources: + requests: + memory: "1Gi" + cpu: "500m" + limits: + memory: "2Gi" + cpu: "1000m" + volumeMounts: + - name: elasticsearch-storage + mountPath: /usr/share/elasticsearch/data + volumes: + - name: elasticsearch-storage + persistentVolumeClaim: + claimName: elasticsearch-pvc +--- +apiVersion: v1 +kind: Service +metadata: + name: elasticsearch + namespace: astroml + labels: + app: elasticsearch +spec: + type: ClusterIP + ports: + - port: 9200 + targetPort: 9200 + name: http + - port: 9300 + targetPort: 9300 + name: transport + selector: + app: elasticsearch +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: elasticsearch-pvc + namespace: astroml +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 30Gi +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: kibana + namespace: astroml + labels: + app: kibana +spec: + replicas: 1 + selector: + matchLabels: + app: kibana + template: + metadata: + labels: + app: kibana + spec: + containers: + - name: kibana + image: docker.elastic.co/kibana/kibana:8.8.0 + ports: + - containerPort: 5601 + env: + - name: ELASTICSEARCH_HOSTS + value: "http://elasticsearch:9200" + resources: + requests: + memory: "512Mi" + cpu: "250m" + limits: + memory: "1Gi" + cpu: "500m" +--- +apiVersion: v1 +kind: Service +metadata: + name: kibana + namespace: astroml + labels: + app: kibana +spec: + type: ClusterIP + ports: + - port: 5601 + targetPort: 5601 + name: http + selector: + app: kibana diff --git a/k8s/monitoring.yaml b/k8s/monitoring.yaml new file mode 100644 index 0000000..2d4d7da --- /dev/null +++ b/k8s/monitoring.yaml @@ -0,0 +1,234 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: prometheus-config + namespace: astroml +data: + prometheus.yml: | + global: + scrape_interval: 15s + evaluation_interval: 15s + scrape_configs: + - job_name: 'feature-store' + kubernetes_sd_configs: + - role: pod + namespaces: + names: + - astroml + relabel_configs: + - source_labels: [__meta_kubernetes_pod_label_app] + action: keep + regex: feature-store + - source_labels: [__meta_kubernetes_pod_ip] + target_label: __address__ + replacement: $1:8080 + - job_name: 'astroml-ingestion' + kubernetes_sd_configs: + - role: pod + namespaces: + names: + - astroml + relabel_configs: + - source_labels: [__meta_kubernetes_pod_label_app] + action: keep + regex: astroml-ingestion + - source_labels: [__meta_kubernetes_pod_ip] + target_label: __address__ + replacement: $1:8080 + - job_name: 'postgres' + kubernetes_sd_configs: + - role: pod + namespaces: + names: + - astroml + relabel_configs: + - source_labels: [__meta_kubernetes_pod_label_app] + action: keep + regex: postgres + - source_labels: [__meta_kubernetes_pod_ip] + target_label: __address__ + replacement: $1:9187 + - job_name: 'redis' + kubernetes_sd_configs: + - role: pod + namespaces: + names: + - astroml + relabel_configs: + - source_labels: [__meta_kubernetes_pod_label_app] + action: keep + regex: redis + - source_labels: [__meta_kubernetes_pod_ip] + target_label: __address__ + replacement: $1:9121 +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: prometheus + namespace: astroml + labels: + app: prometheus +spec: + replicas: 1 + selector: + matchLabels: + app: prometheus + template: + metadata: + labels: + app: prometheus + spec: + containers: + - name: prometheus + image: prom/prometheus:latest + ports: + - containerPort: 9090 + volumeMounts: + - name: prometheus-config + mountPath: /etc/prometheus + - name: prometheus-storage + mountPath: /prometheus + resources: + requests: + memory: "512Mi" + cpu: "500m" + limits: + memory: "1Gi" + cpu: "1000m" + volumes: + - name: prometheus-config + configMap: + name: prometheus-config + - name: prometheus-storage + persistentVolumeClaim: + claimName: prometheus-pvc +--- +apiVersion: v1 +kind: Service +metadata: + name: prometheus + namespace: astroml + labels: + app: prometheus +spec: + type: ClusterIP + ports: + - port: 9090 + targetPort: 9090 + name: http + selector: + app: prometheus +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: prometheus-pvc + namespace: astroml +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 20Gi +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: grafana-config + namespace: astroml +data: + grafana.ini: | + [server] + http_port = 3000 + [security] + admin_user = admin + admin_password = admin + [database] + type = sqlite3 + path = /var/lib/grafana/grafana.db +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: grafana + namespace: astroml + labels: + app: grafana +spec: + replicas: 1 + selector: + matchLabels: + app: grafana + template: + metadata: + labels: + app: grafana + spec: + containers: + - name: grafana + image: grafana/grafana:latest + ports: + - containerPort: 3000 + env: + - name: GF_SECURITY_ADMIN_PASSWORD + valueFrom: + secretKeyRef: + name: grafana-secret + key: admin-password + volumeMounts: + - name: grafana-config + mountPath: /etc/grafana + - name: grafana-storage + mountPath: /var/lib/grafana + resources: + requests: + memory: "256Mi" + cpu: "250m" + limits: + memory: "512Mi" + cpu: "500m" + volumes: + - name: grafana-config + configMap: + name: grafana-config + - name: grafana-storage + persistentVolumeClaim: + claimName: grafana-pvc +--- +apiVersion: v1 +kind: Service +metadata: + name: grafana + namespace: astroml + labels: + app: grafana +spec: + type: ClusterIP + ports: + - port: 3000 + targetPort: 3000 + name: http + selector: + app: grafana +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: grafana-pvc + namespace: astroml +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 10Gi +--- +apiVersion: v1 +kind: Secret +metadata: + name: grafana-secret + namespace: astroml +type: Opaque +stringData: + admin-password: "admin_change_me" diff --git a/k8s/services.yaml b/k8s/services.yaml new file mode 100644 index 0000000..7378410 --- /dev/null +++ b/k8s/services.yaml @@ -0,0 +1,47 @@ +apiVersion: v1 +kind: Service +metadata: + name: astroml-ingestion + namespace: astroml + labels: + app: astroml-ingestion +spec: + type: ClusterIP + ports: + - port: 8000 + targetPort: 8000 + name: http + selector: + app: astroml-ingestion +--- +apiVersion: v1 +kind: Service +metadata: + name: astroml-training + namespace: astroml + labels: + app: astroml-training +spec: + type: ClusterIP + ports: + - port: 6006 + targetPort: 6006 + name: tensorboard + selector: + app: astroml-training +--- +apiVersion: v1 +kind: Service +metadata: + name: astroml-api + namespace: astroml + labels: + app: astroml-api +spec: + type: ClusterIP + ports: + - port: 8000 + targetPort: 8000 + name: http + selector: + app: astroml-api diff --git a/requirements.txt b/requirements.txt index 0d5b3d7..e824f0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,10 @@ omegaconf>=2.3.0 pytorch-lightning>=2.0.0 prometheus-client>=0.19.0 -# Feature Store dependencies +```txt id="z3f7qc" +# ============================================================================ +# Feature Store Dependencies +# ============================================================================ redis>=5.0.0 cachetools>=5.3.0 pyarrow>=14.0.0 @@ -34,14 +37,28 @@ tqdm>=4.66.0 click>=8.1.0 rich>=13.7.0 -# Development and testing dependencies +# ============================================================================ +# Visualization Dependencies +# ============================================================================ +matplotlib>=3.7.0 +seaborn>=0.12.0 + +# ============================================================================ +# Development & Testing Dependencies +# ============================================================================ pytest>=7.4.0 pytest-cov>=4.1.0 pytest-mock>=3.12.0 + black>=23.11.0 flake8>=6.1.0 mypy>=1.7.0 + +# ============================================================================ +# Jupyter / Notebook Dependencies +# ============================================================================ jupyter>=1.0.0 notebook>=7.0.0 ipykernel>=6.26.0 +``` diff --git a/scripts/deploy-k8s.sh b/scripts/deploy-k8s.sh new file mode 100644 index 0000000..8c958d3 --- /dev/null +++ b/scripts/deploy-k8s.sh @@ -0,0 +1,324 @@ +#!/bin/bash +# Kubernetes deployment script for AstroML +# This script handles deployment to Kubernetes clusters + +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Function to print colored output +print_status() { + echo -e "${GREEN}[INFO]${NC} $1" +} + +print_warning() { + echo -e "${YELLOW}[WARNING]${NC} $1" +} + +print_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +print_header() { + echo -e "${BLUE}=== $1 ===${NC}" +} + +# Function to check prerequisites +check_prerequisites() { + print_header "Checking Prerequisites" + + # Check kubectl + if ! command -v kubectl > /dev/null 2>&1; then + print_error "kubectl is not installed" + exit 1 + fi + print_status "kubectl is installed" + + # Check kustomize + if ! command -v kustomize > /dev/null 2>&1; then + print_warning "kustomize is not installed, installing..." + curl -s "https://raw.githubusercontent.com/kubernetes-sigs/kustomize/master/hack/install_kustomize.sh" | bash + sudo mv kustomize /usr/local/bin/ + fi + print_status "kustomize is available" + + # Check cluster connectivity + if ! kubectl cluster-info > /dev/null 2>&1; then + print_error "Cannot connect to Kubernetes cluster" + exit 1 + fi + print_status "Kubernetes cluster is accessible" +} + +# Function to deploy to namespace +deploy_namespace() { + local namespace=${1:-astroml} + print_header "Deploying Namespace" + + kubectl create namespace $namespace --dry-run=client -o yaml | kubectl apply -f - + print_status "Namespace $namespace created/verified" +} + +# Function to deploy secrets +deploy_secrets() { + print_header "Deploying Secrets" + + # Check if secrets file exists + if [ -f "k8s/secrets.yaml" ]; then + kubectl apply -f k8s/secrets.yaml + print_status "Secrets deployed" + else + print_warning "No secrets file found, using default values" + fi +} + +# Function to deploy base infrastructure +deploy_base() { + print_header "Deploying Base Infrastructure" + + kubectl apply -f k8s/namespace.yaml + kubectl apply -f k8s/postgres-deployment.yaml + kubectl apply -f k8s/redis-deployment.yaml + + print_status "Waiting for PostgreSQL to be ready..." + kubectl wait --for=condition=ready pod -l app=postgres -n astroml --timeout=300s + + print_status "Waiting for Redis to be ready..." + kubectl wait --for=condition=ready pod -l app=redis -n astroml --timeout=300s + + print_status "Base infrastructure deployed" +} + +# Function to deploy Feature Store +deploy_feature_store() { + print_header "Deploying Feature Store" + + kubectl apply -f k8s/feature-store-deployment.yaml + + print_status "Waiting for Feature Store to be ready..." + kubectl wait --for=condition=ready pod -l app=feature-store -n astroml --timeout=300s + + print_status "Feature Store deployed" +} + +# Function to deploy applications +deploy_applications() { + print_header "Deploying Applications" + + kubectl apply -f k8s/astroml-deployment.yaml + kubectl apply -f k8s/services.yaml + + print_status "Waiting for applications to be ready..." + kubectl wait --for=condition=ready pod -l app=astroml-ingestion -n astroml --timeout=300s + kubectl wait --for=condition=ready pod -l app=astroml-training -n astroml --timeout=300s + + print_status "Applications deployed" +} + +# Function to deploy monitoring +deploy_monitoring() { + print_header "Deploying Monitoring Stack" + + kubectl apply -f k8s/monitoring.yaml + + print_status "Waiting for monitoring stack to be ready..." + kubectl wait --for=condition=ready pod -l app=prometheus -n astroml --timeout=300s + kubectl wait --for=condition=ready pod -l app=grafana -n astroml --timeout=300s + + print_status "Monitoring stack deployed" +} + +# Function to deploy logging +deploy_logging() { + print_header "Deploying Logging Stack" + + kubectl apply -f k8s/logging.yaml + + print_status "Waiting for logging stack to be ready..." + kubectl wait --for=condition=ready pod -l app=elasticsearch -n astroml --timeout=300s + kubectl wait --for=condition=ready pod -l app=kibana -n astroml --timeout=300s + + print_status "Logging stack deployed" +} + +# Function to deploy ingress +deploy_ingress() { + print_header "Deploying Ingress" + + kubectl apply -f k8s/ingress.yaml + + print_status "Ingress deployed" +} + +# Function to deploy using kustomize +deploy_kustomize() { + print_header "Deploying with Kustomize" + + kustomize build k8s/ | kubectl apply -f - + + print_status "Deployment completed with Kustomize" +} + +# Function to verify deployment +verify_deployment() { + print_header "Verifying Deployment" + + print_status "Checking pod status..." + kubectl get pods -n astroml + + print_status "Checking services..." + kubectl get services -n astroml + + print_status "Checking ingress..." + kubectl get ingress -n astroml + + print_status "Deployment verification completed" +} + +# Function to get access information +get_access_info() { + print_header "Access Information" + + print_status "Service Endpoints:" + kubectl get services -n astroml + + print_status "Ingress Endpoints:" + kubectl get ingress -n astroml + + print_status "To access Grafana:" + echo "kubectl port-forward -n astroml svc/grafana 3000:3000" + + print_status "To access Kibana:" + echo "kubectl port-forward -n astroml svc/kibana 5601:5601" +} + +# Function to rollback deployment +rollback_deployment() { + local deployment=${1:-astroml-ingestion} + print_header "Rolling Back Deployment" + + kubectl rollout undo deployment/$deployment -n astroml + + print_status "Rollback completed for $deployment" +} + +# Function to scale deployment +scale_deployment() { + local deployment=${1:-astroml-ingestion} + local replicas=${2:-3} + print_header "Scaling Deployment" + + kubectl scale deployment/$deployment -n astroml --replicas=$replicas + + print_status "Deployment $deployment scaled to $replicas replicas" +} + +# Function to show logs +show_logs() { + local deployment=${1:-astroml-ingestion} + print_header "Showing Logs" + + kubectl logs -f deployment/$deployment -n astroml +} + +# Function to clean up +cleanup() { + print_header "Cleaning Up" + + kustomize build k8s/ | kubectl delete -f - + + print_status "Cleanup completed" +} + +# Main execution +main() { + local command=${1:-deploy} + local environment=${2:-production} + + print_header "AstroML Kubernetes Deployment" + + # Change to project directory + cd "$(dirname "$0")/.." + + # Check prerequisites + check_prerequisites + + case $command in + "deploy") + deploy_namespace + deploy_secrets + deploy_base + deploy_feature_store + deploy_applications + deploy_monitoring + deploy_logging + deploy_ingress + verify_deployment + get_access_info + ;; + "kustomize") + deploy_kustomize + verify_deployment + get_access_info + ;; + "monitoring") + deploy_monitoring + ;; + "logging") + deploy_logging + ;; + "verify") + verify_deployment + ;; + "access") + get_access_info + ;; + "rollback") + rollback_deployment $2 + ;; + "scale") + scale_deployment $2 $3 + ;; + "logs") + show_logs $2 + ;; + "cleanup") + cleanup + ;; + "help"|*) + echo "AstroML Kubernetes Deployment Script" + echo "" + echo "Usage: $0 [COMMAND] [OPTIONS]" + echo "" + echo "Commands:" + echo " deploy Deploy all components" + echo " kustomize Deploy using Kustomize" + echo " monitoring Deploy monitoring stack only" + echo " logging Deploy logging stack only" + echo " verify Verify deployment status" + echo " access Show access information" + echo " rollback [name] Rollback deployment" + echo " scale [name] [replicas] Scale deployment" + echo " logs [name] Show logs for deployment" + echo " cleanup Remove all components" + echo " help Show this help message" + echo "" + echo "Examples:" + echo " $0 deploy" + echo " $0 kustomize" + echo " $0 scale astroml-ingestion 5" + echo " $0 logs feature-store" + ;; + esac +} + +# Handle signals gracefully +trap 'print_warning "Deployment interrupted"; exit 1' SIGINT SIGTERM + +# Execute main function +main "$@" diff --git a/scripts/verify-k8s-deployment.sh b/scripts/verify-k8s-deployment.sh new file mode 100644 index 0000000..7746256 --- /dev/null +++ b/scripts/verify-k8s-deployment.sh @@ -0,0 +1,397 @@ +#!/bin/bash +# Kubernetes deployment verification script for AstroML +# This script verifies that all Kubernetes components are deployed correctly + +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Function to print colored output +print_status() { + echo -e "${GREEN}[INFO]${NC} $1" +} + +print_warning() { + echo -e "${YELLOW}[WARNING]${NC} $1" +} + +print_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +print_header() { + echo -e "${BLUE}=== $1 ===${NC}" +} + +# Function to check prerequisites +check_prerequisites() { + print_header "Checking Prerequisites" + + # Check kubectl + if ! command -v kubectl > /dev/null 2>&1; then + print_error "kubectl is not installed" + return 1 + fi + print_status "kubectl is installed" + + # Check cluster connectivity + if ! kubectl cluster-info > /dev/null 2>&1; then + print_error "Cannot connect to Kubernetes cluster" + return 1 + fi + print_status "Kubernetes cluster is accessible" + + # Check kustomize + if ! command -v kustomize > /dev/null 2>&1; then + print_warning "kustomize is not installed" + else + print_status "kustomize is available" + fi + + return 0 +} + +# Function to verify namespace +verify_namespace() { + print_header "Verifying Namespace" + + if kubectl get namespace astroml > /dev/null 2>&1; then + print_status "Namespace astroml exists" + kubectl get namespace astroml + else + print_error "Namespace astroml does not exist" + return 1 + fi +} + +# Function to verify deployments +verify_deployments() { + print_header "Verifying Deployments" + + local deployments=( + "postgres" + "redis" + "feature-store" + "astroml-ingestion" + "astroml-training" + "prometheus" + "grafana" + "elasticsearch" + "kibana" + ) + + local failed_deployments=0 + + for deployment in "${deployments[@]}"; do + if kubectl get deployment $deployment -n astroml > /dev/null 2>&1; then + local ready=$(kubectl get deployment $deployment -n astroml -o jsonpath='{.status.readyReplicas}') + local desired=$(kubectl get deployment $deployment -n astroml -o jsonpath='{.spec.replicas}') + + if [ "$ready" = "$desired" ] && [ "$ready" != "" ]; then + print_status "✓ $deployment is ready ($ready/$desired replicas)" + else + print_warning "⚠ $deployment is not ready ($ready/$desired replicas)" + failed_deployments=$((failed_deployments + 1)) + fi + else + print_warning "✗ $deployment does not exist" + failed_deployments=$((failed_deployments + 1)) + fi + done + + return $failed_deployments +} + +# Function to verify pods +verify_pods() { + print_header "Verifying Pods" + + print_status "Pod status in astroml namespace:" + kubectl get pods -n astroml + + local failed_pods=0 + + # Check for failed pods + local failed=$(kubectl get pods -n astroml -o json | jq -r '.items[] | select(.status.phase=="Failed") | .metadata.name') + if [ -n "$failed" ]; then + print_error "Failed pods detected: $failed" + failed_pods=$((failed_pods + 1)) + fi + + # Check for pending pods + local pending=$(kubectl get pods -n astroml -o json | jq -r '.items[] | select(.status.phase=="Pending") | .metadata.name') + if [ -n "$pending" ]; then + print_warning "Pending pods detected: $pending" + fi + + return $failed_pods +} + +# Function to verify services +verify_services() { + print_header "Verifying Services" + + print_status "Services in astroml namespace:" + kubectl get services -n astroml + + local services=( + "postgres" + "redis" + "feature-store" + "astroml-ingestion" + "astroml-training" + "prometheus" + "grafana" + "elasticsearch" + "kibana" + ) + + local failed_services=0 + + for service in "${services[@]}"; do + if kubectl get service $service -n astroml > /dev/null 2>&1; then + local type=$(kubectl get service $service -n astroml -o jsonpath='{.spec.type}') + local ports=$(kubectl get service $service -n astroml -o jsonpath='{.spec.ports[*].port}') + print_status "✓ $service exists ($type, ports: $ports)" + else + print_warning "✗ $service does not exist" + failed_services=$((failed_services + 1)) + fi + done + + return $failed_services +} + +# Function to verify ingress +verify_ingress() { + print_header "Verifying Ingress" + + if kubectl get ingress -n astroml > /dev/null 2>&1; then + print_status "Ingress resources in astroml namespace:" + kubectl get ingress -n astroml + return 0 + else + print_warning "No ingress resources found" + return 1 + fi +} + +# Function to verify persistent volumes +verify_persistent_volumes() { + print_header "Verifying Persistent Volumes" + + print_status "PVCs in astroml namespace:" + kubectl get pvc -n astroml + + local pvcs=( + "postgres-storage" + "feature-store-pvc" + "prometheus-pvc" + "grafana-pvc" + "elasticsearch-pvc" + ) + + local failed_pvcs=0 + + for pvc in "${pvcs[@]}"; do + if kubectl get pvc $pvc -n astroml > /dev/null 2>&1; then + local status=$(kubectl get pvc $pvc -n astroml -o jsonpath='{.status.phase}') + print_status "✓ $pvc exists ($status)" + else + print_warning "✗ $pvc does not exist" + failed_pvcs=$((failed_pvcs + 1)) + fi + done + + return $failed_pvcs +} + +# Function to verify configmaps +verify_configmaps() { + print_header "Verifying ConfigMaps" + + print_status "ConfigMaps in astroml namespace:" + kubectl get configmaps -n astroml + + local configmaps=( + "astroml-config" + "feature-store-config" + "postgres-config" + "prometheus-config" + "grafana-config" + "fluentd-config" + ) + + local failed_configmaps=0 + + for configmap in "${configmaps[@]}"; do + if kubectl get configmap $configmap -n astroml > /dev/null 2>&1; then + print_status "✓ $configmap exists" + else + print_warning "✗ $configmap does not exist" + failed_configmaps=$((failed_configmaps + 1)) + fi + done + + return $failed_configmaps +} + +# Function to verify secrets +verify_secrets() { + print_header "Verifying Secrets" + + print_status "Secrets in astroml namespace:" + kubectl get secrets -n astroml + + local secrets=( + "postgres-secret" + "grafana-secret" + ) + + local failed_secrets=0 + + for secret in "${secrets[@]}"; do + if kubectl get secret $secret -n astroml > /dev/null 2>&1; then + print_status "✓ $secret exists" + else + print_warning "✗ $secret does not exist" + failed_secrets=$((failed_secrets + 1)) + fi + done + + return $failed_secrets +} + +# Function to verify HPA +verify_hpa() { + print_header "Verifying Horizontal Pod Autoscalers" + + if kubectl get hpa -n astroml > /dev/null 2>&1; then + print_status "HPA resources in astroml namespace:" + kubectl get hpa -n astroml + return 0 + else + print_warning "No HPA resources found" + return 1 + fi +} + +# Function to test connectivity +test_connectivity() { + print_header "Testing Connectivity" + + # Test Feature Store + print_status "Testing Feature Store connectivity..." + if kubectl exec -n astroml deployment/feature-store -- python -c " +from astroml.features import create_feature_store +store = create_feature_store('/app/feature_store') +print('Feature Store is accessible') +" 2>/dev/null; then + print_status "✓ Feature Store is accessible" + else + print_warning "✗ Feature Store connectivity test failed" + fi + + # Test PostgreSQL + print_status "Testing PostgreSQL connectivity..." + if kubectl exec -n astroml deployment/postgres -- pg_isready -U astroml > /dev/null 2>&1; then + print_status "✓ PostgreSQL is accessible" + else + print_warning "✗ PostgreSQL connectivity test failed" + fi + + # Test Redis + print_status "Testing Redis connectivity..." + if kubectl exec -n astroml deployment/redis -- redis-cli ping | grep -q "PONG"; then + print_status "✓ Redis is accessible" + else + print_warning "✗ Redis connectivity test failed" + fi +} + +# Function to check resource usage +check_resource_usage() { + print_header "Checking Resource Usage" + + print_status "Pod resource usage:" + kubectl top pods -n astroml 2>/dev/null || print_warning "Metrics server not available" + + print_status "Node resource usage:" + kubectl top nodes 2>/dev/null || print_warning "Metrics server not available" +} + +# Function to generate report +generate_report() { + print_header "Verification Report" + + echo "Kubernetes Deployment Verification completed on $(date)" + echo "======================================================" + echo "" + echo "Components Verified:" + echo "- Namespace" + echo "- Deployments" + echo "- Pods" + echo "- Services" + echo "- Ingress" + echo "- Persistent Volumes" + echo "- ConfigMaps" + echo "- Secrets" + echo "- Horizontal Pod Autoscalers" + echo "- Connectivity" + echo "- Resource Usage" + echo "" + echo "For detailed information, check the output above." + echo "" + echo "Next Steps:" + echo "1. Review any warnings or errors above" + echo "2. Check logs for failed components: kubectl logs -n astroml" + echo "3. Access services: kubectl port-forward -n astroml svc/ :" + echo "4. Monitor deployment: kubectl get pods -n astroml -w" +} + +# Main execution +main() { + print_header "AstroML Kubernetes Deployment Verification" + + # Change to project directory + cd "$(dirname "$0")/.." + + local failed_checks=0 + + # Run verification steps + check_prerequisites || ((failed_checks++)) + verify_namespace || ((failed_checks++)) + verify_deployments || ((failed_checks++)) + verify_pods || ((failed_checks++)) + verify_services || ((failed_checks++)) + verify_ingress || ((failed_checks++)) + verify_persistent_volumes || ((failed_checks++)) + verify_configmaps || ((failed_checks++)) + verify_secrets || ((failed_checks++)) + verify_hpa || ((failed_checks++)) + test_connectivity + check_resource_usage + + # Generate report + generate_report + + # Exit with appropriate code + if [ $failed_checks -eq 0 ]; then + print_status "✅ All verification checks passed!" + exit 0 + else + print_error "❌ $failed_checks verification checks failed" + exit 1 + fi +} + +# Handle signals gracefully +trap 'print_warning "Verification interrupted"; exit 1' SIGINT SIGTERM + +# Execute main function +main "$@" diff --git a/src/auth_tests.rs b/src/auth_tests.rs new file mode 100644 index 0000000..441a060 --- /dev/null +++ b/src/auth_tests.rs @@ -0,0 +1,475 @@ +//! Authentication and authorization tests for the Fraud Registry Soroban contract. +//! +//! This module tests: +//! - Admin authentication and authorization +//! - Validator registration and lifecycle +//! - Access control for privileged operations +//! - Session-like behavior through validator state +//! +//! Run with: +//! cargo test --lib auth -- --nocapture + +#[cfg(test)] +mod auth_tests { + use soroban_sdk::{testutils::Address as _, Address, Env, String}; + use crate::{Error, FraudRegistry, FraudRegistryClient}; + + // Helper: deploy and initialise a fresh contract instance. + fn setup_contract(env: &Env) -> (FraudRegistryClient<'_>, Address) { + let contract_id = env.register_contract(None, FraudRegistry); + let client = FraudRegistryClient::new(env, &contract_id); + let admin = Address::generate(env); + client.initialize(&admin); + (client, admin) + } + + // --------------------------------------------------------------------------- + // Admin Authentication Tests + // --------------------------------------------------------------------------- + + #[test] + fn test_admin_initialization_sets_correct_admin() { + let env = Env::default(); + let contract_id = env.register_contract(None, FraudRegistry); + let client = FraudRegistryClient::new(&env, &contract_id); + + let admin = Address::generate(&env); + client.initialize(&admin); + + // Verify admin can perform admin-only operations + let validator = Address::generate(&env); + let result = client.try_register_validator(&admin, &validator, &75_u32); + assert!(result.is_ok(), "Admin should be able to register validators"); + } + + #[test] + fn test_non_admin_cannot_initialize_contract() { + let env = Env::default(); + let contract_id = env.register_contract(None, FraudRegistry); + let client = FraudRegistryClient::new(&env, &contract_id); + + let admin = Address::generate(&env); + client.initialize(&admin); + + // Try to re-initialize with different admin (documents SC-1 vulnerability) + let attacker = Address::generate(&env); + client.initialize(&attacker); + + // Original admin should no longer have access + let validator = Address::generate(&env); + let result = client.try_register_validator(&admin, &validator, &75_u32); + assert_eq!(result, Err(Ok(Error::Unauthorized))); + } + + #[test] + fn test_admin_can_update_config() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let result = client.try_update_config(&admin, &Some(60_u32), &Some(70_u32), &Some(5_u32)); + assert!(result.is_ok(), "Admin should be able to update config"); + } + + #[test] + fn test_admin_can_deactivate_validator() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + client.register_validator(&admin, &validator, &75_u32); + + let result = client.try_deactivate_validator(&admin, &validator); + assert!(result.is_ok(), "Admin should be able to deactivate validators"); + } + + #[test] + fn test_admin_can_update_validator_reputation() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + client.register_validator(&admin, &validator, &75_u32); + + let result = client.try_update_validator_reputation(&admin, &validator, &90_u32); + assert!(result.is_ok(), "Admin should be able to update validator reputation"); + } + + // --------------------------------------------------------------------------- + // Non-Admin Authorization Tests + // --------------------------------------------------------------------------- + + #[test] + fn test_non_admin_cannot_register_validator() { + let env = Env::default(); + let (client, _admin) = setup_contract(&env); + + let attacker = Address::generate(&env); + let validator = Address::generate(&env); + + let result = client.try_register_validator(&attacker, &validator, &75_u32); + assert_eq!(result, Err(Ok(Error::Unauthorized))); + } + + #[test] + fn test_non_admin_cannot_update_config() { + let env = Env::default(); + let (client, _admin) = setup_contract(&env); + + let attacker = Address::generate(&env); + let result = client.try_update_config(&attacker, &Some(60_u32), &Some(70_u32), &Some(5_u32)); + assert_eq!(result, Err(Ok(Error::Unauthorized))); + } + + #[test] + fn test_non_admin_cannot_deactivate_validator() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let attacker = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + let result = client.try_deactivate_validator(&attacker, &validator); + assert_eq!(result, Err(Ok(Error::Unauthorized))); + } + + #[test] + fn test_non_admin_cannot_update_validator_reputation() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let attacker = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + let result = client.try_update_validator_reputation(&attacker, &validator, &90_u32); + assert_eq!(result, Err(Ok(Error::Unauthorized))); + } + + // --------------------------------------------------------------------------- + // Validator Registration Authentication + // --------------------------------------------------------------------------- + + #[test] + fn test_validator_registration_requires_admin() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + + // Successful registration by admin + let result = client.try_register_validator(&admin, &validator, &75_u32); + assert!(result.is_ok()); + + // Verify validator exists + let validator_info = client.get_validator(&validator); + assert_eq!(validator_info.address, validator); + } + + #[test] + fn test_validator_registration_validates_reputation_bounds() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator1 = Address::generate(&env); + let validator2 = Address::generate(&env); + + // Reputation > 100 should fail + let result = client.try_register_validator(&admin, &validator1, &101_u32); + assert_eq!(result, Err(Ok(Error::InvalidInput))); + + // Reputation = 100 should succeed + let result = client.try_register_validator(&admin, &validator2, &100_u32); + assert!(result.is_ok()); + } + + #[test] + fn test_duplicate_validator_registration_fails() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + + // Try to register same validator again + let result = client.try_register_validator(&admin, &validator, &80_u32); + assert_eq!(result, Err(Ok(Error::ValidatorAlreadyExists))); + } + + // --------------------------------------------------------------------------- + // Validator Activation/Deactivation Authentication + // --------------------------------------------------------------------------- + + #[test] + fn test_deactivated_validator_cannot_submit_reports() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let target = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + client.deactivate_validator(&admin, &validator); + + let reason = String::from_str(&env, "Report from inactive validator"); + let result = client.try_report_fraud(&validator, &target, &reason, &80_u32, &None::); + assert_eq!(result, Err(Ok(Error::ValidatorNotActive))); + } + + #[test] + fn test_validator_deactivation_persists_across_operations() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let target = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + client.deactivate_validator(&admin, &validator); + + // Verify validator is still deactivated + let validator_info = client.get_validator(&validator); + assert!(!validator_info.is_active); + + // Try to submit report + let reason = String::from_str(&env, "Test report"); + let result = client.try_report_fraud(&validator, &target, &reason, &80_u32, &None::); + assert_eq!(result, Err(Ok(Error::ValidatorNotActive))); + } + + #[test] + fn test_only_admin_can_reactivate_validator() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let attacker = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + client.deactivate_validator(&admin, &validator); + + // Non-admin cannot reactivate (would require new function, but test the pattern) + // For now, verify that only admin can update validator state + let result = client.try_update_validator_reputation(&attacker, &validator, &90_u32); + assert_eq!(result, Err(Ok(Error::Unauthorized))); + } + + // --------------------------------------------------------------------------- + // Reputation-Based Authentication + // --------------------------------------------------------------------------- + + #[test] + fn test_low_reputation_validator_cannot_submit_reports() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let target = Address::generate(&env); + + // Register with reputation below minimum (50) + client.register_validator(&admin, &validator, &30_u32); + + let reason = String::from_str(&env, "Low reputation attempt"); + let result = client.try_report_fraud(&validator, &target, &reason, &80_u32, &None::); + assert_eq!(result, Err(Ok(Error::InsufficientReputation))); + } + + #[test] + fn test_reputation_update_affects_authentication() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let target = Address::generate(&env); + + // Register with low reputation + client.register_validator(&admin, &validator, &30_u32); + + // Should fail to report + let reason = String::from_str(&env, "Test report"); + let result = client.try_report_fraud(&validator, &target, &reason, &80_u32, &None::); + assert_eq!(result, Err(Ok(Error::InsufficientReputation))); + + // Admin updates reputation to meet threshold + client.update_validator_reputation(&admin, &validator, &60_u32); + + // Should now succeed + let result = client.try_report_fraud(&validator, &target, &reason, &80_u32, &None::); + assert!(result.is_ok()); + } + + #[test] + fn test_reputation_boundary_at_minimum_threshold() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let target = Address::generate(&env); + + // Register with exactly minimum reputation (50) + client.register_validator(&admin, &validator, &50_u32); + + let reason = String::from_str(&env, "Boundary test"); + let result = client.try_report_fraud(&validator, &target, &reason, &80_u32, &None::); + assert!(result.is_ok(), "Reputation at minimum threshold should be accepted"); + } + + // --------------------------------------------------------------------------- + // Confidence-Based Authentication + // --------------------------------------------------------------------------- + + #[test] + fn test_low_confidence_report_rejected() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let target = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + + // Try to report with confidence below minimum (60) + let reason = String::from_str(&env, "Low confidence report"); + let result = client.try_report_fraud(&validator, &target, &reason, &40_u32, &None::); + assert_eq!(result, Err(Ok(Error::InsufficientConfidence))); + } + + #[test] + fn test_confidence_boundary_at_minimum_threshold() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let target = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + + // Report with exactly minimum confidence (60) + let reason = String::from_str(&env, "Boundary test"); + let result = client.try_report_fraud(&validator, &target, &reason, &60_u32, &None::); + assert!(result.is_ok(), "Confidence at minimum threshold should be accepted"); + } + + // --------------------------------------------------------------------------- + // Unregistered Address Authentication + // --------------------------------------------------------------------------- + + #[test] + fn test_unregistered_address_cannot_submit_reports() { + let env = Env::default(); + let (client, _admin) = setup_contract(&env); + + let unregistered = Address::generate(&env); + let target = Address::generate(&env); + + let reason = String::from_str(&env, "Unregistered attempt"); + let result = client.try_report_fraud(&unregistered, &target, &reason, &80_u32, &None::); + assert_eq!(result, Err(Ok(Error::ValidatorNotFound))); + } + + #[test] + fn test_unregistered_address_cannot_be_queried() { + let env = Env::default(); + let (client, _admin) = setup_contract(&env); + + let unregistered = Address::generate(&env); + let result = client.try_get_validator(&unregistered); + assert_eq!(result, Err(Ok(Error::ValidatorNotFound))); + } + + // --------------------------------------------------------------------------- + // Session-Like Behavior (Validator State Persistence) + // --------------------------------------------------------------------------- + + #[test] + fn test_validator_state_persists_across_operations() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let target1 = Address::generate(&env); + let target2 = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + + // Submit first report + let reason1 = String::from_str(&env, "First report"); + client.report_fraud(&validator, &target1, &reason1, &80_u32, &None::); + + // Verify report count increased + let validator_info = client.get_validator(&validator); + assert_eq!(validator_info.report_count, 1); + + // Submit second report to different target + let reason2 = String::from_str(&env, "Second report"); + client.report_fraud(&validator, &target2, &reason2, &75_u32, &None::); + + // Verify report count increased again + let validator_info = client.get_validator(&validator); + assert_eq!(validator_info.report_count, 2); + } + + #[test] + fn test_validator_registration_timestamp_persists() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + + let validator_info = client.get_validator(&validator); + let timestamp = validator_info.registration_timestamp; + + // Timestamp should be non-zero (set during registration) + assert!(timestamp > 0, "Registration timestamp should be set"); + } + + // --------------------------------------------------------------------------- + // Configuration-Based Authentication + // --------------------------------------------------------------------------- + + #[test] + fn test_config_change_affects_authentication_requirements() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let target = Address::generate(&env); + + // Register with reputation 60 (above default minimum of 50) + client.register_validator(&admin, &validator, &60_u32); + + // Should be able to report + let reason = String::from_str(&env, "Test report"); + let result = client.try_report_fraud(&validator, &target, &reason, &80_u32, &None::); + assert!(result.is_ok()); + + // Admin raises minimum reputation to 70 + client.update_config(&admin, &Some(70_u32), &None::, &None::); + + // Should now fail due to new minimum + let result = client.try_report_fraud(&validator, &target, &reason, &80_u32, &None::); + assert_eq!(result, Err(Ok(Error::InsufficientReputation))); + } + + #[test] + fn test_config_change_affects_confidence_requirements() { + let env = Env::default(); + let (client, admin) = setup_contract(&env); + + let validator = Address::generate(&env); + let target = Address::generate(&env); + + client.register_validator(&admin, &validator, &75_u32); + + // Admin raises minimum confidence to 90 + client.update_config(&admin, &None::, &Some(90_u32), &None::); + + // Report with confidence 80 should fail + let reason = String::from_str(&env, "Test report"); + let result = client.try_report_fraud(&validator, &target, &reason, &80_u32, &None::); + assert_eq!(result, Err(Ok(Error::InsufficientConfidence))); + } +} diff --git a/src/lib.rs b/src/lib.rs index fd294d2..d291f66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -391,3 +391,6 @@ mod test; #[cfg(test)] mod security_tests; + +#[cfg(test)] +mod auth_tests; diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..d7c921e --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,5 @@ +"""Integration tests for AstroML. + +This package contains end-to-end integration tests that verify +the complete workflows across multiple components. +""" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..a50b00f --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,563 @@ +"""Shared fixtures for integration tests. + +This module provides fixtures for setting up test databases, +sample data, and common test scenarios for integration testing. +""" +from __future__ import annotations + +import os +import tempfile +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import pytest +import yaml +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker + +from astroml.db.schema import ( + Account, + Asset, + Effect, + GraphAccount, + GraphEdge, + Ledger, + NormalizedTransaction, + Operation, + Transaction, + Base, +) + + +# --------------------------------------------------------------------------- +# Database fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="function") +def test_db_url(tmp_path: Path) -> str: + """Provide an in-memory SQLite database URL for testing.""" + return f"sqlite:///{tmp_path / 'test.db'}" + + +@pytest.fixture(scope="function") +def test_engine(test_db_url: str): + """Create a test database engine.""" + engine = create_engine(test_db_url, echo=False) + Base.metadata.create_all(engine) + yield engine + Base.metadata.drop_all(engine) + engine.dispose() + + +@pytest.fixture(scope="function") +def test_session(test_engine) -> Session: + """Create a test database session.""" + factory = sessionmaker(bind=test_engine) + session = factory() + yield session + session.close() + + +@pytest.fixture(scope="function") +def mock_config(tmp_path: Path): + """Create a mock configuration file.""" + config_dir = tmp_path / "config" + config_dir.mkdir() + + config = { + "database": { + "host": "localhost", + "port": 5432, + "name": "astroml_test", + "user": "test_user", + "password": "test_pass", + }, + "horizon": { + "url": "https://horizon-testnet.stellar.org", + }, + } + + config_file = config_dir / "database.yaml" + with open(config_file, "w") as f: + yaml.dump(config, f) + + # Change to temp directory for the test + original_cwd = os.getcwd() + os.chdir(tmp_path) + yield tmp_path + os.chdir(original_cwd) + + +# --------------------------------------------------------------------------- +# Sample data fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_ledger_data() -> List[Dict[str, Any]]: + """Sample ledger data for testing.""" + base_time = datetime(2024, 1, 1, 0, 0, 0) + return [ + { + "sequence": 1000, + "hash": "a" * 64, + "prev_hash": "b" * 64, + "closed_at": base_time, + "successful_transaction_count": 5, + "failed_transaction_count": 0, + "operation_count": 10, + "total_coins": 1000000000.0, + "fee_pool": 1000000.0, + "base_fee_in_stroops": 100, + "protocol_version": 20, + }, + { + "sequence": 1001, + "hash": "c" * 64, + "prev_hash": "a" * 64, + "closed_at": base_time + timedelta(seconds=5), + "successful_transaction_count": 3, + "failed_transaction_count": 1, + "operation_count": 8, + "total_coins": 1000000005.0, + "fee_pool": 1000005.0, + "base_fee_in_stroops": 100, + "protocol_version": 20, + }, + ] + + +@pytest.fixture +def sample_transaction_data() -> List[Dict[str, Any]]: + """Sample transaction data for testing.""" + base_time = datetime(2024, 1, 1, 0, 0, 0) + return [ + { + "hash": "tx1" + "a" * 60, + "ledger_sequence": 1000, + "source_account": "G" + "A" * 55, + "created_at": base_time, + "fee": 100, + "operation_count": 2, + "successful": True, + "memo_type": "none", + "memo": None, + }, + { + "hash": "tx2" + "b" * 60, + "ledger_sequence": 1000, + "source_account": "G" + "B" * 55, + "created_at": base_time + timedelta(seconds=1), + "fee": 200, + "operation_count": 1, + "successful": True, + "memo_type": "text", + "memo": "test", + }, + { + "hash": "tx3" + "c" * 60, + "ledger_sequence": 1001, + "source_account": "G" + "C" * 55, + "created_at": base_time + timedelta(seconds=6), + "fee": 150, + "operation_count": 3, + "successful": False, + "memo_type": "none", + "memo": None, + }, + ] + + +@pytest.fixture +def sample_operation_data() -> List[Dict[str, Any]]: + """Sample operation data for testing.""" + base_time = datetime(2024, 1, 1, 0, 0, 0) + return [ + { + "transaction_hash": "tx1" + "a" * 60, + "application_order": 0, + "type": "payment", + "source_account": "G" + "A" * 55, + "destination_account": "G" + "B" * 55, + "amount": 100.0, + "asset_code": "XLM", + "asset_issuer": None, + "created_at": base_time, + "details": {"type": "payment"}, + }, + { + "transaction_hash": "tx1" + "a" * 60, + "application_order": 1, + "type": "payment", + "source_account": "G" + "A" * 55, + "destination_account": "G" + "C" * 55, + "amount": 50.0, + "asset_code": "USDC", + "asset_issuer": "G" + "D" * 55, + "created_at": base_time, + "details": {"type": "payment"}, + }, + { + "transaction_hash": "tx2" + "b" * 60, + "application_order": 0, + "type": "create_account", + "source_account": "G" + "B" * 55, + "destination_account": "G" + "E" * 55, + "amount": None, + "asset_code": None, + "asset_issuer": None, + "created_at": base_time + timedelta(seconds=1), + "details": {"type": "create_account", "starting_balance": "100.0"}, + }, + ] + + +@pytest.fixture +def sample_account_data() -> List[Dict[str, Any]]: + """Sample account data for testing.""" + base_time = datetime(2024, 1, 1, 0, 0, 0) + return [ + { + "account_id": "G" + "A" * 55, + "balance": 1000.0, + "sequence": 100, + "home_domain": "example.com", + "flags": 0, + "last_modified_ledger": 1000, + "created_at": base_time - timedelta(days=30), + "updated_at": base_time, + }, + { + "account_id": "G" + "B" * 55, + "balance": 500.0, + "sequence": 50, + "home_domain": None, + "flags": 1, + "last_modified_ledger": 1000, + "created_at": base_time - timedelta(days=15), + "updated_at": base_time, + }, + ] + + +@pytest.fixture +def sample_asset_data() -> List[Dict[str, Any]]: + """Sample asset data for testing.""" + return [ + { + "asset_type": "native", + "asset_code": "XLM", + "asset_issuer": None, + "first_seen_ledger": 1000, + }, + { + "asset_type": "credit_alphanum4", + "asset_code": "USDC", + "asset_issuer": "G" + "D" * 55, + "first_seen_ledger": 1000, + }, + ] + + +@pytest.fixture +def sample_effect_data() -> List[Dict[str, Any]]: + """Sample effect data for testing.""" + base_time = datetime(2024, 1, 1, 0, 0, 0) + return [ + { + "account": "G" + "A" * 55, + "type": "account_debited", + "amount": -100.0, + "asset_code": "XLM", + "asset_issuer": None, + "destination_account": None, + "created_at": base_time, + "details": {"effect_type": "account_debited"}, + }, + { + "account": "G" + "B" * 55, + "type": "account_credited", + "amount": 100.0, + "asset_code": "XLM", + "asset_issuer": None, + "destination_account": None, + "created_at": base_time, + "details": {"effect_type": "account_credited"}, + }, + ] + + +@pytest.fixture +def sample_graph_edges() -> List[Dict[str, Any]]: + """Sample graph edge data for testing.""" + base_time = datetime(2024, 1, 1, 0, 0, 0) + return [ + { + "edge_type": "transaction", + "source_account_id": 1, + "destination_account_id": 2, + "asset_id": 1, + "occurred_at": base_time, + "ledger_sequence": 1000, + "event_index": 0, + "transaction_hash": "tx1" + "a" * 60, + "external_event_id": "evt1", + "amount": 100.0, + "status": "completed", + }, + { + "edge_type": "payment", + "source_account_id": 2, + "destination_account_id": 3, + "asset_id": 2, + "occurred_at": base_time + timedelta(seconds=1), + "ledger_sequence": 1000, + "event_index": 1, + "transaction_hash": "tx2" + "b" * 60, + "external_event_id": "evt2", + "amount": 50.0, + "status": "completed", + }, + ] + + +# --------------------------------------------------------------------------- +# Populated database fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def populated_test_db( + test_session: Session, + sample_ledger_data: List[Dict[str, Any]], + sample_transaction_data: List[Dict[str, Any]], + sample_operation_data: List[Dict[str, Any]], + sample_account_data: List[Dict[str, Any]], + sample_asset_data: List[Dict[str, Any]], + sample_effect_data: List[Dict[str, Any]], +) -> Session: + """Populate test database with sample data.""" + # Add ledgers + for ledger_data in sample_ledger_data: + ledger = Ledger(**ledger_data) + test_session.add(ledger) + + # Add assets + for asset_data in sample_asset_data: + asset = Asset(**asset_data) + test_session.add(asset) + + test_session.flush() + + # Add accounts + for account_data in sample_account_data: + account = Account(**account_data) + test_session.add(account) + + # Add transactions + for tx_data in sample_transaction_data: + transaction = Transaction(**tx_data) + test_session.add(transaction) + + test_session.flush() + + # Add operations + for op_data in sample_operation_data: + operation = Operation(**op_data) + test_session.add(operation) + + # Add effects + for effect_data in sample_effect_data: + effect = Effect(**effect_data) + test_session.add(effect) + + test_session.commit() + yield test_session + test_session.rollback() + + +@pytest.fixture +def populated_graph_db( + test_session: Session, + sample_asset_data: List[Dict[str, Any]], + sample_graph_edges: List[Dict[str, Any]], +) -> Session: + """Populate test database with graph data.""" + # Add assets + for asset_data in sample_asset_data: + asset = Asset(**asset_data) + test_session.add(asset) + + test_session.flush() + + # Add graph accounts + accounts = [ + GraphAccount( + id=1, + account_address="G" + "A" * 55, + account_type="user", + first_seen_at=datetime(2024, 1, 1), + last_seen_at=datetime(2024, 1, 2), + ), + GraphAccount( + id=2, + account_address="G" + "B" * 55, + account_type="user", + first_seen_at=datetime(2024, 1, 1), + last_seen_at=datetime(2024, 1, 2), + ), + GraphAccount( + id=3, + account_address="G" + "C" * 55, + account_type="user", + first_seen_at=datetime(2024, 1, 1), + last_seen_at=datetime(2024, 1, 2), + ), + ] + for account in accounts: + test_session.add(account) + + test_session.flush() + + # Add graph edges + for edge_data in sample_graph_edges: + edge = GraphEdge(**edge_data) + test_session.add(edge) + + test_session.commit() + yield test_session + test_session.rollback() + + +# --------------------------------------------------------------------------- +# Synthetic fraud data fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def synthetic_fraud_patterns() -> Dict[str, Any]: + """Synthetic fraud pattern configurations for testing.""" + return { + "sybil_clusters": [ + { + "cluster_id": "cluster_1", + "accounts": [f"G{'A' * i}{'B' * (55-i)}" for i in range(5)], + "coordinator": "G" + "X" * 55, + "behavior": "circular_transactions", + } + ], + "wash_trading_loops": [ + { + "loop_id": "loop_1", + "accounts": [f"G{'C' * i}{'D' * (55-i)}" for i in range(3)], + "asset": "USDC", + "frequency": "high", + } + ], + } + + +@pytest.fixture +def fraud_labels() -> np.ndarray: + """Sample fraud labels for testing.""" + np.random.seed(42) + # 10% fraud rate + labels = np.zeros(1000) + fraud_indices = np.random.choice(1000, size=100, replace=False) + labels[fraud_indices] = 1 + return labels + + +@pytest.fixture +def fraud_scores() -> np.ndarray: + """Sample fraud scores for testing.""" + np.random.seed(42) + scores = np.random.beta(2, 5, 1000) + return scores + + +# --------------------------------------------------------------------------- +# ML fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_node_features() -> Dict[str, np.ndarray]: + """Sample node features for ML testing.""" + np.random.seed(42) + features = { + f"node_{i}": np.random.randn(16).astype(np.float32) + for i in range(10) + } + return features + + +@pytest.fixture +def sample_edge_list() -> List[tuple]: + """Sample edge list for graph testing.""" + edges = [ + ("node_0", "node_1", 1.0, 1000.0), + ("node_1", "node_2", 0.5, 2000.0), + ("node_2", "node_3", 2.0, 3000.0), + ("node_3", "node_4", 1.5, 4000.0), + ("node_4", "node_0", 0.8, 5000.0), + ] + return edges + + +@pytest.fixture +def sample_training_data() -> tuple: + """Sample training data for model testing.""" + np.random.seed(42) + num_samples = 100 + num_features = 16 + + X = np.random.randn(num_samples, num_features).astype(np.float32) + y = np.random.randint(0, 2, num_samples) + + return X, y + + +# --------------------------------------------------------------------------- +# Temporary directory fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def temp_data_dir(tmp_path: Path) -> Path: + """Create a temporary data directory.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + return data_dir + + +@pytest.fixture +def temp_output_dir(tmp_path: Path) -> Path: + """Create a temporary output directory.""" + output_dir = tmp_path / "outputs" + output_dir.mkdir() + return output_dir + + +# --------------------------------------------------------------------------- +# Mock Horizon API fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_horizon_response(): + """Mock Horizon API response data.""" + return { + "hash": "x" * 64, + "ledger": 1000, + "source_account": "G" + "A" * 55, + "created_at": "2024-01-01T00:00:00Z", + "fee_charged": 100, + "operation_count": 2, + "successful": True, + "memo_type": "none", + "paging_token": "12345", + } diff --git a/tests/integration/test_authentication.py b/tests/integration/test_authentication.py new file mode 100644 index 0000000..43dce55 --- /dev/null +++ b/tests/integration/test_authentication.py @@ -0,0 +1,548 @@ +"""Integration tests for authentication and authorization in AstroML. + +These tests verify the complete authentication flow including: +- Admin initialization and authorization +- Validator registration and lifecycle +- Access control for privileged operations +- Session-like behavior through validator state +- Configuration-based authentication changes +""" +from __future__ import annotations + +import pytest +from typing import Any, Dict +from unittest.mock import MagicMock, patch + + +class TestAdminAuthenticationFlow: + """Integration tests for complete admin authentication flow.""" + + def test_admin_initialization_to_validator_registration_flow( + self, + ) -> None: + """Test complete flow from admin initialization to validator registration.""" + # This would test the Rust contract integration + # For now, we'll create a Python mock that mirrors the contract behavior + + class MockContract: + def __init__(self): + self.admin = None + self.validators = {} + self.config = { + "min_reputation": 50, + "min_confidence": 60, + "consensus_threshold": 3, + } + + def initialize(self, admin_address: str) -> None: + self.admin = admin_address + + def register_validator( + self, + admin_address: str, + validator_address: str, + reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + if validator_address in self.validators: + raise ValueError("ValidatorAlreadyExists") + if not (0 <= reputation <= 100): + raise ValueError("InvalidInput") + self.validators[validator_address] = { + "reputation": reputation, + "is_active": True, + "report_count": 0, + } + + contract = MockContract() + admin = "GADMIN1234567890123456789012345678901234567890123456789012345" + validator = "GVALIDATOR1234567890123456789012345678901234567890123456789" + + # Initialize contract with admin + contract.initialize(admin) + assert contract.admin == admin + + # Register validator as admin + contract.register_validator(admin, validator, 75) + assert validator in contract.validators + assert contract.validators[validator]["reputation"] == 75 + + def test_non_admin_registration_failure_flow( + self, + ) -> None: + """Test that non-admin cannot register validators.""" + class MockContract: + def __init__(self): + self.admin = None + self.validators = {} + + def initialize(self, admin_address: str) -> None: + self.admin = admin_address + + def register_validator( + self, + admin_address: str, + validator_address: str, + reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address] = { + "reputation": reputation, + "is_active": True, + } + + contract = MockContract() + admin = "GADMIN1234567890123456789012345678901234567890123456789012345" + attacker = "GATTACKER1234567890123456789012345678901234567890123456789" + validator = "GVALIDATOR1234567890123456789012345678901234567890123456789" + + contract.initialize(admin) + + # Try to register as attacker + with pytest.raises(PermissionError, match="Unauthorized"): + contract.register_validator(attacker, validator, 75) + + def test_admin_config_update_flow( + self, + ) -> None: + """Test admin can update configuration which affects authentication.""" + class MockContract: + def __init__(self): + self.admin = None + self.config = { + "min_reputation": 50, + "min_confidence": 60, + "consensus_threshold": 3, + } + + def initialize(self, admin_address: str) -> None: + self.admin = admin_address + + def update_config( + self, + admin_address: str, + min_reputation: int | None = None, + min_confidence: int | None = None, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + if min_reputation is not None: + if not (0 <= min_reputation <= 100): + raise ValueError("InvalidInput") + self.config["min_reputation"] = min_reputation + if min_confidence is not None: + if not (0 <= min_confidence <= 100): + raise ValueError("InvalidInput") + self.config["min_confidence"] = min_confidence + + contract = MockContract() + admin = "GADMIN1234567890123456789012345678901234567890123456789012345" + + contract.initialize(admin) + assert contract.config["min_reputation"] == 50 + + # Update config as admin + contract.update_config(admin, min_reputation=70, min_confidence=80) + assert contract.config["min_reputation"] == 70 + assert contract.config["min_confidence"] == 80 + + +class TestValidatorLifecycleIntegration: + """Integration tests for complete validator lifecycle authentication.""" + + def test_validator_registration_to_deactivation_flow( + self, + ) -> None: + """Test complete flow from registration to deactivation.""" + class MockContract: + def __init__(self): + self.admin = None + self.validators = {} + + def initialize(self, admin_address: str) -> None: + self.admin = admin_address + + def register_validator( + self, + admin_address: str, + validator_address: str, + reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address] = { + "reputation": reputation, + "is_active": True, + "report_count": 0, + } + + def deactivate_validator( + self, + admin_address: str, + validator_address: str, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + if validator_address not in self.validators: + raise LookupError("ValidatorNotFound") + self.validators[validator_address]["is_active"] = False + + def submit_report( + self, + validator_address: str, + target_address: str, + confidence: int, + ) -> None: + validator = self.validators.get(validator_address) + if validator is None: + raise LookupError("ValidatorNotFound") + if not validator["is_active"]: + raise PermissionError("ValidatorNotActive") + if validator["reputation"] < 50: + raise PermissionError("InsufficientReputation") + if confidence < 60: + raise ValueError("InsufficientConfidence") + validator["report_count"] += 1 + + contract = MockContract() + admin = "GADMIN1234567890123456789012345678901234567890123456789012345" + validator = "GVALIDATOR1234567890123456789012345678901234567890123456789" + target = "GTARGET1234567890123456789012345678901234567890123456789" + + contract.initialize(admin) + contract.register_validator(admin, validator, 75) + + # Validator can submit reports + contract.submit_report(validator, target, 80) + assert contract.validators[validator]["report_count"] == 1 + + # Admin deactivates validator + contract.deactivate_validator(admin, validator) + assert not contract.validators[validator]["is_active"] + + # Validator can no longer submit reports + with pytest.raises(PermissionError, match="ValidatorNotActive"): + contract.submit_report(validator, target, 80) + + def test_reputation_update_affects_authentication_flow( + self, + ) -> None: + """Test that reputation updates affect authentication capabilities.""" + class MockContract: + def __init__(self): + self.admin = None + self.validators = {} + self.config = {"min_reputation": 50} + + def initialize(self, admin_address: str) -> None: + self.admin = admin_address + + def register_validator( + self, + admin_address: str, + validator_address: str, + reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address] = { + "reputation": reputation, + "is_active": True, + } + + def update_reputation( + self, + admin_address: str, + validator_address: str, + new_reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address]["reputation"] = new_reputation + + def submit_report( + self, + validator_address: str, + confidence: int, + ) -> None: + validator = self.validators[validator_address] + if validator["reputation"] < self.config["min_reputation"]: + raise PermissionError("InsufficientReputation") + + contract = MockContract() + admin = "GADMIN1234567890123456789012345678901234567890123456789012345" + validator = "GVALIDATOR1234567890123456789012345678901234567890123456789" + + contract.initialize(admin) + + # Register with low reputation + contract.register_validator(admin, validator, 30) + + # Cannot submit reports + with pytest.raises(PermissionError, match="InsufficientReputation"): + contract.submit_report(validator, 80) + + # Admin updates reputation + contract.update_reputation(admin, validator, 75) + + # Can now submit reports + contract.submit_report(validator, 80) + + +class TestAuthorizationScenarios: + """Integration tests for complex authorization scenarios.""" + + def test_config_change_affects_all_validators_flow( + self, + ) -> None: + """Test that config changes affect authentication for all validators.""" + class MockContract: + def __init__(self): + self.admin = None + self.validators = {} + self.config = {"min_reputation": 50} + + def initialize(self, admin_address: str) -> None: + self.admin = admin_address + + def register_validator( + self, + admin_address: str, + validator_address: str, + reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address] = { + "reputation": reputation, + "is_active": True, + } + + def update_config( + self, + admin_address: str, + min_reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.config["min_reputation"] = min_reputation + + def submit_report( + self, + validator_address: str, + ) -> None: + validator = self.validators[validator_address] + if validator["reputation"] < self.config["min_reputation"]: + raise PermissionError("InsufficientReputation") + + contract = MockContract() + admin = "GADMIN1234567890123456789012345678901234567890123456789012345" + validator1 = "GVALIDATOR11234567890123456789012345678901234567890123456789" + validator2 = "GVALIDATOR21234567890123456789012345678901234567890123456789" + + contract.initialize(admin) + + # Register validators with reputation 60 + contract.register_validator(admin, validator1, 60) + contract.register_validator(admin, validator2, 60) + + # Both can submit reports + contract.submit_report(validator1) + contract.submit_report(validator2) + + # Admin raises minimum to 70 + contract.update_config(admin, 70) + + # Neither can submit reports now + with pytest.raises(PermissionError, match="InsufficientReputation"): + contract.submit_report(validator1) + with pytest.raises(PermissionError, match="InsufficientReputation"): + contract.submit_report(validator2) + + def test_cascading_authorization_failures( + self, + ) -> None: + """Test that authorization failures cascade properly through operations.""" + class MockContract: + def __init__(self): + self.admin = None + self.validators = {} + + def initialize(self, admin_address: str) -> None: + self.admin = admin_address + + def register_validator( + self, + admin_address: str, + validator_address: str, + reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address] = { + "reputation": reputation, + "is_active": True, + } + + def deactivate_validator( + self, + admin_address: str, + validator_address: str, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address]["is_active"] = False + + def update_reputation( + self, + admin_address: str, + validator_address: str, + new_reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address]["reputation"] = new_reputation + + contract = MockContract() + admin = "GADMIN1234567890123456789012345678901234567890123456789012345" + attacker = "GATTACKER1234567890123456789012345678901234567890123456789" + validator = "GVALIDATOR1234567890123456789012345678901234567890123456789" + + contract.initialize(admin) + contract.register_validator(admin, validator, 75) + + # Attacker tries multiple unauthorized operations + with pytest.raises(PermissionError, match="Unauthorized"): + contract.register_validator(attacker, validator, 75) + + with pytest.raises(PermissionError, match="Unauthorized"): + contract.deactivate_validator(attacker, validator) + + with pytest.raises(PermissionError, match="Unauthorized"): + contract.update_reputation(attacker, validator, 50) + + +class TestSessionLikeBehavior: + """Integration tests for session-like behavior through validator state.""" + + def test_validator_state_persists_across_multiple_operations( + self, + ) -> None: + """Test that validator state persists like a session across operations.""" + class MockContract: + def __init__(self): + self.admin = None + self.validators = {} + self.reports = {} + + def initialize(self, admin_address: str) -> None: + self.admin = admin_address + + def register_validator( + self, + admin_address: str, + validator_address: str, + reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address] = { + "reputation": reputation, + "is_active": True, + "report_count": 0, + "registration_timestamp": 1234567890, + } + + def submit_report( + self, + validator_address: str, + target_address: str, + ) -> None: + validator = self.validators[validator_address] + validator["report_count"] += 1 + if target_address not in self.reports: + self.reports[target_address] = [] + self.reports[target_address].append({ + "validator": validator_address, + "timestamp": 1234567890, + }) + + contract = MockContract() + admin = "GADMIN1234567890123456789012345678901234567890123456789012345" + validator = "GVALIDATOR1234567890123456789012345678901234567890123456789" + target1 = "GTARGET11234567890123456789012345678901234567890123456789" + target2 = "GTARGET21234567890123456789012345678901234567890123456789" + + contract.initialize(admin) + contract.register_validator(admin, validator, 75) + + # Submit multiple reports + contract.submit_report(validator, target1) + contract.submit_report(validator, target2) + contract.submit_report(validator, target1) + + # Verify state persistence + assert contract.validators[validator]["report_count"] == 3 + assert len(contract.reports[target1]) == 2 + assert len(contract.reports[target2]) == 1 + + def test_deactivation_resets_session_like_capabilities( + self, + ) -> None: + """Test that deactivation resets session-like validator capabilities.""" + class MockContract: + def __init__(self): + self.admin = None + self.validators = {} + + def initialize(self, admin_address: str) -> None: + self.admin = admin_address + + def register_validator( + self, + admin_address: str, + validator_address: str, + reputation: int, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address] = { + "reputation": reputation, + "is_active": True, + } + + def deactivate_validator( + self, + admin_address: str, + validator_address: str, + ) -> None: + if self.admin != admin_address: + raise PermissionError("Unauthorized") + self.validators[validator_address]["is_active"] = False + + def submit_report( + self, + validator_address: str, + ) -> None: + validator = self.validators[validator_address] + if not validator["is_active"]: + raise PermissionError("ValidatorNotActive") + + contract = MockContract() + admin = "GADMIN1234567890123456789012345678901234567890123456789012345" + validator = "GVALIDATOR1234567890123456789012345678901234567890123456789" + + contract.initialize(admin) + contract.register_validator(admin, validator, 75) + + # Can submit reports + contract.submit_report(validator) + + # Deactivate + contract.deactivate_validator(admin, validator) + + # Can no longer submit reports + with pytest.raises(PermissionError, match="ValidatorNotActive"): + contract.submit_report(validator) diff --git a/tests/integration/test_feature_engineering.py b/tests/integration/test_feature_engineering.py new file mode 100644 index 0000000..12680af --- /dev/null +++ b/tests/integration/test_feature_engineering.py @@ -0,0 +1,514 @@ +"""Integration tests for the feature engineering pipeline. + +These tests verify the complete workflow from database operations +to computed features, including feature store integration and caching. +""" +from __future__ import annotations + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import pandas as pd +import pytest +from sqlalchemy.orm import Session + +from astroml.db.schema import Operation, Transaction, Ledger +from astroml.features.node_features import compute_node_features +from astroml.features.feature_store import ( + FeatureStore, + FeatureDefinition, + FeatureType, + FeatureStatus, +) +from astroml.features.feature_engine import FeatureEngineering as FeatureEngine, ComputationTask, ComputationStatus +from astroml.features.feature_cache import FeatureCache + + +class TestNodeFeaturesIntegration: + """Integration tests for node feature computation from database.""" + + def test_compute_features_from_database_operations( + self, + populated_test_db: Session, + ) -> None: + """Test computing node features directly from database operations.""" + # Query operations from database + operations = populated_test_db.query(Operation).all() + + # Convert to edge format + edges = [] + for op in operations: + if op.destination_account: + edges.append({ + 'src': op.source_account, + 'dst': op.destination_account, + 'amount': float(op.amount) if op.amount else 0.0, + 'timestamp': op.created_at.timestamp(), + 'asset': op.asset_code or 'XLM', + }) + + # Compute features + features_df = compute_node_features(edges) + + # Verify features were computed + assert not features_df.empty + assert 'in_degree' in features_df.columns + assert 'out_degree' in features_df.columns + assert 'total_received' in features_df.columns + assert 'total_sent' in features_df.columns + assert 'account_age' in features_df.columns + + # Verify data types + assert features_df['in_degree'].dtype == np.int64 + assert features_df['out_degree'].dtype == np.int64 + assert features_df['total_received'].dtype == float + assert features_df['total_sent'].dtype == float + + def test_compute_features_with_first_seen_provided( + self, + populated_test_db: Session, + ) -> None: + """Test computing features with externally provided first_seen timestamps.""" + operations = populated_test_db.query(Operation).all() + + edges = [] + for op in operations: + if op.destination_account: + edges.append({ + 'src': op.source_account, + 'dst': op.destination_account, + 'amount': float(op.amount) if op.amount else 0.0, + 'timestamp': op.created_at.timestamp(), + 'asset': op.asset_code or 'XLM', + }) + + # Provide external first_seen data + base_time = datetime(2024, 1, 1) + nodes_first_seen = { + 'G' + 'A' * 55: (base_time - timedelta(days=30)).timestamp(), + 'G' + 'B' * 55: (base_time - timedelta(days=15)).timestamp(), + } + + features_df = compute_node_features( + edges, + nodes_first_seen=nodes_first_seen, + ref_time=base_time.timestamp(), + ) + + # Verify account age uses provided first_seen where available + assert 'account_age' in features_df.columns + assert features_df['account_age'].min() >= 0 + + def test_compute_features_with_empty_edges( + self, + ) -> None: + """Test computing features with empty edge list.""" + features_df = compute_node_features([]) + + # Should return empty DataFrame with correct columns + assert features_df.empty + expected_columns = [ + 'in_degree', 'out_degree', 'total_received', 'total_sent', + 'account_age', 'first_seen', 'unique_asset_count', 'asset_entropy' + ] + assert list(features_df.columns) == expected_columns + + +class TestFeatureStoreIntegration: + """Integration tests for feature store with database.""" + + def test_register_and_retrieve_feature( + self, + test_session: Session, + temp_data_dir: Path, + ) -> None: + """Test registering a feature definition and retrieving it.""" + store_path = temp_data_dir / "feature_store.db" + store = FeatureStore(store_path=str(store_path)) + + # Define a simple feature + def simple_feature(data: pd.DataFrame) -> pd.DataFrame: + return data[['in_degree', 'out_degree']] + + feature_def = FeatureDefinition( + name="degree_features", + description="Simple degree features", + feature_type=FeatureType.NUMERIC, + computation_function=simple_feature, + tags=["graph", "basic"], + owner="ml-team", + status=FeatureStatus.PRODUCTION, + ) + + # Register feature + store.register_feature(feature_def) + + # Retrieve feature + retrieved = store.get_feature("degree_features", version=1) + + assert retrieved is not None + assert retrieved.name == "degree_features" + assert retrieved.status == FeatureStatus.PRODUCTION + assert "graph" in retrieved.tags + + def test_compute_and_cache_features( + self, + test_session: Session, + temp_data_dir: Path, + sample_node_features: Dict[str, np.ndarray], + ) -> None: + """Test computing features and caching them.""" + cache_path = temp_data_dir / "feature_cache.db" + cache = FeatureCache(cache_path=str(cache_path)) + + # Create sample feature data + feature_data = pd.DataFrame.from_dict(sample_node_features, orient='index') + feature_data.index.name = 'node_id' + + # Cache features + cache.put_features( + feature_name="test_features", + features=feature_data, + metadata={"version": 1, "computed_at": datetime.utcnow().isoformat()}, + ) + + # Retrieve cached features + cached = cache.get_features("test_features") + + assert cached is not None + assert cached.shape == feature_data.shape + assert np.allclose(cached.values, feature_data.values) + + def test_feature_versioning( + self, + temp_data_dir: Path, + ) -> None: + """Test feature versioning in the store.""" + store_path = temp_data_dir / "feature_store.db" + store = FeatureStore(store_path=str(store_path)) + + # Register version 1 + feature_v1 = FeatureDefinition( + name="evolving_feature", + description="First version", + feature_type=FeatureType.NUMERIC, + version=1, + ) + store.register_feature(feature_v1) + + # Register version 2 + feature_v2 = FeatureDefinition( + name="evolving_feature", + description="Second version with improvements", + feature_type=FeatureType.NUMERIC, + version=2, + ) + store.register_feature(feature_v2) + + # Retrieve both versions + v1 = store.get_feature("evolving_feature", version=1) + v2 = store.get_feature("evolving_feature", version=2) + + assert v1 is not None + assert v2 is not None + assert v1.version == 1 + assert v2.version == 2 + assert v1.description != v2.description + + def test_feature_lineage_tracking( + self, + temp_data_dir: Path, + ) -> None: + """Test tracking feature lineage and dependencies.""" + store_path = temp_data_dir / "feature_store.db" + store = FeatureStore(store_path=str(store_path)) + + # Register base feature + base_feature = FeatureDefinition( + name="base_transaction_count", + description="Count of transactions", + feature_type=FeatureType.NUMERIC, + ) + store.register_feature(base_feature) + + # Register derived feature + derived_feature = FeatureDefinition( + name="normalized_transaction_count", + description="Normalized transaction count", + feature_type=FeatureType.NUMERIC, + parameters={"base_feature": "base_transaction_count"}, + metadata={"depends_on": ["base_transaction_count"]}, + ) + store.register_feature(derived_feature) + + # Retrieve lineage + lineage = store.get_feature_lineage("normalized_transaction_count") + + assert lineage is not None + assert "base_transaction_count" in lineage + + +class TestFeatureEngineIntegration: + """Integration tests for feature computation engine.""" + + def test_execute_computation_task( + self, + sample_node_features: Dict[str, np.ndarray], + ) -> None: + """Test executing a single feature computation task.""" + engine = FeatureEngine() + + # Create sample input data + input_data = pd.DataFrame.from_dict(sample_node_features, orient='index') + input_data.index.name = 'node_id' + + # Define a simple computation function + def compute_sum(data: pd.DataFrame) -> pd.DataFrame: + return data.sum(axis=1).to_frame('feature_sum') + + # Create task + task = ComputationTask( + task_id="test_task_1", + feature_name="sum_feature", + data=input_data, + parameters={}, + ) + + # Execute task + result = engine.execute_task(task, compute_sum) + + assert result is not None + assert result.status == ComputationStatus.COMPLETED + assert result.result is not None + assert 'feature_sum' in result.result.columns + + def test_parallel_feature_computation( + self, + sample_node_features: Dict[str, np.ndarray], + ) -> None: + """Test parallel computation of multiple features.""" + engine = FeatureEngine(max_workers=2) + + input_data = pd.DataFrame.from_dict(sample_node_features, orient='index') + input_data.index.name = 'node_id' + + # Define multiple computation functions + def compute_mean(data: pd.DataFrame) -> pd.DataFrame: + return data.mean(axis=1).to_frame('feature_mean') + + def compute_std(data: pd.DataFrame) -> pd.DataFrame: + return data.std(axis=1).to_frame('feature_std') + + # Create tasks + tasks = [ + ComputationTask( + task_id=f"task_{i}", + feature_name=f"feature_{i}", + data=input_data, + ) + for i in range(2) + ] + + # Execute in parallel + results = engine.execute_parallel( + tasks, + [compute_mean, compute_std], + ) + + assert len(results) == 2 + assert all(r.status == ComputationStatus.COMPLETED for r in results) + assert all(r.result is not None for r in results) + + def test_feature_dependency_resolution( + self, + sample_node_features: Dict[str, np.ndarray], + ) -> None: + """Test resolving feature dependencies during computation.""" + engine = FeatureEngine() + + input_data = pd.DataFrame.from_dict(sample_node_features, orient='index') + + # Define dependent features + def base_feature(data: pd.DataFrame) -> pd.DataFrame: + return data.iloc[:, :2].copy() + + def derived_feature(data: pd.DataFrame) -> pd.DataFrame: + # Depends on base_feature output + return data.sum(axis=1).to_frame('derived') + + # Create tasks with dependencies + base_task = ComputationTask( + task_id="base_task", + feature_name="base_feature", + data=input_data, + ) + + derived_task = ComputationTask( + task_id="derived_task", + feature_name="derived_feature", + data=input_data, # Will be replaced with base_task result + ) + + # Execute base task + base_result = engine.execute_task(base_task, base_feature) + + # Execute derived task with base result as input + derived_result = engine.execute_task( + derived_task, + derived_feature, + input_data=base_result.result, + ) + + assert base_result.status == ComputationStatus.COMPLETED + assert derived_result.status == ComputationStatus.COMPLETED + + +class TestEndToEndFeaturePipeline: + """Integration tests for complete feature engineering pipeline.""" + + def test_database_to_features_pipeline( + self, + populated_test_db: Session, + temp_data_dir: Path, + ) -> None: + """Test complete pipeline from database to computed features.""" + # Step 1: Extract operations from database + operations = populated_test_db.query(Operation).all() + + # Step 2: Convert to edge format + edges = [] + for op in operations: + if op.destination_account: + edges.append({ + 'src': op.source_account, + 'dst': op.destination_account, + 'amount': float(op.amount) if op.amount else 0.0, + 'timestamp': op.created_at.timestamp(), + 'asset': op.asset_code or 'XLM', + }) + + # Step 3: Compute node features + features_df = compute_node_features(edges) + + # Step 4: Cache features + cache_path = temp_data_dir / "feature_cache.db" + cache = FeatureCache(cache_path=str(cache_path)) + cache.put_features( + feature_name="node_features", + features=features_df, + metadata={"source": "database", "computed_at": datetime.utcnow().isoformat()}, + ) + + # Step 5: Retrieve cached features + cached_features = cache.get_features("node_features") + + # Verify pipeline + assert not features_df.empty + assert cached_features is not None + assert cached_features.equals(features_df) + + def test_feature_store_workflow( + self, + temp_data_dir: Path, + sample_node_features: Dict[str, np.ndarray], + ) -> None: + """Test complete feature store workflow.""" + store_path = temp_data_dir / "feature_store.db" + store = FeatureStore(store_path=str(store_path)) + + # Step 1: Register feature definition + def aggregate_features(data: pd.DataFrame) -> pd.DataFrame: + return data.agg(['mean', 'std']).T + + feature_def = FeatureDefinition( + name="aggregate_stats", + description="Aggregate statistics for node features", + feature_type=FeatureType.NUMERIC, + computation_function=aggregate_features, + status=FeatureStatus.PRODUCTION, + ) + store.register_feature(feature_def) + + # Step 2: Prepare input data + input_data = pd.DataFrame.from_dict(sample_node_features, orient='index') + + # Step 3: Compute feature + computed = feature_def.computation_function(input_data) + + # Step 4: Store computed feature + cache_path = temp_data_dir / "feature_cache.db" + cache = FeatureCache(cache_path=str(cache_path)) + cache.put_features( + feature_name="aggregate_stats", + features=computed, + metadata={"feature_id": feature_def.feature_id}, + ) + + # Step 5: Retrieve and verify + retrieved = cache.get_features("aggregate_stats") + + assert retrieved is not None + assert not retrieved.empty + assert 'mean' in retrieved.columns or 'std' in retrieved.columns + + def test_incremental_feature_update( + self, + populated_test_db: Session, + temp_data_dir: Path, + ) -> None: + """Test incremental feature updates as new data arrives.""" + cache_path = temp_data_dir / "feature_cache.db" + cache = FeatureCache(cache_path=str(cache_path)) + + # Initial computation + operations = populated_test_db.query(Operation).limit(2).all() + edges = [ + { + 'src': op.source_account, + 'dst': op.destination_account, + 'amount': float(op.amount) if op.amount else 0.0, + 'timestamp': op.created_at.timestamp(), + 'asset': op.asset_code or 'XLM', + } + for op in operations + if op.destination_account + ] + + initial_features = compute_node_features(edges) + cache.put_features("node_features", initial_features) + + # Add new operation + new_op = Operation( + id=999, + transaction_hash="tx_new", + application_order=0, + type="payment", + source_account="G" + "X" * 55, + destination_account="G" + "Y" * 55, + amount=150.0, + asset_code="XLM", + created_at=datetime(2024, 1, 2), + ) + populated_test_db.add(new_op) + populated_test_db.commit() + + # Recompute with new data + all_operations = populated_test_db.query(Operation).all() + edges = [ + { + 'src': op.source_account, + 'dst': op.destination_account, + 'amount': float(op.amount) if op.amount else 0.0, + 'timestamp': op.created_at.timestamp(), + 'asset': op.asset_code or 'XLM', + } + for op in all_operations + if op.destination_account + ] + + updated_features = compute_node_features(edges) + + # Verify update + assert len(updated_features) >= len(initial_features) diff --git a/tests/integration/test_full_pipeline.py b/tests/integration/test_full_pipeline.py new file mode 100644 index 0000000..c1ddb4f --- /dev/null +++ b/tests/integration/test_full_pipeline.py @@ -0,0 +1,577 @@ +"""Comprehensive end-to-end pipeline integration tests. + +These tests verify the complete AstroML workflow from raw ledger data +to trained models, including all intermediate steps: ingestion, +feature engineering, graph construction, model training, and validation. +""" +from __future__ import annotations + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import pandas as pd +import pytest +import torch +from sqlalchemy.orm import Session + +from astroml.db.schema import Ledger, Transaction, Operation, Account, Asset +from astroml.ingestion.service import IngestionService +from astroml.ingestion.parsers import parse_ledger, parse_transaction, parse_operation +from astroml.features.node_features import compute_node_features +from astroml.features.graph.snapshot import Edge, window_snapshot +from astroml.features.transaction_graph import TransactionGraph +from astroml.models.gcn import GCN +from astroml.validation.calibration import CalibrationAnalyzer +from astroml.validation.validator import TransactionValidator + + +class TestFullPipelineIntegration: + """Integration tests for the complete end-to-end pipeline.""" + + def test_ledger_to_model_pipeline( + self, + test_session: Session, + temp_output_dir: Path, + ) -> None: + """Test complete pipeline from ledger ingestion to model training.""" + # Step 1: Ingest ledger data + ledger_data = { + "sequence": 1000, + "hash": "a" * 64, + "prev_hash": "b" * 64, + "closed_at": datetime(2024, 1, 1), + "successful_transaction_count": 2, + "failed_transaction_count": 0, + "operation_count": 4, + } + ledger = parse_ledger(ledger_data) + test_session.add(ledger) + test_session.commit() + + # Step 2: Ingest transactions + tx_data_1 = { + "hash": "tx1" + "a" * 60, + "ledger": 1000, + "source_account": "G" + "A" * 55, + "created_at": datetime(2024, 1, 1), + "fee_charged": 100, + "operation_count": 2, + "successful": True, + "memo_type": "none", + } + tx_data_2 = { + "hash": "tx2" + "b" * 60, + "ledger": 1000, + "source_account": "G" + "B" * 55, + "created_at": datetime(2024, 1, 1), + "fee_charged": 200, + "operation_count": 2, + "successful": True, + "memo_type": "none", + } + + tx1 = parse_transaction(tx_data_1) + tx2 = parse_transaction(tx_data_2) + test_session.add(tx1) + test_session.add(tx2) + test_session.commit() + + # Step 3: Ingest operations + op_data_1 = { + "id": 1, + "transaction_hash": "tx1" + "a" * 60, + "source_account": "G" + "A" * 55, + "type": "payment", + "to": "G" + "B" * 55, + "amount": "100.0", + "asset_type": "native", + "created_at": datetime(2024, 1, 1), + } + op_data_2 = { + "id": 2, + "transaction_hash": "tx1" + "a" * 60, + "source_account": "G" + "A" * 55, + "type": "payment", + "to": "G" + "C" * 55, + "amount": "50.0", + "asset_type": "native", + "created_at": datetime(2024, 1, 1), + } + op_data_3 = { + "id": 3, + "transaction_hash": "tx2" + "b" * 60, + "source_account": "G" + "B" * 55, + "type": "payment", + "to": "G" + "C" * 55, + "amount": "75.0", + "asset_type": "native", + "created_at": datetime(2024, 1, 1), + } + + op1 = parse_operation(op_data_1, application_order=0) + op2 = parse_operation(op_data_2, application_order=1) + op3 = parse_operation(op_data_3, application_order=0) + test_session.add(op1) + test_session.add(op2) + test_session.add(op3) + test_session.commit() + + # Step 4: Extract operations and compute features + operations = test_session.query(Operation).all() + edges = [] + for op in operations: + if op.destination_account: + edges.append({ + 'src': op.source_account, + 'dst': op.destination_account, + 'amount': float(op.amount) if op.amount else 0.0, + 'timestamp': op.created_at.timestamp(), + 'asset': op.asset_code or 'XLM', + }) + + features_df = compute_node_features(edges) + + # Verify features computed + assert not features_df.empty + assert len(features_df) == 3 # A, B, C + + # Step 5: Build graph + graph = TransactionGraph() + for op in operations: + if op.destination_account: + graph.add_transaction( + from_account=op.source_account, + to_account=op.destination_account, + amount=float(op.amount) if op.amount else 0.0, + asset=op.asset_code or 'XLM', + ) + + # Verify graph + summary = graph.summary() + assert summary["node_count"] == 3 + assert summary["transaction_count"] == 3 + + # Step 6: Train simple model + # Convert features to tensor + feature_matrix = features_df.values.astype(np.float32) + num_nodes = feature_matrix.shape[0] + + # Create simple edge index + node_to_idx = {node: i for i, node in enumerate(features_df.index)} + edge_index = [] + for op in operations: + if op.destination_account: + src_idx = node_to_idx.get(op.source_account) + dst_idx = node_to_idx.get(op.destination_account) + if src_idx is not None and dst_idx is not None: + edge_index.append([src_idx, dst_idx]) + + if len(edge_index) == 0: + edge_index = [[0, 1], [1, 2]] + + edge_index = torch.tensor(edge_index, dtype=torch.long).t() + + # Create and train model + model = GCN( + input_dim=feature_matrix.shape[1], + hidden_dim=8, + output_dim=2, + dropout=0.0, + ) + + # Create dummy labels + labels = torch.randint(0, 2, (num_nodes,)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + criterion = torch.nn.NLLLoss() + + model.train() + for _ in range(3): + optimizer.zero_grad() + out = model(torch.tensor(feature_matrix), edge_index) + loss = criterion(out, labels) + loss.backward() + optimizer.step() + + # Verify training completed + assert loss.item() is not None + + # Step 7: Validate predictions + model.eval() + with torch.no_grad(): + predictions = model(torch.tensor(feature_matrix), edge_index) + predicted_probs = torch.softmax(predictions, dim=1)[:, 1].numpy() + + # Verify predictions + assert len(predicted_probs) == num_nodes + assert all(0 <= p <= 1 for p in predicted_probs) + + def test_ingestion_to_validation_pipeline( + self, + test_session: Session, + temp_output_dir: Path, + ) -> None: + """Test pipeline from ingestion through validation.""" + # Step 1: Ingest and validate transactions + transactions = [ + { + "id": "tx1", + "source_account": "G" + "A" * 55, + "amount": 100.0, + "created_at": "2024-01-01T00:00:00Z", + }, + { + "id": "tx2", + "source_account": "G" + "B" * 55, + "amount": 50.0, + "created_at": "2024-01-01T00:01:00Z", + }, + ] + + validator = TransactionValidator( + required_fields={"id", "source_account", "amount"}, + ) + + results = validator.validate_batch(transactions) + + # Verify validation + assert len(results) == 2 + assert all(r.is_valid for r in results) + + # Step 2: Store valid transactions in database + for tx_data in transactions: + # Create ledger + ledger = Ledger( + sequence=1000, + hash="a" * 64, + closed_at=datetime(2024, 1, 1), + successful_transaction_count=1, + failed_transaction_count=0, + operation_count=1, + ) + test_session.add(ledger) + + # Create transaction + tx = Transaction( + hash=tx_data["id"] + "a" * 60, + ledger_sequence=1000, + source_account=tx_data["source_account"], + created_at=datetime.fromisoformat(tx_data["created_at"].replace("Z", "+00:00")), + fee=100, + operation_count=1, + successful=True, + memo_type="none", + ) + test_session.add(tx) + + test_session.commit() + + # Step 3: Verify database state + tx_count = test_session.query(Transaction).count() + assert tx_count == 2 + + def test_synthetic_fraud_to_detection_pipeline( + self, + test_session: Session, + temp_data_dir: Path, + temp_output_dir: Path, + ) -> None: + """Test pipeline from synthetic fraud injection to detection.""" + # Step 1: Create clean ledger + clean_transactions = [ + { + "source_account": "G" + "A" * 55, + "destination_account": "G" + "B" * 55, + "amount": 100.0, + "created_at": "2024-01-01T00:00:00Z", + } + ] + + input_file = temp_data_dir / "clean.jsonl" + output_file = temp_data_dir / "with_fraud.jsonl" + + with open(input_file, "w") as f: + for tx in clean_transactions: + f.write(tx.__str__() + "\n") + + # Step 2: Inject synthetic fraud + from astroml.ingestion.synthetic_fraud_injector import ( + inject_synthetic_fraud, + SybilConfig, + ) + + augmented, summary = inject_synthetic_fraud( + clean_transactions, + seed=42, + sybil=SybilConfig(clusters=1, cluster_size=2, tx_per_member=1), + ) + + # Verify injection + assert len(augmented) > len(clean_transactions) + assert summary.sybil_transactions > 0 + + # Step 3: Store in database + for tx in augmented: + if tx.get("synthetic_fraud"): + # Store fraud pattern metadata + pass + + # Step 4: Verify fraud detection capability + fraud_txs = [tx for tx in augmented if tx.get("synthetic_fraud")] + assert len(fraud_txs) > 0 + + def test_graph_snapshot_to_model_pipeline( + self, + test_session: Session, + temp_output_dir: Path, + ) -> None: + """Test pipeline from graph snapshot to model training.""" + # Step 1: Create normalized transactions + base_time = datetime(2024, 1, 1) + + for i in range(10): + tx = test_session.query(Transaction).first() + if not tx: + # Create transaction if none exists + ledger = Ledger( + sequence=1000 + i, + hash="a" * 64, + closed_at=base_time + timedelta(hours=i), + successful_transaction_count=1, + failed_transaction_count=0, + operation_count=1, + ) + test_session.add(ledger) + + tx = Transaction( + hash=f"tx{i}" + "a" * 60, + ledger_sequence=1000 + i, + source_account=f"G{'A' * i}{'B' * (55-i)}", + created_at=base_time + timedelta(hours=i), + fee=100, + operation_count=1, + successful=True, + memo_type="none", + ) + test_session.add(tx) + + test_session.commit() + + # Step 2: Create graph snapshot + from astroml.features.graph.snapshot import Edge, snapshot_last_n_days + + base_ts = int(base_time.timestamp()) + edges = [ + Edge(src=f"node_{i}", dst=f"node_{(i+1)%5}", timestamp=base_ts + i * 3600) + for i in range(10) + ] + + now_ts = base_ts + 86400 # 1 day later + nodes, window_edges = snapshot_last_n_days(edges, now_ts, days=1) + + # Verify snapshot + assert len(window_edges) > 0 + assert len(nodes) > 0 + + # Step 3: Compute features from snapshot + edge_dicts = [ + { + 'src': e.src, + 'dst': e.dst, + 'amount': 100.0, + 'timestamp': e.timestamp, + 'asset': 'XLM', + } + for e in window_edges + ] + + features_df = compute_node_features(edge_dicts) + + # Verify features + assert not features_df.empty + + def test_feature_store_to_training_pipeline( + self, + temp_output_dir: Path, + sample_node_features: Dict[str, np.ndarray], + ) -> None: + """Test pipeline from feature store to model training.""" + # Step 1: Store features in feature store + from astroml.features.feature_store import FeatureStore, FeatureDefinition, FeatureType + from astroml.features.feature_cache import FeatureCache + + store_path = temp_output_dir / "feature_store.db" + cache_path = temp_output_dir / "feature_cache.db" + + store = FeatureStore(store_path=str(store_path)) + cache = FeatureCache(cache_path=str(cache_path)) + + # Register feature + feature_def = FeatureDefinition( + name="node_embeddings", + description="Node embedding features", + feature_type=FeatureType.VECTOR, + ) + store.register_feature(feature_def) + + # Cache features + features_df = pd.DataFrame.from_dict(sample_node_features, orient='index') + cache.put_features( + feature_name="node_embeddings", + features=features_df, + metadata={"version": 1}, + ) + + # Step 2: Retrieve features for training + cached_features = cache.get_features("node_embeddings") + + # Verify retrieval + assert cached_features is not None + assert cached_features.shape == features_df.shape + + # Step 3: Train model with cached features + feature_matrix = cached_features.values.astype(np.float32) + num_nodes = feature_matrix.shape[0] + + # Simple model + import torch.nn as nn + model = nn.Sequential( + nn.Linear(feature_matrix.shape[1], 16), + nn.ReLU(), + nn.Linear(16, 2), + ) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + criterion = nn.CrossEntropyLoss() + + labels = torch.randint(0, 2, (num_nodes,)) + + model.train() + for _ in range(3): + optimizer.zero_grad() + predictions = model(torch.tensor(feature_matrix)) + loss = criterion(predictions, labels) + loss.backward() + optimizer.step() + + # Verify training + assert loss.item() is not None + + def test_end_to_end_data_quality_pipeline( + self, + test_session: Session, + temp_output_dir: Path, + ) -> None: + """Test complete data quality validation pipeline.""" + # Step 1: Ingest data with potential quality issues + transactions = [ + {"id": "tx1", "source_account": "GAAA", "amount": 100.0, "timestamp": "2024-01-01T00:00:00Z"}, + {"id": "tx2", "source_account": "GBBB", "amount": 50.0, "timestamp": "2024-01-01T00:01:00Z"}, + {"id": "tx3", "source_account": None, "amount": 75.0, "timestamp": "2024-01-01T00:02:00Z"}, # Invalid + {"id": "tx4", "source_account": "GDDD", "amount": "invalid", "timestamp": "2024-01-01T00:03:00Z"}, # Invalid + ] + + # Step 2: Validate data quality + validator = TransactionValidator( + required_fields={"id", "source_account", "amount"}, + field_types={"amount": (int, float)}, + ) + + results = validator.validate_batch(transactions) + + # Step 3: Filter valid transactions + valid_transactions = [ + tx for tx, result in zip(transactions, results) if result.is_valid + ] + + # Verify filtering + assert len(valid_transactions) == 2 + + # Step 4: Store only valid transactions + for tx in valid_transactions: + ledger = Ledger( + sequence=1000, + hash="a" * 64, + closed_at=datetime.fromisoformat(tx["timestamp"].replace("Z", "+00:00")), + successful_transaction_count=1, + failed_transaction_count=0, + operation_count=1, + ) + test_session.add(ledger) + + transaction = Transaction( + hash=tx["id"] + "a" * 60, + ledger_sequence=1000, + source_account=tx["source_account"], + created_at=datetime.fromisoformat(tx["timestamp"].replace("Z", "+00:00")), + fee=100, + operation_count=1, + successful=True, + memo_type="none", + ) + test_session.add(transaction) + + test_session.commit() + + # Step 5: Verify only valid data in database + tx_count = test_session.query(Transaction).count() + assert tx_count == 2 + + def test_model_deployment_pipeline( + self, + sample_training_data: tuple, + temp_output_dir: Path, + ) -> None: + """Test complete model deployment pipeline.""" + X, y = sample_training_data + + # Step 1: Train model + model = GCN( + input_dim=X.shape[1], + hidden_dim=16, + output_dim=2, + dropout=0.5, + ) + + edge_index = torch.randint(0, len(X), (2, len(X) * 2)) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + criterion = torch.nn.NLLLoss() + + model.train() + for _ in range(5): + optimizer.zero_grad() + out = model(torch.tensor(X, dtype=torch.float32), edge_index) + loss = criterion(out, torch.tensor(y, dtype=torch.long)) + loss.backward() + optimizer.step() + + # Step 2: Save model + model_path = temp_output_dir / "deployed_model.pt" + torch.save({ + 'model_state_dict': model.state_dict(), + 'input_dim': X.shape[1], + 'hidden_dim': 16, + 'output_dim': 2, + 'training_loss': loss.item(), + 'deployed_at': datetime.utcnow().isoformat(), + }, model_path) + + # Step 3: Load model for inference + checkpoint = torch.load(model_path) + loaded_model = GCN( + input_dim=checkpoint['input_dim'], + hidden_dim=checkpoint['hidden_dim'], + output_dim=checkpoint['output_dim'], + ) + loaded_model.load_state_dict(checkpoint['model_state_dict']) + + # Step 4: Perform inference + loaded_model.eval() + with torch.no_grad(): + predictions = loaded_model(torch.tensor(X, dtype=torch.float32), edge_index) + + # Verify deployment pipeline + assert model_path.exists() + assert predictions.shape[0] == len(X) diff --git a/tests/integration/test_graph_construction.py b/tests/integration/test_graph_construction.py new file mode 100644 index 0000000..5da2caf --- /dev/null +++ b/tests/integration/test_graph_construction.py @@ -0,0 +1,435 @@ +"""Integration tests for graph construction and snapshot pipeline. + +These tests verify the complete workflow from database operations +to graph construction, snapshot creation, and graph analysis. +""" +from __future__ import annotations + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import pytest +from sqlalchemy.orm import Session + +from astroml.db.schema import Operation, NormalizedTransaction +from astroml.features.graph.snapshot import ( + Edge, + window_snapshot, + snapshot_last_n_days, + SnapshotWindow, + iter_db_snapshots, +) +from astroml.features.transaction_graph import TransactionGraph + + +class TestGraphConstructionIntegration: + """Integration tests for graph construction from database.""" + + def test_build_graph_from_database_operations( + self, + populated_test_db: Session, + ) -> None: + """Test building a transaction graph from database operations.""" + # Query operations from database + operations = populated_test_db.query(Operation).all() + + # Build graph + graph = TransactionGraph() + for op in operations: + if op.destination_account: + graph.add_transaction( + from_account=op.source_account, + to_account=op.destination_account, + amount=float(op.amount) if op.amount else 0.0, + asset=op.asset_code or "XLM", + metadata={"operation_type": op.type}, + ) + + # Verify graph structure + assert len(graph.nodes) > 0 + summary = graph.summary() + assert summary["node_count"] > 0 + assert summary["transaction_count"] > 0 + + def test_graph_with_multiple_assets( + self, + ) -> None: + """Test graph construction with multiple asset types.""" + graph = TransactionGraph() + + # Add transactions with different assets + graph.add_transaction("A", "B", 100.0, "XLM") + graph.add_transaction("B", "C", 50.0, "USDC") + graph.add_transaction("C", "A", 25.0, "BTC") + graph.add_transaction("A", "C", 75.0, "XLM") + + # Verify multiple assets + assets = graph.get_assets() + assert len(assets) == 3 + assert "XLM" in assets + assert "USDC" in assets + assert "BTC" in assets + + def test_graph_edge_aggregation( + self, + ) -> None: + """Test edge weight aggregation methods.""" + graph = TransactionGraph() + + # Add multiple transactions between same accounts + graph.add_transaction("A", "B", 100.0, "XLM") + graph.add_transaction("A", "B", 50.0, "XLM") + graph.add_transaction("A", "B", 25.0, "XLM") + + # Test different aggregations + sum_weight = graph.get_edge_weight("A", "B", aggregation="sum") + mean_weight = graph.get_edge_weight("A", "B", aggregation="mean") + count_weight = graph.get_edge_weight("A", "B", aggregation="count") + max_weight = graph.get_edge_weight("A", "B", aggregation="max") + min_weight = graph.get_edge_weight("A", "B", aggregation="min") + + assert sum_weight == 175.0 + assert mean_weight == 175.0 / 3 + assert count_weight == 3.0 + assert max_weight == 100.0 + assert min_weight == 25.0 + + def test_graph_to_networkx_export( + self, + ) -> None: + """Test exporting graph to NetworkX format.""" + graph = TransactionGraph() + + graph.add_transaction("A", "B", 100.0, "XLM") + graph.add_transaction("B", "C", 50.0, "USDC") + graph.add_transaction("C", "A", 25.0, "XLM") + + # Export to NetworkX + nx_graph = graph.to_networkx() + + # Verify structure + assert nx_graph.number_of_nodes() == 3 + assert nx_graph.number_of_edges() == 3 + + # Verify edge weights + assert nx_graph["A"]["B"]["weight"] == 100.0 + assert nx_graph["B"]["C"]["weight"] == 50.0 + + def test_graph_summary_statistics( + self, + ) -> None: + """Test graph summary statistics computation.""" + graph = TransactionGraph() + + graph.add_transaction("A", "B", 100.0, "XLM") + graph.add_transaction("B", "C", 50.0, "USDC") + graph.add_transaction("A", "C", 25.0, "XLM") + graph.add_transaction("C", "A", 75.0, "BTC") + + summary = graph.summary() + + assert summary["node_count"] == 3 + assert summary["edge_count"] == 4 + assert summary["transaction_count"] == 4 + assert summary["asset_count"] == 3 + assert "XLM" in summary["assets"] + assert summary["assets"]["XLM"] == 2 + + +class TestGraphSnapshotIntegration: + """Integration tests for graph snapshot creation.""" + + def test_window_snapshot_creation( + self, + ) -> None: + """Test creating a time-windowed graph snapshot.""" + base_time = int(datetime(2024, 1, 1).timestamp()) + + edges = [ + Edge(src="A", dst="B", timestamp=base_time), + Edge(src="B", dst="C", timestamp=base_time + 3600), # +1 hour + Edge(src="C", dst="D", timestamp=base_time + 7200), # +2 hours + Edge(src="D", dst="E", timestamp=base_time + 86400), # +1 day + ] + + # Create 12-hour window + start_ts = base_time + end_ts = base_time + 12 * 3600 + + nodes, window_edges = window_snapshot(edges, start_ts, end_ts) + + # Should include first 3 edges (within 12 hours) + assert len(window_edges) == 3 + assert len(nodes) == 4 # A, B, C, D + assert "E" not in nodes + + def test_snapshot_last_n_days( + self, + ) -> None: + """Test snapshot creation for last N days.""" + now_ts = int(datetime(2024, 1, 15).timestamp()) + + edges = [ + Edge(src="A", dst="B", timestamp=now_ts - 86400), # 1 day ago + Edge(src="B", dst="C", timestamp=now_ts - 172800), # 2 days ago + Edge(src="C", dst="D", timestamp=now_ts - 259200), # 3 days ago + Edge(src="D", dst="E", timestamp=now_ts - 432000), # 5 days ago + ] + + # Get last 3 days + nodes, window_edges = snapshot_last_n_days(edges, now_ts, days=3) + + # Should include edges from last 3 days + assert len(window_edges) == 3 + assert len(nodes) == 4 + + def test_snapshot_with_presorted_edges( + self, + ) -> None: + """Test snapshot creation with pre-sorted edges.""" + base_time = int(datetime(2024, 1, 1).timestamp()) + + edges = [ + Edge(src="A", dst="B", timestamp=base_time), + Edge(src="B", dst="C", timestamp=base_time + 3600), + Edge(src="C", dst="D", timestamp=base_time + 7200), + ] + + # With presorted=True (should be faster) + nodes1, edges1 = window_snapshot(edges, base_time, base_time + 7200, presorted=True) + + # With presorted=False (should sort first) + nodes2, edges2 = window_snapshot(edges, base_time, base_time + 7200, presorted=False) + + # Results should be identical + assert len(nodes1) == len(nodes2) + assert len(edges1) == len(edges2) + + def test_empty_snapshot_window( + self, + ) -> None: + """Test snapshot creation when no edges fall in window.""" + base_time = int(datetime(2024, 1, 1).timestamp()) + + edges = [ + Edge(src="A", dst="B", timestamp=base_time), + Edge(src="B", dst="C", timestamp=base_time + 3600), + ] + + # Window with no edges + nodes, window_edges = window_snapshot( + edges, base_time + 7200, base_time + 10800 + ) + + # Should be empty + assert len(nodes) == 0 + assert len(window_edges) == 0 + + +class TestDatabaseSnapshotIntegration: + """Integration tests for database-backed snapshot creation.""" + + def test_db_snapshot_from_normalized_transactions( + self, + test_session: Session, + ) -> None: + """Test creating snapshots from normalized transactions in database.""" + # Add normalized transactions + base_time = datetime(2024, 1, 1) + + transactions = [ + NormalizedTransaction( + transaction_hash="tx1", + sender="G" + "A" * 55, + receiver="G" + "B" * 55, + asset="XLM", + amount=100.0, + timestamp=base_time, + ), + NormalizedTransaction( + transaction_hash="tx2", + sender="G" + "B" * 55, + receiver="G" + "C" * 55, + asset="USDC", + amount=50.0, + timestamp=base_time + timedelta(hours=1), + ), + NormalizedTransaction( + transaction_hash="tx3", + sender="G" + "C" * 55, + receiver="G" + "A" * 55, + asset="XLM", + amount=25.0, + timestamp=base_time + timedelta(hours=2), + ), + ] + + for tx in transactions: + test_session.add(tx) + test_session.commit() + + # Create snapshot + t0 = base_time + t_now = base_time + timedelta(hours=3) + + snapshots = list(iter_db_snapshots( + window="1h", + t0=t0, + t_now=t_now, + session=test_session, + )) + + # Should have 3 hourly snapshots + assert len(snapshots) == 3 + + # Verify snapshot structure + for snapshot in snapshots: + assert isinstance(snapshot, SnapshotWindow) + assert isinstance(snapshot.index, int) + assert isinstance(snapshot.start, datetime) + assert isinstance(snapshot.end, datetime) + assert isinstance(snapshot.edges, list) + assert isinstance(snapshot.nodes, set) + + def test_db_snapshot_with_rolling_window( + self, + test_session: Session, + ) -> None: + """Test creating rolling window snapshots from database.""" + base_time = datetime(2024, 1, 1) + + # Add transactions + for i in range(10): + tx = NormalizedTransaction( + transaction_hash=f"tx{i}", + sender=f"G{'A' * i}{'B' * (55-i)}", + receiver=f"G{'C' * i}{'D' * (55-i)}", + asset="XLM", + amount=10.0 * i, + timestamp=base_time + timedelta(hours=i), + ) + test_session.add(tx) + test_session.commit() + + # Create rolling snapshots (2-hour window, 1-hour step) + t0 = base_time + t_now = base_time + timedelta(hours=10) + + snapshots = list(iter_db_snapshots( + window="2h", + step="1h", + t0=t0, + t_now=t_now, + session=test_session, + )) + + # Should have 10 snapshots (rolling with overlap) + assert len(snapshots) == 10 + + +class TestGraphConstructionPipelineIntegration: + """Integration tests for complete graph construction pipeline.""" + + def test_database_to_graph_to_snapshot_pipeline( + self, + populated_test_db: Session, + ) -> None: + """Test complete pipeline from database to graph snapshot.""" + # Step 1: Extract operations from database + operations = populated_test_db.query(Operation).all() + + # Step 2: Build transaction graph + graph = TransactionGraph() + for op in operations: + if op.destination_account: + graph.add_transaction( + from_account=op.source_account, + to_account=op.destination_account, + amount=float(op.amount) if op.amount else 0.0, + asset=op.asset_code or "XLM", + ) + + # Step 3: Convert to edge format for snapshot + base_time = int(datetime(2024, 1, 1).timestamp()) + edges = [] + for src, dsts in graph.edges.items(): + for dst in dsts: + for txn in graph.edges[src][dst]: + edges.append(Edge(src=src, dst=dst, timestamp=base_time)) + + # Step 4: Create snapshot + nodes, window_edges = window_snapshot(edges, base_time, base_time + 86400) + + # Verify pipeline + assert len(graph.nodes) > 0 + assert len(edges) > 0 + assert len(nodes) > 0 + + def test_incremental_graph_construction( + self, + test_session: Session, + ) -> None: + """Test incremental graph construction as new data arrives.""" + # Initial graph + graph = TransactionGraph() + graph.add_transaction("A", "B", 100.0, "XLM") + graph.add_transaction("B", "C", 50.0, "USDC") + + initial_summary = graph.summary() + assert initial_summary["transaction_count"] == 2 + + # Add new transactions + graph.add_transaction("C", "D", 25.0, "BTC") + graph.add_transaction("D", "A", 75.0, "XLM") + + updated_summary = graph.summary() + assert updated_summary["transaction_count"] == 4 + assert updated_summary["node_count"] == 4 + + def test_graph_filtering_by_asset( + self, + ) -> None: + """Test filtering graph by specific asset.""" + graph = TransactionGraph() + + graph.add_transaction("A", "B", 100.0, "XLM") + graph.add_transaction("B", "C", 50.0, "USDC") + graph.add_transaction("C", "A", 25.0, "XLM") + graph.add_transaction("A", "D", 75.0, "BTC") + + # Filter by XLM + xlm_txns = graph.get_transactions(asset="XLM") + assert len(xlm_txns) == 2 + + # Filter by USDC + usdc_txns = graph.get_transactions(asset="USDC") + assert len(usdc_txns) == 1 + + def test_graph_persistence_workflow( + self, + temp_output_dir: Path, + ) -> None: + """Test saving and loading graph data.""" + graph = TransactionGraph() + + graph.add_transaction("A", "B", 100.0, "XLM") + graph.add_transaction("B", "C", 50.0, "USDC") + + # Save graph summary + summary = graph.summary() + import json + summary_path = temp_output_dir / "graph_summary.json" + with open(summary_path, 'w') as f: + json.dump(summary, f) + + # Verify file exists + assert summary_path.exists() + + # Load and verify + with open(summary_path, 'r') as f: + loaded_summary = json.load(f) + + assert loaded_summary["node_count"] == 3 + assert loaded_summary["transaction_count"] == 2 diff --git a/tests/integration/test_ingestion_pipeline.py b/tests/integration/test_ingestion_pipeline.py new file mode 100644 index 0000000..311cd74 --- /dev/null +++ b/tests/integration/test_ingestion_pipeline.py @@ -0,0 +1,444 @@ +"""End-to-end integration tests for the ingestion pipeline. + +These tests verify the complete workflow from fetching ledger data +to storing it in the database, including parsing and state management. +""" +from __future__ import annotations + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, List + +import pytest +from sqlalchemy.orm import Session + +from astroml.db.schema import Ledger, Transaction, Operation, Account, Asset, Effect +from astroml.ingestion.service import IngestionService, IngestionResult +from astroml.ingestion.parsers import ( + parse_ledger, + parse_transaction, + parse_operation, + parse_effect, +) +from astroml.ingestion.synthetic_fraud_injector import ( + inject_synthetic_fraud, + SybilConfig, + WashLoopConfig, + InjectionSummary, + run_injection, +) + + +class TestIngestionServiceIntegration: + """Integration tests for IngestionService with database persistence.""" + + def test_ingest_ledgers_to_database( + self, + test_session: Session, + sample_ledger_data: List[Dict[str, Any]], + ) -> None: + """Test complete ingestion workflow from ledger data to database.""" + service = IngestionService() + + # Mock fetch function that returns ledger data + def fetch_ledger(ledger_id: int) -> Dict[str, Any]: + return sample_ledger_data[ledger_id - 1000] + + # Mock process function that stores in database + def process_ledger(ledger_id: int, payload: Dict[str, Any]) -> None: + ledger = parse_ledger(payload) + test_session.add(ledger) + test_session.commit() + + # Ingest ledgers + result = service.ingest( + start_ledger=1000, + end_ledger=1001, + fetch_fn=fetch_ledger, + process_fn=process_ledger, + ) + + # Verify results + assert result.attempted == [1000, 1001] + assert result.processed == [1000, 1001] + assert result.skipped == [] + + # Verify database state + ledgers = test_session.query(Ledger).all() + assert len(ledgers) == 2 + assert ledgers[0].sequence == 1000 + assert ledgers[1].sequence == 1001 + + def test_ingest_with_idempotency( + self, + test_session: Session, + sample_ledger_data: List[Dict[str, Any]], + ) -> None: + """Test that ingestion is idempotent - re-processing skips already processed ledgers.""" + service = IngestionService() + + def fetch_ledger(ledger_id: int) -> Dict[str, Any]: + return sample_ledger_data[ledger_id - 1000] + + def process_ledger(ledger_id: int, payload: Dict[str, Any]) -> None: + ledger = parse_ledger(payload) + test_session.add(ledger) + test_session.commit() + + # First ingestion + result1 = service.ingest( + start_ledger=1000, + end_ledger=1001, + fetch_fn=fetch_ledger, + process_fn=process_ledger, + ) + assert result1.processed == [1000, 1001] + + # Second ingestion - should skip already processed + result2 = service.ingest( + start_ledger=1000, + end_ledger=1001, + fetch_fn=fetch_ledger, + process_fn=process_ledger, + ) + assert result2.attempted == [1000, 1001] + assert result2.processed == [] + assert result2.skipped == [1000, 1001] + + # Verify no duplicates in database + ledgers = test_session.query(Ledger).all() + assert len(ledgers) == 2 + + def test_ingest_with_partial_failure( + self, + test_session: Session, + sample_ledger_data: List[Dict[str, Any]], + ) -> None: + """Test ingestion continues even if one ledger fails to process.""" + service = IngestionService() + + def fetch_ledger(ledger_id: int) -> Dict[str, Any]: + return sample_ledger_data[ledger_id - 1000] + + call_count = [0] + + def process_ledger(ledger_id: int, payload: Dict[str, Any]) -> None: + call_count[0] += 1 + if ledger_id == 1000: + raise ValueError("Simulated failure") + ledger = parse_ledger(payload) + test_session.add(ledger) + test_session.commit() + + # Should fail on first ledger + with pytest.raises(ValueError): + service.ingest( + start_ledger=1000, + end_ledger=1001, + fetch_fn=fetch_ledger, + process_fn=process_ledger, + ) + + # State should not have marked ledger 1000 as processed + # Retry without the failing ledger + result = service.ingest( + start_ledger=1001, + end_ledger=1001, + fetch_fn=fetch_ledger, + process_fn=process_ledger, + ) + assert result.processed == [1001] + + # Verify only successful ledger is in database + ledgers = test_session.query(Ledger).all() + assert len(ledgers) == 1 + assert ledgers[0].sequence == 1001 + + +class TestParserIntegration: + """Integration tests for parsers with database storage.""" + + def test_parse_and_store_complete_transaction( + self, + test_session: Session, + sample_transaction_data: List[Dict[str, Any]], + sample_operation_data: List[Dict[str, Any]], + ) -> None: + """Test parsing and storing a complete transaction with operations.""" + # First, add a ledger + ledger = Ledger( + sequence=1000, + hash="a" * 64, + closed_at=datetime(2024, 1, 1), + successful_transaction_count=1, + failed_transaction_count=0, + operation_count=2, + ) + test_session.add(ledger) + test_session.commit() + + # Parse and store transaction + tx_data = sample_transaction_data[0] + transaction = parse_transaction(tx_data) + test_session.add(transaction) + test_session.commit() + + # Parse and store operations + for i, op_data in enumerate(sample_operation_data): + if op_data["transaction_hash"] == tx_data["hash"]: + operation = parse_operation(op_data, application_order=i) + test_session.add(operation) + test_session.commit() + + # Verify transaction was stored + stored_tx = test_session.query(Transaction).filter_by(hash=tx_data["hash"]).first() + assert stored_tx is not None + assert stored_tx.source_account == tx_data["source_account"] + assert stored_tx.ledger_sequence == 1000 + + # Verify operations were stored and linked + operations = test_session.query(Operation).filter_by(transaction_hash=tx_data["hash"]).all() + assert len(operations) == 2 + + def test_parse_and_store_effects( + self, + test_session: Session, + sample_effect_data: List[Dict[str, Any]], + ) -> None: + """Test parsing and storing effects.""" + for effect_data in sample_effect_data: + effect = parse_effect(effect_data) + test_session.add(effect) + test_session.commit() + + # Verify effects were stored + effects = test_session.query(Effect).all() + assert len(effects) == 2 + assert effects[0].type == "account_debited" + assert effects[1].type == "account_credited" + + +class TestSyntheticFraudInjectionIntegration: + """Integration tests for synthetic fraud injection.""" + + def test_inject_fraud_patterns_to_file( + self, + temp_data_dir: Path, + ) -> None: + """Test injecting fraud patterns and saving to file.""" + # Create sample clean ledger + clean_ledger = [ + { + "source_account": "G" + "A" * 55, + "destination_account": "G" + "B" * 55, + "amount": 100.0, + "created_at": "2024-01-01T00:00:00Z", + } + ] + + input_file = temp_data_dir / "clean_ledger.jsonl" + output_file = temp_data_dir / "augmented_ledger.jsonl" + summary_file = temp_data_dir / "summary.json" + + # Write clean ledger + with open(input_file, "w") as f: + for tx in clean_ledger: + f.write(tx.__str__() + "\n") + + # Run injection + summary = run_injection( + input_path=str(input_file), + output_path=str(output_file), + summary_path=str(summary_file), + seed=42, + sybil=SybilConfig(clusters=1, cluster_size=3, tx_per_member=2), + wash=WashLoopConfig(loops=1, loop_size=3, rounds=2), + source_field="source_account", + dest_field="destination_account", + amount_field="amount", + timestamp_field="created_at", + ) + + # Verify summary + assert summary.original_transactions == 1 + assert summary.sybil_transactions == 6 # 1 cluster * 3 members * 2 tx + assert summary.wash_loop_transactions == 6 # 1 loop * 3 accounts * 2 rounds + assert summary.injected_transactions == 12 + assert summary.total_transactions == 13 + + # Verify output file exists + assert output_file.exists() + assert summary_file.exists() + + def test_inject_fraud_in_memory( + self, + ) -> None: + """Test injecting fraud patterns in memory.""" + clean_transactions = [ + { + "source_account": "G" + "A" * 55, + "destination_account": "G" + "B" * 55, + "amount": 100.0, + "created_at": "2024-01-01T00:00:00Z", + } + ] + + augmented, summary = inject_synthetic_fraud( + clean_transactions, + seed=42, + sybil=SybilConfig(clusters=1, cluster_size=2, tx_per_member=1), + wash=WashLoopConfig(loops=0, loop_size=0, rounds=0), # No wash loops + source_field="source_account", + dest_field="destination_account", + amount_field="amount", + timestamp_field="created_at", + ) + + # Verify augmentation + assert len(augmented) == 3 # 1 original + 2 sybil transactions + assert summary.original_transactions == 1 + assert summary.sybil_transactions == 2 + assert summary.wash_loop_transactions == 0 + + # Verify synthetic transactions are tagged + synthetic_txs = [tx for tx in augmented if tx.get("synthetic_fraud")] + assert len(synthetic_txs) == 2 + assert all(tx["fraud_pattern"] == "sybil_cluster" for tx in synthetic_txs) + + def test_fraud_injection_preserves_original_data( + self, + ) -> None: + """Test that fraud injection preserves original transaction data.""" + original = [ + { + "source_account": "G" + "A" * 55, + "destination_account": "G" + "B" * 55, + "amount": 100.0, + "created_at": "2024-01-01T00:00:00Z", + "custom_field": "should_preserve", + } + ] + + augmented, _ = inject_synthetic_fraud( + original, + seed=42, + sybil=SybilConfig(clusters=0, cluster_size=0, tx_per_member=0), + wash=WashLoopConfig(loops=0, loop_size=0, rounds=0), + ) + + # Original transaction should be unchanged + assert len(augmented) == 1 + assert augmented[0]["custom_field"] == "should_preserve" + assert "synthetic_fraud" not in augmented[0] + + +class TestCompleteIngestionWorkflow: + """Integration tests for the complete ingestion workflow.""" + + def test_ledger_to_operations_workflow( + self, + test_session: Session, + ) -> None: + """Test complete workflow from ledger to operations.""" + # Create ledger + ledger_data = { + "sequence": 1000, + "hash": "a" * 64, + "prev_hash": "b" * 64, + "closed_at": datetime(2024, 1, 1), + "successful_transaction_count": 1, + "failed_transaction_count": 0, + "operation_count": 2, + } + ledger = Ledger(**ledger_data) + test_session.add(ledger) + test_session.commit() + + # Create transaction + tx_data = { + "hash": "tx1" + "a" * 60, + "ledger": 1000, + "source_account": "G" + "A" * 55, + "created_at": datetime(2024, 1, 1), + "fee_charged": 100, + "operation_count": 2, + "successful": True, + "memo_type": "none", + } + transaction = parse_transaction(tx_data) + test_session.add(transaction) + test_session.commit() + + # Create operations + op_data_1 = { + "id": 1, + "transaction_hash": "tx1" + "a" * 60, + "source_account": "G" + "A" * 55, + "type": "payment", + "to": "G" + "B" * 55, + "amount": "100.0", + "asset_type": "native", + "created_at": datetime(2024, 1, 1), + } + op_data_2 = { + "id": 2, + "transaction_hash": "tx1" + "a" * 60, + "source_account": "G" + "A" * 55, + "type": "create_account", + "account": "G" + "C" * 55, + "starting_balance": "50.0", + "created_at": datetime(2024, 1, 1), + } + + op1 = parse_operation(op_data_1, application_order=0) + op2 = parse_operation(op_data_2, application_order=1) + test_session.add(op1) + test_session.add(op2) + test_session.commit() + + # Verify complete chain + assert test_session.query(Ledger).count() == 1 + assert test_session.query(Transaction).count() == 1 + assert test_session.query(Operation).count() == 2 + + # Verify relationships + stored_tx = test_session.query(Transaction).first() + assert stored_tx.ledger_sequence == 1000 + assert len(stored_tx.operations) == 2 + + def test_incremental_ingestion_with_state( + self, + test_session: Session, + sample_ledger_data: List[Dict[str, Any]], + ) -> None: + """Test incremental ingestion with state persistence.""" + service = IngestionService() + + def fetch_ledger(ledger_id: int) -> Dict[str, Any]: + return sample_ledger_data[ledger_id - 1000] + + def process_ledger(ledger_id: int, payload: Dict[str, Any]) -> None: + ledger = parse_ledger(payload) + test_session.add(ledger) + test_session.commit() + + # First batch + result1 = service.ingest( + start_ledger=1000, + end_ledger=1000, + fetch_fn=fetch_ledger, + process_fn=process_ledger, + ) + assert result1.processed == [1000] + + # Second batch - should continue from where we left off + result2 = service.ingest( + start_ledger=1001, + end_ledger=1001, + fetch_fn=fetch_ledger, + process_fn=process_ledger, + ) + assert result2.processed == [1001] + + # Verify both ledgers are in database + assert test_session.query(Ledger).count() == 2 diff --git a/tests/integration/test_model_training.py b/tests/integration/test_model_training.py new file mode 100644 index 0000000..0f9206f --- /dev/null +++ b/tests/integration/test_model_training.py @@ -0,0 +1,496 @@ +"""Integration tests for the model training pipeline. + +These tests verify the complete workflow from features to trained models, +including training, evaluation, and model persistence. +""" +from __future__ import annotations + +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import pandas as pd +import pytest +import torch +import torch.nn as nn +from torch_geometric.data import Data + +from astroml.models.gcn import GCN +from astroml.models.sage_encoder import InductiveSAGEEncoder +from astroml.training.train_sage import train_epoch, build_reconstruction_target +from astroml.features.gnn.sampler import MultiHopSampler + + +class TestGCNTrainingIntegration: + """Integration tests for GCN model training.""" + + def test_gcn_training_workflow( + self, + sample_training_data: tuple, + ) -> None: + """Test complete GCN training workflow.""" + X, y = sample_training_data + + # Create simple graph structure (random edges) + num_nodes = X.shape[0] + edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2)) + + # Convert to PyG format + data = Data( + x=torch.tensor(X, dtype=torch.float32), + edge_index=edge_index, + y=torch.tensor(y, dtype=torch.long), + ) + + # Create model + model = GCN( + input_dim=X.shape[1], + hidden_dim=16, + output_dim=2, + dropout=0.5, + ) + + # Training setup + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + criterion = nn.NLLLoss() + + # Train for a few epochs + model.train() + initial_loss = None + for epoch in range(5): + optimizer.zero_grad() + out = model(data.x, data.edge_index) + loss = criterion(out, data.y) + loss.backward() + optimizer.step() + + if epoch == 0: + initial_loss = loss.item() + + # Verify loss decreased + final_loss = loss.item() + assert final_loss < initial_loss or final_loss == initial_loss + + def test_gcn_prediction_workflow( + self, + sample_training_data: tuple, + ) -> None: + """Test GCN prediction workflow after training.""" + X, y = sample_training_data + num_nodes = X.shape[0] + edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2)) + + data = Data( + x=torch.tensor(X, dtype=torch.float32), + edge_index=edge_index, + y=torch.tensor(y, dtype=torch.long), + ) + + model = GCN( + input_dim=X.shape[1], + hidden_dim=16, + output_dim=2, + dropout=0.0, # No dropout for prediction + ) + + # Train briefly + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + criterion = nn.NLLLoss() + model.train() + for _ in range(3): + optimizer.zero_grad() + out = model(data.x, data.edge_index) + loss = criterion(out, data.y) + loss.backward() + optimizer.step() + + # Predict + model.eval() + with torch.no_grad(): + predictions = model(data.x, data.edge_index) + predicted_classes = predictions.argmax(dim=1) + + # Verify predictions + assert predicted_classes.shape == (num_nodes,) + assert torch.all(predicted_classes >= 0) + assert torch.all(predicted_classes < 2) + + +class TestGraphSAGETrainingIntegration: + """Integration tests for GraphSAGE model training.""" + + def test_sage_encoder_training( + self, + sample_node_features: Dict[str, np.ndarray], + sample_edge_list: List[tuple], + ) -> None: + """Test GraphSAGE encoder training with reconstruction loss.""" + # Prepare data + node_ids = list(sample_node_features.keys()) + features = np.stack([sample_node_features[nid] for nid in node_ids]) + features_tensor = torch.tensor(features, dtype=torch.float32) + + # Create edge index + node_to_idx = {nid: i for i, nid in enumerate(node_ids)} + edge_list = [] + for src, dst, _, _ in sample_edge_list: + if src in node_to_idx and dst in node_to_idx: + edge_list.append([node_to_idx[src], node_to_idx[dst]]) + + if len(edge_list) == 0: + # Create dummy edges if none exist + edge_list = [[0, 1], [1, 2], [2, 0]] + + edge_index = torch.tensor(edge_list, dtype=torch.long).t() + + # Create encoder + encoder = InductiveSAGEEncoder( + input_dim=features.shape[1], + hidden_dim=16, + output_dim=8, + num_layers=2, + dropout=0.0, + aggregator='mean', + ) + + # Create sampler + sampler = MultiHopSampler(edge_index, num_hops=2, fanout=[5, 5]) + + # Train nodes + train_nodes = torch.arange(min(10, len(node_ids))) + + # Training setup + optimizer = torch.optim.Adam(encoder.parameters(), lr=0.01) + + # Train for one epoch + loss = train_epoch( + encoder=encoder, + sampler=sampler, + features=features_tensor, + edge_index=edge_index, + train_nodes=train_nodes, + optimizer=optimizer, + batch_size=4, + device='cpu', + ) + + # Verify loss is finite + assert isinstance(loss, float) + assert np.isfinite(loss) + + def test_reconstruction_target_computation( + self, + sample_node_features: Dict[str, np.ndarray], + sample_edge_list: List[tuple], + ) -> None: + """Test reconstruction target computation for training.""" + node_ids = list(sample_node_features.keys()) + features = np.stack([sample_node_features[nid] for nid in node_ids]) + features_tensor = torch.tensor(features, dtype=torch.float32) + + # Create edge index + node_to_idx = {nid: i for i, nid in enumerate(node_ids)} + edge_list = [] + for src, dst, _, _ in sample_edge_list: + if src in node_to_idx and dst in node_to_idx: + edge_list.append([node_to_idx[src], node_to_idx[dst]]) + + if len(edge_list) == 0: + edge_list = [[0, 1], [1, 2], [2, 0]] + + edge_index = torch.tensor(edge_list, dtype=torch.long).t() + + # Compute reconstruction targets + target_nodes = torch.arange(min(5, len(node_ids))) + targets = build_reconstruction_target( + edge_index=edge_index, + features=features_tensor, + target_nodes=target_nodes, + ) + + # Verify shape and values + assert targets.shape == (len(target_nodes), features.shape[1]) + assert torch.all(torch.isfinite(targets)) + + +class TestModelPersistenceIntegration: + """Integration tests for model persistence and loading.""" + + def test_save_and_load_gcn_model( + self, + sample_training_data: tuple, + temp_output_dir: Path, + ) -> None: + """Test saving and loading GCN model.""" + X, y = sample_training_data + num_nodes = X.shape[0] + edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2)) + + # Create and train model + model = GCN( + input_dim=X.shape[1], + hidden_dim=16, + output_dim=2, + dropout=0.5, + ) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + criterion = nn.NLLLoss() + model.train() + for _ in range(3): + optimizer.zero_grad() + data = Data( + x=torch.tensor(X, dtype=torch.float32), + edge_index=edge_index, + y=torch.tensor(y, dtype=torch.long), + ) + out = model(data.x, data.edge_index) + loss = criterion(out, data.y) + loss.backward() + optimizer.step() + + # Save model + model_path = temp_output_dir / "gcn_model.pt" + torch.save({ + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'input_dim': X.shape[1], + 'hidden_dim': 16, + 'output_dim': 2, + }, model_path) + + # Verify file exists + assert model_path.exists() + + # Load model + checkpoint = torch.load(model_path) + loaded_model = GCN( + input_dim=checkpoint['input_dim'], + hidden_dim=checkpoint['hidden_dim'], + output_dim=checkpoint['output_dim'], + ) + loaded_model.load_state_dict(checkpoint['model_state_dict']) + + # Verify loaded model works + loaded_model.eval() + with torch.no_grad(): + data = Data( + x=torch.tensor(X, dtype=torch.float32), + edge_index=edge_index, + ) + predictions = loaded_model(data.x, data.edge_index) + + assert predictions.shape == (num_nodes, 2) + + def test_save_and_load_sage_encoder( + self, + sample_node_features: Dict[str, np.ndarray], + temp_output_dir: Path, + ) -> None: + """Test saving and loading GraphSAGE encoder.""" + node_ids = list(sample_node_features.keys()) + features = np.stack([sample_node_features[nid] for nid in node_ids]) + + # Create encoder + encoder = InductiveSAGEEncoder( + input_dim=features.shape[1], + hidden_dim=16, + output_dim=8, + num_layers=2, + dropout=0.0, + aggregator='mean', + ) + + # Save encoder + encoder_path = temp_output_dir / "sage_encoder.pt" + torch.save({ + 'encoder_state_dict': encoder.state_dict(), + 'input_dim': features.shape[1], + 'hidden_dim': 16, + 'output_dim': 8, + 'num_layers': 2, + 'aggregator': 'mean', + }, encoder_path) + + # Verify file exists + assert encoder_path.exists() + + # Load encoder + checkpoint = torch.load(encoder_path) + loaded_encoder = InductiveSAGEEncoder( + input_dim=checkpoint['input_dim'], + hidden_dim=checkpoint['hidden_dim'], + output_dim=checkpoint['output_dim'], + num_layers=checkpoint['num_layers'], + aggregator=checkpoint['aggregator'], + ) + loaded_encoder.load_state_dict(checkpoint['encoder_state_dict']) + + # Verify loaded encoder works + features_tensor = torch.tensor(features, dtype=torch.float32) + with torch.no_grad(): + embeddings = loaded_encoder(features_tensor, []) + + assert embeddings.shape == (len(node_ids), 8) + + +class TestTrainingPipelineIntegration: + """Integration tests for complete training pipelines.""" + + def test_features_to_model_pipeline( + self, + sample_node_features: Dict[str, np.ndarray], + sample_edge_list: List[tuple], + temp_output_dir: Path, + ) -> None: + """Test complete pipeline from features to trained model.""" + # Step 1: Prepare features + node_ids = list(sample_node_features.keys()) + features = np.stack([sample_node_features[nid] for nid in node_ids]) + features_tensor = torch.tensor(features, dtype=torch.float32) + + # Step 2: Create graph structure + node_to_idx = {nid: i for i, nid in enumerate(node_ids)} + edge_list = [] + for src, dst, _, _ in sample_edge_list: + if src in node_to_idx and dst in node_to_idx: + edge_list.append([node_to_idx[src], node_to_idx[dst]]) + + if len(edge_list) == 0: + edge_list = [[0, 1], [1, 2], [2, 0]] + + edge_index = torch.tensor(edge_list, dtype=torch.long).t() + + # Step 3: Create and train model + model = GCN( + input_dim=features.shape[1], + hidden_dim=16, + output_dim=2, + dropout=0.5, + ) + + # Create dummy labels + labels = torch.randint(0, 2, (len(node_ids),)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + criterion = nn.NLLLoss() + + model.train() + for _ in range(5): + optimizer.zero_grad() + out = model(features_tensor, edge_index) + loss = criterion(out, labels) + loss.backward() + optimizer.step() + + # Step 4: Save model + model_path = temp_output_dir / "trained_model.pt" + torch.save({ + 'model_state_dict': model.state_dict(), + 'input_dim': features.shape[1], + 'hidden_dim': 16, + 'output_dim': 2, + 'training_loss': loss.item(), + 'trained_at': datetime.utcnow().isoformat(), + }, model_path) + + # Verify pipeline + assert model_path.exists() + checkpoint = torch.load(model_path) + assert 'training_loss' in checkpoint + assert 'trained_at' in checkpoint + + def test_incremental_training_workflow( + self, + sample_node_features: Dict[str, np.ndarray], + temp_output_dir: Path, + ) -> None: + """Test incremental training with new data.""" + node_ids = list(sample_node_features.keys()) + features = np.stack([sample_node_features[nid] for nid in node_ids]) + + # Initial training + model = GCN( + input_dim=features.shape[1], + hidden_dim=16, + output_dim=2, + dropout=0.5, + ) + + edge_index = torch.randint(0, len(node_ids), (2, len(node_ids) * 2)) + labels = torch.randint(0, 2, (len(node_ids),)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + criterion = nn.NLLLoss() + + model.train() + for _ in range(3): + optimizer.zero_grad() + out = model(torch.tensor(features, dtype=torch.float32), edge_index) + loss = criterion(out, labels) + loss.backward() + optimizer.step() + + initial_loss = loss.item() + + # Add new data + new_features = np.random.randn(5, features.shape[1]).astype(np.float32) + updated_features = np.vstack([features, new_features]) + updated_edge_index = torch.randint(0, len(node_ids) + 5, (2, (len(node_ids) + 5) * 2)) + updated_labels = torch.randint(0, 2, (len(node_ids) + 5,)) + + # Continue training + for _ in range(3): + optimizer.zero_grad() + out = model(torch.tensor(updated_features, dtype=torch.float32), updated_edge_index) + loss = criterion(out, updated_labels) + loss.backward() + optimizer.step() + + # Verify training continued + assert loss.item() is not None + + def test_model_evaluation_workflow( + self, + sample_training_data: tuple, + ) -> None: + """Test model evaluation workflow.""" + X, y = sample_training_data + + # Split data + split_idx = int(0.8 * len(X)) + X_train, X_test = X[:split_idx], X[split_idx:] + y_train, y_test = y[:split_idx], y[split_idx:] + + # Create model + model = GCN( + input_dim=X.shape[1], + hidden_dim=16, + output_dim=2, + dropout=0.5, + ) + + # Train + edge_index = torch.randint(0, len(X_train), (2, len(X_train) * 2)) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + criterion = nn.NLLLoss() + + model.train() + for _ in range(5): + optimizer.zero_grad() + out = model(torch.tensor(X_train, dtype=torch.float32), edge_index) + loss = criterion(out, torch.tensor(y_train, dtype=torch.long)) + loss.backward() + optimizer.step() + + # Evaluate + model.eval() + with torch.no_grad(): + test_edge_index = torch.randint(0, len(X_test), (2, len(X_test) * 2)) + predictions = model(torch.tensor(X_test, dtype=torch.float32), test_edge_index) + predicted_classes = predictions.argmax(dim=1) + accuracy = (predicted_classes == torch.tensor(y_test)).float().mean() + + # Verify evaluation + assert 0.0 <= accuracy.item() <= 1.0 diff --git a/tests/integration/test_streaming.py b/tests/integration/test_streaming.py new file mode 100644 index 0000000..f2b17ec --- /dev/null +++ b/tests/integration/test_streaming.py @@ -0,0 +1,379 @@ +"""Integration tests for streaming ingestion pipeline. + +These tests verify the complete workflow from real-time streaming +to database persistence, including reconnection logic and cursor tracking. +""" +from __future__ import annotations + +import asyncio +import json +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astroml.ingestion.stream import HorizonStreamClient +from astroml.ingestion.config import StreamConfig +from astroml.ingestion.enhanced_stream import ( + EnhancedStreamConfig, + RateLimitTracker, +) + + +class TestStreamClientIntegration: + """Integration tests for Horizon streaming client.""" + + @pytest.mark.asyncio + async def test_stream_client_initialization( + self, + ) -> None: + """Test stream client initialization with configuration.""" + config = StreamConfig( + horizon_url="https://horizon-testnet.stellar.org", + stream_endpoint="/transactions", + cursor="12345", + ) + + client = HorizonStreamClient(config) + + assert client._config.horizon_url == "https://horizon-testnet.stellar.org" + assert client._config.stream_endpoint == "/transactions" + assert client._last_cursor == "12345" + + @pytest.mark.asyncio + async def test_stream_client_url_building( + self, + ) -> None: + """Test stream URL construction with cursor.""" + config = StreamConfig( + horizon_url="https://horizon-testnet.stellar.org", + stream_endpoint="/transactions", + cursor="12345", + ) + + client = HorizonStreamClient(config) + url = client._build_stream_url() + + assert "cursor=12345" in url + assert "order=asc" in url + assert url.startswith("https://horizon-testnet.stellar.org/transactions") + + @pytest.mark.asyncio + async def test_stream_client_cursor_tracking( + self, + ) -> None: + """Test cursor tracking during streaming.""" + config = StreamConfig(cursor="1000") + client = HorizonStreamClient(config) + + # Mock event with new cursor + event = MagicMock() + event.data = json.dumps({ + "hash": "x" * 64, + "paging_token": "1001", + }) + + client._running = True + + with patch.object(client, "_persist_transaction", new_callable=AsyncMock): + with patch.object(client, "_save_cursor"): + await client._process_event(event) + + assert client._last_cursor == "1001" + + @pytest.mark.asyncio + async def test_stream_client_reconnection_logic( + self, + ) -> None: + """Test exponential backoff on reconnection.""" + config = StreamConfig( + reconnect_base_seconds=0.01, + reconnect_max_seconds=0.05, + max_retries=3, + ) + client = HorizonStreamClient(config) + client._running = True + + with patch("astroml.ingestion.stream.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + await client._handle_reconnect(ConnectionError("test")) + first_delay = mock_sleep.call_args[0][0] + + await client._handle_reconnect(ConnectionError("test")) + second_delay = mock_sleep.call_args[0][0] + + assert second_delay > first_delay + + @pytest.mark.asyncio + async def test_stream_client_max_retries( + self, + ) -> None: + """Test that client stops after max retries.""" + config = StreamConfig(max_retries=3) + client = HorizonStreamClient(config) + client._running = True + client._retry_count = 3 + + with patch("astroml.ingestion.stream.asyncio.sleep", new_callable=AsyncMock): + await client._handle_reconnect(ConnectionError("test")) + + assert client._running is False + + +class TestRateLimitTrackerIntegration: + """Integration tests for rate limiting in streaming.""" + + def test_rate_limit_tracker_initialization( + self, + ) -> None: + """Test rate limit tracker initialization.""" + tracker = RateLimitTracker(backoff_factor=1.5) + + assert tracker.backoff_factor == 1.5 + assert tracker.current_backoff == 1.0 + assert tracker.request_count == 0 + + def test_rate_limit_request_tracking( + self, + ) -> None: + """Test request tracking for rate limiting.""" + tracker = RateLimitTracker() + + tracker.record_request() + tracker.record_request() + tracker.record_request() + + assert tracker.request_count == 3 + + def test_rate_limit_backoff_calculation( + self, + ) -> None: + """Test backoff time calculation after rate limit.""" + tracker = RateLimitTracker(backoff_factor=2.0) + + backoff1 = tracker.handle_rate_limit() + assert backoff1 == 2.0 + + backoff2 = tracker.handle_rate_limit() + assert backoff2 == 4.0 + + def test_rate_limit_throttling_decision( + self, + ) -> None: + """Test throttling decision based on recent rate limits.""" + tracker = RateLimitTracker() + + # No rate limit yet + assert tracker.should_throttle() is False + + # Hit rate limit + tracker.handle_rate_limit() + + # Should throttle immediately after + assert tracker.should_throttle() is True + + def test_request_rate_calculation( + self, + ) -> None: + """Test request rate calculation.""" + tracker = RateLimitTracker() + + tracker.record_request() + tracker.record_request() + tracker.record_request() + + rate = tracker.get_request_rate() + assert rate > 0 + + +class TestEnhancedStreamingIntegration: + """Integration tests for enhanced streaming service.""" + + @pytest.mark.asyncio + async def test_enhanced_stream_config( + self, + ) -> None: + """Test enhanced stream configuration.""" + config = EnhancedStreamConfig( + horizon_url="https://horizon-testnet.stellar.org", + stream_type="effects", + cursor="now", + max_retries=5, + batch_size=100, + ) + + assert config.horizon_url == "https://horizon-testnet.stellar.org" + assert config.stream_type == "effects" + assert config.cursor == "now" + assert config.max_retries == 5 + assert config.batch_size == 100 + + @pytest.mark.asyncio + async def test_stream_event_processing( + self, + mock_horizon_response: Dict[str, Any], + ) -> None: + """Test processing of stream events.""" + from astroml.ingestion.parsers import parse_transaction + + # Parse mock response + transaction = parse_transaction(mock_horizon_response) + + # Verify parsing + assert transaction.hash == mock_horizon_response["hash"] + assert transaction.source_account == mock_horizon_response["source_account"] + assert transaction.ledger_sequence == mock_horizon_response["ledger"] + + @pytest.mark.asyncio + async def test_stream_batch_processing( + self, + ) -> None: + """Test batch processing of stream events.""" + events = [] + for i in range(10): + event = MagicMock() + event.data = json.dumps({ + "hash": "x" * 64, + "ledger": 1000 + i, + "source_account": f"G{'A' * 55}", + "created_at": "2024-01-01T00:00:00Z", + "fee_charged": 100, + "operation_count": 1, + "successful": True, + "memo_type": "none", + "paging_token": str(1000 + i), + }) + events.append(event) + + # Process batch + processed_count = 0 + for event in events: + data = json.loads(event.data) + if data.get("hash"): + processed_count += 1 + + assert processed_count == 10 + + +class TestStreamingPipelineIntegration: + """Integration tests for complete streaming pipeline.""" + + @pytest.mark.asyncio + async def test_stream_to_database_pipeline( + self, + test_session, + mock_horizon_response: Dict[str, Any], + ) -> None: + """Test complete pipeline from stream to database.""" + from astroml.ingestion.parsers import parse_transaction + from astroml.db.schema import Ledger, Transaction + + # Create ledger first + ledger = Ledger( + sequence=1000, + hash="a" * 64, + closed_at=datetime(2024, 1, 1), + successful_transaction_count=1, + failed_transaction_count=0, + operation_count=1, + ) + test_session.add(ledger) + test_session.commit() + + # Parse and store transaction from stream + transaction = parse_transaction(mock_horizon_response) + test_session.add(transaction) + test_session.commit() + + # Verify database state + stored_tx = test_session.query(Transaction).filter_by( + hash=mock_horizon_response["hash"] + ).first() + + assert stored_tx is not None + assert stored_tx.source_account == mock_horizon_response["source_account"] + + @pytest.mark.asyncio + async def test_stream_cursor_persistence( + self, + temp_output_dir: Path, + ) -> None: + """Test cursor persistence across stream restarts.""" + cursor_file = temp_output_dir / ".stream_cursor" + + # Save cursor + cursor = "12345" + cursor_file.write_text(cursor) + + # Load cursor + loaded_cursor = cursor_file.read_text().strip() + + assert loaded_cursor == cursor + + @pytest.mark.asyncio + async def test_stream_error_recovery( + self, + ) -> None: + """Test stream recovery from transient errors.""" + config = StreamConfig(max_retries=3) + client = HorizonStreamClient(config) + client._running = True + + # Simulate error + error_count = [0] + + async def mock_fetch(): + error_count[0] += 1 + if error_count[0] < 3: + raise ConnectionError("Transient error") + return {"data": "success"} + + # Should recover after retries + with patch.object(client, "_handle_reconnect", new_callable=AsyncMock): + try: + for _ in range(3): + await mock_fetch() + except ConnectionError: + pass + + assert error_count[0] == 3 + + @pytest.mark.asyncio + async def test_stream_metrics_tracking( + self, + ) -> None: + """Test metrics tracking during streaming.""" + from astroml.ingestion.metrics import ( + STREAM_RECORDS_PROCESSED, + STREAM_ERRORS, + ) + + # Simulate processing + STREAM_RECORDS_PROCESSED.inc() + STREAM_RECORDS_PROCESSED.inc() + STREAM_RECORDS_PROCESSED.inc() + + # Simulate error + STREAM_ERRORS.inc() + + # Verify metrics (in real scenario, would query Prometheus) + # Here we just verify the metrics can be incremented + assert STREAM_RECORDS_PROCESSED._value.get() == 3 + assert STREAM_ERRORS._value.get() == 1 + + @pytest.mark.asyncio + async def test_stream_graceful_shutdown( + self, + ) -> None: + """Test graceful shutdown of streaming client.""" + config = StreamConfig() + client = HorizonStreamClient(config) + + # Simulate running state + client._running = True + + # Trigger shutdown + client._running = False + + assert client._running is False diff --git a/tests/integration/test_validation.py b/tests/integration/test_validation.py new file mode 100644 index 0000000..2b68a7b --- /dev/null +++ b/tests/integration/test_validation.py @@ -0,0 +1,404 @@ +"""Integration tests for validation and calibration pipeline. + +These tests verify the complete workflow from model predictions +to validation, calibration, and quality assurance. +""" +from __future__ import annotations + +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import pandas as pd +import pytest + +from astroml.validation.calibration import CalibrationAnalyzer +from astroml.validation.data_quality import ( + DataQualityReport, + TemporalValidator, + ValidationResult, +) +from astroml.validation.validator import ( + TransactionValidator, + validate_transaction, + CorruptionType, +) + + +class TestCalibrationIntegration: + """Integration tests for model calibration.""" + + def test_calibration_analysis_workflow( + self, + fraud_labels: np.ndarray, + fraud_scores: np.ndarray, + ) -> None: + """Test complete calibration analysis workflow.""" + analyzer = CalibrationAnalyzer(n_bins=10, strategy='uniform') + + # Compute calibration curve + fraction_positives, mean_predicted = analyzer.compute_calibration_curve( + fraud_labels, fraud_scores + ) + + # Verify calibration data + assert len(fraction_positives) == len(mean_predicted) + assert len(fraction_positives) <= 10 + assert np.all(fraction_positives >= 0) + assert np.all(fraction_positives <= 1) + assert np.all(mean_predicted >= 0) + assert np.all(mean_predicted <= 1) + + def test_calibration_metrics_computation( + self, + fraud_labels: np.ndarray, + fraud_scores: np.ndarray, + ) -> None: + """Test comprehensive calibration metrics computation.""" + analyzer = CalibrationAnalyzer(n_bins=10) + + # Compute metrics + metrics = analyzer.compute_calibration_metrics( + fraud_labels, fraud_scores + ) + + # Verify metrics + assert 'brier_score' in metrics + assert 'log_loss' in metrics + assert metrics['brier_score'] >= 0 + assert metrics['log_loss'] >= 0 + + def test_calibration_with_perfect_predictions( + self, + ) -> None: + """Test calibration with perfectly calibrated predictions.""" + # Create perfectly calibrated data + np.random.seed(42) + n_samples = 1000 + y_true = np.random.randint(0, 2, n_samples) + y_prob = y_true.astype(float) + np.random.normal(0, 0.05, n_samples) + y_prob = np.clip(y_prob, 0.01, 0.99) + + analyzer = CalibrationAnalyzer(n_bins=10) + metrics = analyzer.compute_calibration_metrics(y_true, y_prob) + + # Perfect calibration should have low Brier score + assert metrics['brier_score'] < 0.1 + + def test_calibration_with_random_predictions( + self, + ) -> None: + """Test calibration with random (uncalibrated) predictions.""" + # Create random predictions + np.random.seed(42) + n_samples = 1000 + y_true = np.random.randint(0, 2, n_samples) + y_prob = np.random.uniform(0, 1, n_samples) + + analyzer = CalibrationAnalyzer(n_bins=10) + metrics = analyzer.compute_calibration_metrics(y_true, y_prob) + + # Random predictions should have higher Brier score + assert metrics['brier_score'] >= 0.2 + + +class TestDataQualityIntegration: + """Integration tests for data quality validation.""" + + def test_transaction_validation_workflow( + self, + sample_transaction_data: List[Dict[str, Any]], + ) -> None: + """Test complete transaction validation workflow.""" + validator = TransactionValidator( + required_fields={"hash", "source_account", "created_at", "fee"}, + field_types={"fee": int, "operation_count": int}, + ) + + # Validate transactions + results = validator.validate_batch(sample_transaction_data) + + # Verify results + assert len(results) == len(sample_transaction_data) + assert all(isinstance(r, type(results[0])) for r in results) + + def test_data_quality_report_generation( + self, + ) -> None: + """Test comprehensive data quality report generation.""" + # Create sample transactions with various issues + transactions = [ + {"id": "tx1", "source_account": "GAAA", "amount": 100.0}, + {"id": "tx2", "amount": 50.0}, # Missing source_account + {"id": "tx3", "source_account": "GBBB", "amount": "invalid"}, # Invalid type + {"id": "tx4", "source_account": "GCCC", "amount": 200.0}, + ] + + validator = TransactionValidator( + required_fields={"id", "source_account", "amount"}, + field_types={"amount": (int, float)}, + ) + + # Validate and generate report + results = validator.validate_batch(transactions) + + valid_count = sum(1 for r in results if r.is_valid) + report = DataQualityReport( + total_records=len(transactions), + valid_records=valid_count, + validation_results=[ + ValidationResult( + is_valid=r.is_valid, + error_type=r.errors[0].error_type if r.errors else None, + message=r.errors[0].message if r.errors else "Valid", + ) + for r in results + ], + ) + + # Verify report + assert report.total_records == 4 + assert report.valid_records == 2 + assert report.quality_score == 50.0 + assert len(report.error_types) > 0 + + def test_temporal_validation_workflow( + self, + ) -> None: + """Test temporal data validation workflow.""" + validator = TemporalValidator(timestamp_field="timestamp") + + # Create transactions with timestamps + base_time = datetime(2024, 1, 1) + transactions = [ + {"id": "tx1", "timestamp": base_time}, + {"id": "tx2", "timestamp": base_time + timedelta(hours=1)}, + {"id": "tx3", "timestamp": base_time + timedelta(hours=2)}, + ] + + # Validate ordering + result = validator.validate_timestamp_ordering(transactions) + + # Should be valid (monotonically increasing) + assert result.is_valid + + def test_temporal_validation_with_out_of_order( + self, + ) -> None: + """Test temporal validation with out-of-order timestamps.""" + validator = TemporalValidator(timestamp_field="timestamp") + + # Create transactions with out-of-order timestamps + base_time = datetime(2024, 1, 1) + transactions = [ + {"id": "tx1", "timestamp": base_time + timedelta(hours=2)}, + {"id": "tx2", "timestamp": base_time}, + {"id": "tx3", "timestamp": base_time + timedelta(hours=1)}, + ] + + # Validate ordering + result = validator.validate_timestamp_ordering(transactions) + + # Should be invalid + assert not result.is_valid + + +class TestValidationPipelineIntegration: + """Integration tests for complete validation pipeline.""" + + def test_model_prediction_validation_workflow( + self, + fraud_labels: np.ndarray, + fraud_scores: np.ndarray, + ) -> None: + """Test validation of model predictions before calibration.""" + # Validate prediction format + assert len(fraud_labels) == len(fraud_scores) + assert np.all((fraud_scores >= 0) & (fraud_scores <= 1)) + + # Check for NaN or infinite values + assert not np.any(np.isnan(fraud_scores)) + assert not np.any(np.isinf(fraud_scores)) + + # Proceed with calibration + analyzer = CalibrationAnalyzer(n_bins=10) + metrics = analyzer.compute_calibration_metrics(fraud_labels, fraud_scores) + + # Verify metrics are valid + assert all(np.isfinite(v) for v in metrics.values()) + + def test_end_to_end_validation_pipeline( + self, + sample_transaction_data: List[Dict[str, Any]], + fraud_labels: np.ndarray, + fraud_scores: np.ndarray, + ) -> None: + """Test complete validation pipeline from transactions to calibrated metrics.""" + # Step 1: Validate transaction data + validator = TransactionValidator( + required_fields={"hash", "source_account", "created_at"}, + ) + tx_results = validator.validate_batch(sample_transaction_data) + + # Step 2: Filter valid transactions + valid_tx_count = sum(1 for r in tx_results if r.is_valid) + assert valid_tx_count > 0 + + # Step 3: Validate prediction data + assert len(fraud_labels) == len(fraud_scores) + assert not np.any(np.isnan(fraud_scores)) + + # Step 4: Compute calibration metrics + analyzer = CalibrationAnalyzer(n_bins=10) + metrics = analyzer.compute_calibration_metrics(fraud_labels, fraud_scores) + + # Step 5: Verify pipeline results + assert 'brier_score' in metrics + assert metrics['brier_score'] >= 0 + assert valid_tx_count == len(sample_transaction_data) + + def test_validation_with_corrupted_data( + self, + ) -> None: + """Test validation pipeline with corrupted data.""" + # Create corrupted transactions + corrupted_transactions = [ + {"id": None, "source_account": "GAAA", "amount": 100.0}, # Null ID + {"id": "tx2", "amount": 50.0}, # Missing source_account + {"amount": 200.0}, # Missing both id and source_account + ] + + validator = TransactionValidator( + required_fields={"id", "source_account"}, + ) + + # Validate + results = validator.validate_batch(corrupted_transactions) + + # All should be invalid + assert all(not r.is_valid for r in results) + + # Check error types + error_types = {r.errors[0].error_type for r in results if r.errors} + assert CorruptionType.MISSING_FIELD in error_types + + def test_validation_report_persistence( + self, + temp_output_dir: Path, + ) -> None: + """Test saving and loading validation reports.""" + # Create a validation report + report = DataQualityReport( + total_records=100, + valid_records=95, + validation_results=[ + ValidationResult( + is_valid=True, + message="Valid transaction", + ) + for _ in range(95) + ] + [ + ValidationResult( + is_valid=False, + error_type="MISSING_FIELD", + message="Missing required field", + ) + for _ in range(5) + ], + ) + + # Save report + report_path = temp_output_dir / "validation_report.json" + import json + with open(report_path, 'w') as f: + json.dump({ + 'total_records': report.total_records, + 'valid_records': report.valid_records, + 'quality_score': report.quality_score, + 'error_types': list(report.error_types), + }, f) + + # Verify file exists + assert report_path.exists() + + # Load and verify + with open(report_path, 'r') as f: + loaded = json.load(f) + + assert loaded['total_records'] == 100 + assert loaded['valid_records'] == 95 + assert loaded['quality_score'] == 95.0 + + +class TestCalibrationVisualizationIntegration: + """Integration tests for calibration visualization.""" + + def test_calibration_plot_generation( + self, + fraud_labels: np.ndarray, + fraud_scores: np.ndarray, + temp_output_dir: Path, + ) -> None: + """Test calibration plot generation and saving.""" + analyzer = CalibrationAnalyzer(n_bins=10) + + # Compute calibration curve + fraction_positives, mean_predicted = analyzer.compute_calibration_curve( + fraud_labels, fraud_scores + ) + + # Generate plot + import matplotlib.pyplot as plt + + plt.figure(figsize=(8, 6)) + plt.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated') + plt.plot(mean_predicted, fraction_positives, 's-', label='Model') + plt.xlabel('Mean predicted probability') + plt.ylabel('Fraction of positives') + plt.title('Calibration Curve') + plt.legend() + + # Save plot + plot_path = temp_output_dir / "calibration_curve.png" + plt.savefig(plot_path, dpi=100, bbox_inches='tight') + plt.close() + + # Verify file exists + assert plot_path.exists() + + def test_calibration_metrics_report( + self, + fraud_labels: np.ndarray, + fraud_scores: np.ndarray, + temp_output_dir: Path, + ) -> None: + """Test generating comprehensive calibration metrics report.""" + analyzer = CalibrationAnalyzer(n_bins=10) + + # Compute metrics + metrics = analyzer.compute_calibration_metrics(fraud_labels, fraud_scores) + + # Generate report + report = { + 'calibration_metrics': metrics, + 'n_samples': len(fraud_labels), + 'n_bins': analyzer.n_bins, + 'strategy': analyzer.strategy, + 'generated_at': datetime.utcnow().isoformat(), + } + + # Save report + report_path = temp_output_dir / "calibration_report.json" + import json + with open(report_path, 'w') as f: + json.dump(report, f, indent=2) + + # Verify file exists and contains expected data + assert report_path.exists() + with open(report_path, 'r') as f: + loaded = json.load(f) + + assert 'calibration_metrics' in loaded + assert 'brier_score' in loaded['calibration_metrics'] + assert loaded['n_samples'] == len(fraud_labels) diff --git a/web/src/App.tsx b/web/src/App.tsx index 88b9ddb..bcd3fa2 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -1,5 +1,6 @@ import { LoyaltyDashboard } from './components/LoyaltyDashboard' import { ModelMonitoringDashboard } from './components/ModelMonitoringDashboard' +import { TransactionHistoryPage } from './components/TransactionHistory' export default function App() { return ( @@ -9,6 +10,9 @@ export default function App() {

Loyalty Dashboard

+
+

Transaction History

+ ) } diff --git a/web/src/api/transactions.ts b/web/src/api/transactions.ts new file mode 100644 index 0000000..921f932 --- /dev/null +++ b/web/src/api/transactions.ts @@ -0,0 +1,74 @@ +import type { BlockchainTransaction, TransactionHistoryResponse } from '../lib/types' + +// Mock transaction data for demonstration +const mockTransactions: BlockchainTransaction[] = Array.from({ length: 250 }).map((_, i) => { + const operationTypes = ['payment', 'create_account', 'change_trust', 'path_payment', 'manage_buy_offer'] + const assetCodes = ['XLM', 'USDC', 'EURC', 'BTC', 'ETH'] + const baseTime = Date.now() - i * 3600000 // One hour apart + + return { + hash: `tx_${'a'.repeat(56)}${i.toString().padStart(8, '0')}`, + ledgerSequence: 50000 + i, + sourceAccount: `G${'A'.repeat(28)}${'B'.repeat(27)}`, + destinationAccount: i % 3 === 0 ? undefined : `G${'C'.repeat(28)}${'D'.repeat(27)}`, + amount: i % 4 === 0 ? undefined : Math.floor(Math.random() * 10000) + 100, + assetCode: assetCodes[i % assetCodes.length], + assetIssuer: i % 2 === 0 ? undefined : `G${'E'.repeat(28)}${'F'.repeat(27)}`, + operationType: operationTypes[i % operationTypes.length], + createdAt: new Date(baseTime).toISOString(), + fee: 100 + (i % 5) * 50, + successful: i % 10 !== 0, // 10% failure rate + memoType: i % 7 === 0 ? 'text' : undefined, + } +}) + +export async function getTransactionHistory( + page: number, + pageSize: number, + filters?: { + sourceAccount?: string + operationType?: string + startDate?: string + endDate?: string + } +): Promise { + // Simulate API delay + await new Promise((resolve) => setTimeout(resolve, 300)) + + let filtered = [...mockTransactions] + + // Apply filters if provided + if (filters?.sourceAccount) { + filtered = filtered.filter((tx) => + tx.sourceAccount.toLowerCase().includes(filters.sourceAccount!.toLowerCase()) + ) + } + + if (filters?.operationType) { + filtered = filtered.filter((tx) => tx.operationType === filters.operationType) + } + + if (filters?.startDate) { + filtered = filtered.filter((tx) => new Date(tx.createdAt) >= new Date(filters.startDate!)) + } + + if (filters?.endDate) { + filtered = filtered.filter((tx) => new Date(tx.createdAt) <= new Date(filters.endDate!)) + } + + const start = page * pageSize + const end = start + pageSize + const data = filtered.slice(start, end) + + return { + data, + page, + pageSize, + total: filtered.length, + } +} + +export async function getTransactionByHash(hash: string): Promise { + await new Promise((resolve) => setTimeout(resolve, 200)) + return mockTransactions.find((tx) => tx.hash === hash) || null +} diff --git a/web/src/components/TransactionHistory/TransactionHistoryPage.tsx b/web/src/components/TransactionHistory/TransactionHistoryPage.tsx new file mode 100644 index 0000000..ed64c23 --- /dev/null +++ b/web/src/components/TransactionHistory/TransactionHistoryPage.tsx @@ -0,0 +1,155 @@ +import { useState } from 'react' +import { useTransactionHistory } from '../../hooks/useTransactionHistory' +import { TransactionHistoryTable } from './TransactionHistoryTable' + +export function TransactionHistoryPage() { + const [page, setPage] = useState(0) + const pageSize = 20 + + const [filters, setFilters] = useState<{ + sourceAccount?: string + operationType?: string + startDate?: string + endDate?: string + }>({}) + + const { data: history, isLoading: loading } = useTransactionHistory(page, pageSize, filters) + + const handleFilterChange = (key: string, value: string) => { + setFilters((prev) => ({ + ...prev, + [key]: value || undefined, + })) + setPage(0) // Reset to first page when filters change + } + + return ( +
+
+

Transaction History

+

+ View and search Stellar blockchain transactions +

+
+ +
+
+
+ + handleFilterChange('sourceAccount', e.target.value)} + style={{ + width: '100%', + padding: '8px 12px', + border: '1px solid #ddd', + borderRadius: 4, + fontSize: 14, + }} + /> +
+ +
+ + +
+ +
+ + handleFilterChange('startDate', e.target.value)} + style={{ + width: '100%', + padding: '8px 12px', + border: '1px solid #ddd', + borderRadius: 4, + fontSize: 14, + }} + /> +
+ +
+ + handleFilterChange('endDate', e.target.value)} + style={{ + width: '100%', + padding: '8px 12px', + border: '1px solid #ddd', + borderRadius: 4, + fontSize: 14, + }} + /> +
+
+ +
+ +
+
+ + + + {history && ( +
+ Showing {Math.min((page + 1) * pageSize, history.total)} of {history.total} transactions +
+ )} +
+ ) +} diff --git a/web/src/components/TransactionHistory/TransactionHistoryTable.tsx b/web/src/components/TransactionHistory/TransactionHistoryTable.tsx new file mode 100644 index 0000000..7556421 --- /dev/null +++ b/web/src/components/TransactionHistory/TransactionHistoryTable.tsx @@ -0,0 +1,136 @@ +import React from 'react' +import type { BlockchainTransaction, TransactionHistoryResponse } from '../../lib/types' + +export function TransactionHistoryTable({ + response, + loading, + page, + pageSize, + onPageChange, +}: { + response: TransactionHistoryResponse | undefined + loading: boolean + page: number + pageSize: number + onPageChange: (p: number) => void +}) { + const total = response?.total ?? 0 + const totalPages = Math.max(1, Math.ceil(total / pageSize)) + + const formatHash = (hash: string) => { + return `${hash.slice(0, 8)}...${hash.slice(-8)}` + } + + const formatAddress = (address: string) => { + return `${address.slice(0, 4)}...${address.slice(-4)}` + } + + return ( +
+
+

Transaction History

+
+ + Page {page + 1} / {totalPages} + +
+
+
+ + + + + + + + + + + + + + + + + {loading && ( + + )} + {!loading && response?.data.length === 0 && ( + + )} + {!loading && response?.data.map((tx) => ( + + + + + + + + + + + + + ))} + +
HashLedgerSourceDestinationTypeAmountAssetFeeStatusDate
Loading...
No transactions found
+ + {formatHash(tx.hash)} + + {tx.ledgerSequence} + + {formatAddress(tx.sourceAccount)} + + + {tx.destinationAccount ? ( + + {formatAddress(tx.destinationAccount)} + + ) : ( + - + )} + {tx.operationType} + {tx.amount !== undefined ? tx.amount.toLocaleString() : '-'} + {tx.assetCode || 'XLM'}{tx.fee} stroops + + {tx.successful ? 'Success' : 'Failed'} + + + {new Date(tx.createdAt).toLocaleString()} +
+
+
+ ) +} + +const th: React.CSSProperties = { + textAlign: 'left', + borderBottom: '2px solid #ddd', + padding: 12, + fontWeight: 600, + fontSize: '13px', + color: '#555' +} +const td: React.CSSProperties = { + borderBottom: '1px solid #f1f1f1', + padding: 10, + fontSize: '13px' +} diff --git a/web/src/components/TransactionHistory/index.ts b/web/src/components/TransactionHistory/index.ts new file mode 100644 index 0000000..5f86137 --- /dev/null +++ b/web/src/components/TransactionHistory/index.ts @@ -0,0 +1,2 @@ +export { TransactionHistoryPage } from './TransactionHistoryPage' +export { TransactionHistoryTable } from './TransactionHistoryTable' diff --git a/web/src/hooks/useTransactionHistory.ts b/web/src/hooks/useTransactionHistory.ts new file mode 100644 index 0000000..fd19268 --- /dev/null +++ b/web/src/hooks/useTransactionHistory.ts @@ -0,0 +1,19 @@ +import { useQuery } from '@tanstack/react-query' +import { getTransactionHistory } from '../api/transactions' +import type { TransactionHistoryResponse } from '../lib/types' + +export function useTransactionHistory( + page: number, + pageSize: number, + filters?: { + sourceAccount?: string + operationType?: string + startDate?: string + endDate?: string + } +) { + return useQuery({ + queryKey: ['transactions', page, pageSize, filters], + queryFn: () => getTransactionHistory(page, pageSize, filters), + }) +} diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index ed81212..f1e0c5f 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -82,3 +82,25 @@ export type FraudStats = { recentAlerts: FraudAlert[] riskOverTime: { date: string; score: number }[] } + +export type BlockchainTransaction = { + hash: string + ledgerSequence: number + sourceAccount: string + destinationAccount?: string + amount?: number + assetCode?: string + assetIssuer?: string + operationType: string + createdAt: string // ISO + fee: number + successful: boolean + memoType?: string +} + +export type TransactionHistoryResponse = { + data: BlockchainTransaction[] + page: number + pageSize: number + total: number +}