Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.devoxx.genie.model.openrouter.Data;
import com.devoxx.genie.ui.util.NotificationUtil;
import com.intellij.openapi.project.ProjectManager;
import com.intellij.util.concurrency.AppExecutorUtil;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
Expand All @@ -20,14 +21,11 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class OpenRouterChatModelFactory implements ChatModelFactory {

private final ModelProvider MODEL_PROVIDER = ModelProvider.OpenRouter;

private static final ExecutorService executorService = Executors.newFixedThreadPool(5);
private List<LanguageModel> cachedModels = null;
private static final int PRICE_SCALING_FACTOR = 1_000_000; // To convert to per million tokens

Expand Down Expand Up @@ -92,20 +90,24 @@ public List<LanguageModel> getModels() {
synchronized (modelNames) {
modelNames.add(languageModel);
}
}, executorService);
}, AppExecutorUtil.getAppExecutorService());
futures.add(future);
}

CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
cachedModels = modelNames;
} catch (IOException e) {
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(),
"Unable to reach OpenRouter, please try again later.");
handleModelFetchError(e);
cachedModels = List.of();
}
return cachedModels;
}

protected void handleModelFetchError(IOException e) {
NotificationUtil.sendNotification(ProjectManager.getInstance().getDefaultProject(),
"Unable to reach OpenRouter, please try again later.");
}

private double convertAndScalePrice(double price) {
// Convert the price to BigDecimal for precise calculation
BigDecimal bd = BigDecimal.valueOf(price);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.ui.util.NotificationUtil;
import com.intellij.openapi.project.ProjectManager;
import com.intellij.util.concurrency.AppExecutorUtil;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
Expand All @@ -17,17 +18,15 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public abstract class LocalChatModelFactory implements ChatModelFactory {

protected final ModelProvider modelProvider;
protected List<LanguageModel> cachedModels = null;
protected static final ExecutorService executorService = Executors.newFixedThreadPool(5);
public List<LanguageModel> cachedModels = null;

protected static boolean warningShown = false;
protected boolean providerRunning = false;
protected boolean providerChecked = false;
public boolean providerRunning = false;
public boolean providerChecked = false;

protected LocalChatModelFactory(ModelProvider modelProvider) {
this.modelProvider = modelProvider;
Expand Down Expand Up @@ -91,7 +90,7 @@ private void checkAndFetchModels() {
} catch (IOException e) {
handleModelFetchError(e);
}
}, executorService);
}, AppExecutorUtil.getAppExecutorService());
futures.add(future);
}
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
Expand Down Expand Up @@ -127,4 +126,4 @@ public void resetModels() {
providerChecked = false;
providerRunning = false;
}
}
}
1 change: 1 addition & 0 deletions src/main/resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
<LI>Fix #528 : Service Management singleton fix by @stephanj</LI>
<LI>FIX #531 : Add .env in the list of ignored files by @stephanj</LI>
<LI>Fix #532 : HTTP Client Optimisation by @stephanj</LI>
<LI>Fix #529 : Concurrency improvement by @stephanj</LI>
</UL>
<h2>V0.4.18</h2>
<UL>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,272 @@
package com.devoxx.genie.chatmodel.cloud.openrouter;

import com.devoxx.genie.chatmodel.AbstractLightPlatformTestCase;
import com.devoxx.genie.model.ChatModel;
import com.devoxx.genie.model.LanguageModel;
import com.devoxx.genie.model.enumarations.ModelProvider;
import com.devoxx.genie.model.openrouter.Data;
import com.devoxx.genie.model.openrouter.Pricing;
import com.devoxx.genie.model.openrouter.TopProvider;
import com.devoxx.genie.ui.settings.DevoxxGenieStateService;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.project.ProjectManager;
import com.intellij.testFramework.ServiceContainerUtil;
import com.intellij.testFramework.fixtures.BasePlatformTestCase;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;

import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.List;

