Merge "Add BaseMockResponse with two flavors."
diff --git a/src/main/java/com/google/mockwebserver/BaseMockResponse.java b/src/main/java/com/google/mockwebserver/BaseMockResponse.java
new file mode 100644
index 0000000..a12abf5
--- /dev/null
+++ b/src/main/java/com/google/mockwebserver/BaseMockResponse.java
@@ -0,0 +1,178 @@
+/*
+ * Copyright (C) 2013 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.mockwebserver;
+
+import static com.google.mockwebserver.MockWebServer.ASCII;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Base scripted response to be replayed by {@link MockWebServer}.
+ */
+abstract class BaseMockResponse<T extends BaseMockResponse<T>> {
+    protected static final String CONTENT_LENGTH = "Content-Length";
+
+    private String status = "HTTP/1.1 200 OK";
+    private List<String> headers = new ArrayList<String>();
+    private int bytesPerSecond = Integer.MAX_VALUE;
+    private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN;
+
+    protected BaseMockResponse() {
+    }
+
+    @Override
+    protected Object clone() throws CloneNotSupportedException {
+        final BaseMockResponse<?> result = (BaseMockResponse<?>) super.clone();
+        result.headers = new ArrayList<String>(result.headers);
+        return result;
+    }
+
+    /**
+     * Returns the HTTP response line, such as "HTTP/1.1 200 OK".
+     */
+    public String getStatus() {
+        return status;
+    }
+
+    public T setResponseCode(int code) {
+        this.status = "HTTP/1.1 " + code + " OK";
+        return self();
+    }
+
+    public T setStatus(String status) {
+        this.status = status;
+        return self();
+    }
+
+    /**
+     * Returns the HTTP headers, such as "Content-Length: 0".
+     */
+    public List<String> getHeaders() {
+        return headers;
+    }
+
+    public T clearHeaders() {
+        headers.clear();
+        return self();
+    }
+
+    public T addHeader(String header) {
+        headers.add(header);
+        return self();
+    }
+
+    public T addHeader(String name, Object value) {
+        return addHeader(name + ": " + String.valueOf(value));
+    }
+
+    public T setHeader(String name, Object value) {
+        removeHeader(name);
+        return addHeader(name, value);
+    }
+
+    public T removeHeader(String name) {
+        name += ": ";
+        for (Iterator<String> i = headers.iterator(); i.hasNext();) {
+            String header = i.next();
+            if (name.regionMatches(true, 0, header, 0, name.length())) {
+                i.remove();
+            }
+        }
+        return self();
+    }
+
+    public SocketPolicy getSocketPolicy() {
+        return socketPolicy;
+    }
+
+    public T setSocketPolicy(SocketPolicy socketPolicy) {
+        this.socketPolicy = socketPolicy;
+        return self();
+    }
+
+    public int getBytesPerSecond() {
+        return bytesPerSecond;
+    }
+
+    /**
+     * Set simulated network speed, in bytes per second.
+     */
+    public T setBytesPerSecond(int bytesPerSecond) {
+        this.bytesPerSecond = bytesPerSecond;
+        return self();
+    }
+
+    @Override public String toString() {
+        return status;
+    }
+
+    /**
+     * Write complete response, including all headers and the given body.
+     * Handles applying {@link #setBytesPerSecond(int)} limits.
+     */
+    protected void writeResponse(InputStream body, OutputStream out) throws IOException {
+        out.write((getStatus() + "\r\n").getBytes(ASCII));
+        for (String header : getHeaders()) {
+            out.write((header + "\r\n").getBytes(ASCII));
+        }
+        out.write(("\r\n").getBytes(ASCII));
+        out.flush();
+
+        // Stream data in MTU-sized increments
+        final byte[] buffer = new byte[1452];
+        final long delayMs;
+        if (bytesPerSecond == Integer.MAX_VALUE) {
+            delayMs = 0;
+        } else {
+            delayMs = (1000 * buffer.length) / bytesPerSecond;
+        }
+
+        int read;
+        long sinceDelay = 0;
+        while ((read = body.read(buffer)) != -1) {
+            out.write(buffer, 0, read);
+            out.flush();
+
+            sinceDelay += read;
+            if (sinceDelay >= buffer.length && delayMs > 0) {
+                sinceDelay %= buffer.length;
+                try {
+                    Thread.sleep(delayMs);
+                } catch (InterruptedException e) {
+                    throw new AssertionError();
+                }
+            }
+        }
+    }
+
+    /**
+     * Write complete response. Usually implemented by calling
+     * {@link #writeResponse(InputStream, OutputStream)} with the
+     * implementation-specific body.
+     */
+    public abstract void writeResponse(OutputStream out) throws IOException;
+
+    /**
+     * Return concrete {@code this} to enable builder-style methods.
+     */
+    protected abstract T self();
+}
diff --git a/src/main/java/com/google/mockwebserver/MockResponse.java b/src/main/java/com/google/mockwebserver/MockResponse.java
index 64eb8f3..8f11996 100644
--- a/src/main/java/com/google/mockwebserver/MockResponse.java
+++ b/src/main/java/com/google/mockwebserver/MockResponse.java
@@ -21,132 +21,35 @@
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
-import java.io.InputStream;
+import java.io.OutputStream;
 import java.io.UnsupportedEncodingException;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
 
 /**
- * A scripted response to be replayed by the mock web server.
+ * A scripted response to be replayed by {@link MockWebServer}.
  */
