diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/balancer/BalancerUtil.java b/core/src/main/java/com/predic8/membrane/core/interceptor/balancer/BalancerUtil.java index d75bcff266..cf92c33665 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/balancer/BalancerUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/balancer/BalancerUtil.java @@ -15,71 +15,121 @@ package com.predic8.membrane.core.interceptor.balancer; import com.predic8.membrane.core.interceptor.*; +import com.predic8.membrane.core.interceptor.flow.AbstractFlowInterceptor; +import com.predic8.membrane.core.interceptor.flow.IfInterceptor; +import com.predic8.membrane.core.interceptor.flow.choice.AbstractCaseOtherwise; +import com.predic8.membrane.core.interceptor.flow.choice.ChooseInterceptor; import com.predic8.membrane.core.proxies.*; -import com.predic8.membrane.core.router.*; +import com.predic8.membrane.core.router.Router; import java.util.*; +import java.util.stream.Stream; public class BalancerUtil { - public static List collectClusters(Router router) { - ArrayList result = new ArrayList<>(); - for (Proxy r : router.getRuleManager().getRules()) { - List interceptors = r.getFlow(); - if (interceptors != null) - for (Interceptor i : interceptors) - if (i instanceof LoadBalancingInterceptor) - result.addAll(((LoadBalancingInterceptor)i).getClusterManager().getClusters()); + /** + * The various getFlow() methods only expose the direct child flow of an interceptor. + * Branching interceptors such as if/else and choose/case keep additional flow lists + * outside of getFlow(), so those branches must be added explicitly while walking the flow tree. + */ + private static Stream> allFlows(Router router) { + Set visited = Collections.newSetFromMap(new IdentityHashMap<>()); + return Stream.concat(ruleFlows(router), globalFlows(router)) + .flatMap(flow -> allFlows(flow, visited)); + } + + private static Stream> ruleFlows(Router router) { + return router.getRuleManager() + .getRules() + .stream() + .map(Proxy::getFlow); + } + + private static Stream> globalFlows(Router router) { + return Optional.ofNullable(router.getRegistry()) + .stream() + .flatMap(registry -> registry.getBean(GlobalInterceptor.class).stream()) + .map(GlobalInterceptor::getFlow); + } + + private static Stream> allFlows(List flow, Set visited) { + if (flow == null) { + return Stream.empty(); } - return result; + return Stream.concat( + Stream.of(flow), + flow.stream().flatMap(interceptor -> childFlows(interceptor, visited)) + ); } - public static List collectBalancers(Router router) { - ArrayList result = new ArrayList<>(); - for (Proxy r : router.getRuleManager().getRules()) { - List interceptors = r.getFlow(); - if (interceptors != null) - for (Interceptor i : interceptors) - if (i instanceof LoadBalancingInterceptor) - result.add((LoadBalancingInterceptor)i); + private static Stream> childFlows(Interceptor interceptor, Set visited) { + if (interceptor == null || !visited.add(interceptor)) { + return Stream.empty(); } - return result; + return directChildFlows(interceptor) + .flatMap(flow -> allFlows(flow, visited)); } - public static Balancer lookupBalancer(Router router, String name) { - for (Proxy r : router.getRuleManager().getRules()) { - List interceptors = r.getFlow(); - if (interceptors != null) - for (Interceptor i : interceptors) - if (i instanceof LoadBalancingInterceptor) - if (((LoadBalancingInterceptor)i).getName().equalsIgnoreCase(name)) - return ((LoadBalancingInterceptor) i).getClusterManager(); + private static Stream> directChildFlows(Interceptor interceptor) { + if (interceptor instanceof ChooseInterceptor chooseInterceptor) { + return chooseInterceptor.getChoices().stream() + .map(AbstractCaseOtherwise::getFlow); } - throw new RuntimeException("balancer with name \"" + name + "\" not found."); + if (interceptor instanceof IfInterceptor ifInterceptor) { + return Stream.of(ifInterceptor.getFlow(), ifInterceptor.getElseInterceptor()); + } + if (interceptor instanceof AbstractFlowInterceptor flowInterceptor) { + return Stream.of(flowInterceptor.getFlow()); + } + return Stream.empty(); + } + + private static Stream balancerBeans(Router router) { + return Stream.concat( + Optional.ofNullable(router.getRegistry()) + .stream() + .flatMap(registry -> registry.getBeans(Balancer.class).stream()), + Optional.ofNullable(router.getBeanFactory()) + .map(ctx -> ctx.getBeansOfType(Balancer.class).values().stream()) + .orElseGet(Stream::empty) + ).distinct(); + } + + public static List collectBalancers(Router router) { + return allFlows(router) + .filter(Objects::nonNull) + .flatMap(List::stream) + .filter(LoadBalancingInterceptor.class::isInstance) + .map(LoadBalancingInterceptor.class::cast) + .distinct() + .toList(); + } + + public static List collectClusters(Router router) { + return Stream.concat( + collectBalancers(router).stream() + .flatMap(lbi -> lbi.getClusterManager().getClusters().stream()), + balancerBeans(router).flatMap(b -> b.getClusters().stream()) + ).distinct().toList(); + } + + public static Balancer lookupBalancer(Router router, String name) { + return collectBalancers(router).stream() + .filter(lbi -> lbi.getName() != null && lbi.getName().equalsIgnoreCase(name)) + .map(LoadBalancingInterceptor::getClusterManager) + .findFirst() + .orElseThrow(() -> new RuntimeException("balancer with name %s not found.".formatted(name))); } public static LoadBalancingInterceptor lookupBalancerInterceptor(Router router, String name) { - for (Proxy r : router.getRuleManager().getRules()) { - List interceptors = r.getFlow(); - if (interceptors != null) - for (Interceptor i : interceptors) - if (i instanceof LoadBalancingInterceptor) - if (((LoadBalancingInterceptor)i).getName().equalsIgnoreCase(name)) - return (LoadBalancingInterceptor) i; - } - throw new RuntimeException("balancer with name \"" + name + "\" not found."); + return collectBalancers(router).stream() + .filter(lbi -> lbi.getName() != null && lbi.getName().equalsIgnoreCase(name)) + .findFirst() + .orElseThrow(() -> new RuntimeException("balancer with name %s not found.".formatted(name))); } public static boolean hasLoadBalancing(Router router) { - for (Proxy r : router.getRuleManager().getRules()) { - List interceptors = r.getFlow(); - if (interceptors == null) - continue; - for (Interceptor i : interceptors) - if (i instanceof LoadBalancingInterceptor) - return true; - } - return false; + return !collectBalancers(router).isEmpty(); } public static void up(Router router, String balancerName, String cName, String host, int port) { @@ -122,8 +172,8 @@ public static List getSessionsByNode(Router router, String balancerName return lookupBalancer(router, balancerName).getSessionsByNode(cName, node); } - public static String getSingleClusterNameOrDefault(Balancer balancer){ - if(balancer.getClusters().size() == 1) + public static String getSingleClusterNameOrDefault(Balancer balancer) { + if (balancer.getClusters().size() == 1) return balancer.getClusters().getFirst().getName(); return Cluster.DEFAULT_NAME; }