Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.mistralai;

import java.util.List;
import java.util.Map;

import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
Expand All @@ -41,6 +42,9 @@
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

import static org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingModel.CODESTRAL_EMBED;
import static org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingModel.EMBED;

/**
* Provides the Mistral AI Embedding Model.
*
Expand All @@ -53,6 +57,9 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {

private static final Logger logger = LoggerFactory.getLogger(MistralAiEmbeddingModel.class);

private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Map.of(EMBED.getValue(), 1024,
CODESTRAL_EMBED.getValue(), 1536);

private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();

private final MistralAiEmbeddingOptions defaultOptions;
Expand All @@ -78,8 +85,7 @@ public MistralAiEmbeddingModel(MistralAiApi mistralAiApi) {
}

public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode) {
this(mistralAiApi, metadataMode,
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
this(mistralAiApi, metadataMode, MistralAiEmbeddingOptions.builder().withModel(EMBED.getValue()).build(),
RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

Expand Down Expand Up @@ -179,6 +185,11 @@ public float[] embed(Document document) {
return this.embed(document.getFormattedContent(this.metadataMode));
}

@Override
public int dimensions() {
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -321,7 +321,8 @@ public String getName() {
public enum EmbeddingModel {

// @formatter:off
EMBED("mistral-embed");
EMBED("mistral-embed"),
CODESTRAL_EMBED("codestral-embed");
// @formatter:on

private final String value;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,50 +16,76 @@

package org.springframework.ai.mistralai;

import java.util.List;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(classes = MistralAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
class MistralAiEmbeddingIT {

private static final int MISTRAL_EMBED_DIMENSIONS = 1024;

@Autowired
private MistralAiApi mistralAiApi;

@Autowired
private MistralAiEmbeddingModel mistralAiEmbeddingModel;

@Test
void defaultEmbedding() {
assertThat(this.mistralAiEmbeddingModel).isNotNull();
var embeddingResponse = this.mistralAiEmbeddingModel.embedForResponse(List.of("Hello World"));
var embeddingResponse = mistralAiEmbeddingModel.embedForResponse(List.of("Hello World"));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(MISTRAL_EMBED_DIMENSIONS);
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed");
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(MISTRAL_EMBED_DIMENSIONS);
}

@Test
void embeddingTest() {
assertThat(this.mistralAiEmbeddingModel).isNotNull();
var embeddingResponse = this.mistralAiEmbeddingModel.call(new EmbeddingRequest(
List.of("Hello World", "World is big"),
MistralAiEmbeddingOptions.builder().withModel("mistral-embed").withEncodingFormat("float").build()));
@ParameterizedTest
@CsvSource({ "mistral-embed, 1024", "codestral-embed, 1536" })
void defaultOptionsEmbedding(String model, int dimensions) {
var mistralAiEmbeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).build();
var anotherMistralAiEmbeddingModel = new MistralAiEmbeddingModel(mistralAiApi, mistralAiEmbeddingOptions);
var embeddingResponse = anotherMistralAiEmbeddingModel.embedForResponse(List.of("Hello World", "World is big"));
assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed");
embeddingResponse.getResults().forEach(result -> {
assertThat(result).isNotNull();
assertThat(result.getOutput()).hasSize(dimensions);
});
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(model);
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(9);
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(9);
assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
assertThat(anotherMistralAiEmbeddingModel.dimensions()).isEqualTo(dimensions);
}

@ParameterizedTest
@CsvSource({ "mistral-embed, 1024", "codestral-embed, 1536" })
void calledOptionsEmbedding(String model, int dimensions) {
var mistralAiEmbeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).build();
var embeddingRequest = new EmbeddingRequest(List.of("Hello World", "World is big", "We are small"),
mistralAiEmbeddingOptions);
var embeddingResponse = mistralAiEmbeddingModel.call(embeddingRequest);
assertThat(embeddingResponse.getResults()).hasSize(3);
embeddingResponse.getResults().forEach(result -> {
assertThat(result).isNotNull();
assertThat(result.getOutput()).hasSize(dimensions);
});
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(model);
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(14);
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(14);
assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(MISTRAL_EMBED_DIMENSIONS);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,6 @@

package org.springframework.ai.mistralai;

import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
import org.springframework.ai.mistralai.moderation.MistralAiModerationModel;
Expand All @@ -27,30 +26,28 @@
@SpringBootConfiguration
public class MistralAiTestConfiguration {

@Bean
public MistralAiApi mistralAiApi() {
private static String retrieveApiKey() {
var apiKey = System.getenv("MISTRAL_AI_API_KEY");
if (!StringUtils.hasText(apiKey)) {
throw new IllegalArgumentException(
"Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key.");
}
return new MistralAiApi(apiKey);
return apiKey;
}

@Bean
public MistralAiApi mistralAiApi() {
return new MistralAiApi(retrieveApiKey());
}

@Bean
public MistralAiModerationApi mistralAiModerationApi() {
var apiKey = System.getenv("MISTRAL_AI_API_KEY");
if (!StringUtils.hasText(apiKey)) {
throw new IllegalArgumentException(
"Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key.");
}
return new MistralAiModerationApi(apiKey);
return new MistralAiModerationApi(retrieveApiKey());
}

@Bean
public EmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) {
return new MistralAiEmbeddingModel(api,
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build());
public MistralAiEmbeddingModel mistralAiEmbeddingModel(MistralAiApi mistralAiApi) {
return new MistralAiEmbeddingModel(mistralAiApi);
}

@Bean
Expand Down