-public final class MockResponse implements Cloneable {
-    private static final String EMPTY_BODY_HEADER = "Content-Length: 0";
+public class MockResponse extends BaseMockResponse<MockResponse> implements Cloneable {
     private static final String CHUNKED_BODY_HEADER = "Transfer-encoding: chunked";
 
-    private String status = "HTTP/1.1 200 OK";
-    private List<String> headers = new ArrayList<String>();
-    private InputStream body;
-    private long bodyLength;
-    private int bytesPerSecond = Integer.MAX_VALUE;
-    private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN;
+    private byte[] body;
 
     public MockResponse() {
-        headers.add(EMPTY_BODY_HEADER);
+        this.body = new byte[0];
+        addHeader(CONTENT_LENGTH, 0);
     }
 
-    @Override public MockResponse clone() {
+    @Override
+    public MockResponse clone() {
         try {
-            MockResponse result = (MockResponse) super.clone();
-            result.headers = new ArrayList<String>(result.headers);
-            return result;
+            return (MockResponse) super.clone();
         } catch (CloneNotSupportedException e) {
             throw new AssertionError();
         }
     }
 
-    /**
-     * Returns the HTTP response line, such as "HTTP/1.1 200 OK".
-     */
-    public String getStatus() {
-        return status;
-    }
-
-    public MockResponse setResponseCode(int code) {
-        this.status = "HTTP/1.1 " + code + " OK";
-        return this;
-    }
-
-    public MockResponse setStatus(String status) {
-        this.status = status;
-        return this;
-    }
-
-    /**
-     * Returns the HTTP headers, such as "Content-Length: 0".
-     */
-    public List<String> getHeaders() {
-        return headers;
-    }
-
-    public MockResponse clearHeaders() {
-        headers.clear();
-        return this;
-    }
-
-    public MockResponse addHeader(String header) {
-        headers.add(header);
-        return this;
-    }
-
-    public MockResponse addHeader(String name, Object value) {
-        return addHeader(name + ": " + String.valueOf(value));
-    }
-
-    public MockResponse setHeader(String name, Object value) {
-        removeHeader(name);
-        return addHeader(name, value);
-    }
-
-    public MockResponse removeHeader(String name) {
-        name += ": ";
-        for (Iterator<String> i = headers.iterator(); i.hasNext();) {
-            String header = i.next();
-            if (name.regionMatches(true, 0, header, 0, name.length())) {
-                i.remove();
-            }
-        }
-        return this;
-    }
-
-    /**
-     * Returns a {@code byte[]} containing the raw HTTP payload. This is less
-     * efficient than {@link #getBodyStream()}.
-     */
-    public byte[] getBody() {
-        try {
-            return readFullyNoClose(body);
-        } catch (IOException e) {
-            throw new RuntimeException(e);
-        }
-    }
-
-    /**
-     * Returns an input stream containing the raw HTTP payload.
-     */
-    public InputStream getBodyStream() {
-        return body;
-    }
-
-    /**
-     * Returns length of raw HTTP payload.
-     */
-    public long getBodyLength() {
-        return bodyLength;
-    }
-
-    public MockResponse setBody(InputStream body, long bodyLength) {
-        if (this.body == null) {
-            headers.remove(EMPTY_BODY_HEADER);
-        }
-        this.headers.add("Content-Length: " + bodyLength);
-        this.body = body;
-        this.bodyLength = bodyLength;
-        return this;
-    }
-
     public MockResponse setBody(byte[] body) {
-        return setBody(new ByteArrayInputStream(body), body.length);
+        this.body = body;
+        setHeader(CONTENT_LENGTH, body.length);
+        return this;
     }
 
     public MockResponse setBody(String body) {
@@ -157,9 +60,13 @@
         }
     }
 
+    public byte[] getBody() {
+        return body;
+    }
+
     public MockResponse setChunkedBody(byte[] body, int maxChunkSize) throws IOException {
-        headers.remove(EMPTY_BODY_HEADER);
-        headers.add(CHUNKED_BODY_HEADER);
+        removeHeader(CONTENT_LENGTH);
+        addHeader(CHUNKED_BODY_HEADER);
 
         ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
         int pos = 0;
@@ -173,9 +80,7 @@
         }
         bytesOut.write("0\r\n\r\n".getBytes(ASCII)); // last chunk + empty trailer + crlf
 
-        body = bytesOut.toByteArray();
-        this.body = new ByteArrayInputStream(body);
-        this.bodyLength = body.length;
+        this.body = bytesOut.toByteArray();
         return this;
     }
 
@@ -183,38 +88,13 @@
         return setChunkedBody(body.getBytes(ASCII), maxChunkSize);
     }
 
