Support SSL handshake failures.
Also add an API to RecordedRequest to help differentiate between
original TLSv1 requests and SSLv3 fallback requests.
Bug: http://b/4462288
Change-Id: I31ad0e51ca4d21365e8f1d1e717f97cc94eb040e
diff --git a/src/main/java/com/google/mockwebserver/MockWebServer.java b/src/main/java/com/google/mockwebserver/MockWebServer.java
index ba92458..65f4547 100644
--- a/src/main/java/com/google/mockwebserver/MockWebServer.java
+++ b/src/main/java/com/google/mockwebserver/MockWebServer.java
@@ -17,6 +17,7 @@
package com.google.mockwebserver;
import static com.google.mockwebserver.SocketPolicy.DISCONNECT_AT_START;
+import static com.google.mockwebserver.SocketPolicy.FAIL_HANDSHAKE;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
@@ -33,6 +34,8 @@
import java.net.SocketException;
import java.net.URL;
import java.net.UnknownHostException;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
@@ -47,8 +50,11 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
+import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.X509TrustManager;
/**
* A scriptable web server. Callers supply canned responses and the server
@@ -267,6 +273,11 @@
if (tunnelProxy) {
createTunnel();
}
+ MockResponse response = responseQueue.peek();
+ if (response != null && response.getSocketPolicy() == FAIL_HANDSHAKE) {
+ processHandshakeFailure(raw, sequenceNumber++);
+ return;
+ }
socket = sslSocketFactory.createSocket(
raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
((SSLSocket) socket).setUseClientMode(false);
@@ -279,7 +290,7 @@
InputStream in = new BufferedInputStream(socket.getInputStream());
OutputStream out = new BufferedOutputStream(socket.getOutputStream());
- while (!responseQueue.isEmpty() && processOneRequest(in, out, socket)) {}
+ while (!responseQueue.isEmpty() && processOneRequest(socket, in, out)) {}
if (sequenceNumber == 0) {
logger.warning("MockWebServer connection didn't make a request");
@@ -301,7 +312,7 @@
private void createTunnel() throws IOException, InterruptedException {
while (true) {
MockResponse connect = responseQueue.peek();
- if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) {
+ if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
throw new IllegalStateException("Tunnel without any CONNECT!");
}
if (connect.getSocketPolicy() == SocketPolicy.UPGRADE_TO_SSL_AT_END) {
@@ -314,9 +325,9 @@
* Reads a request and writes its response. Returns true if a request
* was processed.
*/
- private boolean processOneRequest(InputStream in, OutputStream out, Socket socket)
+ private boolean processOneRequest(Socket socket, InputStream in, OutputStream out)
throws IOException, InterruptedException {
- RecordedRequest request = readRequest(in, sequenceNumber);
+ RecordedRequest request = readRequest(socket, in, sequenceNumber);
if (request == null) {
return false;
}
@@ -336,10 +347,40 @@
}));
}
+ private void processHandshakeFailure(Socket raw, int sequenceNumber) throws Exception {
+ responseQueue.take();
+ X509TrustManager untrusted = new X509TrustManager() {
+ @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
+ throws CertificateException {
+ throw new CertificateException();
+ }
+ @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
+ throw new AssertionError();
+ }
+ @Override public X509Certificate[] getAcceptedIssuers() {
+ throw new AssertionError();
+ }
+ };
+ SSLContext context = SSLContext.getInstance("TLS");
+ context.init(null, new TrustManager[] { untrusted }, new java.security.SecureRandom());
+ SSLSocketFactory sslSocketFactory = context.getSocketFactory();
+ SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
+ raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
+ try {
+ socket.startHandshake(); // we're testing a handshake failure
+ throw new AssertionError();
+ } catch (IOException expected) {
+ }
+ socket.close();
+ requestCount.incrementAndGet();
+ requestQueue.add(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket));
+ }
+
/**
* @param sequenceNumber the index of this request on this connection.
*/
- private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
+ private RecordedRequest readRequest(Socket socket, InputStream in, int sequenceNumber)
+ throws IOException {
String request;
try {
request = readAsciiUntilCrlf(in);
@@ -401,7 +442,7 @@
}
return new RecordedRequest(request, headers, chunkSizes,
- requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
+ requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket);
}
/**
diff --git a/src/main/java/com/google/mockwebserver/RecordedRequest.java b/src/main/java/com/google/mockwebserver/RecordedRequest.java
index 8f09084..a06c0bc 100644
--- a/src/main/java/com/google/mockwebserver/RecordedRequest.java
+++ b/src/main/java/com/google/mockwebserver/RecordedRequest.java
@@ -16,7 +16,9 @@
package com.google.mockwebserver;
+import java.net.Socket;
import java.util.List;
+import javax.net.ssl.SSLSocket;
/**
* An HTTP request that came into the mock web server.
@@ -28,15 +30,23 @@
private final int bodySize;
private final byte[] body;
private final int sequenceNumber;
+ private final String sslProtocol;
RecordedRequest(String requestLine, List<String> headers, List<Integer> chunkSizes,
- int bodySize, byte[] body, int sequenceNumber) {
+ int bodySize, byte[] body, int sequenceNumber, Socket socket) {
this.requestLine = requestLine;
this.headers = headers;
this.chunkSizes = chunkSizes;
this.bodySize = bodySize;
this.body = body;
this.sequenceNumber = sequenceNumber;
+
+ if (socket instanceof SSLSocket) {
+ SSLSocket sslSocket = (SSLSocket) socket;
+ sslProtocol = sslSocket.getSession().getProtocol();
+ } else {
+ sslProtocol = null;
+ }
}
public String getRequestLine() {
@@ -79,6 +89,14 @@
return sequenceNumber;
}
+ /**
+ * Returns the connection's SSL protocol like {@code TLSv1}, {@code SSLv3},
+ * {@code NONE} or null if the connection doesn't use SSL.
+ */
+ public String getSslProtocol() {
+ return sslProtocol;
+ }
+
@Override public String toString() {
return requestLine;
}
diff --git a/src/main/java/com/google/mockwebserver/SocketPolicy.java b/src/main/java/com/google/mockwebserver/SocketPolicy.java
index d256a45..3a6797b 100644
--- a/src/main/java/com/google/mockwebserver/SocketPolicy.java
+++ b/src/main/java/com/google/mockwebserver/SocketPolicy.java
@@ -51,6 +51,11 @@
DISCONNECT_AT_START,
/**
+ * Don't trust the client during the SSL handshake.
+ */
+ FAIL_HANDSHAKE,
+
+ /**
* Shutdown the socket input after sending the response. For testing bad
* behavior.
*/