Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions formatter.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,27 @@
<setting id="org.eclipse.jdt.core.formatter.indent_body_declarations_compare_to_record_header" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.indent_statements_compare_to_body" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.indent_statements_compare_to_block" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.indent_switchstatements_compare_to_switch" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_switch_case_expressions" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_switch_case_expressions" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_before_switch_case_arrow_operator" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_switch" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.indent_switchstatements_compare_to_switch" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_arrow_in_switch_case" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.parentheses_positions_in_switch_statement" value="common_lines"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_arrow_in_switch_default" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_arrow_in_switch_case" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.keep_switch_body_block_on_one_line" value="one_line_never"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_switch" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_expressions_in_switch_case_with_arrow" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_switch" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.keep_switch_case_with_arrow_on_one_line" value="one_line_never"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_expressions_in_switch_case_with_colon" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_arrow_in_switch_default" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.indent_switchstatements_compare_to_cases" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_switch" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_switch_case_with_arrow" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_between_statement_group_in_switch" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_switch" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.indent_breaks_compare_to_cases" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.indent_empty_lines" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.align_type_members_on_columns" value="false"/>
Expand All @@ -33,7 +52,6 @@
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_annotation_type_declaration" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_block" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_block_in_case" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_switch" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_array_initializer" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.keep_empty_array_initializer_on_one_line" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_lambda_body" value="end_of_line"/>
Expand All @@ -45,7 +63,6 @@
<setting id="org.eclipse.jdt.core.formatter.parentheses_positions_in_lambda_declaration" value="common_lines"/>
<setting id="org.eclipse.jdt.core.formatter.parentheses_positions_in_if_while_statement" value="common_lines"/>
<setting id="org.eclipse.jdt.core.formatter.parentheses_positions_in_for_statment" value="common_lines"/>
<setting id="org.eclipse.jdt.core.formatter.parentheses_positions_in_switch_statement" value="common_lines"/>
<setting id="org.eclipse.jdt.core.formatter.parentheses_positions_in_try_clause" value="common_lines"/>
<setting id="org.eclipse.jdt.core.formatter.parentheses_positions_in_catch_clause" value="common_lines"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_type_declaration" value="insert"/>
Expand Down Expand Up @@ -126,17 +143,7 @@
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_colon_in_for" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_colon_in_case" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_colon_in_default" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_arrow_in_switch_case" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_arrow_in_switch_case" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_arrow_in_switch_default" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_arrow_in_switch_default" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_colon_in_case" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_switch_case_expressions" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_switch_case_expressions" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_switch" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_switch" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_switch" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_switch" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_while" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_while" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_while" value="do not insert"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ public abstract class ArtifactProvider {

