Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use code generator for cloning responses #2223

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions gen/SourceGenerator/Extensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) .NET Foundation and Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

namespace SourceGenerator;

static class Extensions {
public static IEnumerable<ClassDeclarationSyntax> FindClasses(this Compilation compilation, Func<ClassDeclarationSyntax, bool> predicate)
=> compilation.SyntaxTrees
.Select(tree => compilation.GetSemanticModel(tree))
.SelectMany(model => model.SyntaxTree.GetRoot().DescendantNodes().OfType<ClassDeclarationSyntax>())
.Where(predicate);

public static IEnumerable<ClassDeclarationSyntax> FindAnnotatedClass(this Compilation compilation, string attributeName, bool strict) {
return compilation.FindClasses(
syntax => syntax.AttributeLists.Any(list => list.Attributes.Any(CheckAttribute))
);

bool CheckAttribute(AttributeSyntax attr) {
var name = attr.Name.ToString();
return strict ? name == attributeName : name.StartsWith(attributeName);
}
}

public static IEnumerable<ITypeSymbol> GetBaseTypesAndThis(this ITypeSymbol type) {
var current = type;

while (current != null) {
yield return current;

current = current.BaseType;
}
}
}
11 changes: 1 addition & 10 deletions gen/SourceGenerator/ImmutableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@
// limitations under the License.
//

using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;

namespace SourceGenerator;

[Generator]
Expand All @@ -28,10 +22,7 @@ public void Initialize(GeneratorInitializationContext context) { }
public void Execute(GeneratorExecutionContext context) {
var compilation = context.Compilation;

var mutableClasses = compilation.SyntaxTrees
.Select(tree => compilation.GetSemanticModel(tree))
.SelectMany(model => model.SyntaxTree.GetRoot().DescendantNodes().OfType<ClassDeclarationSyntax>())
.Where(syntax => syntax.AttributeLists.Any(list => list.Attributes.Any(attr => attr.Name.ToString() == "GenerateImmutable")));
var mutableClasses = compilation.FindAnnotatedClass("GenerateImmutable", strict: true);

foreach (var mutableClass in mutableClasses) {
var immutableClass = GenerateImmutableClass(mutableClass, compilation);
Expand Down
105 changes: 105 additions & 0 deletions gen/SourceGenerator/InheritedCloneGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) .NET Foundation and Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

namespace SourceGenerator;

[Generator]
public class InheritedCloneGenerator : ISourceGenerator {
const string AttributeName = "GenerateClone";

public void Initialize(GeneratorInitializationContext context) { }

public void Execute(GeneratorExecutionContext context) {
var compilation = context.Compilation;

var candidates = compilation.FindAnnotatedClass(AttributeName, false);

foreach (var candidate in candidates) {
var semanticModel = compilation.GetSemanticModel(candidate.SyntaxTree);
var genericClassSymbol = semanticModel.GetDeclaredSymbol(candidate);
if (genericClassSymbol == null) continue;

// Get the method name from the attribute Name argument
var attributeData = genericClassSymbol.GetAttributes().FirstOrDefault(a => a.AttributeClass?.Name == $"{AttributeName}Attribute");
var methodName = (string)attributeData.NamedArguments.FirstOrDefault(arg => arg.Key == "Name").Value.Value;

// Get the generic argument type where properties need to be copied from
var attributeSyntax = candidate.AttributeLists
.SelectMany(l => l.Attributes)
.FirstOrDefault(a => a.Name.ToString().StartsWith(AttributeName));
if (attributeSyntax == null) continue; // This should never happen

var typeArgumentSyntax = ((GenericNameSyntax)attributeSyntax.Name).TypeArgumentList.Arguments[0];
var typeSymbol = (INamedTypeSymbol)semanticModel.GetSymbolInfo(typeArgumentSyntax).Symbol;

var code = GenerateMethod(candidate, genericClassSymbol, typeSymbol, methodName);
context.AddSource($"{genericClassSymbol.Name}.Clone.g.cs", SourceText.From(code, Encoding.UTF8));
}
}

static string GenerateMethod(
TypeDeclarationSyntax classToExtendSyntax,
INamedTypeSymbol classToExtendSymbol,
INamedTypeSymbol classToClone,
string methodName
) {
var namespaceName = classToExtendSymbol.ContainingNamespace.ToDisplayString();
var className = classToExtendSyntax.Identifier.Text;
var genericTypeParameters = string.Join(", ", classToExtendSymbol.TypeParameters.Select(tp => tp.Name));
var classDeclaration = classToExtendSymbol.TypeParameters.Length > 0 ? $"{className}<{genericTypeParameters}>" : className;

var all = classToClone.GetBaseTypesAndThis();
var props = all.SelectMany(x => x.GetMembers().OfType<IPropertySymbol>()).ToArray();
var usings = classToExtendSyntax.SyntaxTree.GetCompilationUnitRoot().Usings.Select(u => u.ToString());

var constructorParams = classToExtendSymbol.Constructors.First().Parameters.ToArray();
var constructorArgs = string.Join(", ", constructorParams.Select(p => $"original.{GetPropertyName(p.Name, props)}"));
var constructorParamNames = constructorParams.Select(p => p.Name).ToArray();

var properties = props
// ReSharper disable once PossibleUnintendedLinearSearchInSet
.Where(prop => !constructorParamNames.Contains(prop.Name, StringComparer.OrdinalIgnoreCase) && prop.SetMethod != null)
.Select(prop => $" {prop.Name} = original.{prop.Name},")
.ToArray();

const string template = """
{Usings}

namespace {Namespace};

public partial class {ClassDeclaration} {
public static {ClassDeclaration} {MethodName}({OriginalClassName} original)
=> new {ClassDeclaration}({ConstructorArgs}) {
{Properties}
};
}
""";

var code = template
.Replace("{Usings}", string.Join("\n", usings))
.Replace("{Namespace}", namespaceName)
.Replace("{ClassDeclaration}", classDeclaration)
.Replace("{OriginalClassName}", classToClone.Name)
.Replace("{MethodName}", methodName)
.Replace("{ConstructorArgs}", constructorArgs)
.Replace("{Properties}", string.Join("\n", properties).TrimEnd(','));

return code;

static string GetPropertyName(string parameterName, IPropertySymbol[] properties) {
var property = properties.FirstOrDefault(p => string.Equals(p.Name, parameterName, StringComparison.OrdinalIgnoreCase));
return property?.Name ?? parameterName;
}
}
}
9 changes: 9 additions & 0 deletions gen/SourceGenerator/Properties/launchSettings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"$schema": "http://json.schemastore.org/launchsettings.json",
"profiles": {
"Generators": {
"commandName": "DebugRoslynComponent",
"targetProject": "../../src/RestSharp/RestSharp.csproj"
}
}
}
14 changes: 11 additions & 3 deletions gen/SourceGenerator/SourceGenerator.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" PrivateAssets="All" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" PrivateAssets="All" />
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" PrivateAssets="All"/>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" PrivateAssets="All"/>
</ItemGroup>
<ItemGroup>
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false"/>
</ItemGroup>
<ItemGroup>

