Skip to content

Commit 13736dc

Browse files
committed
mcp: add client-side OAuth flow (preliminary)
This is a preliminary implementation of OAuth 2.1 for the client. When a StreamableClientTransport encounters a 401 Unauthorized response from the server, it initiates the OAuth flow described in thec authorization section of the MCP spec (https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization). On success, the transport obtains an access token which it passes to all subsequent requests. Much remains to be done here: - Dynamic client registration is not implemented. Since it is optional, we also need another way of supplying the client ID and secret to this code. - Resource Indicators, as described in section 2.5.1 of the MCP spec. - There is no way for the user to provide a redirect URL. - All of this is unexported, so it is available only to our own StreamingClientTransport. We should add API so people can use it with their own transports. - And, of course, tests. We should test against fake implementations but also, if we can find any, real reference implementations.
1 parent e7f1487 commit 13736dc

File tree

5 files changed

+158
-42
lines changed

5 files changed

+158
-42
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ go 1.23.0
55
require (
66
github.com/google/go-cmp v0.7.0
77
github.com/yosida95/uritemplate/v3 v3.0.2
8+
golang.org/x/oauth2 v0.30.0
89
golang.org/x/tools v0.34.0
910
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
22
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
33
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
44
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
5+
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
6+
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
57
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
68
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=

internal/oauthex/resource_meta.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string,
145145
// If there is no URL in the request, it returns nil, nil.
146146
func GetProtectedResourceMetadataFromHeader(ctx context.Context, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) {
147147
defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader")
148-
headers := header[http.CanonicalHeaderKey("WWW-Authenticate")]
149-
if len(headers) == 0 {
148+
authHeaders := header[http.CanonicalHeaderKey("WWW-Authenticate")]
149+
if len(authHeaders) == 0 {
150150
return nil, nil
151151
}
152-
cs, err := parseWWWAuthenticate(headers)
152+
cs, err := parseWWWAuthenticate(authHeaders)
153153
if err != nil {
154154
return nil, err
155155
}

mcp/auth.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright 2025 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package mcp
6+
7+
import (
8+
"context"
9+
"fmt"
10+
"net/http"
11+
12+
"github.com/modelcontextprotocol/go-sdk/internal/oauthex"
13+
"golang.org/x/oauth2"
14+
"golang.org/x/oauth2/authhandler"
15+
)
16+
17+
// newAuthClient returns a shallow copy of c with its tranport replaced by one that
18+
// authorizes with the token source.
19+
func newAuthClient(c *http.Client, ts oauth2.TokenSource) *http.Client {
20+
c2 := *c
21+
c2.Transport = &oauth2.Transport{
22+
Base: c.Transport,
23+
Source: ts,
24+
}
25+
return &c2
26+
}
27+
28+
// doOauth runs the OAuth 2.1 flow for MCP as described in
29+
// https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization.
30+
// It returns the resulting TokenSource.
31+
func doOauth(ctx context.Context, header http.Header, c *http.Client, oauthHandler authhandler.AuthorizationHandler) (oauth2.TokenSource, error) {
32+
prm, err := oauthex.GetProtectedResourceMetadataFromHeader(ctx, header, c)
33+
if err != nil {
34+
return nil, err
35+
}
36+
if len(prm.AuthorizationServers) == 0 {
37+
return nil, fmt.Errorf("resource %s provided no authorization servers", prm.Resource)
38+
}
39+
// TODO: try more than one?
40+
authServer := prm.AuthorizationServers[0]
41+
// TODO: which scopes to ask for? All of them?
42+
scopes := prm.ScopesSupported
43+
asm, err := oauthex.GetAuthServerMeta(ctx, authServer, c)
44+
if err != nil {
45+
return nil, err
46+
}
47+
// TODO: register the client with the auth server if not registered yet,
48+
// or find another way to get the client ID and secret.
49+
50+
// Get an access token from the auth server.
51+
config := &oauth2.Config{
52+
ClientID: "TODO: from registration",
53+
ClientSecret: "TODO: from registration",
54+
Endpoint: oauth2.Endpoint{
55+
AuthURL: asm.AuthorizationEndpoint,
56+
TokenURL: asm.TokenEndpoint,
57+
// DeviceAuthURL: "",
58+
// AuthStyle: "from auth meta?",
59+
},
60+
RedirectURL: "", // ???
61+
Scopes: scopes,
62+
}
63+
v := oauth2.GenerateVerifier()
64+
pkceParams := authhandler.PKCEParams{
65+
ChallengeMethod: "S256",
66+
Challenge: oauth2.S256ChallengeFromVerifier(v),
67+
Verifier: v,
68+
}
69+
state := randText()
70+
return authhandler.TokenSourceWithPKCE(ctx, config, state, oauthHandler, &pkceParams), nil
71+
}

mcp/streamable.go

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ import (
2020
"time"
2121

2222
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
23+
"github.com/modelcontextprotocol/go-sdk/internal/util"
2324
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
25+
"golang.org/x/oauth2/authhandler"
2426
)
2527

2628
const (
@@ -656,7 +658,7 @@ type StreamableReconnectOptions struct {
656658
}
657659

658660
// DefaultReconnectOptions provides sensible defaults for reconnect logic.
659-
var DefaultReconnectOptions = &StreamableReconnectOptions{
661+
var DefaultReconnectOptions = StreamableReconnectOptions{
660662
MaxRetries: 5,
661663
growFactor: 1.5,
662664
initialDelay: 1 * time.Second,
@@ -666,10 +668,18 @@ var DefaultReconnectOptions = &StreamableReconnectOptions{
666668
// StreamableClientTransportOptions provides options for the
667669
// [NewStreamableClientTransport] constructor.
668670
type StreamableClientTransportOptions struct {
669-
// HTTPClient is the client to use for making HTTP requests. If nil,
670-
// http.DefaultClient is used.
671-
HTTPClient *http.Client
672-
ReconnectOptions *StreamableReconnectOptions
671+
// ReconnectOptions control the transport's behavior when it is disconnected
672+
// from the server.
673+
ReconnectOptions StreamableReconnectOptions
674+
// HTTPClient is the client to use for making unauthenticaed HTTP requests.
675+
// If nil, http.DefaultClient is used.
676+
// For authenticated requests, a shallow clone of the client will be used,
677+
// with a different transport. The cookie jar will not be copied.
678+
HTTPClient *http.Client
679+
// AuthHandler is a function that handles the user interaction part of the OAuth 2.1 flow.
680+
// It should prompt the user at the given URL and return the expected OAuth values.
681+
// See [authhandler.AuthorizationHandler] for more.
682+
AuthHandler authhandler.AuthorizationHandler
673683
}
674684

675685
// NewStreamableClientTransport returns a new client transport that connects to
@@ -679,6 +689,12 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt
679689
if opts != nil {
680690
t.opts = *opts
681691
}
692+
if t.opts.HTTPClient == nil {
693+
t.opts.HTTPClient = http.DefaultClient
694+
}
695+
if t.opts.ReconnectOptions == (StreamableReconnectOptions{}) {
696+
t.opts.ReconnectOptions = DefaultReconnectOptions
697+
}
682698
return t
683699
}
684700

@@ -691,36 +707,26 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt
691707
// When closed, the connection issues a DELETE request to terminate the logical
692708
// session.
693709
func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, error) {
694-
client := t.opts.HTTPClient
695-
if client == nil {
696-
client = http.DefaultClient
697-
}
698-
reconnOpts := t.opts.ReconnectOptions
699-
if reconnOpts == nil {
700-
reconnOpts = DefaultReconnectOptions
701-
}
702710
// Create a new cancellable context that will manage the connection's lifecycle.
703711
// This is crucial for cleanly shutting down the background SSE listener by
704712
// cancelling its blocking network operations, which prevents hangs on exit.
705713
connCtx, cancel := context.WithCancel(context.Background())
706-
conn := &streamableClientConn{
707-
url: t.url,
708-
client: client,
709-
incoming: make(chan []byte, 100),
710-
done: make(chan struct{}),
711-
ReconnectOptions: reconnOpts,
712-
ctx: connCtx,
713-
cancel: cancel,
714-
}
715-
return conn, nil
714+
return &streamableClientConn{
715+
url: t.url,
716+
opts: t.opts,
717+
incoming: make(chan []byte, 100),
718+
done: make(chan struct{}),
719+
ctx: connCtx,
720+
cancel: cancel,
721+
}, nil
716722
}
717723

718724
type streamableClientConn struct {
719-
url string
720-
client *http.Client
721-
incoming chan []byte
722-
done chan struct{}
723-
ReconnectOptions *StreamableReconnectOptions
725+
url string
726+
opts StreamableClientTransportOptions
727+
authClient *http.Client
728+
incoming chan []byte
729+
done chan struct{}
724730

725731
closeOnce sync.Once
726732
closeErr error
@@ -800,7 +806,11 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
800806
return nil
801807
}
802808

803-
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) {
809+
// postMessage makes a POST request to the server with msg as the body.
810+
// It returns the session ID.
811+
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (_ string, err error) {
812+
defer util.Wrapf(&err, "MCP client posting message, session ID %q", sessionID)
813+
804814
data, err := jsonrpc2.EncodeMessage(msg)
805815
if err != nil {
806816
return "", err
@@ -819,28 +829,59 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
819829
req.Header.Set("Content-Type", "application/json")
820830
req.Header.Set("Accept", "application/json, text/event-stream")
821831

822-
resp, err := s.client.Do(req)
832+
// Use an HTTP client that does authentication, if there is one.
833+
// Otherwise, use the one provided by the user.
834+
client := s.authClient
835+
if client == nil {
836+
client = s.opts.HTTPClient
837+
}
838+
// TODO: Resource Indicators, as in
839+
// https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization#resource-parameter-implementation
840+
resp, err := client.Do(req)
823841
if err != nil {
824842
return "", err
825843
}
844+
bodyClosed := false // avoid a second call to Close: undefined behavior (see [io.Closer])
845+
defer func() {
846+
if resp != nil && !bodyClosed {
847+
resp.Body.Close()
848+
}
849+
}()
850+
851+
if resp.StatusCode == http.StatusUnauthorized {
852+
if client == s.authClient {
853+
return "", errors.New("got StatusUnauthorized when already authorized")
854+
}
855+
tokenSource, err := doOauth(ctx, resp.Header, s.opts.HTTPClient, s.opts.AuthHandler)
856+
if err != nil {
857+
return "", err
858+
}
859+
s.authClient = newAuthClient(s.opts.HTTPClient, tokenSource)
860+
resp.Body.Close() // because we're about to replace resp
861+
resp, err = s.authClient.Do(req)
862+
if err != nil {
863+
return "", err
864+
}
865+
if resp.StatusCode == http.StatusUnauthorized {
866+
return "", errors.New("got StatusUnauthorized just after authorization")
867+
}
868+
}
826869

827870
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
828871
// TODO: do a best effort read of the body here, and format it in the error.
829-
resp.Body.Close()
830872
return "", fmt.Errorf("broken session: %v", resp.Status)
831873
}
832874

833875
sessionID = resp.Header.Get(sessionIDHeader)
834876
switch ct := resp.Header.Get("Content-Type"); ct {
835877
case "text/event-stream":
836878
// Section 2.1: The SSE stream is initiated after a POST.
879+
bodyClosed = true // handleSSE will close.
837880
go s.handleSSE(resp)
838881
case "application/json":
839882
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
840-
resp.Body.Close()
841883
return "", fmt.Errorf("streamable HTTP client does not yet support raw JSON responses")
842884
default:
843-
resp.Body.Close()
844885
return "", fmt.Errorf("unsupported content type %q", ct)
845886
}
846887
return sessionID, nil
@@ -912,12 +953,13 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s
912953
// an error if all retries are exhausted.
913954
func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) {
914955
var finalErr error
956+
maxRetries := s.opts.ReconnectOptions.MaxRetries
915957

916-
for attempt := 0; attempt < s.ReconnectOptions.MaxRetries; attempt++ {
958+
for attempt := 0; attempt < maxRetries; attempt++ {
917959
select {
918960
case <-s.done:
919961
return nil, fmt.Errorf("connection closed by client during reconnect")
920-
case <-time.After(calculateReconnectDelay(s.ReconnectOptions, attempt)):
962+
case <-time.After(calculateReconnectDelay(&s.opts.ReconnectOptions, attempt)):
921963
resp, err := s.establishSSE(lastEventID)
922964
if err != nil {
923965
finalErr = err // Store the error and try again.
@@ -935,9 +977,9 @@ func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
935977
}
936978
// If the loop completes, all retries have failed.
937979
if finalErr != nil {
938-
return nil, fmt.Errorf("connection failed after %d attempts: %w", s.ReconnectOptions.MaxRetries, finalErr)
980+
return nil, fmt.Errorf("connection failed after %d attempts: %w", maxRetries, finalErr)
939981
}
940-
return nil, fmt.Errorf("connection failed after %d attempts", s.ReconnectOptions.MaxRetries)
982+
return nil, fmt.Errorf("connection failed after %d attempts", maxRetries)
941983
}
942984

943985
// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed.
@@ -966,7 +1008,7 @@ func (s *streamableClientConn) Close() error {
9661008
req.Header.Set(protocolVersionHeader, s.protocolVersion)
9671009
}
9681010
req.Header.Set(sessionIDHeader, s._sessionID)
969-
if _, err := s.client.Do(req); err != nil {
1011+
if _, err := s.opts.HTTPClient.Do(req); err != nil {
9701012
s.closeErr = err
9711013
}
9721014
}
@@ -992,7 +1034,7 @@ func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
9921034
}
9931035
req.Header.Set("Accept", "text/event-stream")
9941036

995-
return s.client.Do(req)
1037+
return s.opts.HTTPClient.Do(req)
9961038
}
9971039

9981040
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.

0 commit comments

Comments
 (0)