Merge "Add BaseMockResponse with two flavors."
diff --git a/src/main/java/com/google/mockwebserver/BaseMockResponse.java b/src/main/java/com/google/mockwebserver/BaseMockResponse.java
new file mode 100644
index 0000000..a12abf5
--- /dev/null
+++ b/src/main/java/com/google/mockwebserver/BaseMockResponse.java
@@ -0,0 +1,178 @@
+/*
+ * Copyright (C) 2013 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.mockwebserver;
+
+import static com.google.mockwebserver.MockWebServer.ASCII;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Base scripted response to be replayed by {@link MockWebServer}.
+ */
+abstract class BaseMockResponse<T extends BaseMockResponse<T>> {
+ protected static final String CONTENT_LENGTH = "Content-Length";
+
+ private String status = "HTTP/1.1 200 OK";
+ private List<String> headers = new ArrayList<String>();
+ private int bytesPerSecond = Integer.MAX_VALUE;
+ private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN;
+
+ protected BaseMockResponse() {
+ }
+
+ @Override
+ protected Object clone() throws CloneNotSupportedException {
+ final BaseMockResponse<?> result = (BaseMockResponse<?>) super.clone();
+ result.headers = new ArrayList<String>(result.headers);
+ return result;
+ }
+
+ /**
+ * Returns the HTTP response line, such as "HTTP/1.1 200 OK".
+ */
+ public String getStatus() {
+ return status;
+ }
+
+ public T setResponseCode(int code) {
+ this.status = "HTTP/1.1 " + code + " OK";
+ return self();
+ }
+
+ public T setStatus(String status) {
+ this.status = status;
+ return self();
+ }
+
+ /**
+ * Returns the HTTP headers, such as "Content-Length: 0".
+ */
+ public List<String> getHeaders() {
+ return headers;
+ }
+
+ public T clearHeaders() {
+ headers.clear();
+ return self();
+ }
+
+ public T addHeader(String header) {
+ headers.add(header);
+ return self();
+ }
+
+ public T addHeader(String name, Object value) {
+ return addHeader(name + ": " + String.valueOf(value));
+ }
+
+ public T setHeader(String name, Object value) {
+ removeHeader(name);
+ return addHeader(name, value);
+ }
+
+ public T removeHeader(String name) {
+ name += ": ";
+ for (Iterator<String> i = headers.iterator(); i.hasNext();) {
+ String header = i.next();
+ if (name.regionMatches(true, 0, header, 0, name.length())) {
+ i.remove();
+ }
+ }
+ return self();
+ }
+
+ public SocketPolicy getSocketPolicy() {
+ return socketPolicy;
+ }
+
+ public T setSocketPolicy(SocketPolicy socketPolicy) {
+ this.socketPolicy = socketPolicy;
+ return self();
+ }
+
+ public int getBytesPerSecond() {
+ return bytesPerSecond;
+ }
+
+ /**
+ * Set simulated network speed, in bytes per second.
+ */
+ public T setBytesPerSecond(int bytesPerSecond) {
+ this.bytesPerSecond = bytesPerSecond;
+ return self();
+ }
+
+ @Override public String toString() {
+ return status;
+ }
+
+ /**
+ * Write complete response, including all headers and the given body.
+ * Handles applying {@link #setBytesPerSecond(int)} limits.
+ */
+ protected void writeResponse(InputStream body, OutputStream out) throws IOException {
+ out.write((getStatus() + "\r\n").getBytes(ASCII));
+ for (String header : getHeaders()) {
+ out.write((header + "\r\n").getBytes(ASCII));
+ }
+ out.write(("\r\n").getBytes(ASCII));
+ out.flush();
+
+ // Stream data in MTU-sized increments
+ final byte[] buffer = new byte[1452];
+ final long delayMs;
+ if (bytesPerSecond == Integer.MAX_VALUE) {
+ delayMs = 0;
+ } else {
+ delayMs = (1000 * buffer.length) / bytesPerSecond;
+ }
+
+ int read;
+ long sinceDelay = 0;
+ while ((read = body.read(buffer)) != -1) {
+ out.write(buffer, 0, read);
+ out.flush();
+
+ sinceDelay += read;
+ if (sinceDelay >= buffer.length && delayMs > 0) {
+ sinceDelay %= buffer.length;
+ try {
+ Thread.sleep(delayMs);
+ } catch (InterruptedException e) {
+ throw new AssertionError();
+ }
+ }
+ }
+ }
+
+ /**
+ * Write complete response. Usually implemented by calling
+ * {@link #writeResponse(InputStream, OutputStream)} with the
+ * implementation-specific body.
+ */
+ public abstract void writeResponse(OutputStream out) throws IOException;
+
+ /**
+ * Return concrete {@code this} to enable builder-style methods.
+ */
+ protected abstract T self();
+}
diff --git a/src/main/java/com/google/mockwebserver/MockResponse.java b/src/main/java/com/google/mockwebserver/MockResponse.java
index 64eb8f3..8f11996 100644
--- a/src/main/java/com/google/mockwebserver/MockResponse.java
+++ b/src/main/java/com/google/mockwebserver/MockResponse.java
@@ -21,132 +21,35 @@
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
-import java.io.InputStream;
+import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
/**
- * A scripted response to be replayed by the mock web server.
+ * A scripted response to be replayed by {@link MockWebServer}.
*/
-public final class MockResponse implements Cloneable {
- private static final String EMPTY_BODY_HEADER = "Content-Length: 0";
+public class MockResponse extends BaseMockResponse<MockResponse> implements Cloneable {
private static final String CHUNKED_BODY_HEADER = "Transfer-encoding: chunked";
- private String status = "HTTP/1.1 200 OK";
- private List<String> headers = new ArrayList<String>();
- private InputStream body;
- private long bodyLength;
- private int bytesPerSecond = Integer.MAX_VALUE;
- private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN;
+ private byte[] body;
public MockResponse() {
- headers.add(EMPTY_BODY_HEADER);
+ this.body = new byte[0];
+ addHeader(CONTENT_LENGTH, 0);
}
- @Override public MockResponse clone() {
+ @Override
+ public MockResponse clone() {
try {
- MockResponse result = (MockResponse) super.clone();
- result.headers = new ArrayList<String>(result.headers);
- return result;
+ return (MockResponse) super.clone();
} catch (CloneNotSupportedException e) {
throw new AssertionError();
}
}
- /**
- * Returns the HTTP response line, such as "HTTP/1.1 200 OK".
- */
- public String getStatus() {
- return status;
- }
-
- public MockResponse setResponseCode(int code) {
- this.status = "HTTP/1.1 " + code + " OK";
- return this;
- }
-
- public MockResponse setStatus(String status) {
- this.status = status;
- return this;
- }
-
- /**
- * Returns the HTTP headers, such as "Content-Length: 0".
- */
- public List<String> getHeaders() {
- return headers;
- }
-
- public MockResponse clearHeaders() {
- headers.clear();
- return this;
- }
-
- public MockResponse addHeader(String header) {
- headers.add(header);
- return this;
- }
-
- public MockResponse addHeader(String name, Object value) {
- return addHeader(name + ": " + String.valueOf(value));
- }
-
- public MockResponse setHeader(String name, Object value) {
- removeHeader(name);
- return addHeader(name, value);
- }
-
- public MockResponse removeHeader(String name) {
- name += ": ";
- for (Iterator<String> i = headers.iterator(); i.hasNext();) {
- String header = i.next();
- if (name.regionMatches(true, 0, header, 0, name.length())) {
- i.remove();
- }
- }
- return this;
- }
-
- /**
- * Returns a {@code byte[]} containing the raw HTTP payload. This is less
- * efficient than {@link #getBodyStream()}.
- */
- public byte[] getBody() {
- try {
- return readFullyNoClose(body);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- /**
- * Returns an input stream containing the raw HTTP payload.
- */
- public InputStream getBodyStream() {
- return body;
- }
-
- /**
- * Returns length of raw HTTP payload.
- */
- public long getBodyLength() {
- return bodyLength;
- }
-
- public MockResponse setBody(InputStream body, long bodyLength) {
- if (this.body == null) {
- headers.remove(EMPTY_BODY_HEADER);
- }
- this.headers.add("Content-Length: " + bodyLength);
- this.body = body;
- this.bodyLength = bodyLength;
- return this;
- }
-
public MockResponse setBody(byte[] body) {
- return setBody(new ByteArrayInputStream(body), body.length);
+ this.body = body;
+ setHeader(CONTENT_LENGTH, body.length);
+ return this;
}
public MockResponse setBody(String body) {
@@ -157,9 +60,13 @@
}
}
+ public byte[] getBody() {
+ return body;
+ }
+
public MockResponse setChunkedBody(byte[] body, int maxChunkSize) throws IOException {
- headers.remove(EMPTY_BODY_HEADER);
- headers.add(CHUNKED_BODY_HEADER);
+ removeHeader(CONTENT_LENGTH);
+ addHeader(CHUNKED_BODY_HEADER);
ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
int pos = 0;
@@ -173,9 +80,7 @@
}
bytesOut.write("0\r\n\r\n".getBytes(ASCII)); // last chunk + empty trailer + crlf
- body = bytesOut.toByteArray();
- this.body = new ByteArrayInputStream(body);
- this.bodyLength = body.length;
+ this.body = bytesOut.toByteArray();
return this;
}
@@ -183,38 +88,13 @@
return setChunkedBody(body.getBytes(ASCII), maxChunkSize);
}
- public SocketPolicy getSocketPolicy() {
- return socketPolicy;
- }
-
- public MockResponse setSocketPolicy(SocketPolicy socketPolicy) {
- this.socketPolicy = socketPolicy;
+ @Override
+ protected MockResponse self() {
return this;
}
- public int getBytesPerSecond() {
- return bytesPerSecond;
- }
-
- /**
- * Set simulated network speed, in bytes per second.
- */
- public MockResponse setBytesPerSecond(int bytesPerSecond) {
- this.bytesPerSecond = bytesPerSecond;
- return this;
- }
-
- @Override public String toString() {
- return status;
- }
-
- private static byte[] readFullyNoClose(InputStream in) throws IOException {
- ByteArrayOutputStream bytes = new ByteArrayOutputStream();
- byte[] buffer = new byte[1024];
- int count;
- while ((count = in.read(buffer)) != -1) {
- bytes.write(buffer, 0, count);
- }
- return bytes.toByteArray();
+ @Override
+ public void writeResponse(OutputStream out) throws IOException {
+ super.writeResponse(new ByteArrayInputStream(body), out);
}
}
diff --git a/src/main/java/com/google/mockwebserver/MockStreamResponse.java b/src/main/java/com/google/mockwebserver/MockStreamResponse.java
new file mode 100644
index 0000000..9db1c3c
--- /dev/null
+++ b/src/main/java/com/google/mockwebserver/MockStreamResponse.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2013 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.mockwebserver;
+
+import java.io.ByteArrayInputStream;
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+
+/**
+ * A scripted response to be replayed by {@link MockWebServer}. This specific
+ * variant uses an {@link InputStream} as its data source. Each instance can
+ * only be consumed once.
+ */
+public class MockStreamResponse extends BaseMockResponse<MockStreamResponse> {
+ private InputStream body;
+
+ public MockStreamResponse() {
+ body = new ByteArrayInputStream(new byte[0]);
+ addHeader(CONTENT_LENGTH, 0);
+ }
+
+ public MockStreamResponse setBody(InputStream body, long bodyLength) {
+ // Release any existing body
+ if (this.body != null) {
+ closeQuietly(this.body);
+ }
+
+ this.body = body;
+ setHeader(CONTENT_LENGTH, bodyLength);
+ return this;
+ }
+
+ @Override
+ public void writeResponse(OutputStream out) throws IOException {
+ if (body == null) {
+ throw new IllegalStateException("Stream already consumed");
+ }
+
+ try {
+ super.writeResponse(body, out);
+ } finally {
+ closeQuietly(body);
+ }
+ body = null;
+ }
+
+ @Override
+ protected MockStreamResponse self() {
+ return this;
+ }
+
+ private static void closeQuietly(Closeable closeable) {
+ if (closeable != null) {
+ try {
+ closeable.close();
+ } catch (RuntimeException rethrown) {
+ throw rethrown;
+ } catch (Exception ignored) {
+ }
+ }
+ }
+}
diff --git a/src/main/java/com/google/mockwebserver/MockWebServer.java b/src/main/java/com/google/mockwebserver/MockWebServer.java
index 6173c88..b774eff 100644
--- a/src/main/java/com/google/mockwebserver/MockWebServer.java
+++ b/src/main/java/com/google/mockwebserver/MockWebServer.java
@@ -51,6 +51,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
+
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
@@ -68,8 +69,8 @@
private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
private final BlockingQueue<RecordedRequest> requestQueue
= new LinkedBlockingQueue<RecordedRequest>();
- private final BlockingQueue<MockResponse> responseQueue
- = new LinkedBlockingDeque<MockResponse>();
+ private final BlockingQueue<BaseMockResponse<?>> responseQueue
+ = new LinkedBlockingDeque<BaseMockResponse<?>>();
private final Set<Socket> openClientSockets
= Collections.newSetFromMap(new ConcurrentHashMap<Socket, Boolean>());
private boolean singleResponse;
@@ -158,8 +159,12 @@
return requestCount.get();
}
- public void enqueue(MockResponse response) {
- responseQueue.add(response.clone());
+ public void enqueue(BaseMockResponse<?> response) {
+ if (response instanceof MockResponse) {
+ responseQueue.add(((MockResponse) response).clone());
+ } else {
+ responseQueue.add(response);
+ }
}
/**
@@ -236,7 +241,7 @@
} catch (SocketException ignored) {
continue;
}
- MockResponse peek = responseQueue.peek();
+ BaseMockResponse<?> peek = responseQueue.peek();
if (peek != null && peek.getSocketPolicy() == DISCONNECT_AT_START) {
responseQueue.take();
socket.close();
@@ -274,7 +279,7 @@
if (tunnelProxy) {
createTunnel();
}
- MockResponse response = responseQueue.peek();
+ BaseMockResponse<?> response = responseQueue.peek();
if (response != null && response.getSocketPolicy() == FAIL_HANDSHAKE) {
processHandshakeFailure(raw, sequenceNumber++);
return;
@@ -312,7 +317,7 @@
*/
private void createTunnel() throws IOException, InterruptedException {
while (true) {
- MockResponse connect = responseQueue.peek();
+ BaseMockResponse<?> connect = responseQueue.peek();
if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
throw new IllegalStateException("Tunnel without any CONNECT!");
}
@@ -332,8 +337,8 @@
if (request == null) {
return false;
}
- MockResponse response = dispatch(request);
- writeResponse(out, response);
+ BaseMockResponse<?> response = dispatch(request);
+ response.writeResponse(out);
if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
in.close();
out.close();
@@ -450,7 +455,7 @@
/**
* Returns a response to satisfy {@code request}.
*/
- private MockResponse dispatch(RecordedRequest request) throws InterruptedException {
+ private BaseMockResponse<?> dispatch(RecordedRequest request) throws InterruptedException {
if (responseQueue.isEmpty()) {
throw new IllegalStateException("Unexpected request: " + request);
}
@@ -471,44 +476,6 @@
}
}
- private void writeResponse(OutputStream out, MockResponse response) throws IOException {
- out.write((response.getStatus() + "\r\n").getBytes(ASCII));
- for (String header : response.getHeaders()) {
- out.write((header + "\r\n").getBytes(ASCII));
- }
- out.write(("\r\n").getBytes(ASCII));
- out.flush();
-
- final InputStream in = response.getBodyStream();
- final int bytesPerSecond = response.getBytesPerSecond();
-
- // Stream data in MTU-sized increments
- final byte[] buffer = new byte[1452];
- final long delayMs;
- if (bytesPerSecond == Integer.MAX_VALUE) {
- delayMs = 0;
- } else {
- delayMs = (1000 * buffer.length) / bytesPerSecond;
- }
-
- int read;
- long sinceDelay = 0;
- while ((read = in.read(buffer)) != -1) {
- out.write(buffer, 0, read);
- out.flush();
-
- sinceDelay += read;
- if (sinceDelay >= buffer.length && delayMs > 0) {
- sinceDelay %= buffer.length;
- try {
- Thread.sleep(delayMs);
- } catch (InterruptedException e) {
- throw new AssertionError();
- }
- }
- }
- }
-
/**
* Transfer bytes from {@code in} to {@code out} until either {@code length}
* bytes have been transferred or {@code in} is exhausted.