<Using Include="System.Text"/>
<Using Include="Microsoft.CodeAnalysis"/>
<Using Include="Microsoft.CodeAnalysis.CSharp"/>
<Using Include="Microsoft.CodeAnalysis.CSharp.Syntax"/>
<Using Include="Microsoft.CodeAnalysis.Text"/>
</ItemGroup>
</Project>
9 changes: 7 additions & 2 deletions src/RestSharp/Extensions/GenerateImmutableAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
namespace RestSharp.Extensions;

[AttributeUsage(AttributeTargets.Class)]
class GenerateImmutableAttribute : Attribute { }
class GenerateImmutableAttribute : Attribute;

[AttributeUsage(AttributeTargets.Class)]
class GenerateCloneAttribute<T> : Attribute where T : class {
public string? Name { get; set; }
};

[AttributeUsage(AttributeTargets.Property)]
class Exclude : Attribute { }
class Exclude : Attribute;
7 changes: 3 additions & 4 deletions src/RestSharp/Extensions/HttpResponseExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ static class HttpResponseExtensions {
: new HttpRequestException($"Request failed with status code {httpResponse.StatusCode}");
#endif

public static string GetResponseString(this HttpResponseMessage response, byte[] bytes, Encoding clientEncoding) {
public static async Task<string> GetResponseString(this HttpResponseMessage response, byte[] bytes, Encoding clientEncoding) {
var encodingString = response.Content.Headers.ContentType?.CharSet;
var encoding = encodingString != null ? TryGetEncoding(encodingString) : clientEncoding;

using var reader = new StreamReader(new MemoryStream(bytes), encoding);
return reader.ReadToEnd();

return await reader.ReadToEndAsync();
Encoding TryGetEncoding(string es) {
try {
return Encoding.GetEncoding(es);
Expand Down Expand Up @@ -69,4 +68,4 @@ Encoding TryGetEncoding(string es) {
return original == null ? null : streamWriter(original);
}
}
}
}
30 changes: 4 additions & 26 deletions src/RestSharp/Response/RestResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

using System.Diagnostics;
using System.Net;
using System.Text;
using RestSharp.Extensions;

Expand All @@ -25,34 +24,13 @@ namespace RestSharp;
/// Container for data sent back from API including deserialized data
/// </summary>
/// <typeparam name="T">Type of data to deserialize to</typeparam>
[DebuggerDisplay("{" + nameof(DebuggerDisplay) + "()}")]
public class RestResponse<T>(RestRequest request) : RestResponse(request) {
[GenerateClone<RestResponse>(Name = "FromResponse")]
[DebuggerDisplay($"{{{nameof(DebuggerDisplay)}()}}")]
public partial class RestResponse<T>(RestRequest request) : RestResponse(request) {
/// <summary>
/// Deserialized entity data
/// </summary>
public T? Data { get; set; }

public static RestResponse<T> FromResponse(RestResponse response)
=> new(response.Request) {
Content = response.Content,
ContentEncoding = response.ContentEncoding,
ContentHeaders = response.ContentHeaders,
ContentLength = response.ContentLength,
ContentType = response.ContentType,
Cookies = response.Cookies,
ErrorException = response.ErrorException,
ErrorMessage = response.ErrorMessage,
Headers = response.Headers,
IsSuccessStatusCode = response.IsSuccessStatusCode,
RawBytes = response.RawBytes,
ResponseStatus = response.ResponseStatus,
ResponseUri = response.ResponseUri,
RootElement = response.RootElement,
Server = response.Server,
StatusCode = response.StatusCode,
StatusDescription = response.StatusDescription,
Version = response.Version
};
}

/// <summary>
Expand All @@ -78,7 +56,7 @@ async Task<RestResponse> GetDefaultResponse() {
#endif

var bytes = stream == null ? null : await stream.ReadAsBytes(cancellationToken).ConfigureAwait(false);
var content = bytes == null ? null : httpResponse.GetResponseString(bytes, encoding);
var content = bytes == null ? null : await httpResponse.GetResponseString(bytes, encoding);

return new RestResponse(request) {
Content = content,
Expand Down
7 changes: 4 additions & 3 deletions src/RestSharp/Response/RestResponseBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

using System.Diagnostics;
using System.Net;
// ReSharper disable PropertyCanBeMadeInitOnly.Global

namespace RestSharp;

Expand Down Expand Up @@ -65,12 +65,13 @@ protected RestResponseBase(RestRequest request) {
public HttpStatusCode StatusCode { get; set; }

/// <summary>
/// Whether or not the HTTP response status code indicates success
/// Whether the HTTP response status code indicates success
/// </summary>
public bool IsSuccessStatusCode { get; set; }

/// <summary>
/// Whether or not the HTTP response status code indicates success and no other error occurred (deserialization, timeout, ...)
/// Whether the HTTP response status code indicates success and no other error occurred
/// (deserialization, timeout, ...)
/// </summary>
public bool IsSuccessful => IsSuccessStatusCode && ResponseStatus == ResponseStatus.Completed;

Expand Down
2 changes: 1 addition & 1 deletion src/RestSharp/RestClient.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public async Task<RestResponse> ExecuteAsync(RestRequest request, CancellationTo
/// <inheritdoc />
[PublicAPI]
public async Task<Stream?> DownloadStreamAsync(RestRequest request, CancellationToken cancellationToken = default) {
// Make sure we only read the headers so we can stream the content body efficiently
// Make sure we only read the headers, so we can stream the content body efficiently
request.CompletionOption = HttpCompletionOption.ResponseHeadersRead;
var response = await ExecuteRequestAsync(request, cancellationToken).ConfigureAwait(false);

Expand Down
Loading