Skip to content

Commit

Permalink
NIFI-5628 Added content length check to OkHttpReplicationClient.
Browse files Browse the repository at this point in the history
Added unit tests.
  • Loading branch information
alopresto committed Sep 27, 2018
1 parent ca70dbb commit 1baead6
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,35 @@
import com.fasterxml.jackson.annotation.JsonInclude.Value;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.module.jaxb.JaxbAnnotationIntrospector;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.security.KeyStore;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.GZIPInputStream;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;
import javax.ws.rs.HttpMethod;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import okhttp3.Call;
import okhttp3.ConnectionPool;
import okhttp3.Headers;
Expand All @@ -42,36 +71,6 @@
import org.slf4j.LoggerFactory;
import org.springframework.util.StreamUtils;

import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;
import javax.ws.rs.HttpMethod;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.security.KeyStore;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.GZIPInputStream;

public class OkHttpReplicationClient implements HttpReplicationClient {
private static final Logger logger = LoggerFactory.getLogger(OkHttpReplicationClient.class);
private static final Set<String> gzipEncodings = Stream.of("gzip", "x-gzip").collect(Collectors.toSet());
Expand All @@ -95,12 +94,35 @@ public OkHttpReplicationClient(final NiFiProperties properties) {
@Override
public PreparedRequest prepareRequest(final String method, final Map<String, String> headers, final Object entity) {
final boolean gzip = isUseGzip(headers);
checkContentLengthHeader(method, headers);
final RequestBody requestBody = createRequestBody(headers, entity, gzip);

final Map<String, String> updatedHeaders = gzip ? updateHeadersForGzip(headers) : headers;
return new OkHttpPreparedRequest(method, updatedHeaders, entity, requestBody);
}

/**
* Checks the content length header on DELETE requests to ensure it is set to '0', avoiding request timeouts on replicated requests.
* @param method the HTTP method of the request
* @param headers the header keys and values
*/
private void checkContentLengthHeader(String method, Map<String, String> headers) {
// Only applies to DELETE requests
if (HttpMethod.DELETE.equalsIgnoreCase(method)) {
// Find the Content-Length header if present
final String CONTENT_LENGTH_HEADER_KEY = "Content-Length";
Map.Entry<String, String> contentLengthEntry = headers.entrySet().stream().filter(entry -> entry.getKey().equalsIgnoreCase(CONTENT_LENGTH_HEADER_KEY)).findFirst().orElse(null);
// If no CL header, do nothing
if (contentLengthEntry != null) {
// If the provided CL value is non-zero, override it
if (contentLengthEntry.getValue() != null && !contentLengthEntry.getValue().equalsIgnoreCase("0")) {
logger.warn("This is a DELETE request; the provided Content-Length was {}; setting Content-Length to 0", contentLengthEntry.getValue());
headers.put(CONTENT_LENGTH_HEADER_KEY, "0");
}
}
}
}

@Override
public Response replicate(final PreparedRequest request, final String uri) throws IOException {
if (!(Objects.requireNonNull(request) instanceof OkHttpPreparedRequest)) {
Expand Down Expand Up @@ -140,7 +162,7 @@ private byte[] getResponseBytes(final okhttp3.Response callResponse) throws IOEx
final String contentEncoding = callResponse.header("Content-Encoding");
if (gzipEncodings.contains(contentEncoding)) {
try (final InputStream gzipIn = new GZIPInputStream(new ByteArrayInputStream(rawBytes));
final ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
final ByteArrayOutputStream baos = new ByteArrayOutputStream()) {

StreamUtils.copy(gzipIn, baos);
return baos.toByteArray();
Expand Down Expand Up @@ -183,7 +205,7 @@ private Call createCall(final OkHttpPreparedRequest request, final String uri) {

@SuppressWarnings("unchecked")
private HttpUrl buildUrl(final OkHttpPreparedRequest request, final String uri) {
HttpUrl.Builder urlBuilder = HttpUrl.parse(uri.toString()).newBuilder();
HttpUrl.Builder urlBuilder = HttpUrl.parse(uri).newBuilder();
switch (request.getMethod().toUpperCase()) {
case HttpMethod.DELETE:
case HttpMethod.HEAD:
Expand Down Expand Up @@ -226,7 +248,7 @@ private String getContentType(final Map<String, String> headers, final String de

private byte[] serializeEntity(final Object entity, final String contentType, final boolean gzip) {
try (final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final OutputStream out = gzip ? new GZIPOutputStream(baos, 1) : baos) {
final OutputStream out = gzip ? new GZIPOutputStream(baos, 1) : baos) {

getSerializer(contentType).serialize(entity, out);
out.close();
Expand Down Expand Up @@ -269,10 +291,10 @@ private boolean isUseGzip(final Map<String, String> headers) {
} else {
final String[] acceptEncodingTokens = rawAcceptEncoding.split(",");
return Stream.of(acceptEncodingTokens)
.map(String::trim)
.filter(StringUtils::isNotEmpty)
.map(String::toLowerCase)
.anyMatch(gzipEncodings::contains);
.map(String::trim)
.filter(StringUtils::isNotEmpty)
.map(String::toLowerCase)
.anyMatch(gzipEncodings::contains);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.apache.nifi.cluster.coordination.http.replication.okhttp

import org.apache.nifi.properties.StandardNiFiProperties
import org.apache.nifi.util.NiFiProperties
import org.junit.BeforeClass
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.slf4j.Logger
import org.slf4j.LoggerFactory

@RunWith(JUnit4.class)
class OkHttpReplicationClientTest extends GroovyTestCase {
private static final Logger logger = LoggerFactory.getLogger(OkHttpReplicationClientTest.class)

@BeforeClass
static void setUpOnce() throws Exception {
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}

private static StandardNiFiProperties mockNiFiProperties() {
[getClusterNodeConnectionTimeout: { -> "10 ms" },
getClusterNodeReadTimeout : { -> "10 ms" },
getProperty : { String prop ->
logger.mock("Requested getProperty(${prop}) -> \"\"")
""
}] as StandardNiFiProperties
}

@Test
void testShouldReplaceNonZeroContentLengthHeader() {
// Arrange
def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"]
String method = "DELETE"
logger.info("Original headers: ${headers}")

NiFiProperties mockProperties = mockNiFiProperties()

OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)

// Act
client.checkContentLengthHeader(method, headers)
logger.info("Checked headers: ${headers}")

// Assert
assert headers.size() == 2
assert headers."Content-Length" == "0"
}

@Test
void testShouldReplaceNonZeroContentLengthHeaderOnDeleteCaseInsensitive() {
// Arrange
def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"]
String method = "delete"
logger.info("Original headers: ${headers}")

NiFiProperties mockProperties = mockNiFiProperties()

OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)

// Act
client.checkContentLengthHeader(method, headers)
logger.info("Checked headers: ${headers}")

// Assert
assert headers.size() == 2
assert headers."Content-Length" == "0"
}

@Test
void testShouldNotReplaceContentLengthHeaderWhenZeroOrNull() {
// Arrange
String method = "DELETE"
def zeroOrNullContentLengths = [null, "0"]

NiFiProperties mockProperties = mockNiFiProperties()

OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)

// Act
zeroOrNullContentLengths.each { String contentLength ->
def headers = ["Content-Length": contentLength, "Other-Header": "arbitrary value"]
logger.info("Original headers: ${headers}")

logger.info("Trying method ${method}")
client.checkContentLengthHeader(method, headers)
logger.info("Checked headers: ${headers}")

// Assert
assert headers.size() == 2
assert headers."Content-Length" == contentLength
}
}

@Test
void testShouldNotReplaceNonZeroContentLengthHeaderOnOtherMethod() {
// Arrange
def headers = ["Content-Length": "123", "Other-Header": "arbitrary value"]
logger.info("Original headers: ${headers}")

NiFiProperties mockProperties = mockNiFiProperties()

OkHttpReplicationClient client = new OkHttpReplicationClient(mockProperties)

def nonDeleteMethods = ["POST", "PUT", "GET", "HEAD"]

// Act
nonDeleteMethods.each { String method ->
logger.info("Trying method ${method}")
client.checkContentLengthHeader(method, headers)
logger.info("Checked headers: ${headers}")

// Assert
assert headers.size() == 2
assert headers."Content-Length" == "123"
}
}
}

0 comments on commit 1baead6

Please sign in to comment.