Skip to content

Commit 16f60b4

Browse files
authored
Accept requestID as a string #25 (#26)
1 parent 512e7b6 commit 16f60b4

File tree

4 files changed

+60
-5
lines changed

4 files changed

+60
-5
lines changed

src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import io.modelcontextprotocol.kotlin.sdk.PingRequest
1515
import io.modelcontextprotocol.kotlin.sdk.Progress
1616
import io.modelcontextprotocol.kotlin.sdk.ProgressNotification
1717
import io.modelcontextprotocol.kotlin.sdk.Request
18+
import io.modelcontextprotocol.kotlin.sdk.RequestId
1819
import io.modelcontextprotocol.kotlin.sdk.RequestResult
1920
import io.modelcontextprotocol.kotlin.sdk.fromJSON
2021
import io.modelcontextprotocol.kotlin.sdk.toJSON
@@ -119,11 +120,11 @@ public abstract class Protocol<SendRequestT : Request, SendNotificationT : Notif
119120
mutableMapOf()
120121

121122
@PublishedApi
122-
internal val responseHandlers: MutableMap<Long, (response: JSONRPCResponse?, error: Exception?) -> Unit> =
123+
internal val responseHandlers: MutableMap<RequestId, (response: JSONRPCResponse?, error: Exception?) -> Unit> =
123124
mutableMapOf()
124125

125126
@PublishedApi
126-
internal val progressHandlers: MutableMap<Long, ProgressCallback> = mutableMapOf()
127+
internal val progressHandlers: MutableMap<RequestId, ProgressCallback> = mutableMapOf()
127128

128129
/**
129130
* Callback for when the connection is closed for any reason.

src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ private val REQUEST_MESSAGE_ID = AtomicLong(0L)
2525
* A progress token, used to associate progress notifications with the original request.
2626
* Stores message ID.
2727
*/
28-
public typealias ProgressToken = Long
28+
public typealias ProgressToken = RequestId
2929

3030
/**
3131
* An opaque token used to represent a cursor for pagination.
@@ -191,7 +191,14 @@ public data class EmptyRequestResult(
191191
/**
192192
* A uniquely identifying ID for a request in JSON-RPC.
193193
*/
194-
public typealias RequestId = Long
194+
@Serializable(with = RequestIdSerializer::class)
195+
public sealed interface RequestId {
196+
@Serializable
197+
public data class StringId(val value: String) : RequestId
198+
199+
@Serializable
200+
public data class NumberId(val value: Long) : RequestId
201+
}
195202

196203
/**
197204
* Represents a JSON-RPC message in the protocol.
@@ -204,7 +211,7 @@ public sealed interface JSONRPCMessage
204211
*/
205212
@Serializable
206213
public data class JSONRPCRequest(
207-
val id: RequestId = REQUEST_MESSAGE_ID.incrementAndGet(),
214+
val id: RequestId = RequestId.NumberId(REQUEST_MESSAGE_ID.incrementAndGet()),
208215
val method: String,
209216
val params: JsonElement? = null,
210217
val jsonrpc: String = JSONRPC_VERSION,

src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/types.util.kt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import kotlinx.serialization.KSerializer
77
import kotlinx.serialization.descriptors.PrimitiveKind
88
import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor
99
import kotlinx.serialization.descriptors.SerialDescriptor
10+
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
1011
import kotlinx.serialization.encoding.Decoder
1112
import kotlinx.serialization.encoding.Encoder
1213
import kotlinx.serialization.json.*
@@ -270,3 +271,29 @@ internal object JSONRPCMessagePolymorphicSerializer :
270271
}
271272

272273
internal val EmptyJsonObject = JsonObject(emptyMap())
274+
275+
public class RequestIdSerializer : KSerializer<RequestId> {
276+
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("RequestId")
277+
278+
override fun deserialize(decoder: Decoder): RequestId {
279+
val jsonDecoder = decoder as? JsonDecoder ?: error("Can only deserialize JSON")
280+
val element = jsonDecoder.decodeJsonElement()
281+
282+
return when (element) {
283+
is JsonPrimitive -> when {
284+
element.isString -> RequestId.StringId(element.content)
285+
element.longOrNull != null -> RequestId.NumberId(element.long)
286+
else -> error("Invalid RequestId type")
287+
}
288+
else -> error("Invalid RequestId format")
289+
}
290+
}
291+
292+
override fun serialize(encoder: Encoder, value: RequestId) {
293+
val jsonEncoder = encoder as? JsonEncoder ?: error("Can only serialize JSON")
294+
when (value) {
295+
is RequestId.StringId -> jsonEncoder.encodeString(value.value)
296+
is RequestId.NumberId -> jsonEncoder.encodeLong(value.value)
297+
}
298+
}
299+
}

src/test/kotlin/client/TypesTest.kt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,24 @@ class TypesTest {
2424
val line = "{\"result\":{\"content\":[{\"type\":\"text\"}],\"isError\":false},\"jsonrpc\":\"2.0\",\"id\":4}"
2525
McpJson.decodeFromString<JSONRPCMessage>(line)
2626
}
27+
28+
@Test
29+
fun testJSONRPCMessageWithStringId() {
30+
val line = """
31+
{
32+
"jsonrpc": "2.0",
33+
"method": "initialize",
34+
"id": "ebf9f64a-0",
35+
"params": {
36+
"protocolVersion": "2024-11-05",
37+
"capabilities": {},
38+
"clientInfo": {
39+
"name": "mcp-java-client",
40+
"version": "0.2.0"
41+
}
42+
}
43+
}
44+
""".trimIndent()
45+
McpJson.decodeFromString<JSONRPCMessage>(line)
46+
}
2747
}

0 commit comments

Comments
 (0)