diff --git a/cmd/filebot/filebot.go b/cmd/filebot/filebot.go index 3086da6..6c9d5b5 100644 --- a/cmd/filebot/filebot.go +++ b/cmd/filebot/filebot.go @@ -11,11 +11,15 @@ import ( func runMainCommand(_ *cobra.Command, _ []string) { ctx := context.Background() - conf := setting.Flags.Config() + conf, err := setting.Flags.Config() + if err != nil { + setting.Logger().Fatal("Failed load configuration", err) + return + } w := watcher.NewWatcher(ctx) defer w.Close() - w.AddFilesToObservable(*conf) + w.AddFilesToObservable(conf) tasks.RunTaskWithInterval(ctx, 1*time.Hour, tasks.MoveToTrashTask) diff --git a/file/move.go b/file/move.go index 2264a1a..b954810 100644 --- a/file/move.go +++ b/file/move.go @@ -38,13 +38,16 @@ func MoveToDestination(dest string, paths ...string) (err error) { newPath := filepath.Join(fixedDest, filepath.Base(src)) - if _, err = os.Stat(src); !errors.Is(err, os.ErrNotExist) { - err = os.Rename(src, newPath) - if err != nil { + if stat, err := os.Stat(src); err == nil { + if err = os.Rename(src, newPath); err != nil { setting.Logger().Error(fmt.Sprintf("Failed to move file from %s to %s", src, newPath), err) continue } + if err = updateOwnerAndGroupID(stat, newPath); err != nil { + return err + } + setting.Logger().Info(fmt.Sprintf("Moved file from %s to %s", src, dest)) } } @@ -52,6 +55,13 @@ func MoveToDestination(dest string, paths ...string) (err error) { return nil } +func updateOwnerAndGroupID(ogInfo os.FileInfo, src string) (err error) { + conf, _ := setting.Flags.Config() + uid, gid := uid(src, ogInfo, conf), gid(src, ogInfo, conf) + + return os.Chown(src, int(uid), int(gid)) +} + func checkFilePermissions(stat os.FileInfo) error { writePermIndex := strings.Index(stat.Mode().String(), unixWritePermission) if writePermIndex == -1 { diff --git a/file/stat.go b/file/stat.go new file mode 100644 index 0000000..778ca11 --- /dev/null +++ b/file/stat.go @@ -0,0 +1,55 @@ +package file + +import ( + "github.com/wittano/filebot/setting" + "os" + "path/filepath" + "syscall" +) + +type uID int +type gID int + +const ( + rootID uID = 0 + rootGID gID = 0 +) + +func uid(path string, ogStat os.FileInfo, config setting.Config) uID { + baseDir := filepath.Dir(path) + + for _, dir := range config.Dirs { + if dir.Dest == baseDir { + uid := uID(os.Getuid()) + + if dir.UID > 0 { + uid = uID(dir.UID) + } else if dir.IsRoot { + uid = rootID + } + + return uid + } + } + + return uID(ogStat.Sys().(*syscall.Stat_t).Uid) +} + +func gid(path string, ogStat os.FileInfo, config setting.Config) gID { + baseDir := filepath.Dir(path) + + for _, dir := range config.Dirs { + if dir.Dest == baseDir { + uid := gID(os.Getgid()) + if dir.UID > 0 { + uid = gID(dir.GID) + } else if dir.IsRoot { + uid = rootGID + } + + return uid + } + } + + return gID(ogStat.Sys().(*syscall.Stat_t).Gid) +} diff --git a/setting/config_file.go b/setting/config_file.go index 7c8fd48..0b3db9c 100644 --- a/setting/config_file.go +++ b/setting/config_file.go @@ -21,7 +21,7 @@ type Config struct { Dirs []Directory `validate:"required"` } -var config *Config +var config Config type Directory struct { Src []string `validate:"required"` @@ -30,16 +30,12 @@ type Directory struct { MoveToTrash bool `validate:"required_without=Dest"` After uint Exceptions []string + UID uint32 + GID uint32 + IsRoot bool } func (d Directory) RealPaths() (paths []string, err error) { - v := validator.New(validator.WithRequiredStructEnabled()) - - err = v.Struct(d) - if err != nil { - return - } - for _, exp := range d.Src { if d.Recursive { paths, err = path.PathsFromPatternRecursive(exp) @@ -139,29 +135,28 @@ func isUserRoot() bool { return os.Getuid() == 0 } -func load(path string) (*Config, error) { +func load(path string) (Config, error) { bytes, err := os.ReadFile(path) if err != nil { - return nil, err + return Config{}, err } var unmarshal map[string]Directory if err := toml.Unmarshal(bytes, &unmarshal); err != nil { - return nil, err + return Config{}, err } if len(unmarshal) == 0 { - return nil, errors.New("config file is empty") + return Config{}, errors.New("config file is empty") } - config = new(Config) config.Dirs = maps.Values(unmarshal) v := validator.New(validator.WithRequiredStructEnabled()) for _, d := range config.Dirs { if err = v.Struct(d); err != nil { - return nil, err + return Config{}, err } } diff --git a/setting/flags.go b/setting/flags.go index 7d4bfd9..c67ea91 100644 --- a/setting/flags.go +++ b/setting/flags.go @@ -21,17 +21,12 @@ var Flags = Flag{ "", } -func (f Flag) Config() *Config { - if config != nil { - return config +func (f Flag) Config() (Config, error) { + if config.Dirs != nil { + return config, nil } - c, err := load(f.ConfigPath) - if err != nil { - Logger().Fatal("Failed to load config file", err) - } - - return c + return load(f.ConfigPath) } func (f Flag) LogLevel() LogLevel { diff --git a/tasks/trash.go b/tasks/trash.go index eae897d..0dd5ad1 100644 --- a/tasks/trash.go +++ b/tasks/trash.go @@ -17,7 +17,12 @@ func MoveToTrashTask(ctx context.Context) (err error) { default: } - for _, dir := range setting.Flags.Config().Dirs { + config, err := setting.Flags.Config() + if err != nil { + return err + } + + for _, dir := range config.Dirs { if dir.MoveToTrash { if err = moveFileToTrash(dir); err != nil { return diff --git a/watcher/watcher.go b/watcher/watcher.go index fdbc382..cacdd37 100644 --- a/watcher/watcher.go +++ b/watcher/watcher.go @@ -152,9 +152,9 @@ func (w *MyWatcher) updateObservableFileList(ctx context.Context) error { case <-ctx.Done(): return default: - conf := setting.Flags.Config() + conf, _ := setting.Flags.Config() - w.AddFilesToObservable(*conf) + w.AddFilesToObservable(conf) } }()