diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index bd1bba21..50f3a510 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -142,7 +142,10 @@ public class StreamableHttpClientTransport( ContentType.Application.Json -> response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json -> runCatching { McpJson.decodeFromString(json) } .onSuccess { _onMessage(it) } - .onFailure(_onError) + .onFailure { + _onError(it) + throw it + } } ContentType.Text.EventStream -> handleInlineSse( @@ -313,7 +316,10 @@ public class StreamableHttpClientTransport( _onMessage(msg) } } - .onFailure(_onError) + .onFailure { + _onError(it) + throw it + } } // reset id = null diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt index 12d20905..fdd895c6 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt @@ -12,13 +12,16 @@ import io.ktor.http.HttpStatusCode import io.ktor.http.content.TextContent import io.ktor.http.headersOf import io.ktor.utils.io.ByteReadChannel +import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.RequestId import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.delay import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withTimeout import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.buildJsonObject @@ -27,6 +30,7 @@ import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNull import kotlin.test.assertTrue +import kotlin.test.fail import kotlin.time.Duration.Companion.seconds class StreamableHttpClientTransportTest { @@ -380,4 +384,44 @@ class StreamableHttpClientTransportTest { assertEquals("resume-100", resumptionTokenReceived) transport.close() } + + @Test + fun testClientConnectWithInvalidJson() = runTest { + // Transport under test: respond with invalid JSON for the initialize request + val transport = createTransport { _ -> + respond( + "this is not valid json", + status = HttpStatusCode.OK, + headers = headersOf(HttpHeaders.ContentType, ContentType.Application.Json.toString()), + ) + } + + val client = Client( + clientInfo = Implementation( + name = "test-client", + version = "1.0", + ), + ) + + runCatching { + // Real time-keeping is needed; otherwise Protocol will always throw TimeoutCancellationException in tests + withTimeout(5.seconds) { + client.connect(transport) + } + + }.onSuccess { + fail("Expected client.connect to fail on invalid JSON response") + }.onFailure { e -> + when (e) { + is TimeoutCancellationException -> fail("Client connect caused a hang", e) + is IllegalStateException -> { + // Expected behavior: connect finishes and fails with an exception. + } + + else -> fail("Unexpected exception during client.connect", e) + } + }.also { + transport.close() + } + } }