public static ArtifactProvider createArtifactProvider(Configuration.ModuleConfiguration configuration) {
return switch (configuration.name()) {
case "text" -> new TextArtifactProvider(configuration);
case "recursive_text" -> new RecursiveTextArtifactProvider(configuration);
default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
case "text" -> new TextArtifactProvider(configuration);
case "recursive_text" -> new RecursiveTextArtifactProvider(configuration);
default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public synchronized void put(String key, String value) {
}
}

@SuppressWarnings("unchecked")
public synchronized <T> T get(String key, Class<T> clazz) {
try {
var jsonData = this.data.get(key);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package edu.kit.kastel.sdq.lissa.ratlr.classifier;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.ollama.OllamaChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import edu.kit.kastel.sdq.lissa.ratlr.Configuration;
import edu.kit.kastel.sdq.lissa.ratlr.Environment;
import okhttp3.Credentials;

import java.time.Duration;
import java.util.Map;
import java.util.Objects;

public class ChatLanguageModelProvider {
public static final String OPENAI = "openai";
public static final String OLLAMA = "ollama";

private final String platform;
private String model;

public ChatLanguageModelProvider(Configuration.ModuleConfiguration configuration) {
String[] modeXplatform = configuration.name().split(Classifier.CONFIG_NAME_SEPARATOR, 2);
if (modeXplatform.length == 1) {
this.platform = null;
return;
}
this.platform = modeXplatform[1];
initModelPlatform(configuration);
}

public ChatLanguageModel createChatModel() {
return switch (platform) {
case OPENAI -> createOpenAiChatModel(model);
case OLLAMA -> createOllamaChatModel(model);
default -> throw new IllegalArgumentException("Unsupported platform: " + platform);
};
}

private void initModelPlatform(Configuration.ModuleConfiguration configuration) {
this.model = switch (platform) {
case OPENAI -> configuration.argumentAsString("model", "gpt-4o-mini");
case OLLAMA -> configuration.argumentAsString("model", "llama3:8b");
default -> throw new IllegalArgumentException("Unsupported platform: " + platform);
};
}

public String modelName() {
return Objects.requireNonNull(model, "Model not initialized");
}

private static OllamaChatModel createOllamaChatModel(String model) {
String host = Environment.getenv("OLLAMA_HOST");
String user = Environment.getenv("OLLAMA_USER");
String password = Environment.getenv("OLLAMA_PASSWORD");

var ollama = OllamaChatModel.builder().baseUrl(host).modelName(model).timeout(Duration.ofMinutes(5)).temperature(0.0);
if (user != null && password != null && !user.isEmpty() && !password.isEmpty()) {
ollama.customHeaders(Map.of("Authorization", Credentials.basic(user, password)));
}
return ollama.build();
}

private static OpenAiChatModel createOpenAiChatModel(String model) {
String openAiOrganizationId = Environment.getenv("OPENAI_ORGANIZATION_ID");
String openAiApiKey = Environment.getenv("OPENAI_API_KEY");
if (openAiOrganizationId == null || openAiApiKey == null) {
throw new IllegalStateException("OPENAI_ORGANIZATION_ID or OPENAI_API_KEY environment variable not set");
}
return new OpenAiChatModel.OpenAiChatModelBuilder().modelName(model).organizationId(openAiOrganizationId).apiKey(openAiApiKey).temperature(0.0).build();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package edu.kit.kastel.sdq.lissa.ratlr.classifier;

import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element;

public record ClassificationResult(Element source, Element target, double confidence) {
public ClassificationResult {
if (confidence < 0 || confidence > 1) {
throw new IllegalArgumentException("Confidence must be between 0 and 1");
}
}

public static ClassificationResult of(Element source, Element target) {
return new ClassificationResult(source, target, 1.0);
}

public static ClassificationResult of(Element source, Element target, double confidence) {
return new ClassificationResult(source, target, confidence);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -15,11 +16,12 @@

public abstract class Classifier {
private static final int THREADS = 100;
static final String CONFIG_NAME_SEPARATOR = "_";

protected final Logger logger = LoggerFactory.getLogger(this.getClass());

public List<ClassificationResult> classify(ElementStore sourceStore, ElementStore targetStore) {
List<Future<ClassificationResult>> futureResults = new ArrayList<>();
List<Future<List<ClassificationResult>>> futureResults = new ArrayList<>();
ExecutorService executor = Executors.newFixedThreadPool(THREADS);
for (var query : sourceStore.getAllElements(true)) {
var targetCandidates = targetStore.findSimilar(query.second());
Expand All @@ -36,24 +38,20 @@ public List<ClassificationResult> classify(ElementStore sourceStore, ElementStor
throw new IllegalStateException(e);
}

return futureResults.stream().map(Future::resultNow).toList();
return futureResults.stream().map(Future::resultNow).flatMap(Collection::stream).toList();
}

protected abstract ClassificationResult classify(Element source, List<Element> targets);
protected abstract List<ClassificationResult> classify(Element source, List<Element> targets);

protected abstract Classifier copyOf();

public static Classifier createClassifier(Configuration.ModuleConfiguration configuration) {
return switch (configuration.name()) {
case "mock" -> new MockClassifier();
case "simple_ollama" -> new SimpleOllamaClassifier(configuration);
case "simple_openai" -> new SimpleOpenAiClassifier(configuration);
case "reasoning_ollama" -> new ReasoningOllamaClassifier(configuration);
case "reasoning_openai" -> new ReasoningOpenAiClassifier(configuration);
default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
return switch (configuration.name().split(CONFIG_NAME_SEPARATOR)[0]) {
case "mock" -> new MockClassifier();
case "simple" -> new SimpleClassifier(configuration);
case "reasoning" -> new ReasoningClassifier(configuration);
default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
};
}

public record ClassificationResult(Element source, List<Element> targets) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

public class MockClassifier extends Classifier {
@Override
protected ClassificationResult classify(Element source, List<Element> targets) {
return new ClassificationResult(source, targets);
protected List<ClassificationResult> classify(Element source, List<Element> targets) {
return targets.stream().map(target -> ClassificationResult.of(source, target)).toList();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,40 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public abstract class ReasoningClassifier extends Classifier {
public class ReasoningClassifier extends Classifier {
private final Cache cache;
private final ChatLanguageModelProvider provider;

private final ChatLanguageModel llm;
private final String prompt;
private final boolean useOriginalArtifacts;
private final boolean useSystemMessage;
private final String model;

protected ReasoningClassifier(Configuration.ModuleConfiguration configuration, String model) {
this.cache = CacheManager.getDefaultInstance().getCache(this.getClass().getSimpleName() + "_" + model);
public ReasoningClassifier(Configuration.ModuleConfiguration configuration) {
this.provider = new ChatLanguageModelProvider(configuration);
this.cache = CacheManager.getDefaultInstance().getCache(this.getClass().getSimpleName() + "_" + provider.modelName());
this.prompt = Prompt.values()[configuration.argumentAsInt("prompt_id", 0)].prompt;
this.useOriginalArtifacts = configuration.argumentAsBoolean("use_original_artifacts", false);
this.useSystemMessage = configuration.argumentAsBoolean("use_system_message", true);
this.model = model;
this.llm = createChatModel(model);
this.llm = this.provider.createChatModel();
}

protected ReasoningClassifier(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts, boolean useSystemMessage) {
private ReasoningClassifier(Cache cache, ChatLanguageModelProvider provider, String prompt, boolean useOriginalArtifacts, boolean useSystemMessage) {
this.cache = cache;
this.model = model;
this.llm = llm;
this.provider = provider;
this.prompt = prompt;
this.useOriginalArtifacts = useOriginalArtifacts;
this.useSystemMessage = useSystemMessage;
this.llm = this.provider.createChatModel();
}

@Override
protected final Classifier copyOf() {
return copyOf(cache, model, createChatModel(model), prompt, useOriginalArtifacts, useSystemMessage);
return new ReasoningClassifier(cache, provider, prompt, useOriginalArtifacts, useSystemMessage);
}

protected abstract ReasoningClassifier copyOf(Cache cache, String model, ChatLanguageModel llm, String prompt, boolean useOriginalArtifacts,
boolean useSystemMessage);

protected abstract ChatLanguageModel createChatModel(String model);

@Override
protected final ClassificationResult classify(Element source, List<Element> targets) {
protected final List<ClassificationResult> classify(Element source, List<Element> targets) {
List<Element> relatedTargets = new ArrayList<>();

var targetsToConsider = targets;
Expand Down Expand Up @@ -87,7 +83,7 @@ protected final ClassificationResult classify(Element source, List<Element> targ
relatedTargets.add(target);
}
}
return new ClassificationResult(source, relatedTargets);
return relatedTargets.stream().map(relatedTarget -> ClassificationResult.of(source, relatedTarget)).toList();
}

private boolean isRelated(String llmResponse) {
Expand Down
Loading