Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip directory symlink recursion on TarFile archive creation #74376

Merged
merged 3 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions src/libraries/System.Formats.Tar/src/System/Formats/Tar/TarFile.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Enumeration;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -278,12 +279,20 @@ private static void CreateFromDirectoryInternal(string sourceDirectoryName, Stre
DirectoryInfo di = new(sourceDirectoryName);
string basePath = GetBasePathForCreateFromDirectory(di, includeBaseDirectory);

bool skipBaseDirRecursion = false;
if (includeBaseDirectory)
{
writer.WriteEntry(di.FullName, GetEntryNameForBaseDirectory(di.Name));
skipBaseDirRecursion = (di.Attributes & FileAttributes.ReparsePoint) != 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need skipBaseDirRecursion.

Suggested change
skipBaseDirRecursion = (di.Attributes & FileAttributes.ReparsePoint) != 0;
if ((di.Attributes & FileAttributes.ReparsePoint) != 0)
{
return;
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both suggestions can be fixed in my next PR.

}

foreach (FileSystemInfo file in di.EnumerateFileSystemInfos("*", SearchOption.AllDirectories))
if (skipBaseDirRecursion)
{
// The base directory is a symlink, do not recurse into it
return;
}

foreach (FileSystemInfo file in GetFileSystemEnumerationForCreation(sourceDirectoryName))
{
writer.WriteEntry(file.FullName, GetEntryNameForFileSystemInfo(file, basePath.Length));
}
Expand Down Expand Up @@ -325,18 +334,44 @@ private static async Task CreateFromDirectoryInternalAsync(string sourceDirector
DirectoryInfo di = new(sourceDirectoryName);
string basePath = GetBasePathForCreateFromDirectory(di, includeBaseDirectory);

bool skipBaseDirRecursion = false;
if (includeBaseDirectory)
{
await writer.WriteEntryAsync(di.FullName, GetEntryNameForBaseDirectory(di.Name), cancellationToken).ConfigureAwait(false);
skipBaseDirRecursion = (di.Attributes & FileAttributes.ReparsePoint) != 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}

foreach (FileSystemInfo file in di.EnumerateFileSystemInfos("*", SearchOption.AllDirectories))
if (skipBaseDirRecursion)
{
// The base directory is a symlink, do not recurse into it
return;
}

foreach (FileSystemInfo file in GetFileSystemEnumerationForCreation(sourceDirectoryName))
{
await writer.WriteEntryAsync(file.FullName, GetEntryNameForFileSystemInfo(file, basePath.Length), cancellationToken).ConfigureAwait(false);
}
}
}

// Generates a recursive enumeration of the filesystem entries inside the specified source directory, while
// making sure that directory symlinks do not get recursed.
private static IEnumerable<FileSystemInfo> GetFileSystemEnumerationForCreation(string sourceDirectoryName)
{
return new FileSystemEnumerable<FileSystemInfo>(
directory: sourceDirectoryName,
transform: (ref FileSystemEntry entry) => entry.ToFileSystemInfo(),
options: new EnumerationOptions()
{
RecurseSubdirectories = true
})
{
ShouldRecursePredicate = IsNotADirectorySymlink
};

static bool IsNotADirectorySymlink(ref FileSystemEntry entry) => entry.IsDirectory && (entry.Attributes & FileAttributes.ReparsePoint) == 0;
}

// Determines what should be the base path for all the entries when creating an archive.
private static string GetBasePathForCreateFromDirectory(DirectoryInfo di, bool includeBaseDirectory) =>
includeBaseDirectory && di.Parent != null ? di.Parent.FullName : di.FullName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,65 @@ public void IncludeAllSegmentsOfPath(bool includeBaseDirectory)

Assert.Null(reader.GetNextEntry());
}

[Fact]
public void SkipRecursionIntoDirectorySymlinks()
{
using TempDirectory root = new TempDirectory();

string destinationArchive = Path.Join(root.Path, "destination.tar");

string externalDirectory = Path.Join(root.Path, "externalDirectory");
Directory.CreateDirectory(externalDirectory);

File.Create(Path.Join(externalDirectory, "file.txt")).Dispose();

string sourceDirectoryName = Path.Join(root.Path, "baseDirectory");
Directory.CreateDirectory(sourceDirectoryName);

string subDirectory = Path.Join(sourceDirectoryName, "subDirectory");
Directory.CreateSymbolicLink(subDirectory, externalDirectory); // Should not recurse here

TarFile.CreateFromDirectory(sourceDirectoryName, destinationArchive, includeBaseDirectory: false);

using FileStream archiveStream = File.OpenRead(destinationArchive);
using TarReader reader = new(archiveStream, leaveOpen: false);

TarEntry entry = reader.GetNextEntry();
Assert.NotNull(entry);
Assert.Equal("subDirectory/", entry.Name);
Assert.Equal(TarEntryType.SymbolicLink, entry.EntryType);

Assert.Null(reader.GetNextEntry()); // file.txt should not be found
}

