diff --git a/paths.go b/paths.go index 7fb644d..c42ca31 100644 --- a/paths.go +++ b/paths.go @@ -32,7 +32,6 @@ package paths import ( "fmt" "io" - "io/ioutil" "os" "path/filepath" "strings" @@ -418,14 +417,14 @@ func (p *Path) Chtimes(atime, mtime time.Time) error { // ReadFile reads the file named by filename and returns the contents func (p *Path) ReadFile() ([]byte, error) { - return ioutil.ReadFile(p.path) + return os.ReadFile(p.path) } // WriteFile writes data to a file named by filename. If the file // does not exist, WriteFile creates it otherwise WriteFile truncates // it before writing. func (p *Path) WriteFile(data []byte) error { - return ioutil.WriteFile(p.path, data, os.FileMode(0644)) + return os.WriteFile(p.path, data, os.FileMode(0644)) } // WriteToTempFile writes data to a newly generated temporary file. diff --git a/readdir.go b/readdir.go index 53341b9..42f5d73 100644 --- a/readdir.go +++ b/readdir.go @@ -30,7 +30,8 @@ package paths import ( - "io/ioutil" + "errors" + "os" "strings" ) @@ -41,7 +42,7 @@ type ReadDirFilter func(file *Path) bool // ReadDir returns a PathList containing the content of the directory // pointed by the current Path. The resulting list is filtered by the given filters chained. func (p *Path) ReadDir(filters ...ReadDirFilter) (PathList, error) { - infos, err := ioutil.ReadDir(p.path) + infos, err := os.ReadDir(p.path) if err != nil { return nil, err } @@ -69,27 +70,7 @@ func (p *Path) ReadDir(filters ...ReadDirFilter) (PathList, error) { // ReadDirRecursive returns a PathList containing the content of the directory // and its subdirectories pointed by the current Path func (p *Path) ReadDirRecursive() (PathList, error) { - infos, err := ioutil.ReadDir(p.path) - if err != nil { - return nil, err - } - paths := PathList{} - for _, info := range infos { - path := p.Join(info.Name()) - paths.Add(path) - - if isDir, err := path.IsDirCheck(); err != nil { - return nil, err - } else if isDir { - subPaths, err := path.ReadDirRecursive() - if err != nil { - return nil, err - } - paths.AddAll(subPaths) - } - - } - return paths, nil + return p.ReadDirRecursiveFiltered(nil) } // ReadDirRecursiveFiltered returns a PathList containing the content of the directory @@ -101,41 +82,55 @@ func (p *Path) ReadDirRecursive() (PathList, error) { // - `filters` are the filters that are checked to determine if the entry should be // added to the resulting PathList func (p *Path) ReadDirRecursiveFiltered(recursionFilter ReadDirFilter, filters ...ReadDirFilter) (PathList, error) { - infos, err := ioutil.ReadDir(p.path) - if err != nil { - return nil, err - } + var search func(*Path) (PathList, error) - accept := func(p *Path) bool { - for _, filter := range filters { - if !filter(p) { - return false - } + explored := map[string]bool{} + search = func(currPath *Path) (PathList, error) { + canonical := currPath.Canonical().path + if explored[canonical] { + return nil, errors.New("directories symlink loop detected") } - return true - } + explored[canonical] = true + defer delete(explored, canonical) - paths := PathList{} - for _, info := range infos { - path := p.Join(info.Name()) + infos, err := os.ReadDir(currPath.path) + if err != nil { + return nil, err + } - if accept(path) { - paths.Add(path) + accept := func(p *Path) bool { + for _, filter := range filters { + if !filter(p) { + return false + } + } + return true } - if recursionFilter == nil || recursionFilter(path) { - if isDir, err := path.IsDirCheck(); err != nil { - return nil, err - } else if isDir { - subPaths, err := path.ReadDirRecursiveFiltered(recursionFilter, filters...) - if err != nil { + paths := PathList{} + for _, info := range infos { + path := currPath.Join(info.Name()) + + if accept(path) { + paths.Add(path) + } + + if recursionFilter == nil || recursionFilter(path) { + if isDir, err := path.IsDirCheck(); err != nil { return nil, err + } else if isDir { + subPaths, err := search(path) + if err != nil { + return nil, err + } + paths.AddAll(subPaths) } - paths.AddAll(subPaths) } } + return paths, nil } - return paths, nil + + return search(p) } // FilterDirectories is a ReadDirFilter that accepts only directories diff --git a/readdir_test.go b/readdir_test.go index d0ec927..c22d542 100644 --- a/readdir_test.go +++ b/readdir_test.go @@ -33,6 +33,7 @@ import ( "fmt" "os" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -245,3 +246,71 @@ func TestReadDirRecursiveFiltered(t *testing.T) { pathEqualsTo(t, "testdata/fileset/test.txt", l[7]) pathEqualsTo(t, "testdata/fileset/test.txt.gz", l[8]) } + +func TestReadDirRecursiveLoopDetection(t *testing.T) { + loopsPath := New("testdata", "loops") + unbuondedReaddir := func(testdir string) (PathList, error) { + // This is required to unbound the recursion, otherwise it will stop + // when the paths becomes too long due to the symlink loop: this is not + // what we want, we are looking for an early detection of the loop. + skipBrokenLinks := func(p *Path) bool { + _, err := p.Stat() + return err == nil + } + + var files PathList + var err error + done := make(chan bool) + go func() { + files, err = loopsPath.Join(testdir).ReadDirRecursiveFiltered( + skipBrokenLinks, + ) + done <- true + }() + require.Eventually( + t, + func() bool { + select { + case <-done: + return true + default: + return false + } + }, + 5*time.Second, + 10*time.Millisecond, + "Infinite symlink loop while loading sketch", + ) + return files, err + } + + for _, dir := range []string{"loop_1", "loop_2", "loop_3", "loop_4"} { + l, err := unbuondedReaddir(dir) + require.EqualError(t, err, "directories symlink loop detected", "loop not detected in %s", dir) + require.Nil(t, l) + } + + { + l, err := unbuondedReaddir("regular_1") + require.NoError(t, err) + require.Len(t, l, 4) + l.Sort() + pathEqualsTo(t, "testdata/loops/regular_1/dir1", l[0]) + pathEqualsTo(t, "testdata/loops/regular_1/dir1/file1", l[1]) + pathEqualsTo(t, "testdata/loops/regular_1/dir2", l[2]) + pathEqualsTo(t, "testdata/loops/regular_1/dir2/file1", l[3]) + } + + { + l, err := unbuondedReaddir("regular_2") + require.NoError(t, err) + require.Len(t, l, 6) + l.Sort() + pathEqualsTo(t, "testdata/loops/regular_2/dir1", l[0]) + pathEqualsTo(t, "testdata/loops/regular_2/dir1/file1", l[1]) + pathEqualsTo(t, "testdata/loops/regular_2/dir2", l[2]) + pathEqualsTo(t, "testdata/loops/regular_2/dir2/dir1", l[3]) + pathEqualsTo(t, "testdata/loops/regular_2/dir2/dir1/file1", l[4]) + pathEqualsTo(t, "testdata/loops/regular_2/dir2/file2", l[5]) + } +} diff --git a/testdata/loops/loop_1/dir1/loop b/testdata/loops/loop_1/dir1/loop new file mode 120000 index 0000000..c9f3ab1 --- /dev/null +++ b/testdata/loops/loop_1/dir1/loop @@ -0,0 +1 @@ +../dir1 \ No newline at end of file diff --git a/testdata/loops/loop_2/dir1/loop2 b/testdata/loops/loop_2/dir1/loop2 new file mode 120000 index 0000000..d014eb4 --- /dev/null +++ b/testdata/loops/loop_2/dir1/loop2 @@ -0,0 +1 @@ +../dir2 \ No newline at end of file diff --git a/testdata/loops/loop_2/dir2/loop1 b/testdata/loops/loop_2/dir2/loop1 new file mode 120000 index 0000000..c9f3ab1 --- /dev/null +++ b/testdata/loops/loop_2/dir2/loop1 @@ -0,0 +1 @@ +../dir1 \ No newline at end of file diff --git a/testdata/loops/loop_3/dir1/loop2 b/testdata/loops/loop_3/dir1/loop2 new file mode 120000 index 0000000..d014eb4 --- /dev/null +++ b/testdata/loops/loop_3/dir1/loop2 @@ -0,0 +1 @@ +../dir2 \ No newline at end of file diff --git a/testdata/loops/loop_3/dir2/dir3/loop2 b/testdata/loops/loop_3/dir2/dir3/loop2 new file mode 120000 index 0000000..85babfd --- /dev/null +++ b/testdata/loops/loop_3/dir2/dir3/loop2 @@ -0,0 +1 @@ +../../dir1/ \ No newline at end of file diff --git a/testdata/loops/loop_4/dir1/dir2/loop2 b/testdata/loops/loop_4/dir1/dir2/loop2 new file mode 120000 index 0000000..3fd50ca --- /dev/null +++ b/testdata/loops/loop_4/dir1/dir2/loop2 @@ -0,0 +1 @@ +../dir3 \ No newline at end of file diff --git a/testdata/loops/loop_4/dir1/dir3/dir4/loop1 b/testdata/loops/loop_4/dir1/dir3/dir4/loop1 new file mode 120000 index 0000000..4f388a6 --- /dev/null +++ b/testdata/loops/loop_4/dir1/dir3/dir4/loop1 @@ -0,0 +1 @@ +../../../dir1 \ No newline at end of file diff --git a/testdata/loops/regular_1/dir1/file1 b/testdata/loops/regular_1/dir1/file1 new file mode 100644 index 0000000..e69de29 diff --git a/testdata/loops/regular_1/dir2 b/testdata/loops/regular_1/dir2 new file mode 120000 index 0000000..df490f8 --- /dev/null +++ b/testdata/loops/regular_1/dir2 @@ -0,0 +1 @@ +dir1 \ No newline at end of file diff --git a/testdata/loops/regular_2/dir1/file1 b/testdata/loops/regular_2/dir1/file1 new file mode 100644 index 0000000..e69de29 diff --git a/testdata/loops/regular_2/dir2/dir1 b/testdata/loops/regular_2/dir2/dir1 new file mode 120000 index 0000000..c9f3ab1 --- /dev/null +++ b/testdata/loops/regular_2/dir2/dir1 @@ -0,0 +1 @@ +../dir1 \ No newline at end of file diff --git a/testdata/loops/regular_2/dir2/file2 b/testdata/loops/regular_2/dir2/file2 new file mode 100644 index 0000000..e69de29