Skip to content

Commit c1a2031

Browse files
committed
Add tests
1 parent cea55c7 commit c1a2031

File tree

3 files changed

+233
-53
lines changed

3 files changed

+233
-53
lines changed

internal/cmd/backup.go

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
303303
}
304304
}(&err)
305305

306-
c, err := client.NewClient(cmd)
306+
spiceClient, err := client.NewClient(cmd)
307307
if err != nil {
308308
return fmt.Errorf("unable to initialize client: %w", err)
309309
}
@@ -316,7 +316,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
316316
return fmt.Errorf("error creating backup file encoder: %w", err)
317317
}
318318
} else {
319-
encoder, zedToken, err = encoderForNewBackup(cmd, c, backupFile)
319+
encoder, zedToken, err = encoderForNewBackup(cmd, spiceClient, backupFile)
320320
if err != nil {
321321
return err
322322
}
@@ -343,17 +343,13 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
343343
}
344344

345345
ctx := cmd.Context()
346-
relationshipStream, err := c.ExportBulkRelationships(ctx, req)
347-
if err != nil {
348-
return fmt.Errorf("error exporting relationships: %w", err)
349-
}
350346

351347
relationshipReadStart := time.Now()
352348
tick := time.Tick(5 * time.Second)
353-
bar := console.CreateProgressBar("processing backup")
349+
progressBar := console.CreateProgressBar("processing backup")
354350
var relsFilteredOut, relsProcessed uint64
355351
defer func() {
356-
_ = bar.Finish()
352+
_ = progressBar.Finish()
357353

358354
evt := log.Info().
359355
Uint64("filtered", relsFilteredOut).
@@ -369,6 +365,51 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
369365
}
370366
}()
371367

368+
err = takeBackup(ctx, spiceClient, req, func(response *v1.ExportBulkRelationshipsResponse) error {
369+
for _, rel := range response.Relationships {
370+
if hasRelPrefix(rel, prefixFilter) {
371+
if err := encoder.Append(rel); err != nil {
372+
return fmt.Errorf("error storing relationship: %w", err)
373+
}
374+
} else {
375+
relsFilteredOut++
376+
}
377+
378+
relsProcessed++
379+
if err := progressBar.Add(1); err != nil {
380+
return fmt.Errorf("error incrementing progress bar: %w", err)
381+
}
382+
383+
// progress fallback in case there is no TTY
384+
if !isatty.IsTerminal(os.Stderr.Fd()) {
385+
select {
386+
case <-tick:
387+
log.Info().
388+
Uint64("filtered", relsFilteredOut).
389+
Uint64("processed", relsProcessed).
390+
Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))).
391+
Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)).
392+
Msg("backup progress")
393+
default:
394+
}
395+
}
396+
}
397+
398+
if err := writeProgress(progressFile, response); err != nil {
399+
return err
400+
}
401+
return nil
402+
})
403+
404+
backupCompleted = true
405+
return nil
406+
}
407+
408+
func takeBackup(ctx context.Context, spiceClient client.Client, req *v1.ExportBulkRelationshipsRequest, processResponse func(*v1.ExportBulkRelationshipsResponse) error) error {
409+
relationshipStream, err := spiceClient.ExportBulkRelationships(ctx, req)
410+
if err != nil {
411+
return fmt.Errorf("error exporting relationships: %w", err)
412+
}
372413
var lastResponse *v1.ExportBulkRelationshipsResponse
373414
for {
374415
if err := ctx.Err(); err != nil {
@@ -386,15 +427,16 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
386427
}
387428

388429
if isRetryableError(err) {
389-
// TODO: do we need to clean up the existing stream in some way?
390430
// TODO: best way to test this?
391431
// If the error is retryable, we overwrite the existing stream with a new
392432
// stream based on a new request that starts at the cursor location of the
393433
// last received response.
394-
relationshipStream, err = c.ExportBulkRelationships(ctx, &v1.ExportBulkRelationshipsRequest{
395-
OptionalLimit: pageLimit,
396-
OptionalCursor: lastResponse.AfterResultCursor,
397-
})
434+
435+
// Clone the request to ensure that we are keeping all other fields the same
436+
newReq := req.CloneVT()
437+
newReq.OptionalCursor = lastResponse.AfterResultCursor
438+
439+
relationshipStream, err = spiceClient.ExportBulkRelationships(ctx, newReq)
398440
log.Info().Err(err).Str("cursor token", lastResponse.AfterResultCursor.Token).Msg("encountered retryable error, resuming stream after token")
399441
// Bounce to the top of the loop
400442
continue
@@ -410,41 +452,12 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
410452
// starting at its cursor
411453
lastResponse = relsResp
412454

413-
for _, rel := range relsResp.Relationships {
414-
if hasRelPrefix(rel, prefixFilter) {
415-
if err := encoder.Append(rel); err != nil {
416-
return fmt.Errorf("error storing relationship: %w", err)
417-
}
418-
} else {
419-
relsFilteredOut++
420-
}
421-
422-
relsProcessed++
423-
if err := bar.Add(1); err != nil {
424-
return fmt.Errorf("error incrementing progress bar: %w", err)
425-
}
426-
427-
// progress fallback in case there is no TTY
428-
if !isatty.IsTerminal(os.Stderr.Fd()) {
429-
select {
430-
case <-tick:
431-
log.Info().
432-
Uint64("filtered", relsFilteredOut).
433-
Uint64("processed", relsProcessed).
434-
Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))).
435-
Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)).
436-
Msg("backup progress")
437-
default:
438-
}
439-
}
440-
}
441-
442-
if err := writeProgress(progressFile, relsResp); err != nil {
455+
// Process the response using the provided function
456+
err = processResponse(relsResp)
457+
if err != nil {
443458
return err
444459
}
445460
}
446-
447-
backupCompleted = true
448461
return nil
449462
}
450463

