diff --git a/README.md b/README.md index d7f32b01..eb8eb0ab 100644 --- a/README.md +++ b/README.md @@ -333,6 +333,29 @@ The service can be configured using environment variables: | `MCP_REGISTRY_SEED_IMPORT` | Import `seed.json` on first run | `true` | | `MCP_REGISTRY_SERVER_ADDRESS` | Listen address for the server | `:8080` | +### Background Job Configuration + +The registry includes a background verification job that continuously validates domain ownership for registered servers: + +| Variable | Description | Default | +|----------|-------------|---------| +| `MCP_REGISTRY_BACKGROUND_JOB_ENABLED` | Enable background verification job | `true` | +| `MCP_REGISTRY_BACKGROUND_JOB_CRON_SCHEDULE` | Cron schedule for verification runs | `0 0 2 * * *` (daily at 2 AM) | +| `MCP_REGISTRY_BACKGROUND_JOB_MAX_CONCURRENT` | Maximum concurrent verifications | `10` | +| `MCP_REGISTRY_BACKGROUND_JOB_VERIFICATION_TIMEOUT_SECONDS` | Timeout for each verification (seconds) | `30` | +| `MCP_REGISTRY_BACKGROUND_JOB_FAILURE_THRESHOLD` | Consecutive failures before marking as failed | `3` | +| `MCP_REGISTRY_BACKGROUND_JOB_NOTIFICATION_COOLDOWN_HOURS` | Hours between failure notifications | `24` | +| `MCP_REGISTRY_BACKGROUND_JOB_CLEANUP_INTERVAL_DAYS` | Days between cleanup of old records | `7` | + +Example production configuration: +```bash +# Enable background job with custom schedule (every 6 hours) +MCP_REGISTRY_BACKGROUND_JOB_ENABLED=true +MCP_REGISTRY_BACKGROUND_JOB_CRON_SCHEDULE="0 0 */6 * * *" +MCP_REGISTRY_BACKGROUND_JOB_MAX_CONCURRENT=20 +MCP_REGISTRY_BACKGROUND_JOB_VERIFICATION_TIMEOUT_SECONDS=60 +``` + ## Testing diff --git a/cmd/registry/main.go b/cmd/registry/main.go index 01d2b0a2..ab48f261 100644 --- a/cmd/registry/main.go +++ b/cmd/registry/main.go @@ -17,6 +17,7 @@ import ( "github.com/modelcontextprotocol/registry/internal/database" "github.com/modelcontextprotocol/registry/internal/model" "github.com/modelcontextprotocol/registry/internal/service" + "github.com/modelcontextprotocol/registry/internal/verification" ) func main() { @@ -93,6 +94,42 @@ func main() { } } + // Initialize background verification job (if enabled) + var backgroundJob *verification.BackgroundVerificationJob + if cfg.BackgroundJobEnabled { + // Convert config to verification background job config + backgroundJobConfig := &verification.BackgroundJobConfig{ + CronSchedule: cfg.BackgroundJobCronSchedule, + MaxConcurrentVerifications: cfg.BackgroundJobMaxConcurrentVerifications, + VerificationTimeout: time.Duration(cfg.BackgroundJobVerificationTimeoutSeconds) * time.Second, + FailureThreshold: cfg.BackgroundJobFailureThreshold, + RetryDelay: time.Second, + NotificationCooldown: time.Duration(cfg.BackgroundJobNotificationCooldownHours) * time.Hour, + CleanupInterval: time.Duration(cfg.BackgroundJobCleanupIntervalDays) * 24 * time.Hour, + } + + backgroundJob = verification.NewBackgroundVerificationJob(db, backgroundJobConfig, nil) + + // Start background verification job + ctx := context.Background() + if err := backgroundJob.Start(ctx); err != nil { + log.Printf("Failed to start background verification job: %v", err) + } else { + log.Println("Background verification job started successfully") + } + + // Defer stopping the background job + defer func() { + if err := backgroundJob.Stop(); err != nil { + log.Printf("Error stopping background verification job: %v", err) + } else { + log.Println("Background verification job stopped successfully") + } + }() + } else { + log.Println("Background verification job is disabled") + } + // Initialize authentication services authService := auth.NewAuthService(cfg) diff --git a/go.mod b/go.mod index 5cbf5146..27696a50 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.0 require ( github.com/caarlos0/env/v11 v11.3.1 github.com/google/uuid v1.6.0 + github.com/robfig/cron/v3 v3.0.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/stretchr/testify v1.10.0 github.com/swaggo/files v1.0.1 diff --git a/go.sum b/go.sum index 4b96ee6a..e36c55a8 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8 github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6NgVqpn3+iol9aGu4= diff --git a/internal/config/config.go b/internal/config/config.go index 445178b5..d2a2b25e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,7 @@ package config import ( - env "github.com/caarlos0/env/v11" + "github.com/caarlos0/env/v11" ) type DatabaseType string @@ -25,6 +25,15 @@ type Config struct { Version string `env:"VERSION" envDefault:"dev"` GithubClientID string `env:"GITHUB_CLIENT_ID" envDefault:""` GithubClientSecret string `env:"GITHUB_CLIENT_SECRET" envDefault:""` + + // Background verification job configuration + BackgroundJobEnabled bool `env:"BACKGROUND_JOB_ENABLED" envDefault:"true"` + BackgroundJobCronSchedule string `env:"BACKGROUND_JOB_CRON_SCHEDULE" envDefault:"0 0 2 * * *"` + BackgroundJobMaxConcurrentVerifications int `env:"BACKGROUND_JOB_MAX_CONCURRENT" envDefault:"10"` + BackgroundJobVerificationTimeoutSeconds int `env:"BACKGROUND_JOB_VERIFICATION_TIMEOUT_SECONDS" envDefault:"30"` + BackgroundJobFailureThreshold int `env:"BACKGROUND_JOB_FAILURE_THRESHOLD" envDefault:"3"` + BackgroundJobNotificationCooldownHours int `env:"BACKGROUND_JOB_NOTIFICATION_COOLDOWN_HOURS" envDefault:"24"` + BackgroundJobCleanupIntervalDays int `env:"BACKGROUND_JOB_CLEANUP_INTERVAL_DAYS" envDefault:"7"` } // NewConfig creates a new configuration with default values diff --git a/internal/database/database.go b/internal/database/database.go index 2637bbe0..23f10fff 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -3,6 +3,7 @@ package database import ( "context" "errors" + "time" "github.com/modelcontextprotocol/registry/internal/model" ) @@ -32,6 +33,16 @@ type Database interface { ImportSeed(ctx context.Context, seedFilePath string) error // Close closes the database connection Close() error + + // Domain verification methods + // GetVerifiedDomains retrieves all domains that are currently verified + GetVerifiedDomains(ctx context.Context) ([]string, error) + // GetDomainVerification retrieves domain verification details + GetDomainVerification(ctx context.Context, domain string) (*model.DomainVerification, error) + // UpdateDomainVerification updates or creates domain verification record + UpdateDomainVerification(ctx context.Context, domainVerification *model.DomainVerification) error + // CleanupOldVerifications removes old verification records before the given time + CleanupOldVerifications(ctx context.Context, before time.Time) (int, error) } // ConnectionType represents the type of database connection diff --git a/internal/database/memory.go b/internal/database/memory.go index 2440ceef..be7f014c 100644 --- a/internal/database/memory.go +++ b/internal/database/memory.go @@ -18,6 +18,7 @@ import ( type MemoryDB struct { entries map[string]*model.ServerDetail domainVerifications map[string]*model.DomainVerification // key: domain + metadata map[string]*model.Metadata // key: serverID mu sync.RWMutex } @@ -33,6 +34,7 @@ func NewMemoryDB(e map[string]*model.Server) *MemoryDB { return &MemoryDB{ entries: serverDetails, domainVerifications: make(map[string]*model.DomainVerification), + metadata: make(map[string]*model.Metadata), } } @@ -356,3 +358,85 @@ func (db *MemoryDB) GetVerificationTokens(ctx context.Context, domain string) (* return domainVerification.VerificationTokens, nil } + +// GetVerifiedDomains retrieves all domains that are currently verified +func (db *MemoryDB) GetVerifiedDomains(ctx context.Context) ([]string, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var domains []string + for _, metadata := range db.metadata { + if metadata.DomainVerification != nil && + metadata.DomainVerification.Status == model.VerificationStatusVerified { + domains = append(domains, metadata.DomainVerification.Domain) + } + } + + return domains, nil +} + +// GetDomainVerification retrieves domain verification details +func (db *MemoryDB) GetDomainVerification(ctx context.Context, domain string) (*model.DomainVerification, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + for _, metadata := range db.metadata { + if metadata.DomainVerification != nil && + metadata.DomainVerification.Domain == domain { + return metadata.DomainVerification, nil + } + } + + return nil, ErrNotFound +} + +// UpdateDomainVerification updates or creates domain verification record +func (db *MemoryDB) UpdateDomainVerification(ctx context.Context, domainVerification *model.DomainVerification) error { + db.mu.Lock() + defer db.mu.Unlock() + + // Find existing metadata entry for this domain or create a new one + var targetMetadata *model.Metadata + var targetServerID string + + // Find existing metadata for this domain + for _, metadata := range db.metadata { + if metadata.DomainVerification != nil && + metadata.DomainVerification.Domain == domainVerification.Domain { + targetMetadata = metadata + break + } + } + + if targetMetadata == nil { + // Create new metadata entry + targetServerID = uuid.New().String() + targetMetadata = &model.Metadata{ + ServerID: targetServerID, + } + db.metadata[targetServerID] = targetMetadata + } + + targetMetadata.DomainVerification = domainVerification + return nil +} + +// CleanupOldVerifications removes old verification records before the given time +func (db *MemoryDB) CleanupOldVerifications(ctx context.Context, before time.Time) (int, error) { + db.mu.Lock() + defer db.mu.Unlock() + + count := 0 + for serverID, metadata := range db.metadata { + if metadata.DomainVerification != nil { + // Remove records that are old and have failed status + if metadata.DomainVerification.Status == model.VerificationStatusFailed && + metadata.DomainVerification.LastVerificationAttempt.Before(before) { + delete(db.metadata, serverID) + count++ + } + } + } + + return count, nil +} diff --git a/internal/database/mongo.go b/internal/database/mongo.go index f6664b06..2de93175 100644 --- a/internal/database/mongo.go +++ b/internal/database/mongo.go @@ -20,6 +20,7 @@ type MongoDB struct { database *mongo.Database serverCollection *mongo.Collection verificationCollection *mongo.Collection + metadataCollection *mongo.Collection } // NewMongoDB creates a new instance of the MongoDB database @@ -40,6 +41,7 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName database := client.Database(databaseName) serverCollection := database.Collection(collectionName) verificationCollection := database.Collection(verificationCollectionName) + metadataCollection := database.Collection("metadata") // Create indexes for better query performance models := []mongo.IndexModel{ @@ -90,6 +92,7 @@ func NewMongoDB(ctx context.Context, connectionURI, databaseName, collectionName database: database, serverCollection: serverCollection, verificationCollection: verificationCollection, + metadataCollection: metadataCollection, }, nil } @@ -402,3 +405,98 @@ func (db *MongoDB) GetVerificationTokens(ctx context.Context, domain string) (*m return domainVerification.VerificationTokens, nil } + +// GetVerifiedDomains retrieves all domains that are currently verified +func (db *MongoDB) GetVerifiedDomains(ctx context.Context) ([]string, error) { + filter := bson.M{ + "domain_verification.status": model.VerificationStatusVerified, + } + + cursor, err := db.metadataCollection.Find(ctx, filter) + if err != nil { + return nil, fmt.Errorf("failed to query verified domains: %w", err) + } + defer cursor.Close(ctx) + + var domains []string + for cursor.Next(ctx) { + var metadata model.Metadata + if err := cursor.Decode(&metadata); err != nil { + log.Printf("Failed to decode metadata: %v", err) + continue + } + + if metadata.DomainVerification != nil { + domains = append(domains, metadata.DomainVerification.Domain) + } + } + + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error: %w", err) + } + + return domains, nil +} + +// GetDomainVerification retrieves domain verification details +func (db *MongoDB) GetDomainVerification(ctx context.Context, domain string) (*model.DomainVerification, error) { + filter := bson.M{ + "domain_verification.domain": domain, + } + + var metadata model.Metadata + err := db.metadataCollection.FindOne(ctx, filter).Decode(&metadata) + if err != nil { + if errors.Is(err, mongo.ErrNoDocuments) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("failed to get domain verification: %w", err) + } + + if metadata.DomainVerification == nil { + return nil, ErrNotFound + } + + return metadata.DomainVerification, nil +} + +// UpdateDomainVerification updates or creates domain verification record +func (db *MongoDB) UpdateDomainVerification(ctx context.Context, domainVerification *model.DomainVerification) error { + filter := bson.M{ + "domain_verification.domain": domainVerification.Domain, + } + + update := bson.M{ + "$set": bson.M{ + "domain_verification": domainVerification, + }, + "$setOnInsert": bson.M{ + "server_id": uuid.New().String(), + }, + } + + opts := options.Update().SetUpsert(true) + _, err := db.metadataCollection.UpdateOne(ctx, filter, update, opts) + if err != nil { + return fmt.Errorf("failed to update domain verification: %w", err) + } + + return nil +} + +// CleanupOldVerifications removes old verification records before the given time +func (db *MongoDB) CleanupOldVerifications(ctx context.Context, before time.Time) (int, error) { + filter := bson.M{ + "domain_verification.status": model.VerificationStatusFailed, + "domain_verification.last_verification_attempt": bson.M{ + "$lt": before, + }, + } + + result, err := db.metadataCollection.DeleteMany(ctx, filter) + if err != nil { + return 0, fmt.Errorf("failed to cleanup old verifications: %w", err) + } + + return int(result.DeletedCount), nil +} diff --git a/internal/model/model.go b/internal/model/model.go index 10dc9bfb..7eaf7746 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -138,7 +138,61 @@ type ServerDetail struct { Remotes []Remote `json:"remotes,omitempty" bson:"remotes,omitempty"` } -// VerificationToken represents a domain verification token for a server +// VerificationStatus represents the verification status of a domain +type VerificationStatus string + +const ( + // VerificationStatusVerified indicates the domain is successfully verified + VerificationStatusVerified VerificationStatus = "verified" + // VerificationStatusWarning indicates the domain has failed verification once or twice + VerificationStatusWarning VerificationStatus = "warning" + // VerificationStatusUnverified indicates the domain has failed verification 3+ times + VerificationStatusUnverified VerificationStatus = "unverified" + // VerificationStatusFailed indicates the domain has failed verification and is no longer valid + VerificationStatusFailed VerificationStatus = "failed" + // VerificationStatusPending indicates initial verification is in progress + VerificationStatusPending VerificationStatus = "pending" +) + +// VerificationMethod represents the method used for domain verification +type VerificationMethod string + +const ( + // VerificationMethodDNS indicates DNS TXT record verification + VerificationMethodDNS VerificationMethod = "dns" + // VerificationMethodHTTP indicates HTTP-01 web challenge verification + VerificationMethodHTTP VerificationMethod = "http" +) + +// DomainVerification represents comprehensive domain verification tracking +type DomainVerification struct { + Domain string `json:"domain" bson:"domain"` + DNSToken string `json:"dns_token,omitempty" bson:"dns_token,omitempty"` + HTTPToken string `json:"http_token,omitempty" bson:"http_token,omitempty"` + Status VerificationStatus `json:"status" bson:"status"` + CreatedAt time.Time `json:"created_at" bson:"created_at"` + LastVerified time.Time `json:"last_verified" bson:"last_verified"` + LastVerificationAttempt time.Time `json:"last_verification_attempt" bson:"last_verification_attempt"` + ConsecutiveFailures int `json:"consecutive_failures" bson:"consecutive_failures"` + LastError string `json:"last_error,omitempty" bson:"last_error,omitempty"` + LastSuccessfulMethod VerificationMethod `json:"last_successful_method,omitempty" bson:"last_successful_method,omitempty"` + NextVerification time.Time `json:"next_verification" bson:"next_verification"` + LastNotificationSent time.Time `json:"last_notification_sent" bson:"last_notification_sent"` + + // Legacy compatibility field + VerificationTokens *VerificationTokens `json:"verification_tokens,omitempty" bson:"verification_tokens,omitempty"` + + // Legacy fields for backward compatibility + LastVerifiedAt *time.Time `json:"last_verified_at,omitempty" bson:"last_verified_at,omitempty"` + LastFailureAt *time.Time `json:"last_failure_at,omitempty" bson:"last_failure_at,omitempty"` + WarningNotifiedAt *time.Time `json:"warning_notified_at,omitempty" bson:"warning_notified_at,omitempty"` + DowngradeNotifiedAt *time.Time `json:"downgrade_notified_at,omitempty" bson:"downgrade_notified_at,omitempty"` + SuccessfulMethods []VerificationMethod `json:"successful_methods,omitempty" bson:"successful_methods,omitempty"` + LastDNSVerificationAt *time.Time `json:"last_dns_verification_at,omitempty" bson:"last_dns_verification_at,omitempty"` + LastHTTPVerificationAt *time.Time `json:"last_http_verification_at,omitempty" bson:"last_http_verification_at,omitempty"` +} + +// VerificationToken represents a domain verification token for a server (legacy, will be replaced by DomainVerification) type VerificationToken struct { Token string `json:"token" bson:"token"` CreatedAt time.Time `json:"created_at" bson:"created_at"` @@ -146,18 +200,19 @@ type VerificationToken struct { LastVerifiedAt *time.Time `json:"last_verified_at,omitempty" bson:"last_verified_at,omitempty"` } +// Metadata represents a metadata entry for a server +type Metadata struct { + ServerID string `json:"server_id" bson:"server_id"` + VerificationToken *VerificationToken `json:"verification_token,omitempty" bson:"verification_token,omitempty"` + DomainVerification *DomainVerification `json:"domain_verification,omitempty" bson:"domain_verification,omitempty"` +} + // VerificationTokens represents the collection of verification tokens for a domain type VerificationTokens struct { VerifiedToken *VerificationToken `json:"verified_token,omitempty" bson:"verified_token,omitempty"` PendingTokens []VerificationToken `json:"pending_tokens,omitempty" bson:"pending_tokens,omitempty"` } -// DomainVerification represents verification data for a specific domain -type DomainVerification struct { - Domain string `json:"domain" bson:"domain"` - VerificationTokens *VerificationTokens `json:"verification_tokens,omitempty" bson:"verification_tokens,omitempty"` -} - // DomainVerificationRequest represents a request to generate a verification token for a domain type DomainVerificationRequest struct { Domain string `json:"domain"` diff --git a/internal/verification/background_job.go b/internal/verification/background_job.go new file mode 100644 index 00000000..b36a5be2 --- /dev/null +++ b/internal/verification/background_job.go @@ -0,0 +1,475 @@ +package verification + +import ( + "context" + "errors" + "fmt" + "log" + "sync" + "time" + + "github.com/modelcontextprotocol/registry/internal/database" + "github.com/modelcontextprotocol/registry/internal/model" + cron "github.com/robfig/cron/v3" +) + +// BackgroundVerificationJob handles continuous domain verification +type BackgroundVerificationJob struct { + db database.Database + cron *cron.Cron + running bool + mu sync.RWMutex + config *BackgroundJobConfig + notifyFunc NotificationFunc + stopChan chan struct{} + doneChan chan struct{} +} + +// BackgroundJobConfig contains configuration for the background verification job +type BackgroundJobConfig struct { + // CronSchedule defines when to run verification (default: "0 0 2 * * *" - daily at 2 AM) + CronSchedule string + + // MaxConcurrentVerifications limits parallel verifications (default: 10) + MaxConcurrentVerifications int + + // VerificationTimeout is the timeout for each verification attempt (default: 30s) + VerificationTimeout time.Duration + + // FailureThreshold is the number of consecutive failures before marking as failed (default: 3) + FailureThreshold int + + // RetryDelay is the delay between verification attempts (default: 1s) + RetryDelay time.Duration + + // NotificationCooldown is the minimum time between failure notifications (default: 24h) + NotificationCooldown time.Duration + + // CleanupInterval is how often to clean up old verification records (default: 7 days) + CleanupInterval time.Duration +} + +// NotificationFunc is called when domain verification fails repeatedly +type NotificationFunc func(ctx context.Context, domain string, failures int, lastError error) + +// DefaultBackgroundJobConfig returns a sensible default configuration +func DefaultBackgroundJobConfig() *BackgroundJobConfig { + return &BackgroundJobConfig{ + CronSchedule: "0 0 2 * * *", // Daily at 2 AM (with seconds) + MaxConcurrentVerifications: 10, + VerificationTimeout: 30 * time.Second, + FailureThreshold: 3, + RetryDelay: time.Second, + NotificationCooldown: 24 * time.Hour, + CleanupInterval: 7 * 24 * time.Hour, // 7 days + } +} + +// NewBackgroundVerificationJob creates a new background verification job +func NewBackgroundVerificationJob(db database.Database, config *BackgroundJobConfig, notifyFunc NotificationFunc) *BackgroundVerificationJob { + if config == nil { + config = DefaultBackgroundJobConfig() + } + + if notifyFunc == nil { + notifyFunc = defaultNotificationFunc + } + + cronWithSeconds := cron.New(cron.WithSeconds()) + + return &BackgroundVerificationJob{ + db: db, + cron: cronWithSeconds, + config: config, + notifyFunc: notifyFunc, + stopChan: make(chan struct{}), + doneChan: make(chan struct{}), + } +} + +// Start begins the background verification job +func (bvj *BackgroundVerificationJob) Start(ctx context.Context) error { + bvj.mu.Lock() + defer bvj.mu.Unlock() + + if bvj.running { + return fmt.Errorf("background verification job is already running") + } + + // Add the main verification job + _, err := bvj.cron.AddFunc(bvj.config.CronSchedule, func() { + if err := bvj.runVerificationCycle(ctx); err != nil { + log.Printf("Background verification cycle failed: %v", err) + } + }) + if err != nil { + return fmt.Errorf("failed to schedule verification job: %w", err) + } + + // Add cleanup job (run daily at 3 AM) + _, err = bvj.cron.AddFunc("0 0 3 * * *", func() { + if err := bvj.runCleanup(ctx); err != nil { + log.Printf("Background cleanup failed: %v", err) + } + }) + if err != nil { + return fmt.Errorf("failed to schedule cleanup job: %w", err) + } + + bvj.cron.Start() + bvj.running = true + + log.Printf("Background verification job started with schedule: %s", bvj.config.CronSchedule) + + // Start monitoring goroutine + go bvj.monitor(ctx) + + return nil +} + +// Stop gracefully stops the background verification job +func (bvj *BackgroundVerificationJob) Stop() error { + bvj.mu.Lock() + defer bvj.mu.Unlock() + + if !bvj.running { + return fmt.Errorf("background verification job is not running") + } + + log.Println("Stopping background verification job...") + + // Stop the cron scheduler + bvj.cron.Stop() + + // Signal monitoring goroutine to stop + close(bvj.stopChan) + + // Wait for monitoring goroutine to finish + <-bvj.doneChan + + bvj.running = false + log.Println("Background verification job stopped") + + return nil +} + +// IsRunning returns whether the background job is currently running +func (bvj *BackgroundVerificationJob) IsRunning() bool { + bvj.mu.RLock() + defer bvj.mu.RUnlock() + return bvj.running +} + +// RunNow triggers an immediate verification cycle +func (bvj *BackgroundVerificationJob) RunNow(ctx context.Context) error { + return bvj.runVerificationCycle(ctx) +} + +// monitor runs in a separate goroutine to handle graceful shutdown +func (bvj *BackgroundVerificationJob) monitor(ctx context.Context) { + defer close(bvj.doneChan) + + select { + case <-ctx.Done(): + log.Println("Background verification job context canceled") + case <-bvj.stopChan: + log.Println("Background verification job stop signal received") + } +} + +// runVerificationCycle executes a complete verification cycle for all domains +func (bvj *BackgroundVerificationJob) runVerificationCycle(ctx context.Context) error { + log.Println("Starting background domain verification cycle") + startTime := time.Now() + + // Get all domains that need verification + domains, err := bvj.db.GetVerifiedDomains(ctx) + if err != nil { + return fmt.Errorf("failed to get verified domains: %w", err) + } + + if len(domains) == 0 { + log.Println("No verified domains found for background verification") + return nil + } + + log.Printf("Found %d verified domains for background verification", len(domains)) + + // Create semaphore for concurrent verification control + semaphore := make(chan struct{}, bvj.config.MaxConcurrentVerifications) + var wg sync.WaitGroup + + successCount := 0 + failureCount := 0 + var mu sync.Mutex + + // Process each domain + for _, domain := range domains { + wg.Add(1) + go func(domain string) { + defer wg.Done() + + // Acquire semaphore + semaphore <- struct{}{} + defer func() { <-semaphore }() + + // Create verification context with timeout + verifyCtx, cancel := context.WithTimeout(ctx, bvj.config.VerificationTimeout) + defer cancel() + + success := bvj.verifyDomain(verifyCtx, domain) + + mu.Lock() + if success { + successCount++ + } else { + failureCount++ + } + mu.Unlock() + }(domain) + } + + // Wait for all verifications to complete + wg.Wait() + + duration := time.Since(startTime) + log.Printf("Background verification cycle completed in %v: %d successful, %d failed", + duration, successCount, failureCount) + + return nil +} + +// verifyDomain performs verification for a single domain +func (bvj *BackgroundVerificationJob) verifyDomain(ctx context.Context, domain string) bool { + log.Printf("Starting background verification for domain: %s", domain) + + // Get existing domain verification record to fetch stored tokens + domainVerification, err := bvj.db.GetDomainVerification(ctx, domain) + if err != nil { + log.Printf("Failed to get domain verification record for %s: %v", domain, err) + return false + } + + // Extract tokens for verification + var dnsToken, httpToken string + if domainVerification.DNSToken != "" { + dnsToken = domainVerification.DNSToken + } + if domainVerification.HTTPToken != "" { + httpToken = domainVerification.HTTPToken + } + + // If no tokens are available, we can't verify + if dnsToken == "" && httpToken == "" { + log.Printf("No verification tokens found for domain %s", domain) + return false + } + + // Try both DNS and HTTP verification methods using stored tokens + methods := []model.VerificationMethod{model.VerificationMethodDNS, model.VerificationMethodHTTP} + var lastError error + + for _, method := range methods { + var token string + switch method { + case model.VerificationMethodDNS: + token = dnsToken + case model.VerificationMethodHTTP: + token = httpToken + } + + // Skip this method if we don't have a token for it + if token == "" { + continue + } + + success, err := bvj.runSingleVerification(ctx, domain, token, method) + if err != nil { + lastError = err + log.Printf("%s verification failed for %s: %v", method, domain, err) + continue + } + + if success { + return bvj.handleVerificationSuccess(ctx, domain, method) + } + } + + // All methods failed + return bvj.handleVerificationFailure(ctx, domain, lastError) +} + +// runSingleVerification performs a single verification attempt +func (bvj *BackgroundVerificationJob) runSingleVerification( + ctx context.Context, domain, token string, method model.VerificationMethod, +) (bool, error) { + switch method { + case model.VerificationMethodDNS: + // Create a custom config with the provided context's timeout + config := DefaultDNSConfig() + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 { + config.Timeout = remaining + } + // If remaining <= 0, keep the default timeout + } + result, err := VerifyDNSRecordWithConfig(ctx, domain, token, config) + if err != nil { + return false, err + } + return result.Success, nil + + case model.VerificationMethodHTTP: + // Create a custom config with the provided context's timeout + config := DefaultHTTPConfig() + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 { + config.Timeout = remaining + } + // If remaining <= 0, keep the default timeout + } + result, err := VerifyHTTPChallengeWithConfig(ctx, domain, token, config) + if err != nil { + return false, err + } + return result.Success, nil + + default: + return false, fmt.Errorf("unknown verification method: %s", method) + } +} + +// handleVerificationSuccess processes a successful verification +func (bvj *BackgroundVerificationJob) handleVerificationSuccess(ctx context.Context, domain string, method model.VerificationMethod) bool { + err := bvj.updateVerificationSuccess(ctx, domain, method) + if err != nil { + log.Printf("Failed to update verification success for %s: %v", domain, err) + return false + } + + log.Printf("Background verification successful for domain %s using %s", domain, method) + return true +} + +// handleVerificationFailure processes a failed verification +func (bvj *BackgroundVerificationJob) handleVerificationFailure(ctx context.Context, domain string, lastError error) bool { + err := bvj.updateVerificationFailure(ctx, domain, lastError) + if err != nil { + log.Printf("Failed to update verification failure for %s: %v", domain, err) + } else { + log.Printf("Background verification failed for domain %s", domain) + } + + return false +} + +// updateVerificationSuccess updates the domain verification record on successful verification +func (bvj *BackgroundVerificationJob) updateVerificationSuccess(ctx context.Context, domain string, method model.VerificationMethod) error { + now := time.Now() + + domainVerification := &model.DomainVerification{ + Domain: domain, + Status: model.VerificationStatusVerified, + LastVerified: now, + LastSuccessfulMethod: method, + ConsecutiveFailures: 0, // Reset failure count + NextVerification: now.Add(24 * time.Hour), // Next verification in 24 hours + } + + return bvj.db.UpdateDomainVerification(ctx, domainVerification) +} + +// updateVerificationFailure updates the domain verification record on failed verification +func (bvj *BackgroundVerificationJob) updateVerificationFailure(ctx context.Context, domain string, lastError error) error { + // Get current domain verification record + domainVerification, err := bvj.db.GetDomainVerification(ctx, domain) + if err != nil { + // If record doesn't exist, create a new one + if errors.Is(err, database.ErrNotFound) { + now := time.Now() + domainVerification = &model.DomainVerification{ + Domain: domain, + Status: model.VerificationStatusPending, + CreatedAt: now, + } + } else { + return fmt.Errorf("failed to get domain verification record: %w", err) + } + } + + now := time.Now() + domainVerification.ConsecutiveFailures++ + domainVerification.LastVerificationAttempt = now + if lastError != nil { + domainVerification.LastError = lastError.Error() + } + + // Check if we've exceeded the failure threshold + if domainVerification.ConsecutiveFailures >= bvj.config.FailureThreshold { + domainVerification.Status = model.VerificationStatusFailed + + // Send notification if cooldown period has passed + if domainVerification.LastNotificationSent.IsZero() || + now.Sub(domainVerification.LastNotificationSent) >= bvj.config.NotificationCooldown { + // Send notification when threshold is exceeded, regardless of whether we have a specific error + bvj.notifyFunc(ctx, domain, domainVerification.ConsecutiveFailures, lastError) + domainVerification.LastNotificationSent = now + } + } + + // Calculate next verification time (exponential backoff) + nextVerification := now.Add(time.Duration(domainVerification.ConsecutiveFailures) * time.Hour) + if nextVerification.Sub(now) > 24*time.Hour { + nextVerification = now.Add(24 * time.Hour) // Cap at 24 hours + } + domainVerification.NextVerification = nextVerification + + return bvj.db.UpdateDomainVerification(ctx, domainVerification) +} + +// runCleanup removes old verification records and performs maintenance +func (bvj *BackgroundVerificationJob) runCleanup(ctx context.Context) error { + log.Println("Starting background verification cleanup") + + cutoffTime := time.Now().Add(-bvj.config.CleanupInterval) + + // Clean up old failed verification records + count, err := bvj.db.CleanupOldVerifications(ctx, cutoffTime) + if err != nil { + return fmt.Errorf("failed to cleanup old verifications: %w", err) + } + + log.Printf("Background cleanup completed: removed %d old verification records", count) + return nil +} + +// defaultNotificationFunc is a simple notification function that logs failures +func defaultNotificationFunc(ctx context.Context, domain string, failures int, lastError error) { + if lastError != nil { + log.Printf("ALERT: Domain %s has failed verification %d times consecutively. Last error: %v", + domain, failures, lastError) + } else { + log.Printf("ALERT: Domain %s has failed verification %d times consecutively. No specific error available.", + domain, failures) + } +} + +// GetStatus returns the current status of the background verification job +func (bvj *BackgroundVerificationJob) GetStatus() map[string]any { + bvj.mu.RLock() + defer bvj.mu.RUnlock() + + status := map[string]any{ + "running": bvj.running, + "cron_schedule": bvj.config.CronSchedule, + "config": bvj.config, + } + + if bvj.running { + status["cron_entries"] = len(bvj.cron.Entries()) + } + + return status +} diff --git a/internal/verification/background_job_test.go b/internal/verification/background_job_test.go new file mode 100644 index 00000000..ffbb20ae --- /dev/null +++ b/internal/verification/background_job_test.go @@ -0,0 +1,469 @@ +package verification //nolint:testpackage // Need access to unexported fields for testing + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/modelcontextprotocol/registry/internal/database" + "github.com/modelcontextprotocol/registry/internal/model" +) + +const ( + defaultCronScheduleBackgroundJob = "0 0 2 * * *" + testDomainBackgroundJob = "example.com" + errMsgGetDomainVerifyBackgroundJob = "Failed to get domain verification: %v" +) + +// mockDatabase is a mock implementation of the Database interface for testing +type mockDatabase struct { + mu sync.RWMutex + verifiedDomains []string + domainVerifications map[string]*model.DomainVerification + cleanupCount int +} + +func newMockDatabase() *mockDatabase { + return &mockDatabase{ + domainVerifications: make(map[string]*model.DomainVerification), + } +} + +func (m *mockDatabase) List(ctx context.Context, filter map[string]any, cursor string, limit int) ([]*model.Server, string, error) { + return nil, "", nil +} + +func (m *mockDatabase) GetByID(ctx context.Context, id string) (*model.ServerDetail, error) { + return nil, database.ErrNotFound +} + +func (m *mockDatabase) Publish(ctx context.Context, serverDetail *model.ServerDetail) error { + return nil +} + +func (m *mockDatabase) StoreVerificationToken(ctx context.Context, serverID string, token *model.VerificationToken) error { + return nil +} + +func (m *mockDatabase) GetVerificationToken(ctx context.Context, serverID string) (*model.VerificationToken, error) { + return nil, database.ErrNotFound +} + +func (m *mockDatabase) ImportSeed(ctx context.Context, seedFilePath string) error { + return nil +} + +func (m *mockDatabase) Close() error { + return nil +} + +func (m *mockDatabase) GetVerifiedDomains(ctx context.Context) ([]string, error) { + m.mu.RLock() + defer m.mu.RUnlock() + return append([]string{}, m.verifiedDomains...), nil +} + +func (m *mockDatabase) GetDomainVerification(ctx context.Context, domain string) (*model.DomainVerification, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + dv, exists := m.domainVerifications[domain] + if !exists { + return nil, database.ErrNotFound + } + + // Return a copy to avoid race conditions + result := *dv + return &result, nil +} + +func (m *mockDatabase) UpdateDomainVerification(ctx context.Context, domainVerification *model.DomainVerification) error { + m.mu.Lock() + defer m.mu.Unlock() + + // Make a copy to avoid sharing memory + dv := *domainVerification + m.domainVerifications[domainVerification.Domain] = &dv + return nil +} + +func (m *mockDatabase) CleanupOldVerifications(ctx context.Context, before time.Time) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.cleanupCount++ + return 1, nil +} + +func (m *mockDatabase) GetVerificationTokens(ctx context.Context, domain string) (*model.VerificationTokens, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + dv, exists := m.domainVerifications[domain] + if !exists || dv.VerificationTokens == nil { + return nil, database.ErrNotFound + } + + // Return a copy to avoid race conditions + result := *dv.VerificationTokens + return &result, nil +} + +func (m *mockDatabase) addVerifiedDomain(domain string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.verifiedDomains = append(m.verifiedDomains, domain) + now := time.Now() + + // Generate test tokens for the domain + dnsToken, _ := GenerateVerificationToken() + httpToken, _ := GenerateVerificationToken() + + m.domainVerifications[domain] = &model.DomainVerification{ + Domain: domain, + Status: model.VerificationStatusVerified, + CreatedAt: now, + LastVerified: now, + ConsecutiveFailures: 0, + DNSToken: dnsToken, + HTTPToken: httpToken, + } +} + +// mockNotificationFunc captures notifications for testing +type mockNotificationFunc struct { + mu sync.RWMutex + notifications []notificationRecord +} + +type notificationRecord struct { + domain string + failures int + err error +} + +func (m *mockNotificationFunc) notify(ctx context.Context, domain string, failures int, lastError error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.notifications = append(m.notifications, notificationRecord{ + domain: domain, + failures: failures, + err: lastError, + }) +} + +func (m *mockNotificationFunc) getNotifications() []notificationRecord { + m.mu.RLock() + defer m.mu.RUnlock() + + return append([]notificationRecord{}, m.notifications...) +} + +func TestNewBackgroundVerificationJob(t *testing.T) { + db := newMockDatabase() + config := DefaultBackgroundJobConfig() + mockNotify := &mockNotificationFunc{} + + job := NewBackgroundVerificationJob(db, config, mockNotify.notify) + + if job == nil { + t.Fatal("NewBackgroundVerificationJob returned nil") + } + + if job.db != db { + t.Error("Database not set correctly") + } + + if job.config != config { + t.Error("Config not set correctly") + } + + if job.running { + t.Error("Job should not be running initially") + } +} + +func TestNewBackgroundVerificationJobWithDefaults(t *testing.T) { + db := newMockDatabase() + + job := NewBackgroundVerificationJob(db, nil, nil) + + if job == nil { + t.Fatal("NewBackgroundVerificationJob returned nil") + } + + if job.config == nil { + t.Error("Default config not set") + } + + if job.config.CronSchedule != defaultCronScheduleBackgroundJob { + t.Errorf("Default cron schedule = %s, want %s", job.config.CronSchedule, defaultCronScheduleBackgroundJob) + } +} + +func TestDefaultBackgroundJobConfig(t *testing.T) { + config := DefaultBackgroundJobConfig() + + if config == nil { + t.Fatal("DefaultBackgroundJobConfig returned nil") + } + + if config.CronSchedule != defaultCronScheduleBackgroundJob { + t.Errorf("CronSchedule = %s, want %s", config.CronSchedule, defaultCronScheduleBackgroundJob) + } + + if config.MaxConcurrentVerifications != 10 { + t.Errorf("MaxConcurrentVerifications = %d, want %d", config.MaxConcurrentVerifications, 10) + } + + if config.VerificationTimeout != 30*time.Second { + t.Errorf("VerificationTimeout = %v, want %v", config.VerificationTimeout, 30*time.Second) + } + + if config.FailureThreshold != 3 { + t.Errorf("FailureThreshold = %d, want %d", config.FailureThreshold, 3) + } +} + +func TestBackgroundJobStartStop(t *testing.T) { + db := newMockDatabase() + config := &BackgroundJobConfig{ + CronSchedule: "0 0 * * * *", // Every minute for testing (with seconds) + MaxConcurrentVerifications: 1, + VerificationTimeout: 1 * time.Second, + FailureThreshold: 1, + RetryDelay: 100 * time.Millisecond, + NotificationCooldown: 1 * time.Second, + CleanupInterval: 1 * time.Hour, + } + mockNotify := &mockNotificationFunc{} + + job := NewBackgroundVerificationJob(db, config, mockNotify.notify) + ctx := context.Background() + + // Test starting the job + err := job.Start(ctx) + if err != nil { + t.Fatalf("Failed to start job: %v", err) + } + + if !job.IsRunning() { + t.Error("Job should be running after start") + } + + // Test that starting again fails + err = job.Start(ctx) + if err == nil { + t.Error("Starting already running job should fail") + } + + // Test stopping the job + err = job.Stop() + if err != nil { + t.Fatalf("Failed to stop job: %v", err) + } + + if job.IsRunning() { + t.Error("Job should not be running after stop") + } + + // Test that stopping again fails + err = job.Stop() + if err == nil { + t.Error("Stopping already stopped job should fail") + } +} + +func TestRunNow(t *testing.T) { + db := newMockDatabase() + db.addVerifiedDomain(testDomainBackgroundJob) + + mockNotify := &mockNotificationFunc{} + job := NewBackgroundVerificationJob(db, DefaultBackgroundJobConfig(), mockNotify.notify) + + ctx := context.Background() + err := job.RunNow(ctx) + if err != nil { + t.Errorf("RunNow failed: %v", err) + } +} + +func TestRunNowWithNoDomains(t *testing.T) { + db := newMockDatabase() + mockNotify := &mockNotificationFunc{} + job := NewBackgroundVerificationJob(db, DefaultBackgroundJobConfig(), mockNotify.notify) + + ctx := context.Background() + err := job.RunNow(ctx) + if err != nil { + t.Errorf("RunNow with no domains failed: %v", err) + } +} + +func TestGetStatus(t *testing.T) { + db := newMockDatabase() + config := DefaultBackgroundJobConfig() + mockNotify := &mockNotificationFunc{} + + job := NewBackgroundVerificationJob(db, config, mockNotify.notify) + + status := job.GetStatus() + if status == nil { + t.Fatal("GetStatus returned nil") + } + + if status["running"] != false { + t.Error("Status should show not running initially") + } + + if status["cron_schedule"] != config.CronSchedule { + t.Errorf("Status cron_schedule = %v, want %v", status["cron_schedule"], config.CronSchedule) + } + + // Start job and check status again + ctx := context.Background() + err := job.Start(ctx) + if err != nil { + t.Fatalf("Failed to start job: %v", err) + } + defer job.Stop() + + status = job.GetStatus() + if status["running"] != true { + t.Error("Status should show running after start") + } + + if status["cron_entries"] == nil { + t.Error("Status should include cron_entries when running") + } +} + +func TestUpdateVerificationSuccess(t *testing.T) { + db := newMockDatabase() + mockNotify := &mockNotificationFunc{} + job := NewBackgroundVerificationJob(db, DefaultBackgroundJobConfig(), mockNotify.notify) + + ctx := context.Background() + domain := testDomainBackgroundJob + method := model.VerificationMethodDNS + + err := job.updateVerificationSuccess(ctx, domain, method) + if err != nil { + t.Errorf("updateVerificationSuccess failed: %v", err) + } + + // Verify the domain verification was updated + dv, err := db.GetDomainVerification(ctx, domain) + if err != nil { + t.Fatalf(errMsgGetDomainVerifyBackgroundJob, err) + } + + if dv.Status != model.VerificationStatusVerified { + t.Errorf("Status = %s, want %s", dv.Status, model.VerificationStatusVerified) + } + + if dv.LastSuccessfulMethod != method { + t.Errorf("LastSuccessfulMethod = %s, want %s", dv.LastSuccessfulMethod, method) + } + + if dv.ConsecutiveFailures != 0 { + t.Errorf("ConsecutiveFailures = %d, want %d", dv.ConsecutiveFailures, 0) + } +} + +func TestUpdateVerificationFailure(t *testing.T) { + db := newMockDatabase() + config := &BackgroundJobConfig{ + FailureThreshold: 3, + NotificationCooldown: 1 * time.Hour, + } + mockNotify := &mockNotificationFunc{} + job := NewBackgroundVerificationJob(db, config, mockNotify.notify) + + ctx := context.Background() + domain := testDomainBackgroundJob + testError := errors.New("verification failed") + + // Set up initial domain verification + now := time.Now() + initialDV := &model.DomainVerification{ + Domain: domain, + Status: model.VerificationStatusVerified, + CreatedAt: now, + ConsecutiveFailures: 0, + } + db.UpdateDomainVerification(ctx, initialDV) + + // First failure + err := job.updateVerificationFailure(ctx, domain, testError) + if err != nil { + t.Errorf("updateVerificationFailure failed: %v", err) + } + + dv, err := db.GetDomainVerification(ctx, domain) + if err != nil { + t.Fatalf(errMsgGetDomainVerifyBackgroundJob, err) + } + + if dv.ConsecutiveFailures != 1 { + t.Errorf("ConsecutiveFailures = %d, want %d", dv.ConsecutiveFailures, 1) + } + + if dv.LastError != testError.Error() { + t.Errorf("LastError = %s, want %s", dv.LastError, testError.Error()) + } + + // Third failure (should trigger notification) + dv.ConsecutiveFailures = 2 + db.UpdateDomainVerification(ctx, dv) + + err = job.updateVerificationFailure(ctx, domain, testError) + if err != nil { + t.Errorf("updateVerificationFailure failed: %v", err) + } + + dv, err = db.GetDomainVerification(ctx, domain) + if err != nil { + t.Fatalf(errMsgGetDomainVerifyBackgroundJob, err) + } + + if dv.Status != model.VerificationStatusFailed { + t.Errorf("Status = %s, want %s", dv.Status, model.VerificationStatusFailed) + } + + // Check notification was sent + notifications := mockNotify.getNotifications() + if len(notifications) != 1 { + t.Errorf("Expected 1 notification, got %d", len(notifications)) + } + + if len(notifications) > 0 { + if notifications[0].domain != domain { + t.Errorf("Notification domain = %s, want %s", notifications[0].domain, domain) + } + if notifications[0].failures != 3 { + t.Errorf("Notification failures = %d, want %d", notifications[0].failures, 3) + } + } +} + +func TestRunSingleVerification(t *testing.T) { + db := newMockDatabase() + job := NewBackgroundVerificationJob(db, DefaultBackgroundJobConfig(), nil) + + domain := testDomainBackgroundJob + token := "test-token" + ctx := context.Background() + + // Test with valid method + success, err := job.runSingleVerification(ctx, domain, token, model.VerificationMethodHTTP) + if err != nil { + t.Logf("Expected error for domain verification: %v", err) + } + // Note: success can be false since we don't have a real server running + _ = success +} diff --git a/internal/verification/dns.go b/internal/verification/dns.go index 9dc1c2e2..83675fe8 100644 --- a/internal/verification/dns.go +++ b/internal/verification/dns.go @@ -6,7 +6,6 @@ import ( "fmt" "log" "net" - "strings" "time" ) @@ -112,50 +111,36 @@ func DefaultDNSConfig() *DNSVerificationConfig { // log.Printf("Domain %s verification failed: %s", result.Domain, result.Message) // } func VerifyDNSRecord(domain, expectedToken string) (*DNSVerificationResult, error) { - return VerifyDNSRecordWithConfig(domain, expectedToken, DefaultDNSConfig()) + return VerifyDNSRecordWithConfig(context.Background(), domain, expectedToken, DefaultDNSConfig()) } // VerifyDNSRecordWithConfig performs DNS verification with custom configuration -func VerifyDNSRecordWithConfig(domain, expectedToken string, config *DNSVerificationConfig) (*DNSVerificationResult, error) { +func VerifyDNSRecordWithConfig(ctx context.Context, domain, expectedToken string, config *DNSVerificationConfig) (*DNSVerificationResult, error) { startTime := time.Now() - // Input validation - if domain == "" { - return nil, &DNSVerificationError{ - Domain: domain, - Token: expectedToken, - Message: "domain cannot be empty", - } - } - - if expectedToken == "" { - return nil, &DNSVerificationError{ - Domain: domain, - Token: expectedToken, - Message: "token cannot be empty", - } - } - - // Validate token format - if !ValidateTokenFormat(expectedToken) { - return nil, &DNSVerificationError{ - Domain: domain, - Token: expectedToken, - Message: "invalid token format", + // Validate inputs and normalize domain + normalizedDomain, err := ValidateVerificationInputs(domain, expectedToken) + if err != nil { + var validationErr *ValidationError + if errors.As(err, &validationErr) { + return nil, &DNSVerificationError{ + Domain: validationErr.Domain, + Token: validationErr.Token, + Message: validationErr.Message, + } } + return nil, err } - - // Normalize domain (remove trailing dots, convert to lowercase) - domain = strings.ToLower(strings.TrimSuffix(domain, ".")) + domain = normalizedDomain log.Printf("Starting DNS verification for domain: %s with token: %s", domain, expectedToken) - // Create context with timeout - ctx, cancel := context.WithTimeout(context.Background(), config.Timeout) + // Create context with timeout based on the passed context + timeoutCtx, cancel := context.WithTimeout(ctx, config.Timeout) defer cancel() // Perform verification with retries - result, err := performDNSVerificationWithRetries(ctx, domain, expectedToken, config) + result, err := performDNSVerificationWithRetries(timeoutCtx, domain, expectedToken, config) // Calculate duration duration := time.Since(startTime) @@ -170,6 +155,8 @@ func VerifyDNSRecordWithConfig(domain, expectedToken string, config *DNSVerifica } // performDNSVerificationWithRetries implements the retry logic for DNS verification +// This function handles DNS-specific retry patterns including exponential backoff +// and DNS error classification for domain ownership verification via TXT records. func performDNSVerificationWithRetries( ctx context.Context, domain, expectedToken string, @@ -179,51 +166,54 @@ func performDNSVerificationWithRetries( var lastResult *DNSVerificationResult retryDelay := config.RetryDelay + maxRetries := config.MaxRetries + dnsRetryCount := 0 - for attempt := 0; attempt <= config.MaxRetries; attempt++ { + for attempt := 0; attempt <= maxRetries; attempt++ { + dnsRetryCount++ if attempt > 0 { - log.Printf("DNS verification retry %d/%d for domain %s after %v delay", - attempt+1, config.MaxRetries, domain, retryDelay) + log.Printf("DNS TXT record verification retry %d/%d for domain %s after %v delay", + attempt+1, maxRetries, domain, retryDelay) // Wait before retry with context cancellation support - timer := time.NewTimer(retryDelay) - select { - case <-timer.C: - // Timer fired normally, continue with retry - case <-ctx.Done(): - // Context canceled, stop timer to prevent leak - timer.Stop() + if !WaitWithContext(ctx, retryDelay) { return nil, &DNSVerificationError{ Domain: domain, Token: expectedToken, - Message: "verification canceled", + Message: "DNS verification canceled", Cause: ctx.Err(), } } - // Exponential backoff + // Exponential backoff with DNS-specific multiplier retryDelay *= 2 } + // Perform DNS TXT record lookup result, err := performDNSVerification(ctx, domain, expectedToken, config) if err == nil { + log.Printf("DNS verification succeeded on attempt %d for domain %s", dnsRetryCount, domain) return result, nil } lastErr = err lastResult = result - // Check if error is retryable + // Check if DNS error is retryable if !IsRetryableDNSError(err) { - log.Printf("Non-retryable DNS error for domain %s: %v", domain, err) + log.Printf("Non-retryable DNS TXT record error for domain %s: %v", domain, err) break } - log.Printf("Retryable DNS error for domain %s (attempt %d/%d): %v", - domain, attempt+1, config.MaxRetries, err) + log.Printf("Retryable DNS TXT record error for domain %s (attempt %d/%d): %v", + domain, attempt+1, maxRetries, err) } - // All retries exhausted + // All retries exhausted for DNS verification + if lastResult != nil { + log.Printf("DNS verification completed with %d total attempts and %d failures for domain %s", + dnsRetryCount, maxRetries+1, domain) + } return lastResult, lastErr } diff --git a/internal/verification/dns_mock_test.go b/internal/verification/dns_mock_test.go index e2433859..8b925639 100644 --- a/internal/verification/dns_mock_test.go +++ b/internal/verification/dns_mock_test.go @@ -27,7 +27,7 @@ func TestVerifyDNSRecordWithMockSuccess(t *testing.T) { config.Resolver = mockResolver config.Timeout = 1 * time.Second - result, err := verification.VerifyDNSRecordWithConfig(testDomain, token, config) + result, err := verification.VerifyDNSRecordWithConfig(context.Background(), testDomain, token, config) if err != nil { t.Errorf("VerifyDNSRecord returned unexpected error: %v", err) @@ -70,7 +70,7 @@ func TestVerifyDNSRecordWithMockTokenNotFound(t *testing.T) { config := verification.DefaultDNSConfig() config.Resolver = mockResolver - result, err := verification.VerifyDNSRecordWithConfig(testDomain, token, config) + result, err := verification.VerifyDNSRecordWithConfig(context.Background(), testDomain, token, config) if err != nil { t.Errorf("Unexpected error: %v", err) @@ -112,7 +112,7 @@ func TestVerifyDNSRecordWithMockDNSError(t *testing.T) { config.Resolver = mockResolver config.MaxRetries = 0 - result, err := verification.VerifyDNSRecordWithConfig(testDomain, token, config) + result, err := verification.VerifyDNSRecordWithConfig(context.Background(), testDomain, token, config) var dnsErr *verification.DNSVerificationError if !errors.As(err, &dnsErr) { @@ -147,7 +147,7 @@ func TestVerifyDNSRecordWithMockTimeout(t *testing.T) { config.Timeout = 50 * time.Millisecond config.MaxRetries = 0 - _, err = verification.VerifyDNSRecordWithConfig(testDomain, token, config) + _, err = verification.VerifyDNSRecordWithConfig(context.Background(), testDomain, token, config) if err == nil { t.Error("Expected timeout error") diff --git a/internal/verification/dns_resolver.go b/internal/verification/dns_resolver.go index ae79d86c..903f875e 100644 --- a/internal/verification/dns_resolver.go +++ b/internal/verification/dns_resolver.go @@ -22,7 +22,7 @@ func (d *DefaultDNSResolver) LookupTXT(ctx context.Context, name string) ([]stri // NewDefaultDNSResolver creates a DNS resolver with the given configuration // -//nolint:ireturn // Factory function returning interface is acceptable for dependency injection +//nolint:ireturn // Factory function intentionally returns interface for dependency injection func NewDefaultDNSResolver(config *DNSVerificationConfig) DNSResolver { if config.UseSecureResolvers && len(config.CustomResolvers) > 0 { // Create custom dialer for secure resolvers diff --git a/internal/verification/dns_test.go b/internal/verification/dns_test.go index 709a99fa..79691f4b 100644 --- a/internal/verification/dns_test.go +++ b/internal/verification/dns_test.go @@ -28,7 +28,7 @@ func TestVerifyDNSRecordSuccess(t *testing.T) { config := verification.DefaultDNSConfig() config.Resolver = mockResolver - result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + result, err := verification.VerifyDNSRecordWithConfig(context.Background(), domain, token, config) if err != nil { t.Errorf("VerifyDNSRecord returned unexpected error: %v", err) } @@ -77,7 +77,7 @@ func TestVerifyDNSRecordTokenNotFound(t *testing.T) { config := verification.DefaultDNSConfig() config.Resolver = mockResolver - result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + result, err := verification.VerifyDNSRecordWithConfig(context.Background(), domain, token, config) if err != nil { t.Errorf("VerifyDNSRecord returned unexpected error: %v", err) } @@ -178,7 +178,7 @@ func TestVerifyDNSRecordTokenFormatValidation(t *testing.T) { config := verification.DefaultDNSConfig() config.Resolver = mockResolver - result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + result, err := verification.VerifyDNSRecordWithConfig(context.Background(), domain, token, config) if err != nil { var dnsErr *verification.DNSVerificationError @@ -226,7 +226,7 @@ func TestVerifyDNSRecordWithConfigTimeout(t *testing.T) { } domain := testDomain - result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + result, err := verification.VerifyDNSRecordWithConfig(context.Background(), domain, token, config) if err == nil { t.Error("Expected timeout error but got none") @@ -298,7 +298,7 @@ func TestVerifyDNSRecordWithCustomPrefix(t *testing.T) { config.Resolver = mockResolver config.RecordPrefix = customPrefix - result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + result, err := verification.VerifyDNSRecordWithConfig(context.Background(), domain, token, config) if err != nil { t.Errorf("VerifyDNSRecord returned unexpected error: %v", err) } @@ -350,7 +350,7 @@ func TestVerifyDNSRecordCustomPrefixFailsWithWrongRecord(t *testing.T) { config.Resolver = mockResolver config.RecordPrefix = customPrefix - result, err := verification.VerifyDNSRecordWithConfig(domain, token, config) + result, err := verification.VerifyDNSRecordWithConfig(context.Background(), domain, token, config) if err != nil { t.Errorf("VerifyDNSRecord returned unexpected error: %v", err) } diff --git a/internal/verification/http.go b/internal/verification/http.go new file mode 100644 index 00000000..c6dd662b --- /dev/null +++ b/internal/verification/http.go @@ -0,0 +1,450 @@ +package verification + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "strings" + "time" +) + +// HTTPVerificationError represents errors that can occur during HTTP verification +type HTTPVerificationError struct { + Domain string + Token string + URL string + Message string + Cause error +} + +func (e *HTTPVerificationError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("HTTP verification failed for domain %s: %s (cause: %v)", + e.Domain, e.Message, e.Cause) + } + return fmt.Sprintf("HTTP verification failed for domain %s: %s", e.Domain, e.Message) +} + +func (e *HTTPVerificationError) Unwrap() error { + return e.Cause +} + +// HTTPVerificationResult represents the result of an HTTP verification attempt +type HTTPVerificationResult struct { + Success bool `json:"success"` + Domain string `json:"domain"` + Token string `json:"token"` + URL string `json:"url"` + Message string `json:"message"` + StatusCode int `json:"status_code,omitempty"` + ResponseBody string `json:"response_body,omitempty"` + Duration string `json:"duration"` +} + +// HTTPVerificationConfig holds configuration for HTTP verification +type HTTPVerificationConfig struct { + // Timeout for HTTP requests (default: 10 seconds) + Timeout time.Duration + + // MaxRetries for transient failures (default: 3) + MaxRetries int + + // RetryDelay base delay between retries (default: 1 second) + RetryDelay time.Duration + + // FollowRedirects whether to follow HTTP redirects (default: true) + FollowRedirects bool + + // UserAgent to use for HTTP requests (default: "MCP-Registry-Verifier/1.0") + UserAgent string + + // AllowHTTP whether to allow non-HTTPS URLs (default: false for security) + AllowHTTP bool + + // MaxResponseSize maximum size of response body to read (default: 1KB) + MaxResponseSize int64 + + // CustomTransport allows injecting a custom HTTP transport (primarily for testing) + CustomTransport http.RoundTripper +} + +// DefaultHTTPConfig returns the default configuration for HTTP verification +func DefaultHTTPConfig() *HTTPVerificationConfig { + return &HTTPVerificationConfig{ + Timeout: 10 * time.Second, + MaxRetries: 3, + RetryDelay: 1 * time.Second, + FollowRedirects: true, + UserAgent: "MCP-Registry-Verifier/1.0", + AllowHTTP: false, + MaxResponseSize: 1024, // 1KB should be enough for a token + } +} + +// VerifyHTTPChallenge verifies domain ownership by checking for a specific token +// at the well-known HTTP-01 challenge URL: https://domain/.well-known/mcp-challenge/token +// +// This function implements the HTTP-01 web challenge verification method described +// in the Server Name Verification system. It fetches the well-known URL and verifies +// that the response body exactly matches the expected token. +// +// Security considerations: +// - Only allows HTTPS by default to prevent man-in-the-middle attacks +// - Uses a short timeout to prevent hanging on slow responses +// - Limits response body size to prevent memory exhaustion attacks +// - Implements retry logic with exponential backoff for transient failures +// - Validates token format before making the HTTP request +// +// Parameters: +// - domain: The domain name to verify (e.g., "example.com") +// - expectedToken: The 128-bit token that should be served at the challenge URL +// +// Returns: +// - HTTPVerificationResult with verification status and details +// - An error if the verification process fails critically +// +// The default configuration uses HTTPS-only. To allow HTTP (not recommended for production), +// use VerifyHTTPChallengeWithConfig with AllowHTTP set to true. +// +// Example usage: +// +// result, err := VerifyHTTPChallenge("example.com", "TBeVXe_X4npM6p8vpzStnA") +// if err != nil { +// log.Printf("HTTP verification error: %v", err) +// return err +// } +// if result.Success { +// log.Printf("Domain %s verified successfully via HTTP", result.Domain) +// } else { +// log.Printf("Domain %s verification failed: %s", result.Domain, result.Message) +// } +func VerifyHTTPChallenge(domain, expectedToken string) (*HTTPVerificationResult, error) { + return VerifyHTTPChallengeWithConfig(context.Background(), domain, expectedToken, DefaultHTTPConfig()) +} + +// VerifyHTTPChallengeWithConfig performs HTTP verification with custom configuration +func VerifyHTTPChallengeWithConfig( + ctx context.Context, domain, expectedToken string, config *HTTPVerificationConfig, +) (*HTTPVerificationResult, error) { + startTime := time.Now() + + // Validate inputs and normalize domain + normalizedDomain, err := ValidateVerificationInputs(domain, expectedToken) + if err != nil { + var validationErr *ValidationError + if errors.As(err, &validationErr) { + return nil, &HTTPVerificationError{ + Domain: validationErr.Domain, + Token: validationErr.Token, + Message: validationErr.Message, + } + } + return nil, err + } + domain = normalizedDomain + + log.Printf("Starting HTTP verification for domain: %s with token: %s", domain, expectedToken) + + // Perform verification with retries using the provided context + result, err := performHTTPVerificationWithRetries(ctx, domain, expectedToken, config) + + // Calculate duration + duration := time.Since(startTime) + if result != nil { + result.Duration = duration.String() + } + + log.Printf("HTTP verification completed for domain %s in %v: success=%t", + domain, duration, result != nil && result.Success) + + return result, err +} + +// performHTTPVerificationWithRetries implements the retry logic for HTTP verification +// This function handles HTTP-01 challenge verification with retry patterns including +// exponential backoff and HTTP error classification for web-based domain verification. +func performHTTPVerificationWithRetries( + ctx context.Context, + domain, expectedToken string, + config *HTTPVerificationConfig, +) (*HTTPVerificationResult, error) { + var lastErr error + var lastResult *HTTPVerificationResult + + initialDelay := config.RetryDelay + currentDelay := initialDelay + httpAttempts := 0 + + for attempt := 0; attempt <= config.MaxRetries; attempt++ { + httpAttempts++ + if attempt > 0 { + log.Printf("HTTP-01 challenge verification retry %d/%d for domain %s after %v delay", + attempt+1, config.MaxRetries, domain, currentDelay) + + // Wait before retry with context cancellation support + if !WaitWithContext(ctx, currentDelay) { + return nil, &HTTPVerificationError{ + Domain: domain, + Token: expectedToken, + Message: "HTTP verification canceled", + Cause: ctx.Err(), + } + } + + // Exponential backoff with HTTP-specific multiplier + currentDelay *= 2 + } + + // Perform HTTP-01 challenge request + result, err := performHTTPVerification(ctx, domain, expectedToken, config) + if err == nil { + log.Printf("HTTP verification succeeded on attempt %d for domain %s", httpAttempts, domain) + return result, nil + } + + lastErr = err + lastResult = result + + // Check if HTTP error is retryable + if !IsRetryableHTTPError(err) { + log.Printf("Non-retryable HTTP-01 challenge error for domain %s: %v", domain, err) + break + } + + log.Printf("Retryable HTTP-01 challenge error for domain %s (attempt %d/%d): %v", + domain, attempt+1, config.MaxRetries, err) + } + + // All retries exhausted for HTTP verification + if lastResult != nil { + log.Printf("HTTP verification completed with %d total attempts and %d failures for domain %s", + httpAttempts, config.MaxRetries+1, domain) + } + return lastResult, lastErr +} + +// performHTTPVerification performs a single HTTP verification attempt +func performHTTPVerification(ctx context.Context, domain, expectedToken string, config *HTTPVerificationConfig) (*HTTPVerificationResult, error) { + // Construct the challenge URL + scheme := "https" + if config.AllowHTTP { + scheme = "http" + } + challengeURL := fmt.Sprintf("%s://%s/.well-known/mcp-challenge/%s", scheme, domain, expectedToken) + + // Create HTTP client + client := createHTTPClient(config) + + // Create request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, challengeURL, nil) + if err != nil { + httpErr := &HTTPVerificationError{ + Domain: domain, + Token: expectedToken, + URL: challengeURL, + Message: "failed to create HTTP request", + Cause: err, + } + + result := &HTTPVerificationResult{ + Success: false, + Domain: domain, + Token: expectedToken, + URL: challengeURL, + Message: httpErr.Message, + } + + return result, httpErr + } + + // Set User-Agent + req.Header.Set("User-Agent", config.UserAgent) + + log.Printf("Making HTTP request to: %s", challengeURL) + + // Make the request + resp, err := client.Do(req) + if err != nil { + httpErr := &HTTPVerificationError{ + Domain: domain, + Token: expectedToken, + URL: challengeURL, + Message: "failed to make HTTP request", + Cause: err, + } + + result := &HTTPVerificationResult{ + Success: false, + Domain: domain, + Token: expectedToken, + URL: challengeURL, + Message: httpErr.Message, + } + + return result, httpErr + } + defer resp.Body.Close() + + log.Printf("HTTP response status: %d for URL: %s", resp.StatusCode, challengeURL) + + // Check status code + if resp.StatusCode != http.StatusOK { + result := &HTTPVerificationResult{ + Success: false, + Domain: domain, + Token: expectedToken, + URL: challengeURL, + Message: fmt.Sprintf("HTTP request failed with status %d", resp.StatusCode), + StatusCode: resp.StatusCode, + } + + log.Printf("HTTP verification failed for domain %s: unexpected status code %d", domain, resp.StatusCode) + return result, nil + } + + // Read response body with size limit + limitedReader := io.LimitReader(resp.Body, config.MaxResponseSize) + body, err := io.ReadAll(limitedReader) + if err != nil { + httpErr := &HTTPVerificationError{ + Domain: domain, + Token: expectedToken, + URL: challengeURL, + Message: "failed to read response body", + Cause: err, + } + + result := &HTTPVerificationResult{ + Success: false, + Domain: domain, + Token: expectedToken, + URL: challengeURL, + Message: httpErr.Message, + StatusCode: resp.StatusCode, + } + + return result, httpErr + } + + responseBody := strings.TrimSpace(string(body)) + log.Printf("HTTP response body: '%s' (expected: '%s')", responseBody, expectedToken) + + // Check if response body matches expected token + if responseBody == expectedToken { + result := &HTTPVerificationResult{ + Success: true, + Domain: domain, + Token: expectedToken, + URL: challengeURL, + Message: "domain verification successful", + StatusCode: resp.StatusCode, + ResponseBody: responseBody, + } + + log.Printf("HTTP verification successful for domain %s", domain) + return result, nil + } + + // Token mismatch + result := &HTTPVerificationResult{ + Success: false, + Domain: domain, + Token: expectedToken, + URL: challengeURL, + Message: fmt.Sprintf("token mismatch: expected '%s', got '%s'", expectedToken, responseBody), + StatusCode: resp.StatusCode, + ResponseBody: responseBody, + } + + log.Printf("HTTP verification failed for domain %s: token mismatch", domain) + return result, nil +} + +// createHTTPClient creates an HTTP client with the specified configuration +func createHTTPClient(config *HTTPVerificationConfig) *http.Client { + // Use custom transport if provided (for testing) + if config.CustomTransport != nil { + return &http.Client{ + Transport: config.CustomTransport, + Timeout: config.Timeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if !config.FollowRedirects { + return http.ErrUseLastResponse + } + // Limit redirects to prevent infinite loops + if len(via) >= 10 { + return fmt.Errorf("too many redirects") + } + return nil + }, + } + } + + // Create custom transport with security settings + transport := &http.Transport{ + Dial: (&net.Dialer{ + Timeout: 5 * time.Second, + }).Dial, + TLSHandshakeTimeout: 5 * time.Second, + TLSClientConfig: &tls.Config{ + // Require valid certificates (no self-signed) + InsecureSkipVerify: false, + // Set minimum TLS version to 1.2 for security + MinVersion: tls.VersionTLS12, + }, + DisableKeepAlives: true, // Don't reuse connections for verification requests + } + + return &http.Client{ + Transport: transport, + Timeout: config.Timeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if !config.FollowRedirects { + return http.ErrUseLastResponse + } + // Limit redirects to prevent infinite loops + if len(via) >= 10 { + return fmt.Errorf("too many redirects") + } + return nil + }, + } +} + +// IsRetryableHTTPError determines if an HTTP error should be retried +func IsRetryableHTTPError(err error) bool { + if err == nil { + return false + } + + // Check for network timeouts and temporary failures + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + + // Check for context timeout + if errors.Is(err, context.DeadlineExceeded) { + return true + } + + // Check for DNS errors (might be temporary) + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return dnsErr.Timeout() + } + + // Don't retry on validation errors or permanent failures + var httpErr *HTTPVerificationError + if errors.As(err, &httpErr) { + return false + } + + // Default to not retryable for unknown errors + return false +} diff --git a/internal/verification/http_test.go b/internal/verification/http_test.go new file mode 100644 index 00000000..f19dd996 --- /dev/null +++ b/internal/verification/http_test.go @@ -0,0 +1,392 @@ +package verification_test + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/modelcontextprotocol/registry/internal/verification" +) + +const ( + errMsgGenTokenHTTP = "Failed to generate test token: %v" + wellKnownChallengePathHTTP = "/.well-known/mcp-challenge/%s" + httpsScheme = "https://" + errMsgUnexpectedHTTP = "VerifyHTTPChallenge returned unexpected error: %v" + errMsgNilResultHTTP = "VerifyHTTPChallenge returned nil result" + logMsgResultHTTP = "HTTP verification result: %+v" + testDomainHTTP = "example.com" + wrongTokenHTTP = "wrong-token" + resultStatusCodeHTTP = "Result status code = %d, want %d" + resultResponseBodyHTTP = "Result response body = %s, want %s" +) + +func TestVerifyHTTPChallenge(t *testing.T) { + // Generate a test token + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenTokenHTTP, err) + } + + // Create test server + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedPath := fmt.Sprintf(wellKnownChallengePathHTTP, token) + if r.URL.Path == expectedPath { + w.WriteHeader(http.StatusOK) + w.Write([]byte(token)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + // Extract domain from test server URL + domain := strings.TrimPrefix(server.URL, httpsScheme) + + // Create custom config with test server transport + config := verification.DefaultHTTPConfig() + config.CustomTransport = server.Client().Transport + + result, err := verification.VerifyHTTPChallengeWithConfig(context.Background(), domain, token, config) + + if err != nil { + t.Errorf(errMsgUnexpectedHTTP, err) + } + + if result == nil { + t.Fatal(errMsgNilResultHTTP) + } + + if !result.Success { + t.Errorf("Expected verification to succeed, got: %s", result.Message) + } + + if result.Domain != domain { + t.Errorf("Result domain = %s, want %s", result.Domain, domain) + } + + if result.Token != token { + t.Errorf("Result token = %s, want %s", result.Token, token) + } + + if result.StatusCode != http.StatusOK { + t.Errorf(resultStatusCodeHTTP, result.StatusCode, http.StatusOK) + } + + if result.ResponseBody != token { + t.Errorf(resultResponseBodyHTTP, result.ResponseBody, token) + } + + t.Logf(logMsgResultHTTP, result) +} + +func TestVerifyHTTPChallengeTokenNotFound(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + // Create test server that returns 404 for all requests + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("Not Found")) + })) + defer server.Close() + + domain := strings.TrimPrefix(server.URL, "https://") + + config := verification.DefaultHTTPConfig() + config.CustomTransport = server.Client().Transport + + result, err := verification.VerifyHTTPChallengeWithConfig(context.Background(), domain, token, config) + + if err != nil { + t.Errorf("VerifyHTTPChallenge returned unexpected error: %v", err) + } + + if result == nil { + t.Fatal("VerifyHTTPChallenge returned nil result") + } + + if result.Success { + t.Error("Expected verification to fail when token is not found") + } + + if !strings.Contains(result.Message, "404") { + t.Errorf("Expected '404' in message, got: %s", result.Message) + } + + if result.StatusCode != http.StatusNotFound { + t.Errorf("Result status code = %d, want %d", result.StatusCode, http.StatusNotFound) + } + + t.Logf("HTTP verification result: %+v", result) +} + +func TestVerifyHTTPChallengeTokenMismatch(t *testing.T) { + // Generate a test token + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenTokenHTTP, err) + } + + // Create test server that returns wrong token + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedPath := fmt.Sprintf(wellKnownChallengePathHTTP, token) + if r.URL.Path == expectedPath { + w.WriteHeader(http.StatusOK) + w.Write([]byte(wrongTokenHTTP)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + // Extract domain from test server URL + domain := strings.TrimPrefix(server.URL, httpsScheme) + + // Create custom config with test server transport + config := verification.DefaultHTTPConfig() + config.CustomTransport = server.Client().Transport + + result, err := verification.VerifyHTTPChallengeWithConfig(context.Background(), domain, token, config) + + if err != nil { + t.Errorf(errMsgUnexpectedHTTP, err) + } + + if result == nil { + t.Fatal(errMsgNilResultHTTP) + } + + if result.Success { + t.Error("Expected verification to fail due to token mismatch") + } + + if result.StatusCode != http.StatusOK { + t.Errorf(resultStatusCodeHTTP, result.StatusCode, http.StatusOK) + } + + if result.ResponseBody != wrongTokenHTTP { + t.Errorf(resultResponseBodyHTTP, result.ResponseBody, wrongTokenHTTP) + } + + t.Logf(logMsgResultHTTP, result) +} + +func TestVerifyHTTPChallengeInvalidInputs(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenTokenHTTP, err) + } + + testCases := []struct { + name string + domain string + token string + }{ + {"empty domain", "", token}, + {"empty token", testDomainHTTP, ""}, + {"invalid token", testDomainHTTP, "invalid-token"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := verification.VerifyHTTPChallenge(tc.domain, tc.token) + + if err == nil { + t.Error("Expected error for invalid input") + } + + var httpErr *verification.HTTPVerificationError + if !errors.As(err, &httpErr) { + t.Errorf("Expected HTTPVerificationError, got: %T", err) + } + + // Result should be nil for validation errors + if result != nil { + t.Error("Expected nil result for validation error") + } + }) + } +} + +func TestVerifyHTTPChallengeWithTimeout(t *testing.T) { + token, err := verification.GenerateVerificationToken() + if err != nil { + t.Fatalf(errMsgGenTokenHTTP, err) + } + + // Create test server that responds slowly + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) // Longer than the config timeout + expectedPath := fmt.Sprintf(wellKnownChallengePathHTTP, token) + if r.URL.Path == expectedPath { + w.WriteHeader(http.StatusOK) + w.Write([]byte(token)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + domain := strings.TrimPrefix(server.URL, httpsScheme) + + config := &verification.HTTPVerificationConfig{ + Timeout: 100 * time.Millisecond, + MaxRetries: 0, + RetryDelay: 0, + FollowRedirects: true, + UserAgent: "test", + AllowHTTP: false, + MaxResponseSize: 1024, + CustomTransport: server.Client().Transport, + } + + result, err := verification.VerifyHTTPChallengeWithConfig(context.Background(), domain, token, config) + + if err == nil { + t.Error("Expected timeout error but got none") + } else { + t.Logf("HTTP request failed as expected: %v", err) + // Verify it's a context timeout or network error + if !strings.Contains(err.Error(), "timeout") && !strings.Contains(err.Error(), "context deadline exceeded") { + t.Errorf("Expected timeout-related error, got: %v", err) + } + } + + if result == nil { + t.Fatal("Expected result but got nil") + } + + if result.Duration == "" { + t.Error("Expected duration to be populated") + } + + t.Logf("Verification completed in: %s", result.Duration) +} + +func TestDefaultHTTPConfig(t *testing.T) { + config := verification.DefaultHTTPConfig() + + if config == nil { + t.Fatal("DefaultHTTPConfig returned nil") + } + + if config.Timeout <= 0 { + t.Error("Default timeout should be positive") + } + + if config.MaxRetries < 0 { + t.Error("Default max retries should be non-negative") + } + + if config.RetryDelay <= 0 { + t.Error("Default retry delay should be positive") + } + + if !config.FollowRedirects { + t.Error("Default should follow redirects") + } + + if config.UserAgent == "" { + t.Error("Default should have user agent") + } + + if config.AllowHTTP { + t.Error("Default should not allow HTTP (HTTPS only)") + } + + if config.MaxResponseSize <= 0 { + t.Error("Default max response size should be positive") + } + + t.Logf("Default HTTP config: %+v", config) +} + +func TestHTTPVerificationError(t *testing.T) { + // Test error without cause + err1 := &verification.HTTPVerificationError{ + Domain: testDomainHTTP, + Token: "test-token", + URL: "https://example.com/.well-known/mcp-challenge/test-token", + Message: "test error", + } + + expectedMsg1 := "HTTP verification failed for domain example.com: test error" + if err1.Error() != expectedMsg1 { + t.Errorf("Error() = %q, want %q", err1.Error(), expectedMsg1) + } + + // Test error with cause + cause := errors.New("network error") + err2 := &verification.HTTPVerificationError{ + Domain: testDomainHTTP, + Token: "test-token", + URL: "https://example.com/.well-known/mcp-challenge/test-token", + Message: "request failed", + Cause: cause, + } + + expectedMsg2 := "HTTP verification failed for domain example.com: request failed (cause: network error)" + if err2.Error() != expectedMsg2 { + t.Errorf("Error() = %q, want %q", err2.Error(), expectedMsg2) + } + + // Test Unwrap + if !errors.Is(err2, cause) { + t.Errorf("Expected error to wrap cause, but errors.Is returned false") + } +} + +func TestIsRetryableHTTPError(t *testing.T) { + testCases := []struct { + name string + err error + retry bool + }{ + {"nil error", nil, false}, + {"context timeout", context.DeadlineExceeded, true}, + {"validation error", &verification.HTTPVerificationError{Message: "validation failed"}, false}, + {"network error", &mockNetError{timeout: true, temporary: false}, true}, + { + "non-retryable temporary network error (Temporary() deprecated)", + &mockNetError{timeout: false, temporary: true}, + false, // Not retryable: Temporary() is deprecated + }, + {"permanent network error", &mockNetError{timeout: false, temporary: false}, false}, + {"unknown error", errors.New("unknown"), false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := verification.IsRetryableHTTPError(tc.err) + if result != tc.retry { + t.Errorf("IsRetryableHTTPError(%v) = %t, want %t", tc.err, result, tc.retry) + } + }) + } +} + +// mockNetError implements net.Error for testing +type mockNetError struct { + timeout bool + temporary bool +} + +func (e *mockNetError) Error() string { + return "mock network error" +} + +func (e *mockNetError) Timeout() bool { + return e.timeout +} + +func (e *mockNetError) Temporary() bool { + return e.temporary +} diff --git a/internal/verification/validation.go b/internal/verification/validation.go new file mode 100644 index 00000000..d6b35471 --- /dev/null +++ b/internal/verification/validation.go @@ -0,0 +1,51 @@ +package verification + +import ( + "strings" +) + +// ValidationError represents a validation error that can be used by both DNS and HTTP verification +type ValidationError struct { + Domain string + Token string + Message string +} + +func (e *ValidationError) Error() string { + return e.Message +} + +// ValidateVerificationInputs performs common validation for both DNS and HTTP verification +// Returns the normalized domain and any validation error +func ValidateVerificationInputs(domain, token string) (string, error) { + // Input validation + if domain == "" { + return "", &ValidationError{ + Domain: domain, + Token: token, + Message: "domain cannot be empty", + } + } + + if token == "" { + return "", &ValidationError{ + Domain: domain, + Token: token, + Message: "token cannot be empty", + } + } + + // Validate token format + if !ValidateTokenFormat(token) { + return "", &ValidationError{ + Domain: domain, + Token: token, + Message: "invalid token format", + } + } + + // Normalize domain (remove trailing dots, convert to lowercase) + normalizedDomain := strings.ToLower(strings.TrimSuffix(domain, ".")) + + return normalizedDomain, nil +} diff --git a/internal/verification/wait.go b/internal/verification/wait.go new file mode 100644 index 00000000..a70dcadd --- /dev/null +++ b/internal/verification/wait.go @@ -0,0 +1,21 @@ +package verification + +import ( + "context" + "time" +) + +// WaitWithContext waits for the specified duration with context cancellation support +// Returns true if the timer completed normally, false if context was canceled +func WaitWithContext(ctx context.Context, duration time.Duration) bool { + timer := time.NewTimer(duration) + defer timer.Stop() + select { + case <-timer.C: + // Timer fired normally + return true + case <-ctx.Done(): + // Context canceled + return false + } +} diff --git a/registry b/registry deleted file mode 100755 index 55dbc976..00000000 Binary files a/registry and /dev/null differ