Support SSL handshake failures.

Also add an API to RecordedRequest to help differentiate between
original TLSv1 requests and SSLv3 fallback requests.

Bug: http://b/4462288
Change-Id: I31ad0e51ca4d21365e8f1d1e717f97cc94eb040e
diff --git a/src/main/java/com/google/mockwebserver/MockWebServer.java b/src/main/java/com/google/mockwebserver/MockWebServer.java
index ba92458..65f4547 100644
--- a/src/main/java/com/google/mockwebserver/MockWebServer.java
+++ b/src/main/java/com/google/mockwebserver/MockWebServer.java
@@ -17,6 +17,7 @@
 package com.google.mockwebserver;
 
 import static com.google.mockwebserver.SocketPolicy.DISCONNECT_AT_START;
+import static com.google.mockwebserver.SocketPolicy.FAIL_HANDSHAKE;
 import java.io.BufferedInputStream;
 import java.io.BufferedOutputStream;
 import java.io.ByteArrayOutputStream;
@@ -33,6 +34,8 @@
 import java.net.SocketException;
 import java.net.URL;
 import java.net.UnknownHostException;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Iterator;
@@ -47,8 +50,11 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.logging.Level;
 import java.util.logging.Logger;
+import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLSocket;
 import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.X509TrustManager;
 
 /**
  * A scriptable web server. Callers supply canned responses and the server
@@ -267,6 +273,11 @@
                     if (tunnelProxy) {
                         createTunnel();
                     }
+                    MockResponse response = responseQueue.peek();
+                    if (response != null && response.getSocketPolicy() == FAIL_HANDSHAKE) {
+                        processHandshakeFailure(raw, sequenceNumber++);
+                        return;
+                    }
                     socket = sslSocketFactory.createSocket(
                             raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
                     ((SSLSocket) socket).setUseClientMode(false);
@@ -279,7 +290,7 @@
                 InputStream in = new BufferedInputStream(socket.getInputStream());
                 OutputStream out = new BufferedOutputStream(socket.getOutputStream());
 
-                while (!responseQueue.isEmpty() && processOneRequest(in, out, socket)) {}
+                while (!responseQueue.isEmpty() && processOneRequest(socket, in, out)) {}
 
                 if (sequenceNumber == 0) {
                     logger.warning("MockWebServer connection didn't make a request");
@@ -301,7 +312,7 @@
             private void createTunnel() throws IOException, InterruptedException {
                 while (true) {
                     MockResponse connect = responseQueue.peek();
-                    if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) {
+                    if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
                         throw new IllegalStateException("Tunnel without any CONNECT!");
                     }
                     if (connect.getSocketPolicy() == SocketPolicy.UPGRADE_TO_SSL_AT_END) {
@@ -314,9 +325,9 @@
              * Reads a request and writes its response. Returns true if a request
              * was processed.
              */
-            private boolean processOneRequest(InputStream in, OutputStream out, Socket socket)
+            private boolean processOneRequest(Socket socket, InputStream in, OutputStream out)
                     throws IOException, InterruptedException {
-                RecordedRequest request = readRequest(in, sequenceNumber);
+                RecordedRequest request = readRequest(socket, in, sequenceNumber);
                 if (request == null) {
                     return false;
                 }
@@ -336,10 +347,40 @@
         }));
     }
 
+    private void processHandshakeFailure(Socket raw, int sequenceNumber) throws Exception {
+        responseQueue.take();
+        X509TrustManager untrusted = new X509TrustManager() {
+            @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
+                    throws CertificateException {
+                throw new CertificateException();
+            }
+            @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
+                throw new AssertionError();
+            }
+            @Override public X509Certificate[] getAcceptedIssuers() {
+                throw new AssertionError();
+            }
+        };
+        SSLContext context = SSLContext.getInstance("TLS");
+        context.init(null, new TrustManager[] { untrusted }, new java.security.SecureRandom());
+        SSLSocketFactory sslSocketFactory = context.getSocketFactory();
+        SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
+                raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
+        try {
+            socket.startHandshake(); // we're testing a handshake failure
+            throw new AssertionError();
+        } catch (IOException expected) {
+        }
+        socket.close();
+        requestCount.incrementAndGet();
+        requestQueue.add(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket));
+    }
+
     /**
      * @param sequenceNumber the index of this request on this connection.
      */
-    private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
+    private RecordedRequest readRequest(Socket socket, InputStream in, int sequenceNumber)
+            throws IOException {
         String request;
         try {
             request = readAsciiUntilCrlf(in);
@@ -401,7 +442,7 @@
         }
 
         return new RecordedRequest(request, headers, chunkSizes,
-                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
+                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket);
     }
 
     /**
diff --git a/src/main/java/com/google/mockwebserver/RecordedRequest.java b/src/main/java/com/google/mockwebserver/RecordedRequest.java
index 8f09084..a06c0bc 100644
--- a/src/main/java/com/google/mockwebserver/RecordedRequest.java
+++ b/src/main/java/com/google/mockwebserver/RecordedRequest.java
@@ -16,7 +16,9 @@
 
 package com.google.mockwebserver;
 
+import java.net.Socket;
 import java.util.List;
+import javax.net.ssl.SSLSocket;
 
 /**
  * An HTTP request that came into the mock web server.
@@ -28,15 +30,23 @@
     private final int bodySize;
     private final byte[] body;
     private final int sequenceNumber;
+    private final String sslProtocol;
 
     RecordedRequest(String requestLine, List<String> headers, List<Integer> chunkSizes,
-            int bodySize, byte[] body, int sequenceNumber) {
+            int bodySize, byte[] body, int sequenceNumber, Socket socket) {
         this.requestLine = requestLine;
         this.headers = headers;
         this.chunkSizes = chunkSizes;
         this.bodySize = bodySize;
         this.body = body;
         this.sequenceNumber = sequenceNumber;
+
+        if (socket instanceof SSLSocket) {
+            SSLSocket sslSocket = (SSLSocket) socket;
+            sslProtocol = sslSocket.getSession().getProtocol();
+        } else {
+            sslProtocol = null;
+        }
     }
 
     public String getRequestLine() {
@@ -79,6 +89,14 @@
         return sequenceNumber;
     }
 
+    /**
+     * Returns the connection's SSL protocol like {@code TLSv1}, {@code SSLv3},
+     * {@code NONE} or null if the connection doesn't use SSL.
+     */
+    public String getSslProtocol() {
+        return sslProtocol;
+    }
+
     @Override public String toString() {
         return requestLine;
     }
diff --git a/src/main/java/com/google/mockwebserver/SocketPolicy.java b/src/main/java/com/google/mockwebserver/SocketPolicy.java
index d256a45..3a6797b 100644
--- a/src/main/java/com/google/mockwebserver/SocketPolicy.java
+++ b/src/main/java/com/google/mockwebserver/SocketPolicy.java
@@ -51,6 +51,11 @@
     DISCONNECT_AT_START,
 
     /**
+     * Don't trust the client during the SSL handshake.
+     */
+    FAIL_HANDSHAKE,
+
+    /**
      * Shutdown the socket input after sending the response. For testing bad
      * behavior.
      */