From 5169ac6b50f8aa9b31603d8e43d363d682418fd5 Mon Sep 17 00:00:00 2001 From: Wittano Bonarotti Date: Sun, 30 Jun 2024 20:48:24 +0200 Subject: [PATCH] test(api): added tests for downloading file --- Makefile | 9 ++- audio/path.go | 46 +++++++++++-- bot/command/spock.go | 4 +- server/audio.go | 28 +++----- server/audio_test.go | 27 +++++--- server/file.go | 34 ++++++---- server/file_test.go | 157 +++++++++++++++++++++++++++++++++++++++++++ test/assets.go | 6 +- 8 files changed, 260 insertions(+), 51 deletions(-) create mode 100644 server/file_test.go diff --git a/Makefile b/Makefile index 0e7f517..a410a48 100644 --- a/Makefile +++ b/Makefile @@ -28,8 +28,13 @@ protobuf: mkdir -p $(PROTOBUF_API_DEST) protoc --go_out=./api --go_opt=paths=source_relative --go-grpc_out=./api --go-grpc_opt=paths=source_relative proto/* -test: protobuf - CGO_CFLAGS="-w" find . -name go.mod -execdir go test ./... \; +test: test-bot test-server + +test-bot: protobuf + CGO_CFLAGS="-w" go test ./bot/...; + +test-server: protobuf + CGO_CFLAGS="-w" go test ./server/...; all: bot-prod sever tui test diff --git a/audio/path.go b/audio/path.go index 304ae44..b6dfa89 100644 --- a/audio/path.go +++ b/audio/path.go @@ -2,11 +2,13 @@ package audio import ( "errors" + "fmt" "io" "math/rand" "os" "os/exec" "path/filepath" + "strings" "syscall" "time" ) @@ -16,8 +18,41 @@ const ( assetsDirKey = "ASSETS_DIR" ) -func Path(name string) string { - return filepath.Join(AssertDir(), name) +func Path(name string) (path string, err error) { + assertDir := AssertDir() + path = filepath.Join(assertDir, name) + _, err = os.Stat(path) + if err != nil { + path, err = searchPathByNameOrUUID(name) + } + + return +} + +func searchPathByNameOrUUID(prefix string) (p string, err error) { + var paths []string + paths, err = Paths() + if err != nil { + return + } + + for _, p = range paths { + base := filepath.Base(p) + if strings.HasPrefix(base, prefix) { + return + } else { + split := strings.Split(base, "-") + if len(split) < 2 { + continue + } + + if strings.HasPrefix(strings.Join(split[1:], "-"), prefix) { + return + } + } + } + + return "", fmt.Errorf("path with prefix %s wasn't found", prefix) } func AssertDir() (path string) { @@ -30,7 +65,8 @@ func AssertDir() (path string) { } func Paths() (paths []string, err error) { - dirs, err := os.ReadDir(AssertDir()) + assertDir := AssertDir() + dirs, err := os.ReadDir(assertDir) if err != nil { return nil, err } @@ -42,7 +78,7 @@ func Paths() (paths []string, err error) { paths = make([]string, 0, len(dirs)) for _, dir := range dirs { if dir.Type() != os.ModeDir { - paths = append(paths, dir.Name()) + paths = append(paths, filepath.Join(assertDir, dir.Name())) } } @@ -77,7 +113,7 @@ func PathsWithPagination(page uint32, size uint32) (paths []string, err error) { return } -func RandomPath() (string, error) { +func RandomAudioName() (string, error) { paths, err := Paths() if err != nil { return "", err diff --git a/bot/command/spock.go b/bot/command/spock.go index 57032e6..c29c8b7 100644 --- a/bot/command/spock.go +++ b/bot/command/spock.go @@ -134,9 +134,9 @@ func audioPath(data discordgo.ApplicationCommandInteractionData) (path string, e } if name == "" { - path, err = audio.RandomPath() + path, err = audio.RandomAudioName() } else { - path = audio.Path(path) + path, err = audio.Path(path) } return diff --git a/server/audio.go b/server/audio.go index 7cd3657..9db4606 100644 --- a/server/audio.go +++ b/server/audio.go @@ -35,21 +35,22 @@ func (a audioServer) List(pagination *pb.Pagination, server pb.AudioService_List return } +// TODO Add file validation func (a audioServer) Add(server pb.AudioService_AddServer) error { id := uuid.NewString() var path string au, err := server.Recv() - if err != nil && !errors.Is(err, io.EOF) { + if err != nil { return status.Error(codes.Internal, err.Error()) } - if fileExistsInAssetsDir(au.Info.Name) { + if _, err := audio.Path(au.Info.Name); err == nil { return status.Error(codes.AlreadyExists, fmt.Sprintf("file %s already exists", au.Info.Name)) } if path == "" { - path = audio.Path(fmt.Sprintf("%s-%s.%s", au.Info.Name, id, strings.ToLower(au.Info.Type.String()))) + path = filepath.Join(audio.AssertDir(), fmt.Sprintf("%s-%s.%s", au.Info.Name, id, strings.ToLower(au.Info.Type.String()))) } f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0600) @@ -95,24 +96,13 @@ func (a audioServer) Remove(_ context.Context, req *pb.RemoveAudio) (e *emptypb. } for _, query := range req.Name { - rmErr := os.Remove(audio.Path(query)) - if rmErr != nil { - err = errors.Join(err, status.Error(codes.NotFound, rmErr.Error())) + path, err := audio.Path(query) + if err != nil { + return nil, status.Error(codes.NotFound, err.Error()) } - } - - return -} - -func fileExistsInAssetsDir(filename string) (exists bool) { - dir, err := os.ReadDir(audio.AssertDir()) - if err != nil { - return - } - for _, d := range dir { - if strings.HasPrefix(d.Name(), filename) { - return true + if err = os.Remove(path); err != nil { + return nil, status.Error(codes.Internal, err.Error()) } } diff --git a/server/audio_test.go b/server/audio_test.go index 601dd88..d74afc5 100644 --- a/server/audio_test.go +++ b/server/audio_test.go @@ -35,7 +35,7 @@ func (c closers) Close() (err error) { return } -func createClient() (client pb.AudioServiceClient, server io.Closer, err error) { +func createAudioClient() (client pb.AudioServiceClient, server io.Closer, err error) { s, err := New(port) if err != nil { return nil, nil, err @@ -59,7 +59,7 @@ func createClient() (client pb.AudioServiceClient, server io.Closer, err error) } func TestRemoveAudio(t *testing.T) { - client, closer, err := createClient() + client, closer, err := createAudioClient() if err != nil { t.Fatal(err) } @@ -74,13 +74,17 @@ func TestRemoveAudio(t *testing.T) { t.Fatal(err) } + for i, p := range paths { + paths[i] = filepath.Base(p) + } + if _, err := client.Remove(context.Background(), &pb.RemoveAudio{Name: paths}); err != nil { t.Fatal(err) } } func TestUploadFile(t *testing.T) { - client, closer, err := createClient() + client, closer, err := createAudioClient() if err != nil { t.Fatal(err) } @@ -116,13 +120,18 @@ func TestUploadFile(t *testing.T) { t.Fatal(err) } - if s, err := os.Stat(audio.Path(res.Filename)); err != nil || s.Size() == 0 { + path, err = audio.Path(res.Filename) + if err != nil { + t.Fatal(err) + } + + if s, err := os.Stat(path); err != nil || s.Size() == 0 { t.Fatalf("failed upload file: %v", err) } } func TestUploadFile_FileAlreadyExists(t *testing.T) { - client, closer, err := createClient() + client, closer, err := createAudioClient() if err != nil { t.Fatal(err) } @@ -132,9 +141,9 @@ func TestUploadFile_FileAlreadyExists(t *testing.T) { t.Fatal(err) } - path := audio.Path("test") - f, err := os.Create(path) - if err != nil { + path := filepath.Join(audio.AssertDir(), "test") + f, createErr := os.Create(path) + if errors.Join(err, createErr) != nil { t.Fatal(err) } f.Close() @@ -164,7 +173,7 @@ func fillTempFile(t *testing.T, path string) error { } defer f.Close() - for i := 0; i < 1000; i++ { + for i := 0; i < 100; i++ { if _, err := f.WriteString(strconv.Itoa(rand.Int()) + "\n"); err != nil { return err } diff --git a/server/file.go b/server/file.go index d53eb7d..d3181f9 100644 --- a/server/file.go +++ b/server/file.go @@ -6,36 +6,41 @@ import ( "fmt" komputer "github.com/wittano/komputer/api/proto" "github.com/wittano/komputer/audio" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "io" "log/slog" "os" ) +const downloadBufSize = 1024 * 1024 + type fileServer struct { komputer.UnimplementedAudioFileServiceServer } -// TODO Verification of service structe func (fs fileServer) Download(req *komputer.DownloadFile, server komputer.AudioFileService_DownloadServer) error { if req == nil { - return errors.New("download: missing req data") + return status.Error(codes.InvalidArgument, "download: missing required data") + } + + path, err := audio.Path(filename(req)) + if err != nil { + return status.Error(codes.NotFound, err.Error()) } - path := audio.Path(filename(req)) f, err := os.Open(path) if err != nil { slog.Error("failed find f "+path, err) - return err + return status.Error(codes.NotFound, err.Error()) } defer f.Close() - ctx := server.Context() - - buf := make([]byte, 1024) + buf := make([]byte, downloadBufSize) for { select { - case <-ctx.Done(): - return context.Canceled + case <-server.Context().Done(): + return status.Error(codes.Canceled, context.Canceled.Error()) default: } @@ -43,11 +48,12 @@ func (fs fileServer) Download(req *komputer.DownloadFile, server komputer.AudioF if errors.Is(err, io.EOF) { break } else if err != nil { - return err + return status.Error(codes.Internal, err.Error()) } if err = server.Send(&komputer.FileBuffer{Content: buf, Size: uint64(n)}); err != nil { - return err + slog.Error("failed send chunk of file", err) + return status.Error(codes.Internal, err.Error()) } } @@ -66,5 +72,9 @@ func filename(req *komputer.DownloadFile) (name string) { return } - return fmt.Sprintf("%s-%s", name, uuid) + if name == "" { + return string(uuid.Uuid) + } + + return fmt.Sprintf("%s-%s", name, uuid.Uuid) } diff --git a/server/file_test.go b/server/file_test.go new file mode 100644 index 0000000..8f64d1c --- /dev/null +++ b/server/file_test.go @@ -0,0 +1,157 @@ +package server + +import ( + "context" + "errors" + "fmt" + pb "github.com/wittano/komputer/api/proto" + "github.com/wittano/komputer/audio" + "github.com/wittano/komputer/test" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + "io" + "log" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" +) + +func createDownloadClient() (client pb.AudioFileServiceClient, server io.Closer, err error) { + s, err := New(port) + if err != nil { + return nil, nil, err + } + + go func() { + if err := s.Start(); err != nil { + log.Fatalf("Server exited with error: %v", err) + } + }() + + conn, err := grpc.NewClient("localhost:"+strconv.Itoa(port), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + s.Close() + return + } + + server = closers{conn, s} + client = pb.NewAudioFileServiceClient(conn) + return +} + +func TestFileServer_Download_ButFileDoesNotExists(t *testing.T) { + client, closer, err := createDownloadClient() + if err != nil { + t.Fatal(err) + } + defer closer.Close() + + if err = test.CreateAssertDir(t, 1); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + name := "invalid_name" + stream, err := client.Download(ctx, &pb.DownloadFile{Name: &name}) + if err != nil { + t.Fatal(err) + } + defer stream.CloseSend() + + if _, err := stream.Recv(); status.Code(err) != codes.NotFound { + t.Fatal("file invalid_name wasn't found in assets dir. Status code: " + strconv.Itoa(int(status.Code(err)))) + } +} + +func TestFileServer_Download(t *testing.T) { + t.Parallel() + + client, closer, err := createDownloadClient() + if err != nil { + t.Fatal(err) + } + defer closer.Close() + + if err = test.CreateAssertDir(t, 1); err != nil { + t.Fatal(err) + } + + assertDir := audio.AssertDir() + dir, err := os.ReadDir(assertDir) + if err != nil { + t.Fatal(err) + } + + path := filepath.Join(assertDir, dir[0].Name()) + if err := fillTempFile(t, path); err != nil { + t.Fatal(err) + } + + base := filepath.Base(path) + split := strings.Split(base, "-") + if len(split) < 2 { + t.Fatal("invalid split " + base) + } + + uuid := pb.UUID{Uuid: []byte(strings.Join(split[1:], "-"))} + data := []*pb.DownloadFile{ + { + Name: &split[0], + }, + { + Uuid: &uuid, + }, + { + Name: &split[0], + Uuid: &uuid, + }, + } + + for _, d := range data { + var name string + if d.Name != nil { + name = *d.Name + } + + t.Run(fmt.Sprintf("download file with name: %s and uuid: %s", name, d.Uuid), func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + stream, err := client.Download(ctx, d) + if err != nil { + t.Fatal(err) + } + defer stream.CloseSend() + + for { + select { + case <-ctx.Done(): + return + default: + } + + recv, err := stream.Recv() + if err != nil && !errors.Is(err, io.EOF) { + t.Fatal(err) + } else if errors.Is(err, io.EOF) || (recv != nil && recv.Size < downloadBufSize) { + return + } + + if recv.Size == 0 { + t.Fatal("server didn't send chunk of file") + } + + if len(recv.Content) == 0 { + t.Fatalf("invalid size and number of bytes in content. Want: %d, Got: %d", recv.Size, len(recv.Content)) + } + } + }) + } +} diff --git a/test/assets.go b/test/assets.go index f871896..cc77e3b 100644 --- a/test/assets.go +++ b/test/assets.go @@ -2,7 +2,9 @@ package test import ( "fmt" + "github.com/google/uuid" "os" + "path/filepath" "testing" ) @@ -16,9 +18,9 @@ func CreateAssertDir(t *testing.T, n int) (err error) { } for i := 0; i < n; i++ { - f, err := os.CreateTemp(dir, fmt.Sprintf("test-%d.*.mp3", i)) + f, err := os.Create(filepath.Join(dir, fmt.Sprintf("test-%s.mp3", uuid.NewString()))) if err != nil { - t.Fatal(err) + return err } err = f.Close() }