@@ -20,7 +20,9 @@ import (
20
20
"time"
21
21
22
22
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
23
+ "github.com/modelcontextprotocol/go-sdk/internal/util"
23
24
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
25
+ "golang.org/x/oauth2/authhandler"
24
26
)
25
27
26
28
const (
@@ -656,7 +658,7 @@ type StreamableReconnectOptions struct {
656
658
}
657
659
658
660
// DefaultReconnectOptions provides sensible defaults for reconnect logic.
659
- var DefaultReconnectOptions = & StreamableReconnectOptions {
661
+ var DefaultReconnectOptions = StreamableReconnectOptions {
660
662
MaxRetries : 5 ,
661
663
growFactor : 1.5 ,
662
664
initialDelay : 1 * time .Second ,
@@ -666,10 +668,18 @@ var DefaultReconnectOptions = &StreamableReconnectOptions{
666
668
// StreamableClientTransportOptions provides options for the
667
669
// [NewStreamableClientTransport] constructor.
668
670
type StreamableClientTransportOptions struct {
669
- // HTTPClient is the client to use for making HTTP requests. If nil,
670
- // http.DefaultClient is used.
671
- HTTPClient * http.Client
672
- ReconnectOptions * StreamableReconnectOptions
671
+ // ReconnectOptions control the transport's behavior when it is disconnected
672
+ // from the server.
673
+ ReconnectOptions StreamableReconnectOptions
674
+ // HTTPClient is the client to use for making unauthenticaed HTTP requests.
675
+ // If nil, http.DefaultClient is used.
676
+ // For authenticated requests, a shallow clone of the client will be used,
677
+ // with a different transport. The cookie jar will not be copied.
678
+ HTTPClient * http.Client
679
+ // AuthHandler is a function that handles the user interaction part of the OAuth 2.1 flow.
680
+ // It should prompt the user at the given URL and return the expected OAuth values.
681
+ // See [authhandler.AuthorizationHandler] for more.
682
+ AuthHandler authhandler.AuthorizationHandler
673
683
}
674
684
675
685
// NewStreamableClientTransport returns a new client transport that connects to
@@ -679,6 +689,12 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt
679
689
if opts != nil {
680
690
t .opts = * opts
681
691
}
692
+ if t .opts .HTTPClient == nil {
693
+ t .opts .HTTPClient = http .DefaultClient
694
+ }
695
+ if t .opts .ReconnectOptions == (StreamableReconnectOptions {}) {
696
+ t .opts .ReconnectOptions = DefaultReconnectOptions
697
+ }
682
698
return t
683
699
}
684
700
@@ -691,36 +707,26 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt
691
707
// When closed, the connection issues a DELETE request to terminate the logical
692
708
// session.
693
709
func (t * StreamableClientTransport ) Connect (ctx context.Context ) (Connection , error ) {
694
- client := t .opts .HTTPClient
695
- if client == nil {
696
- client = http .DefaultClient
697
- }
698
- reconnOpts := t .opts .ReconnectOptions
699
- if reconnOpts == nil {
700
- reconnOpts = DefaultReconnectOptions
701
- }
702
710
// Create a new cancellable context that will manage the connection's lifecycle.
703
711
// This is crucial for cleanly shutting down the background SSE listener by
704
712
// cancelling its blocking network operations, which prevents hangs on exit.
705
713
connCtx , cancel := context .WithCancel (context .Background ())
706
- conn := & streamableClientConn {
707
- url : t .url ,
708
- client : client ,
709
- incoming : make (chan []byte , 100 ),
710
- done : make (chan struct {}),
711
- ReconnectOptions : reconnOpts ,
712
- ctx : connCtx ,
713
- cancel : cancel ,
714
- }
715
- return conn , nil
714
+ return & streamableClientConn {
715
+ url : t .url ,
716
+ opts : t .opts ,
717
+ incoming : make (chan []byte , 100 ),
718
+ done : make (chan struct {}),
719
+ ctx : connCtx ,
720
+ cancel : cancel ,
721
+ }, nil
716
722
}
717
723
718
724
type streamableClientConn struct {
719
- url string
720
- client * http. Client
721
- incoming chan [] byte
722
- done chan struct {}
723
- ReconnectOptions * StreamableReconnectOptions
725
+ url string
726
+ opts StreamableClientTransportOptions
727
+ authClient * http. Client
728
+ incoming chan [] byte
729
+ done chan struct {}
724
730
725
731
closeOnce sync.Once
726
732
closeErr error
@@ -800,7 +806,11 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
800
806
return nil
801
807
}
802
808
803
- func (s * streamableClientConn ) postMessage (ctx context.Context , sessionID string , msg jsonrpc.Message ) (string , error ) {
809
+ // postMessage makes a POST request to the server with msg as the body.
810
+ // It returns the session ID.
811
+ func (s * streamableClientConn ) postMessage (ctx context.Context , sessionID string , msg jsonrpc.Message ) (_ string , err error ) {
812
+ defer util .Wrapf (& err , "MCP client posting message, session ID %q" , sessionID )
813
+
804
814
data , err := jsonrpc2 .EncodeMessage (msg )
805
815
if err != nil {
806
816
return "" , err
@@ -819,28 +829,59 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
819
829
req .Header .Set ("Content-Type" , "application/json" )
820
830
req .Header .Set ("Accept" , "application/json, text/event-stream" )
821
831
822
- resp , err := s .client .Do (req )
832
+ // Use an HTTP client that does authentication, if there is one.
833
+ // Otherwise, use the one provided by the user.
834
+ client := s .authClient
835
+ if client == nil {
836
+ client = s .opts .HTTPClient
837
+ }
838
+ // TODO: Resource Indicators, as in
839
+ // https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization#resource-parameter-implementation
840
+ resp , err := client .Do (req )
823
841
if err != nil {
824
842
return "" , err
825
843
}
844
+ bodyClosed := false // avoid a second call to Close: undefined behavior (see [io.Closer])
845
+ defer func () {
846
+ if resp != nil && ! bodyClosed {
847
+ resp .Body .Close ()
848
+ }
849
+ }()
850
+
851
+ if resp .StatusCode == http .StatusUnauthorized {
852
+ if client == s .authClient {
853
+ return "" , errors .New ("got StatusUnauthorized when already authorized" )
854
+ }
855
+ tokenSource , err := doOauth (ctx , resp .Header , s .opts .HTTPClient , s .opts .AuthHandler )
856
+ if err != nil {
857
+ return "" , err
858
+ }
859
+ s .authClient = newAuthClient (s .opts .HTTPClient , tokenSource )
860
+ resp .Body .Close () // because we're about to replace resp
861
+ resp , err = s .authClient .Do (req )
862
+ if err != nil {
863
+ return "" , err
864
+ }
865
+ if resp .StatusCode == http .StatusUnauthorized {
866
+ return "" , errors .New ("got StatusUnauthorized just after authorization" )
867
+ }
868
+ }
826
869
827
870
if resp .StatusCode < 200 || resp .StatusCode >= 300 {
828
871
// TODO: do a best effort read of the body here, and format it in the error.
829
- resp .Body .Close ()
830
872
return "" , fmt .Errorf ("broken session: %v" , resp .Status )
831
873
}
832
874
833
875
sessionID = resp .Header .Get (sessionIDHeader )
834
876
switch ct := resp .Header .Get ("Content-Type" ); ct {
835
877
case "text/event-stream" :
836
878
// Section 2.1: The SSE stream is initiated after a POST.
879
+ bodyClosed = true // handleSSE will close.
837
880
go s .handleSSE (resp )
838
881
case "application/json" :
839
882
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
840
- resp .Body .Close ()
841
883
return "" , fmt .Errorf ("streamable HTTP client does not yet support raw JSON responses" )
842
884
default :
843
- resp .Body .Close ()
844
885
return "" , fmt .Errorf ("unsupported content type %q" , ct )
845
886
}
846
887
return sessionID , nil
@@ -912,12 +953,13 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s
912
953
// an error if all retries are exhausted.
913
954
func (s * streamableClientConn ) reconnect (lastEventID string ) (* http.Response , error ) {
914
955
var finalErr error
956
+ maxRetries := s .opts .ReconnectOptions .MaxRetries
915
957
916
- for attempt := 0 ; attempt < s . ReconnectOptions . MaxRetries ; attempt ++ {
958
+ for attempt := 0 ; attempt < maxRetries ; attempt ++ {
917
959
select {
918
960
case <- s .done :
919
961
return nil , fmt .Errorf ("connection closed by client during reconnect" )
920
- case <- time .After (calculateReconnectDelay (s .ReconnectOptions , attempt )):
962
+ case <- time .After (calculateReconnectDelay (& s . opts .ReconnectOptions , attempt )):
921
963
resp , err := s .establishSSE (lastEventID )
922
964
if err != nil {
923
965
finalErr = err // Store the error and try again.
@@ -935,9 +977,9 @@ func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
935
977
}
936
978
// If the loop completes, all retries have failed.
937
979
if finalErr != nil {
938
- return nil , fmt .Errorf ("connection failed after %d attempts: %w" , s . ReconnectOptions . MaxRetries , finalErr )
980
+ return nil , fmt .Errorf ("connection failed after %d attempts: %w" , maxRetries , finalErr )
939
981
}
940
- return nil , fmt .Errorf ("connection failed after %d attempts" , s . ReconnectOptions . MaxRetries )
982
+ return nil , fmt .Errorf ("connection failed after %d attempts" , maxRetries )
941
983
}
942
984
943
985
// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed.
@@ -966,7 +1008,7 @@ func (s *streamableClientConn) Close() error {
966
1008
req .Header .Set (protocolVersionHeader , s .protocolVersion )
967
1009
}
968
1010
req .Header .Set (sessionIDHeader , s ._sessionID )
969
- if _ , err := s .client .Do (req ); err != nil {
1011
+ if _ , err := s .opts . HTTPClient .Do (req ); err != nil {
970
1012
s .closeErr = err
971
1013
}
972
1014
}
@@ -992,7 +1034,7 @@ func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
992
1034
}
993
1035
req .Header .Set ("Accept" , "text/event-stream" )
994
1036
995
- return s .client .Do (req )
1037
+ return s .opts . HTTPClient .Do (req )
996
1038
}
997
1039
998
1040
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
0 commit comments