Skip to content

Detection of recursion loops in ReadDirRecursive* methods #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ package paths
import (
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -418,14 +417,14 @@ func (p *Path) Chtimes(atime, mtime time.Time) error {

// ReadFile reads the file named by filename and returns the contents
func (p *Path) ReadFile() ([]byte, error) {
return ioutil.ReadFile(p.path)
return os.ReadFile(p.path)
}

// WriteFile writes data to a file named by filename. If the file
// does not exist, WriteFile creates it otherwise WriteFile truncates
// it before writing.
func (p *Path) WriteFile(data []byte) error {
return ioutil.WriteFile(p.path, data, os.FileMode(0644))
return os.WriteFile(p.path, data, os.FileMode(0644))
}

// WriteToTempFile writes data to a newly generated temporary file.
Expand Down
89 changes: 42 additions & 47 deletions readdir.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
package paths

import (
"io/ioutil"
"errors"
"os"
"strings"
)

Expand All @@ -41,7 +42,7 @@ type ReadDirFilter func(file *Path) bool
// ReadDir returns a PathList containing the content of the directory
// pointed by the current Path. The resulting list is filtered by the given filters chained.
func (p *Path) ReadDir(filters ...ReadDirFilter) (PathList, error) {
infos, err := ioutil.ReadDir(p.path)
infos, err := os.ReadDir(p.path)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -69,27 +70,7 @@ func (p *Path) ReadDir(filters ...ReadDirFilter) (PathList, error) {
// ReadDirRecursive returns a PathList containing the content of the directory
// and its subdirectories pointed by the current Path
func (p *Path) ReadDirRecursive() (PathList, error) {
infos, err := ioutil.ReadDir(p.path)
if err != nil {
return nil, err
}
paths := PathList{}
for _, info := range infos {
path := p.Join(info.Name())
paths.Add(path)

if isDir, err := path.IsDirCheck(); err != nil {
return nil, err
} else if isDir {
subPaths, err := path.ReadDirRecursive()
if err != nil {
return nil, err
}
paths.AddAll(subPaths)
}

}
return paths, nil
return p.ReadDirRecursiveFiltered(nil)
}

// ReadDirRecursiveFiltered returns a PathList containing the content of the directory
Expand All @@ -101,41 +82,55 @@ func (p *Path) ReadDirRecursive() (PathList, error) {
// - `filters` are the filters that are checked to determine if the entry should be
// added to the resulting PathList
func (p *Path) ReadDirRecursiveFiltered(recursionFilter ReadDirFilter, filters ...ReadDirFilter) (PathList, error) {
infos, err := ioutil.ReadDir(p.path)
if err != nil {
return nil, err
}
var search func(*Path) (PathList, error)

accept := func(p *Path) bool {
for _, filter := range filters {
if !filter(p) {
return false
}
explored := map[string]bool{}
search = func(currPath *Path) (PathList, error) {
canonical := currPath.Canonical().path
if explored[canonical] {
return nil, errors.New("directories symlink loop detected")
}
return true
}
explored[canonical] = true
defer delete(explored, canonical)

paths := PathList{}
for _, info := range infos {
path := p.Join(info.Name())
infos, err := os.ReadDir(currPath.path)
if err != nil {
return nil, err
}

if accept(path) {
paths.Add(path)
accept := func(p *Path) bool {
for _, filter := range filters {
if !filter(p) {
return false
}
}
return true
}

if recursionFilter == nil || recursionFilter(path) {
if isDir, err := path.IsDirCheck(); err != nil {
return nil, err
} else if isDir {
subPaths, err := path.ReadDirRecursiveFiltered(recursionFilter, filters...)
if err != nil {
paths := PathList{}
for _, info := range infos {
path := currPath.Join(info.Name())

if accept(path) {
paths.Add(path)
}

if recursionFilter == nil || recursionFilter(path) {
if isDir, err := path.IsDirCheck(); err != nil {
return nil, err
} else if isDir {
subPaths, err := search(path)
if err != nil {
return nil, err
}
paths.AddAll(subPaths)
}
paths.AddAll(subPaths)
}
}
return paths, nil
}
return paths, nil

return search(p)
}

// FilterDirectories is a ReadDirFilter that accepts only directories
Expand Down
69 changes: 69 additions & 0 deletions readdir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"fmt"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -245,3 +246,71 @@ func TestReadDirRecursiveFiltered(t *testing.T) {
pathEqualsTo(t, "testdata/fileset/test.txt", l[7])
pathEqualsTo(t, "testdata/fileset/test.txt.gz", l[8])
}

func TestReadDirRecursiveLoopDetection(t *testing.T) {
loopsPath := New("testdata", "loops")
unbuondedReaddir := func(testdir string) (PathList, error) {
// This is required to unbound the recursion, otherwise it will stop
// when the paths becomes too long due to the symlink loop: this is not
// what we want, we are looking for an early detection of the loop.
skipBrokenLinks := func(p *Path) bool {
_, err := p.Stat()
return err == nil
}

var files PathList
var err error
done := make(chan bool)
go func() {
files, err = loopsPath.Join(testdir).ReadDirRecursiveFiltered(
skipBrokenLinks,
)
done <- true
}()
require.Eventually(
t,
func() bool {
select {
case <-done:
return true
default:
return false
}
},
5*time.Second,
10*time.Millisecond,
"Infinite symlink loop while loading sketch",
)
return files, err
}

for _, dir := range []string{"loop_1", "loop_2", "loop_3", "loop_4"} {
l, err := unbuondedReaddir(dir)
require.EqualError(t, err, "directories symlink loop detected", "loop not detected in %s", dir)
require.Nil(t, l)
}

{
l, err := unbuondedReaddir("regular_1")
require.NoError(t, err)
require.Len(t, l, 4)
l.Sort()
pathEqualsTo(t, "testdata/loops/regular_1/dir1", l[0])
pathEqualsTo(t, "testdata/loops/regular_1/dir1/file1", l[1])
pathEqualsTo(t, "testdata/loops/regular_1/dir2", l[2])
pathEqualsTo(t, "testdata/loops/regular_1/dir2/file1", l[3])
}

{
l, err := unbuondedReaddir("regular_2")
require.NoError(t, err)
require.Len(t, l, 6)
l.Sort()
pathEqualsTo(t, "testdata/loops/regular_2/dir1", l[0])
pathEqualsTo(t, "testdata/loops/regular_2/dir1/file1", l[1])
pathEqualsTo(t, "testdata/loops/regular_2/dir2", l[2])
pathEqualsTo(t, "testdata/loops/regular_2/dir2/dir1", l[3])
pathEqualsTo(t, "testdata/loops/regular_2/dir2/dir1/file1", l[4])
pathEqualsTo(t, "testdata/loops/regular_2/dir2/file2", l[5])
}
}
1 change: 1 addition & 0 deletions testdata/loops/loop_1/dir1/loop
1 change: 1 addition & 0 deletions testdata/loops/loop_2/dir1/loop2
1 change: 1 addition & 0 deletions testdata/loops/loop_2/dir2/loop1
1 change: 1 addition & 0 deletions testdata/loops/loop_3/dir1/loop2
1 change: 1 addition & 0 deletions testdata/loops/loop_3/dir2/dir3/loop2
1 change: 1 addition & 0 deletions testdata/loops/loop_4/dir1/dir2/loop2
1 change: 1 addition & 0 deletions testdata/loops/loop_4/dir1/dir3/dir4/loop1
Empty file.
1 change: 1 addition & 0 deletions testdata/loops/regular_1/dir2
Empty file.
1 change: 1 addition & 0 deletions testdata/loops/regular_2/dir2/dir1
Empty file.