Skip to content

Commit c1a88a3

Browse files
authored
Fix OllamaAgentIntegrationTest and AIAgentIntegrationTest (#629)
1 parent 4d6604e commit c1a88a3

File tree

3 files changed

+45
-26
lines changed

3 files changed

+45
-26
lines changed

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/AIAgentIntegrationTest.kt

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ import ai.koog.prompt.message.Message
4444
import ai.koog.prompt.message.ResponseMetaInfo
4545
import ai.koog.prompt.params.LLMParams
4646
import ai.koog.prompt.params.LLMParams.ToolChoice
47-
import kotlinx.coroutines.runBlocking
4847
import kotlinx.coroutines.test.runTest
4948
import kotlinx.serialization.Serializable
5049
import org.junit.jupiter.api.Assumptions.assumeTrue
@@ -315,16 +314,16 @@ class AIAgentIntegrationTest {
315314
}
316315

317316
@BeforeTest
318-
fun setupTest() = runBlocking {
317+
fun setupTest() = runTest {
319318
cleanUp()
320319
}
321320

322321
@AfterTest
323-
fun teardownTest() = runBlocking {
322+
fun teardownTest() = runTest {
324323
cleanUp()
325324
}
326325

327-
private fun runMultipleToolsTest(model: LLModel, runMode: ToolCalls) = runBlocking {
326+
private fun runMultipleToolsTest(model: LLModel, runMode: ToolCalls) = runTest(timeout = 300.seconds) {
328327
Models.assumeAvailable(model.provider)
329328
assumeTrue(model.capabilities.contains(LLMCapability.Tools), "Model $model does not support tools")
330329

@@ -337,8 +336,9 @@ class AIAgentIntegrationTest {
337336
getSingleRunAgentWithRunMode(model, runMode, eventHandlerConfig = eventHandlerConfig)
338337
multiToolAgent.run(twoToolsPrompt)
339338

340-
assertTrue(
341-
parallelToolCalls.size == 2,
339+
assertEquals(
340+
2,
341+
parallelToolCalls.size,
342342
"There should be exactly 2 tool calls in a Multiple tool calls scenario"
343343
)
344344
assertTrue(
@@ -359,14 +359,14 @@ class AIAgentIntegrationTest {
359359
)
360360
}
361361

362-
assertTrue(firstCall.tool == CalculatorTool.name, "First tool call should be ${CalculatorTool.name}")
363-
assertTrue(secondCall.tool == DelayTool.name, "Second tool call should be ${DelayTool.name}")
362+
assertEquals(CalculatorTool.name, firstCall.tool, "First tool call should be ${CalculatorTool.name}")
363+
assertEquals(DelayTool.name, secondCall.tool, "Second tool call should be ${DelayTool.name}")
364364
}
365365
}
366366

367367
@ParameterizedTest
368368
@MethodSource("openAIModels", "anthropicModels", "googleModels")
369-
fun integration_AIAgentShouldNotCallToolsByDefault(model: LLModel) = runBlocking {
369+
fun integration_AIAgentShouldNotCallToolsByDefault(model: LLModel) = runTest {
370370
Models.assumeAvailable(model.provider)
371371
withRetry {
372372
val executor = getExecutor(model)
@@ -387,7 +387,7 @@ class AIAgentIntegrationTest {
387387

388388
@ParameterizedTest
389389
@MethodSource("openAIModels", "anthropicModels", "googleModels")
390-
fun integration_AIAgentShouldCallCustomTool(model: LLModel) = runBlocking {
390+
fun integration_AIAgentShouldCallCustomTool(model: LLModel) = runTest {
391391
Models.assumeAvailable(model.provider)
392392
val systemPromptForSmallLLM = systemPrompt + "You MUST use tools."
393393
assumeTrue(model.capabilities.contains(LLMCapability.Tools), "Model $model does not support tools")
@@ -426,7 +426,7 @@ class AIAgentIntegrationTest {
426426

427427
@ParameterizedTest
428428
@MethodSource("modelsWithVisionCapability")
429-
fun integration_AIAgentWithImageCapabilityTest(model: LLModel) = runTest(timeout = 120.seconds) {
429+
fun integration_AIAgentWithImageCapabilityTest(model: LLModel) = runTest(timeout = 300.seconds) {
430430
Models.assumeAvailable(model.provider)
431431
assumeTrue(model.capabilities.contains(LLMCapability.Vision.Image), "Model must support vision capability")
432432

@@ -477,7 +477,7 @@ class AIAgentIntegrationTest {
477477

478478
@ParameterizedTest
479479
@MethodSource("openAIModels", "anthropicModels", "googleModels")
480-
fun integration_testRequestLLMWithoutToolsTest(model: LLModel) = runTest(timeout = 120.seconds) {
480+
fun integration_testRequestLLMWithoutToolsTest(model: LLModel) = runTest(timeout = 180.seconds) {
481481
Models.assumeAvailable(model.provider)
482482
assumeTrue(model.capabilities.contains(LLMCapability.Tools), "Model $model does not support tools")
483483

@@ -632,8 +632,9 @@ class AIAgentIntegrationTest {
632632
}
633633
}
634634

635-
assertTrue(
636-
reasoningCallsCount == expectedReasoningCalls,
635+
assertEquals(
636+
expectedReasoningCalls,
637+
reasoningCallsCount,
637638
"With reasoningInterval=$interval and ${toolExecutionCounter.size} tool calls, " +
638639
"expected $expectedReasoningCalls reasoning calls but got $reasoningCallsCount"
639640
)
@@ -642,7 +643,7 @@ class AIAgentIntegrationTest {
642643

643644
@ParameterizedTest
644645
@MethodSource("openAIModels", "anthropicModels", "googleModels")
645-
fun integration_AgentCreateAndRestoreTest(model: LLModel) = runTest(timeout = 120.seconds) {
646+
fun integration_AgentCreateAndRestoreTest(model: LLModel) = runTest(timeout = 180.seconds) {
646647
val checkpointStorageProvider = InMemoryPersistencyStorageProvider("integration_AgentCreateAndRestoreTest")
647648
val sayHello = "Hello World!"
648649
val hello = "Hello"
@@ -730,7 +731,7 @@ class AIAgentIntegrationTest {
730731

731732
@ParameterizedTest
732733
@MethodSource("openAIModels", "anthropicModels", "googleModels")
733-
fun integration_AgentCheckpointRollbackTest(model: LLModel) = runTest(timeout = 120.seconds) {
734+
fun integration_AgentCheckpointRollbackTest(model: LLModel) = runTest(timeout = 180.seconds) {
734735
val checkpointStorageProvider = InMemoryPersistencyStorageProvider("integration_AgentCheckpointRollbackTest")
735736

736737
val hello = "Hello"
@@ -845,7 +846,7 @@ class AIAgentIntegrationTest {
845846

846847
@ParameterizedTest
847848
@MethodSource("openAIModels", "anthropicModels", "googleModels")
848-
fun integration_AgentCheckpointContinuousPersistenceTest(model: LLModel) = runTest(timeout = 120.seconds) {
849+
fun integration_AgentCheckpointContinuousPersistenceTest(model: LLModel) = runTest(timeout = 180.seconds) {
849850
val checkpointStorageProvider =
850851
InMemoryPersistencyStorageProvider("integration_AgentCheckpointContinuousPersistenceTest")
851852

@@ -922,7 +923,7 @@ class AIAgentIntegrationTest {
922923

923924
@ParameterizedTest
924925
@MethodSource("openAIModels", "anthropicModels", "googleModels")
925-
fun integration_AgentCheckpointStorageProvidersTest(model: LLModel) = runTest(timeout = 120.seconds) {
926+
fun integration_AgentCheckpointStorageProvidersTest(model: LLModel) = runTest(timeout = 180.seconds) {
926927
val strategyName = "storage-providers-strategy"
927928

928929
val hello = "Hello"
@@ -991,7 +992,7 @@ class AIAgentIntegrationTest {
991992

992993
@ParameterizedTest
993994
@MethodSource("openAIModels", "anthropicModels", "googleModels")
994-
fun integration_AgentWithToolsWithoutParamsTest(model: LLModel) = runTest(timeout = 120.seconds) {
995+
fun integration_AgentWithToolsWithoutParamsTest(model: LLModel) = runTest(timeout = 180.seconds) {
995996
assumeTrue(model.capabilities.contains(LLMCapability.Tools), "Model $model does not support tools")
996997
val flakyModels = listOf(
997998
GoogleModels.Gemini2_0Flash.id,
@@ -1045,7 +1046,7 @@ class AIAgentIntegrationTest {
10451046

10461047
@ParameterizedTest
10471048
@MethodSource("openAIModels", "anthropicModels", "googleModels")
1048-
fun integration_ParallelNodesExecutionTest(model: LLModel) = runTest(timeout = 120.seconds) {
1049+
fun integration_ParallelNodesExecutionTest(model: LLModel) = runTest(timeout = 180.seconds) {
10491050
Models.assumeAvailable(model.provider)
10501051

10511052
val parallelStrategy = strategy<String, String>("parallel-nodes-strategy") {
@@ -1120,7 +1121,7 @@ class AIAgentIntegrationTest {
11201121

11211122
@ParameterizedTest
11221123
@MethodSource("openAIModels", "anthropicModels", "googleModels")
1123-
fun integration_ParallelNodesWithSelectionTest(model: LLModel) = runTest(timeout = 120.seconds) {
1124+
fun integration_ParallelNodesWithSelectionTest(model: LLModel) = runTest(timeout = 180.seconds) {
11241125
Models.assumeAvailable(model.provider)
11251126

11261127
val selectionStrategy = strategy<String, String>("parallel-selection-strategy") {

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/OllamaAgentIntegrationTest.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class OllamaAgentIntegrationTest {
8686
val definePrompt by node<Unit, Unit> {
8787
llm.writeSession {
8888
model = OllamaModels.Meta.LLAMA_3_2
89-
rewritePrompt {
89+
updatePrompt {
9090
prompt("test-ollama") {
9191
system(
9292
""""
@@ -147,7 +147,7 @@ class OllamaAgentIntegrationTest {
147147
agentConfig = AIAgentConfig(
148148
prompt("test-ollama", LLMParams(temperature = 0.0)) {},
149149
model,
150-
15
150+
20
151151
),
152152
toolRegistry = toolRegistry
153153
) {
@@ -173,7 +173,7 @@ class OllamaAgentIntegrationTest {
173173
promptsAndResponses.add("RESPONSE: $responseText")
174174
}
175175

176-
onAgentFinished { eventContext ->
176+
onAgentFinished { _ ->
177177
println("Agent execution finished")
178178
}
179179
}

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/annotations/RetryExtension.kt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,18 @@ class RetryExtension : InvocationInterceptor {
4545
}
4646

4747
var lastException: Throwable? = null
48+
var attempt = 0
4849

49-
for (attempt in 1..retry.times) {
50+
while
51+
(attempt < retry.times) {
52+
attempt++
5053
try {
5154
println("[DEBUG_LOG] Test '${extensionContext.displayName}' - attempt $attempt of ${retry.times}")
52-
invocation.proceed()
55+
if (attempt == 1) {
56+
invocation.proceed()
57+
} else {
58+
invokeTestMethodDirectly(invocationContext, extensionContext)
59+
}
5360
println("[DEBUG_LOG] Test '${extensionContext.displayName}' succeeded on attempt $attempt")
5461
return
5562
} catch (throwable: Throwable) {
@@ -93,4 +100,15 @@ class RetryExtension : InvocationInterceptor {
93100

94101
throw lastException!!
95102
}
103+
104+
private fun invokeTestMethodDirectly(
105+
invocationContext: ReflectiveInvocationContext<Method>,
106+
extensionContext: ExtensionContext
107+
) {
108+
val testInstance = extensionContext.requiredTestInstance
109+
val testMethod = invocationContext.executable
110+
val arguments = invocationContext.arguments
111+
112+
testMethod.invoke(testInstance, *arguments.toTypedArray())
113+
}
96114
}

0 commit comments

Comments
 (0)