Skip to content

Commit

Permalink
add saveAll
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Sep 19, 2024
1 parent 1107671 commit 83c0940
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ public Object intercept(Invocation invocation) throws Throwable {
Collection<Object> objects;
if (parameter instanceof MapperMethod.ParamMap) {
MapperMethod.ParamMap<Object> paramMap = ((MapperMethod.ParamMap<Object>) parameter);
if (paramMap.get("param1") instanceof Collection) {
if (!paramMap.containsKey("param1") && paramMap.containsKey("arg0")) {
objects = ((Collection<Object>) paramMap.get("arg0"));
} else if (paramMap.get("param1") instanceof Collection) {
objects = ((Collection<Object>) paramMap.get("param1"));
} else {
objects = Collections.singletonList(paramMap.get("param1"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ public interface BaseDao<Entity> {
@InsertProvider(type = BaseSqlProvider.class, method = "insert")
int save(Entity entity);

/**
* Insert all of the entity.
*/
@Options(useGeneratedKeys = true, keyProperty = "id", keyColumn = "id")
@InsertProvider(type = BaseSqlProvider.class, method = "insertList")
int saveAll(List<Entity> entities);

/**
* Partially update the entity by primary key.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.io.Serializable;
import java.lang.reflect.Type;
import java.util.Collection;
import java.util.List;

@Slf4j
public class BaseSqlProvider {
Expand All @@ -46,6 +47,19 @@ public <Entity> String insert(Entity entity, ProviderContext context) {
return SQLBuilder.insert(tableMetaData, entity, databaseId);
}

public <Entity> String insertList(List<Entity> entities, ProviderContext context) {
Assert.notNull(entities, "entities must not be null");
Assert.notEmpty(entities, "entities list must not be empty");

String databaseId = context.getDatabaseId();

Class<?> entityClass = entities.get(0).getClass();

TableMetaData tableMetaData = TableMetaData.forClass(entityClass);

return SQLBuilder.insertList(tableMetaData, entities, databaseId);
}

public <Entity> String partialUpdateById(Entity entity, ProviderContext context) {
Assert.notNull(entity, "entity must not null");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.beans.PropertyDescriptor;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -98,6 +99,99 @@ public static <Entity> String insert(TableMetaData tableMetaData, Entity entity,
return sql.toString();
}

public static <Entity> String insertList(TableMetaData tableMetaData, List<Entity> entities, String databaseId) {
if (entities == null || entities.isEmpty()) {
throw new IllegalArgumentException("Entities list must not be null or empty");
}

Class<?> entityClass = entities.get(0).getClass();
Map<String, String> fieldColumnMap = tableMetaData.getFieldColumnMap();

SQL sql = new SQL();

switch (DBType.toType(databaseId)) {
case MYSQL: {
sql.INSERT_INTO(tableMetaData.getTableName());

boolean firstRow = true;
List<String> columns = new ArrayList<>();
int idx = 0;
for (Entity entity : entities) {
List<String> values = new ArrayList<>();
for (Map.Entry<String, String> entry : fieldColumnMap.entrySet()) {
// Ignore primary key
if (Objects.equals(entry.getKey(), tableMetaData.getPkProperty())) {
continue;
}
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(entityClass, entry.getKey());
if (ps == null || ps.getReadMethod() == null) {
continue;
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
if (!ObjectUtils.isEmpty(value)) {
if (firstRow) {
sql.VALUES(
"`" + entry.getValue() + "`",
getTokenParam("arg0[" + idx + "]." + entry.getKey()));
}
values.add(getTokenParam("arg0[" + idx + "]." + entry.getKey()));
}
}
if (firstRow) {
firstRow = false;
} else {
sql.ADD_ROW();
sql.INTO_VALUES(values.toArray(new String[0]));
}
idx++;
}
break;
}
case POSTGRESQL: {
sql.INSERT_INTO("\"" + tableMetaData.getTableName() + "\"");

boolean firstRow = true;
List<String> columns = new ArrayList<>();
int idx = 0;
for (Entity entity : entities) {
List<String> values = new ArrayList<>();
for (Map.Entry<String, String> entry : fieldColumnMap.entrySet()) {
// Ignore primary key
if (Objects.equals(entry.getKey(), tableMetaData.getPkProperty())) {
continue;
}
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(entityClass, entry.getKey());
if (ps == null || ps.getReadMethod() == null) {
continue;
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
if (!ObjectUtils.isEmpty(value)) {
if (firstRow) {
sql.VALUES(
"\"" + entry.getValue() + "\"",
getTokenParam("arg0[" + idx + "]." + entry.getKey()));
}
values.add(getTokenParam("arg0[" + idx + "]." + entry.getKey()));
}
}
if (firstRow) {
firstRow = false;
} else {
sql.ADD_ROW();
sql.INTO_VALUES(values.toArray(new String[0]));
}
idx++;
}
}

default: {
log.error("Unsupported data source");
}
}

return sql.toString();
}

public static <Entity> String partialUpdate(TableMetaData tableMetaData, Entity entity, String databaseId) {
Class<?> entityClass = entity.getClass();
Map<String, String> fieldColumnMap = tableMetaData.getFieldColumnMap();
Expand Down Expand Up @@ -165,6 +259,9 @@ public static <Entity> String update(TableMetaData tableMetaData, Entity entity,
continue;
}
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(entityClass, entry.getKey());
if (ps == null || ps.getReadMethod() == null) {
continue;
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
Field field = ReflectionUtils.findField(entityClass, entry.getKey());
if (field != null) {
Expand All @@ -187,6 +284,9 @@ public static <Entity> String update(TableMetaData tableMetaData, Entity entity,
continue;
}
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(entityClass, entry.getKey());
if (ps == null || ps.getReadMethod() == null) {
continue;
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
Field field = ReflectionUtils.findField(entityClass, entry.getKey());
if (field != null) {
Expand Down

0 comments on commit 83c0940

Please sign in to comment.