Skip to content

Commit 23331f7

Browse files
committed
Add Mistral AI Codestral Embed model
Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com>
1 parent 72f7c63 commit 23331f7

File tree

4 files changed

+71
-36
lines changed

4 files changed

+71
-36
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.mistralai;
1818

1919
import java.util.List;
20+
import java.util.Map;
2021

2122
import io.micrometer.observation.ObservationRegistry;
2223
import org.slf4j.Logger;
@@ -41,6 +42,9 @@
4142
import org.springframework.retry.support.RetryTemplate;
4243
import org.springframework.util.Assert;
4344

45+
import static org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingModel.CODESTRAL_EMBED;
46+
import static org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingModel.EMBED;
47+
4448
/**
4549
* Provides the Mistral AI Embedding Model.
4650
*
@@ -53,6 +57,9 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
5357

5458
private static final Logger logger = LoggerFactory.getLogger(MistralAiEmbeddingModel.class);
5559

60+
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Map.of(EMBED.getValue(), 1024,
61+
CODESTRAL_EMBED.getValue(), 1536);
62+
5663
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
5764

5865
private final MistralAiEmbeddingOptions defaultOptions;
@@ -78,8 +85,7 @@ public MistralAiEmbeddingModel(MistralAiApi mistralAiApi) {
7885
}
7986

8087
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode) {
81-
this(mistralAiApi, metadataMode,
82-
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
88+
this(mistralAiApi, metadataMode, MistralAiEmbeddingOptions.builder().withModel(EMBED.getValue()).build(),
8389
RetryUtils.DEFAULT_RETRY_TEMPLATE);
8490
}
8591

@@ -179,6 +185,11 @@ public float[] embed(Document document) {
179185
return this.embed(document.getFormattedContent(this.metadataMode));
180186
}
181187

188+
@Override
189+
public int dimensions() {
190+
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
191+
}
192+
182193
/**
183194
* Use the provided convention for reporting observation data
184195
* @param observationConvention The provided convention

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -321,7 +321,8 @@ public String getName() {
321321
public enum EmbeddingModel {
322322

323323
// @formatter:off
324-
EMBED("mistral-embed");
324+
EMBED("mistral-embed"),
325+
CODESTRAL_EMBED("codestral-embed");
325326
// @formatter:on
326327

327328
private final String value;
Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,50 +16,76 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19-
import java.util.List;
20-
2119
import org.junit.jupiter.api.Test;
2220
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
23-
21+
import org.junit.jupiter.params.ParameterizedTest;
22+
import org.junit.jupiter.params.provider.CsvSource;
2423
import org.springframework.ai.embedding.EmbeddingRequest;
24+
import org.springframework.ai.mistralai.api.MistralAiApi;
2525
import org.springframework.beans.factory.annotation.Autowired;
2626
import org.springframework.boot.test.context.SpringBootTest;
2727

28+
import java.util.List;
29+
2830
import static org.assertj.core.api.Assertions.assertThat;
2931

3032
@SpringBootTest(classes = MistralAiTestConfiguration.class)
3133
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
3234
class MistralAiEmbeddingIT {
3335

36+
private static final int MISTRAL_EMBED_DIMENSIONS = 1024;
37+
38+
@Autowired
39+
private MistralAiApi mistralAiApi;
40+
3441
@Autowired
3542
private MistralAiEmbeddingModel mistralAiEmbeddingModel;
3643

3744
@Test
3845
void defaultEmbedding() {
39-
assertThat(this.mistralAiEmbeddingModel).isNotNull();
40-
var embeddingResponse = this.mistralAiEmbeddingModel.embedForResponse(List.of("Hello World"));
46+
var embeddingResponse = mistralAiEmbeddingModel.embedForResponse(List.of("Hello World"));
4147
assertThat(embeddingResponse.getResults()).hasSize(1);
4248
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
43-
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
49+
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(MISTRAL_EMBED_DIMENSIONS);
4450
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed");
4551
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);
4652
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
47-
assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
53+
assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(MISTRAL_EMBED_DIMENSIONS);
4854
}
4955

50-
@Test
51-
void embeddingTest() {
52-
assertThat(this.mistralAiEmbeddingModel).isNotNull();
53-
var embeddingResponse = this.mistralAiEmbeddingModel.call(new EmbeddingRequest(
54-
List.of("Hello World", "World is big"),
55-
MistralAiEmbeddingOptions.builder().withModel("mistral-embed").withEncodingFormat("float").build()));
56+
@ParameterizedTest
57+
@CsvSource({ "mistral-embed, 1024", "codestral-embed, 1536" })
58+
void defaultOptionsEmbedding(String model, int dimensions) {
59+
var mistralAiEmbeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).build();
60+
var anotherMistralAiEmbeddingModel = new MistralAiEmbeddingModel(mistralAiApi, mistralAiEmbeddingOptions);
61+
var embeddingResponse = anotherMistralAiEmbeddingModel.embedForResponse(List.of("Hello World", "World is big"));
5662
assertThat(embeddingResponse.getResults()).hasSize(2);
57-
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
58-
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
59-
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed");
63+
embeddingResponse.getResults().forEach(result -> {
64+
assertThat(result).isNotNull();
65+
assertThat(result.getOutput()).hasSize(dimensions);
66+
});
67+
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(model);
6068
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(9);
6169
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(9);
62-
assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
70+
assertThat(anotherMistralAiEmbeddingModel.dimensions()).isEqualTo(dimensions);
71+
}
72+
73+
@ParameterizedTest
74+
@CsvSource({ "mistral-embed, 1024", "codestral-embed, 1536" })
75+
void calledOptionsEmbedding(String model, int dimensions) {
76+
var mistralAiEmbeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).build();
77+
var embeddingRequest = new EmbeddingRequest(List.of("Hello World", "World is big", "We are small"),
78+
mistralAiEmbeddingOptions);
79+
var embeddingResponse = mistralAiEmbeddingModel.call(embeddingRequest);
80+
assertThat(embeddingResponse.getResults()).hasSize(3);
81+
embeddingResponse.getResults().forEach(result -> {
82+
assertThat(result).isNotNull();
83+
assertThat(result.getOutput()).hasSize(dimensions);
84+
});
85+
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(model);
86+
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(14);
87+
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(14);
88+
assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(MISTRAL_EMBED_DIMENSIONS);
6389
}
6490

6591
}

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,7 +16,6 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19-
import org.springframework.ai.embedding.EmbeddingModel;
2019
import org.springframework.ai.mistralai.api.MistralAiApi;
2120
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
2221
import org.springframework.ai.mistralai.moderation.MistralAiModerationModel;
@@ -27,30 +26,28 @@
2726
@SpringBootConfiguration
2827
public class MistralAiTestConfiguration {
2928

30-
@Bean
31-
public MistralAiApi mistralAiApi() {
29+
private static String retrieveApiKey() {
3230
var apiKey = System.getenv("MISTRAL_AI_API_KEY");
3331
if (!StringUtils.hasText(apiKey)) {
3432
throw new IllegalArgumentException(
3533
"Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key.");
3634
}
37-
return new MistralAiApi(apiKey);
35+
return apiKey;
36+
}
37+
38+
@Bean
39+
public MistralAiApi mistralAiApi() {
40+
return new MistralAiApi(retrieveApiKey());
3841
}
3942

4043
@Bean
4144
public MistralAiModerationApi mistralAiModerationApi() {
42-
var apiKey = System.getenv("MISTRAL_AI_API_KEY");
43-
if (!StringUtils.hasText(apiKey)) {
44-
throw new IllegalArgumentException(
45-
"Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key.");
46-
}
47-
return new MistralAiModerationApi(apiKey);
45+
return new MistralAiModerationApi(retrieveApiKey());
4846
}
4947

5048
@Bean
51-
public EmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) {
52-
return new MistralAiEmbeddingModel(api,
53-
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build());
49+
public MistralAiEmbeddingModel mistralAiEmbeddingModel(MistralAiApi mistralAiApi) {
50+
return new MistralAiEmbeddingModel(mistralAiApi);
5451
}
5552

5653
@Bean

0 commit comments

Comments
 (0)