diff --git a/embeddedwebserver/src/main/java/org/apache/ranger/server/tomcat/EmbeddedServer.java b/embeddedwebserver/src/main/java/org/apache/ranger/server/tomcat/EmbeddedServer.java index dbdbd2a8c2..c1a7726498 100644 --- a/embeddedwebserver/src/main/java/org/apache/ranger/server/tomcat/EmbeddedServer.java +++ b/embeddedwebserver/src/main/java/org/apache/ranger/server/tomcat/EmbeddedServer.java @@ -226,7 +226,7 @@ public void start() { valve.setMaxDays(EmbeddedServerUtil.getIntConfig(ACCESS_LOG_ROTATE_MAX_DAYS, 15)); valve.setRenameOnRotate(EmbeddedServerUtil.getBooleanConfig(ACCESS_LOG_ROTATE_RENAME_ON_ROTATE, false)); - String defaultAccessLogPattern = servername.equalsIgnoreCase(KMS_SERVER_NAME) ? "%h %l %u %t \"%m %U\" %s %b %D" : "%h %l %u %t \"%r\" %s %b %D"; + String defaultAccessLogPattern = servername.equalsIgnoreCase(KMS_SERVER_NAME) ? "%h %l %u %t \"%m %U\" %s %b %D %{eek_op}r" : "%h %l %u %t \"%r\" %s %b %D"; String logPattern = EmbeddedServerUtil.getConfig(ACCESS_LOG_PATTERN, defaultAccessLogPattern); valve.setPattern(logPattern); diff --git a/kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMSMDCFilter.java b/kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMSMDCFilter.java index 3116b46712..203f9078d5 100644 --- a/kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMSMDCFilter.java +++ b/kms/src/main/java/org/apache/hadoop/crypto/key/kms/server/KMSMDCFilter.java @@ -20,6 +20,8 @@ import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.token.delegation.web.HttpUserGroupInformation; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.servlet.Filter; import javax.servlet.FilterChain; @@ -31,6 +33,9 @@ import javax.servlet.http.HttpServletResponse; import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; /** * Servlet filter that captures context of the HTTP request to be use in the @@ -38,8 +43,12 @@ */ @InterfaceAudience.Private public class KMSMDCFilter implements Filter { + static final Logger logger = LoggerFactory.getLogger(KMSMDCFilter.class); + static final String RANGER_KMS_REST_API_PATH = "/kms/api/status"; + private static final String EEK_OP_CODE = "eek_op"; + private static final ThreadLocal DATA_TL = new ThreadLocal<>(); public static UserGroupInformation getUgi() { @@ -54,6 +63,10 @@ public static String getURL() { return DATA_TL.get().url; } + public static String getOperation() { + return DATA_TL.get().operation; + } + @Override public void init(FilterConfig config) throws ServletException { } @@ -62,6 +75,7 @@ public void init(FilterConfig config) throws ServletException { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { try { String path = ((HttpServletRequest) request).getRequestURI(); + HttpServletRequest req = (HttpServletRequest) request; HttpServletResponse resp = (HttpServletResponse) response; if (path.startsWith(RANGER_KMS_REST_API_PATH)) { @@ -70,15 +84,37 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha DATA_TL.remove(); UserGroupInformation ugi = HttpUserGroupInformation.get(); - String method = ((HttpServletRequest) request).getMethod(); - StringBuffer requestURL = ((HttpServletRequest) request).getRequestURL(); - String queryString = ((HttpServletRequest) request).getQueryString(); + String method = req.getMethod(); + StringBuffer requestURL = req.getRequestURL(); + String queryString = req.getQueryString(); + + // Extract operation from query parameters if present + String operation = null; + if (path.contains("/_eek") && queryString != null) { + for (String param : queryString.split("&")) { + String[] kv = param.split("=", 2); + if (kv.length == 2 && "eek_op".equals(kv[0])) { + try { + operation = URLDecoder.decode(kv[1], StandardCharsets.UTF_8.name()); + } catch (UnsupportedEncodingException | IllegalArgumentException e) { + logger.error("Failed to decode eek_op parameter value using UTF-8 encoding: {}", kv[1], e); + throw new ServletException("Failed to decode eek_op parameter: '" + kv[1] + "'. " + e.getClass().getSimpleName() + ": " + e.getMessage(), e); + } + break; + } + } + } if (queryString != null) { requestURL.append("?").append(queryString); } - DATA_TL.set(new Data(ugi, method, requestURL.toString())); + // Store opCode in request attribute for Tomcat access logs + if (operation != null) { + req.setAttribute(EEK_OP_CODE, operation); + } + + DATA_TL.set(new Data(ugi, method, requestURL.toString(), operation)); chain.doFilter(request, resp); } @@ -95,11 +131,13 @@ private static class Data { private final UserGroupInformation ugi; private final String method; private final String url; + private final String operation; - private Data(UserGroupInformation ugi, String method, String url) { - this.ugi = ugi; - this.method = method; - this.url = url; + private Data(UserGroupInformation ugi, String method, String url, String operation) { + this.ugi = ugi; + this.method = method; + this.url = url; + this.operation = operation; } } } diff --git a/kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMSMDCFilter.java b/kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMSMDCFilter.java index c83fbb512d..f739346688 100644 --- a/kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMSMDCFilter.java +++ b/kms/src/test/java/org/apache/hadoop/crypto/key/kms/server/TestKMSMDCFilter.java @@ -29,7 +29,11 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -86,4 +90,79 @@ public void testDoFilter_withOtherPath() throws Exception { verify(chain).doFilter(request, response); } + + @Test + public void testDoFilter_withEekOpParameter() throws Exception { + KMSMDCFilter filter = new KMSMDCFilter(); + + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + FilterChain chain = mock(FilterChain.class); + + when(request.getRequestURI()).thenReturn("/kms/v1/_eek"); + when(request.getMethod()).thenReturn("POST"); + when(request.getRequestURL()).thenReturn(new StringBuffer("http://localhost/kms/v1/_eek")); + when(request.getQueryString()).thenReturn("eek_op=generate"); + + filter.doFilter(request, response, chain); + + verify(request).setAttribute("eek_op", "generate"); + verify(chain).doFilter(request, response); + } + + @Test + public void testDoFilter_withEekPathWithoutEekOp() throws Exception { + KMSMDCFilter filter = new KMSMDCFilter(); + + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + FilterChain chain = mock(FilterChain.class); + + when(request.getRequestURI()).thenReturn("/kms/v1/_eek"); + when(request.getMethod()).thenReturn("POST"); + when(request.getRequestURL()).thenReturn(new StringBuffer("http://localhost/kms/v1/_eek")); + when(request.getQueryString()).thenReturn("other_param=value"); + + filter.doFilter(request, response, chain); + + verify(request, never()).setAttribute(eq("eek_op"), any()); + verify(chain).doFilter(request, response); + } + + @Test + public void testDoFilter_withMalformedEekOpEncoding() throws Exception { + KMSMDCFilter filter = new KMSMDCFilter(); + + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + FilterChain chain = mock(FilterChain.class); + + when(request.getRequestURI()).thenReturn("/kms/v1/_eek"); + when(request.getMethod()).thenReturn("POST"); + when(request.getRequestURL()).thenReturn(new StringBuffer("http://localhost/kms/v1/_eek")); + // Malformed percent encoding + when(request.getQueryString()).thenReturn("eek_op=%E0%A4%A"); + + assertThrows(ServletException.class, () -> { + filter.doFilter(request, response, chain); + }); + } + + @Test + public void testDoFilter_withOtherPathAndQueryString() throws Exception { + KMSMDCFilter filter = new KMSMDCFilter(); + + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + FilterChain chain = mock(FilterChain.class); + + when(request.getRequestURI()).thenReturn("/kms/v1/keys"); + when(request.getMethod()).thenReturn("GET"); + when(request.getRequestURL()).thenReturn(new StringBuffer("http://localhost/kms/v1/keys")); + when(request.getQueryString()).thenReturn("foo=bar"); + + filter.doFilter(request, response, chain); + + verify(chain).doFilter(request, response); + } }