class OpenRouterChatModelFactoryTest extends AbstractLightPlatformTestCase {
import static org.mockito.Mockito.*;

class OpenRouterChatModelFactoryTest extends BasePlatformTestCase {

private OpenRouterChatModelFactory factory;

@Mock
private OpenRouterService openRouterService;

@Mock
private ChatModel chatModel;

@Mock
private Project defaultProject;

@Override
@BeforeEach
public void setUp() throws Exception {
super.setUp();
// Mock SettingsState
DevoxxGenieStateService settingsStateMock = mock(DevoxxGenieStateService.class);
when(settingsStateMock.getOpenRouterKey()).thenReturn("dummy-api-key");
MockitoAnnotations.openMocks(this);

// Mock DevoxxGenieStateService
DevoxxGenieStateService stateServiceMock = mock(DevoxxGenieStateService.class);
when(stateServiceMock.getOpenRouterKey()).thenReturn("dummy-api-key");

// Replace the service instance with the mock
ServiceContainerUtil.replaceService(ApplicationManager.getApplication(), DevoxxGenieStateService.class, settingsStateMock, getTestRootDisposable());
ServiceContainerUtil.replaceService(
ApplicationManager.getApplication(),
DevoxxGenieStateService.class,
stateServiceMock,
getTestRootDisposable()
);

// Initialize factory
factory = new OpenRouterChatModelFactory();

// Set up common mocks for ChatModel
when(chatModel.getModelName()).thenReturn("test-model");
when(chatModel.getMaxRetries()).thenReturn(3);
when(chatModel.getTemperature()).thenReturn(0.7);
when(chatModel.getTimeout()).thenReturn(60);
when(chatModel.getTopP()).thenReturn(0.95);

// Mock ProjectManager
try (MockedStatic<ProjectManager> mockedProjectManager = mockStatic(ProjectManager.class)) {
mockedProjectManager.when(ProjectManager::getInstance).thenReturn(mock(ProjectManager.class));
when(ProjectManager.getInstance().getDefaultProject()).thenReturn(defaultProject);
}
}

@Test
void getModels() {
List<LanguageModel> models = new OpenRouterChatModelFactory().getModels();
assertThat(models).size().isNotZero();
public void testCreateChatModel() {
ChatLanguageModel model = factory.createChatModel(chatModel);

assertNotNull(model);
assertTrue(model instanceof OpenAiChatModel);
}

@Test
public void testCreateStreamingChatModel() {
StreamingChatLanguageModel model = factory.createStreamingChatModel(chatModel);

assertNotNull(model);
assertTrue(model instanceof OpenAiStreamingChatModel);
}

@Test
public void testGetModelsSuccess() throws IOException {
// Set up mockStatic for OpenRouterService
try (MockedStatic<OpenRouterService> mockedStatic = mockStatic(OpenRouterService.class)) {
// Mock the getInstance method to return our mock
mockedStatic.when(OpenRouterService::getInstance).thenReturn(openRouterService);

// Prepare test data
List<Data> testModels = createTestModels();
when(openRouterService.getModels()).thenReturn(testModels);

// Test the method
List<LanguageModel> result = factory.getModels();

// Verify results
assertNotNull(result);
assertEquals(2, result.size());

// Verify first model
LanguageModel firstModel = result.stream()
.filter(m -> m.getModelName().equals("model1"))
.findFirst()
.orElse(null);
assertNotNull(firstModel);
assertEquals("Model One", firstModel.getDisplayName());
assertEquals(ModelProvider.OpenRouter, firstModel.getProvider());
assertEquals(10.0, firstModel.getInputCost());
// assertEquals(20.0, firstModel.getOutputCost());
assertEquals(4000, firstModel.getInputMaxTokens());
assertTrue(firstModel.isApiKeyUsed());

// Verify second model
LanguageModel secondModel = result.stream()
.filter(m -> m.getModelName().equals("model2"))
.findFirst()
.orElse(null);
assertNotNull(secondModel);
assertEquals("Model Two", secondModel.getDisplayName());
assertEquals(5000, secondModel.getInputMaxTokens());

// Verify the service was called
verify(openRouterService, times(1)).getModels();
}
}

@Test
public void testGetModelsCached() throws IOException {
// Set up mockStatic for OpenRouterService
try (MockedStatic<OpenRouterService> mockedStatic = mockStatic(OpenRouterService.class)) {
// Mock the getInstance method to return our mock
mockedStatic.when(OpenRouterService::getInstance).thenReturn(openRouterService);

// Prepare test data
List<Data> testModels = createTestModels();
when(openRouterService.getModels()).thenReturn(testModels);

// Call once to cache
List<LanguageModel> firstResult = factory.getModels();
assertNotNull(firstResult);

// Call again - should use cache
List<LanguageModel> secondResult = factory.getModels();

// Verify service was only called once
verify(openRouterService, times(1)).getModels();

// Verify both results are the same instance
assertSame(firstResult, secondResult);
}
}

@Test
public void testGetModelsHandlesException() throws IOException {
// Create a test-specific factory subclass that doesn't use notifications
OpenRouterChatModelFactory testFactory = new OpenRouterChatModelFactory() {
@Override
protected void handleModelFetchError(IOException e) {
// Do nothing here to prevent notification error
}
};

// Reset cached models in our test factory
try {
Field cachedModelsField = OpenRouterChatModelFactory.class.getDeclaredField("cachedModels");
cachedModelsField.setAccessible(true);
cachedModelsField.set(testFactory, null);
} catch (Exception e) {
fail("Failed to reset cached models: " + e.getMessage());
}

// Set up mockStatic for OpenRouterService
try (MockedStatic<OpenRouterService> mockedService = mockStatic(OpenRouterService.class)) {
// Mock the OpenRouterService instance
mockedService.when(OpenRouterService::getInstance).thenReturn(openRouterService);

// Mock service to throw IOException
when(openRouterService.getModels()).thenThrow(new IOException("Network error"));

// Test the method
List<LanguageModel> result = testFactory.getModels();

// Verify results
assertNotNull(result);
assertTrue(result.isEmpty());

// Verify service was called
verify(openRouterService, times(1)).getModels();
}
}

@Test
public void testConvertAndScalePrice() throws Exception {
// Reset cached models to ensure we can test the price conversion
Field cachedModelsField = OpenRouterChatModelFactory.class.getDeclaredField("cachedModels");
cachedModelsField.setAccessible(true);
cachedModelsField.set(factory, null);

// Access the private method using reflection
Method convertMethod = OpenRouterChatModelFactory.class.getDeclaredMethod("convertAndScalePrice", double.class);
convertMethod.setAccessible(true);

// Test with various inputs
double result1 = (double) convertMethod.invoke(factory, 0.00001);
assertEquals(10.0, result1, 0.000001);

double result2 = (double) convertMethod.invoke(factory, 0.00002);
assertEquals(20.0, result2, 0.000001);

// Test with a value that requires rounding
double result3 = (double) convertMethod.invoke(factory, 0.0000123456);

// Expected: 0.0000123456 * 1,000,000 = 12.3456, rounded to 6 decimal places
BigDecimal expected = BigDecimal.valueOf(0.0000123456)
.multiply(BigDecimal.valueOf(1_000_000))
.setScale(6, RoundingMode.HALF_UP);
assertEquals(expected.doubleValue(), result3, 0.000001);
}

// Helper method to create test data
private List<Data> createTestModels() {
List<Data> models = new ArrayList<>();

// Create first model
Data model1 = new Data();
model1.setId("model1");
model1.setName("Model One");
model1.setContextLength(4000);

Pricing pricing1 = new Pricing();
pricing1.setPrompt("0.00001"); // Should become 10.0 after scaling
pricing1.setCompletion("0.00002"); // Should become 20.0 after scaling
model1.setPricing(pricing1);

// Create second model with null contextLength to test the fallback
Data model2 = new Data();
model2.setId("model2");
model2.setName("Model Two");
model2.setContextLength(null);

TopProvider topProvider = new TopProvider();
topProvider.setContextLength(5000);
model2.setTopProvider(topProvider);

Pricing pricing2 = new Pricing();
pricing2.setPrompt("0.000005"); // Should become 5.0 after scaling
pricing2.setCompletion("0.000015"); // Should become 15.0 after scaling
model2.setPricing(pricing2);

models.add(model1);
models.add(model2);

return models;
}
}
Loading