Add BaseMockResponse with two flavors.

Some tests rely on MockResponses being cloned as they're enqueued,
which doesn't mix well with the recent switch to using InputStream
to represent all bodies.

This change introduces a base class that handles common header
management, and two distinct flavors: MockResponse which returns to
using byte[] bodies, and MockStreamResponse which is InputStream
based.  MockResponse are again cloned when enqueued, making existing
tests happy.

Bug: 8334369
Change-Id: I12b83a06d17ba82d46166c5550397a2198ce68fd
diff --git a/src/main/java/com/google/mockwebserver/BaseMockResponse.java b/src/main/java/com/google/mockwebserver/BaseMockResponse.java
new file mode 100644
index 0000000..a12abf5
--- /dev/null
+++ b/src/main/java/com/google/mockwebserver/BaseMockResponse.java
@@ -0,0 +1,178 @@
+/*
+ * Copyright (C) 2013 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.mockwebserver;
+
+import static com.google.mockwebserver.MockWebServer.ASCII;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Base scripted response to be replayed by {@link MockWebServer}.
+ */
+abstract class BaseMockResponse<T extends BaseMockResponse<T>> {
+    protected static final String CONTENT_LENGTH = "Content-Length";
+
+    private String status = "HTTP/1.1 200 OK";
+    private List<String> headers = new ArrayList<String>();
+    private int bytesPerSecond = Integer.MAX_VALUE;
+    private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN;
+
+    protected BaseMockResponse() {
+    }
+
+    @Override
+    protected Object clone() throws CloneNotSupportedException {
+        final BaseMockResponse<?> result = (BaseMockResponse<?>) super.clone();
+        result.headers = new ArrayList<String>(result.headers);
+        return result;
+    }
+
+    /**
+     * Returns the HTTP response line, such as "HTTP/1.1 200 OK".
+     */
+    public String getStatus() {
+        return status;
+    }
+
+    public T setResponseCode(int code) {
+        this.status = "HTTP/1.1 " + code + " OK";
+        return self();
+    }
+
+    public T setStatus(String status) {
+        this.status = status;
+        return self();
+    }
+
+    /**
+     * Returns the HTTP headers, such as "Content-Length: 0".
+     */
+    public List<String> getHeaders() {
+        return headers;
+    }
+
+    public T clearHeaders() {
+        headers.clear();
+        return self();
+    }
+
+    public T addHeader(String header) {
+        headers.add(header);
+        return self();
+    }
+
+    public T addHeader(String name, Object value) {
+        return addHeader(name + ": " + String.valueOf(value));
+    }
+
+    public T setHeader(String name, Object value) {
+        removeHeader(name);
+        return addHeader(name, value);
+    }
+
+    public T removeHeader(String name) {
+        name += ": ";
+        for (Iterator<String> i = headers.iterator(); i.hasNext();) {
+            String header = i.next();
+            if (name.regionMatches(true, 0, header, 0, name.length())) {
+                i.remove();
+            }
+        }
+        return self();
+    }
+
+    public SocketPolicy getSocketPolicy() {
+        return socketPolicy;
+    }
+
+    public T setSocketPolicy(SocketPolicy socketPolicy) {
+        this.socketPolicy = socketPolicy;
+        return self();
+    }
+
+    public int getBytesPerSecond() {
+        return bytesPerSecond;
+    }
+
+    /**
+     * Set simulated network speed, in bytes per second.
+     */
+    public T setBytesPerSecond(int bytesPerSecond) {
+        this.bytesPerSecond = bytesPerSecond;
+        return self();
+    }
+
+    @Override public String toString() {
+        return status;
+    }
+
+    /**
+     * Write complete response, including all headers and the given body.
+     * Handles applying {@link #setBytesPerSecond(int)} limits.
+     */
+    protected void writeResponse(InputStream body, OutputStream out) throws IOException {
+        out.write((getStatus() + "\r\n").getBytes(ASCII));
+        for (String header : getHeaders()) {
+            out.write((header + "\r\n").getBytes(ASCII));
+        }
+        out.write(("\r\n").getBytes(ASCII));
+        out.flush();
+
+        // Stream data in MTU-sized increments
+        final byte[] buffer = new byte[1452];
+        final long delayMs;
+        if (bytesPerSecond == Integer.MAX_VALUE) {
+            delayMs = 0;
+        } else {
+            delayMs = (1000 * buffer.length) / bytesPerSecond;
+        }
+
+        int read;
+        long sinceDelay = 0;
+        while ((read = body.read(buffer)) != -1) {
+            out.write(buffer, 0, read);
+            out.flush();
+
+            sinceDelay += read;
+            if (sinceDelay >= buffer.length && delayMs > 0) {
+                sinceDelay %= buffer.length;
+                try {
+                    Thread.sleep(delayMs);
+                } catch (InterruptedException e) {
+                    throw new AssertionError();
+                }
+            }
+        }
+    }
+
+    /**
+     * Write complete response. Usually implemented by calling
+     * {@link #writeResponse(InputStream, OutputStream)} with the
+     * implementation-specific body.
+     */
+    public abstract void writeResponse(OutputStream out) throws IOException;
+
+    /**
+     * Return concrete {@code this} to enable builder-style methods.
+     */
+    protected abstract T self();
+}
diff --git a/src/main/java/com/google/mockwebserver/MockResponse.java b/src/main/java/com/google/mockwebserver/MockResponse.java
index 64eb8f3..8f11996 100644
--- a/src/main/java/com/google/mockwebserver/MockResponse.java
+++ b/src/main/java/com/google/mockwebserver/MockResponse.java
@@ -21,132 +21,35 @@
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
-import java.io.InputStream;
+import java.io.OutputStream;
 import java.io.UnsupportedEncodingException;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
 
 /**
- * A scripted response to be replayed by the mock web server.
+ * A scripted response to be replayed by {@link MockWebServer}.
  */