-    public SocketPolicy getSocketPolicy() {
-        return socketPolicy;
-    }
-
-    public MockResponse setSocketPolicy(SocketPolicy socketPolicy) {
-        this.socketPolicy = socketPolicy;
+    @Override
+    protected MockResponse self() {
         return this;
     }
 
-    public int getBytesPerSecond() {
-        return bytesPerSecond;
-    }
-
-    /**
-     * Set simulated network speed, in bytes per second.
-     */
-    public MockResponse setBytesPerSecond(int bytesPerSecond) {
-        this.bytesPerSecond = bytesPerSecond;
-        return this;
-    }
-
-    @Override public String toString() {
-        return status;
-    }
-
-    private static byte[] readFullyNoClose(InputStream in) throws IOException {
-        ByteArrayOutputStream bytes = new ByteArrayOutputStream();
-        byte[] buffer = new byte[1024];
-        int count;
-        while ((count = in.read(buffer)) != -1) {
-            bytes.write(buffer, 0, count);
-        }
-        return bytes.toByteArray();
+    @Override
+    public void writeResponse(OutputStream out) throws IOException {
+        super.writeResponse(new ByteArrayInputStream(body), out);
     }
 }
diff --git a/src/main/java/com/google/mockwebserver/MockStreamResponse.java b/src/main/java/com/google/mockwebserver/MockStreamResponse.java
new file mode 100644
index 0000000..9db1c3c
--- /dev/null
+++ b/src/main/java/com/google/mockwebserver/MockStreamResponse.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2013 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.mockwebserver;
+
+import java.io.ByteArrayInputStream;
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+
+/**
+ * A scripted response to be replayed by {@link MockWebServer}. This specific
+ * variant uses an {@link InputStream} as its data source. Each instance can
+ * only be consumed once.
+ */
+public class MockStreamResponse extends BaseMockResponse<MockStreamResponse> {
+    private InputStream body;
+
+    public MockStreamResponse() {
+        body = new ByteArrayInputStream(new byte[0]);
+        addHeader(CONTENT_LENGTH, 0);
+    }
+
+    public MockStreamResponse setBody(InputStream body, long bodyLength) {
+        // Release any existing body
+        if (this.body != null) {
+            closeQuietly(this.body);
+        }
+
+        this.body = body;
+        setHeader(CONTENT_LENGTH, bodyLength);
+        return this;
+    }
+
+    @Override
+    public void writeResponse(OutputStream out) throws IOException {
+        if (body == null) {
+            throw new IllegalStateException("Stream already consumed");
+        }
+
+        try {
+            super.writeResponse(body, out);
+        } finally {
+            closeQuietly(body);
+        }
+        body = null;
+    }
+
+    @Override
+    protected MockStreamResponse self() {
+        return this;
+    }
+
+    private static void closeQuietly(Closeable closeable) {
+        if (closeable != null) {
+            try {
+                closeable.close();
+            } catch (RuntimeException rethrown) {
+                throw rethrown;
+            } catch (Exception ignored) {
+            }
+        }
+    }
+}
diff --git a/src/main/java/com/google/mockwebserver/MockWebServer.java b/src/main/java/com/google/mockwebserver/MockWebServer.java
index 6173c88..b774eff 100644
--- a/src/main/java/com/google/mockwebserver/MockWebServer.java
+++ b/src/main/java/com/google/mockwebserver/MockWebServer.java
@@ -51,6 +51,7 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.logging.Level;
 import java.util.logging.Logger;
+
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLSocket;
 import javax.net.ssl.SSLSocketFactory;
@@ -68,8 +69,8 @@
     private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
     private final BlockingQueue<RecordedRequest> requestQueue
             = new LinkedBlockingQueue<RecordedRequest>();
-    private final BlockingQueue<MockResponse> responseQueue
-            = new LinkedBlockingDeque<MockResponse>();
+    private final BlockingQueue<BaseMockResponse<?>> responseQueue
+            = new LinkedBlockingDeque<BaseMockResponse<?>>();
     private final Set<Socket> openClientSockets
             = Collections.newSetFromMap(new ConcurrentHashMap<Socket, Boolean>());
     private boolean singleResponse;
@@ -158,8 +159,12 @@
         return requestCount.get();
     }
 
