diff --git a/formatter.xml b/formatter.xml
index b17fa783..547869d3 100644
--- a/formatter.xml
+++ b/formatter.xml
@@ -13,8 +13,27 @@
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -33,7 +52,6 @@
-
@@ -45,7 +63,6 @@
-
@@ -126,17 +143,7 @@
-
-
-
-
-
-
-
-
-
-
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/artifactprovider/ArtifactProvider.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/artifactprovider/ArtifactProvider.java
index f66ed720..7bc9553f 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/artifactprovider/ArtifactProvider.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/artifactprovider/ArtifactProvider.java
@@ -12,9 +12,9 @@ public abstract class ArtifactProvider {
public static ArtifactProvider createArtifactProvider(Configuration.ModuleConfiguration configuration) {
return switch (configuration.name()) {
- case "text" -> new TextArtifactProvider(configuration);
- case "recursive_text" -> new RecursiveTextArtifactProvider(configuration);
- default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
+ case "text" -> new TextArtifactProvider(configuration);
+ case "recursive_text" -> new RecursiveTextArtifactProvider(configuration);
+ default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
};
}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java
index 3cbe65c4..f5a712d1 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/cache/Cache.java
@@ -65,6 +65,7 @@ public synchronized void put(String key, String value) {
}
}
+ @SuppressWarnings("unchecked")
public synchronized T get(String key, Class clazz) {
try {
var jsonData = this.data.get(key);
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java
new file mode 100644
index 00000000..74349b65
--- /dev/null
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java
@@ -0,0 +1,71 @@
+package edu.kit.kastel.sdq.lissa.ratlr.classifier;
+
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.ollama.OllamaChatModel;
+import dev.langchain4j.model.openai.OpenAiChatModel;
+import edu.kit.kastel.sdq.lissa.ratlr.Configuration;
+import edu.kit.kastel.sdq.lissa.ratlr.Environment;
+import okhttp3.Credentials;
+
+import java.time.Duration;
+import java.util.Map;
+import java.util.Objects;
+
+public class ChatLanguageModelProvider {
+ public static final String OPENAI = "openai";
+ public static final String OLLAMA = "ollama";
+
+ private final String platform;
+ private String model;
+
+ public ChatLanguageModelProvider(Configuration.ModuleConfiguration configuration) {
+ String[] modeXplatform = configuration.name().split(Classifier.CONFIG_NAME_SEPARATOR, 2);
+ if (modeXplatform.length == 1) {
+ this.platform = null;
+ return;
+ }
+ this.platform = modeXplatform[1];
+ initModelPlatform(configuration);
+ }
+
+ public ChatLanguageModel createChatModel() {
+ return switch (platform) {
+ case OPENAI -> createOpenAiChatModel(model);
+ case OLLAMA -> createOllamaChatModel(model);
+ default -> throw new IllegalArgumentException("Unsupported platform: " + platform);
+ };
+ }
+
+ private void initModelPlatform(Configuration.ModuleConfiguration configuration) {
+ this.model = switch (platform) {
+ case OPENAI -> configuration.argumentAsString("model", "gpt-4o-mini");
+ case OLLAMA -> configuration.argumentAsString("model", "llama3:8b");
+ default -> throw new IllegalArgumentException("Unsupported platform: " + platform);
+ };
+ }
+
+ public String modelName() {
+ return Objects.requireNonNull(model, "Model not initialized");
+ }
+
+ private static OllamaChatModel createOllamaChatModel(String model) {
+ String host = Environment.getenv("OLLAMA_HOST");
+ String user = Environment.getenv("OLLAMA_USER");
+ String password = Environment.getenv("OLLAMA_PASSWORD");
+
+ var ollama = OllamaChatModel.builder().baseUrl(host).modelName(model).timeout(Duration.ofMinutes(5)).temperature(0.0);
+ if (user != null && password != null && !user.isEmpty() && !password.isEmpty()) {
+ ollama.customHeaders(Map.of("Authorization", Credentials.basic(user, password)));
+ }
+ return ollama.build();
+ }
+
+ private static OpenAiChatModel createOpenAiChatModel(String model) {
+ String openAiOrganizationId = Environment.getenv("OPENAI_ORGANIZATION_ID");
+ String openAiApiKey = Environment.getenv("OPENAI_API_KEY");
+ if (openAiOrganizationId == null || openAiApiKey == null) {
+ throw new IllegalStateException("OPENAI_ORGANIZATION_ID or OPENAI_API_KEY environment variable not set");
+ }
+ return new OpenAiChatModel.OpenAiChatModelBuilder().modelName(model).organizationId(openAiOrganizationId).apiKey(openAiApiKey).temperature(0.0).build();
+ }
+}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatModels.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatModels.java
deleted file mode 100644
index 538ab4b5..00000000
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatModels.java
+++ /dev/null
@@ -1,32 +0,0 @@
-package edu.kit.kastel.sdq.lissa.ratlr.classifier;
-
-import dev.langchain4j.model.ollama.OllamaChatModel;
-import dev.langchain4j.model.openai.OpenAiChatModel;
-import edu.kit.kastel.sdq.lissa.ratlr.Environment;
-import okhttp3.Credentials;
-
-import java.time.Duration;
-import java.util.Map;
-
-public final class ChatModels {
- public static OllamaChatModel createOllamaChatModel(String model) {
- String host = Environment.getenv("OLLAMA_HOST");
- String user = Environment.getenv("OLLAMA_USER");
- String password = Environment.getenv("OLLAMA_PASSWORD");
-
- var ollama = OllamaChatModel.builder().baseUrl(host).modelName(model).timeout(Duration.ofMinutes(5)).temperature(0.0);
- if (user != null && password != null && !user.isEmpty() && !password.isEmpty()) {
- ollama.customHeaders(Map.of("Authorization", Credentials.basic(user, password)));
- }
- return ollama.build();
- }
-
- public static OpenAiChatModel createOpenAiChatModel(String model) {
- String openAiOrganizationId = Environment.getenv("OPENAI_ORGANIZATION_ID");
- String openAiApiKey = Environment.getenv("OPENAI_API_KEY");
- if (openAiOrganizationId == null || openAiApiKey == null) {
- throw new IllegalStateException("OPENAI_ORGANIZATION_ID or OPENAI_API_KEY environment variable not set");
- }
- return new OpenAiChatModel.OpenAiChatModelBuilder().modelName(model).organizationId(openAiOrganizationId).apiKey(openAiApiKey).temperature(0.0).build();
- }
-}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ClassificationResult.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ClassificationResult.java
new file mode 100644
index 00000000..79be4562
--- /dev/null
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ClassificationResult.java
@@ -0,0 +1,19 @@
+package edu.kit.kastel.sdq.lissa.ratlr.classifier;
+
+import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element;
+
+public record ClassificationResult(Element source, Element target, double confidence) {
+ public ClassificationResult {
+ if (confidence < 0 || confidence > 1) {
+ throw new IllegalArgumentException("Confidence must be between 0 and 1");
+ }
+ }
+
+ public static ClassificationResult of(Element source, Element target) {
+ return new ClassificationResult(source, target, 1.0);
+ }
+
+ public static ClassificationResult of(Element source, Element target, double confidence) {
+ return new ClassificationResult(source, target, confidence);
+ }
+}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/Classifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/Classifier.java
index 16d25b5c..15696167 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/Classifier.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/Classifier.java
@@ -7,6 +7,7 @@
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -15,11 +16,12 @@
public abstract class Classifier {
private static final int THREADS = 100;
+ static final String CONFIG_NAME_SEPARATOR = "_";
protected final Logger logger = LoggerFactory.getLogger(this.getClass());
public List classify(ElementStore sourceStore, ElementStore targetStore) {
- List> futureResults = new ArrayList<>();
+ List>> futureResults = new ArrayList<>();
ExecutorService executor = Executors.newFixedThreadPool(THREADS);
for (var query : sourceStore.getAllElements(true)) {
var targetCandidates = targetStore.findSimilar(query.second());
@@ -36,24 +38,20 @@ public List classify(ElementStore sourceStore, ElementStor
throw new IllegalStateException(e);
}
- return futureResults.stream().map(Future::resultNow).toList();
+ return futureResults.stream().map(Future::resultNow).flatMap(Collection::stream).toList();
}
- protected abstract ClassificationResult classify(Element source, List targets);
+ protected abstract List classify(Element source, List targets);
protected abstract Classifier copyOf();
public static Classifier createClassifier(Configuration.ModuleConfiguration configuration) {
- return switch (configuration.name()) {
- case "mock" -> new MockClassifier();
- case "simple_ollama" -> new SimpleOllamaClassifier(configuration);
- case "simple_openai" -> new SimpleOpenAiClassifier(configuration);
- case "reasoning_ollama" -> new ReasoningOllamaClassifier(configuration);
- case "reasoning_openai" -> new ReasoningOpenAiClassifier(configuration);
- default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
+ return switch (configuration.name().split(CONFIG_NAME_SEPARATOR)[0]) {
+ case "mock" -> new MockClassifier();
+ case "simple" -> new SimpleClassifier(configuration);
+ case "reasoning" -> new ReasoningClassifier(configuration);
+ default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
};
}
- public record ClassificationResult(Element source, List targets) {
- }
}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/MockClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/MockClassifier.java
index d0ae730f..3efcb612 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/MockClassifier.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/MockClassifier.java
@@ -6,8 +6,8 @@
public class MockClassifier extends Classifier {
@Override
- protected ClassificationResult classify(Element source, List targets) {
- return new ClassificationResult(source, targets);
+ protected List classify(Element source, List targets) {
+ return targets.stream().map(target -> ClassificationResult.of(source, target)).toList();
}
@Override
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java
index ca487525..26c6847e 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningClassifier.java
@@ -18,44 +18,40 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;
-public abstract class ReasoningClassifier extends Classifier {
+public class ReasoningClassifier extends Classifier {
private final Cache cache;
+ private final ChatLanguageModelProvider provider;
+
private final ChatLanguageModel llm;
private final String prompt;
private final boolean useOriginalArtifacts;
private final boolean useSystemMessage;
- private final String model;
- protected ReasoningClassifier(Configuration.ModuleConfiguration configuration, String model) {
- this.cache = CacheManager.getDefaultInstance().getCache(this.getClass().getSimpleName() + "_" + model);
+ public ReasoningClassifier(Configuration.ModuleConfiguration configuration) {
+ this.provider = new ChatLanguageModelProvider(configuration);
+ this.cache = CacheManager.getDefaultInstance().getCache(this.getClass().getSimpleName() + "_" + provider.modelName());
this.prompt = Prompt.values()[configuration.argumentAsInt("prompt_id", 0)].prompt;
this.useOriginalArtifacts = configuration.argumentAsBoolean("use_original_artifacts", false);
this.useSystemMessage = configuration.argumentAsBoolean("use_system_message", true);
- this.model = model;
- this.llm = createChatModel(model);
+ this.llm = this.provider.createChatModel();
}
- protected ReasoningClassifier(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts, boolean useSystemMessage) {
+ private ReasoningClassifier(Cache cache, ChatLanguageModelProvider provider, String prompt, boolean useOriginalArtifacts, boolean useSystemMessage) {
this.cache = cache;
- this.model = model;
- this.llm = llm;
+ this.provider = provider;
this.prompt = prompt;
this.useOriginalArtifacts = useOriginalArtifacts;
this.useSystemMessage = useSystemMessage;
+ this.llm = this.provider.createChatModel();
}
@Override
protected final Classifier copyOf() {
- return copyOf(cache, model, createChatModel(model), prompt, useOriginalArtifacts, useSystemMessage);
+ return new ReasoningClassifier(cache, provider, prompt, useOriginalArtifacts, useSystemMessage);
}
- protected abstract ReasoningClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts,
- boolean useSystemMessage);
-
- protected abstract ChatLanguageModel createChatModel(String model);
-
@Override
- protected final ClassificationResult classify(Element source, List targets) {
+ protected final List classify(Element source, List targets) {
List relatedTargets = new ArrayList<>();
var targetsToConsider = targets;
@@ -87,7 +83,7 @@ protected final ClassificationResult classify(Element source, List targ
relatedTargets.add(target);
}
}
- return new ClassificationResult(source, relatedTargets);
+ return relatedTargets.stream().map(relatedTarget -> ClassificationResult.of(source, relatedTarget)).toList();
}
private boolean isRelated(String llmResponse) {
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOllamaClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOllamaClassifier.java
deleted file mode 100644
index d73f2f43..00000000
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOllamaClassifier.java
+++ /dev/null
@@ -1,27 +0,0 @@
-package edu.kit.kastel.sdq.lissa.ratlr.classifier;
-
-import dev.langchain4j.model.chat.ChatLanguageModel;
-import edu.kit.kastel.sdq.lissa.ratlr.Configuration;
-import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache;
-
-public class ReasoningOllamaClassifier extends ReasoningClassifier {
- public ReasoningOllamaClassifier(Configuration.ModuleConfiguration configuration) {
- super(configuration, configuration.argumentAsString("model", "llama3:8b"));
- }
-
- protected ReasoningOllamaClassifier(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts,
- boolean useSystemMessage) {
- super(cache, model, llm, prompt, useOriginalArtifacts, useSystemMessage);
- }
-
- @Override
- protected ChatLanguageModel createChatModel(String model) {
- return ChatModels.createOllamaChatModel(model);
- }
-
- @Override
- protected ReasoningClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts,
- boolean useSystemMessage) {
- return new ReasoningOllamaClassifier(cache, model, llm, prompt, useOriginalArtifacts, useSystemMessage);
- }
-}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOpenAiClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOpenAiClassifier.java
deleted file mode 100644
index be89c4a2..00000000
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ReasoningOpenAiClassifier.java
+++ /dev/null
@@ -1,27 +0,0 @@
-package edu.kit.kastel.sdq.lissa.ratlr.classifier;
-
-import dev.langchain4j.model.chat.ChatLanguageModel;
-import edu.kit.kastel.sdq.lissa.ratlr.Configuration;
-import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache;
-
-public class ReasoningOpenAiClassifier extends ReasoningClassifier {
- public ReasoningOpenAiClassifier(Configuration.ModuleConfiguration configuration) {
- super(configuration, configuration.argumentAsString("model", "gpt-4o-mini"));
- }
-
- protected ReasoningOpenAiClassifier(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts,
- boolean useSystemMessage) {
- super(cache, model, llm, prompt, useOriginalArtifacts, useSystemMessage);
- }
-
- @Override
- protected ChatLanguageModel createChatModel(String model) {
- return ChatModels.createOpenAiChatModel(model);
- }
-
- @Override
- protected ReasoningClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts,
- boolean useSystemMessage) {
- return new ReasoningOpenAiClassifier(cache, model, llm, prompt, useOriginalArtifacts, useSystemMessage);
- }
-}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java
index eb90af2c..549e9133 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java
@@ -10,7 +10,7 @@
import java.util.List;
import java.util.UUID;
-public abstract class SimpleClassifier extends Classifier {
+public class SimpleClassifier extends Classifier {
private static final String DEFAULT_TEMPLATE = """
Question: Here are two parts of software development artifacts. \n
@@ -21,35 +21,32 @@ public abstract class SimpleClassifier extends Classifier {
""";
private final Cache cache;
+ private final ChatLanguageModelProvider provider;
+
private final ChatLanguageModel llm;
private final String template;
- private final String model;
- protected SimpleClassifier(Configuration.ModuleConfiguration configuration, String model) {
+ public SimpleClassifier(Configuration.ModuleConfiguration configuration) {
+ this.provider = new ChatLanguageModelProvider(configuration);
this.template = configuration.argumentAsString("template", DEFAULT_TEMPLATE);
- this.cache = CacheManager.getDefaultInstance().getCache(this.getClass().getSimpleName() + "_" + model);
- this.model = model;
- this.llm = createChatModel(model);
+ this.cache = CacheManager.getDefaultInstance().getCache(this.getClass().getSimpleName() + "_" + provider.modelName());
+ this.llm = provider.createChatModel();
}
- protected SimpleClassifier(Cache cache, String model, ChatLanguageModel llm, String template) {
+ private SimpleClassifier(Cache cache, ChatLanguageModelProvider provider, String template) {
this.cache = cache;
- this.model = model;
- this.llm = llm;
+ this.provider = provider;
this.template = template;
+ this.llm = provider.createChatModel();
}
@Override
protected final Classifier copyOf() {
- return copyOf(cache, model, createChatModel(model), template);
+ return new SimpleClassifier(cache, provider, template);
}
- protected abstract SimpleClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String template);
-
- protected abstract ChatLanguageModel createChatModel(String model);
-
@Override
- protected final ClassificationResult classify(Element source, List targets) {
+ protected final List classify(Element source, List targets) {
List relatedTargets = new ArrayList<>();
for (var target : targets) {
@@ -59,7 +56,7 @@ protected final ClassificationResult classify(Element source, List targ
relatedTargets.add(target);
}
}
- return new ClassificationResult(source, relatedTargets);
+ return relatedTargets.stream().map(relatedTarget -> ClassificationResult.of(source, relatedTarget)).toList();
}
private String classify(Element source, Element target) {
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOllamaClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOllamaClassifier.java
deleted file mode 100644
index bbff00b9..00000000
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOllamaClassifier.java
+++ /dev/null
@@ -1,26 +0,0 @@
-package edu.kit.kastel.sdq.lissa.ratlr.classifier;
-
-import dev.langchain4j.model.chat.ChatLanguageModel;
-import edu.kit.kastel.sdq.lissa.ratlr.Configuration;
-import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache;
-
-public class SimpleOllamaClassifier extends SimpleClassifier {
-
- public SimpleOllamaClassifier(Configuration.ModuleConfiguration configuration) {
- super(configuration, configuration.argumentAsString("model", "llama3:8b"));
- }
-
- public SimpleOllamaClassifier(Cache cache, String model, ChatLanguageModel llm, String template) {
- super(cache, model, llm, template);
- }
-
- @Override
- protected ChatLanguageModel createChatModel(String model) {
- return ChatModels.createOllamaChatModel(model);
- }
-
- @Override
- protected SimpleClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String template) {
- return new SimpleOllamaClassifier(cache, model, llm, template);
- }
-}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOpenAiClassifier.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOpenAiClassifier.java
deleted file mode 100644
index 6bbf8fad..00000000
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleOpenAiClassifier.java
+++ /dev/null
@@ -1,25 +0,0 @@
-package edu.kit.kastel.sdq.lissa.ratlr.classifier;
-
-import dev.langchain4j.model.chat.ChatLanguageModel;
-import edu.kit.kastel.sdq.lissa.ratlr.Configuration;
-import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache;
-
-public class SimpleOpenAiClassifier extends SimpleClassifier {
- public SimpleOpenAiClassifier(Configuration.ModuleConfiguration configuration) {
- super(configuration, configuration.argumentAsString("model", "gpt-4o-mini"));
- }
-
- public SimpleOpenAiClassifier(Cache cache, String model, ChatLanguageModel llm, String template) {
- super(cache, model, llm, template);
- }
-
- @Override
- protected ChatLanguageModel createChatModel(String model) {
- return ChatModels.createOpenAiChatModel(model);
- }
-
- @Override
- protected SimpleClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String template) {
- return new SimpleOpenAiClassifier(cache, model, llm, template);
- }
-}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/EmbeddingCreator.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/EmbeddingCreator.java
index 0e36063f..cca554d7 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/EmbeddingCreator.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/embeddingcreator/EmbeddingCreator.java
@@ -14,10 +14,10 @@ public float[] calculateEmbedding(Element element) {
public static EmbeddingCreator createEmbeddingCreator(Configuration.ModuleConfiguration configuration) {
return switch (configuration.name()) {
- case "ollama" -> new OllamaEmbeddingCreator(configuration);
- case "openai" -> new OpenAiEmbeddingCreator(configuration);
- case "onnx" -> new OnnxEmbeddingCreator(configuration);
- default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
+ case "ollama" -> new OllamaEmbeddingCreator(configuration);
+ case "openai" -> new OpenAiEmbeddingCreator(configuration);
+ case "onnx" -> new OnnxEmbeddingCreator(configuration);
+ default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
};
}
}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/postprocessor/TraceLinkIdPostprocessor.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/postprocessor/TraceLinkIdPostprocessor.java
index 5dddad1d..7e911560 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/postprocessor/TraceLinkIdPostprocessor.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/postprocessor/TraceLinkIdPostprocessor.java
@@ -8,12 +8,12 @@
public abstract class TraceLinkIdPostprocessor {
public static TraceLinkIdPostprocessor createTraceLinkIdPostprocessor(Configuration.ModuleConfiguration moduleConfiguration) {
return switch (moduleConfiguration.name()) {
- case "req2code" -> new ReqCodePostprocessor();
- case "sad2code" -> new SadCodePostprocessor();
+ case "req2code" -> new ReqCodePostprocessor();
+ case "sad2code" -> new SadCodePostprocessor();
- case "identity" -> new IdentityPostprocessor();
- case null -> new IdentityPostprocessor();
- default -> throw new IllegalStateException("Unexpected value: " + moduleConfiguration.name());
+ case "identity" -> new IdentityPostprocessor();
+ case null -> new IdentityPostprocessor();
+ default -> throw new IllegalStateException("Unexpected value: " + moduleConfiguration.name());
};
}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeMethodPreprocessor.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeMethodPreprocessor.java
index da6f32b5..e6fbdb76 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeMethodPreprocessor.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeMethodPreprocessor.java
@@ -42,7 +42,7 @@ protected List preprocess(Artifact artifact) {
elements.add(artifactAsElement);
var newElements = switch (language) {
- case JAVA -> splitJava(artifactAsElement);
+ case JAVA -> splitJava(artifactAsElement);
};
elements.addAll(newElements);
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeTreePreprocessor.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeTreePreprocessor.java
index 0ea8607c..bcb43d8b 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeTreePreprocessor.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/CodeTreePreprocessor.java
@@ -17,7 +17,7 @@ public CodeTreePreprocessor(Configuration.ModuleConfiguration configuration) {
@Override
public List preprocess(List artifacts) {
return switch (language) {
- case JAVA -> createJavaTree(artifacts);
+ case JAVA -> createJavaTree(artifacts);
};
}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/Preprocessor.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/Preprocessor.java
index 374f5738..c8570897 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/Preprocessor.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/Preprocessor.java
@@ -13,13 +13,13 @@ public abstract class Preprocessor {
public static Preprocessor createPreprocessor(Configuration.ModuleConfiguration configuration) {
return switch (configuration.name()) {
- case "sentence" -> new SentencePreprocessor(configuration);
- case "code_chunking" -> new CodeChunkingPreprocessor(configuration);
- case "code_method" -> new CodeMethodPreprocessor(configuration);
- case "code_tree" -> new CodeTreePreprocessor(configuration);
- case "model_uml" -> new ModelUMLPreprocessor(configuration);
- case "artifact" -> new SingleArtifactPreprocessor();
- default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
+ case "sentence" -> new SentencePreprocessor(configuration);
+ case "code_chunking" -> new CodeChunkingPreprocessor(configuration);
+ case "code_method" -> new CodeMethodPreprocessor(configuration);
+ case "code_tree" -> new CodeTreePreprocessor(configuration);
+ case "model_uml" -> new ModelUMLPreprocessor(configuration);
+ case "artifact" -> new SingleArtifactPreprocessor();
+ default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
};
}
}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/RecursiveSplitter.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/RecursiveSplitter.java
index 11fd9c93..d31b64ad 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/RecursiveSplitter.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/preprocessor/RecursiveSplitter.java
@@ -155,10 +155,10 @@ private String joinDocs(List docs, String separator) {
public static List getSeparatorsForLanguage(Language language) {
// Taken from LangChain (Python)
return switch (language) {
- case JAVA -> List.of("\nclass ", "\npublic ", "\nprotected ", "\nprivate ", "\nstatic ", "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", "\n\n",
- "\n", " ", "");
- // Add other languages as needed...
- default -> throw new IllegalArgumentException("Unsupported language: " + language);
+ case JAVA -> List.of("\nclass ", "\npublic ", "\nprotected ", "\nprivate ", "\nstatic ", "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ",
+ "\n\n", "\n", " ", "");
+ // Add other languages as needed...
+ default -> throw new IllegalArgumentException("Unsupported language: " + language);
};
}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/AnyResultAggregator.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/AnyResultAggregator.java
index 8c555955..a623ad94 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/AnyResultAggregator.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/AnyResultAggregator.java
@@ -1,7 +1,7 @@
package edu.kit.kastel.sdq.lissa.ratlr.resultaggregator;
import edu.kit.kastel.sdq.lissa.ratlr.Configuration;
-import edu.kit.kastel.sdq.lissa.ratlr.classifier.Classifier;
+import edu.kit.kastel.sdq.lissa.ratlr.classifier.ClassificationResult;
import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element;
import edu.kit.kastel.sdq.lissa.ratlr.knowledge.TraceLink;
@@ -19,16 +19,15 @@ public AnyResultAggregator(Configuration.ModuleConfiguration configuration) {
}
@Override
- public Set aggregate(List sourceElements, List targetElements, List classificationResults) {
+ public Set aggregate(List sourceElements, List targetElements, List classificationResults) {
Set traceLinks = new LinkedHashSet<>();
for (var result : classificationResults) {
var sourceElementsForTraceLink = buildListOfValidElements(result.source(), sourceGranularity, sourceElements);
- for (var target : result.targets()) {
- var targetElementsForTraceLink = buildListOfValidElements(target, targetGranularity, targetElements);
- for (var sourceElement : sourceElementsForTraceLink) {
- for (var targetElement : targetElementsForTraceLink) {
- traceLinks.add(new TraceLink(sourceElement.getIdentifier(), targetElement.getIdentifier()));
- }
+ var target = result.target();
+ var targetElementsForTraceLink = buildListOfValidElements(target, targetGranularity, targetElements);
+ for (var sourceElement : sourceElementsForTraceLink) {
+ for (var targetElement : targetElementsForTraceLink) {
+ traceLinks.add(new TraceLink(sourceElement.getIdentifier(), targetElement.getIdentifier()));
}
}
}
diff --git a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/ResultAggregator.java b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/ResultAggregator.java
index 42d9fc3c..9f6f417c 100644
--- a/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/ResultAggregator.java
+++ b/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/resultaggregator/ResultAggregator.java
@@ -1,7 +1,7 @@
package edu.kit.kastel.sdq.lissa.ratlr.resultaggregator;
import edu.kit.kastel.sdq.lissa.ratlr.Configuration;
-import edu.kit.kastel.sdq.lissa.ratlr.classifier.Classifier;
+import edu.kit.kastel.sdq.lissa.ratlr.classifier.ClassificationResult;
import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element;
import edu.kit.kastel.sdq.lissa.ratlr.knowledge.TraceLink;
@@ -9,13 +9,12 @@
import java.util.Set;
public abstract class ResultAggregator {
- public abstract Set aggregate(List sourceElements, List targetElements,
- List classificationResults);
+ public abstract Set aggregate(List sourceElements, List targetElements, List classificationResults);
public static ResultAggregator createResultAggregator(Configuration.ModuleConfiguration configuration) {
return switch (configuration.name()) {
- case "any_connection" -> new AnyResultAggregator(configuration);
- default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
+ case "any_connection" -> new AnyResultAggregator(configuration);
+ default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
};
}
}