SqlDbSet.cs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Vit.Linq.ExpressionTree.CollectionsQuery;
  5. using Vit.Linq.ExpressionTree.ComponentModel;
  6. using System.Linq.Expressions;
  7. using Vit.Orm.Entity;
  8. using System.Reflection;
  9. using Vit.Linq;
  10. using Vit.Orm.Sql.DataReader;
  11. using Vit.Extensions.Linq_Extensions;
  12. namespace Vit.Orm.Sql
  13. {
  14. public class SqlDbSetConstructor
  15. {
  16. public static IDbSet CreateDbSet( SqlDbContext dbContext, Type entityType, IEntityDescriptor entityDescriptor)
  17. {
  18. return _CreateDbSet.MakeGenericMethod(entityType)
  19. .Invoke(null, new object[] { dbContext, entityDescriptor }) as IDbSet;
  20. }
  21. static MethodInfo _CreateDbSet = new Func<SqlDbContext, IEntityDescriptor,IDbSet>(CreateDbSet<object>)
  22. .Method.GetGenericMethodDefinition();
  23. public static IDbSet CreateDbSet<Entity>(SqlDbContext dbContext, IEntityDescriptor entityDescriptor)
  24. {
  25. return new SqlDbSet<Entity>(dbContext, entityDescriptor);
  26. }
  27. }
  28. public class SqlDbSet<Entity> : Vit.Orm.DbSet<Entity>
  29. {
  30. protected SqlDbContext dbContext;
  31. protected IEntityDescriptor _entityDescriptor;
  32. public override IEntityDescriptor entityDescriptor => _entityDescriptor;
  33. public virtual ISqlTranslator sqlTranslator => dbContext.sqlTranslator;
  34. public SqlDbSet(SqlDbContext dbContext, IEntityDescriptor entityDescriptor)
  35. {
  36. this.dbContext = dbContext;
  37. this._entityDescriptor = entityDescriptor;
  38. }
  39. public override void Create()
  40. {
  41. string sql = sqlTranslator.PrepareCreate(entityDescriptor);
  42. dbContext.Execute(sql: sql);
  43. }
  44. public override Entity Add(Entity entity)
  45. {
  46. // #1 prepare sql
  47. (string sql, Func<Entity, Dictionary<string, object>> GetSqlParams) = sqlTranslator.PrepareAdd(this);
  48. // #2 get sql params
  49. var sqlParam = GetSqlParams(entity);
  50. // #3 execute
  51. var affectedRowCount = dbContext.Execute(sql: sql, param: (object)sqlParam);
  52. return affectedRowCount == 1 ? entity : default;
  53. }
  54. public override void AddRange(IEnumerable<Entity> entitys)
  55. {
  56. // #1 prepare sql
  57. (string sql, Func<Entity, Dictionary<string, object>> GetSqlParams) = sqlTranslator.PrepareAdd(this);
  58. // #2 execute
  59. var affectedRowCount = 0;
  60. foreach (var entity in entitys)
  61. {
  62. var sqlParam = GetSqlParams(entity);
  63. if (dbContext.Execute(sql: sql, param: (object)sqlParam) == 1)
  64. affectedRowCount++;
  65. }
  66. }
  67. public override Entity Get(object keyValue)
  68. {
  69. // #1 prepare sql
  70. string sql = sqlTranslator.PrepareGet(this);
  71. // #2 get sql params
  72. var sqlParam = new Dictionary<string, object>();
  73. sqlParam[entityDescriptor.keyName] = keyValue;
  74. // #3 execute
  75. using var reader = dbContext.ExecuteReader(sql: sql, param: (object)sqlParam);
  76. if (reader.Read())
  77. {
  78. var entity = (Entity)Activator.CreateInstance(typeof(Entity));
  79. foreach (var column in entityDescriptor.allColumns)
  80. {
  81. column.Set(entity, TypeUtil.ConvertToType(reader[column.name], column.type));
  82. }
  83. return entity;
  84. }
  85. return default;
  86. }
  87. public override IQueryable<Entity> Query()
  88. {
  89. var dbContextId = "SqlDbSet_" + dbContext.GetHashCode();
  90. Func<Expression, Type, object> QueryExecutor = (expression, type) =>
  91. {
  92. // #1 convert to ExpressionNode
  93. // (query) => query.Where().OrderBy().Skip().Take().Select().ToList();
  94. // (users) => users.SelectMany(
  95. // user => users.Where(father => (father.id == user.fatherId)).DefaultIfEmpty(),
  96. // (user, father) => new <>f__AnonymousType4`2(user = user, father = father)
  97. // ).Where().Select();
  98. var isArgument = QueryableBuilder.QueryTypeNameCompare(dbContextId);
  99. ExpressionNode node = dbContext.convertService.ConvertToData(expression, autoReduce: true, isArgument: isArgument);
  100. //var strNode = Json.Serialize(node);
  101. // #2 convert to Streams
  102. // {select,left,joins,where,order,skip,take}
  103. var stream = StreamReader.ReadNode(node);
  104. //var strStream = Json.Serialize(stream);
  105. // #3.1 ExecuteUpdate
  106. if (stream is StreamToUpdate streamToUpdate)
  107. {
  108. (string sql, Dictionary<string, object> sqlParam) = sqlTranslator.PrepareExecuteUpdate(streamToUpdate);
  109. return dbContext.Execute(sql: sql, param: (object)sqlParam);
  110. }
  111. // #3.3 Query
  112. // #3.3.1
  113. var combinedStream = stream as CombinedStream;
  114. if (combinedStream == null) combinedStream = new CombinedStream("tmp") { source = stream };
  115. // #3.3.2 execute and read result
  116. switch (combinedStream.method)
  117. {
  118. case nameof(Queryable_Extensions.ToExecuteString):
  119. {
  120. // ToExecuteString
  121. (string sql, Dictionary<string, object> sqlParam, IDbDataReader dataReader) = sqlTranslator.PrepareQuery(combinedStream, entityType: null);
  122. return sql;
  123. }
  124. case "Count":
  125. {
  126. // Count
  127. (string sql, Dictionary<string, object> sqlParam, IDbDataReader dataReader) = sqlTranslator.PrepareQuery(combinedStream, entityType: null);
  128. var count = dbContext.ExecuteScalar(sql: sql, param: (object)sqlParam);
  129. return Convert.ToInt32(count);
  130. }
  131. case nameof(Queryable_Extensions.ExecuteDelete):
  132. {
  133. // ExecuteDelete
  134. (string sql, Dictionary<string, object> sqlParam) = sqlTranslator.PrepareExecuteDelete(combinedStream);
  135. var count = dbContext.Execute(sql: sql, param: (object)sqlParam);
  136. return count;
  137. }
  138. case "FirstOrDefault" or "First" or "LastOrDefault" or "Last":
  139. {
  140. var entityType = expression.Type;
  141. (string sql, Dictionary<string, object> sqlParam, IDbDataReader dataReader) = sqlTranslator.PrepareQuery(combinedStream, entityType);
  142. using var reader = dbContext.ExecuteReader(sql: sql, param: (object)sqlParam);
  143. return dataReader.ReadData(reader);
  144. }
  145. case "ToList":
  146. case "":
  147. case null:
  148. {
  149. // ToList
  150. var entityType = expression.Type.GetGenericArguments()?.FirstOrDefault();
  151. (string sql, Dictionary<string, object> sqlParam, IDbDataReader dataReader) = sqlTranslator.PrepareQuery(combinedStream, entityType);
  152. using var reader = dbContext.ExecuteReader(sql: sql, param: (object)sqlParam);
  153. return dataReader.ReadData(reader);
  154. }
  155. }
  156. throw new NotSupportedException("not supported query type: " + combinedStream.method);
  157. };
  158. return QueryableBuilder.Build<Entity>(QueryExecutor, dbContextId);
  159. }
  160. public override int Update(Entity entity)
  161. {
  162. // #1 prepare sql
  163. (string sql, Func<Entity, Dictionary<string, object>> GetSqlParams) = sqlTranslator.PrepareUpdate(this);
  164. // #2 get sql params
  165. var sqlParam = GetSqlParams(entity);
  166. // #3 execute
  167. var affectedRowCount = dbContext.Execute(sql: sql, param: (object)sqlParam);
  168. return affectedRowCount;
  169. }
  170. public override int UpdateRange(IEnumerable<Entity> entitys)
  171. {
  172. // #1 prepare sql
  173. (string sql, Func<Entity, Dictionary<string, object>> GetSqlParams) = sqlTranslator.PrepareUpdate(this);
  174. // #2 execute
  175. var affectedRowCount = 0;
  176. foreach (var entity in entitys)
  177. {
  178. var sqlParam = GetSqlParams(entity);
  179. affectedRowCount += dbContext.Execute(sql: sql, param: (object)sqlParam);
  180. }
  181. return affectedRowCount;
  182. }
  183. public override int Delete(Entity entity)
  184. {
  185. var key = entityDescriptor.key.Get(entity);
  186. return DeleteByKey(key);
  187. }
  188. public override int DeleteRange(IEnumerable<Entity> entitys)
  189. {
  190. var keys = entitys.Select(entity => entityDescriptor.key.Get(entity)).ToList();
  191. return DeleteByKeys(keys);
  192. }
  193. public override int DeleteByKey(object keyValue)
  194. {
  195. // #1 prepare sql
  196. string sql = sqlTranslator.PrepareDelete(this);
  197. // #2 get sql params
  198. var sqlParam = new Dictionary<string, object>();
  199. sqlParam[entityDescriptor.keyName] = keyValue;
  200. // #3 execute
  201. var affectedRowCount = dbContext.Execute(sql: sql, param: (object)sqlParam);
  202. return affectedRowCount;
  203. }
  204. public override int DeleteByKeys<Key>(IEnumerable<Key> keys)
  205. {
  206. // #1 prepare sql
  207. string sql = sqlTranslator.PrepareDeleteRange(this);
  208. // #2 get sql params
  209. var sqlParam = new Dictionary<string, object>();
  210. sqlParam["keys"] = keys;
  211. // #3 execute
  212. var affectedRowCount = dbContext.Execute(sql: sql, param: (object)sqlParam);
  213. return affectedRowCount;
  214. }
  215. }
  216. }