@@ -10,6 +10,7 @@ import (
10
10
"encoding/base64"
11
11
"encoding/gob"
12
12
"encoding/json"
13
+ "errors"
13
14
"fmt"
14
15
"iter"
15
16
"maps"
@@ -513,7 +514,8 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns
513
514
// If no tools have been added, the server will not have the tool capability.
514
515
// The same goes for other features like prompts and resources.
515
516
func (s * Server ) Run (ctx context.Context , t Transport ) error {
516
- ss , err := s .Connect (ctx , t )
517
+ // TODO: provide a way to pass ServerSessionOptions?
518
+ ss , err := s .Connect (ctx , t , nil )
517
519
if err != nil {
518
520
return err
519
521
}
@@ -535,7 +537,7 @@ func (s *Server) Run(ctx context.Context, t Transport) error {
535
537
// bind implements the binder[*ServerSession] interface, so that Servers can
536
538
// be connected using [connect].
537
539
func (s * Server ) bind (conn * jsonrpc2.Connection ) * ServerSession {
538
- ss := & ServerSession {conn : conn , server : s , state : & SessionState {} }
540
+ ss := & ServerSession {conn : conn , server : s }
539
541
s .mu .Lock ()
540
542
s .sessions = append (s .sessions , ss )
541
543
s .mu .Unlock ()
@@ -556,14 +558,33 @@ func (s *Server) disconnect(cc *ServerSession) {
556
558
}
557
559
}
558
560
561
+ type ServerSessionOptions struct {
562
+ SessionID string
563
+ SessionState * SessionState
564
+ SessionStore SessionStore
565
+ }
566
+
559
567
// Connect connects the MCP server over the given transport and starts handling
560
568
// messages.
561
569
//
562
570
// It returns a connection object that may be used to terminate the connection
563
571
// (with [Connection.Close]), or await client termination (with
564
572
// [Connection.Wait]).
565
- func (s * Server ) Connect (ctx context.Context , t Transport ) (* ServerSession , error ) {
566
- return connect (ctx , t , s )
573
+ func (s * Server ) Connect (ctx context.Context , t Transport , opts * ServerSessionOptions ) (* ServerSession , error ) {
574
+ if opts != nil && opts .SessionState == nil && opts .SessionStore != nil {
575
+ return nil , errors .New ("ServerSessionOptions has store but no state" )
576
+ }
577
+ ss , err := connect (ctx , t , s )
578
+ if err != nil {
579
+ return nil , err
580
+ }
581
+ if opts != nil {
582
+ ss .opts = * opts
583
+ }
584
+ if ss .opts .SessionState == nil {
585
+ ss .opts .SessionState = & SessionState {}
586
+ }
587
+ return ss , nil
567
588
}
568
589
569
590
func (s * Server ) callInitializedHandler (ctx context.Context , ss * ServerSession , params * InitializedParams ) (Result , error ) {
@@ -596,32 +617,20 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot
596
617
// Call [ServerSession.Close] to close the connection, or await client
597
618
// termination with [ServerSession.Wait].
598
619
type ServerSession struct {
599
- server * Server
600
- conn * jsonrpc2.Connection
620
+ server * Server
621
+ conn * jsonrpc2.Connection
622
+ opts ServerSessionOptions
623
+
601
624
mu sync.Mutex
602
625
initialized bool
603
626
keepaliveCancel context.CancelFunc
604
-
605
- sessionID string
606
- state * SessionState
607
- store SessionStore
608
627
}
609
628
610
629
func (ss * ServerSession ) setConn (c Connection ) {
611
630
}
612
631
613
- func (ss * ServerSession ) ID () string { return ss .sessionID }
614
-
615
- // InitSession initializes the session with a session ID, state, and store.
616
- // If called, it must be called immediately after the session is connected.
617
- // If never called, the session will begin with a zero SessionState and no session
618
- // ID or store.
619
- func (ss * ServerSession ) InitSession (sessionID string , state * SessionState , store SessionStore ) {
620
- ss .mu .Lock ()
621
- defer ss .mu .Unlock ()
622
- ss .sessionID = sessionID
623
- ss .state = state
624
- ss .store = store
632
+ func (ss * ServerSession ) ID () string {
633
+ return ss .opts .SessionID
625
634
}
626
635
627
636
// Ping pings the client.
@@ -645,7 +654,7 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag
645
654
// is below that of the last SetLevel.
646
655
func (ss * ServerSession ) Log (ctx context.Context , params * LoggingMessageParams ) error {
647
656
ss .mu .Lock ()
648
- logLevel := ss .state .LogLevel
657
+ logLevel := ss .opts . SessionState .LogLevel
649
658
ss .mu .Unlock ()
650
659
if logLevel == "" {
651
660
// The spec is unclear, but seems to imply that no log messages are sent until the client
@@ -759,10 +768,10 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam
759
768
return nil , fmt .Errorf ("%w: \" params\" must be be provided" , jsonrpc2 .ErrInvalidParams )
760
769
}
761
770
ss .mu .Lock ()
762
- ss .state .InitializeParams = params
771
+ ss .opts . SessionState .InitializeParams = params
763
772
ss .mu .Unlock ()
764
- if ss .store != nil {
765
- if err := ss . store .Store (ctx , ss .sessionID , ss .state ); err != nil {
773
+ if store := ss .opts . SessionStore ; store != nil {
774
+ if err := store .Store (ctx , ss .opts . SessionID , ss .opts . SessionState ); err != nil {
766
775
return nil , fmt .Errorf ("storing session state: %w" , err )
767
776
}
768
777
}
@@ -802,9 +811,9 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error
802
811
func (ss * ServerSession ) setLevel (ctx context.Context , params * SetLevelParams ) (* emptyResult , error ) {
803
812
ss .mu .Lock ()
804
813
defer ss .mu .Unlock ()
805
- ss .state .LogLevel = params .Level
806
- if ss .store != nil {
807
- if err := ss . store .Store (ctx , ss .sessionID , ss .state ); err != nil {
814
+ ss .opts . SessionState .LogLevel = params .Level
815
+ if store := ss .opts . SessionStore ; store != nil {
816
+ if err := store .Store (ctx , ss .opts . SessionID , ss .opts . SessionState ); err != nil {
808
817
return nil , err
809
818
}
810
819
}
0 commit comments