-public final class MockResponse implements Cloneable {
-    private static final String EMPTY_BODY_HEADER = "Content-Length: 0";
+public class MockResponse extends BaseMockResponse<MockResponse> implements Cloneable {
     private static final String CHUNKED_BODY_HEADER = "Transfer-encoding: chunked";
 
-    private String status = "HTTP/1.1 200 OK";
-    private List<String> headers = new ArrayList<String>();
-    private InputStream body;
-    private long bodyLength;
-    private int bytesPerSecond = Integer.MAX_VALUE;
-    private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN;
+    private byte[] body;
 
     public MockResponse() {
-        headers.add(EMPTY_BODY_HEADER);
+        this.body = new byte[0];
+        addHeader(CONTENT_LENGTH, 0);
     }
 
-    @Override public MockResponse clone() {
+    @Override
+    public MockResponse clone() {
         try {
-            MockResponse result = (MockResponse) super.clone();
-            result.headers = new ArrayList<String>(result.headers);
-            return result;
+            return (MockResponse) super.clone();
         } catch (CloneNotSupportedException e) {
             throw new AssertionError();
         }
     }
 
-    /**
-     * Returns the HTTP response line, such as "HTTP/1.1 200 OK".
-     */
-    public String getStatus() {
-        return status;
-    }
-
-    public MockResponse setResponseCode(int code) {
-        this.status = "HTTP/1.1 " + code + " OK";
-        return this;
-    }
-
-    public MockResponse setStatus(String status) {
-        this.status = status;
-        return this;
-    }
-
-    /**
-     * Returns the HTTP headers, such as "Content-Length: 0".
-     */
-    public List<String> getHeaders() {
-        return headers;
-    }
-
-    public MockResponse clearHeaders() {
-        headers.clear();
-        return this;
-    }
-
-    public MockResponse addHeader(String header) {
-        headers.add(header);
-        return this;
-    }
-
-    public MockResponse addHeader(String name, Object value) {
-        return addHeader(name + ": " + String.valueOf(value));
-    }
-
-    public MockResponse setHeader(String name, Object value) {
-        removeHeader(name);
-        return addHeader(name, value);
-    }
-
-    public MockResponse removeHeader(String name) {
-        name += ": ";
-        for (Iterator<String> i = headers.iterator(); i.hasNext();) {
-            String header = i.next();
-            if (name.regionMatches(true, 0, header, 0, name.length())) {
-                i.remove();
-            }
-        }
-        return this;
-    }
-
-    /**
-     * Returns a {@code byte[]} containing the raw HTTP payload. This is less
-     * efficient than {@link #getBodyStream()}.
-     */
-    public byte[] getBody() {
-        try {
-            return readFullyNoClose(body);
-        } catch (IOException e) {
-            throw new RuntimeException(e);
-        }
-    }
-
-    /**
-     * Returns an input stream containing the raw HTTP payload.
-     */
-    public InputStream getBodyStream() {
-        return body;
-    }
-
-    /**
-     * Returns length of raw HTTP payload.
-     */
-    public long getBodyLength() {
-        return bodyLength;
-    }
-
-    public MockResponse setBody(InputStream body, long bodyLength) {
-        if (this.body == null) {
-            headers.remove(EMPTY_BODY_HEADER);
-        }
-        this.headers.add("Content-Length: " + bodyLength);
-        this.body = body;
-        this.bodyLength = bodyLength;
-        return this;
-    }
-
     public MockResponse setBody(byte[] body) {
-        return setBody(new ByteArrayInputStream(body), body.length);
+        this.body = body;
+        setHeader(CONTENT_LENGTH, body.length);
+        return this;
     }
 
     public MockResponse setBody(String body) {
@@ -157,9 +60,13 @@
         }
     }
 
+    public byte[] getBody() {
+        return body;
+    }
+
     public MockResponse setChunkedBody(byte[] body, int maxChunkSize) throws IOException {
-        headers.remove(EMPTY_BODY_HEADER);
-        headers.add(CHUNKED_BODY_HEADER);
+        removeHeader(CONTENT_LENGTH);
+        addHeader(CHUNKED_BODY_HEADER);
 
         ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
         int pos = 0;
@@ -173,9 +80,7 @@
         }
         bytesOut.write("0\r\n\r\n".getBytes(ASCII)); // last chunk + empty trailer + crlf
 
-        body = bytesOut.toByteArray();
-        this.body = new ByteArrayInputStream(body);
-        this.bodyLength = body.length;
+        this.body = bytesOut.toByteArray();
         return this;
     }
 
@@ -183,38 +88,13 @@
         return setChunkedBody(body.getBytes(ASCII), maxChunkSize);
     }
 
