diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index a59a66f9ea..a37c4d940d 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -103,7 +103,7 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv if cfg.RepoAccessTTL != nil { opts = append(opts, lockdown.WithTTL(*cfg.RepoAccessTTL)) } - repoAccessCache = lockdown.GetInstance(gqlClient, restClient, opts...) + repoAccessCache = lockdown.NewRepoAccessCache(gqlClient, restClient, opts...) } return &githubClients{ diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index e3a031f999..1141fbce89 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -399,7 +399,7 @@ func (d *RequestDeps) GetRepoAccessCache(ctx context.Context) (*lockdown.RepoAcc } // Create repo access cache - instance := lockdown.GetInstance(gqlClient, restClient, d.RepoAccessOpts...) + instance := lockdown.NewRepoAccessCache(gqlClient, restClient, d.RepoAccessOpts...) return instance, nil } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index c2be1984f7..b04370976e 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -70,16 +70,15 @@ func (rt *repoAccessMockTransport) RoundTrip(req *http.Request) (*http.Response, value = repoAccessValue{isPrivate: false} } - responseBody, err := json.Marshal(map[string]any{ - "data": map[string]any{ - "viewer": map[string]any{ - "login": "test-viewer", - }, - "repository": map[string]any{ - "isPrivate": value.isPrivate, - }, - }, - }) + data := map[string]any{} + if strings.Contains(payload.Query, "viewer") { + data["viewer"] = map[string]any{"login": "test-viewer"} + } + if strings.Contains(payload.Query, "repository") { + data["repository"] = map[string]any{"isPrivate": value.isPrivate} + } + + responseBody, err := json.Marshal(map[string]any{"data": data}) if err != nil { return nil, err } diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index f787875b2e..238ccb06ee 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "maps" "strings" "sync" "time" @@ -15,27 +16,29 @@ import ( // RepoAccessCache caches repository metadata related to lockdown checks so that // multiple tools can reuse the same access information safely across goroutines. +// In HTTP mode each request must construct its own instance so viewer-scoped +// lookups run under the requesting user's credentials. type RepoAccessCache struct { client *githubv4.Client restClient *github.Client - mu sync.Mutex cache *cache2go.CacheTable ttl time.Duration logger *slog.Logger trustedBotLogins map[string]struct{} + + viewerMu sync.Mutex + viewerLogin string } type repoAccessCacheEntry struct { - isPrivate bool - knownUsers map[string]bool // normalized login -> has push access - viewerLogin string + isPrivate bool + knownUsers map[string]bool // normalized login -> has push access } // RepoAccessInfo captures repository metadata needed for lockdown decisions. type RepoAccessInfo struct { IsPrivate bool HasPushAccess bool - ViewerLogin string } const ( @@ -43,11 +46,6 @@ const ( defaultRepoAccessCacheKey = "repo-access-cache" ) -var ( - instance *RepoAccessCache - instanceMu sync.Mutex -) - // RepoAccessOption configures RepoAccessCache at construction time. type RepoAccessOption func(*RepoAccessCache) @@ -66,8 +64,8 @@ func WithLogger(logger *slog.Logger) RepoAccessOption { } } -// WithCacheName overrides the cache table name used for storing entries. This option is intended for tests -// that need isolated cache instances. +// WithCacheName overrides the cache table name used for storing entries. +// Use this to isolate cache entries between tenants or in tests. func WithCacheName(name string) RepoAccessOption { return func(c *RepoAccessCache) { if name != "" { @@ -76,25 +74,8 @@ func WithCacheName(name string) RepoAccessOption { } } -// GetInstance returns the singleton instance of RepoAccessCache. -// It initializes the instance on first call with the provided client and options. -// Subsequent calls ignore the client and options parameters and return the existing instance. -// This is the preferred way to access the cache in production code. -func GetInstance(client *githubv4.Client, restClient *github.Client, opts ...RepoAccessOption) *RepoAccessCache { - instanceMu.Lock() - defer instanceMu.Unlock() - if instance == nil { - instance = newRepoAccessCache(client, restClient, opts...) - } - return instance -} - -// NewRepoAccessCache creates a standalone cache instance, used for tests. +// NewRepoAccessCache creates a RepoAccessCache bound to the supplied clients. func NewRepoAccessCache(client *githubv4.Client, restClient *github.Client, opts ...RepoAccessOption) *RepoAccessCache { - return newRepoAccessCache(client, restClient, opts...) -} - -func newRepoAccessCache(client *githubv4.Client, restClient *github.Client, opts ...RepoAccessOption) *RepoAccessCache { c := &RepoAccessCache{ client: client, restClient: restClient, @@ -113,13 +94,6 @@ func newRepoAccessCache(client *githubv4.Client, restClient *github.Client, opts return c } -// SetLogger updates the logger used for cache diagnostics. -func (c *RepoAccessCache) SetLogger(logger *slog.Logger) { - c.mu.Lock() - c.logger = logger - c.mu.Unlock() -} - // CacheStats summarizes cache activity counters. type CacheStats struct { Hits int64 @@ -150,10 +124,55 @@ func (c *RepoAccessCache) IsSafeContent(ctx context.Context, username, owner, re c.logDebug(ctx, fmt.Sprintf("evaluated repo access for user %s to %s/%s for content filtering, result: hasPushAccess=%t, isPrivate=%t", username, owner, repo, repoInfo.HasPushAccess, repoInfo.IsPrivate)) - if repoInfo.IsPrivate || repoInfo.ViewerLogin == strings.ToLower(username) { + if repoInfo.IsPrivate { + return true, nil + } + if repoInfo.HasPushAccess { return true, nil } - return repoInfo.HasPushAccess, nil + + viewerLogin, err := c.viewerLoginFor(ctx) + if err != nil { + return false, err + } + return viewerLogin == strings.ToLower(username), nil +} + +func (c *RepoAccessCache) viewerLoginFor(ctx context.Context) (string, error) { + c.viewerMu.Lock() + defer c.viewerMu.Unlock() + if c.viewerLogin != "" { + return c.viewerLogin, nil + } + if c.client == nil { + return "", fmt.Errorf("nil GraphQL client") + } + var query struct { + Viewer struct { + Login githubv4.String + } + } + if err := c.client.Query(ctx, &query, nil); err != nil { + return "", fmt.Errorf("failed to query viewer login: %w", err) + } + login := strings.ToLower(string(query.Viewer.Login)) + if login == "" { + return "", fmt.Errorf("viewer login returned empty") + } + c.viewerLogin = login + return c.viewerLogin, nil +} + +// setViewerLogin seeds the cached viewer login from a piggy-backed query response. +func (c *RepoAccessCache) setViewerLogin(login string) { + if login == "" { + return + } + c.viewerMu.Lock() + defer c.viewerMu.Unlock() + if c.viewerLogin == "" { + c.viewerLogin = strings.ToLower(login) + } } func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) { @@ -163,19 +182,16 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner key := cacheKey(owner, repo) userKey := strings.ToLower(username) - c.mu.Lock() - defer c.mu.Unlock() - // Try to get entry from cache - this will keep the item alive if it exists - cacheItem, err := c.cache.Value(key) - if err == nil { + // Entries are immutable once added: the cache table is shared across instances, + // so we publish a fresh entry with a cloned knownUsers map on every miss. + if cacheItem, err := c.cache.Value(key); err == nil { entry := cacheItem.Data().(*repoAccessCacheEntry) if cachedHasPush, known := entry.knownUsers[userKey]; known { c.logDebug(ctx, fmt.Sprintf("repo access cache hit for user %s to %s/%s", username, owner, repo)) return RepoAccessInfo{ IsPrivate: entry.isPrivate, HasPushAccess: cachedHasPush, - ViewerLogin: entry.viewerLogin, }, nil } @@ -186,41 +202,48 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner return RepoAccessInfo{}, pushErr } - entry.knownUsers[userKey] = hasPush - c.cache.Add(key, c.ttl, entry) + users := make(map[string]bool, len(entry.knownUsers)+1) + maps.Copy(users, entry.knownUsers) + users[userKey] = hasPush + c.cache.Add(key, c.ttl, &repoAccessCacheEntry{ + isPrivate: entry.isPrivate, + knownUsers: users, + }) return RepoAccessInfo{ IsPrivate: entry.isPrivate, HasPushAccess: hasPush, - ViewerLogin: entry.viewerLogin, }, nil } c.logDebug(ctx, fmt.Sprintf("repo access cache miss for user %s to %s/%s", username, owner, repo)) - info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) + isPrivate, viewerLogin, queryErr := c.queryRepoAccessInfo(ctx, owner, repo) if queryErr != nil { return RepoAccessInfo{}, queryErr } + c.setViewerLogin(viewerLogin) - // Create new entry - entry := &repoAccessCacheEntry{ - knownUsers: map[string]bool{userKey: info.HasPushAccess}, - isPrivate: info.IsPrivate, - viewerLogin: info.ViewerLogin, + hasPush, pushErr := c.checkPushAccess(ctx, username, owner, repo) + if pushErr != nil { + return RepoAccessInfo{}, pushErr } - c.cache.Add(key, c.ttl, entry) + + c.cache.Add(key, c.ttl, &repoAccessCacheEntry{ + knownUsers: map[string]bool{userKey: hasPush}, + isPrivate: isPrivate, + }) return RepoAccessInfo{ - IsPrivate: entry.isPrivate, - HasPushAccess: entry.knownUsers[userKey], - ViewerLogin: entry.viewerLogin, + IsPrivate: isPrivate, + HasPushAccess: hasPush, }, nil } -func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) { +// queryRepoAccessInfo fetches repository visibility and the viewer login in a single GraphQL round-trip. +func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, owner, repo string) (bool, string, error) { if c.client == nil { - return RepoAccessInfo{}, fmt.Errorf("nil GraphQL client") + return false, "", fmt.Errorf("nil GraphQL client") } var query struct { @@ -238,22 +261,12 @@ func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, own } if err := c.client.Query(ctx, &query, variables); err != nil { - return RepoAccessInfo{}, fmt.Errorf("failed to query repository metadata: %w", err) - } - - hasPush, err := c.checkPushAccess(ctx, username, owner, repo) - if err != nil { - return RepoAccessInfo{}, err + return false, "", fmt.Errorf("failed to query repository metadata: %w", err) } - c.logDebug(ctx, fmt.Sprintf("queried repo access info for user %s to %s/%s: isPrivate=%t, hasPushAccess=%t, viewerLogin=%s", - username, owner, repo, bool(query.Repository.IsPrivate), hasPush, query.Viewer.Login)) + c.logDebug(ctx, fmt.Sprintf("queried repo access info for %s/%s: isPrivate=%t", owner, repo, bool(query.Repository.IsPrivate))) - return RepoAccessInfo{ - IsPrivate: bool(query.Repository.IsPrivate), - HasPushAccess: hasPush, - ViewerLogin: string(query.Viewer.Login), - }, nil + return bool(query.Repository.IsPrivate), string(query.Viewer.Login), nil } // checkPushAccess checks if the user has push access to the repository via the REST permission endpoint. diff --git a/pkg/lockdown/lockdown_test.go b/pkg/lockdown/lockdown_test.go index bb8887e709..f16d6a062c 100644 --- a/pkg/lockdown/lockdown_test.go +++ b/pkg/lockdown/lockdown_test.go @@ -2,6 +2,7 @@ package lockdown import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "sync" @@ -20,7 +21,13 @@ const ( testUser = "octocat" ) -type repoMetadataQuery struct { +type viewerLoginQuery struct { + Viewer struct { + Login githubv4.String + } +} + +type repoAccessQuery struct { Viewer struct { Login githubv4.String } @@ -48,42 +55,59 @@ func (c *countingTransport) CallCount() int { return c.calls } -func newMockRepoAccessCache(t *testing.T, ttl time.Duration) (*RepoAccessCache, *countingTransport) { - t.Helper() - - var query repoMetadataQuery - +func newMockGQLClient(viewerLogin string, isPrivate bool) (*githubv4.Client, *countingTransport) { variables := map[string]any{ "owner": githubv4.String(testOwner), "name": githubv4.String(testRepo), } - response := githubv4mock.DataResponse(map[string]any{ - "viewer": map[string]any{ - "login": testUser, - }, - "repository": map[string]any{ - "isPrivate": false, - }, - }) - - httpClient := githubv4mock.NewMockedHTTPClient(githubv4mock.NewQueryMatcher(query, variables, response)) + httpClient := githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + viewerLoginQuery{}, + nil, + githubv4mock.DataResponse(map[string]any{ + "viewer": map[string]any{"login": viewerLogin}, + }), + ), + githubv4mock.NewQueryMatcher( + repoAccessQuery{}, + variables, + githubv4mock.DataResponse(map[string]any{ + "viewer": map[string]any{"login": viewerLogin}, + "repository": map[string]any{"isPrivate": isPrivate}, + }), + ), + ) counting := &countingTransport{next: httpClient.Transport} httpClient.Transport = counting gqlClient := githubv4.NewClient(httpClient) + return gqlClient, counting +} +func newMockRESTServer(t *testing.T, permission string) *gogithub.Client { + t.Helper() restServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - resp := gogithub.RepositoryPermissionLevel{ - Permission: gogithub.Ptr("write"), - } + resp := gogithub.RepositoryPermissionLevel{Permission: gogithub.Ptr(permission)} w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(resp) })) t.Cleanup(restServer.Close) restClient, err := gogithub.NewClient(gogithub.WithEnterpriseURLs(restServer.URL+"/", restServer.URL+"/")) require.NoError(t, err) + return restClient +} - return NewRepoAccessCache(gqlClient, restClient, WithTTL(ttl)), counting +func newMockRepoAccessCache(t *testing.T, ttl time.Duration) (*RepoAccessCache, *countingTransport) { + t.Helper() + gqlClient, counting := newMockGQLClient(testUser, false) + restClient := newMockRESTServer(t, "write") + cache := NewRepoAccessCache( + gqlClient, + restClient, + WithTTL(ttl), + WithCacheName(t.Name()), + ) + return cache, counting } func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { @@ -92,7 +116,7 @@ func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { cache, transport := newMockRepoAccessCache(t, 5*time.Millisecond) info, err := cache.getRepoAccessInfo(ctx, testUser, testOwner, testRepo) require.NoError(t, err) - require.Equal(t, testUser, info.ViewerLogin) + require.False(t, info.IsPrivate) require.True(t, info.HasPushAccess) require.EqualValues(t, 1, transport.CallCount()) @@ -100,7 +124,95 @@ func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { info, err = cache.getRepoAccessInfo(ctx, testUser, testOwner, testRepo) require.NoError(t, err) - require.Equal(t, testUser, info.ViewerLogin) + require.False(t, info.IsPrivate) require.True(t, info.HasPushAccess) require.EqualValues(t, 2, transport.CallCount()) } + +func TestRepoAccessCacheIsolatesViewerPerInstance(t *testing.T) { + ctx := t.Context() + + cacheName := t.Name() + restClient := newMockRESTServer(t, "read") + + attackerGQL, _ := newMockGQLClient("attacker", false) + attackerCache := NewRepoAccessCache(attackerGQL, restClient, WithCacheName(cacheName)) + safe, err := attackerCache.IsSafeContent(ctx, "attacker", testOwner, testRepo) + require.NoError(t, err) + require.True(t, safe) + + victimGQL, _ := newMockGQLClient("victim", false) + victimCache := NewRepoAccessCache(victimGQL, restClient, WithCacheName(cacheName)) + safe, err = victimCache.IsSafeContent(ctx, "attacker", testOwner, testRepo) + require.NoError(t, err) + require.False(t, safe, "attacker-authored content must not be safe for the victim") + + safe, err = victimCache.IsSafeContent(ctx, "victim", testOwner, testRepo) + require.NoError(t, err) + require.True(t, safe) +} + +type flakyTransport struct { + mu sync.Mutex + failN int + calls int + next http.RoundTripper +} + +func (f *flakyTransport) RoundTrip(req *http.Request) (*http.Response, error) { + f.mu.Lock() + f.calls++ + shouldFail := f.calls <= f.failN + f.mu.Unlock() + if shouldFail { + return nil, errors.New("simulated transient failure") + } + return f.next.RoundTrip(req) +} + +func TestRepoAccessCacheRetriesViewerLoginAfterTransientError(t *testing.T) { + ctx := t.Context() + + httpClient := githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + viewerLoginQuery{}, + nil, + githubv4mock.DataResponse(map[string]any{ + "viewer": map[string]any{"login": testUser}, + }), + ), + ) + flaky := &flakyTransport{next: httpClient.Transport, failN: 1} + httpClient.Transport = flaky + gqlClient := githubv4.NewClient(httpClient) + + cache := NewRepoAccessCache(gqlClient, nil, WithCacheName(t.Name())) + + _, err := cache.viewerLoginFor(ctx) + require.Error(t, err, "first call should surface the transient failure") + + login, err := cache.viewerLoginFor(ctx) + require.NoError(t, err, "second call must retry, not return the cached error") + require.Equal(t, testUser, login) +} + +func TestRepoAccessCacheRejectsEmptyViewerLogin(t *testing.T) { + ctx := t.Context() + + httpClient := githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + viewerLoginQuery{}, + nil, + githubv4mock.DataResponse(map[string]any{ + "viewer": map[string]any{"login": ""}, + }), + ), + ) + gqlClient := githubv4.NewClient(httpClient) + + cache := NewRepoAccessCache(gqlClient, nil, WithCacheName(t.Name())) + + _, err := cache.viewerLoginFor(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "empty") +}