Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enhancement](user) Support limit user connection by ip #38837

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,15 @@ public long getMaxConn(String qualifiedUser) {
}
}

public long getMaxIpConn(String qualifiedUser) {
readLock();
try {
return propertyMgr.getMaxIpConn(qualifiedUser);
} finally {
readUnlock();
}
}

public int getQueryTimeout(String qualifiedUser) {
readLock();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ public class CommonUserProperties implements Writable, GsonPostProcessable {
// The max connections allowed for a user on one FE
@SerializedName(value = "mc", alternate = {"maxConn"})
private long maxConn = 100;
// The max connections allowed for a user and ip on one FE
@SerializedName(value = "mic", alternate = {"maxIpConn"})
private long maxIpConn = 100;
// The maximum total number of query instances that the user is allowed to send from this FE
@SerializedName(value = "mqi", alternate = {"maxQueryInstances"})
private long maxQueryInstances = -1;
Expand Down Expand Up @@ -75,6 +78,10 @@ long getMaxConn() {
return maxConn;
}

long getMaxIpConn() {
return maxIpConn;
}

long getMaxQueryInstances() {
return maxQueryInstances;
}
Expand All @@ -95,6 +102,11 @@ void setMaxConn(long maxConn) {
this.maxConn = maxConn;
}

void setMaxIpConn(long maxIpConn) {
this.maxIpConn = maxIpConn;
}


void setMaxQueryInstances(long maxQueryInstances) {
this.maxQueryInstances = maxQueryInstances;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public class UserProperty implements Writable {
private static final Logger LOG = LogManager.getLogger(UserProperty.class);
// advanced properties
public static final String PROP_MAX_USER_CONNECTIONS = "max_user_connections";
public static final String PROP_MAX_USER_IP_CONNECTIONS = "max_user_ip_connections";
public static final String PROP_MAX_QUERY_INSTANCES = "max_query_instances";
public static final String PROP_PARALLEL_FRAGMENT_EXEC_INSTANCE_NUM = "parallel_fragment_exec_instance_num";
public static final String PROP_RESOURCE_TAGS = "resource_tags";
Expand Down Expand Up @@ -123,6 +124,7 @@ public class UserProperty implements Writable {

static {
ADVANCED_PROPERTIES.add(Pattern.compile("^" + PROP_MAX_USER_CONNECTIONS + "$", Pattern.CASE_INSENSITIVE));
ADVANCED_PROPERTIES.add(Pattern.compile("^" + PROP_MAX_USER_IP_CONNECTIONS + "$", Pattern.CASE_INSENSITIVE));
ADVANCED_PROPERTIES.add(Pattern.compile("^" + PROP_RESOURCE + ".", Pattern.CASE_INSENSITIVE));
ADVANCED_PROPERTIES.add(Pattern.compile("^" + PROP_LOAD_CLUSTER + "." + DppConfig.CLUSTER_NAME_REGEX + "."
+ DppConfig.PRIORITY + "$", Pattern.CASE_INSENSITIVE));
Expand Down Expand Up @@ -159,6 +161,10 @@ public long getMaxConn() {
return this.commonProperties.getMaxConn();
}

public long getMaxIpConn() {
return this.commonProperties.getMaxIpConn();
}

public int getQueryTimeout() {
return this.commonProperties.getQueryTimeout();
}
Expand Down Expand Up @@ -207,6 +213,7 @@ public void update(List<Pair<String, String>> properties) throws UserException {
public void update(List<Pair<String, String>> properties, boolean isReplay) throws UserException {
// copy
long newMaxConn = this.commonProperties.getMaxConn();
long newMaxIpConn = this.commonProperties.getMaxIpConn();
long newMaxQueryInstances = this.commonProperties.getMaxQueryInstances();
int newParallelFragmentExecInstanceNum = this.commonProperties.getParallelFragmentExecInstanceNum();
String sqlBlockRules = this.commonProperties.getSqlBlockRules();
Expand All @@ -229,18 +236,13 @@ public void update(List<Pair<String, String>> properties, boolean isReplay) thro
String[] keyArr = key.split("\\" + SetUserPropertyVar.DOT_SEPARATOR);
if (keyArr[0].equalsIgnoreCase(PROP_MAX_USER_CONNECTIONS)) {
// set property "max_user_connections" = "1000"
if (keyArr.length != 1) {
throw new DdlException(PROP_MAX_USER_CONNECTIONS + " format error");
}

try {
newMaxConn = Long.parseLong(value);
} catch (NumberFormatException e) {
throw new DdlException(PROP_MAX_USER_CONNECTIONS + " is not number");
}

if (newMaxConn <= 0 || newMaxConn > 10000) {
throw new DdlException(PROP_MAX_USER_CONNECTIONS + " is not valid, must between 1 and 10000");
newMaxConn = getConn(key, value, keyArr);
} else if (keyArr[0].equalsIgnoreCase(PROP_MAX_USER_IP_CONNECTIONS)) {
// set property "max_user_ip_connections" = "1000"
newMaxIpConn = getConn(key, value, keyArr);
if (newMaxIpConn > newMaxConn) {
throw new DdlException(
PROP_MAX_USER_IP_CONNECTIONS + " should not be larger than " + PROP_MAX_USER_CONNECTIONS);
}
} else if (keyArr[0].equalsIgnoreCase(PROP_LOAD_CLUSTER)) {
updateLoadCluster(keyArr, value, newDppConfigs);
Expand All @@ -254,7 +256,7 @@ public void update(List<Pair<String, String>> properties, boolean isReplay) thro
}

newDefaultLoadCluster = value;
} else if (keyArr[0].equalsIgnoreCase(DEFAULT_CLOUD_CLUSTER)) {
} else if (keyArr[0].equalsIgnoreCase(DEFAULT_CLOUD_CLUSTER)) {
// set property "DEFAULT_CLOUD_CLUSTER" = "cluster1"
if (keyArr.length != 1) {
throw new DdlException(DEFAULT_CLOUD_CLUSTER + " format error");
Expand Down Expand Up @@ -372,6 +374,7 @@ public void update(List<Pair<String, String>> properties, boolean isReplay) thro

// set
this.commonProperties.setMaxConn(newMaxConn);
this.commonProperties.setMaxIpConn(newMaxIpConn);
this.commonProperties.setMaxQueryInstances(newMaxQueryInstances);
this.commonProperties.setParallelFragmentExecInstanceNum(newParallelFragmentExecInstanceNum);
this.commonProperties.setSqlBlockRules(sqlBlockRules);
Expand All @@ -390,6 +393,23 @@ public void update(List<Pair<String, String>> properties, boolean isReplay) thro
defaultCloudCluster = newDefaultCloudCluster;
}

private long getConn(String key, String value, String[] keyArr) throws DdlException {
if (keyArr.length != 1) {
throw new DdlException(key + " format error");
}
long conn;
try {
conn = Long.parseLong(value);
} catch (NumberFormatException e) {
throw new DdlException(key + " is not number");
}

if (conn <= 0 || conn > 10000) {
throw new DdlException(key + " is not valid, must between 1 and 10000");
}
return conn;
}

private long getLongProperty(String key, String value, String[] keyArr, String propName) throws DdlException {
// eg: set property "load_mem_limit" = "2147483648";
if (keyArr.length != 1) {
Expand Down Expand Up @@ -494,6 +514,9 @@ public List<List<String>> fetchProperty() {
// max user connections
result.add(Lists.newArrayList(PROP_MAX_USER_CONNECTIONS, String.valueOf(commonProperties.getMaxConn())));

// max user ip connections
result.add(Lists.newArrayList(PROP_MAX_USER_IP_CONNECTIONS, String.valueOf(commonProperties.getMaxIpConn())));

// max query instance
result.add(Lists.newArrayList(PROP_MAX_QUERY_INSTANCES,
String.valueOf(commonProperties.getMaxQueryInstances())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ public long getMaxConn(String qualifiedUser) {
return existProperty.getMaxConn();
}

public long getMaxIpConn(String qualifiedUser) {
UserProperty existProperty = propertyMap.get(qualifiedUser);
existProperty = getLdapPropertyIfNull(qualifiedUser, existProperty);
if (existProperty == null) {
return 0;
}
return existProperty.getMaxIpConn();
}

public long getMaxQueryInstances(String qualifiedUser) {
UserProperty existProperty = propertyMap.get(qualifiedUser);
existProperty = getLdapPropertyIfNull(qualifiedUser, existProperty);
Expand Down
16 changes: 13 additions & 3 deletions fe/fe-core/src/main/java/org/apache/doris/qe/ConnectScheduler.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class ConnectScheduler {
private final AtomicInteger nextConnectionId;
private final Map<Integer, ConnectContext> connectionMap = Maps.newConcurrentMap();
private final Map<String, AtomicInteger> connByUser = Maps.newConcurrentMap();
private final Map<String, AtomicInteger> connByUserIp = Maps.newConcurrentMap();
private final Map<String, Integer> flightToken2ConnectionId = Maps.newConcurrentMap();

// valid trace id -> query id
Expand Down Expand Up @@ -95,13 +96,22 @@ public boolean registerConnection(ConnectContext ctx) {
return false;
}
// Check user
connByUser.putIfAbsent(ctx.getQualifiedUser(), new AtomicInteger(0));
AtomicInteger conns = connByUser.get(ctx.getQualifiedUser());
if (conns.incrementAndGet() > ctx.getEnv().getAuth().getMaxConn(ctx.getQualifiedUser())) {
String qualifiedUser = ctx.getQualifiedUser();
String userIpKey = qualifiedUser + ":" + ctx.getRemoteIP();
connByUser.putIfAbsent(qualifiedUser, new AtomicInteger(0));
connByUserIp.putIfAbsent(userIpKey, new AtomicInteger(0));
AtomicInteger conns = connByUser.get(qualifiedUser);
AtomicInteger ipConns = connByUserIp.get(userIpKey);
if (conns.incrementAndGet() > ctx.getEnv().getAuth().getMaxConn(qualifiedUser)) {
conns.decrementAndGet();
numberConnection.decrementAndGet();
return false;
}
if (ipConns.incrementAndGet() > ctx.getEnv().getAuth().getMaxIpConn(qualifiedUser)) {
ipConns.decrementAndGet();
numberConnection.decrementAndGet();
return false;
}
connectionMap.put(ctx.getConnectionId(), ctx);
if (ctx.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) {
flightToken2ConnectionId.put(ctx.getPeerIdentity(), ctx.getConnectionId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public void testNormal() throws IOException, DdlException {
public void testUpdate() throws UserException {
List<Pair<String, String>> properties = Lists.newArrayList();
properties.add(Pair.of("MAX_USER_CONNECTIONS", "100"));
properties.add(Pair.of("max_user_ip_connections", "20"));
properties.add(Pair.of("load_cluster.dpp-cluster.hadoop_palo_path", "/user/palo2"));
properties.add(Pair.of("default_load_cluster", "dpp-cluster"));
properties.add(Pair.of("max_qUERY_instances", "3000"));
Expand All @@ -112,6 +113,7 @@ public void testUpdate() throws UserException {
UserProperty userProperty = new UserProperty();
userProperty.update(properties);
Assert.assertEquals(100, userProperty.getMaxConn());
Assert.assertEquals(20, userProperty.getMaxIpConn());
Assert.assertEquals("/user/palo2", userProperty.getLoadClusterInfo("dpp-cluster").second.getPaloPath());
Assert.assertEquals("dpp-cluster", userProperty.getDefaultLoadCluster());
Assert.assertEquals(3000, userProperty.getMaxQueryInstances());
Expand All @@ -121,6 +123,15 @@ public void testUpdate() throws UserException {
Assert.assertEquals(500, userProperty.getQueryTimeout());
Assert.assertEquals(Sets.newHashSet(), userProperty.getCopiedResourceTags());

properties.add(Pair.of("max_user_ip_connections", "120"));
try {
userProperty.update(properties);
Assert.fail();
} catch (Exception e) {
Assert.assertTrue(e.getMessage().contains("max_user_ip_connections should not be larger than max_user_connections"));
}


// fetch property
List<List<String>> rows = userProperty.fetchProperty();
for (List<String> row : rows) {
Expand All @@ -129,6 +140,8 @@ public void testUpdate() throws UserException {

if (key.equalsIgnoreCase("max_user_connections")) {
Assert.assertEquals("100", value);
} else if (key.equalsIgnoreCase("max_user_ip_connections")) {
Assert.assertEquals("20", value);
} else if (key.equalsIgnoreCase("load_cluster.dpp-cluster.hadoop_palo_path")) {
Assert.assertEquals("/user/palo2", value);
} else if (key.equalsIgnoreCase("default_load_cluster")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public void testSerialization() throws IOException, UserException {

List<Pair<String, String>> properties = Lists.newArrayList();
properties.add(Pair.of(UserProperty.PROP_MAX_USER_CONNECTIONS, "100"));
properties.add(Pair.of(UserProperty.PROP_MAX_USER_IP_CONNECTIONS, "100"));
properties.add(Pair.of(UserProperty.PROP_MAX_QUERY_INSTANCES, "2"));
properties.add(Pair.of(UserProperty.PROP_PARALLEL_FRAGMENT_EXEC_INSTANCE_NUM, "8"));
properties.add(Pair.of(UserProperty.PROP_SQL_BLOCK_RULES, "r1,r2"));
Expand All @@ -60,6 +61,7 @@ public void testSerialization() throws IOException, UserException {
UserProperty prop2 = UserProperty.read(in);

Assert.assertEquals(100, prop2.getMaxConn());
Assert.assertEquals(100, prop2.getMaxIpConn());
Assert.assertEquals(2, prop2.getMaxQueryInstances());
Assert.assertEquals(8, prop2.getParallelFragmentExecInstanceNum());
Assert.assertEquals(Lists.newArrayList("r1", "r2"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ public void test() throws Exception {
Assert.assertEquals(1000000, execMemLimit);

List<List<String>> userProps = Env.getCurrentEnv().getAuth().getUserProperties(Auth.ROOT_USER);
Assert.assertEquals(12, userProps.size());
Assert.assertEquals(13, userProps.size());

// now :
// be1 be2 be3 ==>tag1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ suite("test_auth_compatibility", "account") {

sql """SET PROPERTY FOR '${user}' 'max_user_connections'= '2048'"""

sql """SET PROPERTY FOR '${user}' 'max_user_ip_connections'= '100'"""

sql """SET PROPERTY FOR '${user}'
'load_cluster.cluster1.hadoop_palo_path' = '/user/doris/doris_path',
'load_cluster.cluster1.hadoop_configs' = 'fs.default.name=hdfs://dpp.cluster.com:port;mapred.job.tracker=dpp.cluster.com:port;hadoop.job.ugi=user,password;mapred.job.queue.name=job_queue_name_in_hadoop;mapred.job.priority=HIGH;';
Expand All @@ -83,6 +85,9 @@ suite("test_auth_compatibility", "account") {
def result = getProperty("max_user_connections", "${user}")
assertEquals(result.Value as String, "2048" as String)

result = getProperty("max_user_ip_connections", "${user}")
assertEquals(result.Value as String, "100" as String)

result = getProperty("default_load_cluster", "${user}")
assertEquals(result.Value as String, "cluster1" as String)

Expand Down
Loading