-    public SocketPolicy getSocketPolicy() {
-        return socketPolicy;
-    }
-
-    public MockResponse setSocketPolicy(SocketPolicy socketPolicy) {
-        this.socketPolicy = socketPolicy;
+    @Override
+    protected MockResponse self() {
         return this;
     }
 
-    public int getBytesPerSecond() {
-        return bytesPerSecond;
-    }
-
-    /**
-     * Set simulated network speed, in bytes per second.
-     */
-    public MockResponse setBytesPerSecond(int bytesPerSecond) {
-        this.bytesPerSecond = bytesPerSecond;
-        return this;
-    }
-
-    @Override public String toString() {
-        return status;
-    }
-
-    private static byte[] readFullyNoClose(InputStream in) throws IOException {
-        ByteArrayOutputStream bytes = new ByteArrayOutputStream();
-        byte[] buffer = new byte[1024];
-        int count;
-        while ((count = in.read(buffer)) != -1) {
-            bytes.write(buffer, 0, count);
-        }
-        return bytes.toByteArray();
+    @Override
+    public void writeResponse(OutputStream out) throws IOException {
+        super.writeResponse(new ByteArrayInputStream(body), out);
     }
 }
diff --git a/src/main/java/com/google/mockwebserver/MockStreamResponse.java b/src/main/java/com/google/mockwebserver/MockStreamResponse.java
new file mode 100644
index 0000000..9db1c3c
--- /dev/null
+++ b/src/main/java/com/google/mockwebserver/MockStreamResponse.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2013 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.mockwebserver;
+
+import java.io.ByteArrayInputStream;
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+
+/**
+ * A scripted response to be replayed by {@link MockWebServer}. This specific
+ * variant uses an {@link InputStream} as its data source. Each instance can
+ * only be consumed once.
+ */
+public class MockStreamResponse extends BaseMockResponse<MockStreamResponse> {
+    private InputStream body;
+
+    public MockStreamResponse() {
+        body = new ByteArrayInputStream(new byte[0]);
+        addHeader(CONTENT_LENGTH, 0);
+    }
+
+    public MockStreamResponse setBody(InputStream body, long bodyLength) {
+        // Release any existing body
+        if (this.body != null) {
+            closeQuietly(this.body);
+        }
+
+        this.body = body;
+        setHeader(CONTENT_LENGTH, bodyLength);
+        return this;
+    }
+
+    @Override
+    public void writeResponse(OutputStream out) throws IOException {
+        if (body == null) {
+            throw new IllegalStateException("Stream already consumed");
+        }
+
+        try {
+            super.writeResponse(body, out);
+        } finally {
+            closeQuietly(body);
+        }
+        body = null;
+    }
+
+    @Override
+    protected MockStreamResponse self() {
+        return this;
+    }
+
+    private static void closeQuietly(Closeable closeable) {
+        if (closeable != null) {
+            try {
+                closeable.close();
+            } catch (RuntimeException rethrown) {
+                throw rethrown;
+            } catch (Exception ignored) {
+            }
+        }
+    }
+}
diff --git a/src/main/java/com/google/mockwebserver/MockWebServer.java b/src/main/java/com/google/mockwebserver/MockWebServer.java
index 6173c88..b774eff 100644
--- a/src/main/java/com/google/mockwebserver/MockWebServer.java
+++ b/src/main/java/com/google/mockwebserver/MockWebServer.java
@@ -51,6 +51,7 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.logging.Level;
 import java.util.logging.Logger;
+
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLSocket;
 import javax.net.ssl.SSLSocketFactory;
@@ -68,8 +69,8 @@
     private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
     private final BlockingQueue<RecordedRequest> requestQueue
             = new LinkedBlockingQueue<RecordedRequest>();
-    private final BlockingQueue<MockResponse> responseQueue
-            = new LinkedBlockingDeque<MockResponse>();
+    private final BlockingQueue<BaseMockResponse<?>> responseQueue
+            = new LinkedBlockingDeque<BaseMockResponse<?>>();
     private final Set<Socket> openClientSockets
             = Collections.newSetFromMap(new ConcurrentHashMap<Socket, Boolean>());
     private boolean singleResponse;
@@ -158,8 +159,12 @@
         return requestCount.get();
     }
 
