Skip to content

Commit

Permalink
add partialUpdateList
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Sep 19, 2024
1 parent 83c0940 commit 2a93a4c
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ public interface BaseDao<Entity> {
@UpdateProvider(type = BaseSqlProvider.class, method = "updateById")
int updateById(Entity entity);

/**
* Partially update the entity by primary key.
*/
@UpdateProvider(type = BaseSqlProvider.class, method = "partialUpdateByIds")
int partialUpdateByIds(List<Entity> entities);
//
// /**
// * Fully update the entity by primary key.
// */
// @UpdateProvider(type = BaseSqlProvider.class, method = "updateByIds")
// int updateByIds(List<Entity> entities);

/**
* Query the entity by primary key.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@

public interface HostDao extends BaseDao<HostPO> {

int saveAll(@Param("hosts") List<HostPO> hosts);

HostPO findByHostname(@Param("hostname") String hostname);

List<HostPO> findAllByHostnameIn(@Param("hostnames") Collection<String> hostnames);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,5 @@ public interface RepoDao extends BaseDao<RepoPO> {

Optional<RepoPO> findByRepoName(@Param("repoName") String clusterName);

int saveAll(@Param("clusters") List<RepoPO> repos);

List<RepoPO> findAllByClusterId(@Param("clusterId") Long clusterId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ public <Entity> String partialUpdateById(Entity entity, ProviderContext context)
return SQLBuilder.partialUpdate(tableMetaData, entity, databaseId);
}

public <Entity> String partialUpdateByIds(List<Entity> entities, ProviderContext context) {
Assert.notNull(entities, "entities must not be null");
Assert.notEmpty(entities, "entities list must not be empty");

String databaseId = context.getDatabaseId();

Class<?> entityClass = entities.get(0).getClass();

TableMetaData tableMetaData = TableMetaData.forClass(entityClass);

return SQLBuilder.partialUpdateList(tableMetaData, entities, databaseId);
}

public <Entity> String updateById(Entity entity, ProviderContext context) {
Assert.notNull(entity, "entity must not null");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -308,6 +309,96 @@ public static <Entity> String update(TableMetaData tableMetaData, Entity entity,
return sql.toString();
}

public static String escapeSingleQuote(String input) {
if (input != null) {
return input.replace("'", "''");
}
return null;
}

public static <Entity> String partialUpdateList(
TableMetaData tableMetaData, List<Entity> entities, String databaseId) {
if (entities == null || entities.isEmpty()) {
throw new IllegalArgumentException("Entities list must not be null or empty");
}

Class<?> entityClass = entities.get(0).getClass();
Map<String, String> fieldColumnMap = tableMetaData.getFieldColumnMap();

SQL sql = new SQL();
switch (DBType.toType(databaseId)) {
case MYSQL: {
StringBuilder sqlBuilder = new StringBuilder();
sqlBuilder
.append("UPDATE ")
.append(tableMetaData.getTableName())
.append(" SET ");
Map<String, StringBuilder> setClauses = new LinkedHashMap<>();
String primaryKey = "id";
for (Map.Entry<String, String> entry : fieldColumnMap.entrySet()) {
log.info("entry: {}", entry);
log.info("primaryKey: {}", tableMetaData.getPkProperty());
// Ignore primary key
if (Objects.equals(entry.getKey(), tableMetaData.getPkProperty())) {
primaryKey = entry.getValue();
continue;
}

StringBuilder caseClause = new StringBuilder();
caseClause.append(entry.getValue()).append(" = CASE ");
log.info(caseClause.toString());
for (Entity entity : entities) {
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(entityClass, entry.getKey());
if (ps == null || ps.getReadMethod() == null) {
continue;
}

Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
if (!ObjectUtils.isEmpty(value)) {
PropertyDescriptor pkPs =
BeanUtils.getPropertyDescriptor(entityClass, tableMetaData.getPkProperty());
Object pkValue = ReflectionUtils.invokeMethod(pkPs.getReadMethod(), entity);

caseClause
.append("WHEN ")
.append(primaryKey)
.append(" = '")
.append(pkValue.toString())
.append("' THEN '")
.append(escapeSingleQuote(value.toString()))
.append("' ");
}
}

caseClause.append("END");
setClauses.put(entry.getValue(), caseClause);
}
sqlBuilder.append(String.join(", ", setClauses.values()));

sqlBuilder.append(" WHERE ").append(primaryKey).append(" IN (");
String pkValues = entities.stream()
.map(entity -> {
PropertyDescriptor pkPs =
BeanUtils.getPropertyDescriptor(entityClass, tableMetaData.getPkProperty());
Object pkValue = ReflectionUtils.invokeMethod(pkPs.getReadMethod(), entity);
return "'" + pkValue.toString() + "'";
})
.collect(Collectors.joining(", "));

sqlBuilder.append(pkValues).append(")");
return sqlBuilder.toString();
}
case POSTGRESQL: {
break;
}
default: {
log.error("Unsupported data source");
}
}

return sql.toString();
}

public static String selectById(TableMetaData tableMetaData, String databaseId, Serializable id) {

SQL sql = new SQL();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,4 @@
</where>
</select>

<insert id="saveAll" useGeneratedKeys="true" keyProperty="id" keyColumn="id">
insert into host (hostname, ipv4, ipv6, os, arch, available_processors, free_memory_size, total_memory_size, free_disk, total_disk, state, cluster_id, create_by, update_by, create_time, update_time)
values
<foreach collection='hosts' item='host' separator=','>
(#{host.hostname}, #{host.ipv4}, #{host.ipv6}, #{host.os}, #{host.arch}, #{host.availableProcessors}, #{host.freeMemorySize}, #{host.totalMemorySize}, #{host.freeDisk}, #{host.totalDisk}, #{host.state}, #{host.clusterId} ,#{host.createBy},#{host.updateBy},#{host.createTime},#{host.updateTime})
</foreach>
</insert>

</mapper>
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@
limit 1
</select>

<insert id="saveAll" useGeneratedKeys="true" keyProperty="id" keyColumn="id">
insert into repo (base_url, os, arch, repo_id, repo_name, repo_type, cluster_id, create_by, update_by, create_time, update_time)
values
<foreach collection='clusters' item='cluster' separator=','>
(#{cluster.baseUrl},#{cluster.os},#{cluster.arch},#{cluster.repoId},#{cluster.repoName},#{cluster.repoType},#{cluster.clusterId},#{cluster.createBy},#{cluster.updateBy},#{cluster.createTime},#{cluster.updateTime})
</foreach>
</insert>

<select id="findAllByClusterId" parameterType="java.lang.Long"
resultType="org.apache.bigtop.manager.dao.po.RepoPO">
select
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,4 @@
</where>
</select>

<insert id="saveAll" useGeneratedKeys="true" keyProperty="id" keyColumn="id">
insert into host (hostname, ipv4, ipv6, os, arch, available_processors, free_memory_size, total_memory_size, free_disk, total_disk, state, cluster_id, create_by, update_by, create_time, update_time)
values
<foreach collection='hosts' item='host' separator=','>
(#{host.hostname}, #{host.ipv4}, #{host.ipv6}, #{host.os}, #{host.arch}, #{host.availableProcessors}, #{host.freeMemorySize}, #{host.totalMemorySize}, #{host.freeDisk}, #{host.totalDisk}, #{host.state}, #{host.clusterId} ,#{host.createBy},#{host.updateBy},#{host.createTime},#{host.updateTime})
</foreach>
</insert>

</mapper>
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@
limit 1
</select>

<insert id="saveAll" useGeneratedKeys="true" keyProperty="id" keyColumn="id">
insert into repo (base_url, os, arch, repo_id, repo_name, repo_type, cluster_id, create_by, update_by, create_time, update_time)
values
<foreach collection='clusters' item='cluster' separator=','>
(#{cluster.baseUrl},#{cluster.os},#{cluster.arch},#{cluster.repoId},#{cluster.repoName},#{cluster.repoType},#{cluster.clusterId},#{cluster.createBy},#{cluster.updateBy},#{cluster.createTime},#{cluster.updateTime})
</foreach>
</insert>

<select id="findAllByClusterId" parameterType="java.lang.Long"
resultType="org.apache.bigtop.manager.dao.po.RepoPO">
select
Expand Down

0 comments on commit 2a93a4c

Please sign in to comment.