Skip to content

Commit

Permalink
Attempt to fix InAnYan#98
Browse files Browse the repository at this point in the history
  • Loading branch information
InAnYan committed Aug 6, 2024
1 parent 182ab00 commit 2468fec
Show file tree
Hide file tree
Showing 10 changed files with 451 additions and 122 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/jabref/gui/entryeditor/AiChatTab.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ protected void handleFocus() {
protected void bindToEntry(BibEntry entry) {
if (!aiService.getPreferences().getEnableAi()) {
showPrivacyNotice(entry);
} else if (aiService.getPreferences().getApiToken().isEmpty()) {
} else if (aiService.getPreferences().getSelectedApiToken().isEmpty()) {
showApiKeyMissing();
} else if (entry.getFiles().isEmpty()) {
showErrorNoFiles();
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/jabref/gui/entryeditor/AiSummaryTab.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ protected void handleFocus() {
protected void bindToEntry(BibEntry entry) {
if (!aiService.getPreferences().getEnableAi()) {
showPrivacyNotice(entry);
} else if (aiService.getPreferences().getApiToken().isEmpty()) {
} else if (aiService.getPreferences().getSelectedApiToken().isEmpty()) {
showApiKeyMissing();
} else if (bibDatabaseContext.getDatabasePath().isEmpty()) {
showErrorNoDatabasePath();
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/org/jabref/gui/preferences/ai/AiTab.fxml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<?import org.controlsfx.control.textfield.*?>

<?import org.jabref.gui.icon.JabRefIconView?>
<?import javafx.geometry.Insets?>
<fx:root spacing="10.0" type="VBox" xmlns="http://javafx.com/javafx/17.0.2-ea" xmlns:fx="http://javafx.com/fxml/1" fx:controller="org.jabref.gui.preferences.ai.AiTab">
<children>
<Label styleClass="titleHeader" text="%AI" />
Expand Down Expand Up @@ -35,6 +36,9 @@
<ComboBox fx:id="chatModelComboBox" editable="true" maxWidth="1.7976931348623157E308" HBox.hgrow="ALWAYS" />
<Button fx:id="chatModelHelp" prefWidth="20.0" />
</children>
<padding>
<Insets left="20.0"/>
</padding>
</HBox>

<HBox alignment="CENTER_LEFT" spacing="10.0">
Expand All @@ -43,6 +47,9 @@
<CustomPasswordField fx:id="apiKeyTextField" HBox.hgrow="ALWAYS" />
<Button fx:id="apiTokenHelp" prefWidth="20.0" />
</children>
<padding>
<Insets left="20.0"/>
</padding>
</HBox>

<Label styleClass="sectionHeader" text="%Expert settings" />
Expand Down
161 changes: 120 additions & 41 deletions src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,32 @@ public class AiTabViewModel implements PreferenceTabViewModel {

private final ListProperty<String> chatModelsList =
new SimpleListProperty<>(FXCollections.observableArrayList());
private final StringProperty selectedChatModel = new SimpleStringProperty();

private final StringProperty apiToken = new SimpleStringProperty();
private final StringProperty currentChatModel = new SimpleStringProperty();

private final StringProperty openAiChatModel = new SimpleStringProperty();
private final StringProperty mistralAiChatModel = new SimpleStringProperty();
private final StringProperty huggingFaceChatModel = new SimpleStringProperty();

private final StringProperty currentApiToken = new SimpleStringProperty();

private final StringProperty openAiApiToken = new SimpleStringProperty();
private final StringProperty mistralAiApiToken = new SimpleStringProperty();
private final StringProperty huggingFaceApiToken = new SimpleStringProperty();

private final BooleanProperty customizeExpertSettings = new SimpleBooleanProperty();

private final ListProperty<AiPreferences.EmbeddingModel> embeddingModelsList =
new SimpleListProperty<>(FXCollections.observableArrayList(AiPreferences.EmbeddingModel.values()));
private final ObjectProperty<AiPreferences.EmbeddingModel> selectedEmbeddingModel = new SimpleObjectProperty<>();

private final StringProperty apiBaseUrl = new SimpleStringProperty();
private final StringProperty currentApiBaseUrl = new SimpleStringProperty();
private final BooleanProperty disableApiBaseUrl = new SimpleBooleanProperty(true); // {@link HuggingFaceChatModel} doesn't support setting API base URL

private final StringProperty openAiApiBaseUrl = new SimpleStringProperty();
private final StringProperty mistralAiApiBaseUrl = new SimpleStringProperty();
private final StringProperty huggingFaceApiBaseUrl = new SimpleStringProperty();

private final StringProperty instruction = new SimpleStringProperty();
private final DoubleProperty temperature = new SimpleDoubleProperty();
private final IntegerProperty contextWindowSize = new SimpleIntegerProperty();
Expand Down Expand Up @@ -89,45 +102,103 @@ public AiTabViewModel(PreferencesService preferencesService) {
);

this.selectedAiProvider.addListener((observable, oldValue, newValue) -> {
List<String> models = AiDefaultPreferences.CHAT_MODELS.get(newValue);
List<String> models = AiDefaultPreferences.AVAILABLE_CHAT_MODELS.get(newValue);

// When we setAll on Hugging Face, models are empty, and currentChatModel become null.
// It becomes null beause currentChatModel is binded to combobox, and this combobox becomes empty.
// For some reason, custom edited value in the combobox will be erased, so we need to store the old value.
String oldChatModel = currentChatModel.get();
chatModelsList.setAll(models);
if (!models.isEmpty()) {
selectedChatModel.setValue(chatModelsList.getFirst());

disableApiBaseUrl.set(newValue == AiPreferences.AiProvider.HUGGING_FACE);

if (oldValue != null) {
switch (oldValue) {
case OPEN_AI -> {
openAiChatModel.set(oldChatModel);
openAiApiToken.set(currentApiToken.get());
openAiApiBaseUrl.set(currentApiBaseUrl.get());
}

case MISTRAL_AI -> {
mistralAiChatModel.set(oldChatModel);
mistralAiApiToken.set(currentApiToken.get());
mistralAiApiBaseUrl.set(currentApiBaseUrl.get()); }

case HUGGING_FACE -> {
huggingFaceChatModel.set(oldChatModel);
huggingFaceApiToken.set(currentApiToken.get());
huggingFaceApiBaseUrl.set(currentApiBaseUrl.get());
}
}
}

apiBaseUrl.set(AiDefaultPreferences.PROVIDERS_API_URLS.get(newValue));
apiToken.set("");
switch (newValue) {
case OPEN_AI -> {
currentChatModel.set(openAiChatModel.get());
currentApiToken.set(openAiApiToken.get());
currentApiBaseUrl.set(openAiApiBaseUrl.get());
}

case MISTRAL_AI -> {
currentChatModel.set(mistralAiChatModel.get());
currentApiToken.set(mistralAiApiToken.get());
currentApiBaseUrl.set(mistralAiApiBaseUrl.get());
}

case HUGGING_FACE -> {
currentChatModel.set(huggingFaceChatModel.get());
currentApiToken.set(huggingFaceApiToken.get());
currentApiBaseUrl.set(huggingFaceApiBaseUrl.get());
}
}
});

this.selectedChatModel.addListener((observable, oldValue, newValue) -> {
this.currentChatModel.addListener((observable, oldValue, newValue) -> {
switch (selectedAiProvider.get()) {
case OPEN_AI -> openAiChatModel.set(newValue);
case MISTRAL_AI -> mistralAiChatModel.set(newValue);
case HUGGING_FACE -> huggingFaceChatModel.set(newValue);
}

Map<String, Integer> modelContextWindows = AiDefaultPreferences.CONTEXT_WINDOW_SIZES.get(selectedAiProvider.get());

if (modelContextWindows == null) {
contextWindowSize.set(AiDefaultPreferences.CONTEXT_WINDOW_SIZE);
return;
}

Integer value = modelContextWindows.get(newValue);
contextWindowSize.set(modelContextWindows.getOrDefault(newValue, AiDefaultPreferences.CONTEXT_WINDOW_SIZE));
});

contextWindowSize.set(value == null ? AiDefaultPreferences.CONTEXT_WINDOW_SIZE : value);
this.currentApiToken.addListener((observable, oldValue, newValue) -> {
switch (selectedAiProvider.get()) {
case OPEN_AI -> openAiApiToken.set(newValue);
case MISTRAL_AI -> mistralAiApiToken.set(newValue);
case HUGGING_FACE -> huggingFaceApiToken.set(newValue);
}
});

this.selectedAiProvider.addListener((observable, oldValue, newValue) ->
disableApiBaseUrl.set(newValue == AiPreferences.AiProvider.HUGGING_FACE)
);
this.currentApiBaseUrl.addListener((observable, oldValue, newValue) -> {
switch (selectedAiProvider.get()) {
case OPEN_AI -> openAiApiBaseUrl.set(newValue);
case MISTRAL_AI -> mistralAiApiBaseUrl.set(newValue);
case HUGGING_FACE -> huggingFaceApiBaseUrl.set(newValue);
}
});

this.apiTokenValidator = new FunctionBasedValidator<>(
apiToken,
currentApiToken,
token -> !StringUtil.isBlank(token),
ValidationMessage.error(Localization.lang("An OpenAI token has to be provided")));

this.chatModelValidator = new FunctionBasedValidator<>(
selectedChatModel,
currentChatModel,
chatModel -> !StringUtil.isBlank(chatModel),
ValidationMessage.error(Localization.lang("Chat model has to be provided")));

this.apiBaseUrlValidator = new FunctionBasedValidator<>(
apiBaseUrl,
currentApiBaseUrl,
token -> !StringUtil.isBlank(token),
ValidationMessage.error(Localization.lang("API base URL has to be provided")));

Expand Down Expand Up @@ -172,35 +243,55 @@ public AiTabViewModel(PreferencesService preferencesService) {
public void setValues() {
enableAi.setValue(aiPreferences.getEnableAi());

selectedAiProvider.setValue(aiPreferences.getAiProvider());
selectedChatModel.setValue(aiPreferences.getChatModel());
apiToken.setValue(aiPreferences.getApiToken());
openAiChatModel.setValue(aiPreferences.getOpenAiChatModel());
mistralAiChatModel.setValue(aiPreferences.getMistralAiChatModel());
huggingFaceChatModel.setValue(aiPreferences.getHuggingFaceChatModel());

openAiApiToken.setValue(aiPreferences.getOpenAiApiToken());
mistralAiApiToken.setValue(aiPreferences.getMistralAiApiToken());
huggingFaceApiToken.setValue(aiPreferences.getHuggingFaceApiToken());

customizeExpertSettings.setValue(aiPreferences.getCustomizeExpertSettings());

selectedEmbeddingModel.setValue(aiPreferences.getEmbeddingModel());
apiBaseUrl.setValue(aiPreferences.getApiBaseUrl());

openAiApiBaseUrl.setValue(aiPreferences.getOpenAiApiBaseUrl());
mistralAiApiBaseUrl.setValue(aiPreferences.getMistralAiApiBaseUrl());
huggingFaceApiBaseUrl.setValue(aiPreferences.getHuggingFaceApiBaseUrl());

instruction.setValue(aiPreferences.getInstruction());
temperature.setValue(aiPreferences.getTemperature());
contextWindowSize.setValue(aiPreferences.getContextWindowSize());
documentSplitterChunkSize.setValue(aiPreferences.getDocumentSplitterChunkSize());
documentSplitterOverlapSize.setValue(aiPreferences.getDocumentSplitterOverlapSize());
ragMaxResultsCount.setValue(aiPreferences.getRagMaxResultsCount());
ragMinScore.setValue(aiPreferences.getRagMinScore());

selectedAiProvider.setValue(aiPreferences.getAiProvider());
}

@Override
public void storeSettings() {
aiPreferences.setEnableAi(enableAi.get());

aiPreferences.setAiProvider(selectedAiProvider.get());
aiPreferences.setChatModel(selectedChatModel.get());
aiPreferences.setApiToken(apiToken.get());

aiPreferences.setOpenAiChatModel(openAiChatModel.get());
aiPreferences.setMistralAiChatModel(mistralAiChatModel.get());
aiPreferences.setHuggingFaceChatModel(huggingFaceChatModel.get());

aiPreferences.setOpenAiApiToken(openAiApiToken.get());
aiPreferences.setMistralAiApiToken(mistralAiApiToken.get());
aiPreferences.setHuggingFaceApiToken(huggingFaceApiToken.get());

aiPreferences.setCustomizeExpertSettings(customizeExpertSettings.get());

aiPreferences.setEmbeddingModel(selectedEmbeddingModel.get());
aiPreferences.setApiBaseUrl(apiBaseUrl.get());

aiPreferences.setOpenAiApiBaseUrl(openAiApiBaseUrl.get());
aiPreferences.setMistralAiApiBaseUrl(mistralAiApiBaseUrl.get());
aiPreferences.setHuggingFaceApiBaseUrl(huggingFaceApiBaseUrl.get());

aiPreferences.setInstruction(instruction.get());
aiPreferences.setTemperature(temperature.get());
aiPreferences.setContextWindowSize(contextWindowSize.get());
Expand All @@ -212,29 +303,17 @@ public void storeSettings() {

public void resetExpertSettings() {
String resetApiBaseUrl = AiDefaultPreferences.PROVIDERS_API_URLS.get(selectedAiProvider.get());
aiPreferences.setApiBaseUrl(resetApiBaseUrl);
apiBaseUrl.setValue(resetApiBaseUrl);
currentApiBaseUrl.set(resetApiBaseUrl);

aiPreferences.setInstruction(AiDefaultPreferences.SYSTEM_MESSAGE);
instruction.set(AiDefaultPreferences.SYSTEM_MESSAGE);

int resetContextWindowSize = AiDefaultPreferences.CONTEXT_WINDOW_SIZES.getOrDefault(selectedAiProvider.get(), Map.of()).getOrDefault(selectedChatModel.get(), 0);
aiPreferences.setContextWindowSize(resetContextWindowSize);
int resetContextWindowSize = AiDefaultPreferences.CONTEXT_WINDOW_SIZES.getOrDefault(selectedAiProvider.get(), Map.of()).getOrDefault(currentChatModel.get(), 0);
contextWindowSize.set(resetContextWindowSize);

aiPreferences.setTemperature(AiDefaultPreferences.TEMPERATURE);
temperature.set(AiDefaultPreferences.TEMPERATURE);

aiPreferences.setDocumentSplitterChunkSize(AiDefaultPreferences.DOCUMENT_SPLITTER_CHUNK_SIZE);
documentSplitterChunkSize.set(AiDefaultPreferences.DOCUMENT_SPLITTER_CHUNK_SIZE);

aiPreferences.setDocumentSplitterOverlapSize(AiDefaultPreferences.DOCUMENT_SPLITTER_OVERLAP);
documentSplitterOverlapSize.set(AiDefaultPreferences.DOCUMENT_SPLITTER_OVERLAP);

aiPreferences.setRagMaxResultsCount(AiDefaultPreferences.RAG_MAX_RESULTS_COUNT);
ragMaxResultsCount.set(AiDefaultPreferences.RAG_MAX_RESULTS_COUNT);

aiPreferences.setRagMinScore(AiDefaultPreferences.RAG_MIN_SCORE);
ragMinScore.set(AiDefaultPreferences.RAG_MIN_SCORE);
}

Expand Down Expand Up @@ -296,11 +375,11 @@ public ReadOnlyListProperty<String> chatModelsProperty() {
}

public StringProperty selectedChatModelProperty() {
return selectedChatModel;
return currentChatModel;
}

public StringProperty apiTokenProperty() {
return apiToken;
return currentApiToken;
}

public BooleanProperty customizeExpertSettingsProperty() {
Expand All @@ -316,7 +395,7 @@ public ObjectProperty<AiPreferences.EmbeddingModel> selectedEmbeddingModelProper
}

public StringProperty apiBaseUrlProperty() {
return apiBaseUrl;
return currentApiBaseUrl;
}

public BooleanProperty disableApiBaseUrlProperty() {
Expand Down
44 changes: 28 additions & 16 deletions src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,19 @@
import org.jabref.preferences.AiPreferences;

public class AiDefaultPreferences {
public static final boolean ENABLE_CHAT = false;

public static final AiPreferences.AiProvider PROVIDER = AiPreferences.AiProvider.OPEN_AI;
public static final Map<AiPreferences.AiProvider, List<String>> CHAT_MODELS = Map.of(
public static final Map<AiPreferences.AiProvider, List<String>> AVAILABLE_CHAT_MODELS = Map.of(
AiPreferences.AiProvider.OPEN_AI, List.of("gpt-4o-mini", "gpt-4o", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo"),
// "mistral" and "mixtral" are not language mistakes.
AiPreferences.AiProvider.MISTRAL_AI, List.of("open-mistral-nemo", "open-mistral-7b", "open-mixtral-8x7b", "open-mixtral-8x22b", "mistral-large-latest"),
AiPreferences.AiProvider.HUGGING_FACE, List.of()
);
public static final String CHAT_MODEL = CHAT_MODELS.get(PROVIDER).getFirst();

public static final boolean CUSTOMIZE_SETTINGS = false;

public static final AiPreferences.EmbeddingModel EMBEDDING_MODEL = AiPreferences.EmbeddingModel.ALL_MINILM_L6_V2;
public static final String SYSTEM_MESSAGE = "You are an AI assistant that analyses research papers.";
public static final double TEMPERATURE = 0.7;
public static final int DOCUMENT_SPLITTER_CHUNK_SIZE = 300;
public static final int DOCUMENT_SPLITTER_OVERLAP = 100;
public static final int RAG_MAX_RESULTS_COUNT = 10;
public static final double RAG_MIN_SCORE = 0.3;

public static final int CONTEXT_WINDOW_SIZE = 8196;
public static final Map<AiPreferences.AiProvider, String> PROVIDERS_API_URLS = Map.of(
AiPreferences.AiProvider.OPEN_AI, "https://api.openai.com/v1",
AiPreferences.AiProvider.MISTRAL_AI, "https://api.mistral.ai/v1",
AiPreferences.AiProvider.HUGGING_FACE, "https://huggingface.co/api"
);

public static final Map<AiPreferences.AiProvider, Map<String, Integer>> CONTEXT_WINDOW_SIZES = Map.of(
AiPreferences.AiProvider.OPEN_AI, Map.of(
"gpt-4o-mini", 128000,
Expand All @@ -49,4 +35,30 @@ public class AiDefaultPreferences {
"open-mixtral-8x22b", 64000
)
);

public static final boolean ENABLE_CHAT = false;

public static final AiPreferences.AiProvider PROVIDER = AiPreferences.AiProvider.OPEN_AI;

public static final Map<AiPreferences.AiProvider, String> CHAT_MODELS = Map.of(
AiPreferences.AiProvider.OPEN_AI, "gpt-4o-mini",
AiPreferences.AiProvider.MISTRAL_AI, "open-mixtral-8x22b",
AiPreferences.AiProvider.HUGGING_FACE, ""
);

public static final boolean CUSTOMIZE_SETTINGS = false;

public static final AiPreferences.EmbeddingModel EMBEDDING_MODEL = AiPreferences.EmbeddingModel.ALL_MINILM_L6_V2;
public static final String SYSTEM_MESSAGE = "You are an AI assistant that analyses research papers.";
public static final double TEMPERATURE = 0.7;
public static final int DOCUMENT_SPLITTER_CHUNK_SIZE = 300;
public static final int DOCUMENT_SPLITTER_OVERLAP = 100;
public static final int RAG_MAX_RESULTS_COUNT = 10;
public static final double RAG_MIN_SCORE = 0.3;

public static final int CONTEXT_WINDOW_SIZE = 8196;

public static int getContextWindowSize(AiPreferences.AiProvider aiProvider, String model) {
return CONTEXT_WINDOW_SIZES.getOrDefault(aiProvider, Map.of()).getOrDefault(model, 0);
}
}
Loading

0 comments on commit 2468fec

Please sign in to comment.