-    public void enqueue(MockResponse response) {
-        responseQueue.add(response.clone());
+    public void enqueue(BaseMockResponse<?> response) {
+        if (response instanceof MockResponse) {
+            responseQueue.add(((MockResponse) response).clone());
+        } else {
+            responseQueue.add(response);
+        }
     }
 
     /**
@@ -236,7 +241,7 @@
                     } catch (SocketException ignored) {
                         continue;
                     }
-                    MockResponse peek = responseQueue.peek();
+                    BaseMockResponse<?> peek = responseQueue.peek();
                     if (peek != null && peek.getSocketPolicy() == DISCONNECT_AT_START) {
                         responseQueue.take();
                         socket.close();
@@ -274,7 +279,7 @@
                     if (tunnelProxy) {
                         createTunnel();
                     }
-                    MockResponse response = responseQueue.peek();
+                    BaseMockResponse<?> response = responseQueue.peek();
                     if (response != null && response.getSocketPolicy() == FAIL_HANDSHAKE) {
                         processHandshakeFailure(raw, sequenceNumber++);
                         return;
@@ -312,7 +317,7 @@
              */
             private void createTunnel() throws IOException, InterruptedException {
                 while (true) {
-                    MockResponse connect = responseQueue.peek();
+                    BaseMockResponse<?> connect = responseQueue.peek();
                     if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
                         throw new IllegalStateException("Tunnel without any CONNECT!");
                     }
@@ -332,8 +337,8 @@
                 if (request == null) {
                     return false;
                 }
-                MockResponse response = dispatch(request);
-                writeResponse(out, response);
+                BaseMockResponse<?> response = dispatch(request);
+                response.writeResponse(out);
                 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
                     in.close();
                     out.close();
@@ -450,7 +455,7 @@
     /**
      * Returns a response to satisfy {@code request}.
      */
-    private MockResponse dispatch(RecordedRequest request) throws InterruptedException {
+    private BaseMockResponse<?> dispatch(RecordedRequest request) throws InterruptedException {
         if (responseQueue.isEmpty()) {
             throw new IllegalStateException("Unexpected request: " + request);
         }
@@ -471,44 +476,6 @@
         }
     }
 
-    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
-        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
-        for (String header : response.getHeaders()) {
-            out.write((header + "\r\n").getBytes(ASCII));
-        }
-        out.write(("\r\n").getBytes(ASCII));
-        out.flush();
-
-        final InputStream in = response.getBodyStream();
-        final int bytesPerSecond = response.getBytesPerSecond();
-
-        // Stream data in MTU-sized increments
-        final byte[] buffer = new byte[1452];
-        final long delayMs;
-        if (bytesPerSecond == Integer.MAX_VALUE) {
-            delayMs = 0;
-        } else {
-            delayMs = (1000 * buffer.length) / bytesPerSecond;
-        }
-
-        int read;
-        long sinceDelay = 0;
-        while ((read = in.read(buffer)) != -1) {
-            out.write(buffer, 0, read);
-            out.flush();
-
-            sinceDelay += read;
-            if (sinceDelay >= buffer.length && delayMs > 0) {
-                sinceDelay %= buffer.length;
-                try {
-                    Thread.sleep(delayMs);
-                } catch (InterruptedException e) {
-                    throw new AssertionError();
-                }
-            }
-        }
-    }
-
     /**
      * Transfer bytes from {@code in} to {@code out} until either {@code length}
      * bytes have been transferred or {@code in} is exhausted.