internal/cmd/backup_test.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package cmd
22

33
import (
4+
"context"
45
"encoding/json"
56
"errors"
67
"fmt"
8+
"io"
79
"os"
810
"path/filepath"
911
"strings"
@@ -12,6 +14,7 @@ import (
1214
"github.com/google/uuid"
1315
"github.com/rs/zerolog"
1416
"github.com/stretchr/testify/require"
17+
"google.golang.org/grpc"
1518
"google.golang.org/grpc/codes"
1619
"google.golang.org/grpc/status"
1720

@@ -22,6 +25,7 @@ import (
2225
"github.com/authzed/zed/internal/client"
2326
"github.com/authzed/zed/internal/storage"
2427
zedtesting "github.com/authzed/zed/internal/testing"
28+
"github.com/authzed/zed/pkg/backupformat"
2529
)
2630

2731
func init() {
@@ -606,3 +610,166 @@ func TestAddSizeErrInfo(t *testing.T) {
606610
})
607611
}
608612
}
613+
614+
func TestTakeBackupMockWorksAsExpected(t *testing.T) {
615+
rels := []*v1.Relationship{
616+
{
617+
Resource: &v1.ObjectReference{
618+
ObjectType: "resource",
619+
ObjectId: "foo",
620+
},
621+
Relation: "view",
622+
Subject: &v1.SubjectReference{
623+
Object: &v1.ObjectReference{
624+
ObjectType: "user",
625+
ObjectId: "jim",
626+
},
627+
},
628+
},
629+
}
630+
client := &mockClientForBackup{
631+
t: t,
632+
recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){
633+
func() (*v1.ExportBulkRelationshipsResponse, error) {
634+
return &v1.ExportBulkRelationshipsResponse{
635+
Relationships: rels,
636+
}, nil
637+
},
638+
},
639+
}
640+
641+
err := takeBackup(t.Context(), client, &v1.ExportBulkRelationshipsRequest{}, func(response *v1.ExportBulkRelationshipsResponse) error {
642+
require.Len(t, response.Relationships, 1, "expecting 1 rel in the list")
643+
return nil
644+
})
645+
require.NoError(t, err)
646+
647+
client.assertAllRecvCalls()
648+
}
649+
650+
func TestTakeBackupRecoversFromRetryableErrors(t *testing.T) {
651+
firstRels := []*v1.Relationship{
652+
{
653+
Resource: &v1.ObjectReference{
654+
ObjectType: "resource",
655+
ObjectId: "foo",
656+
},
657+
Relation: "view",
658+
Subject: &v1.SubjectReference{
659+
Object: &v1.ObjectReference{
660+
ObjectType: "user",
661+
ObjectId: "jim",
662+
},
663+
},
664+
},
665+
}
666+
cursor := &v1.Cursor{
667+
Token: "an token",
668+
}
669+
secondRels := []*v1.Relationship{
670+
{
671+
Resource: &v1.ObjectReference{
672+
ObjectType: "resource",
673+
ObjectId: "bar",
674+
},
675+
Relation: "view",
676+
Subject: &v1.SubjectReference{
677+
Object: &v1.ObjectReference{
678+
ObjectType: "user",
679+
ObjectId: "jim",
680+
},
681+
},
682+
},
683+
}
684+
client := &mockClientForBackup{
685+
t: t,
686+
recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){
687+
func() (*v1.ExportBulkRelationshipsResponse, error) {
688+
return &v1.ExportBulkRelationshipsResponse{
689+
Relationships: firstRels,
690+
// Need to test that this cursor is supplied
691+
AfterResultCursor: cursor,
692+
}, nil
693+
},
694+
func() (*v1.ExportBulkRelationshipsResponse, error) {
695+
// Return a retryable error
696+
return nil, status.Error(codes.Unavailable, "i fell over")
697+
},
698+
func() (*v1.ExportBulkRelationshipsResponse, error) {
699+
return &v1.ExportBulkRelationshipsResponse{
700+
Relationships: secondRels,
701+
AfterResultCursor: &v1.Cursor{
702+
Token: "some other token",
703+
},
704+
}, nil
705+
},
706+
},
707+
exportCalls: []func(t *testing.T, req *v1.ExportBulkRelationshipsRequest){
708+
// Initial request
709+
func(_ *testing.T, _ *v1.ExportBulkRelationshipsRequest) {
710+
},
711+
// The retried request - asserting that it's called with the cursor
712+
func(t *testing.T, req *v1.ExportBulkRelationshipsRequest) {
713+
require.Equal(t, req.OptionalCursor.Token, cursor.Token, "cursor token does not match expected")
714+
},
715+
},
716+
}
717+
718+
actualRels := make([]*v1.Relationship, 0)
719+
720+
err := takeBackup(t.Context(), client, &v1.ExportBulkRelationshipsRequest{}, func(response *v1.ExportBulkRelationshipsResponse) error {
721+
actualRels = append(actualRels, response.Relationships...)
722+
return nil
723+
})
724+
require.NoError(t, err)
725+
726+
require.Len(t, actualRels, 2, "expecting two rels in the realized list")
727+
require.Equal(t, actualRels[0].Resource.ObjectId, "foo")
728+
require.Equal(t, actualRels[1].Resource.ObjectId, "bar")
729+
730+
client.assertAllRecvCalls()
731+
}
732+
733+
type mockClientForBackup struct {
734+
client.Client
735+
grpc.ServerStreamingClient[v1.ExportBulkRelationshipsResponse]
736+
t *testing.T
737+
backupformat.Encoder
738+
recvCalls []func() (*v1.ExportBulkRelationshipsResponse, error)
739+
recvCallIndex int
740+
// exportCalls provides a handle on the calls made to ExportBulkRelationships,
741+
// allowing for assertions to be made against those calls.
742+
exportCalls []func(t *testing.T, req *v1.ExportBulkRelationshipsRequest)
743+
exportCallsIndex int
744+
}
745+
746+
func (m *mockClientForBackup) Recv() (*v1.ExportBulkRelationshipsResponse, error) {
747+
// If we've run through all our calls, return an EOF
748+
if m.recvCallIndex == len(m.recvCalls) {
749+
return nil, io.EOF
750+
}
751+
recvCall := m.recvCalls[m.recvCallIndex]
752+
m.recvCallIndex++
753+
return recvCall()
754+
}
755+
756+
func (m *mockClientForBackup) ExportBulkRelationships(_ context.Context, req *v1.ExportBulkRelationshipsRequest, _ ...grpc.CallOption) (grpc.ServerStreamingClient[v1.ExportBulkRelationshipsResponse], error) {
757+
if m.exportCalls == nil {
758+
// If the caller doesn't supply exportCalls, pass through
759+
return m, nil
760+
}
761+
if m.exportCallsIndex == len(m.exportCalls) {
762+
// If invoked too many times, fail the test
763+
m.t.FailNow()
764+
return m, nil
765+
}
766+
exportCall := m.exportCalls[m.exportCallsIndex]
767+
m.exportCallsIndex++
768+
exportCall(m.t, req)
769+
return m, nil
770+
}
771+
772+
// assertAllRecvCalls asserts that the number of invocations is as expected
773+
func (m *mockClientForBackup) assertAllRecvCalls() {
774+
require.Equal(m.t, len(m.recvCalls), m.recvCallIndex, "the number of provided recvCalls should match the number of invocations")
775+
}

0 commit comments

Comments
 (0)