Skip to content

Commit fdde3f8

Browse files
committed
Fix websocket ktor server implementation, add test and logs
1 parent db729db commit fdde3f8

File tree

13 files changed

+917
-76
lines changed

13 files changed

+917
-76
lines changed

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,11 @@ public open class Client(private val clientInfo: Implementation, options: Client
156156
serverVersion = result.serverInfo
157157

158158
notification(InitializedNotification())
159+
} catch (error: CancellationException) {
160+
throw IllegalStateException("Error connecting to transport: ${error.message}")
159161
} catch (error: Throwable) {
162+
logger.error(error) { "Failed to initialize client" }
160163
close()
161-
if (error !is CancellationException) {
162-
throw IllegalStateException("Error connecting to transport: ${error.message}")
163-
}
164164

165165
throw error
166166
}

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
package io.modelcontextprotocol.kotlin.sdk.client
22

3-
import io.ktor.client.HttpClient
4-
import io.ktor.client.plugins.websocket.webSocketSession
5-
import io.ktor.client.request.HttpRequestBuilder
6-
import io.ktor.client.request.header
7-
import io.ktor.http.HttpHeaders
8-
import io.ktor.websocket.WebSocketSession
3+
import io.github.oshai.kotlinlogging.KotlinLogging
4+
import io.ktor.client.*
5+
import io.ktor.client.plugins.websocket.*
6+
import io.ktor.client.request.*
7+
import io.ktor.http.*
8+
import io.ktor.websocket.*
99
import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL
1010
import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport
1111
import kotlin.properties.Delegates
1212