[Fact]
public void SkipRecursionIntoBaseDirectorySymlink()
{
using TempDirectory root = new TempDirectory();

string destinationArchive = Path.Join(root.Path, "destination.tar");

string externalDirectory = Path.Join(root.Path, "externalDirectory");
Directory.CreateDirectory(externalDirectory);

string subDirectory = Path.Join(externalDirectory, "subDirectory");
Directory.CreateDirectory(subDirectory);

string sourceDirectoryName = Path.Join(root.Path, "baseDirectory");
Directory.CreateSymbolicLink(sourceDirectoryName, externalDirectory);

TarFile.CreateFromDirectory(sourceDirectoryName, destinationArchive, includeBaseDirectory: true); // Base directory is a symlink, do not recurse
jozkee marked this conversation as resolved.
Show resolved Hide resolved

using FileStream archiveStream = File.OpenRead(destinationArchive);
using TarReader reader = new(archiveStream, leaveOpen: false);

TarEntry entry = reader.GetNextEntry();
Assert.NotNull(entry);
Assert.Equal("baseDirectory/", entry.Name);
Assert.Equal(TarEntryType.SymbolicLink, entry.EntryType);

Assert.Null(reader.GetNextEntry());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -237,5 +237,65 @@ public async Task IncludeAllSegmentsOfPath_Async(bool includeBaseDirectory)
}
}
}

[Fact]
public async Task SkipRecursionIntoDirectorySymlinksAsync()
{
using TempDirectory root = new TempDirectory();

string destinationArchive = Path.Join(root.Path, "destination.tar");

string externalDirectory = Path.Join(root.Path, "externalDirectory");
Directory.CreateDirectory(externalDirectory);

File.Create(Path.Join(externalDirectory, "file.txt")).Dispose();

string sourceDirectoryName = Path.Join(root.Path, "baseDirectory");
Directory.CreateDirectory(sourceDirectoryName);

string subDirectory = Path.Join(sourceDirectoryName, "subDirectory");
Directory.CreateSymbolicLink(subDirectory, externalDirectory); // Should not recurse here

await TarFile.CreateFromDirectoryAsync(sourceDirectoryName, destinationArchive, includeBaseDirectory: false);

await using FileStream archiveStream = File.OpenRead(destinationArchive);
await using TarReader reader = new(archiveStream, leaveOpen: false);

TarEntry entry = await reader.GetNextEntryAsync();
Assert.NotNull(entry);
Assert.Equal("subDirectory/", entry.Name);
Assert.Equal(TarEntryType.SymbolicLink, entry.EntryType);

Assert.Null(await reader.GetNextEntryAsync()); // file.txt should not be found
}

[Fact]
public async Task SkipRecursionIntoBaseDirectorySymlinkAsync()
{
using TempDirectory root = new TempDirectory();

string destinationArchive = Path.Join(root.Path, "destination.tar");

string externalDirectory = Path.Join(root.Path, "externalDirectory");
Directory.CreateDirectory(externalDirectory);

string subDirectory = Path.Join(externalDirectory, "subDirectory");
Directory.CreateDirectory(subDirectory);

string sourceDirectoryName = Path.Join(root.Path, "baseDirectory");
Directory.CreateSymbolicLink(sourceDirectoryName, externalDirectory);

await TarFile.CreateFromDirectoryAsync(sourceDirectoryName, destinationArchive, includeBaseDirectory: true); // Base directory is a symlink, do not recurse

await using FileStream archiveStream = File.OpenRead(destinationArchive);
await using TarReader reader = new(archiveStream, leaveOpen: false);

TarEntry entry = await reader.GetNextEntryAsync();
Assert.NotNull(entry);
Assert.Equal("baseDirectory/", entry.Name);
Assert.Equal(TarEntryType.SymbolicLink, entry.EntryType);

Assert.Null(await reader.GetNextEntryAsync()); // subDirectory should not be found
}
}
}