diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index ea63150f..e21696bc 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -157,9 +157,11 @@ public open class Client(private val clientInfo: Implementation, options: Client notification(InitializedNotification()) } catch (error: Throwable) { + logger.error(error) { "Failed to initialize client: ${error.message}" } close() + if (error !is CancellationException) { - throw IllegalStateException("Error connecting to transport: ${error.message}") + throw IllegalStateException("Error connecting to transport: ${error.message}", error) } throw error diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt index 950f37fa..c0bcd546 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import io.ktor.client.plugins.sse.ClientSSESession import io.ktor.client.plugins.sse.sseSession @@ -46,6 +47,8 @@ public class SseClientTransport( private val reconnectionTime: Duration? = null, private val requestBuilder: HttpRequestBuilder.() -> Unit = {}, ) : AbstractTransport() { + private val logger = KotlinLogging.logger {} + private val initialized: AtomicBoolean = AtomicBoolean(false) private val endpoint = CompletableDeferred() @@ -111,6 +114,8 @@ public class SseClientTransport( val text = response.bodyAsText() error("Error POSTing to endpoint (HTTP ${response.status}): $text") } + + logger.debug { "Client successfully sent message via SSE $endpoint" } } catch (e: Throwable) { _onError(e) throw e @@ -158,6 +163,7 @@ public class SseClientTransport( val path = if (eventData.startsWith("/")) eventData.substring(1) else eventData val endpointUrl = Url("$baseUrl/$path") endpoint.complete(endpointUrl.toString()) + logger.debug { "Client connected to endpoint: $endpointUrl" } } catch (e: Throwable) { _onError(e) endpoint.completeExceptionally(e) diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt index 45719073..ec3f9470 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketClientTransport.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import io.ktor.client.plugins.websocket.webSocketSession import io.ktor.client.request.HttpRequestBuilder @@ -10,6 +11,8 @@ import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport import kotlin.properties.Delegates +private val logger = KotlinLogging.logger {} + /** * Client transport for WebSocket: this will connect to a server over the WebSocket protocol. */ @@ -21,6 +24,8 @@ public class WebSocketClientTransport( override var session: WebSocketSession by Delegates.notNull() override suspend fun initializeSession() { + logger.debug { "Websocket session initialization started..." } + session = urlString?.let { client.webSocketSession(it) { requestBuilder() @@ -32,5 +37,7 @@ public class WebSocketClientTransport( header(HttpHeaders.SecWebSocketProtocol, MCP_SUBPROTOCOL) } + + logger.debug { "Websocket session initialization finished" } } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt index 77062ab1..86479533 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt @@ -1,11 +1,14 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import io.ktor.client.request.HttpRequestBuilder import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME +private val logger = KotlinLogging.logger {} + /** * Returns a new WebSocket transport for the Model Context Protocol using the provided HttpClient. * @@ -36,6 +39,8 @@ public suspend fun HttpClient.mcpWebSocket( version = LIB_VERSION, ), ) + logger.debug { "Client started to connect to server" } client.connect(transport) + logger.debug { "Client finished to connect to server" } return client } diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index ed9a268c..333b9f8c 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -2328,6 +2328,35 @@ public final class io/modelcontextprotocol/kotlin/sdk/ResourceListChangedNotific public final fun serializer ()Lkotlinx/serialization/KSerializer; } +public final class io/modelcontextprotocol/kotlin/sdk/ResourceReference : io/modelcontextprotocol/kotlin/sdk/Reference { + public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/ResourceReference$Companion; + public static final field TYPE Ljava/lang/String; + public fun (Ljava/lang/String;)V + public final fun component1 ()Ljava/lang/String; + public final fun copy (Ljava/lang/String;)Lio/modelcontextprotocol/kotlin/sdk/ResourceReference; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/ResourceReference;Ljava/lang/String;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/ResourceReference; + public fun equals (Ljava/lang/Object;)Z + public fun getType ()Ljava/lang/String; + public final fun getUri ()Ljava/lang/String; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final synthetic class io/modelcontextprotocol/kotlin/sdk/ResourceReference$$serializer : kotlinx/serialization/internal/GeneratedSerializer { + public static final field INSTANCE Lio/modelcontextprotocol/kotlin/sdk/ResourceReference$$serializer; + public final fun childSerializers ()[Lkotlinx/serialization/KSerializer; + public final fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Lio/modelcontextprotocol/kotlin/sdk/ResourceReference; + public synthetic fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ljava/lang/Object; + public final fun getDescriptor ()Lkotlinx/serialization/descriptors/SerialDescriptor; + public final fun serialize (Lkotlinx/serialization/encoding/Encoder;Lio/modelcontextprotocol/kotlin/sdk/ResourceReference;)V + public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V + public fun typeParametersSerializers ()[Lkotlinx/serialization/KSerializer; +} + +public final class io/modelcontextprotocol/kotlin/sdk/ResourceReference$Companion { + public final fun serializer ()Lkotlinx/serialization/KSerializer; +} + public final class io/modelcontextprotocol/kotlin/sdk/ResourceTemplate { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/ResourceTemplate$Companion; public fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/Annotations;)V diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 580c3982..b05cbd57 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -42,7 +42,7 @@ import kotlin.reflect.typeOf import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds -private val LOGGER = KotlinLogging.logger { } +private val logger = KotlinLogging.logger { } public const val IMPLEMENTATION_NAME: String = "mcp-ktor" @@ -212,6 +212,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio } } + logger.info { "Starting transport" } return transport.start() } @@ -229,29 +230,29 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio } private suspend fun onNotification(notification: JSONRPCNotification) { - LOGGER.trace { "Received notification: ${notification.method}" } + logger.trace { "Received notification: ${notification.method}" } val handler = notificationHandlers[notification.method] ?: fallbackNotificationHandler if (handler == null) { - LOGGER.trace { "No handler found for notification: ${notification.method}" } + logger.trace { "No handler found for notification: ${notification.method}" } return } try { handler(notification) } catch (cause: Throwable) { - LOGGER.error(cause) { "Error handling notification: ${notification.method}" } + logger.error(cause) { "Error handling notification: ${notification.method}" } onError(cause) } } private suspend fun onRequest(request: JSONRPCRequest) { - LOGGER.trace { "Received request: ${request.method} (id: ${request.id})" } + logger.trace { "Received request: ${request.method} (id: ${request.id})" } val handler = requestHandlers[request.method] ?: fallbackRequestHandler if (handler === null) { - LOGGER.trace { "No handler found for request: ${request.method}" } + logger.trace { "No handler found for request: ${request.method}" } try { transport?.send( JSONRPCResponse( @@ -263,7 +264,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio ), ) } catch (cause: Throwable) { - LOGGER.error(cause) { "Error sending method not found response" } + logger.error(cause) { "Error sending method not found response" } onError(cause) } return @@ -271,7 +272,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio try { val result = handler(request, RequestHandlerExtra()) - LOGGER.trace { "Request handled successfully: ${request.method} (id: ${request.id})" } + logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" } transport?.send( JSONRPCResponse( @@ -280,7 +281,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio ), ) } catch (cause: Throwable) { - LOGGER.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" } + logger.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" } try { transport?.send( @@ -293,7 +294,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio ), ) } catch (sendError: Throwable) { - LOGGER.error(sendError) { + logger.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" } // Optionally implement fallback behavior here @@ -302,7 +303,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio } private fun onProgress(notification: ProgressNotification) { - LOGGER.trace { + logger.trace { "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" } val progress = notification.params.progress @@ -315,7 +316,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio val error = Error( "Received a progress notification for an unknown token: ${McpJson.encodeToString(notification)}", ) - LOGGER.error { error.message } + logger.error { error.message } onError(error) return } @@ -390,9 +391,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio * Do not use this method to emit notifications! Use notification() instead. */ public suspend fun request(request: Request, options: RequestOptions? = null): T { - LOGGER.trace { "Sending request: ${request.method}" } + logger.trace { "Sending request: ${request.method}" } val result = CompletableDeferred() - val transport = this@Protocol.transport ?: throw Error("Not connected") + val transport = transport ?: throw Error("Not connected") if (this@Protocol.options?.enforceStrictCapabilities == true) { assertCapabilityForMethod(request.method) @@ -402,7 +403,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio val messageId = message.id if (options?.onProgress != null) { - LOGGER.trace { "Registering progress handler for request id: $messageId" } + logger.trace { "Registering progress handler for request id: $messageId" } _progressHandlers.update { current -> current.put(messageId, options.onProgress) } @@ -452,12 +453,12 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT try { withTimeout(timeout) { - LOGGER.trace { "Sending request message with id: $messageId" } + logger.trace { "Sending request message with id: $messageId" } this@Protocol.transport?.send(message) } return result.await() } catch (cause: TimeoutCancellationException) { - LOGGER.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" } + logger.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" } cancel( McpError( ErrorCode.Defined.RequestTimeout.code, @@ -474,7 +475,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio * Emits a notification, which is a one-way message that does not expect a response. */ public suspend fun notification(notification: Notification) { - LOGGER.trace { "Sending notification: ${notification.method}" } + logger.trace { "Sending notification: ${notification.method}" } val transport = this.transport ?: error("Not connected") assertNotificationCapability(notification.method) diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index 4a936768..85c60c05 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.shared +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.websocket.Frame import io.ktor.websocket.WebSocketSession import io.ktor.websocket.close @@ -17,6 +18,8 @@ import kotlin.concurrent.atomics.ExperimentalAtomicApi public const val MCP_SUBPROTOCOL: String = "mcp" +private val logger = KotlinLogging.logger {} + /** * Abstract class representing a WebSocket transport for the Model Context Protocol (MCP). * Handles communication over a WebSocket session. @@ -40,6 +43,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { protected abstract suspend fun initializeSession() override suspend fun start() { + logger.debug { "Starting websocket transport" } + if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { error( "WebSocketClientTransport already started! " + @@ -53,7 +58,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { while (true) { val message = try { session.incoming.receive() - } catch (_: ClosedReceiveChannelException) { + } catch (e: ClosedReceiveChannelException) { + logger.debug { "Closed receive channel, exiting" } return@launch } @@ -84,6 +90,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } override suspend fun send(message: JSONRPCMessage) { + logger.debug { "Sending message" } if (!initialized.load()) { error("Not connected") } @@ -96,6 +103,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { error("Not connected") } + logger.debug { "Closing websocket session" } session.close() session.coroutineContext.job.join() } diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index 7e2ed4e1..93605a33 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -48,7 +48,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredTool { public fun toString ()Ljava/lang/String; } -public class io/modelcontextprotocol/kotlin/sdk/server/Server : io/modelcontextprotocol/kotlin/sdk/shared/Protocol { +public class io/modelcontextprotocol/kotlin/sdk/server/Server { public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;)V public final fun addPrompt (Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;)V public final fun addPrompt (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Lkotlin/jvm/functions/Function2;)V @@ -61,35 +61,20 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server : io/modelcontextp public final fun addTool (Ljava/lang/String;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/Tool$Input;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/Tool$Output;Lio/modelcontextprotocol/kotlin/sdk/ToolAnnotations;Lkotlin/jvm/functions/Function2;)V public static synthetic fun addTool$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/Tool$Input;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/Tool$Output;Lio/modelcontextprotocol/kotlin/sdk/ToolAnnotations;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V public final fun addTools (Ljava/util/List;)V - protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V - protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V - public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V - public final fun createElicitation (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static synthetic fun createElicitation$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; - public final fun createMessage (Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; - public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; - public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun connect (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun getPrompts ()Ljava/util/Map; public final fun getResources ()Ljava/util/Map; public final fun getTools ()Ljava/util/Map; - public final fun listRoots (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; - public fun onClose ()V public final fun onClose (Lkotlin/jvm/functions/Function0;)V + public final fun onConnect (Lkotlin/jvm/functions/Function0;)V public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V - public final fun ping (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun removePrompt (Ljava/lang/String;)Z public final fun removePrompts (Ljava/util/List;)I public final fun removeResource (Ljava/lang/String;)Z public final fun removeResources (Ljava/util/List;)I public final fun removeTool (Ljava/lang/String;)Z public final fun removeTools (Ljava/util/List;)I - public final fun sendLoggingMessage (Lio/modelcontextprotocol/kotlin/sdk/LoggingMessageNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun sendPromptListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun sendResourceListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun sendResourceUpdated (Lio/modelcontextprotocol/kotlin/sdk/ResourceUpdatedNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun sendToolListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } public final class io/modelcontextprotocol/kotlin/sdk/server/ServerOptions : io/modelcontextprotocol/kotlin/sdk/shared/ProtocolOptions { @@ -98,6 +83,30 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/ServerOptions : io/ public final fun getCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities; } +public class io/modelcontextprotocol/kotlin/sdk/server/ServerSession : io/modelcontextprotocol/kotlin/sdk/shared/Protocol { + public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;)V + protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V + public final fun createElicitation (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createElicitation$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/CreateElicitationRequest$RequestedSchema;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun createMessage (Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; + public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun listRoots (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public fun onClose ()V + public final fun onClose (Lkotlin/jvm/functions/Function0;)V + public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V + public final fun ping (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendLoggingMessage (Lio/modelcontextprotocol/kotlin/sdk/LoggingMessageNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendPromptListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendResourceUpdated (Lio/modelcontextprotocol/kotlin/sdk/ResourceUpdatedNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun sendToolListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class io/modelcontextprotocol/kotlin/sdk/server/SseServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { public fun (Ljava/lang/String;Lio/ktor/server/sse/ServerSSESession;)V public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @@ -116,8 +125,13 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTranspor } public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensionsKt { + public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;)V public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Routing;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V + public static final fun mcpWebSocket (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function0;)V public static synthetic fun mcpWebSocket$default (Lio/ktor/server/routing/Route;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V public static synthetic fun mcpWebSocket$default (Lio/ktor/server/routing/Route;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V public static final fun mcpWebSocketTransport (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function2;)V diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 57bae05f..934ba049 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -13,11 +13,29 @@ import io.ktor.server.routing.routing import io.ktor.server.sse.SSE import io.ktor.server.sse.ServerSSESession import io.ktor.server.sse.sse -import io.ktor.util.collections.ConcurrentMap import io.ktor.utils.io.KtorDsl +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.PersistentMap +import kotlinx.collections.immutable.toPersistentMap private val logger = KotlinLogging.logger {} +internal class SseTransportManager(transports: Map = emptyMap()) { + private val transports: AtomicRef> = atomic(transports.toPersistentMap()) + + fun getTransport(sessionId: String): SseServerTransport? = transports.value[sessionId] + + fun addTransport(transport: SseServerTransport) { + transports.update { it.put(transport.sessionId, transport) } + } + + fun removeTransport(sessionId: String) { + transports.update { it.remove(sessionId) } + } +} + @KtorDsl public fun Routing.mcp(path: String, block: ServerSSESession.() -> Server) { route(path) { @@ -25,85 +43,76 @@ public fun Routing.mcp(path: String, block: ServerSSESession.() -> Server) { } } -/** - * Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE). - */ +/* +* Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE). +*/ @KtorDsl public fun Routing.mcp(block: ServerSSESession.() -> Server) { - val transports = ConcurrentMap() + val sseTransportManager = SseTransportManager() sse { - mcpSseEndpoint("", transports, block) + mcpSseEndpoint("", sseTransportManager, block) } post { - mcpPostEndpoint(transports) + mcpPostEndpoint(sseTransportManager) } } @Suppress("FunctionName") -@Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.WARNING) +@Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.ERROR) public fun Application.MCP(block: ServerSSESession.() -> Server) { mcp(block) } @KtorDsl public fun Application.mcp(block: ServerSSESession.() -> Server) { - val transports = ConcurrentMap() - install(SSE) routing { - sse("/sse") { - mcpSseEndpoint("/message", transports, block) - } - - post("/message") { - mcpPostEndpoint(transports) - } + mcp(block) } } -private suspend fun ServerSSESession.mcpSseEndpoint( +internal suspend fun ServerSSESession.mcpSseEndpoint( postEndpoint: String, - transports: ConcurrentMap, + sseTransportManager: SseTransportManager, block: ServerSSESession.() -> Server, ) { - val transport = mcpSseTransport(postEndpoint, transports) + val transport = mcpSseTransport(postEndpoint, sseTransportManager) val server = block() server.onClose { logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } - transports.remove(transport.sessionId) + sseTransportManager.removeTransport(transport.sessionId) } server.connect(transport) + logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" } } internal fun ServerSSESession.mcpSseTransport( postEndpoint: String, - transports: ConcurrentMap, + sseTransportManager: SseTransportManager, ): SseServerTransport { val transport = SseServerTransport(postEndpoint, this) - transports[transport.sessionId] = transport - + sseTransportManager.addTransport(transport) logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" } return transport } -internal suspend fun RoutingContext.mcpPostEndpoint(transports: ConcurrentMap) { - val sessionId: String = call.request.queryParameters["sessionId"] - ?: run { - call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided") - return - } +internal suspend fun RoutingContext.mcpPostEndpoint(sseTransportManager: SseTransportManager) { + val sessionId: String = call.request.queryParameters["sessionId"] ?: run { + call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided") + return + } logger.debug { "Received message for sessionId: $sessionId" } - val transport = transports[sessionId] + val transport = sseTransportManager.getTransport(sessionId) if (transport == null) { logger.warn { "Session not found for sessionId: $sessionId" } call.respond(HttpStatusCode.NotFound, "Session not found") diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index ac71b5fe..3e5aad8d 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -3,58 +3,35 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.github.oshai.kotlinlogging.KotlinLogging import io.modelcontextprotocol.kotlin.sdk.CallToolRequest import io.modelcontextprotocol.kotlin.sdk.CallToolResult -import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema -import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult -import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest -import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult -import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject -import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.InitializeRequest -import io.modelcontextprotocol.kotlin.sdk.InitializeResult -import io.modelcontextprotocol.kotlin.sdk.InitializedNotification -import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION import io.modelcontextprotocol.kotlin.sdk.ListPromptsRequest import io.modelcontextprotocol.kotlin.sdk.ListPromptsResult import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult -import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest -import io.modelcontextprotocol.kotlin.sdk.ListRootsResult import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest import io.modelcontextprotocol.kotlin.sdk.ListToolsResult -import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification import io.modelcontextprotocol.kotlin.sdk.Method -import io.modelcontextprotocol.kotlin.sdk.PingRequest import io.modelcontextprotocol.kotlin.sdk.Prompt import io.modelcontextprotocol.kotlin.sdk.PromptArgument -import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult import io.modelcontextprotocol.kotlin.sdk.Resource -import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification -import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.Tool import io.modelcontextprotocol.kotlin.sdk.ToolAnnotations -import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification -import io.modelcontextprotocol.kotlin.sdk.shared.Protocol import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions -import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions +import io.modelcontextprotocol.kotlin.sdk.shared.Transport import kotlinx.atomicfu.atomic import kotlinx.atomicfu.getAndUpdate import kotlinx.atomicfu.update import kotlinx.collections.immutable.minus +import kotlinx.collections.immutable.persistentListOf import kotlinx.collections.immutable.persistentMapOf import kotlinx.collections.immutable.toPersistentSet -import kotlinx.coroutines.CompletableDeferred -import kotlinx.serialization.json.JsonObject private val logger = KotlinLogging.logger {} @@ -77,26 +54,17 @@ public class ServerOptions(public val capabilities: ServerCapabilities, enforceS * @param serverInfo Information about this server implementation (name, version). * @param options Configuration options for the server. */ -public open class Server(private val serverInfo: Implementation, options: ServerOptions) : Protocol(options) { +public open class Server(private val serverInfo: Implementation, private val options: ServerOptions) { + private val sessions = atomic(persistentListOf()) + @Suppress("ktlint:standard:backing-property-naming") private var _onInitialized: (() -> Unit) = {} @Suppress("ktlint:standard:backing-property-naming") - private var _onClose: () -> Unit = {} - - /** - * The client's reported capabilities after initialization. - */ - public var clientCapabilities: ClientCapabilities? = null - private set - - /** - * The client's version information after initialization. - */ - public var clientVersion: Implementation? = null - private set + private var _onConnect: (() -> Unit) = {} - private val capabilities: ServerCapabilities = options.capabilities + @Suppress("ktlint:standard:backing-property-naming") + private var _onClose: () -> Unit = {} private val _tools = atomic(persistentMapOf()) private val _prompts = atomic(persistentMapOf()) @@ -108,55 +76,83 @@ public open class Server(private val serverInfo: Implementation, options: Server public val resources: Map get() = _resources.value - init { - logger.debug { "Initializing MCP server with capabilities: $capabilities" } + public suspend fun close() { + logger.debug { "Closing MCP server" } + sessions.value.forEach { it.close() } + _onClose() + } - // Core protocol handlers - setRequestHandler(Method.Defined.Initialize) { request, _ -> - handleInitialize(request) - } - setNotificationHandler(Method.Defined.NotificationsInitialized) { - _onInitialized() - CompletableDeferred(Unit) - } + /** + * Starts a new server session with the given transport and initializes + * internal request handlers based on the server's capabilities. + * + * @param transport The transport layer to connect the session with. + * @return The initialized and connected server session. + */ + public suspend fun connect(transport: Transport): ServerSession { + val session = ServerSession(serverInfo, options) // Internal handlers for tools - if (capabilities.tools != null) { - setRequestHandler(Method.Defined.ToolsList) { _, _ -> + if (options.capabilities.tools != null) { + session.setRequestHandler(Method.Defined.ToolsList) { _, _ -> handleListTools() } - setRequestHandler(Method.Defined.ToolsCall) { request, _ -> + session.setRequestHandler(Method.Defined.ToolsCall) { request, _ -> handleCallTool(request) } } // Internal handlers for prompts - if (capabilities.prompts != null) { - setRequestHandler(Method.Defined.PromptsList) { _, _ -> + if (options.capabilities.prompts != null) { + session.setRequestHandler(Method.Defined.PromptsList) { _, _ -> handleListPrompts() } - setRequestHandler(Method.Defined.PromptsGet) { request, _ -> + session.setRequestHandler(Method.Defined.PromptsGet) { request, _ -> handleGetPrompt(request) } } // Internal handlers for resources - if (capabilities.resources != null) { - setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + if (options.capabilities.resources != null) { + session.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> handleListResources() } - setRequestHandler(Method.Defined.ResourcesRead) { request, _ -> + session.setRequestHandler(Method.Defined.ResourcesRead) { request, _ -> handleReadResource(request) } - setRequestHandler(Method.Defined.ResourcesTemplatesList) { _, _ -> + session.setRequestHandler(Method.Defined.ResourcesTemplatesList) { _, _ -> handleListResourceTemplates() } } + + logger.debug { "Server session connecting to transport" } + session.connect(transport) + logger.debug { "Server session successfully connected to transport" } + sessions.update { it.add(session) } + + _onConnect() + return session + } + + /** + * Registers a callback to be invoked when the new server session connected. + */ + public fun onConnect(block: () -> Unit) { + val old = _onConnect + _onConnect = { + old() + block() + } } /** * Registers a callback to be invoked when the server has completed initialization. */ + @Deprecated( + "Initialization moved to ServerSession, use ServerSession.onInitialized instead.", + ReplaceWith("ServerSession.onInitialized"), + DeprecationLevel.WARNING, + ) public fun onInitialized(block: () -> Unit) { val old = _onInitialized _onInitialized = { @@ -176,14 +172,6 @@ public open class Server(private val serverInfo: Implementation, options: Server } } - /** - * Called when the server connection is closing. - */ - override fun onClose() { - logger.info { "Server connection closing" } - _onClose() - } - /** * Registers a single tool. The client can then call this tool. * @@ -192,7 +180,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support tools. */ public fun addTool(tool: Tool, handler: suspend (CallToolRequest) -> CallToolResult) { - if (capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to add tool '${tool.name}': Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability. Enable it in ServerOptions.") } @@ -232,7 +220,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support tools. */ public fun addTools(toolsToAdd: List) { - if (capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to add tools: Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -248,7 +236,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support tools. */ public fun removeTool(name: String): Boolean { - if (capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to remove tool '$name': Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -275,7 +263,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support tools. */ public fun removeTools(toolNames: List): Int { - if (capabilities.tools == null) { + if (options.capabilities.tools == null) { logger.error { "Failed to remove tools: Server does not support tools capability" } throw IllegalStateException("Server does not support tools capability.") } @@ -302,7 +290,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support prompts. */ public fun addPrompt(prompt: Prompt, promptProvider: suspend (GetPromptRequest) -> GetPromptResult) { - if (capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to add prompt '${prompt.name}': Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -336,7 +324,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support prompts. */ public fun addPrompts(promptsToAdd: List) { - if (capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to add prompts: Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -352,7 +340,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support prompts. */ public fun removePrompt(name: String): Boolean { - if (capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to remove prompt '$name': Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -379,7 +367,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support prompts. */ public fun removePrompts(promptNames: List): Int { - if (capabilities.prompts == null) { + if (options.capabilities.prompts == null) { logger.error { "Failed to remove prompts: Server does not support prompts capability" } throw IllegalStateException("Server does not support prompts capability.") } @@ -416,7 +404,7 @@ public open class Server(private val serverInfo: Implementation, options: Server mimeType: String = "text/html", readHandler: suspend (ReadResourceRequest) -> ReadResourceResult, ) { - if (capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to add resource '$name': Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -436,7 +424,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support resources. */ public fun addResources(resourcesToAdd: List) { - if (capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to add resources: Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -452,7 +440,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support resources. */ public fun removeResource(uri: String): Boolean { - if (capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to remove resource '$uri': Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -479,7 +467,7 @@ public open class Server(private val serverInfo: Implementation, options: Server * @throws IllegalStateException If the server does not support resources. */ public fun removeResources(uris: List): Int { - if (capabilities.resources == null) { + if (options.capabilities.resources == null) { logger.error { "Failed to remove resources: Server does not support resources capability" } throw IllegalStateException("Server does not support resources capability.") } @@ -499,123 +487,7 @@ public open class Server(private val serverInfo: Implementation, options: Server return removedCount } - /** - * Sends a ping request to the client to check connectivity. - * - * @return The result of the ping request. - * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. - */ - public suspend fun ping(): EmptyRequestResult = request(PingRequest()) - - /** - * Creates a message using the server's sampling capability. - * - * @param params The parameters for creating a message. - * @param options Optional request options. - * @return The created message result. - * @throws IllegalStateException If the server does not support sampling or if the request fails. - */ - public suspend fun createMessage( - params: CreateMessageRequest, - options: RequestOptions? = null, - ): CreateMessageResult { - logger.debug { "Creating message with params: $params" } - return request(params, options) - } - - /** - * Lists the available "roots" from the client's perspective (if supported). - * - * @param params JSON parameters for the request, usually empty. - * @param options Optional request options. - * @return The list of roots. - * @throws IllegalStateException If the server or client does not support roots. - */ - public suspend fun listRoots( - params: JsonObject = EmptyJsonObject, - options: RequestOptions? = null, - ): ListRootsResult { - logger.debug { "Listing roots with params: $params" } - return request(ListRootsRequest(params), options) - } - - public suspend fun createElicitation( - message: String, - requestedSchema: RequestedSchema, - options: RequestOptions? = null, - ): CreateElicitationResult { - logger.debug { "Creating elicitation with message: $message" } - return request(CreateElicitationRequest(message, requestedSchema), options) - } - - /** - * Sends a logging message notification to the client. - * - * @param params The logging message notification parameters. - */ - public suspend fun sendLoggingMessage(params: LoggingMessageNotification) { - logger.trace { "Sending logging message: ${params.params.data}" } - notification(params) - } - - /** - * Sends a resource-updated notification to the client, indicating that a specific resource has changed. - * - * @param params Details of the updated resource. - */ - public suspend fun sendResourceUpdated(params: ResourceUpdatedNotification) { - logger.debug { "Sending resource updated notification for: ${params.params.uri}" } - notification(params) - } - - /** - * Sends a notification to the client indicating that the list of resources has changed. - */ - public suspend fun sendResourceListChanged() { - logger.debug { "Sending resource list changed notification" } - notification(ResourceListChangedNotification()) - } - - /** - * Sends a notification to the client indicating that the list of tools has changed. - */ - public suspend fun sendToolListChanged() { - logger.debug { "Sending tool list changed notification" } - notification(ToolListChangedNotification()) - } - - /** - * Sends a notification to the client indicating that the list of prompts has changed. - */ - public suspend fun sendPromptListChanged() { - logger.debug { "Sending prompt list changed notification" } - notification(PromptListChangedNotification()) - } - // --- Internal Handlers --- - - private suspend fun handleInitialize(request: InitializeRequest): InitializeResult { - logger.info { "Handling initialize request from client ${request.clientInfo}" } - clientCapabilities = request.capabilities - clientVersion = request.clientInfo - - val requestedVersion = request.protocolVersion - val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { - requestedVersion - } else { - logger.warn { - "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" - } - LATEST_PROTOCOL_VERSION - } - - return InitializeResult( - protocolVersion = protocolVersion, - capabilities = capabilities, - serverInfo = serverInfo, - ) - } - private suspend fun handleListTools(): ListToolsResult { val toolList = tools.values.map { it.tool } return ListToolsResult(tools = toolList, nextCursor = null) @@ -666,147 +538,6 @@ public open class Server(private val serverInfo: Implementation, options: Server // If you have resource templates, return them here. For now, return empty. return ListResourceTemplatesResult(listOf()) } - - /** - * Asserts that the client supports the capability required for the given [method]. - * - * This method is automatically called by the [Protocol] framework before handling requests. - * Throws [IllegalStateException] if the capability is not supported. - * - * @param method The method for which we are asserting capability. - */ - override fun assertCapabilityForMethod(method: Method) { - logger.trace { "Asserting capability for method: ${method.value}" } - when (method.value) { - "sampling/createMessage" -> { - if (clientCapabilities?.sampling == null) { - logger.error { "Client capability assertion failed: sampling not supported" } - throw IllegalStateException("Client does not support sampling (required for ${method.value})") - } - } - - "roots/list" -> { - if (clientCapabilities?.roots == null) { - throw IllegalStateException("Client does not support listing roots (required for ${method.value})") - } - } - - "elicitation/create" -> { - if (clientCapabilities?.elicitation == null) { - throw IllegalStateException("Client does not support elicitation (required for ${method.value})") - } - } - - "ping" -> { - // No specific capability required - } - } - } - - /** - * Asserts that the server can handle the specified notification method. - * - * Throws [IllegalStateException] if the server does not have the capabilities required to handle this notification. - * - * @param method The notification method. - */ - override fun assertNotificationCapability(method: Method) { - logger.trace { "Asserting notification capability for method: ${method.value}" } - when (method.value) { - "notifications/message" -> { - if (capabilities.logging == null) { - logger.error { "Server capability assertion failed: logging not supported" } - throw IllegalStateException("Server does not support logging (required for ${method.value})") - } - } - - "notifications/resources/updated", - "notifications/resources/list_changed", - -> { - if (capabilities.resources == null) { - throw IllegalStateException( - "Server does not support notifying about resources (required for ${method.value})", - ) - } - } - - "notifications/tools/list_changed" -> { - if (capabilities.tools == null) { - throw IllegalStateException( - "Server does not support notifying of tool list changes (required for ${method.value})", - ) - } - } - - "notifications/prompts/list_changed" -> { - if (capabilities.prompts == null) { - throw IllegalStateException( - "Server does not support notifying of prompt list changes (required for ${method.value})", - ) - } - } - - "notifications/cancelled", - "notifications/progress", - -> { - // Always allowed - } - } - } - - /** - * Asserts that the server can handle the specified request method. - * - * Throws [IllegalStateException] if the server does not have the capabilities required to handle this request. - * - * @param method The request method. - */ - override fun assertRequestHandlerCapability(method: Method) { - logger.trace { "Asserting request handler capability for method: ${method.value}" } - when (method.value) { - "sampling/createMessage" -> { - if (capabilities.sampling == null) { - logger.error { "Server capability assertion failed: sampling not supported" } - throw IllegalStateException("Server does not support sampling (required for $method)") - } - } - - "logging/setLevel" -> { - if (capabilities.logging == null) { - throw IllegalStateException("Server does not support logging (required for $method)") - } - } - - "prompts/get", - "prompts/list", - -> { - if (capabilities.prompts == null) { - throw IllegalStateException("Server does not support prompts (required for $method)") - } - } - - "resources/list", - "resources/templates/list", - "resources/read", - -> { - if (capabilities.resources == null) { - throw IllegalStateException("Server does not support resources (required for $method)") - } - } - - "tools/call", - "tools/list", - -> { - if (capabilities.tools == null) { - throw IllegalStateException("Server does not support tools (required for $method)") - } - } - - "ping", "initialize" -> { - // No capability required - } - } - } } /** diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt new file mode 100644 index 00000000..471ebf06 --- /dev/null +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt @@ -0,0 +1,371 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationRequest.RequestedSchema +import io.modelcontextprotocol.kotlin.sdk.CreateElicitationResult +import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult +import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeRequest +import io.modelcontextprotocol.kotlin.sdk.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest +import io.modelcontextprotocol.kotlin.sdk.ListRootsResult +import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.Method.Defined +import io.modelcontextprotocol.kotlin.sdk.PingRequest +import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.ResourceListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.ResourceUpdatedNotification +import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS +import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification +import io.modelcontextprotocol.kotlin.sdk.shared.Protocol +import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions +import kotlinx.coroutines.CompletableDeferred +import kotlinx.serialization.json.JsonObject + +private val logger = KotlinLogging.logger {} + +public open class ServerSession(private val serverInfo: Implementation, options: ServerOptions) : Protocol(options) { + @Suppress("ktlint:standard:backing-property-naming") + private var _onInitialized: (() -> Unit) = {} + + @Suppress("ktlint:standard:backing-property-naming") + private var _onClose: () -> Unit = {} + + init { + // Core protocol handlers + setRequestHandler(Method.Defined.Initialize) { request, _ -> + handleInitialize(request) + } + setNotificationHandler(Method.Defined.NotificationsInitialized) { + _onInitialized() + CompletableDeferred(Unit) + } + } + + /** + * The capabilities supported by the server, related to the session. + */ + private val serverCapabilities = options.capabilities + + /** + * The client's reported capabilities after initialization. + */ + public var clientCapabilities: ClientCapabilities? = null + private set + + /** + * The client's version information after initialization. + */ + public var clientVersion: Implementation? = null + private set + + /** + * Registers a callback to be invoked when the server has completed initialization. + */ + public fun onInitialized(block: () -> Unit) { + val old = _onInitialized + _onInitialized = { + old() + block() + } + } + + /** + * Registers a callback to be invoked when the server session is closing. + */ + public fun onClose(block: () -> Unit) { + val old = _onClose + _onClose = { + old() + block() + } + } + + /** + * Called when the server session is closing. + */ + override fun onClose() { + logger.debug { "Server connection closing" } + _onClose() + } + + /** + * Sends a ping request to the client to check connectivity. + * + * @return The result of the ping request. + * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. + */ + public suspend fun ping(): EmptyRequestResult = request(PingRequest()) + + /** + * Creates a message using the server's sampling capability. + * + * @param params The parameters for creating a message. + * @param options Optional request options. + * @return The created message result. + * @throws IllegalStateException If the server does not support sampling or if the request fails. + */ + public suspend fun createMessage( + params: CreateMessageRequest, + options: RequestOptions? = null, + ): CreateMessageResult { + logger.debug { "Creating message with params: $params" } + return request(params, options) + } + + /** + * Lists the available "roots" from the client's perspective (if supported). + * + * @param params JSON parameters for the request, usually empty. + * @param options Optional request options. + * @return The list of roots. + * @throws IllegalStateException If the server or client does not support roots. + */ + public suspend fun listRoots( + params: JsonObject = EmptyJsonObject, + options: RequestOptions? = null, + ): ListRootsResult { + logger.debug { "Listing roots with params: $params" } + return request(ListRootsRequest(params), options) + } + + public suspend fun createElicitation( + message: String, + requestedSchema: RequestedSchema, + options: RequestOptions? = null, + ): CreateElicitationResult { + logger.debug { "Creating elicitation with message: $message" } + return request(CreateElicitationRequest(message, requestedSchema), options) + } + + /** + * Sends a logging message notification to the client. + * + * @param notification The logging message notification. + */ + public suspend fun sendLoggingMessage(notification: LoggingMessageNotification) { + logger.trace { "Sending logging message: ${notification.params.data}" } + notification(notification) + } + + /** + * Sends a resource-updated notification to the client, indicating that a specific resource has changed. + * + * @param notification Details of the updated resource. + */ + public suspend fun sendResourceUpdated(notification: ResourceUpdatedNotification) { + logger.debug { "Sending resource updated notification for: ${notification.params.uri}" } + notification(notification) + } + + /** + * Sends a notification to the client indicating that the list of resources has changed. + */ + public suspend fun sendResourceListChanged() { + logger.debug { "Sending resource list changed notification" } + notification(ResourceListChangedNotification()) + } + + /** + * Sends a notification to the client indicating that the list of tools has changed. + */ + public suspend fun sendToolListChanged() { + logger.debug { "Sending tool list changed notification" } + notification(ToolListChangedNotification()) + } + + /** + * Sends a notification to the client indicating that the list of prompts has changed. + */ + public suspend fun sendPromptListChanged() { + logger.debug { "Sending prompt list changed notification" } + notification(PromptListChangedNotification()) + } + + /** + * Asserts that the client supports the capability required for the given [method]. + * + * This method is automatically called by the [Protocol] framework before handling requests. + * Throws [IllegalStateException] if the capability is not supported. + * + * @param method The method for which we are asserting capability. + */ + override fun assertCapabilityForMethod(method: Method) { + logger.trace { "Asserting capability for method: ${method.value}" } + when (method) { + Defined.SamplingCreateMessage -> { + if (clientCapabilities?.sampling == null) { + logger.error { "Client capability assertion failed: sampling not supported" } + throw IllegalStateException("Client does not support sampling (required for ${method.value})") + } + } + + Defined.RootsList -> { + if (clientCapabilities?.roots == null) { + logger.error { "Client capability assertion failed: listing roots not supported" } + throw IllegalStateException("Client does not support listing roots (required for ${method.value})") + } + } + + Defined.ElicitationCreate -> { + if (clientCapabilities?.elicitation == null) { + logger.error { "Client capability assertion failed: elicitation not supported" } + throw IllegalStateException("Client does not support elicitation (required for ${method.value})") + } + } + + Defined.Ping -> { + // No specific capability required + } + + else -> { + // For notifications not specifically listed, no assertion by default + } + } + } + + /** + * Asserts that the server can handle the specified notification method. + * + * Throws [IllegalStateException] if the server does not have the capabilities required to handle this notification. + * + * @param method The notification method. + */ + override fun assertNotificationCapability(method: Method) { + logger.trace { "Asserting notification capability for method: ${method.value}" } + when (method) { + Defined.NotificationsMessage -> { + if (serverCapabilities.logging == null) { + logger.error { "Server capability assertion failed: logging not supported" } + throw IllegalStateException("Server does not support logging (required for ${method.value})") + } + } + + Defined.NotificationsResourcesUpdated, + Defined.NotificationsResourcesListChanged, + -> { + if (serverCapabilities.resources == null) { + throw IllegalStateException( + "Server does not support notifying about resources (required for ${method.value})", + ) + } + } + + Defined.NotificationsToolsListChanged -> { + if (serverCapabilities.tools == null) { + throw IllegalStateException( + "Server does not support notifying of tool list changes (required for ${method.value})", + ) + } + } + + Defined.NotificationsPromptsListChanged -> { + if (serverCapabilities.prompts == null) { + throw IllegalStateException( + "Server does not support notifying of prompt list changes (required for ${method.value})", + ) + } + } + + Defined.NotificationsCancelled, + Defined.NotificationsProgress, + -> { + // Always allowed + } + + else -> { + // For notifications not specifically listed, no assertion by default + } + } + } + + /** + * Asserts that the server can handle the specified request method. + * + * Throws [IllegalStateException] if the server does not have the capabilities required to handle this request. + * + * @param method The request method. + */ + override fun assertRequestHandlerCapability(method: Method) { + logger.trace { "Asserting request handler capability for method: ${method.value}" } + when (method) { + Defined.SamplingCreateMessage -> { + if (serverCapabilities.sampling == null) { + logger.error { "Server capability assertion failed: sampling not supported" } + throw IllegalStateException("Server does not support sampling (required for $method)") + } + } + + Defined.LoggingSetLevel -> { + if (serverCapabilities.logging == null) { + throw IllegalStateException("Server does not support logging (required for $method)") + } + } + + Defined.PromptsGet, + Defined.PromptsList, + -> { + if (serverCapabilities.prompts == null) { + throw IllegalStateException("Server does not support prompts (required for $method)") + } + } + + Defined.ResourcesList, + Defined.ResourcesTemplatesList, + Defined.ResourcesRead, + Defined.ResourcesSubscribe, + Defined.ResourcesUnsubscribe, + -> { + if (serverCapabilities.resources == null) { + throw IllegalStateException("Server does not support resources (required for $method)") + } + } + + Defined.ToolsCall, + Defined.ToolsList, + -> { + if (serverCapabilities.tools == null) { + throw IllegalStateException("Server does not support tools (required for $method)") + } + } + + Defined.Ping, Defined.Initialize -> { + // No capability required + } + + else -> { + // For notifications not specifically listed, no assertion by default + } + } + } + + private suspend fun handleInitialize(request: InitializeRequest): InitializeResult { + logger.debug { "Handling initialization request from client" } + clientCapabilities = request.capabilities + clientVersion = request.clientInfo + + val requestedVersion = request.protocolVersion + val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { + requestedVersion + } else { + logger.warn { + "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" + } + LATEST_PROTOCOL_VERSION + } + + return InitializeResult( + protocolVersion = protocolVersion, + capabilities = serverCapabilities, + serverInfo = serverInfo, + ) + } +} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt index a3d2fd34..1ceb764b 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt @@ -1,12 +1,83 @@ package io.modelcontextprotocol.kotlin.sdk.server +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.server.application.Application +import io.ktor.server.application.install import io.ktor.server.routing.Route +import io.ktor.server.routing.Routing +import io.ktor.server.routing.routing import io.ktor.server.websocket.WebSocketServerSession +import io.ktor.server.websocket.WebSockets import io.ktor.server.websocket.webSocket +import io.ktor.utils.io.CancellationException +import io.ktor.utils.io.KtorDsl import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME +import kotlinx.coroutines.awaitCancellation + +private val logger = KotlinLogging.logger {} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket. + */ +@KtorDsl +public fun Routing.mcpWebSocket(block: () -> Server) { + webSocket { + mcpWebSocketEndpoint(block) + } +} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket. + */ +@KtorDsl +public fun Routing.mcpWebSocket(path: String, block: () -> Server) { + webSocket(path) { + mcpWebSocketEndpoint(block) + } +} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket. + */ +@KtorDsl +public fun Application.mcpWebSocket(block: () -> Server) { + install(WebSockets) + + routing { + mcpWebSocket(block) + } +} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over WebSocket at the specified path. + */ +@KtorDsl +public fun Application.mcpWebSocket(path: String, block: () -> Server) { + install(WebSockets) + + routing { + mcpWebSocket(path, block) + } +} + +internal suspend fun WebSocketServerSession.mcpWebSocketEndpoint(block: () -> Server) { + logger.info { "Ktor Server establishing new connection" } + val transport = createMcpTransport(this) + val server = block() + var session: ServerSession? = null + try { + session = server.connect(transport) + awaitCancellation() + } catch (e: CancellationException) { + session?.close() + } +} + +private fun createMcpTransport(webSocketSession: WebSocketServerSession): WebSocketMcpServerTransport = + WebSocketMcpServerTransport(webSocketSession) /** * Registers a WebSocket route that establishes an MCP (Model Context Protocol) server session. @@ -14,12 +85,28 @@ import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME * @param options Optional server configuration settings for the MCP server. * @param handler A suspend function that defines the server's behavior. */ +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("Routing.mcpWebSocket"), + DeprecationLevel.WARNING, +) public fun Route.mcpWebSocket(options: ServerOptions? = null, handler: suspend Server.() -> Unit = {}) { webSocket { createMcpServer(this, options, handler) } } +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("Routing.mcpWebSocket"), + DeprecationLevel.WARNING, +) +public fun Route.mcpWebSocket(block: () -> Server) { + webSocket { + block().connect(createMcpTransport(this)) + } +} + /** * Registers a WebSocket route at the specified [path] that establishes an MCP server session. * @@ -27,6 +114,11 @@ public fun Route.mcpWebSocket(options: ServerOptions? = null, handler: suspend S * @param options Optional server configuration settings for the MCP server. * @param handler A suspend function that defines the server's behavior. */ +@Deprecated( + "Use mcpWebSocket with a path and a lambda that returns a Server instance instead", + ReplaceWith("Routing.mcpWebSocket"), + DeprecationLevel.WARNING, +) public fun Route.mcpWebSocket(path: String, options: ServerOptions? = null, handler: suspend Server.() -> Unit = {}) { webSocket(path) { createMcpServer(this, options, handler) @@ -38,6 +130,11 @@ public fun Route.mcpWebSocket(path: String, options: ServerOptions? = null, hand * * @param handler A suspend function that defines the behavior of the transport layer. */ +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("Routing.mcpWebSocket"), + DeprecationLevel.WARNING, +) public fun Route.mcpWebSocketTransport(handler: suspend WebSocketMcpServerTransport.() -> Unit = {}) { webSocket { val transport = createMcpTransport(this) @@ -53,6 +150,11 @@ public fun Route.mcpWebSocketTransport(handler: suspend WebSocketMcpServerTransp * @param path The URL path at which to register the WebSocket route. * @param handler A suspend function that defines the behavior of the transport layer. */ +@Deprecated( + "Use mcpWebSocket with a path and a lambda that returns a Server instance instead", + ReplaceWith("Routing.mcpWebSocket"), + DeprecationLevel.WARNING, +) public fun Route.mcpWebSocketTransport(path: String, handler: suspend WebSocketMcpServerTransport.() -> Unit = {}) { webSocket(path) { val transport = createMcpTransport(this) @@ -62,6 +164,11 @@ public fun Route.mcpWebSocketTransport(path: String, handler: suspend WebSocketM } } +@Deprecated( + "Use mcpWebSocket with a lambda that returns a Server instance instead", + ReplaceWith("mcpWebSocket"), + DeprecationLevel.WARNING, +) private suspend fun Route.createMcpServer( session: WebSocketServerSession, options: ServerOptions?, @@ -87,6 +194,3 @@ private suspend fun Route.createMcpServer( handler(server) server.close() } - -private fun createMcpTransport(session: WebSocketServerSession): WebSocketMcpServerTransport = - WebSocketMcpServerTransport(session) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt index 877fda58..0c0cc78c 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt @@ -1,10 +1,13 @@ package io.modelcontextprotocol.kotlin.sdk.server +import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.http.HttpHeaders import io.ktor.server.websocket.WebSocketServerSession import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport +private val logger = KotlinLogging.logger {} + /** * Server-side implementation of the MCP (Model Context Protocol) transport over WebSocket. * @@ -12,6 +15,7 @@ import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport */ public class WebSocketMcpServerTransport(override val session: WebSocketServerSession) : WebSocketMcpTransport() { override suspend fun initializeSession() { + logger.debug { "Checking session headers" } val subprotocol = session.call.request.headers[HttpHeaders.SecWebSocketProtocol] if (subprotocol != MCP_SUBPROTOCOL) { error("Invalid subprotocol: $subprotocol, expected $MCP_SUBPROTOCOL") diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt index 0294620b..d132ae5f 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -30,6 +30,7 @@ import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.ServerSession import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport import kotlinx.coroutines.CompletableDeferred @@ -241,25 +242,6 @@ class ClientTest { serverOptions, ) - server.setRequestHandler(Method.Defined.Initialize) { _, _ -> - InitializeResult( - protocolVersion = LATEST_PROTOCOL_VERSION, - capabilities = ServerCapabilities( - resources = ServerCapabilities.Resources(null, null), - tools = ServerCapabilities.Tools(null), - ), - serverInfo = Implementation(name = "test", version = "1.0"), - ) - } - - server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> - ListResourcesResult(resources = emptyList(), nextCursor = null) - } - - server.setRequestHandler(Method.Defined.ToolsList) { _, _ -> - ListToolsResult(tools = emptyList(), nextCursor = null) - } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( @@ -269,15 +251,36 @@ class ClientTest { ), ) + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) }, ).joinAll() + val serverSession = serverSessionResult.await() + serverSession.setRequestHandler(Method.Defined.Initialize) { _, _ -> + InitializeResult( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(null, null), + tools = ServerCapabilities.Tools(null), + ), + serverInfo = Implementation(name = "test", version = "1.0"), + ) + } + + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + ListResourcesResult(resources = emptyList(), nextCursor = null) + } + + serverSession.setRequestHandler(Method.Defined.ToolsList) { _, _ -> + ListToolsResult(tools = emptyList(), nextCursor = null) + } // Server supports resources and tools, but not prompts val caps = client.serverCapabilities assertEquals(ServerCapabilities.Resources(null, null), caps?.resources) @@ -368,24 +371,27 @@ class ClientTest { val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) println("Client connected") }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) println("Server connected") }, ).joinAll() + val serverSession = serverSessionResult.await() // These should not throw val jsonObject = buildJsonObject { put("name", "John") put("age", 30) put("isStudent", false) } - server.sendLoggingMessage( + serverSession.sendLoggingMessage( LoggingMessageNotification( params = LoggingMessageNotification.Params( level = LoggingLevel.info, @@ -393,11 +399,11 @@ class ClientTest { ), ), ) - server.sendResourceListChanged() + serverSession.sendResourceListChanged() // This should fail because the server doesn't have the tools capability val ex = assertFailsWith { - server.sendToolListChanged() + serverSession.sendToolListChanged() } assertTrue(ex.message?.contains("Server does not support notifying of tool list changes") == true) } @@ -418,19 +424,6 @@ class ClientTest { val def = CompletableDeferred() val defTimeOut = CompletableDeferred() - server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> - // Simulate delay - def.complete(Unit) - try { - delay(1000) - } catch (e: CancellationException) { - defTimeOut.complete(Unit) - throw e - } - ListResourcesResult(resources = emptyList()) - fail("Shouldn't have been called") - } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( @@ -438,17 +431,34 @@ class ClientTest { options = ClientOptions(capabilities = ClientCapabilities()), ) + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) println("Client connected") }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) println("Server connected") }, ).joinAll() + val serverSession = serverSessionResult.await() + + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + // Simulate delay + def.complete(Unit) + try { + delay(1000) + } catch (e: CancellationException) { + defTimeOut.complete(Unit) + throw e + } + ListResourcesResult(resources = emptyList()) + fail("Shouldn't have been called") + } + val defCancel = CompletableDeferred() val job = launch { try { @@ -478,37 +488,40 @@ class ClientTest { ), ) - server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> - // Simulate a delayed response - // Wait ~100ms unless canceled - try { - withTimeout(100L) { - // Just delay here, if timeout is 0 on the client side, this won't return in time - delay(100) - } - } catch (_: Exception) { - // If aborted, just rethrow or return early - } - ListResourcesResult(resources = emptyList()) - } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions(capabilities = ClientCapabilities()), ) + val serverSessionResult = CompletableDeferred() + listOf( launch { client.connect(clientTransport) println("Client connected") }, launch { - server.connect(serverTransport) + serverSessionResult.complete(server.connect(serverTransport)) println("Server connected") }, ).joinAll() + val serverSession = serverSessionResult.await() + serverSession.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> + // Simulate a delayed response + // Wait ~100ms unless canceled + try { + withTimeout(100L) { + // Just delay here, if timeout is 0 on the client side, this won't return in time + delay(100) + } + } catch (_: Exception) { + // If aborted, just rethrow or return early + } + ListResourcesResult(resources = emptyList()) + } + // Request with 1 msec timeout should fail immediately val ex = assertFailsWith { withTimeout(1) { @@ -559,7 +572,36 @@ class ClientTest { serverOptions, ) - server.setRequestHandler(Method.Defined.Initialize) { _, _ -> + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities(sampling = EmptyJsonObject), + ), + ) + + var receivedMessage: JSONRPCMessage? = null + clientTransport.onMessage { msg -> + receivedMessage = msg + } + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + + serverSession.setRequestHandler(Method.Defined.Initialize) { _, _ -> InitializeResult( protocolVersion = LATEST_PROTOCOL_VERSION, capabilities = ServerCapabilities( @@ -569,6 +611,7 @@ class ClientTest { serverInfo = Implementation(name = "test", version = "1.0"), ) } + val serverListToolsResult = ListToolsResult( tools = listOf( Tool( @@ -583,33 +626,10 @@ class ClientTest { nextCursor = null, ) - server.setRequestHandler(Method.Defined.ToolsList) { _, _ -> + serverSession.setRequestHandler(Method.Defined.ToolsList) { _, _ -> serverListToolsResult } - val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() - - val client = Client( - clientInfo = Implementation(name = "test client", version = "1.0"), - options = ClientOptions( - capabilities = ClientCapabilities(sampling = EmptyJsonObject), - ), - ) - - var receivedMessage: JSONRPCMessage? = null - clientTransport.onMessage { msg -> - receivedMessage = msg - } - - listOf( - launch { - client.connect(clientTransport) - }, - launch { - server.connect(serverTransport) - }, - ).joinAll() - val serverCapabilities = client.serverCapabilities assertEquals(ServerCapabilities.Tools(null), serverCapabilities?.tools) @@ -652,15 +672,25 @@ class ClientTest { ), ) + val serverSessionResult = CompletableDeferred() + listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) }, + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + }, ).joinAll() - val clientCapabilities = server.clientCapabilities + val serverSession = serverSessionResult.await() + + val clientCapabilities = serverSession.clientCapabilities assertEquals(ClientCapabilities.Roots(null), clientCapabilities?.roots) - val listRootsResult = server.listRoots() + val listRootsResult = serverSession.listRoots() assertEquals(listRootsResult.roots, clientRoots) } @@ -773,16 +803,28 @@ class ClientTest { // Track notifications var rootListChangedNotificationReceived = false - server.setNotificationHandler(Method.Defined.NotificationsRootsListChanged) { - rootListChangedNotificationReceived = true - CompletableDeferred(Unit) - } + + val serverSessionResult = CompletableDeferred() listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) }, + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + }, ).joinAll() + val serverSession = serverSessionResult.await() + serverSession.setNotificationHandler( + Method.Defined.NotificationsRootsListChanged, + ) { + rootListChangedNotificationReceived = true + CompletableDeferred(Unit) + } + client.sendRootsListChanged() assertTrue( @@ -809,14 +851,24 @@ class ClientTest { ), ) + val serverSessionResult = CompletableDeferred() + listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) }, + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + }, ).joinAll() + val serverSession = serverSessionResult.await() + // Verify that creating an elicitation throws an exception val exception = assertFailsWith { - server.createElicitation( + serverSession.createElicitation( message = "Please provide your GitHub username", requestedSchema = CreateElicitationRequest.RequestedSchema( properties = buildJsonObject { @@ -879,12 +931,22 @@ class ClientTest { ), ) + val serverSessionResult = CompletableDeferred() + listOf( - launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) }, + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + }, ).joinAll() - val result = server.createElicitation( + val serverSession = serverSessionResult.await() + + val result = serverSession.createElicitation( message = elicitationMessage, requestedSchema = requestedSchema, ) diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt index 4b49e1e8..597a3cb2 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt @@ -70,19 +70,6 @@ class SseTransportTest : BaseTransportTest() { install(ServerSSE) routing { mcp { mcpServer } -// sse { -// mcpSseTransport("", transports).apply { -// onMessage { -// send(it) -// } -// -// start() -// } -// } -// -// post { -// mcpPostEndpoint(transports) -// } } }.startSuspend(wait = false) @@ -110,22 +97,7 @@ class SseTransportTest : BaseTransportTest() { val server = embeddedServer(CIO, port = 0) { install(ServerSSE) routing { - mcp("/sse") { mcpServer } -// route("/sse") { -// sse { -// mcpSseTransport("", transports).apply { -// onMessage { -// send(it) -// } -// -// start() -// } -// } -// -// post { -// mcpPostEndpoint(transports) -// } -// } + mcp { mcpServer } } }.startSuspend(wait = false) diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt index 2d1f25c2..dda2c0b5 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt @@ -1,63 +1,206 @@ package io.modelcontextprotocol.kotlin.sdk.integration import io.ktor.client.HttpClient +import io.ktor.client.plugins.sse.SSE +import io.ktor.server.application.ApplicationStopped import io.ktor.server.application.install import io.ktor.server.cio.CIOApplicationEngine import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.Role import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.client.mcpSse +import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.server.mcp import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.test.runTest import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout import kotlin.test.Test -import kotlin.test.fail +import kotlin.test.assertTrue import io.ktor.client.engine.cio.CIO as ClientCIO -import io.ktor.client.plugins.sse.SSE as ClientSSE import io.ktor.server.cio.CIO as ServerCIO import io.ktor.server.sse.SSE as ServerSSE -private const val URL = "127.0.0.1" - class SseIntegrationTest { @Test fun `client should be able to connect to sse server`() = runTest { - val serverEngine = initServer() + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + val port = server.engine.resolvedConnectors().first().port + client = initClient(serverPort=port) + } + } + } finally { + client?.close() + server?.stop(1000, 2000) + } + } + + /** + * Test Case #1: One opened connection, a client gets a prompt + * + * 1. Open SSE from Client A. + * 2. Send a POST request from Client A to POST /prompts/get. + * 3. Observe that Client A receives a response related to it. + */ + @Test + fun `single sse connection`() = runTest { + var server: EmbeddedServer? = null + var client: Client? = null + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + val port = server.engine.resolvedConnectors().first().port + client = initClient("Client A", port) + + val promptA = getPrompt(client, "Client A") + assertTrue { "Client A" in promptA } + } + } + } finally { + client?.close() + server?.stop(1000, 2000) + } + } + + /** + * Test Case #1: Two open connections, each client gets a client-specific prompt + * + * 1. Open SSE connection #1 from Client A and note the sessionId= value. + * 2. Open SSE connection #2 from Client B and note the sessionId= value. + * 3. Send a POST request to POST /message with the corresponding sessionId#1. + * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. + */ + @Test + fun `multiple sse connections`() = runTest { + var server: EmbeddedServer? = null + var clientA: Client? = null + var clientB: Client? = null + try { withContext(Dispatchers.Default) { - val port = serverEngine.engine.resolvedConnectors().first().port - val client = initClient(port) - client.close() + withTimeout(1000) { + server = initServer() + val port = server.engine.resolvedConnectors().first().port + + clientA = initClient("Client A", port) + clientB = initClient("Client B", port) + + // Step 3: Send a prompt request from Client A + val promptA = getPrompt(clientA, "Client A") + // Step 4: Send a prompt request from Client B + val promptB = getPrompt(clientB, "Client B") + + assertTrue { "Client A" in promptA } + assertTrue { "Client B" in promptB } + } } - } catch (e: Exception) { - fail("Failed to connect client: $e") } finally { - // Make sure to stop the server - serverEngine.stopSuspend(1000, 2000) + clientA?.close() + clientB?.close() + server?.stop(1000, 2000) } } - private suspend fun initClient(port: Int): Client = HttpClient(ClientCIO) { - install(ClientSSE) - }.mcpSse("http://$URL:$port") + private suspend fun initClient(name: String = "", serverPort: Int): Client { + val client = Client( + Implementation(name = name, version = "1.0.0"), + ) + + val httpClient = HttpClient(ClientCIO) { + install(SSE) + } + + // Create a transport wrapper that captures the session ID and received messages + val transport = httpClient.mcpSseTransport { + url { + host = URL + port = serverPort + } + } + + client.connect(transport) + + return client + } private suspend fun initServer(): EmbeddedServer { val server = Server( - Implementation(name = "sse-e2e-test", version = "1.0.0"), - ServerOptions(capabilities = ServerCapabilities()), + Implementation(name = "sse-server", version = "1.0.0"), + ServerOptions( + capabilities = ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)), + ), ) - return embeddedServer(ServerCIO, host = URL, port = 0) { + server.addPrompt( + name = "prompt", + description = "Prompt description", + arguments = listOf( + PromptArgument( + name = "client", + description = "Client name who requested a prompt", + required = true, + ), + ), + ) { request -> + GetPromptResult( + "Prompt for ${request.name}", + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent("Prompt for client ${request.arguments?.get("client")}"), + ), + ), + ) + } + + val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { install(ServerSSE) routing { mcp { server } } - }.startSuspend(wait = false) + } + + ktorServer.monitor.subscribe(ApplicationStopped) { + println("SD -- [T] ktor server has been stopped") + } + + return ktorServer.startSuspend(wait = false) + } + + /** + * Retrieves a prompt result using the provided client and client name. + */ + private suspend fun getPrompt(client: Client, clientName: String): String { + val response = client.getPrompt( + GetPromptRequest( + "prompt", + arguments = mapOf("client" to clientName), + ), + ) + + return (response.messages.first().content as? TextContent)?.text + ?: error("Failed to receive prompt for Client $clientName") + } + + companion object { + private const val URL = "127.0.0.1" + private const val PORT = 0 } } diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/WebSocketIntegrationTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/WebSocketIntegrationTest.kt new file mode 100644 index 00000000..45f97044 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/WebSocketIntegrationTest.kt @@ -0,0 +1,207 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.ktor.client.HttpClient +import io.ktor.server.application.ApplicationStopped +import io.ktor.server.application.install +import io.ktor.server.cio.CIOApplicationEngine +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpWebSocketTransport +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.mcpWebSocket +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlin.test.Test +import kotlin.test.assertTrue +import io.ktor.client.engine.cio.CIO as ClientCIO +import io.ktor.client.plugins.websocket.WebSockets as ClientWebSocket +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.websocket.WebSockets as ServerWebSockets + +class WebSocketIntegrationTest { + + @Test + fun `client should be able to connect to websocket server 2`() = runTest { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + val port = server.engine.resolvedConnectors().first().port + client = initClient(serverPort = port) + } + } + } finally { + client?.close() + server?.stop(1000, 2000) + } + } + + /** + * Test Case #1: One opened connection, a client gets a prompt + * + * 1. Open WebSocket from Client A. + * 2. Send a POST request from Client A to POST /prompts/get. + * 3. Observe that Client A receives a response related to it. + */ + @Test + fun `single websocket connection`() = runTest { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + val port = server.engine.resolvedConnectors().first().port + client = initClient("Client A", port) + + val promptA = getPrompt(client, "Client A") + assertTrue { "Client A" in promptA } + } + } + } finally { + client?.close() + server?.stop(1000, 2000) + } + } + + /** + * Test Case #1: Two open connections, each client gets a client-specific prompt + * + * 1. Open WebSocket connection #1 from Client A and note the sessionId= value. + * 2. Open WebSocket connection #2 from Client B and note the sessionId= value. + * 3. Send a POST request to POST /message with the corresponding sessionId#1. + * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. + */ + @Test + fun `multiple websocket connections`() = runTest { + var server: EmbeddedServer? = null + var clientA: Client? = null + var clientB: Client? = null + + try { + withContext(Dispatchers.Default) { + withTimeout(1000) { + server = initServer() + val port = server.engine.resolvedConnectors().first().port + clientA = initClient("Client A", port) + clientB = initClient("Client B",port) + + // Step 3: Send a prompt request from Client A + val promptA = getPrompt(clientA, "Client A") + // Step 4: Send a prompt request from Client B + val promptB = getPrompt(clientB, "Client B") + + assertTrue { "Client A" in promptA } + assertTrue { "Client B" in promptB } + } + } + } finally { + clientA?.close() + clientB?.close() + server?.stop(1000, 2000) + } + } + + private suspend fun initClient(name: String = "", serverPort: Int): Client { + val client = Client( + Implementation(name = name, version = "1.0.0"), + ) + + val httpClient = HttpClient(ClientCIO) { + install(ClientWebSocket) + } + + // Create a transport wrapper that captures the session ID and received messages + val transport = httpClient.mcpWebSocketTransport { + url { + host = URL + port = serverPort + } + } + + client.connect(transport) + + return client + } + + private suspend fun initServer(): EmbeddedServer { + val server = Server( + Implementation(name = "websocket-server", version = "1.0.0"), + ServerOptions( + capabilities = ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)), + ), + ) + + server.addPrompt( + name = "prompt", + description = "Prompt description", + arguments = listOf( + PromptArgument( + name = "client", + description = "Client name who requested a prompt", + required = true, + ), + ), + ) { request -> + GetPromptResult( + "Prompt for ${request.name}", + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent("Prompt for client ${request.arguments?.get("client")}"), + ), + ), + ) + } + + val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { + install(ServerWebSockets) + routing { + mcpWebSocket(block = { server }) + } + } + + ktorServer.monitor.subscribe(ApplicationStopped) { + println("SD -- [T] ktor server has been stopped") + } + + return ktorServer.startSuspend(wait = false) + } + + /** + * Retrieves a prompt result using the provided client and client name. + */ + private suspend fun getPrompt(client: Client, clientName: String): String { + val response = client.getPrompt( + GetPromptRequest( + "prompt", + arguments = mapOf("client" to clientName), + ), + ) + + return (response.messages.first().content as? TextContent)?.text + ?: error("Failed to receive prompt for Client $clientName") + } + + companion object { + private const val PORT = 0 + private const val URL = "127.0.0.1" + } +}