Skip to content

Commit 1f5950b

Browse files
pcarletondomdomegg
andauthored
[auth] OAuth protected-resource-metadata: fallback on 4xx not just 404 (#879)
Co-authored-by: adam jones <domdomegg+git@gmail.com>
1 parent 64f7cdd commit 1f5950b

File tree

3 files changed

+57
-37
lines changed

3 files changed

+57
-37
lines changed

src/client/auth.test.ts

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,11 @@ describe("OAuth Authorization", () => {
212212
expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource/path?param=value");
213213
});
214214

215-
it("falls back to root discovery when path-aware discovery returns 404", async () => {
216-
// First call (path-aware) returns 404
215+
it.each([400, 401, 403, 404, 410, 422, 429])("falls back to root discovery when path-aware discovery returns %d", async (statusCode) => {
216+
// First call (path-aware) returns 4xx
217217
mockFetch.mockResolvedValueOnce({
218218
ok: false,
219-
status: 404,
219+
status: statusCode,
220220
});
221221

222222
// Second call (root fallback) succeeds
@@ -267,6 +267,20 @@ describe("OAuth Authorization", () => {
267267
expect(calls.length).toBe(2);
268268
});
269269

270+
it("throws error on 500 status and does not fallback", async () => {
271+
// First call (path-aware) returns 500
272+
mockFetch.mockResolvedValueOnce({
273+
ok: false,
274+
status: 500,
275+
});
276+
277+
await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com/path/name"))
278+
.rejects.toThrow();
279+
280+
const calls = mockFetch.mock.calls;
281+
expect(calls.length).toBe(1); // Should not attempt fallback
282+
});
283+
270284
it("does not fallback when the original URL is already at root path", async () => {
271285
// First call (path-aware for root) returns 404
272286
mockFetch.mockResolvedValueOnce({
@@ -907,7 +921,7 @@ describe("OAuth Authorization", () => {
907921
const metadata = await discoverAuthorizationServerMetadata("https://auth.example.com/tenant1");
908922

909923
expect(metadata).toBeUndefined();
910-
924+
911925
// Verify that all discovery URLs were attempted
912926
expect(mockFetch).toHaveBeenCalledTimes(8); // 4 URLs × 2 attempts each (with and without headers)
913927
});

src/client/auth.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ async function tryMetadataDiscovery(
571571
* Determines if fallback to root discovery should be attempted
572572
*/
573573
function shouldAttemptFallback(response: Response | undefined, pathname: string): boolean {
574-
return !response || response.status === 404 && pathname !== '/';
574+
return !response || (response.status >= 400 && response.status < 500) && pathname !== '/';
575575
}
576576

577577
/**

src/client/streamableHttp.test.ts

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ describe("StreamableHTTPClientTransport", () => {
465465

466466
// Verify custom fetch was used
467467
expect(customFetch).toHaveBeenCalled();
468-
468+
469469
// Global fetch should never have been called
470470
expect(global.fetch).not.toHaveBeenCalled();
471471
});
@@ -589,32 +589,32 @@ describe("StreamableHTTPClientTransport", () => {
589589
await expect(transport.send(message)).rejects.toThrow(UnauthorizedError);
590590
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
591591
});
592-
592+
593593
describe('Reconnection Logic', () => {
594594
let transport: StreamableHTTPClientTransport;
595-
595+
596596
// Use fake timers to control setTimeout and make the test instant.
597597
beforeEach(() => jest.useFakeTimers());
598598
afterEach(() => jest.useRealTimers());
599-
599+
600600
it('should reconnect a GET-initiated notification stream that fails', async () => {
601601
// ARRANGE
602602
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
603603
reconnectionOptions: {
604-
initialReconnectionDelay: 10,
605-
maxRetries: 1,
604+
initialReconnectionDelay: 10,
605+
maxRetries: 1,
606606
maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely
607607
reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity
608608
}
609609
});
610-
610+
611611
const errorSpy = jest.fn();
612612
transport.onerror = errorSpy;
613-
613+
614614
const failingStream = new ReadableStream({
615615
start(controller) { controller.error(new Error("Network failure")); }
616616
});
617-
617+
618618
const fetchMock = global.fetch as jest.Mock;
619619
// Mock the initial GET request, which will fail.
620620
fetchMock.mockResolvedValueOnce({
@@ -628,13 +628,13 @@ describe("StreamableHTTPClientTransport", () => {
628628
headers: new Headers({ "content-type": "text/event-stream" }),
629629
body: new ReadableStream(),
630630
});
631-
631+
632632
// ACT
633633
await transport.start();
634634
// Trigger the GET stream directly using the internal method for a clean test.
635635
await transport["_startOrAuthSse"]({});
636636
await jest.advanceTimersByTimeAsync(20); // Trigger reconnection timeout
637-
637+
638638
// ASSERT
639639
expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({
640640
message: expect.stringContaining('SSE stream disconnected: Error: Network failure'),
@@ -644,47 +644,47 @@ describe("StreamableHTTPClientTransport", () => {
644644
expect(fetchMock.mock.calls[0][1]?.method).toBe('GET');
645645
expect(fetchMock.mock.calls[1][1]?.method).toBe('GET');
646646
});
647-
647+
648648
it('should NOT reconnect a POST-initiated stream that fails', async () => {
649649
// ARRANGE
650650
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
651-
reconnectionOptions: {
652-
initialReconnectionDelay: 10,
653-
maxRetries: 1,
651+
reconnectionOptions: {
652+
initialReconnectionDelay: 10,
653+
maxRetries: 1,
654654
maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely
655655
reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity
656656
}
657657
});
658-
658+
659659
const errorSpy = jest.fn();
660660
transport.onerror = errorSpy;
661-
661+
662662
const failingStream = new ReadableStream({
663663
start(controller) { controller.error(new Error("Network failure")); }
664664
});
665-
665+
666666
const fetchMock = global.fetch as jest.Mock;
667667
// Mock the POST request. It returns a streaming content-type but a failing body.
668668
fetchMock.mockResolvedValueOnce({
669669
ok: true, status: 200,
670670
headers: new Headers({ "content-type": "text/event-stream" }),
671671
body: failingStream,
672672
});
673-
673+
674674
// A dummy request message to trigger the `send` logic.
675675
const requestMessage: JSONRPCRequest = {
676676
jsonrpc: '2.0',
677677
method: 'long_running_tool',
678678
id: 'request-1',
679679
params: {},
680680
};
681-
681+
682682
// ACT
683683
await transport.start();
684684
// Use the public `send` method to initiate a POST that gets a stream response.
685685
await transport.send(requestMessage);
686686
await jest.advanceTimersByTimeAsync(20); // Advance time to check for reconnections
687-
687+
688688
// ASSERT
689689
expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({
690690
message: expect.stringContaining('SSE stream disconnected: Error: Network failure'),
@@ -718,7 +718,9 @@ describe("StreamableHTTPClientTransport", () => {
718718
(global.fetch as jest.Mock)
719719
// Initial connection
720720
.mockResolvedValueOnce(unauthedResponse)
721-
// Resource discovery
721+
// Resource discovery, path aware
722+
.mockResolvedValueOnce(unauthedResponse)
723+
// Resource discovery, root
722724
.mockResolvedValueOnce(unauthedResponse)
723725
// OAuth metadata discovery
724726
.mockResolvedValueOnce({
@@ -770,7 +772,9 @@ describe("StreamableHTTPClientTransport", () => {
770772
(global.fetch as jest.Mock)
771773
// Initial connection
772774
.mockResolvedValueOnce(unauthedResponse)
773-
// Resource discovery
775+
// Resource discovery, path aware
776+
.mockResolvedValueOnce(unauthedResponse)
777+
// Resource discovery, root
774778
.mockResolvedValueOnce(unauthedResponse)
775779
// OAuth metadata discovery
776780
.mockResolvedValueOnce({
@@ -822,7 +826,9 @@ describe("StreamableHTTPClientTransport", () => {
822826
(global.fetch as jest.Mock)
823827
// Initial connection
824828
.mockResolvedValueOnce(unauthedResponse)
825-
// Resource discovery
829+
// Resource discovery, path aware
830+
.mockResolvedValueOnce(unauthedResponse)
831+
// Resource discovery, root
826832
.mockResolvedValueOnce(unauthedResponse)
827833
// OAuth metadata discovery
828834
.mockResolvedValueOnce({
@@ -888,7 +894,7 @@ describe("StreamableHTTPClientTransport", () => {
888894
ok: false,
889895
status: 404
890896
});
891-
897+
892898
// Create transport instance
893899
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
894900
authProvider: mockAuthProvider,
@@ -901,14 +907,14 @@ describe("StreamableHTTPClientTransport", () => {
901907

902908
// Verify custom fetch was used
903909
expect(customFetch).toHaveBeenCalled();
904-
910+
905911
// Verify specific OAuth endpoints were called with custom fetch
906912
const customFetchCalls = customFetch.mock.calls;
907913
const callUrls = customFetchCalls.map(([url]) => url.toString());
908-
914+
909915
// Should have called resource metadata discovery
910916
expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true);
911-
917+
912918
// Should have called OAuth authorization server metadata discovery
913919
expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true);
914920

@@ -966,19 +972,19 @@ describe("StreamableHTTPClientTransport", () => {
966972

967973
// Verify custom fetch was used
968974
expect(customFetch).toHaveBeenCalled();
969-
975+
970976
// Verify specific OAuth endpoints were called with custom fetch
971977
const customFetchCalls = customFetch.mock.calls;
972978
const callUrls = customFetchCalls.map(([url]) => url.toString());
973-
979+
974980
// Should have called resource metadata discovery
975981
expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true);
976-
982+
977983
// Should have called OAuth authorization server metadata discovery
978984
expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true);
979985

980986
// Should have called token endpoint for authorization code exchange
981-
const tokenCalls = customFetchCalls.filter(([url, options]) =>
987+
const tokenCalls = customFetchCalls.filter(([url, options]) =>
982988
url.toString().includes('/token') && options?.method === "POST"
983989
);
984990
expect(tokenCalls.length).toBeGreaterThan(0);

0 commit comments

Comments
 (0)