diff --git a/formatter.xml b/formatter.xml index b17fa783..547869d3 100644 --- a/formatter.xml +++ b/formatter.xml @@ -13,8 +13,27 @@ - + + + + + + + + + + + + + + + + + + + + @@ -33,7 +52,6 @@ - @@ -45,7 +63,6 @@ - @@ -126,17 +143,7 @@ - - - - - - - - - - diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/artifactprovider/ArtifactProvider.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/artifactprovider/ArtifactProvider.java index f66ed720..7bc9553f 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/artifactprovider/ArtifactProvider.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/artifactprovider/ArtifactProvider.java @@ -12,9 +12,9 @@ public abstract class ArtifactProvider { public static ArtifactProvider createArtifactProvider(Configuration.ModuleConfiguration configuration) { return switch (configuration.name()) { - case "text" -> new TextArtifactProvider(configuration); - case "recursive_text" -> new RecursiveTextArtifactProvider(configuration); - default -> throw new IllegalStateException("Unexpected value: " + configuration.name()); + case "text" -> new TextArtifactProvider(configuration); + case "recursive_text" -> new RecursiveTextArtifactProvider(configuration); + default -> throw new IllegalStateException("Unexpected value: " + configuration.name()); }; } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java index 3cbe65c4..f5a712d1 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java @@ -65,6 +65,7 @@ public synchronized void put(String key, String value) { } } + @SuppressWarnings("unchecked") public synchronized T get(String key, Class clazz) { try { var jsonData = this.data.get(key); diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java new file mode 100644 index 00000000..74349b65 --- /dev/null +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java @@ -0,0 +1,71 @@ +package edu.kit.kastel.sdq.lissa.ratlr.classifier; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.ollama.OllamaChatModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import edu.kit.kastel.sdq.lissa.ratlr.Configuration; +import edu.kit.kastel.sdq.lissa.ratlr.Environment; +import okhttp3.Credentials; + +import java.time.Duration; +import java.util.Map; +import java.util.Objects; + +public class ChatLanguageModelProvider { + public static final String OPENAI = "openai"; + public static final String OLLAMA = "ollama"; + + private final String platform; + private String model; + + public ChatLanguageModelProvider(Configuration.ModuleConfiguration configuration) { + String[] modeXplatform = configuration.name().split(Classifier.CONFIG_NAME_SEPARATOR, 2); + if (modeXplatform.length == 1) { + this.platform = null; + return; + } + this.platform = modeXplatform[1]; + initModelPlatform(configuration); + } + + public ChatLanguageModel createChatModel() { + return switch (platform) { + case OPENAI -> createOpenAiChatModel(model); + case OLLAMA -> createOllamaChatModel(model); + default -> throw new IllegalArgumentException("Unsupported platform: " + platform); + }; + } + + private void initModelPlatform(Configuration.ModuleConfiguration configuration) { + this.model = switch (platform) { + case OPENAI -> configuration.argumentAsString("model", "gpt-4o-mini"); + case OLLAMA -> configuration.argumentAsString("model", "llama3:8b"); + default -> throw new IllegalArgumentException("Unsupported platform: " + platform); + }; + } + + public String modelName() { + return Objects.requireNonNull(model, "Model not initialized"); + } + + private static OllamaChatModel createOllamaChatModel(String model) { + String host = Environment.getenv("OLLAMA_HOST"); + String user = Environment.getenv("OLLAMA_USER"); + String password = Environment.getenv("OLLAMA_PASSWORD"); + + var ollama = OllamaChatModel.builder().baseUrl(host).modelName(model).timeout(Duration.ofMinutes(5)).temperature(0.0); + if (user != null && password != null && !user.isEmpty() && !password.isEmpty()) { + ollama.customHeaders(Map.of("Authorization", Credentials.basic(user, password))); + } + return ollama.build(); + } + + private static OpenAiChatModel createOpenAiChatModel(String model) { + String openAiOrganizationId = Environment.getenv("OPENAI_ORGANIZATION_ID"); + String openAiApiKey = Environment.getenv("OPENAI_API_KEY"); + if (openAiOrganizationId == null || openAiApiKey == null) { + throw new IllegalStateException("OPENAI_ORGANIZATION_ID or OPENAI_API_KEY environment variable not set"); + } + return new OpenAiChatModel.OpenAiChatModelBuilder().modelName(model).organizationId(openAiOrganizationId).apiKey(openAiApiKey).temperature(0.0).build(); + } +} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatModels.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatModels.java deleted file mode 100644 index 538ab4b5..00000000 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatModels.java +++ /dev/null @@ -1,32 +0,0 @@ -package edu.kit.kastel.sdq.lissa.ratlr.classifier; - -import dev.langchain4j.model.ollama.OllamaChatModel; -import dev.langchain4j.model.openai.OpenAiChatModel; -import edu.kit.kastel.sdq.lissa.ratlr.Environment; -import okhttp3.Credentials; - -import java.time.Duration; -import java.util.Map; - -public final class ChatModels { - public static OllamaChatModel createOllamaChatModel(String model) { - String host = Environment.getenv("OLLAMA_HOST"); - String user = Environment.getenv("OLLAMA_USER"); - String password = Environment.getenv("OLLAMA_PASSWORD"); - - var ollama = OllamaChatModel.builder().baseUrl(host).modelName(model).timeout(Duration.ofMinutes(5)).temperature(0.0); - if (user != null && password != null && !user.isEmpty() && !password.isEmpty()) { - ollama.customHeaders(Map.of("Authorization", Credentials.basic(user, password))); - } - return ollama.build(); - } - - public static OpenAiChatModel createOpenAiChatModel(String model) { - String openAiOrganizationId = Environment.getenv("OPENAI_ORGANIZATION_ID"); - String openAiApiKey = Environment.getenv("OPENAI_API_KEY"); - if (openAiOrganizationId == null || openAiApiKey == null) { - throw new IllegalStateException("OPENAI_ORGANIZATION_ID or OPENAI_API_KEY environment variable not set"); - } - return new OpenAiChatModel.OpenAiChatModelBuilder().modelName(model).organizationId(openAiOrganizationId).apiKey(openAiApiKey).temperature(0.0).build(); - } -} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ClassificationResult.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ClassificationResult.java new file mode 100644 index 00000000..79be4562 --- /dev/null +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ClassificationResult.java @@ -0,0 +1,19 @@ +package edu.kit.kastel.sdq.lissa.ratlr.classifier; + +import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element; + +public record ClassificationResult(Element source, Element target, double confidence) { + public ClassificationResult { + if (confidence < 0 || confidence > 1) { + throw new IllegalArgumentException("Confidence must be between 0 and 1"); + } + } + + public static ClassificationResult of(Element source, Element target) { + return new ClassificationResult(source, target, 1.0); + } + + public static ClassificationResult of(Element source, Element target, double confidence) { + return new ClassificationResult(source, target, confidence); + } +} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/Classifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/Classifier.java index 16d25b5c..15696167 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/Classifier.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/Classifier.java @@ -7,6 +7,7 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -15,11 +16,12 @@ public abstract class Classifier { private static final int THREADS = 100; + static final String CONFIG_NAME_SEPARATOR = "_"; protected final Logger logger = LoggerFactory.getLogger(this.getClass()); public List classify(ElementStore sourceStore, ElementStore targetStore) { - List> futureResults = new ArrayList<>(); + List>> futureResults = new ArrayList<>(); ExecutorService executor = Executors.newFixedThreadPool(THREADS); for (var query : sourceStore.getAllElements(true)) { var targetCandidates = targetStore.findSimilar(query.second()); @@ -36,24 +38,20 @@ public List classify(ElementStore sourceStore, ElementStor throw new IllegalStateException(e); } - return futureResults.stream().map(Future::resultNow).toList(); + return futureResults.stream().map(Future::resultNow).flatMap(Collection::stream).toList(); } - protected abstract ClassificationResult classify(Element source, List targets); + protected abstract List classify(Element source, List targets); protected abstract Classifier copyOf(); public static Classifier createClassifier(Configuration.ModuleConfiguration configuration) { - return switch (configuration.name()) { - case "mock" -> new MockClassifier(); - case "simple_ollama" -> new SimpleOllamaClassifier(configuration); - case "simple_openai" -> new SimpleOpenAiClassifier(configuration); - case "reasoning_ollama" -> new ReasoningOllamaClassifier(configuration); - case "reasoning_openai" -> new ReasoningOpenAiClassifier(configuration); - default -> throw new IllegalStateException("Unexpected value: " + configuration.name()); + return switch (configuration.name().split(CONFIG_NAME_SEPARATOR)[0]) { + case "mock" -> new MockClassifier(); + case "simple" -> new SimpleClassifier(configuration); + case "reasoning" -> new ReasoningClassifier(configuration); + default -> throw new IllegalStateException("Unexpected value: " + configuration.name()); }; } - public record ClassificationResult(Element source, List targets) { - } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/MockClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/MockClassifier.java index d0ae730f..3efcb612 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/MockClassifier.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/MockClassifier.java @@ -6,8 +6,8 @@ public class MockClassifier extends Classifier { @Override - protected ClassificationResult classify(Element source, List targets) { - return new ClassificationResult(source, targets); + protected List classify(Element source, List targets) { + return targets.stream().map(target -> ClassificationResult.of(source, target)).toList(); } @Override diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java index ca487525..26c6847e 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java @@ -18,44 +18,40 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -public abstract class ReasoningClassifier extends Classifier { +public class ReasoningClassifier extends Classifier { private final Cache cache; + private final ChatLanguageModelProvider provider; + private final ChatLanguageModel llm; private final String prompt; private final boolean useOriginalArtifacts; private final boolean useSystemMessage; - private final String model; - protected ReasoningClassifier(Configuration.ModuleConfiguration configuration, String model) { - this.cache = CacheManager.getDefaultInstance().getCache(this.getClass().getSimpleName() + "_" + model); + public ReasoningClassifier(Configuration.ModuleConfiguration configuration) { + this.provider = new ChatLanguageModelProvider(configuration); + this.cache = CacheManager.getDefaultInstance().getCache(this.getClass().getSimpleName() + "_" + provider.modelName()); this.prompt = Prompt.values()[configuration.argumentAsInt("prompt_id", 0)].prompt; this.useOriginalArtifacts = configuration.argumentAsBoolean("use_original_artifacts", false); this.useSystemMessage = configuration.argumentAsBoolean("use_system_message", true); - this.model = model; - this.llm = createChatModel(model); + this.llm = this.provider.createChatModel(); } - protected ReasoningClassifier(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts, boolean useSystemMessage) { + private ReasoningClassifier(Cache cache, ChatLanguageModelProvider provider, String prompt, boolean useOriginalArtifacts, boolean useSystemMessage) { this.cache = cache; - this.model = model; - this.llm = llm; + this.provider = provider; this.prompt = prompt; this.useOriginalArtifacts = useOriginalArtifacts; this.useSystemMessage = useSystemMessage; + this.llm = this.provider.createChatModel(); } @Override protected final Classifier copyOf() { - return copyOf(cache, model, createChatModel(model), prompt, useOriginalArtifacts, useSystemMessage); + return new ReasoningClassifier(cache, provider, prompt, useOriginalArtifacts, useSystemMessage); } - protected abstract ReasoningClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts, - boolean useSystemMessage); - - protected abstract ChatLanguageModel createChatModel(String model); - @Override - protected final ClassificationResult classify(Element source, List targets) { + protected final List classify(Element source, List targets) { List relatedTargets = new ArrayList<>(); var targetsToConsider = targets; @@ -87,7 +83,7 @@ protected final ClassificationResult classify(Element source, List targ relatedTargets.add(target); } } - return new ClassificationResult(source, relatedTargets); + return relatedTargets.stream().map(relatedTarget -> ClassificationResult.of(source, relatedTarget)).toList(); } private boolean isRelated(String llmResponse) { diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOllamaClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOllamaClassifier.java deleted file mode 100644 index d73f2f43..00000000 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOllamaClassifier.java +++ /dev/null @@ -1,27 +0,0 @@ -package edu.kit.kastel.sdq.lissa.ratlr.classifier; - -import dev.langchain4j.model.chat.ChatLanguageModel; -import edu.kit.kastel.sdq.lissa.ratlr.Configuration; -import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache; - -public class ReasoningOllamaClassifier extends ReasoningClassifier { - public ReasoningOllamaClassifier(Configuration.ModuleConfiguration configuration) { - super(configuration, configuration.argumentAsString("model", "llama3:8b")); - } - - protected ReasoningOllamaClassifier(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts, - boolean useSystemMessage) { - super(cache, model, llm, prompt, useOriginalArtifacts, useSystemMessage); - } - - @Override - protected ChatLanguageModel createChatModel(String model) { - return ChatModels.createOllamaChatModel(model); - } - - @Override - protected ReasoningClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts, - boolean useSystemMessage) { - return new ReasoningOllamaClassifier(cache, model, llm, prompt, useOriginalArtifacts, useSystemMessage); - } -} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOpenAiClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOpenAiClassifier.java deleted file mode 100644 index be89c4a2..00000000 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOpenAiClassifier.java +++ /dev/null @@ -1,27 +0,0 @@ -package edu.kit.kastel.sdq.lissa.ratlr.classifier; - -import dev.langchain4j.model.chat.ChatLanguageModel; -import edu.kit.kastel.sdq.lissa.ratlr.Configuration; -import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache; - -public class ReasoningOpenAiClassifier extends ReasoningClassifier { - public ReasoningOpenAiClassifier(Configuration.ModuleConfiguration configuration) { - super(configuration, configuration.argumentAsString("model", "gpt-4o-mini")); - } - - protected ReasoningOpenAiClassifier(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts, - boolean useSystemMessage) { - super(cache, model, llm, prompt, useOriginalArtifacts, useSystemMessage); - } - - @Override - protected ChatLanguageModel createChatModel(String model) { - return ChatModels.createOpenAiChatModel(model); - } - - @Override - protected ReasoningClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts, - boolean useSystemMessage) { - return new ReasoningOpenAiClassifier(cache, model, llm, prompt, useOriginalArtifacts, useSystemMessage); - } -} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java index eb90af2c..549e9133 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java @@ -10,7 +10,7 @@ import java.util.List; import java.util.UUID; -public abstract class SimpleClassifier extends Classifier { +public class SimpleClassifier extends Classifier { private static final String DEFAULT_TEMPLATE = """ Question: Here are two parts of software development artifacts. \n @@ -21,35 +21,32 @@ public abstract class SimpleClassifier extends Classifier { """; private final Cache cache; + private final ChatLanguageModelProvider provider; + private final ChatLanguageModel llm; private final String template; - private final String model; - protected SimpleClassifier(Configuration.ModuleConfiguration configuration, String model) { + public SimpleClassifier(Configuration.ModuleConfiguration configuration) { + this.provider = new ChatLanguageModelProvider(configuration); this.template = configuration.argumentAsString("template", DEFAULT_TEMPLATE); - this.cache = CacheManager.getDefaultInstance().getCache(this.getClass().getSimpleName() + "_" + model); - this.model = model; - this.llm = createChatModel(model); + this.cache = CacheManager.getDefaultInstance().getCache(this.getClass().getSimpleName() + "_" + provider.modelName()); + this.llm = provider.createChatModel(); } - protected SimpleClassifier(Cache cache, String model, ChatLanguageModel llm, String template) { + private SimpleClassifier(Cache cache, ChatLanguageModelProvider provider, String template) { this.cache = cache; - this.model = model; - this.llm = llm; + this.provider = provider; this.template = template; + this.llm = provider.createChatModel(); } @Override protected final Classifier copyOf() { - return copyOf(cache, model, createChatModel(model), template); + return new SimpleClassifier(cache, provider, template); } - protected abstract SimpleClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String template); - - protected abstract ChatLanguageModel createChatModel(String model); - @Override - protected final ClassificationResult classify(Element source, List targets) { + protected final List classify(Element source, List targets) { List relatedTargets = new ArrayList<>(); for (var target : targets) { @@ -59,7 +56,7 @@ protected final ClassificationResult classify(Element source, List targ relatedTargets.add(target); } } - return new ClassificationResult(source, relatedTargets); + return relatedTargets.stream().map(relatedTarget -> ClassificationResult.of(source, relatedTarget)).toList(); } private String classify(Element source, Element target) { diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOllamaClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOllamaClassifier.java deleted file mode 100644 index bbff00b9..00000000 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOllamaClassifier.java +++ /dev/null @@ -1,26 +0,0 @@ -package edu.kit.kastel.sdq.lissa.ratlr.classifier; - -import dev.langchain4j.model.chat.ChatLanguageModel; -import edu.kit.kastel.sdq.lissa.ratlr.Configuration; -import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache; - -public class SimpleOllamaClassifier extends SimpleClassifier { - - public SimpleOllamaClassifier(Configuration.ModuleConfiguration configuration) { - super(configuration, configuration.argumentAsString("model", "llama3:8b")); - } - - public SimpleOllamaClassifier(Cache cache, String model, ChatLanguageModel llm, String template) { - super(cache, model, llm, template); - } - - @Override - protected ChatLanguageModel createChatModel(String model) { - return ChatModels.createOllamaChatModel(model); - } - - @Override - protected SimpleClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String template) { - return new SimpleOllamaClassifier(cache, model, llm, template); - } -} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOpenAiClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOpenAiClassifier.java deleted file mode 100644 index 6bbf8fad..00000000 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOpenAiClassifier.java +++ /dev/null @@ -1,25 +0,0 @@ -package edu.kit.kastel.sdq.lissa.ratlr.classifier; - -import dev.langchain4j.model.chat.ChatLanguageModel; -import edu.kit.kastel.sdq.lissa.ratlr.Configuration; -import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache; - -public class SimpleOpenAiClassifier extends SimpleClassifier { - public SimpleOpenAiClassifier(Configuration.ModuleConfiguration configuration) { - super(configuration, configuration.argumentAsString("model", "gpt-4o-mini")); - } - - public SimpleOpenAiClassifier(Cache cache, String model, ChatLanguageModel llm, String template) { - super(cache, model, llm, template); - } - - @Override - protected ChatLanguageModel createChatModel(String model) { - return ChatModels.createOpenAiChatModel(model); - } - - @Override - protected SimpleClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String template) { - return new SimpleOpenAiClassifier(cache, model, llm, template); - } -} diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/EmbeddingCreator.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/EmbeddingCreator.java index 0e36063f..cca554d7 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/EmbeddingCreator.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/EmbeddingCreator.java @@ -14,10 +14,10 @@ public float[] calculateEmbedding(Element element) { public static EmbeddingCreator createEmbeddingCreator(Configuration.ModuleConfiguration configuration) { return switch (configuration.name()) { - case "ollama" -> new OllamaEmbeddingCreator(configuration); - case "openai" -> new OpenAiEmbeddingCreator(configuration); - case "onnx" -> new OnnxEmbeddingCreator(configuration); - default -> throw new IllegalStateException("Unexpected value: " + configuration.name()); + case "ollama" -> new OllamaEmbeddingCreator(configuration); + case "openai" -> new OpenAiEmbeddingCreator(configuration); + case "onnx" -> new OnnxEmbeddingCreator(configuration); + default -> throw new IllegalStateException("Unexpected value: " + configuration.name()); }; } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/postprocessor/TraceLinkIdPostprocessor.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/postprocessor/TraceLinkIdPostprocessor.java index 5dddad1d..7e911560 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/postprocessor/TraceLinkIdPostprocessor.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/postprocessor/TraceLinkIdPostprocessor.java @@ -8,12 +8,12 @@ public abstract class TraceLinkIdPostprocessor { public static TraceLinkIdPostprocessor createTraceLinkIdPostprocessor(Configuration.ModuleConfiguration moduleConfiguration) { return switch (moduleConfiguration.name()) { - case "req2code" -> new ReqCodePostprocessor(); - case "sad2code" -> new SadCodePostprocessor(); + case "req2code" -> new ReqCodePostprocessor(); + case "sad2code" -> new SadCodePostprocessor(); - case "identity" -> new IdentityPostprocessor(); - case null -> new IdentityPostprocessor(); - default -> throw new IllegalStateException("Unexpected value: " + moduleConfiguration.name()); + case "identity" -> new IdentityPostprocessor(); + case null -> new IdentityPostprocessor(); + default -> throw new IllegalStateException("Unexpected value: " + moduleConfiguration.name()); }; } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeMethodPreprocessor.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeMethodPreprocessor.java index da6f32b5..e6fbdb76 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeMethodPreprocessor.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeMethodPreprocessor.java @@ -42,7 +42,7 @@ protected List preprocess(Artifact artifact) { elements.add(artifactAsElement); var newElements = switch (language) { - case JAVA -> splitJava(artifactAsElement); + case JAVA -> splitJava(artifactAsElement); }; elements.addAll(newElements); diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeTreePreprocessor.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeTreePreprocessor.java index 0ea8607c..bcb43d8b 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeTreePreprocessor.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeTreePreprocessor.java @@ -17,7 +17,7 @@ public CodeTreePreprocessor(Configuration.ModuleConfiguration configuration) { @Override public List preprocess(List artifacts) { return switch (language) { - case JAVA -> createJavaTree(artifacts); + case JAVA -> createJavaTree(artifacts); }; } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/Preprocessor.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/Preprocessor.java index 374f5738..c8570897 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/Preprocessor.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/Preprocessor.java @@ -13,13 +13,13 @@ public abstract class Preprocessor { public static Preprocessor createPreprocessor(Configuration.ModuleConfiguration configuration) { return switch (configuration.name()) { - case "sentence" -> new SentencePreprocessor(configuration); - case "code_chunking" -> new CodeChunkingPreprocessor(configuration); - case "code_method" -> new CodeMethodPreprocessor(configuration); - case "code_tree" -> new CodeTreePreprocessor(configuration); - case "model_uml" -> new ModelUMLPreprocessor(configuration); - case "artifact" -> new SingleArtifactPreprocessor(); - default -> throw new IllegalStateException("Unexpected value: " + configuration.name()); + case "sentence" -> new SentencePreprocessor(configuration); + case "code_chunking" -> new CodeChunkingPreprocessor(configuration); + case "code_method" -> new CodeMethodPreprocessor(configuration); + case "code_tree" -> new CodeTreePreprocessor(configuration); + case "model_uml" -> new ModelUMLPreprocessor(configuration); + case "artifact" -> new SingleArtifactPreprocessor(); + default -> throw new IllegalStateException("Unexpected value: " + configuration.name()); }; } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/RecursiveSplitter.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/RecursiveSplitter.java index 11fd9c93..d31b64ad 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/RecursiveSplitter.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/RecursiveSplitter.java @@ -155,10 +155,10 @@ private String joinDocs(List docs, String separator) { public static List getSeparatorsForLanguage(Language language) { // Taken from LangChain (Python) return switch (language) { - case JAVA -> List.of("\nclass ", "\npublic ", "\nprotected ", "\nprivate ", "\nstatic ", "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", "\n\n", - "\n", " ", ""); - // Add other languages as needed... - default -> throw new IllegalArgumentException("Unsupported language: " + language); + case JAVA -> List.of("\nclass ", "\npublic ", "\nprotected ", "\nprivate ", "\nstatic ", "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", + "\n\n", "\n", " ", ""); + // Add other languages as needed... + default -> throw new IllegalArgumentException("Unsupported language: " + language); }; } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/AnyResultAggregator.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/AnyResultAggregator.java index 8c555955..a623ad94 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/AnyResultAggregator.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/AnyResultAggregator.java @@ -1,7 +1,7 @@ package edu.kit.kastel.sdq.lissa.ratlr.resultaggregator; import edu.kit.kastel.sdq.lissa.ratlr.Configuration; -import edu.kit.kastel.sdq.lissa.ratlr.classifier.Classifier; +import edu.kit.kastel.sdq.lissa.ratlr.classifier.ClassificationResult; import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element; import edu.kit.kastel.sdq.lissa.ratlr.knowledge.TraceLink; @@ -19,16 +19,15 @@ public AnyResultAggregator(Configuration.ModuleConfiguration configuration) { } @Override - public Set aggregate(List sourceElements, List targetElements, List classificationResults) { + public Set aggregate(List sourceElements, List targetElements, List classificationResults) { Set traceLinks = new LinkedHashSet<>(); for (var result : classificationResults) { var sourceElementsForTraceLink = buildListOfValidElements(result.source(), sourceGranularity, sourceElements); - for (var target : result.targets()) { - var targetElementsForTraceLink = buildListOfValidElements(target, targetGranularity, targetElements); - for (var sourceElement : sourceElementsForTraceLink) { - for (var targetElement : targetElementsForTraceLink) { - traceLinks.add(new TraceLink(sourceElement.getIdentifier(), targetElement.getIdentifier())); - } + var target = result.target(); + var targetElementsForTraceLink = buildListOfValidElements(target, targetGranularity, targetElements); + for (var sourceElement : sourceElementsForTraceLink) { + for (var targetElement : targetElementsForTraceLink) { + traceLinks.add(new TraceLink(sourceElement.getIdentifier(), targetElement.getIdentifier())); } } } diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/ResultAggregator.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/ResultAggregator.java index 42d9fc3c..9f6f417c 100644 --- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/ResultAggregator.java +++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/ResultAggregator.java @@ -1,7 +1,7 @@ package edu.kit.kastel.sdq.lissa.ratlr.resultaggregator; import edu.kit.kastel.sdq.lissa.ratlr.Configuration; -import edu.kit.kastel.sdq.lissa.ratlr.classifier.Classifier; +import edu.kit.kastel.sdq.lissa.ratlr.classifier.ClassificationResult; import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element; import edu.kit.kastel.sdq.lissa.ratlr.knowledge.TraceLink; @@ -9,13 +9,12 @@ import java.util.Set; public abstract class ResultAggregator { - public abstract Set aggregate(List sourceElements, List targetElements, - List classificationResults); + public abstract Set aggregate(List sourceElements, List targetElements, List classificationResults); public static ResultAggregator createResultAggregator(Configuration.ModuleConfiguration configuration) { return switch (configuration.name()) { - case "any_connection" -> new AnyResultAggregator(configuration); - default -> throw new IllegalStateException("Unexpected value: " + configuration.name()); + case "any_connection" -> new AnyResultAggregator(configuration); + default -> throw new IllegalStateException("Unexpected value: " + configuration.name()); }; } }