Lith 10 місяців тому
батько
коміт
039f38538c

+ 5 - 40
src/Vitorm.MySql/SqlTranslateService.cs

@@ -216,49 +216,14 @@ CREATE TABLE {DelimitTableName(entityDescriptor)} (
         }
         #endregion
 
-
-        public override (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg)
+        public override (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareIdentityAdd(SqlTranslateArgument arg)
         {
-            /* //sql
-             insert into user(name,birth,fatherId,motherId) values('','','');
-             select seq from sqlite_sequence where name='user';
-              */
-            var entityDescriptor = arg.entityDescriptor;
-
-            var columns = entityDescriptor.columns;
+            var result = PrepareAdd(arg, arg.entityDescriptor.columns);
 
-            // #1 GetSqlParams 
-            Func<object, Dictionary<string, object>> GetSqlParams = (entity) =>
-            {
-                var sqlParam = new Dictionary<string, object>();
-                foreach (var column in columns)
-                {
-                    var columnName = column.name;
-                    var value = column.GetValue(entity);
-
-                    sqlParam[columnName] = value;
-                }
-                return sqlParam;
-            };
-
-            #region #2 columns 
-            List<string> columnNames = new List<string>();
-            List<string> valueParams = new List<string>();
-            string columnName;
-
-            foreach (var column in columns)
-            {
-                columnName = column.name;
-
-                columnNames.Add(DelimitIdentifier(columnName));
-                valueParams.Add(GenerateParameterName(columnName));
-            }
-            #endregion
+            // get generated id
+            result.sql += "select last_insert_id();";
 
-            // #3 build sql
-            string sql = $@"insert into {DelimitTableName(entityDescriptor)}({string.Join(",", columnNames)}) values({string.Join(",", valueParams)});";
-            sql += "select last_insert_id();";
-            return (sql, GetSqlParams);
+            return result;
         }
 
         public override (string sql, Dictionary<string, object> sqlParam, IDbDataReader dataReader) PrepareQuery(QueryTranslateArgument arg, CombinedStream combinedStream)

+ 21 - 2
src/Vitorm.SqlServer/SqlTranslateService.cs

@@ -226,9 +226,28 @@ CREATE TABLE {DelimitTableName(entityDescriptor)} (
 
 
 
-        public override (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg)
+        public override EAddType Entity_GetAddType(SqlTranslateArgument arg, object entity)
         {
-            var result = base.PrepareAdd(arg);
+            var key = arg.entityDescriptor.key;
+            if (key == null) return EAddType.noKeyColumn;
+
+            var keyValue = key.GetValue(entity);
+            var keyIsEmpty = keyValue is null || keyValue.Equals(TypeUtil.DefaultValue(arg.entityDescriptor.key.type));
+
+            var keyIsIdentity = key.databaseGenerated == System.ComponentModel.DataAnnotations.Schema.DatabaseGeneratedOption.Identity;
+
+            if (keyIsIdentity)
+            {
+                return keyIsEmpty ? EAddType.identityKey : throw new ArgumentException("Cannot insert explicit value for identity column.");
+            }
+            else
+            {
+                return !keyIsEmpty ? EAddType.keyWithValue : throw new ArgumentException("Key could not be empty.");
+            }
+        }
+        public override (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareIdentityAdd(SqlTranslateArgument arg)
+        {
+            var result = PrepareAdd(arg, arg.entityDescriptor.columns);
 
             // get generated id
             result.sql += "select convert(int,isnull(SCOPE_IDENTITY(),-1));";

+ 0 - 45
src/Vitorm.Sqlite/SqlTranslateService.cs

@@ -192,51 +192,6 @@ CREATE TABLE {DelimitTableName(entityDescriptor)} (
         #endregion
 
 
-        public override (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg)
-        {
-            /* //sql
-             insert into user(name,birth,fatherId,motherId) values('','','');
-             select seq from sqlite_sequence where name='user';
-              */
-            var entityDescriptor = arg.entityDescriptor;
-
-            var columns = entityDescriptor.allColumns;
-
-            // #1 GetSqlParams 
-            Func<object, Dictionary<string, object>> GetSqlParams = (entity) =>
-            {
-                var sqlParam = new Dictionary<string, object>();
-                foreach (var column in columns)
-                {
-                    var columnName = column.name;
-                    var value = column.GetValue(entity);
-
-                    sqlParam[columnName] = value;
-                }
-                return sqlParam;
-            };
-
-            #region #2 columns 
-            List<string> columnNames = new List<string>();
-            List<string> valueParams = new List<string>();
-            string columnName;
-
-            foreach (var column in columns)
-            {
-                columnName = column.name;
-
-                columnNames.Add(DelimitIdentifier(columnName));
-                valueParams.Add(GenerateParameterName(columnName));
-            }
-            #endregion
-
-            // #3 build sql
-            string sql = $@"insert into {DelimitTableName(entityDescriptor)}({string.Join(",", columnNames)}) values({string.Join(",", valueParams)});";
-            //sql+=$"select seq from sqlite_sequence where name = '{tableName}'; ";
-            sql += "select null;";
-            return (sql, GetSqlParams);
-        }
-
         public override (string sql, Dictionary<string, object> sqlParam, IDbDataReader dataReader) PrepareQuery(QueryTranslateArgument arg, CombinedStream combinedStream)
         {
             string sql = queryTranslateService.BuildQuery(arg, combinedStream);

+ 64 - 25
src/Vitorm/Sql/SqlDbContext.cs

@@ -9,6 +9,7 @@ using Vitorm.Sql.Transaction;
 using Vitorm.Sql.SqlTranslate;
 using Vitorm.StreamQuery;
 using Vit.Extensions.Vitorm_Extensions;
+using Vitorm.Entity;
 
 namespace Vitorm.Sql
 {
@@ -87,17 +88,25 @@ namespace Vitorm.Sql
             var entityDescriptor = GetEntityDescriptor(typeof(Entity));
             SqlTranslateArgument arg = new SqlTranslateArgument(this, entityDescriptor);
 
-            // #1 prepare sql
-            (string sql, Func<object, Dictionary<string, object>> GetSqlParams) = sqlTranslateService.PrepareAdd(arg);
+            var addType = sqlTranslateService.Entity_GetAddType(arg,entity);
+            //if (addType == EAddType.unexpectedEmptyKey) throw new ArgumentException("Key could not be empty.");
 
-            // #2 get sql params
-            var sqlParam = GetSqlParams(entity);
 
-            // #3 execute
-            if (entityDescriptor.key.databaseGenerated == System.ComponentModel.DataAnnotations.Schema.DatabaseGeneratedOption.Identity)
+
+
+            if (addType == EAddType.identityKey)
             {
-                var keyType = TypeUtil.GetUnderlyingType(entityDescriptor.key.type);
+                // #1 prepare sql
+                (string sql, Func<object, Dictionary<string, object>> GetSqlParams) = sqlTranslateService.PrepareIdentityAdd(arg);
+
+                // #2 get sql params
+                var sqlParam = GetSqlParams(entity);
+
+                // #3 add
                 var newKeyValue = ExecuteScalar(sql: sql, param: sqlParam);
+
+                // #4 set key value to entity
+                var keyType = TypeUtil.GetUnderlyingType(entityDescriptor.key.type);
                 newKeyValue = TypeUtil.ConvertToUnderlyingType(newKeyValue, keyType);
                 if (newKeyValue != null)
                 {
@@ -106,6 +115,13 @@ namespace Vitorm.Sql
             }
             else
             {
+                // #1 prepare sql
+                (string sql, Func<object, Dictionary<string, object>> GetSqlParams) = sqlTranslateService.PrepareAdd(arg);
+
+                // #2 get sql params
+                var sqlParam = GetSqlParams(entity);
+
+                // #3 add
                 Execute(sql: sql, param: sqlParam);
             }
 
@@ -116,38 +132,61 @@ namespace Vitorm.Sql
             // #0 get arg
             var entityDescriptor = GetEntityDescriptor(typeof(Entity));
             SqlTranslateArgument arg = new SqlTranslateArgument(this, entityDescriptor);
+            List<(Entity entity, EAddType addType)> entityAndTypes = entities.Select(entity => (entity, sqlTranslateService.Entity_GetAddType(arg, entity))).ToList();
+            //if (entityAndTypes.Any(row => row.addType == EAddType.unexpectedEmptyKey)) throw new ArgumentException("Key could not be empty.");
 
-            // #1 prepare sql
-            (string sql, Func<object, Dictionary<string, object>> GetSqlParams) = sqlTranslateService.PrepareAdd(arg);
 
-            // #2 execute
             var affectedRowCount = 0;
 
-            if (entityDescriptor.key.databaseGenerated == System.ComponentModel.DataAnnotations.Schema.DatabaseGeneratedOption.Identity)
+            // #2 keyWithValue
             {
-                var keyType = TypeUtil.GetUnderlyingType(entityDescriptor.key.type);
-                foreach (var entity in entities)
+                var rows = entityAndTypes.Where(row => row.addType == EAddType.keyWithValue);
+                if (rows.Any())
                 {
-                    var sqlParam = GetSqlParams(entity);
-                    var newKeyValue = ExecuteScalar(sql: sql, param: sqlParam);
-                    newKeyValue = TypeUtil.ConvertToUnderlyingType(newKeyValue, keyType);
-                    if (newKeyValue != null)
+                    // ##1 prepare sql
+                    (string sql, Func<object, Dictionary<string, object>> GetSqlParams) = sqlTranslateService.PrepareAdd(arg);
+
+                    foreach ((var entity, _) in rows)
                     {
-                        entityDescriptor.key.SetValue(entity, newKeyValue);
+                        // #2 get sql params
+                        var sqlParam = GetSqlParams(entity);
+
+                        // #3 add
+                        Execute(sql: sql, param: sqlParam);
+                        affectedRowCount++;
                     }
-                    affectedRowCount++;
                 }
             }
-            else
+
+            // #3 identityKey
             {
-                foreach (var entity in entities)
+                var rows = entityAndTypes.Where(row => row.addType == EAddType.identityKey);
+                if (rows.Any())
                 {
-                    var sqlParam = GetSqlParams(entity);
-                    Execute(sql: sql, param: sqlParam);
-                    affectedRowCount++;
+                    var keyType = TypeUtil.GetUnderlyingType(entityDescriptor.key.type);
+
+                    // ##1 prepare sql
+                    (string sql, Func<object, Dictionary<string, object>> GetSqlParams) = sqlTranslateService.PrepareIdentityAdd(arg);
+
+                    foreach ((var entity, _) in rows)
+                    {
+                        // ##2 get sql params
+                        var sqlParam = GetSqlParams(entity);
+
+                        // ##3 add
+                        var newKeyValue = ExecuteScalar(sql: sql, param: sqlParam);
+
+                        // ##4 set key value to entity
+                        newKeyValue = TypeUtil.ConvertToUnderlyingType(newKeyValue, keyType);
+                        if (newKeyValue != null)
+                        {
+                            entityDescriptor.key.SetValue(entity, newKeyValue);
+                        }
+
+                        affectedRowCount++;
+                    }
                 }
             }
-
         }
 
         #endregion

+ 24 - 0
src/Vitorm/Sql/SqlTranslate/EAddType.cs

@@ -0,0 +1,24 @@
+namespace Vitorm.Sql.SqlTranslate
+{
+    public enum EAddType
+    {
+        /// <summary>
+        /// no key column
+        /// </summary>
+        noKeyColumn,
+        /// <summary>
+        /// keyValue is not empty
+        /// </summary>
+        keyWithValue,
+        /// <summary>
+        /// not Identity && keyValue is empty
+        /// </summary>
+        unexpectedEmptyKey,
+        /// <summary>
+        /// Identity && keyValue is empty
+        /// </summary>
+        identityKey,
+
+        unexpectedType,
+    }
+}

+ 2 - 1
src/Vitorm/Sql/SqlTranslate/ISqlTranslateService.cs

@@ -37,8 +37,9 @@ namespace Vitorm.Sql.SqlTranslate
 
 
         // #1 Create :  PrepareAdd
-
+        EAddType Entity_GetAddType(SqlTranslateArgument arg, object entity);
         (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg);
+        (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareIdentityAdd(SqlTranslateArgument arg);
 
 
         // #2 Retrieve : PrepareGet PrepareQuery

+ 29 - 20
src/Vitorm/Sql/SqlTranslate/SqlTranslateService.cs

@@ -8,16 +8,13 @@ using Vitorm.StreamQuery;
 using System.Collections;
 using System.Text;
 using System.Linq.Expressions;
+using static Vitorm.Sql.SqlDbContext;
+using System.Data;
 
 namespace Vitorm.Sql.SqlTranslate
 {
     public abstract class SqlTranslateService : ISqlTranslateService
     {
-        public SqlTranslateService()
-        {
-        }
-
-
 
         #region DelimitIdentifier
         /// <summary>
@@ -296,15 +293,27 @@ namespace Vitorm.Sql.SqlTranslate
 
 
         #region #1 Create :  PrepareAdd
-        public virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg)
+        public virtual EAddType Entity_GetAddType(SqlTranslateArgument arg, object entity)
+        {
+            var key = arg.entityDescriptor.key;
+            if (key == null) return EAddType.noKeyColumn;
+
+            var keyValue = key.GetValue(entity);
+            if (keyValue is not null && !keyValue.Equals(TypeUtil.DefaultValue(arg.entityDescriptor.key.type))) return EAddType.keyWithValue;
+
+            if (key.databaseGenerated == System.ComponentModel.DataAnnotations.Schema.DatabaseGeneratedOption.Identity) return EAddType.identityKey;
+
+            throw new ArgumentException("Key could not be empty.");
+            //return EAddType.unexpectedEmptyKey;
+        }
+
+        protected virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg, IColumnDescriptor[] columns)
         {
             /* //sql
-             insert into user(name,birth,fatherId,motherId) values('','','');
-             select seq from sqlite_sequence where name='user';
+             insert into user(name,fatherId,motherId) values('',0,0);
               */
-            var entityDescriptor = arg.entityDescriptor;
 
-            var columns = entityDescriptor.columns;
+            var entityDescriptor = arg.entityDescriptor;
 
             // #1 GetSqlParams 
             Func<object, Dictionary<string, object>> GetSqlParams = (entity) =>
@@ -312,10 +321,7 @@ namespace Vitorm.Sql.SqlTranslate
                 var sqlParam = new Dictionary<string, object>();
                 foreach (var column in columns)
                 {
-                    var columnName = column.name;
-                    var value = column.GetValue(entity);
-
-                    sqlParam[columnName] = value;
+                    sqlParam[column.name] = column.GetValue(entity);
                 }
                 return sqlParam;
             };
@@ -323,23 +329,26 @@ namespace Vitorm.Sql.SqlTranslate
             #region #2 columns 
             List<string> columnNames = new List<string>();
             List<string> valueParams = new List<string>();
-            string columnName;
 
             foreach (var column in columns)
             {
-                columnName = column.name;
-
-                columnNames.Add(DelimitIdentifier(columnName));
-                valueParams.Add(GenerateParameterName(columnName));
+                columnNames.Add(DelimitIdentifier(column.name));
+                valueParams.Add(GenerateParameterName(column.name));
             }
             #endregion
 
             // #3 build sql
             string sql = $@"insert into {DelimitTableName(entityDescriptor)}({string.Join(",", columnNames)}) values({string.Join(",", valueParams)});";
-            //sql+=$"select seq from sqlite_sequence where name = '{tableName}'; ";
 
             return (sql, GetSqlParams);
         }
+
+        public virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareAdd(SqlTranslateArgument arg)
+        {
+            return PrepareAdd(arg, arg.entityDescriptor.allColumns);
+        }
+
+        public virtual (string sql, Func<object, Dictionary<string, object>> GetSqlParams) PrepareIdentityAdd(SqlTranslateArgument arg) => throw new NotImplementedException();
         #endregion
 
 

+ 1 - 0
test/Vitorm.MySql.MsTest/CustomTest/ExpressionTreeTest/Query_Test.cs

@@ -11,6 +11,7 @@ namespace Vitorm.MsTest.CustomTest
         public void TestQueryable()
         {
             var initUsers = ExpressionTester.GetSourceData();
+            initUsers.ForEach(u => u.id = 0);
 
             using var dbContext = DataSource.CreateDbContextForWriting();
             var dbSet = dbContext.DbSet<ExpressionTester.User>();

+ 3 - 3
test/Vitorm.MySql.MsTest/DataSource.cs

@@ -21,11 +21,11 @@ namespace Vitorm.MsTest
         public string test { get; set; }
 
 
-        public static User NewUser(int id) => new User { id = id, name = "testUser" + id };
+        public static User NewUser(int id, bool forAdd = false) => new User { id = forAdd ? 0 : id, name = "testUser" + id };
 
-        public static List<User> NewUsers(int startId, int count = 1)
+        public static List<User> NewUsers(int startId, int count = 1, bool forAdd = false)
         {
-            return Enumerable.Range(startId, count).Select(NewUser).ToList();
+            return Enumerable.Range(startId, count).Select(id => NewUser(id, forAdd)).ToList();
         }
     }
 

+ 1 - 0
test/Vitorm.SqlServer.MsTest/CustomTest/ExpressionTreeTest/Query_Test.cs

@@ -11,6 +11,7 @@ namespace Vitorm.MsTest.CustomTest
         public void TestQueryable()
         {
             var initUsers = ExpressionTester.GetSourceData();
+            initUsers.ForEach(u => u.id = 0);
 
             using var dbContext = DataSource.CreateDbContextForWriting();
             var dbSet = dbContext.DbSet<ExpressionTester.User>();

+ 3 - 3
test/Vitorm.SqlServer.MsTest/DataSource.cs

@@ -25,11 +25,11 @@ namespace Vitorm.MsTest
         [System.ComponentModel.DataAnnotations.Schema.NotMapped]
         public string test{ get; set; }
 
-        public static User NewUser(int id) => new User { id = id, name = "testUser" + id };
+        public static User NewUser(int id, bool forAdd = false) => new User { id = forAdd ? 0 : id, name = "testUser" + id };
 
-        public static List<User> NewUsers(int startId, int count = 1)
+        public static List<User> NewUsers(int startId, int count = 1, bool forAdd = false)
         {
-            return Enumerable.Range(startId, count).Select(NewUser).ToList();
+            return Enumerable.Range(startId, count).Select(id => NewUser(id, forAdd)).ToList();
         }
     }
 

+ 2 - 2
test/Vitorm.Sqlite.MsTest/CommonTest/CRUD_Test.cs

@@ -19,7 +19,7 @@ namespace Vitorm.MsTest.CommonTest
         {
             using var dbContext = CreateDbContext();
 
-            var newUserList = User.NewUsers(7, 4);
+            var newUserList = User.NewUsers(7, 4, forAdd: true);
 
 
             // #1 Add
@@ -76,7 +76,7 @@ namespace Vitorm.MsTest.CommonTest
 
             // assert
             {
-                var newUserList = User.NewUsers(4, 3);
+                var newUserList = User.NewUsers(4, 3, forAdd: false);
                 var userList = dbContext.Query<User>().Where(m => m.id >= 4).ToList();
                 Assert.AreEqual(newUserList.Count, userList.Count());
                 Assert.AreEqual(0, userList.Select(m => m.id).Except(newUserList.Select(m => m.id)).Count());

+ 3 - 3
test/Vitorm.Sqlite.MsTest/DataSource.cs

@@ -18,11 +18,11 @@ namespace Vitorm.MsTest
         public string test { get; set; }
 
 
-        public static User NewUser(int id) => new User { id = id, name = "testUser" + id };
+        public static User NewUser(int id, bool forAdd = false) => new User { id = id, name = "testUser" + id };
 
-        public static List<User> NewUsers(int startId, int count = 1)
+        public static List<User> NewUsers(int startId, int count = 1, bool forAdd = false)
         {
-            return Enumerable.Range(startId, count).Select(NewUser).ToList();
+            return Enumerable.Range(startId, count).Select(id => NewUser(id, forAdd)).ToList();
         }
     }