SqlTranslateService.cs 25 KB


  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Linq.Expressions;
  6. using System.Runtime.CompilerServices;
  7. using System.Text;
  8. using System.Text.RegularExpressions;
  9. using Vit.Linq.ExpressionTree.ComponentModel;
  10. using Vitorm.Entity;
  11. //using System.Range;
  12. using Vitorm.StreamQuery;
  13. namespace Vitorm.Sql.SqlTranslate
  14. {
  15. public abstract class SqlTranslateService : ISqlTranslateService
  16. {
  17. #region DelimitIdentifier
  18. /// <summary>
  19. /// Generates the delimited SQL representation of an identifier (column name, table name, etc.).
  20. /// </summary>
  21. /// <param name="identifier">The identifier to delimit.</param>
  22. /// <returns>
  23. /// The generated string.
  24. /// </returns>
  25. public virtual string DelimitIdentifier(string identifier) => $"\"{EscapeIdentifier(identifier)}\""; // Interpolation okay; strings
  26. /// <summary>
  27. /// Generates the escaped SQL representation of an identifier (column name, table name, etc.).
  28. /// </summary>
  29. /// <param name="identifier">The identifier to be escaped.</param>
  30. /// <returns>
  31. /// The generated string.
  32. /// </returns>
  33. public virtual string EscapeIdentifier(string identifier) => identifier?.Replace("\"", "\"\"");
  34. /// <summary>
  35. /// Generates a valid parameter name for the given candidate name.
  36. /// </summary>
  37. /// <param name="name">The candidate name for the parameter.</param>
  38. /// <returns>
  39. /// A valid name based on the candidate name.
  40. /// </returns>
  41. public virtual string GenerateParameterName(string name) => name.StartsWith("@", StringComparison.Ordinal) ? name : "@" + name;
  42. public virtual string DelimitTableName(IEntityDescriptor entityDescriptor) => DelimitIdentifier(entityDescriptor.tableName);
  43. #endregion
  44. public virtual string GetSqlField(string tableName, string columnName)
  45. {
  46. return $"{DelimitIdentifier(tableName)}.{DelimitIdentifier(columnName)}";
  47. }
  48. /// <summary>
  49. /// user.id
  50. /// </summary>
  51. /// <param name="member"></param>
  52. /// <param name="dbContext"></param>
  53. /// <returns></returns>
  54. public virtual string GetSqlField(ExpressionNode_Member member, DbContext dbContext)
  55. {
  56. var memberName = member.memberName;
  57. if (string.IsNullOrWhiteSpace(memberName))
  58. {
  59. var entityType = member.Member_GetType();
  60. var entityDescriptor = dbContext.GetEntityDescriptor(entityType);
  61. memberName = entityDescriptor?.keyName;
  62. }
  63. else if (member.objectValue != null)
  64. {
  65. var entityType = member.objectValue.Member_GetType();
  66. if (entityType != null)
  67. {
  68. var entityDescriptor = dbContext.GetEntityDescriptor(entityType);
  69. if (entityDescriptor != null)
  70. {
  71. var columnName = entityDescriptor.GetColumnNameByPropertyName(memberName);
  72. if (string.IsNullOrEmpty(columnName)) throw new NotSupportedException("[QueryTranslator] can not find database column name for property : " + memberName);
  73. memberName = columnName;
  74. }
  75. }
  76. }
  77. // 1: {"nodeType":"Member","parameterName":"a0","memberName":"id"}
  78. // 2: {"nodeType":"Member","objectValue":{"parameterName":"a0","nodeType":"Member"},"memberName":"id"}
  79. return GetSqlField(member.objectValue?.parameterName ?? member.parameterName, memberName);
  80. }
  81. protected abstract string GetColumnDbType(Type type);
  82. #region EvalExpression
  83. /// <summary>
  84. /// evaluate column in select, for example : "select (u.id + 100) as newId"
  85. /// </summary>
  86. /// <param name="arg"></param>
  87. /// <param name="data"></param>
  88. /// <param name="columnType"></param>
  89. /// <returns></returns>
  90. public virtual string EvalSelectExpression(QueryTranslateArgument arg, ExpressionNode data, Type columnType = null)
  91. {
  92. return EvalExpression(arg, data);
  93. }
  94. /// <summary>
  95. /// read where or value or on
  96. /// </summary>
  97. /// <param name="arg"></param>
  98. /// <returns></returns>
  99. /// <exception cref="NotSupportedException"></exception>
  100. /// <param name="data"></param>
  101. public virtual string EvalExpression(QueryTranslateArgument arg, ExpressionNode data)
  102. {
  103. switch (data.nodeType)
  104. {
  105. case NodeType.AndAlso:
  106. {
  107. ExpressionNode_AndAlso and = data;
  108. return $"({EvalExpression(arg, and.left)} and {EvalExpression(arg, and.right)})";
  109. }
  110. case NodeType.OrElse:
  111. {
  112. ExpressionNode_OrElse or = data;
  113. return $"({EvalExpression(arg, or.left)} or {EvalExpression(arg, or.right)})";
  114. }
  115. case NodeType.Not:
  116. {
  117. ExpressionNode_Not not = data;
  118. return $"(not {EvalExpression(arg, not.body)})";
  119. }
  120. case NodeType.ArrayIndex:
  121. {
  122. throw new NotSupportedException(data.nodeType);
  123. //ExpressionNode_ArrayIndex arrayIndex = data;
  124. //return Expression.ArrayIndex(ToExpression(arg, arrayIndex.left), ToExpression(arg, arrayIndex.right));
  125. }
  126. case NodeType.Equal:
  127. case NodeType.NotEqual:
  128. {
  129. ExpressionNode_Binary binary = data;
  130. // "= null" -> "is null" , "!=null" -> "is not null"
  131. if (binary.right.nodeType == NodeType.Constant && binary.right.value == null)
  132. {
  133. var opera = data.nodeType == NodeType.Equal ? "is null" : "is not null";
  134. return $"{EvalExpression(arg, binary.left)} " + opera;
  135. }
  136. else if (binary.left.nodeType == NodeType.Constant && binary.left.value == null)
  137. {
  138. var opera = data.nodeType == NodeType.Equal ? "is null" : "is not null";
  139. return $"{EvalExpression(arg, binary.right)} " + opera;
  140. }
  141. var @operator = operatorMap[data.nodeType];
  142. return $"({EvalExpression(arg, binary.left)} {@operator} {EvalExpression(arg, binary.right)})";
  143. }
  144. case NodeType.LessThan:
  145. case NodeType.LessThanOrEqual:
  146. case NodeType.GreaterThan:
  147. case NodeType.GreaterThanOrEqual:
  148. case nameof(ExpressionType.Divide):
  149. case nameof(ExpressionType.Modulo):
  150. case nameof(ExpressionType.Multiply):
  151. case nameof(ExpressionType.Power):
  152. case nameof(ExpressionType.Subtract):
  153. {
  154. ExpressionNode_Binary binary = data;
  155. var @operator = operatorMap[data.nodeType];
  156. return $"({EvalExpression(arg, binary.left)} {@operator} {EvalExpression(arg, binary.right)})";
  157. }
  158. case nameof(ExpressionType.Negate):
  159. {
  160. ExpressionNode_Unary unary = data;
  161. return $"(-{EvalExpression(arg, unary.body)})";
  162. }
  163. case NodeType.MethodCall:
  164. {
  165. ExpressionNode_MethodCall methodCall = data;
  166. switch (methodCall.methodName)
  167. {
  168. // ##1 in
  169. case nameof(Enumerable.Contains):
  170. {
  171. var values = methodCall.arguments[0];
  172. var member = methodCall.arguments[1];
  173. return $"{EvalExpression(arg, member)} in {EvalExpression(arg, values)}";
  174. }
  175. // ##2 db primitive function
  176. case nameof(DbFunction.Call):
  177. {
  178. var functionName = methodCall.arguments[0].value as string;
  179. var argList = methodCall.arguments.AsQueryable().Skip(1).Select(argNode => EvalExpression(arg, argNode)).ToList();
  180. var funcArgs = string.Join(",", argList);
  181. return $"{functionName}({funcArgs})";
  182. }
  183. #region ##3 Aggregate
  184. case nameof(Enumerable.Count) when methodCall.arguments.Length == 1:
  185. {
  186. var stream = methodCall.arguments[0] as ExpressionNode_Member;
  187. //if (stream?.nodeType != NodeType.Member) break;
  188. return "Count(*)";
  189. }
  190. case nameof(Enumerable.Max) or nameof(Enumerable.Min) or nameof(Enumerable.Sum) or nameof(Enumerable.Average) when methodCall.arguments.Length == 2:
  191. {
  192. var source = methodCall.arguments[0];
  193. if (source?.nodeType != NodeType.Member) break;
  194. var entityType = methodCall.MethodCall_GetParamTypes()[0].GetGenericArguments()[0];
  195. source = TypeUtil.Clone(source).Member_SetType(entityType);
  196. var lambdaFieldSelect = methodCall.arguments[1] as ExpressionNode_Lambda;
  197. var parameterName = lambdaFieldSelect.parameterNames[0];
  198. var parameterValue = source;
  199. ExpressionNode GetParameter(ExpressionNode_Member member)
  200. {
  201. if (member.nodeType == NodeType.Member && member.parameterName == parameterName)
  202. {
  203. if (string.IsNullOrWhiteSpace(member.memberName))
  204. {
  205. return parameterValue;
  206. }
  207. else
  208. {
  209. return ExpressionNode.Member(objectValue: parameterValue, memberName: member.memberName).Member_SetType(member.Member_GetType());
  210. }
  211. }
  212. return default;
  213. }
  214. var body = StreamReader.DeepClone(lambdaFieldSelect.body, GetParameter);
  215. var funcName = methodCall.methodName;
  216. if (funcName == nameof(Enumerable.Average)) funcName = "AVG";
  217. return $"{funcName}({EvalExpression(arg, body)})";
  218. }
  219. #endregion
  220. // ##4 String.Format(format: "{0}_{1}_{2}", "0", "1", "2")
  221. case nameof(String.Format):
  222. {
  223. // convert to ExpressionNode.Add
  224. // "{0}_{1}_{2}"
  225. var format = methodCall.arguments[0].value as string;
  226. var args = methodCall.arguments.AsQueryable().Skip(1).ToArray();
  227. var nodeParts = SplitToNodeParts(format, args);
  228. ExpressionNode nodeForAdd = null;
  229. foreach (var node in nodeParts)
  230. {
  231. if (nodeForAdd == null) nodeForAdd = node;
  232. else nodeForAdd = ExpressionNode.Add(left: nodeForAdd, right: node, typeof(string));
  233. }
  234. return $"({EvalExpression(arg, nodeForAdd)})";
  235. static IEnumerable<ExpressionNode> SplitToNodeParts(string format, ExpressionNode[] args)
  236. {
  237. string pattern = @"(\{\d+\})|([^{}]+)";
  238. var matches = Regex.Matches(format, pattern);
  239. foreach (Match match in matches)
  240. {
  241. var str = match.Value;
  242. if (str.StartsWith("{") && str.EndsWith("}"))
  243. {
  244. var argIndex = int.Parse(str.Substring(1, str.Length - 2));
  245. yield return args[argIndex];
  246. }
  247. else
  248. {
  249. yield return ExpressionNode.Constant(str, typeof(string));
  250. }
  251. }
  252. }
  253. }
  254. }
  255. throw new NotSupportedException("[QueryTranslator] not suported MethodCall: " + methodCall.methodName);
  256. }
  257. #region Read Value
  258. case NodeType.Member:
  259. return GetSqlField(data, arg.dbContext);
  260. case NodeType.Constant:
  261. {
  262. ExpressionNode_Constant constant = data;
  263. var value = constant.value;
  264. if (value == null)
  265. {
  266. return "null";
  267. }
  268. else if (value is not string && value is IEnumerable enumerable)
  269. {
  270. StringBuilder sql = null;
  271. foreach (var item in enumerable)
  272. {
  273. if (item == null) continue;
  274. if (sql == null)
  275. {
  276. sql = new StringBuilder("(");
  277. var paramName = arg.NewParamName();
  278. arg.sqlParam[paramName] = item;
  279. sql.Append(GenerateParameterName(paramName));
  280. }
  281. else
  282. {
  283. var paramName = arg.NewParamName();
  284. arg.sqlParam[paramName] = item;
  285. sql.Append(",").Append(GenerateParameterName(paramName));
  286. }
  287. }
  288. if (sql == null) return "(null)";
  289. return sql.Append(")").ToString();
  290. }
  291. else
  292. {
  293. var paramName = arg.NewParamName();
  294. arg.sqlParam[paramName] = value;
  295. return GenerateParameterName(paramName);
  296. }
  297. }
  298. #endregion
  299. }
  300. throw new NotSupportedException("[QueryTranslator] not suported nodeType: " + data.nodeType);
  301. }
  302. protected readonly static Dictionary<string, string> operatorMap = new Dictionary<string, string>
  303. {
  304. [NodeType.Equal] = "=",
  305. [NodeType.NotEqual] = "!=",
  306. [NodeType.LessThan] = "<",
  307. [NodeType.LessThanOrEqual] = "<=",
  308. [NodeType.GreaterThan] = ">",
  309. [NodeType.GreaterThanOrEqual] = ">=",
  310. [nameof(ExpressionType.Divide)] = "/",
  311. [nameof(ExpressionType.Modulo)] = "%",
  312. [nameof(ExpressionType.Multiply)] = "*",
  313. [nameof(ExpressionType.Power)] = "^",
  314. [nameof(ExpressionType.Subtract)] = "-",
  315. };
  316. #endregion
  317. // #0 Schema : PrepareCreate PrepareDrop
  318. public abstract string PrepareCreate(IEntityDescriptor entityDescriptor);
  319. public abstract string PrepareDrop(IEntityDescriptor entityDescriptor);
  320. #region #1 Create : PrepareAdd
  321. public virtual EAddType Entity_GetAddType(SqlTranslateArgument arg, object entity)
  322. {
  323. var key = arg.entityDescriptor.key;
  324. if (key == null) return EAddType.noKeyColumn;
  325. var keyValue = key.GetValue(entity);
  326. if (keyValue is not null && !keyValue.Equals(TypeUtil.DefaultValue(arg.entityDescriptor.key.type))) return EAddType.keyWithValue;
  327. if (key.isIdentity) return EAddType.identityKey;
  328. throw new ArgumentException("Key could not be empty.");
  329. //return EAddType.unexpectedEmptyKey;
  330. }
  331. protected virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg, IColumnDescriptor[] columns)
  332. {
  333. /* //sql
  334. insert into user(name,fatherId,motherId) values('',0,0);
  335. */
  336. var entityDescriptor = arg.entityDescriptor;
  337. // #1 GetSqlParams
  338. Dictionary<string, object> GetSqlParams(object entity)
  339. {
  340. var sqlParam = new Dictionary<string, object>();
  341. foreach (var column in columns)
  342. {
  343. sqlParam[column.columnName] = column.GetValue(entity);
  344. }
  345. return sqlParam;
  346. }
  347. #region #2 columns
  348. List<string> columnNames = new List<string>();
  349. List<string> valueParams = new List<string>();
  350. foreach (var column in columns)
  351. {
  352. columnNames.Add(DelimitIdentifier(column.columnName));
  353. valueParams.Add(GenerateParameterName(column.columnName));
  354. }
  355. #endregion
  356. // #3 build sql
  357. string sql = $@"insert into {DelimitTableName(entityDescriptor)}({string.Join(",", columnNames)}) values({string.Join(",", valueParams)});";
  358. return (sql, GetSqlParams);
  359. }
  360. public virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg)
  361. {
  362. return PrepareAdd(arg, arg.entityDescriptor.allColumns);
  363. }
  364. public virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareIdentityAdd(SqlTranslateArgument arg) => throw new NotImplementedException();
  365. #endregion
  366. #region #2 Retrieve : PrepareGet PrepareQuery
  367. public virtual string PrepareGet(SqlTranslateArgument arg)
  368. {
  369. var entityDescriptor = arg.entityDescriptor;
  370. // #2 build sql
  371. string sql = $@"select * from {DelimitTableName(entityDescriptor)} where {DelimitIdentifier(entityDescriptor.keyName)}={GenerateParameterName(entityDescriptor.keyName)};";
  372. return sql;
  373. }
  374. protected abstract BaseQueryTranslateService queryTranslateService { get; }
  375. public virtual (string sql, Dictionary<string, object> sqlParam, IDbDataReader dataReader) PrepareQuery(QueryTranslateArgument arg, CombinedStream combinedStream)
  376. {
  377. string sql = queryTranslateService.BuildQuery(arg, combinedStream);
  378. return (sql, arg.sqlParam, arg.dataReader);
  379. }
  380. public virtual (string sql, Dictionary<string, object> sqlParam) PrepareCountQuery(QueryTranslateArgument arg, CombinedStream combinedStream)
  381. {
  382. string sql = queryTranslateService.BuildCountQuery(arg, combinedStream);
  383. return (sql, arg.sqlParam);
  384. }
  385. #endregion
  386. #region #3 Update: PrepareUpdate PrepareExecuteUpdate
  387. public virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareUpdate(SqlTranslateArgument arg)
  388. {
  389. /* //sql
  390. update user set name='' where id=7;
  391. */
  392. var entityDescriptor = arg.entityDescriptor;
  393. var sqlParam = new Dictionary<string, object>();
  394. // #1 GetSqlParams
  395. Dictionary<string, object> GetSqlParams(object entity)
  396. {
  397. var sqlParam = new Dictionary<string, object>();
  398. foreach (var column in entityDescriptor.allColumns)
  399. {
  400. var columnName = column.columnName;
  401. var value = column.GetValue(entity);
  402. sqlParam[columnName] = value;
  403. }
  404. //sqlParam[entityDescriptor.keyName] = entityDescriptor.key.Get(entity);
  405. return sqlParam;
  406. }
  407. // #2 columns
  408. List<string> columnsToUpdate = new List<string>();
  409. string columnName;
  410. foreach (var column in entityDescriptor.columns)
  411. {
  412. columnName = column.columnName;
  413. columnsToUpdate.Add($"{DelimitIdentifier(columnName)}={GenerateParameterName(columnName)}");
  414. }
  415. // #3 build sql
  416. string sql = $@"update {DelimitTableName(entityDescriptor)} set {string.Join(",", columnsToUpdate)} where {DelimitIdentifier(entityDescriptor.keyName)}={GenerateParameterName(entityDescriptor.keyName)};";
  417. return (sql, GetSqlParams);
  418. }
  419. protected abstract BaseQueryTranslateService executeUpdateTranslateService { get; }
  420. public virtual (string sql, Dictionary<string, object> sqlParam) PrepareExecuteUpdate(QueryTranslateArgument arg, CombinedStream combinedStream)
  421. {
  422. string sql = executeUpdateTranslateService.BuildQuery(arg, combinedStream);
  423. return (sql, arg.sqlParam);
  424. }
  425. #endregion
  426. #region #4 Delete: PrepareDelete PrepareDeleteRange PrepareExecuteDelete
  427. public virtual string PrepareDelete(SqlTranslateArgument arg)
  428. {
  429. /* //sql
  430. delete from user where id = 7;
  431. */
  432. var entityDescriptor = arg.entityDescriptor;
  433. // #2 build sql
  434. string sql = $@"delete from {DelimitTableName(entityDescriptor)} where {DelimitIdentifier(entityDescriptor.keyName)}={GenerateParameterName(entityDescriptor.keyName)} ; ";
  435. return sql;
  436. }
  437. public virtual (string sql, Dictionary<string, object> sqlParam) PrepareDeleteByKeys<Key>(SqlTranslateArgument arg, IEnumerable<Key> keys)
  438. {
  439. // delete from user where id in ( 7 ) ;
  440. var entityDescriptor = arg.entityDescriptor;
  441. StringBuilder sql = new StringBuilder();
  442. Dictionary<string, object> sqlParam = new();
  443. sql.Append("delete from ").Append(DelimitTableName(entityDescriptor)).Append(" where ").Append(DelimitIdentifier(entityDescriptor.keyName)).Append(" in (");
  444. int keyIndex = 0;
  445. foreach (var key in keys)
  446. {
  447. var paramName = "p" + (keyIndex++);
  448. sql.Append(GenerateParameterName(paramName)).Append(",");
  449. sqlParam[paramName] = key;
  450. }
  451. if (keyIndex == 0) sql.Append("null);");
  452. else
  453. {
  454. sql.Length--;
  455. sql.Append(");");
  456. }
  457. return (sql.ToString(), sqlParam);
  458. }
  459. protected abstract BaseQueryTranslateService executeDeleteTranslateService { get; }
  460. public virtual (string sql, Dictionary<string, object> sqlParam) PrepareExecuteDelete(QueryTranslateArgument arg, CombinedStream combinedStream)
  461. {
  462. string sql = executeDeleteTranslateService.BuildQuery(arg, combinedStream);
  463. return (sql, arg.sqlParam);
  464. }
  465. #endregion
  466. }
  467. }