PostSharp.Samples / PostSharp.Samples.StoredProcedure / StoredProcedureAttribute.cs
using AutoMapper;
using PostSharp.Aspects;
using PostSharp.Extensibility;
using PostSharp.Serialization;
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.SqlClient;
using System.Reflection;
using System.Threading.Tasks;
 
namespace PostSharp.Samples.StoredProcedure
{
  [PSerializable]
  [MulticastAttributeUsage( MulticastTargets.Method, TargetMemberAttributes =MulticastAttributes.Instance)]
  internal class StoredProcedureAttribute : MethodInterceptionAspect
  {
    MethodInfo mapDataReaderMethod;
    MethodInfo mapDataReaderAsyncMethod;
 
    public override void RuntimeInitialize(MethodBase method)
    {
      this.mapDataReaderMethod = this.GetType().GetMethod(nameof(MapDataReader), BindingFlags.NonPublic | BindingFlags.Static);
      this.mapDataReaderAsyncMethod = this.GetType().GetMethod(nameof(MapDataReaderAsync), BindingFlags.NonPublic | BindingFlags.Static);
    }
 
    public override bool CompileTimeValidate(MethodBase method)
    {
      if (method.MethodImplementationFlags != MethodImplAttributes.InternalCall)
      {
        // We silently ignore any non-extern method.
        return false;
      }
 
      var methodInfo =  (MethodInfo) method;
 
      // Validate the base type.
      if (!typeof(BaseDbApi).IsAssignableFrom(methodInfo.DeclaringType))
        {
        Message.Write(method, SeverityType.Error, "SP003", "Cannot apply the [StoredProcedure] aspect to {0} because the method is not declared in a type derived from BaseDbApi.", method);
        return false;
      }
 
   
      // Validate the parameter types.
      var success = true;
      foreach (var parameter in methodInfo.GetParameters())
      {
        if (MapType(parameter.ParameterType) == null)
        {
          Message.Write(parameter, SeverityType.Error, "SP001", "The type of parameter {0} cannot be mapped to a database type.", parameter);
          success = false;
        }
      }
 
 
      // Validate the return type.
      var returnType = methodInfo.ReturnType;
 
      if (returnType == typeof(Task))
      {
        returnType = typeof(void);
      }
      else if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(Task<>))
      {
        returnType = returnType.GetGenericArguments()[0];
      }
 
      if (returnType != typeof(void) &&
        !(methodInfo.ReturnType.IsGenericType && methodInfo.ReturnType.GetGenericTypeDefinition() == typeof(IEnumerable<>)) &&
        !(methodInfo.ReturnType.IsGenericType && methodInfo.ReturnType.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>)))
      {
        Message.Write(methodInfo, SeverityType.Error, "SP002", "The return type of method {0} must be void, IEnumerable<> or IAsyncEnumerable<>.", methodInfo);
      }
 
      return success;
 
    }
 
    public override void OnInvoke(MethodInterceptionArgs args)
    {
      var method = (MethodInfo) args.Method;
      var instance = (BaseDbApi) args.Instance;
      var command = CreateCommand(method, instance, args.Arguments);
 
      var returnType = method.ReturnType;
 
      if (returnType == typeof(void))
      {
        command.ExecuteNonQuery();
      }
      else if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
      {
        var reader = command.ExecuteReader();
        args.ReturnValue = this.mapDataReaderMethod.MakeGenericMethod(returnType.GetGenericArguments()[0]).Invoke(null, new object[] { reader, instance.Mapper });
      }
 
      // TODO: Map out parameters back to method arguments.
 
    }
 
    public override async Task OnInvokeAsync(MethodInterceptionArgs args)
    {
      var method = (MethodInfo) args.Method;
      var instance = (BaseDbApi) args.Instance;
      var command = CreateCommand(method, instance, args.Arguments);
 
      var returnType = method.ReturnType;
 
      if (returnType == typeof(Task))
      {
        await command.ExecuteNonQueryAsync();
      }
      else 
      {
        // Must be a Task<T>,
        returnType = returnType.GetGenericArguments()[0];
 
        if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
        {
          var reader = await command.ExecuteReaderAsync();
          args.ReturnValue = this.mapDataReaderMethod.MakeGenericMethod(returnType.GetGenericArguments()[0]).Invoke(null, new object[] { reader, instance.Mapper });
        }
        else if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>))
        {
          var reader = await command.ExecuteReaderAsync();
          args.ReturnValue = this.mapDataReaderAsyncMethod.MakeGenericMethod(returnType.GetGenericArguments()[0]).Invoke(null, new object[] { reader, instance.Mapper });
        }
 
      }
 
      // TODO: Map out parameters back to method arguments.
 
    }
 
    private static IEnumerable<T> MapDataReader<T>(SqlDataReader reader, IMapper mapper)
    {
      try
      {
        while (reader.Read())
        {
          yield return mapper.Map<IDataRecord, T>(reader);
        }
      }
      finally
      {
        reader.Close();
      }
 
    }
 
    private static async IAsyncEnumerable<T> MapDataReaderAsync<T>(SqlDataReader reader, IMapper mapper)
    {
      try
      {
        while (await reader.ReadAsync())
        {
          yield return mapper.Map<IDataRecord, T>(reader);
        }
      }
      finally
      {
        reader.Close();
      }
 
    }
 
    private static SqlCommand CreateCommand(MethodInfo method, BaseDbApi instance, Arguments arguments)
    {
 
      var methodName = method.Name;
      if (methodName.EndsWith("Async"))
      {
        methodName = methodName.Substring(0, methodName.Length - "Async".Length);
      }
 
      var command = new SqlCommand(methodName)
      {
        Connection = instance.Connection,
        CommandType = CommandType.StoredProcedure,
        Transaction = instance.Transaction
 
      };
 
      foreach (var methodParameter in method.GetParameters())
      {
        if (!methodParameter.IsOut)
        {
          command.Parameters.AddWithValue("@" + methodParameter.Name, arguments[methodParameter.Position]);
        }
        else
        {
          var commandParameter = command.CreateParameter();
          commandParameter.ParameterName = "@" + methodParameter.Name;
          commandParameter.SqlDbType = MapType(methodParameter.ParameterType).Value;
          commandParameter.Direction = ParameterDirection.Output;
 
          command.Parameters.Add(commandParameter);
        }
      }
 
      return command;
    }
 
    private static SqlDbType? MapType(Type type)
    {
      // TODO: handle nullable.
 
      switch (Type.GetTypeCode(type))
      {
        case TypeCode.Boolean:
          return SqlDbType.Bit;
 
        case TypeCode.Byte:
          return SqlDbType.TinyInt;
 
        case TypeCode.SByte:
          return null;
 
        case TypeCode.Int16:
          return SqlDbType.SmallInt;
 
        case TypeCode.UInt16:
          return null;
 
        case TypeCode.Int32:
          return SqlDbType.Int;
 
        case TypeCode.UInt32:
          return null;
 
        case TypeCode.Single:
          return null;
 
        case TypeCode.Double:
          return SqlDbType.Float;
 
        case TypeCode.Char:
          return SqlDbType.NChar;
 
        case TypeCode.Int64:
          return SqlDbType.BigInt;
 
        case TypeCode.UInt64:
          return null;
 
        case TypeCode.Decimal:
          return SqlDbType.Decimal;
 
 
        case TypeCode.DateTime:
          return SqlDbType.DateTime;
 
        case TypeCode.String:
          return SqlDbType.NVarChar;
 
        default:
          return null;
 
      }
 
    }
 
 
  }
}