Skip to content

Commit

Permalink
add partial flag
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Sep 19, 2024
1 parent 2a93a4c commit 157dd0e
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ public interface BaseDao<Entity> {
*/
@UpdateProvider(type = BaseSqlProvider.class, method = "partialUpdateByIds")
int partialUpdateByIds(List<Entity> entities);
//
// /**
// * Fully update the entity by primary key.
// */
// @UpdateProvider(type = BaseSqlProvider.class, method = "updateByIds")
// int updateByIds(List<Entity> entities);

/**
* Fully update the entity by primary key.
*/
@UpdateProvider(type = BaseSqlProvider.class, method = "updateByIds")
int updateByIds(List<Entity> entities);

/**
* Query the entity by primary key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public <Entity> String partialUpdateById(Entity entity, ProviderContext context)
Class<?> entityClass = entity.getClass();
TableMetaData tableMetaData = TableMetaData.forClass(entityClass);

return SQLBuilder.partialUpdate(tableMetaData, entity, databaseId);
return SQLBuilder.update(tableMetaData, entity, databaseId, true);
}

public <Entity> String partialUpdateByIds(List<Entity> entities, ProviderContext context) {
Expand All @@ -81,7 +81,7 @@ public <Entity> String partialUpdateByIds(List<Entity> entities, ProviderContext

TableMetaData tableMetaData = TableMetaData.forClass(entityClass);

return SQLBuilder.partialUpdateList(tableMetaData, entities, databaseId);
return SQLBuilder.updateList(tableMetaData, entities, databaseId, true);
}

public <Entity> String updateById(Entity entity, ProviderContext context) {
Expand All @@ -92,7 +92,20 @@ public <Entity> String updateById(Entity entity, ProviderContext context) {
Class<?> entityClass = entity.getClass();
TableMetaData tableMetaData = TableMetaData.forClass(entityClass);

return SQLBuilder.update(tableMetaData, entity, databaseId);
return SQLBuilder.update(tableMetaData, entity, databaseId, false);
}

public <Entity> String updateByIds(List<Entity> entities, ProviderContext context) {
Assert.notNull(entities, "entities must not be null");
Assert.notEmpty(entities, "entities list must not be empty");

String databaseId = context.getDatabaseId();

Class<?> entityClass = entities.get(0).getClass();

TableMetaData tableMetaData = TableMetaData.forClass(entityClass);

return SQLBuilder.updateList(tableMetaData, entities, databaseId, false);
}

public String selectById(Serializable id, ProviderContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ public static <Entity> String insertList(TableMetaData tableMetaData, List<Entit
return sql.toString();
}

public static <Entity> String partialUpdate(TableMetaData tableMetaData, Entity entity, String databaseId) {
public static <Entity> String update(
TableMetaData tableMetaData, Entity entity, String databaseId, boolean partial) {
Class<?> entityClass = entity.getClass();
Map<String, String> fieldColumnMap = tableMetaData.getFieldColumnMap();

Expand All @@ -211,59 +212,9 @@ public static <Entity> String partialUpdate(TableMetaData tableMetaData, Entity
continue;
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
if (!ObjectUtils.isEmpty(value)) {
sql.SET(getEquals(entry.getValue(), entry.getKey()));
}
}

sql.WHERE(getEquals(tableMetaData.getPkColumn(), tableMetaData.getPkProperty()));
break;
}
case POSTGRESQL: {
sql.UPDATE("\"" + tableMetaData.getTableName() + "\"");
for (Map.Entry<String, String> entry : fieldColumnMap.entrySet()) {
// Ignore primary key
if (Objects.equals(entry.getKey(), tableMetaData.getPkProperty())) {
continue;
}
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(entityClass, entry.getKey());
if (ps == null || ps.getReadMethod() == null) {
continue;
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
if (!ObjectUtils.isEmpty(value)) {
sql.SET("\"" + getEquals(entry.getValue() + "\"", entry.getKey()));
}
}
sql.WHERE(getEquals(tableMetaData.getPkColumn(), tableMetaData.getPkProperty()));
break;
}
default: {
log.error("Unsupported data source");
}
}

return sql.toString();
}

public static <Entity> String update(TableMetaData tableMetaData, Entity entity, String databaseId) {
Class<?> entityClass = entity.getClass();
Map<String, String> fieldColumnMap = tableMetaData.getFieldColumnMap();

SQL sql = new SQL();
switch (DBType.toType(databaseId)) {
case MYSQL: {
sql.UPDATE(tableMetaData.getTableName());
for (Map.Entry<String, String> entry : fieldColumnMap.entrySet()) {
// Ignore primary key
if (Objects.equals(entry.getKey(), tableMetaData.getPkProperty())) {
continue;
}
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(entityClass, entry.getKey());
if (ps == null || ps.getReadMethod() == null) {
if (ObjectUtils.isEmpty(value) && partial) {
continue;
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
Field field = ReflectionUtils.findField(entityClass, entry.getKey());
if (field != null) {
Column column = field.getAnnotation(Column.class);
Expand All @@ -289,6 +240,9 @@ public static <Entity> String update(TableMetaData tableMetaData, Entity entity,
continue;
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
if (ObjectUtils.isEmpty(value) && partial) {
continue;
}
Field field = ReflectionUtils.findField(entityClass, entry.getKey());
if (field != null) {
Column column = field.getAnnotation(Column.class);
Expand All @@ -309,35 +263,25 @@ public static <Entity> String update(TableMetaData tableMetaData, Entity entity,
return sql.toString();
}

public static String escapeSingleQuote(String input) {
if (input != null) {
return input.replace("'", "''");
}
return null;
}

public static <Entity> String partialUpdateList(
TableMetaData tableMetaData, List<Entity> entities, String databaseId) {
public static <Entity> String updateList(
TableMetaData tableMetaData, List<Entity> entities, String databaseId, boolean partial) {
if (entities == null || entities.isEmpty()) {
throw new IllegalArgumentException("Entities list must not be null or empty");
}

Class<?> entityClass = entities.get(0).getClass();
Map<String, String> fieldColumnMap = tableMetaData.getFieldColumnMap();

SQL sql = new SQL();
StringBuilder sqlBuilder = new StringBuilder();
switch (DBType.toType(databaseId)) {
case MYSQL: {
StringBuilder sqlBuilder = new StringBuilder();
sqlBuilder
.append("UPDATE ")
.append(tableMetaData.getTableName())
.append(" SET ");
Map<String, StringBuilder> setClauses = new LinkedHashMap<>();
String primaryKey = "id";
for (Map.Entry<String, String> entry : fieldColumnMap.entrySet()) {
log.info("entry: {}", entry);
log.info("primaryKey: {}", tableMetaData.getPkProperty());
// Ignore primary key
if (Objects.equals(entry.getKey(), tableMetaData.getPkProperty())) {
primaryKey = entry.getValue();
Expand All @@ -346,30 +290,122 @@ public static <Entity> String partialUpdateList(

StringBuilder caseClause = new StringBuilder();
caseClause.append(entry.getValue()).append(" = CASE ");
log.info(caseClause.toString());
for (Entity entity : entities) {
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(entityClass, entry.getKey());
if (ps == null || ps.getReadMethod() == null) {
continue;
}

Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
PropertyDescriptor pkPs =
BeanUtils.getPropertyDescriptor(entityClass, tableMetaData.getPkProperty());
if (pkPs == null || pkPs.getReadMethod() == null) {
continue;
}
Object pkValue = ReflectionUtils.invokeMethod(pkPs.getReadMethod(), entity);

if (!ObjectUtils.isEmpty(value)) {
caseClause
.append("WHEN ")
.append(primaryKey)
.append(" = '")
.append(pkValue)
.append("' THEN '")
.append(escapeSingleQuote(value.toString()))
.append("' ");
} else if (!partial) {
Field field = ReflectionUtils.findField(entityClass, entry.getKey());
if (field != null) {
Column column = field.getAnnotation(Column.class);
if (column != null && !column.nullable() && value == null) {
continue;
}
}
caseClause
.append("WHEN ")
.append(primaryKey)
.append(" = '")
.append(pkValue)
.append("' THEN NULL ");
}
}
caseClause.append("ELSE ").append(entry.getValue()).append(" ");
caseClause.append("END");
setClauses.put(entry.getValue(), caseClause);
}
sqlBuilder.append(String.join(", ", setClauses.values()));

sqlBuilder.append(" WHERE ").append(primaryKey).append(" IN (");
String pkValues = entities.stream()
.map(entity -> {
PropertyDescriptor pkPs =
BeanUtils.getPropertyDescriptor(entityClass, tableMetaData.getPkProperty());
Object pkValue = ReflectionUtils.invokeMethod(pkPs.getReadMethod(), entity);
return "'" + pkValue.toString() + "'";
})
.collect(Collectors.joining(", "));

sqlBuilder.append(pkValues).append(")");
break;
}
case POSTGRESQL: {
sqlBuilder
.append("UPDATE ")
.append("\"")
.append(tableMetaData.getTableName())
.append("\"")
.append(" SET ");
Map<String, StringBuilder> setClauses = new LinkedHashMap<>();
String primaryKey = "\"id\"";
for (Map.Entry<String, String> entry : fieldColumnMap.entrySet()) {
// Ignore primary key
if (Objects.equals(entry.getKey(), tableMetaData.getPkProperty())) {
primaryKey = "\"" + entry.getValue() + "\"";
continue;
}

StringBuilder caseClause = new StringBuilder();
caseClause.append(entry.getValue()).append(" = CASE ");

for (Entity entity : entities) {
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(entityClass, entry.getKey());
if (ps == null || ps.getReadMethod() == null) {
continue;
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
PropertyDescriptor pkPs =
BeanUtils.getPropertyDescriptor(entityClass, tableMetaData.getPkProperty());
if (pkPs == null || pkPs.getReadMethod() == null) {
continue;
}
Object pkValue = ReflectionUtils.invokeMethod(pkPs.getReadMethod(), entity);

if (!ObjectUtils.isEmpty(value)) {
caseClause
.append("WHEN ")
.append(primaryKey)
.append(" = '")
.append(pkValue.toString())
.append(pkValue)
.append("' THEN '")
.append(escapeSingleQuote(value.toString()))
.append("' ");
} else if (!partial) {
Field field = ReflectionUtils.findField(entityClass, entry.getKey());
if (field != null) {
Column column = field.getAnnotation(Column.class);
if (column != null && !column.nullable() && value == null) {
continue;
}
}
caseClause
.append("WHEN ")
.append(primaryKey)
.append(" = '")
.append(pkValue)
.append("' THEN NULL ");
}
}

caseClause.append("ELSE \"").append(entry.getValue()).append("\" ");
caseClause.append("END");
setClauses.put(entry.getValue(), caseClause);
}
Expand All @@ -386,17 +422,14 @@ public static <Entity> String partialUpdateList(
.collect(Collectors.joining(", "));

sqlBuilder.append(pkValues).append(")");
return sqlBuilder.toString();
}
case POSTGRESQL: {
break;
}
default: {
log.error("Unsupported data source");
}
}

return sql.toString();
return sqlBuilder.toString();
}

public static String selectById(TableMetaData tableMetaData, String databaseId, Serializable id) {
Expand Down Expand Up @@ -575,6 +608,13 @@ private static String getTokenParam(String property) {
return "#{" + property + "}";
}

private static String escapeSingleQuote(String input) {
if (input != null) {
return input.replace("'", "''");
}
return null;
}

private static <Condition> SQL mysqlCondition(Condition condition, TableMetaData tableMetaData)
throws IllegalAccessException {

Expand Down

0 comments on commit 157dd0e

Please sign in to comment.