using JsonRPC4.Common; using JsonRPC4.Router.Abstractions; using JsonRPC4.Router.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Threading.Tasks; namespace JsonRPC4.Router.Defaults { public class DefaultRpcInvoker : IRpcInvoker { private ILogger logger { get; } private IAuthorizationService authorizationService { get; } private IAuthorizationPolicyProvider policyProvider { get; } private IOptions serverConfig { get; } private IRpcRequestMatcher rpcRequestMatcher { get; } private ConcurrentDictionary objectFactoryCache { get; } = new ConcurrentDictionary(); private ConcurrentDictionary, bool)> classAttributeCache { get; } = new ConcurrentDictionary, bool)>(); private ConcurrentDictionary, bool)> methodAttributeCache { get; } = new ConcurrentDictionary, bool)>(); public DefaultRpcInvoker(IAuthorizationService authorizationService, IAuthorizationPolicyProvider policyProvider, ILogger logger, IOptions serverConfig, IRpcRequestMatcher rpcRequestMatcher) { this.authorizationService = authorizationService; this.policyProvider = policyProvider; this.logger = logger; this.serverConfig = serverConfig; this.rpcRequestMatcher = rpcRequestMatcher; } public async Task> InvokeBatchRequestAsync(IList requests, IRouteContext routeContext, RpcPath path = null) { logger.InvokingBatchRequests(requests.Count); List> invokingTasks = new List>(); foreach (RpcRequest request in requests) { Task item = InvokeRequestAsync(request, routeContext, path); if (request.Id.HasValue) { invokingTasks.Add(item); } } await Task.WhenAll(invokingTasks.ToArray()); List result = (from t in invokingTasks select t.Result into r where r != null select r).ToList(); logger.BatchRequestsComplete(); return result; } public async Task InvokeRequestAsync(RpcRequest request, IRouteContext routeContext, RpcPath path = null) { if (request == null) { throw new ArgumentNullException("request"); } logger.InvokingRequest(request.Id); RpcResponse result; try { if (!routeContext.MethodProvider.TryGetByPath(path, out IReadOnlyList methods)) { throw new RpcException(RpcErrorCode.MethodNotFound, $"No methods found with the path: {path}"); } RpcMethodInfo rpcMethod = rpcRequestMatcher.GetMatchingMethod(request, methods); if (await IsAuthorizedAsync(rpcMethod, routeContext)) { logger.InvokeMethod(request.Method); object obj = await InvokeAsync(rpcMethod, path, routeContext.RequestServices); logger.InvokeMethodComplete(request.Method); IRpcMethodResult rpcMethodResult = obj as IRpcMethodResult; result = ((rpcMethodResult == null) ? new RpcResponse(request.Id, obj) : rpcMethodResult.ToRpcResponse(request.Id)); } else { RpcError error = new RpcError(RpcErrorCode.InvalidRequest, "Unauthorized"); result = new RpcResponse(request.Id, error); } } catch (Exception ex) { logger.LogException(ex, "An Rpc error occurred while trying to invoke request."); RpcException ex2 = ex as RpcException; result = new RpcResponse(error: (ex2 == null) ? new RpcError(RpcErrorCode.InternalError, "An Rpc error occurred while trying to invoke request.", ex) : ex2.ToRpcError(serverConfig.Value.ShowServerExceptions), id: request.Id); } if (request.Id.HasValue) { logger.FinishedRequest(request.Id.ToString()); return result; } logger.FinishedRequestNoId(); return null; } private async Task IsAuthorizedAsync(RpcMethodInfo methodInfo, IRouteContext routeContext) { (List, bool) orAdd = classAttributeCache.GetOrAdd(methodInfo.Method.DeclaringType, GetClassAttributeInfo); List item = orAdd.Item1; bool item2 = orAdd.Item2; (List, bool) orAdd2 = methodAttributeCache.GetOrAdd(methodInfo, GetMethodAttributeInfo); List authorizeDataListMethod = orAdd2.Item1; bool item3 = orAdd2.Item2; if (item.Any() || authorizeDataListMethod.Any()) { if (item2 || item3) { logger.SkippingAuth(); } else { logger.RunningAuth(); AuthorizationResult authorizationResult = await CheckAuthorize(item, routeContext); if (authorizationResult.Succeeded) { authorizationResult = await CheckAuthorize(authorizeDataListMethod, routeContext); } if (!authorizationResult.Succeeded) { logger.AuthFailed(); return false; } logger.AuthSuccessful(); } } else { logger.NoConfiguredAuth(); } return true; (List Data, bool allowAnonymous) GetAttributeInfo(IEnumerable attributes) { bool flag = false; List list = new List(10); foreach (Attribute attribute in attributes) { IAuthorizeData authorizeData = attribute as IAuthorizeData; if (authorizeData != null) { list.Add(authorizeData); } if (!flag && attribute is IAllowAnonymous) { flag = true; } } return (list, flag); } (List Data, bool allowAnonymous) GetClassAttributeInfo(Type type) { return GetAttributeInfo(type.GetCustomAttributes()); } (List Data, bool allowAnonymous) GetMethodAttributeInfo(RpcMethodInfo info) { return GetAttributeInfo(info.Method.GetCustomAttributes()); } } private async Task CheckAuthorize(List authorizeDataList, IRouteContext routeContext) { if (!authorizeDataList.Any()) { return AuthorizationResult.Success(); } AuthorizationPolicy policy = await AuthorizationPolicy.CombineAsync(policyProvider, authorizeDataList); return await authorizationService.AuthorizeAsync(routeContext.User, policy); } private async Task InvokeAsync(RpcMethodInfo methodInfo, RpcPath path, IServiceProvider serviceProvider) { object obj = null; if (serviceProvider != null) { obj = objectFactoryCache.GetOrAdd(methodInfo.Method.DeclaringType, (Type t) => ActivatorUtilities.CreateFactory(t, new Type[0]))(serviceProvider, null); } if (obj == null) { obj = Activator.CreateInstance(methodInfo.Method.DeclaringType); } try { return await HandleAsyncResponses(methodInfo.Method.Invoke(obj, methodInfo.Parameters)); } catch (TargetInvocationException ex) { RpcRouteInfo routeInfo = new RpcRouteInfo(methodInfo, path, serviceProvider); RpcErrorFilterAttribute customAttribute = methodInfo.Method.DeclaringType.GetTypeInfo().GetCustomAttribute(); if (customAttribute != null) { OnExceptionResult onExceptionResult = customAttribute.OnException(routeInfo, ex.InnerException); if (!onExceptionResult.ThrowException) { return onExceptionResult.ResponseObject; } Exception ex2 = onExceptionResult.ResponseObject as Exception; if (ex2 != null) { throw ex2; } } throw new RpcException(RpcErrorCode.InternalError, "Exception occurred from target method execution.", ex); } catch (Exception innerException) { throw new RpcException(RpcErrorCode.InvalidParams, "Exception from attempting to invoke method. Possibly invalid parameters for method.", innerException); } } private static async Task HandleAsyncResponses(object returnObj) { Task task = returnObj as Task; if (task == null) { return returnObj; } try { await task; } catch (Exception inner) { throw new TargetInvocationException(inner); } PropertyInfo property = task.GetType().GetProperty("Result"); if (property != null) { return property.GetValue(returnObj); } return null; } } }