|
1 | 1 | # frozen_string_literal: true
|
2 | 2 |
|
3 |
| -require 'oauth2' |
4 | 3 | require 'jwt'
|
| 4 | +require 'oauth2' |
| 5 | +require 'securerandom' |
5 | 6 |
|
6 | 7 | module MCP
|
7 | 8 | module Server
|
8 | 9 | module Auth
|
9 | 10 | # OAuth 2.1 implementation for MCP server
|
10 | 11 | class OAuth
|
11 |
| - attr_reader :issuer, :client_id, :client_secret, :redirect_uri |
| 12 | + attr_reader :client_id, :client_secret |
12 | 13 |
|
13 | 14 | # Initialize OAuth
|
14 | 15 | # @param options [Hash] OAuth options
|
15 | 16 | def initialize(options = {})
|
16 |
| - @issuer = options[:issuer] |
17 | 17 | @client_id = options[:client_id]
|
18 | 18 | @client_secret = options[:client_secret]
|
19 |
| - @redirect_uri = options[:redirect_uri] |
20 |
| - @jwt_secret = options[:jwt_secret] |
21 | 19 | @token_expiry = options[:token_expiry] || 3600 # 1 hour
|
22 |
| - @authorization_endpoint = options[:authorization_endpoint] |
23 |
| - @token_endpoint = options[:token_endpoint] |
| 20 | + @jwt_secret = options[:jwt_secret] || SecureRandom.hex(32) |
| 21 | + @issuer = options[:issuer] || 'mcp_server' |
24 | 22 | @scopes = options[:scopes] || ['mcp']
|
25 | 23 | @logger = MCP.logger
|
26 |
| - |
27 |
| - validate_options! |
28 |
| - end |
29 |
| - |
30 |
| - # Create an authorization URL |
31 |
| - # @param state [String] State parameter for CSRF protection |
32 |
| - # @param scopes [Array<String>] Requested scopes |
33 |
| - # @return [String] The authorization URL |
34 |
| - def authorization_url(state, scopes = nil) |
35 |
| - client = create_client |
36 |
| - |
37 |
| - client.auth_code.authorize_url( |
38 |
| - redirect_uri: @redirect_uri, |
39 |
| - scope: scopes || @scopes, |
40 |
| - state: state |
41 |
| - ) |
42 |
| - end |
43 |
| - |
44 |
| - # Exchange an authorization code for a token |
45 |
| - # @param code [String] The authorization code |
46 |
| - # @return [OAuth2::AccessToken] The access token |
47 |
| - def exchange_code(code) |
48 |
| - client = create_client |
49 |
| - |
50 |
| - client.auth_code.get_token( |
51 |
| - code, |
52 |
| - redirect_uri: @redirect_uri |
53 |
| - ) |
54 | 24 | end
|
55 | 25 |
|
56 |
| - # Create a JWT from an access token |
57 |
| - # @param token [OAuth2::AccessToken] The access token |
| 26 | + # Create a JWT from token parameters |
| 27 | + # @param token [OAuth2::AccessToken] The token with parameters |
58 | 28 | # @return [String] The JWT
|
59 | 29 | def create_jwt(token)
|
| 30 | + # Extract user ID from token parameters |
| 31 | + user_id = token.params['user_id'] || token.params['sub'] |
| 32 | + |
| 33 | + # Extract scopes from token parameters or use default scopes |
| 34 | + scopes = if token.params['scope'] |
| 35 | + token.params['scope'].split(' ') |
| 36 | + else |
| 37 | + @scopes |
| 38 | + end |
| 39 | + |
| 40 | + # Create JWT payload with string keys for proper JWT serialization |
60 | 41 | payload = {
|
61 |
| - sub: token.params['user_id'] || token.params['sub'], |
62 |
| - exp: Time.now.to_i + @token_expiry, |
63 |
| - iat: Time.now.to_i, |
64 |
| - iss: @issuer, |
65 |
| - scopes: token.params['scope']&.split(' ') || @scopes |
| 42 | + 'sub' => user_id, |
| 43 | + 'exp' => Time.now.to_i + @token_expiry, |
| 44 | + 'iat' => Time.now.to_i, |
| 45 | + 'iss' => @issuer, |
| 46 | + 'scopes' => scopes |
66 | 47 | }
|
67 | 48 |
|
| 49 | + # Encode JWT |
68 | 50 | JWT.encode(payload, @jwt_secret, 'HS256')
|
69 | 51 | end
|
70 | 52 |
|
71 | 53 | # Verify a JWT
|
72 |
| - # @param jwt [String] The JWT to verify |
73 |
| - # @return [Hash] The decoded JWT payload |
74 |
| - def verify_jwt(jwt) |
| 54 | + # @param token [String] The token to verify |
| 55 | + # @return [Hash, nil] The payload if valid, nil if invalid |
| 56 | + def verify_jwt(token) |
75 | 57 | begin
|
76 |
| - decoded = JWT.decode(jwt, @jwt_secret, true, { algorithm: 'HS256' }) |
| 58 | + decoded = JWT.decode(token, @jwt_secret, true, { algorithm: 'HS256' }) |
77 | 59 | decoded[0] # Return the payload
|
78 |
| - rescue JWT::DecodeError => e |
79 |
| - @logger.error("JWT verification failed: #{e.message}") |
| 60 | + rescue JWT::DecodeError, JWT::ExpiredSignature => e |
| 61 | + @logger&.error("JWT verification failed: #{e.message}") if @logger |
80 | 62 | nil
|
81 | 63 | end
|
82 | 64 | end
|
83 | 65 |
|
84 |
| - # Check if a JWT has a specific scope |
85 |
| - # @param jwt [String] The JWT to check |
86 |
| - # @param scope [String] The scope to check for |
87 |
| - # @return [Boolean] true if the JWT has the scope |
88 |
| - def has_scope?(jwt, scope) |
89 |
| - payload = verify_jwt(jwt) |
90 |
| - return false unless payload |
91 |
| - |
92 |
| - scopes = payload['scopes'] || [] |
93 |
| - scopes.include?(scope) |
| 66 | + # Authenticate client credentials |
| 67 | + # @param client_id [String] The client ID |
| 68 | + # @param client_secret [String] The client secret |
| 69 | + # @return [Boolean] true if valid credentials |
| 70 | + def authenticate_client(client_id, client_secret) |
| 71 | + client_id == @client_id && client_secret == @client_secret |
94 | 72 | end
|
95 | 73 |
|
96 |
| - private |
97 |
| - |
98 |
| - # Create an OAuth client |
99 |
| - # @return [OAuth2::Client] The OAuth client |
100 |
| - def create_client |
101 |
| - OAuth2::Client.new( |
102 |
| - @client_id, |
103 |
| - @client_secret, |
104 |
| - site: @issuer, |
105 |
| - authorize_url: @authorization_endpoint, |
106 |
| - token_url: @token_endpoint |
| 74 | + # Create an OAuth2 token |
| 75 | + # @param params [Hash] Token parameters |
| 76 | + # @return [OAuth2::AccessToken] The token |
| 77 | + def create_token(params) |
| 78 | + client = OAuth2::Client.new(@client_id, @client_secret) |
| 79 | + |
| 80 | + # Create a new hash with token parameters |
| 81 | + token_params = {} |
| 82 | + |
| 83 | + # Copy all original params |
| 84 | + params.each do |key, value| |
| 85 | + token_params[key] = value |
| 86 | + end |
| 87 | + |
| 88 | + # Set default scope if not provided |
| 89 | + token_params['scope'] ||= @scopes.join(' ') |
| 90 | + |
| 91 | + # Create token |
| 92 | + OAuth2::AccessToken.new( |
| 93 | + client, |
| 94 | + SecureRandom.hex(16), |
| 95 | + refresh_token: SecureRandom.hex(16), |
| 96 | + expires_in: @token_expiry, |
| 97 | + params: token_params |
107 | 98 | )
|
108 | 99 | end
|
109 | 100 |
|
110 |
| - # Validate required options |
111 |
| - # @raise [MCP::Errors::AuthenticationError] If required options are missing |
112 |
| - def validate_options! |
113 |
| - unless @issuer && @client_id && @client_secret && @redirect_uri && @jwt_secret |
114 |
| - raise MCP::Errors::AuthenticationError, "Missing required OAuth options" |
115 |
| - end |
| 101 | + # Verify if a token has a required scope |
| 102 | + # @param token_payload [Hash] The token payload |
| 103 | + # @param required_scope [String] The required scope |
| 104 | + # @return [Boolean] true if token has the scope |
| 105 | + def verify_scope(token_payload, required_scope) |
| 106 | + return false unless token_payload && token_payload['scopes'] |
| 107 | + |
| 108 | + scopes = token_payload['scopes'] |
| 109 | + return false if scopes.nil? || scopes.empty? |
| 110 | + |
| 111 | + scopes.include?(required_scope) |
116 | 112 | end
|
117 | 113 | end
|
118 | 114 | end
|
|
0 commit comments