Skip to content

Commit

Permalink
Header value null check (#2241)
Browse files Browse the repository at this point in the history
* Throw an exception when adding a header with `null` value
* Add multiple header values properly
* Default headers should allow multiple values
  • Loading branch information
alexeyzimarev authored Jul 11, 2024
1 parent dd52ff6 commit 4ddda24
Show file tree
Hide file tree
Showing 22 changed files with 172 additions and 94 deletions.
14 changes: 7 additions & 7 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
<PackageVersion Include="Newtonsoft.Json" Version="13.0.3" />
<PackageVersion Include="CsvHelper" Version="33.0.1" />
<PackageVersion Include="PolySharp" Version="1.14.1" />
<PackageVersion Include="System.Text.Json" Version="8.0.3" />
<PackageVersion Include="WireMock.Net" Version="1.5.51" />
<PackageVersion Include="System.Text.Json" Version="8.0.4" />
<PackageVersion Include="WireMock.Net" Version="1.5.60" />
<PackageVersion Include="WireMock.Net.FluentAssertions" Version="1.5.51" />
</ItemGroup>
<ItemGroup Label="Compile dependencies">
Expand All @@ -28,22 +28,22 @@
<PackageVersion Include="Nullable" Version="1.3.1" />
<PackageVersion Include="Microsoft.NETFramework.ReferenceAssemblies.net472" Version="1.0.3" />
<PackageVersion Include="Microsoft.SourceLink.GitHub" Version="8.0.0" />
<PackageVersion Include="JetBrains.Annotations" Version="2023.3.0" />
<PackageVersion Include="JetBrains.Annotations" Version="2024.2.0" />
</ItemGroup>
<ItemGroup Label="Testing dependencies">
<PackageVersion Include="AutoFixture" Version="4.18.1" />
<PackageVersion Include="coverlet.collector" Version="6.0.2" />
<PackageVersion Include="FluentAssertions" Version="6.12.0" />
<PackageVersion Include="HttpTracer" Version="2.1.1" />
<PackageVersion Include="Microsoft.AspNetCore.TestHost" Version="$(MicrosoftTestHostVer)" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.9.0" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.10.0" />
<PackageVersion Include="Moq" Version="4.20.70" />
<PackageVersion Include="Polly" Version="8.3.1" />
<PackageVersion Include="Polly" Version="8.4.1" />
<PackageVersion Include="rest-mock-core" Version="0.7.12" />
<PackageVersion Include="RichardSzalay.MockHttp" Version="7.0.0" />
<PackageVersion Include="System.Net.Http.Json" Version="8.0.0" />
<PackageVersion Include="Xunit.Extensions.Logging" Version="1.1.0" />
<PackageVersion Include="xunit.runner.visualstudio" Version="2.5.7" PrivateAssets="All" />
<PackageVersion Include="xunit" Version="2.8.1" />
<PackageVersion Include="xunit.runner.visualstudio" Version="2.8.2" PrivateAssets="All" />
<PackageVersion Include="xunit" Version="2.9.0" />
</ItemGroup>
</Project>
25 changes: 14 additions & 11 deletions gen/SourceGenerator/ImmutableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,24 @@ public void Execute(GeneratorExecutionContext context) {

static string GenerateImmutableClass(TypeDeclarationSyntax mutableClass, Compilation compilation) {
var containingNamespace = compilation.GetSemanticModel(mutableClass.SyntaxTree).GetDeclaredSymbol(mutableClass)!.ContainingNamespace;

var namespaceName = containingNamespace.ToDisplayString();

var className = mutableClass.Identifier.Text;

var usings = mutableClass.SyntaxTree.GetCompilationUnitRoot().Usings.Select(u => u.ToString());
var namespaceName = containingNamespace.ToDisplayString();
var className = mutableClass.Identifier.Text;
var usings = mutableClass.SyntaxTree.GetCompilationUnitRoot().Usings.Select(u => u.ToString());

var properties = GetDefinitions(SyntaxKind.SetKeyword)
.Select(prop => $" public {prop.Type} {prop.Identifier.Text} {{ get; }}")
.Select(
prop => {
var xml = prop.GetLeadingTrivia().FirstOrDefault(x => x.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia)).GetStructure();
return $"/// {xml} public {prop.Type} {prop.Identifier.Text} {{ get; }}";
}
)
.ToArray();

var props = GetDefinitions(SyntaxKind.SetKeyword).ToArray();

const string argName = "inner";
var mutableProperties = props
.Select(prop => $" {prop.Identifier.Text} = {argName}.{prop.Identifier.Text};");

var mutableProperties = props.Select(prop => $" {prop.Identifier.Text} = {argName}.{prop.Identifier.Text};");

var constructor = $$"""
public ReadOnly{{className}}({{className}} {{argName}}) {
Expand Down Expand Up @@ -85,7 +87,8 @@ IEnumerable<PropertyDeclarationSyntax> GetDefinitions(SyntaxKind kind)
.OfType<PropertyDeclarationSyntax>()
.Where(
prop =>
prop.AccessorList!.Accessors.Any(accessor => accessor.Keyword.IsKind(kind)) && prop.AttributeLists.All(list => list.Attributes.All(attr => attr.Name.ToString() != "Exclude"))
prop.AccessorList!.Accessors.Any(accessor => accessor.Keyword.IsKind(kind)) &&
prop.AttributeLists.All(list => list.Attributes.All(attr => attr.Name.ToString() != "Exclude"))
);
}
}
}
5 changes: 3 additions & 2 deletions src/RestSharp/Authenticators/JwtAuthenticator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

namespace RestSharp.Authenticators;
namespace RestSharp.Authenticators;

/// <summary>
/// JSON WEB TOKEN (JWT) Authenticator class.
Expand All @@ -26,7 +26,8 @@ public class JwtAuthenticator(string accessToken) : AuthenticatorBase(GetToken(a
[PublicAPI]
public void SetBearerToken(string accessToken) => Token = GetToken(accessToken);

static string GetToken(string accessToken) => Ensure.NotEmpty(accessToken, nameof(accessToken)).StartsWith("Bearer ") ? accessToken : $"Bearer {accessToken}";
static string GetToken(string accessToken)
=> Ensure.NotEmptyString(accessToken, nameof(accessToken)).StartsWith("Bearer ") ? accessToken : $"Bearer {accessToken}";

protected override ValueTask<Parameter> GetAuthenticationParameter(string accessToken)
=> new(new HeaderParameter(KnownHeaders.Authorization, accessToken));
Expand Down
12 changes: 6 additions & 6 deletions src/RestSharp/Authenticators/OAuth/OAuthWorkflow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ sealed class OAuthWorkflow {
/// <param name="parameters">Any existing, non-OAuth query parameters desired in the request</param>
/// <returns></returns>
public OAuthParameters BuildRequestTokenSignature(string method, WebPairCollection parameters) {
Ensure.NotEmpty(ConsumerKey, nameof(ConsumerKey));
Ensure.NotEmptyString(ConsumerKey, nameof(ConsumerKey));

var allParameters = new WebPairCollection();
allParameters.AddRange(parameters);
Expand Down Expand Up @@ -76,8 +76,8 @@ public OAuthParameters BuildRequestTokenSignature(string method, WebPairCollecti
/// <param name="method">The HTTP method for the intended request</param>
/// <param name="parameters">Any existing, non-OAuth query parameters desired in the request</param>
public OAuthParameters BuildAccessTokenSignature(string method, WebPairCollection parameters) {
Ensure.NotEmpty(ConsumerKey, nameof(ConsumerKey));
Ensure.NotEmpty(Token, nameof(Token));
Ensure.NotEmptyString(ConsumerKey, nameof(ConsumerKey));
Ensure.NotEmptyString(Token, nameof(Token));

var allParameters = new WebPairCollection();
allParameters.AddRange(parameters);
Expand Down Expand Up @@ -105,8 +105,8 @@ public OAuthParameters BuildAccessTokenSignature(string method, WebPairCollectio
/// <param name="method">The HTTP method for the intended request</param>
/// <param name="parameters">Any existing, non-OAuth query parameters desired in the request</param>
public OAuthParameters BuildClientAuthAccessTokenSignature(string method, WebPairCollection parameters) {
Ensure.NotEmpty(ConsumerKey, nameof(ConsumerKey));
Ensure.NotEmpty(ClientUsername, nameof(ClientUsername));
Ensure.NotEmptyString(ConsumerKey, nameof(ConsumerKey));
Ensure.NotEmptyString(ClientUsername, nameof(ClientUsername));

var allParameters = new WebPairCollection();
allParameters.AddRange(parameters);
Expand All @@ -127,7 +127,7 @@ public OAuthParameters BuildClientAuthAccessTokenSignature(string method, WebPai
}

public OAuthParameters BuildProtectedResourceSignature(string method, WebPairCollection parameters) {
Ensure.NotEmpty(ConsumerKey, nameof(ConsumerKey));
Ensure.NotEmptyString(ConsumerKey, nameof(ConsumerKey));

var allParameters = new WebPairCollection();
allParameters.AddRange(parameters);
Expand Down
7 changes: 3 additions & 4 deletions src/RestSharp/Ensure.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ namespace RestSharp;
static class Ensure {
public static T NotNull<T>(T? value, [InvokerParameterName] string name) => value ?? throw new ArgumentNullException(name);

public static string NotEmpty(string? value, [InvokerParameterName] string name)
=> string.IsNullOrWhiteSpace(value) ? throw new ArgumentNullException(name) : value!;

public static string NotEmptyString(object? value, [InvokerParameterName] string name) {
var s = value as string ?? value?.ToString();
return string.IsNullOrWhiteSpace(s) ? throw new ArgumentNullException(name) : s!;
if (s == null) throw new ArgumentNullException(name);

return string.IsNullOrWhiteSpace(s) ? throw new ArgumentException("Parameter cannot be an empty string", name) : s;
}
}
1 change: 1 addition & 0 deletions src/RestSharp/Options/RestClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ public int MaxTimeout {

/// <summary>
/// Set to true to allow multiple default parameters with the same name. Default is false.
/// This setting doesn't apply to headers as multiple header values for the same key is allowed.
/// </summary>
public bool AllowMultipleDefaultParametersWithSameName { get; set; }

Expand Down
9 changes: 4 additions & 5 deletions src/RestSharp/Parameters/DefaultParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ public sealed class DefaultParameters(ReadOnlyRestClientOptions options) : Param
[MethodImpl(MethodImplOptions.Synchronized)]
public DefaultParameters AddParameter(Parameter parameter) {
if (parameter.Type == ParameterType.RequestBody)
throw new NotSupportedException(
"Cannot set request body using default parameters. Use Request.AddBody() instead."
);
throw new NotSupportedException("Cannot set request body using default parameters. Use Request.AddBody() instead.");

if (!options.AllowMultipleDefaultParametersWithSameName &&
!MultiParameterTypes.Contains(parameter.Type) &&
parameter.Type != ParameterType.HttpHeader &&
!MultiParameterTypes.Contains(parameter.Type) &&
this.Any(x => x.Name == parameter.Name)) {
throw new ArgumentException("A default parameters with the same name has already been added", nameof(parameter));
}
Expand Down Expand Up @@ -70,4 +69,4 @@ public DefaultParameters ReplaceParameter(Parameter parameter)
.AddParameter(parameter);

static readonly ParameterType[] MultiParameterTypes = [ParameterType.QueryString, ParameterType.GetOrPost];
}
}
11 changes: 10 additions & 1 deletion src/RestSharp/Parameters/HeaderParameter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,14 @@ public record HeaderParameter : Parameter {
/// </summary>
/// <param name="name">Parameter name</param>
/// <param name="value">Parameter value</param>
public HeaderParameter(string? name, string? value) : base(name, value, ParameterType.HttpHeader, false) { }
public HeaderParameter(string name, string value)
: base(
Ensure.NotEmptyString(name, nameof(name)),
Ensure.NotNull(value, nameof(value)),
ParameterType.HttpHeader,
false
) { }

public new string Name => base.Name!;
public new string Value => (string)base.Value!;
}
40 changes: 37 additions & 3 deletions src/RestSharp/Parameters/Parameter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,66 @@
// See the License for the specific language governing permissions and
// limitations under the License.

using System.Diagnostics;

namespace RestSharp;

/// <summary>
/// Parameter container for REST requests
/// </summary>
public abstract record Parameter(string? Name, object? Value, ParameterType Type, bool Encode) {
[DebuggerDisplay($"{{{nameof(DebuggerDisplay)}()}}")]
public abstract record Parameter {
/// <summary>
/// Parameter container for REST requests
/// </summary>
protected Parameter(string? name, object? value, ParameterType type, bool encode) {
Name = name;
Value = value;
Type = type;
Encode = encode;
}

/// <summary>
/// MIME content type of the parameter
/// </summary>
public ContentType ContentType { get; protected init; } = ContentType.Undefined;
public string? Name { get; }
public object? Value { get; }
public ParameterType Type { get; }
public bool Encode { get; }

/// <summary>
/// Return a human-readable representation of this parameter
/// </summary>
/// <returns>String</returns>
public sealed override string ToString() => Value == null ? $"{Name}" : $"{Name}={Value}";
public sealed override string ToString() => Value == null ? $"{Name}" : $"{Name}={ValueString}";

protected virtual string ValueString => Value?.ToString() ?? "null";

public static Parameter CreateParameter(string? name, object? value, ParameterType type, bool encode = true)
// ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault
=> type switch {
ParameterType.GetOrPost => new GetOrPostParameter(Ensure.NotEmptyString(name, nameof(name)), value?.ToString(), encode),
ParameterType.UrlSegment => new UrlSegmentParameter(Ensure.NotEmptyString(name, nameof(name)), value?.ToString()!, encode),
ParameterType.HttpHeader => new HeaderParameter(name, value?.ToString()),
ParameterType.HttpHeader => new HeaderParameter(name!, value?.ToString()!),
ParameterType.QueryString => new QueryParameter(Ensure.NotEmptyString(name, nameof(name)), value?.ToString(), encode),
_ => throw new ArgumentOutOfRangeException(nameof(type), type, null)
};

[PublicAPI]
public void Deconstruct(out string? name, out object? value, out ParameterType type, out bool encode) {
name = Name;
value = Value;
type = Type;
encode = Encode;
}

/// <summary>
/// Assists with debugging by displaying in the debugger output
/// </summary>
/// <returns></returns>
[UsedImplicitly]
protected string DebuggerDisplay() => $"{GetType().Name.Replace("Parameter", "")} {ToString()}";
}

public record NamedParameter : Parameter {
Expand Down
20 changes: 11 additions & 9 deletions src/RestSharp/Parameters/ParametersCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,27 @@

namespace RestSharp;

public abstract class ParametersCollection : IReadOnlyCollection<Parameter> {
protected readonly List<Parameter> Parameters = [];
public abstract class ParametersCollection<T> : IReadOnlyCollection<T> where T : Parameter {
protected readonly List<T> Parameters = [];

// public ParametersCollection(IEnumerable<Parameter> parameters) => _parameters.AddRange(parameters);

static readonly Func<Parameter, string?, bool> SearchPredicate = (p, name)
static readonly Func<T, string?, bool> SearchPredicate = (p, name)
=> p.Name != null && p.Name.Equals(name, StringComparison.InvariantCultureIgnoreCase);

public bool Exists(Parameter parameter) => Parameters.Any(p => SearchPredicate(p, parameter.Name) && p.Type == parameter.Type);
public bool Exists(T parameter) => Parameters.Any(p => SearchPredicate(p, parameter.Name) && p.Type == parameter.Type);

public Parameter? TryFind(string parameterName) => Parameters.FirstOrDefault(x => SearchPredicate(x, parameterName));
public T? TryFind(string parameterName) => Parameters.FirstOrDefault(x => SearchPredicate(x, parameterName));

public IEnumerable<Parameter> GetParameters(ParameterType parameterType) => Parameters.Where(x => x.Type == parameterType);

public IEnumerable<T> GetParameters<T>() where T : class => Parameters.OfType<T>();
public IEnumerable<TParameter> GetParameters<TParameter>() where TParameter : class, T => Parameters.OfType<TParameter>();

public IEnumerator<Parameter> GetEnumerator() => Parameters.GetEnumerator();
public IEnumerator<T> GetEnumerator() => Parameters.GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

public int Count => Parameters.Count;
}

public abstract class ParametersCollection : ParametersCollection<Parameter> {
public IEnumerable<Parameter> GetParameters(ParameterType parameterType) => Parameters.Where(x => x.Type == parameterType);
}
2 changes: 1 addition & 1 deletion src/RestSharp/Parameters/RequestParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public ParametersCollection AddParameters(IEnumerable<Parameter> parameters) {
}

/// <summary>
/// Add parameters from another parameters collection
/// Add parameters from another parameter collection
/// </summary>
/// <param name="parameters"></param>
/// <returns></returns>
Expand Down
2 changes: 1 addition & 1 deletion src/RestSharp/Parameters/UrlSegmentParameter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public partial record UrlSegmentParameter : NamedParameter {
public UrlSegmentParameter(string name, string value, bool encode = true)
: base(
name,
RegexPattern.Replace(Ensure.NotEmpty(value, nameof(value)), "/"),
RegexPattern.Replace(Ensure.NotEmptyString(value, nameof(value)), "/"),
ParameterType.UrlSegment,
encode
) { }
Expand Down
12 changes: 6 additions & 6 deletions src/RestSharp/Request/HttpRequestMessageExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ namespace RestSharp;

static class HttpRequestMessageExtensions {
public static void AddHeaders(this HttpRequestMessage message, RequestHeaders headers) {
var headerParameters = headers.Parameters.Where(x => !KnownHeaders.IsContentHeader(x.Name!));
var headerParameters = headers.Where(x => !KnownHeaders.IsContentHeader(x.Name));

headerParameters.ForEach(x => AddHeader(x, message.Headers));
headerParameters.GroupBy(x => x.Name).ForEach(x => AddHeader(x, message.Headers));
return;

void AddHeader(Parameter parameter, HttpHeaders httpHeaders) {
var parameterStringValue = parameter.Value!.ToString();
void AddHeader(IGrouping<string, HeaderParameter> group, HttpHeaders httpHeaders) {
var parameterStringValues = group.Select(x => x.Value);

httpHeaders.Remove(parameter.Name!);
httpHeaders.TryAddWithoutValidation(parameter.Name!, parameterStringValue);
httpHeaders.Remove(group.Key);
httpHeaders.TryAddWithoutValidation(group.Key, parameterStringValues);
}
}
}
16 changes: 7 additions & 9 deletions src/RestSharp/Request/RequestHeaders.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,17 @@

namespace RestSharp;

class RequestHeaders {
public RequestParameters Parameters { get; } = new();

class RequestHeaders : ParametersCollection<HeaderParameter> {
public RequestHeaders AddHeaders(ParametersCollection parameters) {
Parameters.AddParameters(parameters.GetParameters<HeaderParameter>());
Parameters.AddRange(parameters.GetParameters<HeaderParameter>());
return this;
}

// Add Accept header based on registered deserializers if the caller has set none.
public RequestHeaders AddAcceptHeader(string[] acceptedContentTypes) {
if (Parameters.TryFind(KnownHeaders.Accept) == null) {
if (TryFind(KnownHeaders.Accept) == null) {
var accepts = acceptedContentTypes.JoinToString(", ");
Parameters.AddParameter(new HeaderParameter(KnownHeaders.Accept, accepts));
Parameters.Add(new(KnownHeaders.Accept, accepts));
}

return this;
Expand All @@ -46,13 +44,13 @@ public RequestHeaders AddCookieHeaders(Uri uri, CookieContainer? cookieContainer
if (string.IsNullOrWhiteSpace(cookies)) return this;

var newCookies = SplitHeader(cookies);
var existing = Parameters.GetParameters<HeaderParameter>().FirstOrDefault(x => x.Name == KnownHeaders.Cookie);
var existing = GetParameters<HeaderParameter>().FirstOrDefault(x => x.Name == KnownHeaders.Cookie);

if (existing?.Value != null) {
newCookies = newCookies.Union(SplitHeader(existing.Value.ToString()!));
newCookies = newCookies.Union(SplitHeader(existing.Value!));
}

Parameters.AddParameter(new HeaderParameter(KnownHeaders.Cookie, string.Join("; ", newCookies)));
Parameters.Add(new(KnownHeaders.Cookie, string.Join("; ", newCookies)));

return this;

Expand Down
Loading

0 comments on commit 4ddda24

Please sign in to comment.