-    public void enqueue(MockResponse response) {
-        responseQueue.add(response.clone());
+    public void enqueue(BaseMockResponse<?> response) {
+        if (response instanceof MockResponse) {
+            responseQueue.add(((MockResponse) response).clone());
+        } else {
+            responseQueue.add(response);
+        }
     }
 
     /**
@@ -236,7 +241,7 @@
                     } catch (SocketException ignored) {
                         continue;
                     }
-                    MockResponse peek = responseQueue.peek();
+                    BaseMockResponse<?> peek = responseQueue.peek();
                     if (peek != null && peek.getSocketPolicy() == DISCONNECT_AT_START) {
                         responseQueue.take();
                         socket.close();
@@ -274,7 +279,7 @@
                     if (tunnelProxy) {
                         createTunnel();
                     }
-                    MockResponse response = responseQueue.peek();
+                    BaseMockResponse<?> response = responseQueue.peek();
                     if (response != null && response.getSocketPolicy() == FAIL_HANDSHAKE) {
                         processHandshakeFailure(raw, sequenceNumber++);
                         return;
@@ -312,7 +317,7 @@
              */
             private void createTunnel() throws IOException, InterruptedException {
                 while (true) {
-                    MockResponse connect = responseQueue.peek();
+                    BaseMockResponse<?> connect = responseQueue.peek();
                     if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
                         throw new IllegalStateException("Tunnel without any CONNECT!");
                     }
@@ -332,8 +337,8 @@
                 if (request == null) {
                     return false;
                 }
-                MockResponse response = dispatch(request);
-                writeResponse(out, response);
+                BaseMockResponse<?> response = dispatch(request);
+                response.writeResponse(out);
                 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
                     in.close();
                     out.close();
@@ -450,7 +455,7 @@
     /**
      * Returns a response to satisfy {@code request}.
      */
-    private MockResponse dispatch(RecordedRequest request) throws InterruptedException {
+    private BaseMockResponse<?> dispatch(RecordedRequest request) throws InterruptedException {
         if (responseQueue.isEmpty()) {
             throw new IllegalStateException("Unexpected request: " + request);
         }
@@ -471,44 +476,6 @@
         }
     }
 
-    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
-        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
-        for (String header : response.getHeaders()) {
-            out.write((header + "\r\n").getBytes(ASCII));
-        }
-        out.write(("\r\n").getBytes(ASCII));
-        out.flush();
-
-        final InputStream in = response.getBodyStream();
-        final int bytesPerSecond = response.getBytesPerSecond();
-
-        // Stream data in MTU-sized increments
-        final byte[] buffer = new byte[1452];
-        final long delayMs;
-        if (bytesPerSecond == Integer.MAX_VALUE) {
-            delayMs = 0;
-        } else {
-            delayMs = (1000 * buffer.length) / bytesPerSecond;
-        }
-
-        int read;
-        long sinceDelay = 0;
-        while ((read = in.read(buffer)) != -1) {
-            out.write(buffer, 0, read);
-            out.flush();
-
-            sinceDelay += read;
-            if (sinceDelay >= buffer.length && delayMs > 0) {
-                sinceDelay %= buffer.length;
-                try {
-                    Thread.sleep(delayMs);
-                } catch (InterruptedException e) {
-                    throw new AssertionError();
-                }
-            }
-        }
-    }
-
     /**
      * Transfer bytes from {@code in} to {@code out} until either {@code length}
      * bytes have been transferred or {@code in} is exhausted.