summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Earle <chris.earle@elastic.co>2016-08-24 19:28:32 -0400
committerChris Earle <chris.earle@elastic.co>2016-08-30 18:02:07 -0400
commitb8f4c92d41411e17ec45f3c83dc81d1f12d39751 (patch)
tree60cad765aac955cc639d0ece8d2575e536c74f43
parent2a7a187bf872ebe440f2916250e1f82d1b3259e1 (diff)
Allow RestClient to send array-based headers
This enables the RestClient to send array-based (multi-valued) header values, rather than only sending whatever happened to be the _last_ value of the header.
-rw-r--r--client/rest/src/main/java/org/elasticsearch/client/RestClient.java13
-rw-r--r--client/rest/src/main/java/org/elasticsearch/client/RestClientBuilder.java4
-rw-r--r--client/rest/src/test/java/org/elasticsearch/client/RestClientIntegTests.java62
-rw-r--r--client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostTests.java79
-rw-r--r--client/test/build.gradle1
-rw-r--r--client/test/src/main/java/org/elasticsearch/client/RestClientTestCase.java76
6 files changed, 153 insertions, 82 deletions
diff --git a/client/rest/src/main/java/org/elasticsearch/client/RestClient.java b/client/rest/src/main/java/org/elasticsearch/client/RestClient.java
index 26af479f66..d2301e1e8e 100644
--- a/client/rest/src/main/java/org/elasticsearch/client/RestClient.java
+++ b/client/rest/src/main/java/org/elasticsearch/client/RestClient.java
@@ -362,12 +362,17 @@ public class RestClient implements Closeable {
private void setHeaders(HttpRequest httpRequest, Header[] requestHeaders) {
Objects.requireNonNull(requestHeaders, "request headers must not be null");
- for (Header defaultHeader : defaultHeaders) {
- httpRequest.setHeader(defaultHeader);
- }
+ // request headers override default headers, so we don't add default headers if they exist as request headers
+ final Set<String> requestNames = new HashSet<>(requestHeaders.length);
for (Header requestHeader : requestHeaders) {
Objects.requireNonNull(requestHeader, "request header must not be null");
- httpRequest.setHeader(requestHeader);
+ httpRequest.addHeader(requestHeader);
+ requestNames.add(requestHeader.getName());
+ }
+ for (Header defaultHeader : defaultHeaders) {
+ if (requestNames.contains(defaultHeader.getName()) == false) {
+ httpRequest.addHeader(defaultHeader);
+ }
}
}
diff --git a/client/rest/src/main/java/org/elasticsearch/client/RestClientBuilder.java b/client/rest/src/main/java/org/elasticsearch/client/RestClientBuilder.java
index 4f9f379d08..d342d59ade 100644
--- a/client/rest/src/main/java/org/elasticsearch/client/RestClientBuilder.java
+++ b/client/rest/src/main/java/org/elasticsearch/client/RestClientBuilder.java
@@ -71,7 +71,9 @@ public final class RestClientBuilder {
}
/**
- * Sets the default request headers, which will be sent along with each request
+ * Sets the default request headers, which will be sent along with each request.
+ * <p>
+ * Request-time headers will always overwrite any default headers.
*
* @throws NullPointerException if {@code defaultHeaders} or any header is {@code null}.
*/
diff --git a/client/rest/src/test/java/org/elasticsearch/client/RestClientIntegTests.java b/client/rest/src/test/java/org/elasticsearch/client/RestClientIntegTests.java
index e7d7852de0..9c5c50946d 100644
--- a/client/rest/src/test/java/org/elasticsearch/client/RestClientIntegTests.java
+++ b/client/rest/src/test/java/org/elasticsearch/client/RestClientIntegTests.java
@@ -19,8 +19,6 @@
package org.elasticsearch.client;
-import com.carrotsearch.randomizedtesting.generators.RandomInts;
-import com.carrotsearch.randomizedtesting.generators.RandomStrings;
import com.sun.net.httpserver.Headers;
import com.sun.net.httpserver.HttpContext;
import com.sun.net.httpserver.HttpExchange;
@@ -28,10 +26,8 @@ import com.sun.net.httpserver.HttpHandler;
import com.sun.net.httpserver.HttpServer;
import org.apache.http.Consts;
import org.apache.http.Header;
-import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.entity.StringEntity;
-import org.apache.http.message.BasicHeader;
import org.apache.http.util.EntityUtils;
import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement;
import org.junit.AfterClass;
@@ -83,13 +79,8 @@ public class RestClientIntegTests extends RestClientTestCase {
for (int statusCode : getAllStatusCodes()) {
createStatusCodeContext(httpServer, statusCode);
}
- int numHeaders = RandomInts.randomIntBetween(getRandom(), 0, 3);
- defaultHeaders = new Header[numHeaders];
- for (int i = 0; i < numHeaders; i++) {
- String headerName = "Header-default" + (getRandom().nextBoolean() ? i : "");
- String headerValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10);
- defaultHeaders[i] = new BasicHeader(headerName, headerValue);
- }
+ int numHeaders = randomIntBetween(0, 5);
+ defaultHeaders = generateHeaders("Header-default", "Header-array", numHeaders);
restClient = RestClient.builder(new HttpHost(httpServer.getAddress().getHostString(), httpServer.getAddress().getPort()))
.setDefaultHeaders(defaultHeaders).build();
}
@@ -148,44 +139,43 @@ public class RestClientIntegTests extends RestClientTestCase {
*/
public void testHeaders() throws IOException {
for (String method : getHttpMethods()) {
- Set<String> standardHeaders = new HashSet<>(
- Arrays.asList("Connection", "Host", "User-agent", "Date"));
+ final Set<String> standardHeaders = new HashSet<>(Arrays.asList("Connection", "Host", "User-agent", "Date"));
if (method.equals("HEAD") == false) {
standardHeaders.add("Content-length");
}
- int numHeaders = RandomInts.randomIntBetween(getRandom(), 1, 5);
- Map<String, String> expectedHeaders = new HashMap<>();
- for (Header defaultHeader : defaultHeaders) {
- expectedHeaders.put(defaultHeader.getName(), defaultHeader.getValue());
- }
- Header[] headers = new Header[numHeaders];
- for (int i = 0; i < numHeaders; i++) {
- String headerName = "Header" + (getRandom().nextBoolean() ? i : "");
- String headerValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10);
- headers[i] = new BasicHeader(headerName, headerValue);
- expectedHeaders.put(headerName, headerValue);
- }
- int statusCode = randomStatusCode(getRandom());
+ final int numHeaders = randomIntBetween(1, 5);
+ final Header[] headers = generateHeaders("Header", "Header-array", numHeaders);
+ final Map<String, List<String>> expectedHeaders = new HashMap<>();
+
+ addHeaders(expectedHeaders, defaultHeaders, headers);
+
+ final int statusCode = randomStatusCode(getRandom());
Response esResponse;
try {
- esResponse = restClient.performRequest(method, "/" + statusCode, Collections.<String, String>emptyMap(),
- (HttpEntity)null, headers);
+ esResponse = restClient.performRequest(method, "/" + statusCode, Collections.<String, String>emptyMap(), headers);
} catch(ResponseException e) {
esResponse = e.getResponse();
}
assertThat(esResponse.getStatusLine().getStatusCode(), equalTo(statusCode));
- for (Header responseHeader : esResponse.getHeaders()) {
- if (responseHeader.getName().startsWith("Header")) {
- String headerValue = expectedHeaders.remove(responseHeader.getName());
- assertNotNull("found response header [" + responseHeader.getName() + "] that wasn't originally sent", headerValue);
+ for (final Header responseHeader : esResponse.getHeaders()) {
+ final String name = responseHeader.getName();
+ final String value = responseHeader.getValue();
+ if (name.startsWith("Header")) {
+ final List<String> values = expectedHeaders.get(name);
+ assertNotNull("found response header [" + name + "] that wasn't originally sent: " + value, values);
+ assertTrue("found incorrect response header [" + name + "]: " + value, values.remove(value));
+
+ // we've collected them all
+ if (values.isEmpty()) {
+ expectedHeaders.remove(name);
+ }
} else {
- assertTrue("unknown header was returned " + responseHeader.getName(),
- standardHeaders.remove(responseHeader.getName()));
+ assertTrue("unknown header was returned " + name, standardHeaders.remove(name));
}
}
- assertEquals("some headers that were sent weren't returned: " + expectedHeaders, 0, expectedHeaders.size());
- assertEquals("some expected standard headers weren't returned: " + standardHeaders, 0, standardHeaders.size());
+ assertTrue("some headers that were sent weren't returned: " + expectedHeaders, expectedHeaders.isEmpty());
+ assertTrue("some expected standard headers weren't returned: " + standardHeaders, standardHeaders.isEmpty());
}
}
diff --git a/client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostTests.java b/client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostTests.java
index a6ae30b01e..92e2b0da97 100644
--- a/client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostTests.java
+++ b/client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostTests.java
@@ -19,8 +19,6 @@
package org.elasticsearch.client;
-import com.carrotsearch.randomizedtesting.generators.RandomInts;
-import com.carrotsearch.randomizedtesting.generators.RandomStrings;
import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.HttpEntityEnclosingRequest;
@@ -41,7 +39,6 @@ import org.apache.http.concurrent.FutureCallback;
import org.apache.http.conn.ConnectTimeoutException;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
-import org.apache.http.message.BasicHeader;
import org.apache.http.message.BasicHttpResponse;
import org.apache.http.message.BasicStatusLine;
import org.apache.http.nio.protocol.HttpAsyncRequestProducer;
@@ -58,7 +55,10 @@ import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.Future;
import static org.elasticsearch.client.RestClientTestUtil.getAllErrorStatusCodes;
@@ -132,13 +132,8 @@ public class RestClientSingleHostTests extends RestClientTestCase {
});
- int numHeaders = RandomInts.randomIntBetween(getRandom(), 0, 3);
- defaultHeaders = new Header[numHeaders];
- for (int i = 0; i < numHeaders; i++) {
- String headerName = "Header-default" + (getRandom().nextBoolean() ? i : "");
- String headerValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10);
- defaultHeaders[i] = new BasicHeader(headerName, headerValue);
- }
+ int numHeaders = randomIntBetween(0, 3);
+ defaultHeaders = generateHeaders("Header-default", "Header-array", numHeaders);
httpHost = new HttpHost("localhost", 9200);
failureListener = new HostsTrackingFailureListener();
restClient = new RestClient(httpClient, 10000, defaultHeaders, new HttpHost[]{httpHost}, null, failureListener);
@@ -333,20 +328,13 @@ public class RestClientSingleHostTests extends RestClientTestCase {
*/
public void testHeaders() throws IOException {
for (String method : getHttpMethods()) {
- Map<String, String> expectedHeaders = new HashMap<>();
- for (Header defaultHeader : defaultHeaders) {
- expectedHeaders.put(defaultHeader.getName(), defaultHeader.getValue());
- }
- int numHeaders = RandomInts.randomIntBetween(getRandom(), 1, 5);
- Header[] headers = new Header[numHeaders];
- for (int i = 0; i < numHeaders; i++) {
- String headerName = "Header" + (getRandom().nextBoolean() ? i : "");
- String headerValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10);
- headers[i] = new BasicHeader(headerName, headerValue);
- expectedHeaders.put(headerName, headerValue);
- }
+ final int numHeaders = randomIntBetween(1, 5);
+ final Header[] headers = generateHeaders("Header", null, numHeaders);
+ final Map<String, List<String>> expectedHeaders = new HashMap<>();
- int statusCode = randomStatusCode(getRandom());
+ addHeaders(expectedHeaders, defaultHeaders, headers);
+
+ final int statusCode = randomStatusCode(getRandom());
Response esResponse;
try {
esResponse = restClient.performRequest(method, "/" + statusCode, headers);
@@ -355,10 +343,18 @@ public class RestClientSingleHostTests extends RestClientTestCase {
}
assertThat(esResponse.getStatusLine().getStatusCode(), equalTo(statusCode));
for (Header responseHeader : esResponse.getHeaders()) {
- String headerValue = expectedHeaders.remove(responseHeader.getName());
- assertNotNull("found response header [" + responseHeader.getName() + "] that wasn't originally sent", headerValue);
+ final String name = responseHeader.getName();
+ final String value = responseHeader.getValue();
+ final List<String> values = expectedHeaders.get(name);
+ assertNotNull("found response header [" + name + "] that wasn't originally sent: " + value, values);
+ assertTrue("found incorrect response header [" + name + "]: " + value, values.remove(value));
+
+ // we've collected them all
+ if (values.isEmpty()) {
+ expectedHeaders.remove(name);
+ }
}
- assertEquals("some headers that were sent weren't returned " + expectedHeaders, 0, expectedHeaders.size());
+ assertTrue("some headers that were sent weren't returned " + expectedHeaders, expectedHeaders.isEmpty());
}
}
@@ -368,11 +364,11 @@ public class RestClientSingleHostTests extends RestClientTestCase {
Map<String, String> params = Collections.emptyMap();
boolean hasParams = randomBoolean();
if (hasParams) {
- int numParams = RandomInts.randomIntBetween(getRandom(), 1, 3);
+ int numParams = randomIntBetween(1, 3);
params = new HashMap<>(numParams);
for (int i = 0; i < numParams; i++) {
String paramKey = "param-" + i;
- String paramValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10);
+ String paramValue = randomAsciiOfLengthBetween(3, 10);
params.put(paramKey, paramValue);
uriBuilder.addParameter(paramKey, paramValue);
}
@@ -412,24 +408,24 @@ public class RestClientSingleHostTests extends RestClientTestCase {
HttpEntity entity = null;
boolean hasBody = request instanceof HttpEntityEnclosingRequest && getRandom().nextBoolean();
if (hasBody) {
- entity = new StringEntity(RandomStrings.randomAsciiOfLengthBetween(getRandom(), 10, 100));
+ entity = new StringEntity(randomAsciiOfLengthBetween(10, 100));
((HttpEntityEnclosingRequest) request).setEntity(entity);
}
Header[] headers = new Header[0];
- for (Header defaultHeader : defaultHeaders) {
- //default headers are expected but not sent for each request
- request.setHeader(defaultHeader);
+ final int numHeaders = randomIntBetween(1, 5);
+ final Set<String> uniqueNames = new HashSet<>(numHeaders);
+ if (randomBoolean()) {
+ headers = generateHeaders("Header", "Header-array", numHeaders);
+ for (Header header : headers) {
+ request.addHeader(header);
+ uniqueNames.add(header.getName());
+ }
}
- if (getRandom().nextBoolean()) {
- int numHeaders = RandomInts.randomIntBetween(getRandom(), 1, 5);
- headers = new Header[numHeaders];
- for (int i = 0; i < numHeaders; i++) {
- String headerName = "Header" + (getRandom().nextBoolean() ? i : "");
- String headerValue = RandomStrings.randomAsciiOfLengthBetween(getRandom(), 3, 10);
- BasicHeader basicHeader = new BasicHeader(headerName, headerValue);
- headers[i] = basicHeader;
- request.setHeader(basicHeader);
+ for (Header defaultHeader : defaultHeaders) {
+ // request level headers override default headers
+ if (uniqueNames.contains(defaultHeader.getName()) == false) {
+ request.addHeader(defaultHeader);
}
}
@@ -459,4 +455,5 @@ public class RestClientSingleHostTests extends RestClientTestCase {
throw new UnsupportedOperationException();
}
}
+
}
diff --git a/client/test/build.gradle b/client/test/build.gradle
index 05d044504e..a7ffe79ac5 100644
--- a/client/test/build.gradle
+++ b/client/test/build.gradle
@@ -30,6 +30,7 @@ install.enabled = false
uploadArchives.enabled = false
dependencies {
+ compile "org.apache.httpcomponents:httpcore:${versions.httpcore}"
compile "com.carrotsearch.randomizedtesting:randomizedtesting-runner:${versions.randomizedrunner}"
compile "junit:junit:${versions.junit}"
compile "org.hamcrest:hamcrest-all:${versions.hamcrest}"
diff --git a/client/test/src/main/java/org/elasticsearch/client/RestClientTestCase.java b/client/test/src/main/java/org/elasticsearch/client/RestClientTestCase.java
index 8c506beb5a..4296932a00 100644
--- a/client/test/src/main/java/org/elasticsearch/client/RestClientTestCase.java
+++ b/client/test/src/main/java/org/elasticsearch/client/RestClientTestCase.java
@@ -31,6 +31,15 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakZombies;
import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite;
+import org.apache.http.Header;
+import org.apache.http.message.BasicHeader;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
@TestMethodProviders({
JUnit3MethodProvider.class
})
@@ -43,4 +52,71 @@ import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite;
@TimeoutSuite(millis = 2 * 60 * 60 * 1000)
public abstract class RestClientTestCase extends RandomizedTest {
+ /**
+ * Create the specified number of {@link Header}s.
+ * <p>
+ * Generated header names will be the {@code baseName} plus its index or, rarely, the {@code arrayName} if it's supplied.
+ *
+ * @param baseName The base name to use for all headers.
+ * @param arrayName The optional ({@code null}able) array name to use randomly.
+ * @param headers The number of headers to create.
+ * @return Never {@code null}.
+ */
+ protected static Header[] generateHeaders(final String baseName, final String arrayName, final int headers) {
+ final Header[] generated = new Header[headers];
+ for (int i = 0; i < headers; i++) {
+ String headerName = baseName + i;
+ if (arrayName != null && rarely()) {
+ headerName = arrayName;
+ }
+
+ generated[i] = new BasicHeader(headerName, randomAsciiOfLengthBetween(3, 10));
+ }
+ return generated;
+ }
+
+ /**
+ * Create a new {@link List} within the {@code map} if none exists for {@code name} or append to the existing list.
+ *
+ * @param map The map to manipulate.
+ * @param name The name to create/append the list for.
+ * @param value The value to add.
+ */
+ private static void createOrAppendList(final Map<String, List<String>> map, final String name, final String value) {
+ List<String> values = map.get(name);
+
+ if (values == null) {
+ values = new ArrayList<>();
+ map.put(name, values);
+ }
+
+ values.add(value);
+ }
+
+ /**
+ * Add the {@code headers} to the {@code map} so that related tests can more easily assert that they exist.
+ * <p>
+ * If both the {@code defaultHeaders} and {@code headers} contain the same {@link Header}, based on its
+ * {@linkplain Header#getName() name}, then this will only use the {@code Header}(s) from {@code headers}.
+ *
+ * @param map The map to build with name/value(s) pairs.
+ * @param defaultHeaders The headers to add to the map representing default headers.
+ * @param headers The headers to add to the map representing request-level headers.
+ * @see #createOrAppendList(Map, String, String)
+ */
+ protected static void addHeaders(final Map<String, List<String>> map, final Header[] defaultHeaders, final Header[] headers) {
+ final Set<String> uniqueHeaders = new HashSet<>();
+ for (final Header header : headers) {
+ final String name = header.getName();
+ createOrAppendList(map, name, header.getValue());
+ uniqueHeaders.add(name);
+ }
+ for (final Header defaultHeader : defaultHeaders) {
+ final String name = defaultHeader.getName();
+ if (uniqueHeaders.contains(name) == false) {
+ createOrAppendList(map, name, defaultHeader.getValue());
+ }
+ }
+ }
+
}