13+
private val logger = KotlinLogging.logger {}
14+
1315
/**
1416
* Client transport for WebSocket: this will connect to a server over the WebSocket protocol.
1517
*/
@@ -21,6 +23,8 @@ public class WebSocketClientTransport(
2123
override var session: WebSocketSession by Delegates.notNull()
2224

2325
override suspend fun initializeSession() {
26+
logger.debug { "Websocket session initialization started..." }
27+
2428
session = urlString?.let {
2529
client.webSocketSession(it) {
2630
requestBuilder()
@@ -32,5 +36,7 @@ public class WebSocketClientTransport(
3236

3337
header(HttpHeaders.SecWebSocketProtocol, MCP_SUBPROTOCOL)
3438
}
39+
40+
logger.debug { "Websocket session initialization finished" }
3541
}
3642
}

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package io.modelcontextprotocol.kotlin.sdk.client
22

3+
import io.github.oshai.kotlinlogging.KotlinLogging
34
import io.ktor.client.HttpClient
45
import io.ktor.client.request.HttpRequestBuilder
56
import io.modelcontextprotocol.kotlin.sdk.Implementation
67
import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION
78
import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME
89

10+
private val logger = KotlinLogging.logger {}
11+
12+
913
/**
1014
* Returns a new WebSocket transport for the Model Context Protocol using the provided HttpClient.
1115
*
@@ -36,6 +40,8 @@ public suspend fun HttpClient.mcpWebSocket(
3640
version = LIB_VERSION,
3741
),
3842
)
43+
logger.debug { "Client started to connect to server" }
3944
client.connect(transport)
45+
logger.debug { "Client finished to connect to server" }
4046
return client
4147
}

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import kotlin.reflect.typeOf
4242
import kotlin.time.Duration
4343
import kotlin.time.Duration.Companion.milliseconds
4444

45-
private val LOGGER = KotlinLogging.logger { }
45+
private val logger = KotlinLogging.logger { }
4646

4747
public const val IMPLEMENTATION_NAME: String = "mcp-ktor"
4848

@@ -212,6 +212,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
212212
}
213213
}
214214

215+
logger.info { "Starting transport" }
215216
return transport.start()
216217
}
217218

@@ -229,29 +230,29 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
229230
}
230231

231232
private suspend fun onNotification(notification: JSONRPCNotification) {
232-
LOGGER.trace { "Received notification: ${notification.method}" }
233+
logger.trace { "Received notification: ${notification.method}" }
233234

234235
val handler = notificationHandlers[notification.method] ?: fallbackNotificationHandler
235236

236237
if (handler == null) {
237-
LOGGER.trace { "No handler found for notification: ${notification.method}" }
238+
logger.trace { "No handler found for notification: ${notification.method}" }
238239
return
239240
}
240241
try {
241242
handler(notification)
242243
} catch (cause: Throwable) {
243-
LOGGER.error(cause) { "Error handling notification: ${notification.method}" }
244+
logger.error(cause) { "Error handling notification: ${notification.method}" }
244245
onError(cause)
245246
}
246247
}
247248

248249
private suspend fun onRequest(request: JSONRPCRequest) {
249-
LOGGER.trace { "Received request: ${request.method} (id: ${request.id})" }
250+
logger.trace { "Received request: ${request.method} (id: ${request.id})" }
250251

251252
val handler = requestHandlers[request.method] ?: fallbackRequestHandler
252253

253254
if (handler === null) {
254-
LOGGER.trace { "No handler found for request: ${request.method}" }
255+
logger.trace { "No handler found for request: ${request.method}" }
255256
try {
256257
transport?.send(
257258
JSONRPCResponse(
@@ -263,15 +264,15 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
263264
),
264265
)
265266
} catch (cause: Throwable) {
266-
LOGGER.error(cause) { "Error sending method not found response" }
267+
logger.error(cause) { "Error sending method not found response" }
267268
onError(cause)
268269
}
269270
return
270271
}
271272

272273
try {
273274
val result = handler(request, RequestHandlerExtra())
274-
LOGGER.trace { "Request handled successfully: ${request.method} (id: ${request.id})" }
275+
logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" }
275276

276277
transport?.send(
277278
JSONRPCResponse(
@@ -280,7 +281,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
280281
),
281282
)
282283
} catch (cause: Throwable) {
283-
LOGGER.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" }
284+
logger.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" }
284285

285286
try {
286287
transport?.send(
@@ -293,7 +294,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
293294
),
294295
)
295296
} catch (sendError: Throwable) {
296-
LOGGER.error(sendError) {
297+
logger.error(sendError) {
297298
"Failed to send error response for request: ${request.method} (id: ${request.id})"
298299
}
299300
// Optionally implement fallback behavior here
@@ -302,7 +303,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
302303
}
303304

304305
private fun onProgress(notification: ProgressNotification) {
305-
LOGGER.trace {
306+
logger.trace {
306307
"Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}"
307308
}
308309
val progress = notification.params.progress
@@ -315,7 +316,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
315316
val error = Error(
316317
"Received a progress notification for an unknown token: ${McpJson.encodeToString(notification)}",
317318
)
318-
LOGGER.error { error.message }
319+
logger.error { error.message }
319320
onError(error)
320321
return
321322
}
@@ -390,9 +391,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
390391
* Do not use this method to emit notifications! Use notification() instead.
391392
*/
392393
public suspend fun <T : RequestResult> request(request: Request, options: RequestOptions? = null): T {
393-
LOGGER.trace { "Sending request: ${request.method}" }
394+
logger.trace { "Sending request: ${request.method}" }
394395
val result = CompletableDeferred<T>()
395-
val transport = this@Protocol.transport ?: throw Error("Not connected")
396+
val transport = transport ?: throw Error("Not connected")
396397

397398
if (this@Protocol.options?.enforceStrictCapabilities == true) {
398399
assertCapabilityForMethod(request.method)
@@ -402,7 +403,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
402403
val messageId = message.id
403404

404405
if (options?.onProgress != null) {
405-
LOGGER.trace { "Registering progress handler for request id: $messageId" }
406+
logger.trace { "Registering progress handler for request id: $messageId" }
406407
_progressHandlers.update { current ->
407408
current.put(messageId, options.onProgress)
408409
}
@@ -452,12 +453,12 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
452453
val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT
453454
try {
454455
withTimeout(timeout) {
455-
LOGGER.trace { "Sending request message with id: $messageId" }
456+
logger.trace { "Sending request message with id: $messageId" }
456457
this@Protocol.transport?.send(message)
457458
}
458459
return result.await()
459460
} catch (cause: TimeoutCancellationException) {
460-
LOGGER.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" }
461+
logger.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" }
461462
cancel(
462463
McpError(
463464
ErrorCode.Defined.RequestTimeout.code,
@@ -474,7 +475,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
474475
* Emits a notification, which is a one-way message that does not expect a response.
475476
*/
476477
public suspend fun notification(notification: Notification) {
477-
LOGGER.trace { "Sending notification: ${notification.method}" }
478+
logger.trace { "Sending notification: ${notification.method}" }
478479
val transport = this.transport ?: error("Not connected")
479480
assertNotificationCapability(notification.method)
480481

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.modelcontextprotocol.kotlin.sdk.shared
22

3+
import io.github.oshai.kotlinlogging.KotlinLogging
34
import io.ktor.websocket.Frame
45
import io.ktor.websocket.WebSocketSession
56
import io.ktor.websocket.close
@@ -17,6 +18,9 @@ import kotlin.concurrent.atomics.ExperimentalAtomicApi
1718

1819
public const val MCP_SUBPROTOCOL: String = "mcp"
1920

21+
private val logger = KotlinLogging.logger {}
22+
23+
2024
/**
2125
* Abstract class representing a WebSocket transport for the Model Context Protocol (MCP).
2226
* Handles communication over a WebSocket session.
@@ -40,6 +44,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
4044
protected abstract suspend fun initializeSession()
4145

4246
override suspend fun start() {
47+
logger.debug { "Starting websocket transport" }
48+
4349
if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
4450
error(
4551
"WebSocketClientTransport already started! " +
@@ -53,7 +59,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
5359
while (true) {
5460
val message = try {
5561
session.incoming.receive()
56-
} catch (_: ClosedReceiveChannelException) {
62+
} catch (e: ClosedReceiveChannelException) {
63+
logger.debug { "Closed receive channel, exiting" }
5764
return@launch
5865
}
5966

@@ -84,6 +91,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
8491
}
8592

8693
override suspend fun send(message: JSONRPCMessage) {
94+
logger.debug { "Sending message" }
8795
if (!initialized.load()) {
8896
error("Not connected")
8997
}
@@ -96,6 +104,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
96104
error("Not connected")
97105
}
98106

107+
logger.debug { "Closing websocket session" }
99108
session.close()
100109
session.coroutineContext.job.join()
101110
}

0 commit comments

Comments
 (0)