123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567 |
- using System;
- using System.Collections;
- using System.Collections.Generic;
- using System.Linq;
- using System.Linq.Expressions;
- using System.Runtime.CompilerServices;
- using System.Text;
- using System.Text.RegularExpressions;
- using Vit.Linq.ExpressionTree.ComponentModel;
- using Vitorm.Entity;
- //using System.Range;
- using Vitorm.StreamQuery;
- namespace Vitorm.Sql.SqlTranslate
- {
- public abstract class SqlTranslateService : ISqlTranslateService
- {
- #region DelimitIdentifier
- /// <summary>
- /// Generates the delimited SQL representation of an identifier (column name, table name, etc.).
- /// </summary>
- /// <param name="identifier">The identifier to delimit.</param>
- /// <returns>
- /// The generated string.
- /// </returns>
- public virtual string DelimitIdentifier(string identifier) => $"\"{EscapeIdentifier(identifier)}\""; // Interpolation okay; strings
- /// <summary>
- /// Generates the escaped SQL representation of an identifier (column name, table name, etc.).
- /// </summary>
- /// <param name="identifier">The identifier to be escaped.</param>
- /// <returns>
- /// The generated string.
- /// </returns>
- public virtual string EscapeIdentifier(string identifier) => identifier?.Replace("\"", "\"\"");
- /// <summary>
- /// Generates a valid parameter name for the given candidate name.
- /// </summary>
- /// <param name="name">The candidate name for the parameter.</param>
- /// <returns>
- /// A valid name based on the candidate name.
- /// </returns>
- public virtual string GenerateParameterName(string name) => name.StartsWith("@", StringComparison.Ordinal) ? name : "@" + name;
- public virtual string DelimitTableName(IEntityDescriptor entityDescriptor) => DelimitIdentifier(entityDescriptor.tableName);
- #endregion
- public virtual string GetSqlField(string tableName, string columnName)
- {
- return $"{DelimitIdentifier(tableName)}.{DelimitIdentifier(columnName)}";
- }
- /// <summary>
- /// user.id
- /// </summary>
- /// <param name="member"></param>
- /// <param name="dbContext"></param>
- /// <returns></returns>
- public virtual string GetSqlField(ExpressionNode_Member member, DbContext dbContext)
- {
- var memberName = member.memberName;
- if (string.IsNullOrWhiteSpace(memberName))
- {
- var entityType = member.Member_GetType();
- var entityDescriptor = dbContext.GetEntityDescriptor(entityType);
- memberName = entityDescriptor?.keyName;
- }
- else if (member.objectValue != null)
- {
- var entityType = member.objectValue.Member_GetType();
- if (entityType != null)
- {
- var entityDescriptor = dbContext.GetEntityDescriptor(entityType);
- if (entityDescriptor != null)
- {
- var columnName = entityDescriptor.GetColumnNameByPropertyName(memberName);
- if (string.IsNullOrEmpty(columnName)) throw new NotSupportedException("[QueryTranslator] can not find database column name for property : " + memberName);
- memberName = columnName;
- }
- }
- }
- // 1: {"nodeType":"Member","parameterName":"a0","memberName":"id"}
- // 2: {"nodeType":"Member","objectValue":{"parameterName":"a0","nodeType":"Member"},"memberName":"id"}
- return GetSqlField(member.objectValue?.parameterName ?? member.parameterName, memberName);
- }
- protected abstract string GetColumnDbType(Type type);
- #region EvalExpression
- /// <summary>
- /// evaluate column in select, for example : "select (u.id + 100) as newId"
- /// </summary>
- /// <param name="arg"></param>
- /// <param name="data"></param>
- /// <param name="columnType"></param>
- /// <returns></returns>
- public virtual string EvalSelectExpression(QueryTranslateArgument arg, ExpressionNode data, Type columnType = null)
- {
- return EvalExpression(arg, data);
- }
- /// <summary>
- /// read where or value or on
- /// </summary>
- /// <param name="arg"></param>
- /// <returns></returns>
- /// <exception cref="NotSupportedException"></exception>
- /// <param name="data"></param>
- public virtual string EvalExpression(QueryTranslateArgument arg, ExpressionNode data)
- {
- switch (data.nodeType)
- {
- case NodeType.AndAlso:
- {
- ExpressionNode_AndAlso and = data;
- return $"({EvalExpression(arg, and.left)} and {EvalExpression(arg, and.right)})";
- }
- case NodeType.OrElse:
- {
- ExpressionNode_OrElse or = data;
- return $"({EvalExpression(arg, or.left)} or {EvalExpression(arg, or.right)})";
- }
- case NodeType.Not:
- {
- ExpressionNode_Not not = data;
- return $"(not {EvalExpression(arg, not.body)})";
- }
- case NodeType.ArrayIndex:
- {
- throw new NotSupportedException(data.nodeType);
- //ExpressionNode_ArrayIndex arrayIndex = data;
- //return Expression.ArrayIndex(ToExpression(arg, arrayIndex.left), ToExpression(arg, arrayIndex.right));
- }
- case NodeType.Equal:
- case NodeType.NotEqual:
- {
- ExpressionNode_Binary binary = data;
- // "= null" -> "is null" , "!=null" -> "is not null"
- if (binary.right.nodeType == NodeType.Constant && binary.right.value == null)
- {
- var opera = data.nodeType == NodeType.Equal ? "is null" : "is not null";
- return $"{EvalExpression(arg, binary.left)} " + opera;
- }
- else if (binary.left.nodeType == NodeType.Constant && binary.left.value == null)
- {
- var opera = data.nodeType == NodeType.Equal ? "is null" : "is not null";
- return $"{EvalExpression(arg, binary.right)} " + opera;
- }
- var @operator = operatorMap[data.nodeType];
- return $"({EvalExpression(arg, binary.left)} {@operator} {EvalExpression(arg, binary.right)})";
- }
- case NodeType.LessThan:
- case NodeType.LessThanOrEqual:
- case NodeType.GreaterThan:
- case NodeType.GreaterThanOrEqual:
- case nameof(ExpressionType.Divide):
- case nameof(ExpressionType.Modulo):
- case nameof(ExpressionType.Multiply):
- case nameof(ExpressionType.Power):
- case nameof(ExpressionType.Subtract):
- {
- ExpressionNode_Binary binary = data;
- var @operator = operatorMap[data.nodeType];
- return $"({EvalExpression(arg, binary.left)} {@operator} {EvalExpression(arg, binary.right)})";
- }
- case nameof(ExpressionType.Negate):
- {
- ExpressionNode_Unary unary = data;
- return $"(-{EvalExpression(arg, unary.body)})";
- }
- case NodeType.MethodCall:
- {
- ExpressionNode_MethodCall methodCall = data;
- switch (methodCall.methodName)
- {
- // ##1 in
- case nameof(Enumerable.Contains):
- {
- var values = methodCall.arguments[0];
- var member = methodCall.arguments[1];
- return $"{EvalExpression(arg, member)} in {EvalExpression(arg, values)}";
- }
- // ##2 db primitive function
- case nameof(DbFunction.Call):
- {
- var functionName = methodCall.arguments[0].value as string;
- var argList = methodCall.arguments.AsQueryable().Skip(1).Select(argNode => EvalExpression(arg, argNode)).ToList();
- var funcArgs = string.Join(",", argList);
- return $"{functionName}({funcArgs})";
- }
- #region ##3 Aggregate
- case nameof(Enumerable.Count) when methodCall.arguments.Length == 1:
- {
- var stream = methodCall.arguments[0] as ExpressionNode_Member;
- //if (stream?.nodeType != NodeType.Member) break;
- return "Count(*)";
- }
- case nameof(Enumerable.Max) or nameof(Enumerable.Min) or nameof(Enumerable.Sum) or nameof(Enumerable.Average) when methodCall.arguments.Length == 2:
- {
- var source = methodCall.arguments[0];
- if (source?.nodeType != NodeType.Member) break;
- var entityType = methodCall.MethodCall_GetParamTypes()[0].GetGenericArguments()[0];
- source = TypeUtil.Clone(source).Member_SetType(entityType);
- var lambdaFieldSelect = methodCall.arguments[1] as ExpressionNode_Lambda;
- var parameterName = lambdaFieldSelect.parameterNames[0];
- var parameterValue = source;
- ExpressionNode GetParameter(ExpressionNode_Member member)
- {
- if (member.nodeType == NodeType.Member && member.parameterName == parameterName)
- {
- if (string.IsNullOrWhiteSpace(member.memberName))
- {
- return parameterValue;
- }
- else
- {
- return ExpressionNode.Member(objectValue: parameterValue, memberName: member.memberName).Member_SetType(member.Member_GetType());
- }
- }
- return default;
- }
- var body = StreamReader.DeepClone(lambdaFieldSelect.body, GetParameter);
- var funcName = methodCall.methodName;
- if (funcName == nameof(Enumerable.Average)) funcName = "AVG";
- return $"{funcName}({EvalExpression(arg, body)})";
- }
- #endregion
- // ##4 String.Format(format: "{0}_{1}_{2}", "0", "1", "2")
- case nameof(String.Format):
- {
- // convert to ExpressionNode.Add
- // "{0}_{1}_{2}"
- var format = methodCall.arguments[0].value as string;
- var args = methodCall.arguments.AsQueryable().Skip(1).ToArray();
- var nodeParts = SplitToNodeParts(format, args);
- ExpressionNode nodeForAdd = null;
- foreach (var node in nodeParts)
- {
- if (nodeForAdd == null) nodeForAdd = node;
- else nodeForAdd = ExpressionNode.Add(left: nodeForAdd, right: node, typeof(string));
- }
- return $"({EvalExpression(arg, nodeForAdd)})";
- static IEnumerable<ExpressionNode> SplitToNodeParts(string format, ExpressionNode[] args)
- {
- string pattern = @"(\{\d+\})|([^{}]+)";
- var matches = Regex.Matches(format, pattern);
- foreach (Match match in matches)
- {
- var str = match.Value;
- if (str.StartsWith("{") && str.EndsWith("}"))
- {
- var argIndex = int.Parse(str.Substring(1, str.Length - 2));
- yield return args[argIndex];
- }
- else
- {
- yield return ExpressionNode.Constant(str, typeof(string));
- }
- }
- }
- }
- }
- throw new NotSupportedException("[QueryTranslator] not suported MethodCall: " + methodCall.methodName);
- }
- #region Read Value
- case NodeType.Member:
- return GetSqlField(data, arg.dbContext);
- case NodeType.Constant:
- {
- ExpressionNode_Constant constant = data;
- var value = constant.value;
- if (value == null)
- {
- return "null";
- }
- else if (value is not string && value is IEnumerable enumerable)
- {
- StringBuilder sql = null;
- foreach (var item in enumerable)
- {
- if (item == null) continue;
- if (sql == null)
- {
- sql = new StringBuilder("(");
- var paramName = arg.NewParamName();
- arg.sqlParam[paramName] = item;
- sql.Append(GenerateParameterName(paramName));
- }
- else
- {
- var paramName = arg.NewParamName();
- arg.sqlParam[paramName] = item;
- sql.Append(",").Append(GenerateParameterName(paramName));
- }
- }
- if (sql == null) return "(null)";
- return sql.Append(")").ToString();
- }
- else
- {
- var paramName = arg.NewParamName();
- arg.sqlParam[paramName] = value;
- return GenerateParameterName(paramName);
- }
- }
- #endregion
- }
- throw new NotSupportedException("[QueryTranslator] not suported nodeType: " + data.nodeType);
- }
- protected readonly static Dictionary<string, string> operatorMap = new Dictionary<string, string>
- {
- [NodeType.Equal] = "=",
- [NodeType.NotEqual] = "!=",
- [NodeType.LessThan] = "<",
- [NodeType.LessThanOrEqual] = "<=",
- [NodeType.GreaterThan] = ">",
- [NodeType.GreaterThanOrEqual] = ">=",
- [nameof(ExpressionType.Divide)] = "/",
- [nameof(ExpressionType.Modulo)] = "%",
- [nameof(ExpressionType.Multiply)] = "*",
- [nameof(ExpressionType.Power)] = "^",
- [nameof(ExpressionType.Subtract)] = "-",
- };
- #endregion
- // #0 Schema : PrepareCreate PrepareDrop
- public abstract string PrepareCreate(IEntityDescriptor entityDescriptor);
- public abstract string PrepareDrop(IEntityDescriptor entityDescriptor);
- #region #1 Create : PrepareAdd
- public virtual EAddType Entity_GetAddType(SqlTranslateArgument arg, object entity)
- {
- var key = arg.entityDescriptor.key;
- if (key == null) return EAddType.noKeyColumn;
- var keyValue = key.GetValue(entity);
- if (keyValue is not null && !keyValue.Equals(TypeUtil.DefaultValue(arg.entityDescriptor.key.type))) return EAddType.keyWithValue;
- if (key.isIdentity) return EAddType.identityKey;
- throw new ArgumentException("Key could not be empty.");
- //return EAddType.unexpectedEmptyKey;
- }
- protected virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg, IColumnDescriptor[] columns)
- {
- /* //sql
- insert into user(name,fatherId,motherId) values('',0,0);
- */
- var entityDescriptor = arg.entityDescriptor;
- // #1 GetSqlParams
- Dictionary<string, object> GetSqlParams(object entity)
- {
- var sqlParam = new Dictionary<string, object>();
- foreach (var column in columns)
- {
- sqlParam[column.columnName] = column.GetValue(entity);
- }
- return sqlParam;
- }
- #region #2 columns
- List<string> columnNames = new List<string>();
- List<string> valueParams = new List<string>();
- foreach (var column in columns)
- {
- columnNames.Add(DelimitIdentifier(column.columnName));
- valueParams.Add(GenerateParameterName(column.columnName));
- }
- #endregion
- // #3 build sql
- string sql = $@"insert into {DelimitTableName(entityDescriptor)}({string.Join(",", columnNames)}) values({string.Join(",", valueParams)});";
- return (sql, GetSqlParams);
- }
- public virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg)
- {
- return PrepareAdd(arg, arg.entityDescriptor.allColumns);
- }
- public virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareIdentityAdd(SqlTranslateArgument arg) => throw new NotImplementedException();
- #endregion
- #region #2 Retrieve : PrepareGet PrepareQuery
- public virtual string PrepareGet(SqlTranslateArgument arg)
- {
- var entityDescriptor = arg.entityDescriptor;
- // #2 build sql
- string sql = $@"select * from {DelimitTableName(entityDescriptor)} where {DelimitIdentifier(entityDescriptor.keyName)}={GenerateParameterName(entityDescriptor.keyName)};";
- return sql;
- }
- protected abstract BaseQueryTranslateService queryTranslateService { get; }
- public virtual (string sql, Dictionary<string, object> sqlParam, IDbDataReader dataReader) PrepareQuery(QueryTranslateArgument arg, CombinedStream combinedStream)
- {
- string sql = queryTranslateService.BuildQuery(arg, combinedStream);
- return (sql, arg.sqlParam, arg.dataReader);
- }
- public virtual (string sql, Dictionary<string, object> sqlParam) PrepareCountQuery(QueryTranslateArgument arg, CombinedStream combinedStream)
- {
- string sql = queryTranslateService.BuildCountQuery(arg, combinedStream);
- return (sql, arg.sqlParam);
- }
- #endregion
- #region #3 Update: PrepareUpdate PrepareExecuteUpdate
- public virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareUpdate(SqlTranslateArgument arg)
- {
- /* //sql
- update user set name='' where id=7;
- */
- var entityDescriptor = arg.entityDescriptor;
- var sqlParam = new Dictionary<string, object>();
- // #1 GetSqlParams
- Dictionary<string, object> GetSqlParams(object entity)
- {
- var sqlParam = new Dictionary<string, object>();
- foreach (var column in entityDescriptor.allColumns)
- {
- var columnName = column.columnName;
- var value = column.GetValue(entity);
- sqlParam[columnName] = value;
- }
- //sqlParam[entityDescriptor.keyName] = entityDescriptor.key.Get(entity);
- return sqlParam;
- }
- // #2 columns
- List<string> columnsToUpdate = new List<string>();
- string columnName;
- foreach (var column in entityDescriptor.columns)
- {
- columnName = column.columnName;
- columnsToUpdate.Add($"{DelimitIdentifier(columnName)}={GenerateParameterName(columnName)}");
- }
- // #3 build sql
- string sql = $@"update {DelimitTableName(entityDescriptor)} set {string.Join(",", columnsToUpdate)} where {DelimitIdentifier(entityDescriptor.keyName)}={GenerateParameterName(entityDescriptor.keyName)};";
- return (sql, GetSqlParams);
- }
- protected abstract BaseQueryTranslateService executeUpdateTranslateService { get; }
- public virtual (string sql, Dictionary<string, object> sqlParam) PrepareExecuteUpdate(QueryTranslateArgument arg, CombinedStream combinedStream)
- {
- string sql = executeUpdateTranslateService.BuildQuery(arg, combinedStream);
- return (sql, arg.sqlParam);
- }
- #endregion
- #region #4 Delete: PrepareDelete PrepareDeleteRange PrepareExecuteDelete
- public virtual string PrepareDelete(SqlTranslateArgument arg)
- {
- /* //sql
- delete from user where id = 7;
- */
- var entityDescriptor = arg.entityDescriptor;
- // #2 build sql
- string sql = $@"delete from {DelimitTableName(entityDescriptor)} where {DelimitIdentifier(entityDescriptor.keyName)}={GenerateParameterName(entityDescriptor.keyName)} ; ";
- return sql;
- }
- public virtual (string sql, Dictionary<string, object> sqlParam) PrepareDeleteByKeys<Key>(SqlTranslateArgument arg, IEnumerable<Key> keys)
- {
- // delete from user where id in ( 7 ) ;
- var entityDescriptor = arg.entityDescriptor;
- StringBuilder sql = new StringBuilder();
- Dictionary<string, object> sqlParam = new();
- sql.Append("delete from ").Append(DelimitTableName(entityDescriptor)).Append(" where ").Append(DelimitIdentifier(entityDescriptor.keyName)).Append(" in (");
- int keyIndex = 0;
- foreach (var key in keys)
- {
- var paramName = "p" + (keyIndex++);
- sql.Append(GenerateParameterName(paramName)).Append(",");
- sqlParam[paramName] = key;
- }
- if (keyIndex == 0) sql.Append("null);");
- else
- {
- sql.Length--;
- sql.Append(");");
- }
- return (sql.ToString(), sqlParam);
- }
- protected abstract BaseQueryTranslateService executeDeleteTranslateService { get; }
- public virtual (string sql, Dictionary<string, object> sqlParam) PrepareExecuteDelete(QueryTranslateArgument arg, CombinedStream combinedStream)
- {
- string sql = executeDeleteTranslateService.BuildQuery(arg, combinedStream);
- return (sql, arg.sqlParam);
- }
- #endregion
- }
- }
|