Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ public void start() {
valve.setMaxDays(EmbeddedServerUtil.getIntConfig(ACCESS_LOG_ROTATE_MAX_DAYS, 15));
valve.setRenameOnRotate(EmbeddedServerUtil.getBooleanConfig(ACCESS_LOG_ROTATE_RENAME_ON_ROTATE, false));

String defaultAccessLogPattern = servername.equalsIgnoreCase(KMS_SERVER_NAME) ? "%h %l %u %t \"%m %U\" %s %b %D" : "%h %l %u %t \"%r\" %s %b %D";
String defaultAccessLogPattern = servername.equalsIgnoreCase(KMS_SERVER_NAME) ? "%h %l %u %t \"%m %U\" %s %b %D %{eek_op}r" : "%h %l %u %t \"%r\" %s %b %D";
String logPattern = EmbeddedServerUtil.getConfig(ACCESS_LOG_PATTERN, defaultAccessLogPattern);

valve.setPattern(logPattern);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.delegation.web.HttpUserGroupInformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
Expand All @@ -31,15 +33,22 @@
import javax.servlet.http.HttpServletResponse;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;

/**
* Servlet filter that captures context of the HTTP request to be use in the
* scope of KMS calls on the server side.
*/
@InterfaceAudience.Private
public class KMSMDCFilter implements Filter {
static final Logger logger = LoggerFactory.getLogger(KMSMDCFilter.class);

static final String RANGER_KMS_REST_API_PATH = "/kms/api/status";

private static final String EEK_OP_CODE = "eek_op";

private static final ThreadLocal<Data> DATA_TL = new ThreadLocal<>();

public static UserGroupInformation getUgi() {
Expand All @@ -54,6 +63,10 @@ public static String getURL() {
return DATA_TL.get().url;
}

public static String getOperation() {
return DATA_TL.get().operation;
}

@Override
public void init(FilterConfig config) throws ServletException {
}
Expand All @@ -62,6 +75,7 @@ public void init(FilterConfig config) throws ServletException {
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
try {
String path = ((HttpServletRequest) request).getRequestURI();
HttpServletRequest req = (HttpServletRequest) request;
HttpServletResponse resp = (HttpServletResponse) response;

if (path.startsWith(RANGER_KMS_REST_API_PATH)) {
Expand All @@ -70,15 +84,37 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha
DATA_TL.remove();

UserGroupInformation ugi = HttpUserGroupInformation.get();
String method = ((HttpServletRequest) request).getMethod();
StringBuffer requestURL = ((HttpServletRequest) request).getRequestURL();
String queryString = ((HttpServletRequest) request).getQueryString();
String method = req.getMethod();
StringBuffer requestURL = req.getRequestURL();
String queryString = req.getQueryString();

// Extract operation from query parameters if present
String operation = null;
if (path.contains("/_eek") && queryString != null) {
for (String param : queryString.split("&")) {
String[] kv = param.split("=", 2);
if (kv.length == 2 && "eek_op".equals(kv[0])) {
try {
operation = URLDecoder.decode(kv[1], StandardCharsets.UTF_8.name());
} catch (UnsupportedEncodingException | IllegalArgumentException e) {
logger.error("Failed to decode eek_op parameter value using UTF-8 encoding: {}", kv[1], e);
throw new ServletException("Failed to decode eek_op parameter: '" + kv[1] + "'. " + e.getClass().getSimpleName() + ": " + e.getMessage(), e);
}
break;
}
}
}

if (queryString != null) {
requestURL.append("?").append(queryString);
}

DATA_TL.set(new Data(ugi, method, requestURL.toString()));
// Store opCode in request attribute for Tomcat access logs
if (operation != null) {
req.setAttribute(EEK_OP_CODE, operation);
}

DATA_TL.set(new Data(ugi, method, requestURL.toString(), operation));

chain.doFilter(request, resp);
}
Expand All @@ -95,11 +131,13 @@ private static class Data {
private final UserGroupInformation ugi;
private final String method;
private final String url;
private final String operation;

private Data(UserGroupInformation ugi, String method, String url) {
this.ugi = ugi;
this.method = method;
this.url = url;
private Data(UserGroupInformation ugi, String method, String url, String operation) {
this.ugi = ugi;
this.method = method;
this.url = url;
this.operation = operation;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -86,4 +90,79 @@ public void testDoFilter_withOtherPath() throws Exception {

verify(chain).doFilter(request, response);
}

@Test
public void testDoFilter_withEekOpParameter() throws Exception {
KMSMDCFilter filter = new KMSMDCFilter();

HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
FilterChain chain = mock(FilterChain.class);

when(request.getRequestURI()).thenReturn("/kms/v1/_eek");
when(request.getMethod()).thenReturn("POST");
when(request.getRequestURL()).thenReturn(new StringBuffer("http://localhost/kms/v1/_eek"));
when(request.getQueryString()).thenReturn("eek_op=generate");

filter.doFilter(request, response, chain);

verify(request).setAttribute("eek_op", "generate");
verify(chain).doFilter(request, response);
}

@Test
public void testDoFilter_withEekPathWithoutEekOp() throws Exception {
KMSMDCFilter filter = new KMSMDCFilter();

HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
FilterChain chain = mock(FilterChain.class);

when(request.getRequestURI()).thenReturn("/kms/v1/_eek");
when(request.getMethod()).thenReturn("POST");
when(request.getRequestURL()).thenReturn(new StringBuffer("http://localhost/kms/v1/_eek"));
when(request.getQueryString()).thenReturn("other_param=value");

filter.doFilter(request, response, chain);

verify(request, never()).setAttribute(eq("eek_op"), any());
verify(chain).doFilter(request, response);
}

@Test
public void testDoFilter_withMalformedEekOpEncoding() throws Exception {
KMSMDCFilter filter = new KMSMDCFilter();

HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
FilterChain chain = mock(FilterChain.class);

when(request.getRequestURI()).thenReturn("/kms/v1/_eek");
when(request.getMethod()).thenReturn("POST");
when(request.getRequestURL()).thenReturn(new StringBuffer("http://localhost/kms/v1/_eek"));
// Malformed percent encoding
when(request.getQueryString()).thenReturn("eek_op=%E0%A4%A");

assertThrows(ServletException.class, () -> {
filter.doFilter(request, response, chain);
});
}

@Test
public void testDoFilter_withOtherPathAndQueryString() throws Exception {
KMSMDCFilter filter = new KMSMDCFilter();

HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletResponse response = mock(HttpServletResponse.class);
FilterChain chain = mock(FilterChain.class);

when(request.getRequestURI()).thenReturn("/kms/v1/keys");
when(request.getMethod()).thenReturn("GET");
when(request.getRequestURL()).thenReturn(new StringBuffer("http://localhost/kms/v1/keys"));
when(request.getQueryString()).thenReturn("foo=bar");

filter.doFilter(request, response, chain);

verify(chain).doFilter(request, response);
}
}