Skip to content

Introduce server sessions #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ public open class Client(private val clientInfo: Implementation, options: Client

notification(InitializedNotification())
} catch (error: Throwable) {
logger.error(error) { "Failed to initialize client: ${error.message}" }
close()

if (error !is CancellationException) {
throw IllegalStateException("Error connecting to transport: ${error.message}")
throw IllegalStateException("Error connecting to transport: ${error.message}", error)
}

throw error
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.modelcontextprotocol.kotlin.sdk.client

import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.HttpClient
import io.ktor.client.plugins.sse.ClientSSESession
import io.ktor.client.plugins.sse.sseSession
Expand Down Expand Up @@ -46,6 +47,8 @@ public class SseClientTransport(
private val reconnectionTime: Duration? = null,
private val requestBuilder: HttpRequestBuilder.() -> Unit = {},
) : AbstractTransport() {
private val logger = KotlinLogging.logger {}

private val initialized: AtomicBoolean = AtomicBoolean(false)
private val endpoint = CompletableDeferred<String>()

Expand Down Expand Up @@ -111,6 +114,8 @@ public class SseClientTransport(
val text = response.bodyAsText()
error("Error POSTing to endpoint (HTTP ${response.status}): $text")
}

logger.debug { "Client successfully sent message via SSE $endpoint" }
} catch (e: Throwable) {
_onError(e)
throw e
Expand Down Expand Up @@ -158,6 +163,7 @@ public class SseClientTransport(
val path = if (eventData.startsWith("/")) eventData.substring(1) else eventData
val endpointUrl = Url("$baseUrl/$path")
endpoint.complete(endpointUrl.toString())
logger.debug { "Client connected to endpoint: $endpointUrl" }
} catch (e: Throwable) {
_onError(e)
endpoint.completeExceptionally(e)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.modelcontextprotocol.kotlin.sdk.client

import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.HttpClient
import io.ktor.client.plugins.websocket.webSocketSession
import io.ktor.client.request.HttpRequestBuilder
Expand All @@ -10,6 +11,8 @@ import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL
import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport
import kotlin.properties.Delegates

private val logger = KotlinLogging.logger {}

/**
* Client transport for WebSocket: this will connect to a server over the WebSocket protocol.
*/
Expand All @@ -21,6 +24,8 @@ public class WebSocketClientTransport(
override var session: WebSocketSession by Delegates.notNull()

override suspend fun initializeSession() {
logger.debug { "Websocket session initialization started..." }

session = urlString?.let {
client.webSocketSession(it) {
requestBuilder()
Expand All @@ -32,5 +37,7 @@ public class WebSocketClientTransport(

header(HttpHeaders.SecWebSocketProtocol, MCP_SUBPROTOCOL)
}

logger.debug { "Websocket session initialization finished" }
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package io.modelcontextprotocol.kotlin.sdk.client

import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.HttpClient
import io.ktor.client.request.HttpRequestBuilder
import io.modelcontextprotocol.kotlin.sdk.Implementation
import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION
import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME

private val logger = KotlinLogging.logger {}

/**
* Returns a new WebSocket transport for the Model Context Protocol using the provided HttpClient.
*
Expand Down Expand Up @@ -36,6 +39,8 @@ public suspend fun HttpClient.mcpWebSocket(
version = LIB_VERSION,
),
)
logger.debug { "Client started to connect to server" }
client.connect(transport)
logger.debug { "Client finished to connect to server" }
return client
}
29 changes: 29 additions & 0 deletions kotlin-sdk-core/api/kotlin-sdk-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -2328,6 +2328,35 @@ public final class io/modelcontextprotocol/kotlin/sdk/ResourceListChangedNotific
public final fun serializer ()Lkotlinx/serialization/KSerializer;
}

public final class io/modelcontextprotocol/kotlin/sdk/ResourceReference : io/modelcontextprotocol/kotlin/sdk/Reference {
public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/ResourceReference$Companion;
public static final field TYPE Ljava/lang/String;
public fun <init> (Ljava/lang/String;)V
public final fun component1 ()Ljava/lang/String;
public final fun copy (Ljava/lang/String;)Lio/modelcontextprotocol/kotlin/sdk/ResourceReference;
public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/ResourceReference;Ljava/lang/String;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/ResourceReference;
public fun equals (Ljava/lang/Object;)Z
public fun getType ()Ljava/lang/String;
public final fun getUri ()Ljava/lang/String;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
}

public final synthetic class io/modelcontextprotocol/kotlin/sdk/ResourceReference$$serializer : kotlinx/serialization/internal/GeneratedSerializer {
public static final field INSTANCE Lio/modelcontextprotocol/kotlin/sdk/ResourceReference$$serializer;
public final fun childSerializers ()[Lkotlinx/serialization/KSerializer;
public final fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Lio/modelcontextprotocol/kotlin/sdk/ResourceReference;
public synthetic fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ljava/lang/Object;
public final fun getDescriptor ()Lkotlinx/serialization/descriptors/SerialDescriptor;
public final fun serialize (Lkotlinx/serialization/encoding/Encoder;Lio/modelcontextprotocol/kotlin/sdk/ResourceReference;)V
public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V
public fun typeParametersSerializers ()[Lkotlinx/serialization/KSerializer;
}

public final class io/modelcontextprotocol/kotlin/sdk/ResourceReference$Companion {
public final fun serializer ()Lkotlinx/serialization/KSerializer;
}

public final class io/modelcontextprotocol/kotlin/sdk/ResourceTemplate {
public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/ResourceTemplate$Companion;
public fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/Annotations;)V
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import kotlin.reflect.typeOf
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds

private val LOGGER = KotlinLogging.logger { }
private val logger = KotlinLogging.logger { }

public const val IMPLEMENTATION_NAME: String = "mcp-ktor"

Expand Down Expand Up @@ -212,6 +212,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
}
}

logger.info { "Starting transport" }
return transport.start()
}

Expand All @@ -229,29 +230,29 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
}

private suspend fun onNotification(notification: JSONRPCNotification) {
LOGGER.trace { "Received notification: ${notification.method}" }
logger.trace { "Received notification: ${notification.method}" }

val handler = notificationHandlers[notification.method] ?: fallbackNotificationHandler

if (handler == null) {
LOGGER.trace { "No handler found for notification: ${notification.method}" }
logger.trace { "No handler found for notification: ${notification.method}" }
return
}
try {
handler(notification)
} catch (cause: Throwable) {
LOGGER.error(cause) { "Error handling notification: ${notification.method}" }
logger.error(cause) { "Error handling notification: ${notification.method}" }
onError(cause)
}
}

private suspend fun onRequest(request: JSONRPCRequest) {
LOGGER.trace { "Received request: ${request.method} (id: ${request.id})" }
logger.trace { "Received request: ${request.method} (id: ${request.id})" }

val handler = requestHandlers[request.method] ?: fallbackRequestHandler

if (handler === null) {
LOGGER.trace { "No handler found for request: ${request.method}" }
logger.trace { "No handler found for request: ${request.method}" }
try {
transport?.send(
JSONRPCResponse(
Expand All @@ -263,15 +264,15 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
),
)
} catch (cause: Throwable) {
LOGGER.error(cause) { "Error sending method not found response" }
logger.error(cause) { "Error sending method not found response" }
onError(cause)
}
return
}

try {
val result = handler(request, RequestHandlerExtra())
LOGGER.trace { "Request handled successfully: ${request.method} (id: ${request.id})" }
logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" }

transport?.send(
JSONRPCResponse(
Expand All @@ -280,7 +281,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
),
)
} catch (cause: Throwable) {
LOGGER.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" }
logger.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" }

try {
transport?.send(
Expand All @@ -293,7 +294,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
),
)
} catch (sendError: Throwable) {
LOGGER.error(sendError) {
logger.error(sendError) {
"Failed to send error response for request: ${request.method} (id: ${request.id})"
}
// Optionally implement fallback behavior here
Expand All @@ -302,7 +303,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
}

private fun onProgress(notification: ProgressNotification) {
LOGGER.trace {
logger.trace {
"Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}"
}
val progress = notification.params.progress
Expand All @@ -315,7 +316,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
val error = Error(
"Received a progress notification for an unknown token: ${McpJson.encodeToString(notification)}",
)
LOGGER.error { error.message }
logger.error { error.message }
onError(error)
return
}
Expand Down Expand Up @@ -390,9 +391,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
* Do not use this method to emit notifications! Use notification() instead.
*/
public suspend fun <T : RequestResult> request(request: Request, options: RequestOptions? = null): T {
LOGGER.trace { "Sending request: ${request.method}" }
logger.trace { "Sending request: ${request.method}" }
val result = CompletableDeferred<T>()
val transport = this@Protocol.transport ?: throw Error("Not connected")
val transport = transport ?: throw Error("Not connected")

if (this@Protocol.options?.enforceStrictCapabilities == true) {
assertCapabilityForMethod(request.method)
Expand All @@ -402,7 +403,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
val messageId = message.id

if (options?.onProgress != null) {
LOGGER.trace { "Registering progress handler for request id: $messageId" }
logger.trace { "Registering progress handler for request id: $messageId" }
_progressHandlers.update { current ->
current.put(messageId, options.onProgress)
}
Expand Down Expand Up @@ -452,12 +453,12 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT
try {
withTimeout(timeout) {
LOGGER.trace { "Sending request message with id: $messageId" }
logger.trace { "Sending request message with id: $messageId" }
this@Protocol.transport?.send(message)
}
return result.await()
} catch (cause: TimeoutCancellationException) {
LOGGER.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" }
logger.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" }
cancel(
McpError(
ErrorCode.Defined.RequestTimeout.code,
Expand All @@ -474,7 +475,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
* Emits a notification, which is a one-way message that does not expect a response.
*/
public suspend fun notification(notification: Notification) {
LOGGER.trace { "Sending notification: ${notification.method}" }
logger.trace { "Sending notification: ${notification.method}" }
val transport = this.transport ?: error("Not connected")
assertNotificationCapability(notification.method)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.modelcontextprotocol.kotlin.sdk.shared

import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.websocket.Frame
import io.ktor.websocket.WebSocketSession
import io.ktor.websocket.close
Expand All @@ -17,6 +18,8 @@ import kotlin.concurrent.atomics.ExperimentalAtomicApi

public const val MCP_SUBPROTOCOL: String = "mcp"

private val logger = KotlinLogging.logger {}

/**
* Abstract class representing a WebSocket transport for the Model Context Protocol (MCP).
* Handles communication over a WebSocket session.
Expand All @@ -40,6 +43,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
protected abstract suspend fun initializeSession()

override suspend fun start() {
logger.debug { "Starting websocket transport" }

if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
error(
"WebSocketClientTransport already started! " +
Expand All @@ -53,7 +58,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
while (true) {
val message = try {
session.incoming.receive()
} catch (_: ClosedReceiveChannelException) {
} catch (e: ClosedReceiveChannelException) {
logger.debug { "Closed receive channel, exiting" }
return@launch
}

Expand Down Expand Up @@ -84,6 +90,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
}

override suspend fun send(message: JSONRPCMessage) {
logger.debug { "Sending message" }
if (!initialized.load()) {
error("Not connected")
}
Expand All @@ -96,6 +103,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
error("Not connected")
}

logger.debug { "Closing websocket session" }
session.close()
session.coroutineContext.job.join()
}
Expand Down
Loading
Loading