Skip to content

Commit

Permalink
Support recovery from SSL errors.
Browse files Browse the repository at this point in the history
Previously, the new WebSocket implementation was unable to handle sites
with self-signed certificates and other cases where the user had
overridden certificate errors. Add code to support this case.

This requires adding infrastructure to pass the SSL error back up to the
content layer which knows how to handle it. It also requires that the ID
of the frame be known, so an extra parameter has been added to the
WebSocketHostMsg_AddChannelRequest IPC message.

BUG=364361

Review URL: https://codereview.chromium.org/304093003

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@275066 0039d316-1c4b-4281-b951-d872f2087c98
  • Loading branch information
ricea@chromium.org committed Jun 5, 2014
1 parent e1c913c commit a624495
Show file tree
Hide file tree
Showing 16 changed files with 430 additions and 47 deletions.
2 changes: 2 additions & 0 deletions content/browser/renderer_host/websocket_dispatcher_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class CONTENT_EXPORT WebSocketDispatcherHost : public BrowserMessageFilter {
// Returns whether the associated renderer process can read raw cookies.
bool CanReadRawCookies() const;

int render_process_id() const { return process_id_; }

private:
typedef base::hash_map<int, WebSocketHost*> WebSocketHostTable;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
namespace content {
namespace {

// This number is unlikely to occur by chance.
static const int kMagicRenderProcessId = 506116062;

// A mock of WebsocketHost which records received messages.
class MockWebSocketHost : public WebSocketHost {
public:
Expand All @@ -43,7 +46,7 @@ class WebSocketDispatcherHostTest : public ::testing::Test {
public:
WebSocketDispatcherHostTest() {
dispatcher_host_ = new WebSocketDispatcherHost(
0,
kMagicRenderProcessId,
base::Bind(&WebSocketDispatcherHostTest::OnGetRequestContext,
base::Unretained(this)),
base::Bind(&WebSocketDispatcherHostTest::CreateWebSocketHost,
Expand Down Expand Up @@ -81,14 +84,19 @@ TEST_F(WebSocketDispatcherHostTest, UnrelatedMessage) {
EXPECT_FALSE(dispatcher_host_->OnMessageReceived(message));
}

TEST_F(WebSocketDispatcherHostTest, RenderProcessIdGetter) {
EXPECT_EQ(kMagicRenderProcessId, dispatcher_host_->render_process_id());
}

TEST_F(WebSocketDispatcherHostTest, AddChannelRequest) {
int routing_id = 123;
GURL socket_url("ws://example.com/test");
std::vector<std::string> requested_protocols;
requested_protocols.push_back("hello");
url::Origin origin("http://example.com/test");
int render_frame_id = -2;
WebSocketHostMsg_AddChannelRequest message(
routing_id, socket_url, requested_protocols, origin);
routing_id, socket_url, requested_protocols, origin, render_frame_id);

ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message));

Expand Down Expand Up @@ -120,8 +128,9 @@ TEST_F(WebSocketDispatcherHostTest, SendFrame) {
std::vector<std::string> requested_protocols;
requested_protocols.push_back("hello");
url::Origin origin("http://example.com/test");
int render_frame_id = -2;
WebSocketHostMsg_AddChannelRequest add_channel_message(
routing_id, socket_url, requested_protocols, origin);
routing_id, socket_url, requested_protocols, origin, render_frame_id);

ASSERT_TRUE(dispatcher_host_->OnMessageReceived(add_channel_message));

Expand Down
119 changes: 107 additions & 12 deletions content/browser/renderer_host/websocket_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
#include "content/browser/renderer_host/websocket_host.h"

#include "base/basictypes.h"
#include "base/memory/weak_ptr.h"
#include "base/strings/string_util.h"
#include "content/browser/renderer_host/websocket_dispatcher_host.h"
#include "content/browser/ssl/ssl_error_handler.h"
#include "content/browser/ssl/ssl_manager.h"
#include "content/common/websocket_messages.h"
#include "ipc/ipc_message_macros.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"
#include "net/ssl/ssl_info.h"
#include "net/websockets/websocket_channel.h"
#include "net/websockets/websocket_event_interface.h"
#include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode
Expand Down Expand Up @@ -80,7 +84,9 @@ ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) {
// renderer or child process via WebSocketDispatcherHost.
class WebSocketEventHandler : public net::WebSocketEventInterface {
public:
WebSocketEventHandler(WebSocketDispatcherHost* dispatcher, int routing_id);
WebSocketEventHandler(WebSocketDispatcherHost* dispatcher,
int routing_id,
int render_frame_id);
virtual ~WebSocketEventHandler();

// net::WebSocketEventInterface implementation
Expand All @@ -102,18 +108,50 @@ class WebSocketEventHandler : public net::WebSocketEventInterface {
scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE;
virtual ChannelState OnFinishOpeningHandshake(
scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE;
virtual ChannelState OnSSLCertificateError(
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
const GURL& url,
const net::SSLInfo& ssl_info,
bool fatal) OVERRIDE;

private:
class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate {
public:
SSLErrorHandlerDelegate(
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks);
virtual ~SSLErrorHandlerDelegate();

base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr();

// SSLErrorHandler::Delegate methods
virtual void CancelSSLRequest(const GlobalRequestID& id,
int error,
const net::SSLInfo* ssl_info) OVERRIDE;
virtual void ContinueSSLRequest(const GlobalRequestID& id) OVERRIDE;

private:
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_;
base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_;

DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate);
};

WebSocketDispatcherHost* const dispatcher_;
const int routing_id_;
const int render_frame_id_;
scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_;

DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
};

WebSocketEventHandler::WebSocketEventHandler(
WebSocketDispatcherHost* dispatcher,
int routing_id)
: dispatcher_(dispatcher), routing_id_(routing_id) {}
int routing_id,
int render_frame_id)
: dispatcher_(dispatcher),
routing_id_(routing_id),
render_frame_id_(render_frame_id) {
}

WebSocketEventHandler::~WebSocketEventHandler() {
DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_;
Expand Down Expand Up @@ -227,18 +265,67 @@ ChannelState WebSocketEventHandler::OnFinishOpeningHandshake(
response_to_pass));
}

ChannelState WebSocketEventHandler::OnSSLCertificateError(
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
const GURL& url,
const net::SSLInfo& ssl_info,
bool fatal) {
DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError"
<< " routing_id=" << routing_id_ << " url=" << url.spec()
<< " cert_status=" << ssl_info.cert_status << " fatal=" << fatal;
ssl_error_handler_delegate_.reset(
new SSLErrorHandlerDelegate(callbacks.Pass()));
// We don't need request_id to be unique so just make a fake one.
GlobalRequestID request_id(-1, -1);
SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_->GetWeakPtr(),
request_id,
ResourceType::SUB_RESOURCE,
url,
dispatcher_->render_process_id(),
render_frame_id_,
ssl_info,
fatal);
// The above method is always asynchronous.
return WebSocketEventInterface::CHANNEL_ALIVE;
}

WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate(
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks)
: callbacks_(callbacks.Pass()), weak_ptr_factory_(this) {}

WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {}

base::WeakPtr<SSLErrorHandler::Delegate>
WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}

void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest(
const GlobalRequestID& id,
int error,
const net::SSLInfo* ssl_info) {
DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest"
<< " error=" << error
<< " cert_status=" << (ssl_info ? ssl_info->cert_status
: static_cast<net::CertStatus>(-1));
callbacks_->CancelSSLRequest(error, ssl_info);
}

void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest(
const GlobalRequestID& id) {
DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest";
callbacks_->ContinueSSLRequest();
}

} // namespace

WebSocketHost::WebSocketHost(int routing_id,
WebSocketDispatcherHost* dispatcher,
net::URLRequestContext* url_request_context)
: routing_id_(routing_id) {
: dispatcher_(dispatcher),
url_request_context_(url_request_context),
routing_id_(routing_id) {
DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id;

scoped_ptr<net::WebSocketEventInterface> event_interface(
new WebSocketEventHandler(dispatcher, routing_id));
channel_.reset(
new net::WebSocketChannel(event_interface.Pass(), url_request_context));
}

WebSocketHost::~WebSocketHost() {}
Expand All @@ -258,15 +345,20 @@ bool WebSocketHost::OnMessageReceived(const IPC::Message& message) {
void WebSocketHost::OnAddChannelRequest(
const GURL& socket_url,
const std::vector<std::string>& requested_protocols,
const url::Origin& origin) {
const url::Origin& origin,
int render_frame_id) {
DVLOG(3) << "WebSocketHost::OnAddChannelRequest"
<< " routing_id=" << routing_id_ << " socket_url=\"" << socket_url
<< "\" requested_protocols=\""
<< JoinString(requested_protocols, ", ") << "\" origin=\""
<< origin.string() << "\"";

channel_->SendAddChannelRequest(
socket_url, requested_protocols, origin);
DCHECK(!channel_);
scoped_ptr<net::WebSocketEventInterface> event_interface(
new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id));
channel_.reset(
new net::WebSocketChannel(event_interface.Pass(), url_request_context_));
channel_->SendAddChannelRequest(socket_url, requested_protocols, origin);
}

void WebSocketHost::OnSendFrame(bool fin,
Expand All @@ -276,13 +368,15 @@ void WebSocketHost::OnSendFrame(bool fin,
<< " routing_id=" << routing_id_ << " fin=" << fin
<< " type=" << type << " data is " << data.size() << " bytes";

DCHECK(channel_);
channel_->SendFrame(fin, MessageTypeToOpCode(type), data);
}

void WebSocketHost::OnFlowControl(int64 quota) {
DVLOG(3) << "WebSocketHost::OnFlowControl"
<< " routing_id=" << routing_id_ << " quota=" << quota;

DCHECK(channel_);
channel_->SendFlowControl(quota);
}

Expand All @@ -293,6 +387,7 @@ void WebSocketHost::OnDropChannel(bool was_clean,
<< " routing_id=" << routing_id_ << " was_clean=" << was_clean
<< " code=" << code << " reason=\"" << reason << "\"";

DCHECK(channel_);
// TODO(yhirano): Handle |was_clean| appropriately.
channel_->StartClosingHandshake(code, reason);
}
Expand Down
11 changes: 9 additions & 2 deletions content/browser/renderer_host/websocket_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class CONTENT_EXPORT WebSocketHost {

void OnAddChannelRequest(const GURL& socket_url,
const std::vector<std::string>& requested_protocols,
const url::Origin& origin);
const url::Origin& origin,
int render_frame_id);

void OnSendFrame(bool fin,
WebSocketMessageType type,
Expand All @@ -63,8 +64,14 @@ class CONTENT_EXPORT WebSocketHost {
// The channel we use to send events to the network.
scoped_ptr<net::WebSocketChannel> channel_;

// The WebSocketHostDispatcher that created this object.
WebSocketDispatcherHost* const dispatcher_;

// The URL request context for the channel.
net::URLRequestContext* const url_request_context_;

// The ID used to route messages.
int routing_id_;
const int routing_id_;

DISALLOW_COPY_AND_ASSIGN(WebSocketHost);
};
Expand Down
7 changes: 2 additions & 5 deletions content/child/websocket_bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,8 @@ void WebSocketBridge::connect(
<< JoinString(protocols_to_pass, ", ") << "), "
<< origin_to_pass.string() << ")";

ChildThread::current()->Send(
new WebSocketHostMsg_AddChannelRequest(channel_id_,
url,
protocols_to_pass,
origin_to_pass));
ChildThread::current()->Send(new WebSocketHostMsg_AddChannelRequest(
channel_id_, url, protocols_to_pass, origin_to_pass, render_frame_id_));
}

void WebSocketBridge::send(bool fin,
Expand Down
5 changes: 3 additions & 2 deletions content/common/websocket_messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ IPC_STRUCT_TRAITS_END()
// The browser process will not send |channel_id| as-is to the remote server; it
// will try to use a short id on the wire. This saves the renderer from
// having to try to choose the ids cleverly.
IPC_MESSAGE_ROUTED3(WebSocketHostMsg_AddChannelRequest,
IPC_MESSAGE_ROUTED4(WebSocketHostMsg_AddChannelRequest,
GURL /* socket_url */,
std::vector<std::string> /* requested_protocols */,
url::Origin /* origin */)
url::Origin /* origin */,
int /* render_frame_id */)

// WebSocket messages sent from the browser to the renderer.

Expand Down
17 changes: 17 additions & 0 deletions net/websockets/websocket_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ class WebSocketChannel::ConnectDelegate
creator_->OnFinishOpeningHandshake(response.Pass());
}

virtual void OnSSLCertificateError(
scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>
ssl_error_callbacks,
const SSLInfo& ssl_info,
bool fatal) OVERRIDE {
creator_->OnSSLCertificateError(
ssl_error_callbacks.Pass(), ssl_info, fatal);
}

private:
// A pointer to the WebSocketChannel that created this object. There is no
// danger of this pointer being stale, because deleting the WebSocketChannel
Expand Down Expand Up @@ -576,6 +585,14 @@ void WebSocketChannel::OnConnectFailure(const std::string& message) {
// |this| has been deleted.
}

void WebSocketChannel::OnSSLCertificateError(
scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,
const SSLInfo& ssl_info,
bool fatal) {
AllowUnused(event_interface_->OnSSLCertificateError(
ssl_error_callbacks.Pass(), socket_url_, ssl_info, fatal));
}

void WebSocketChannel::OnStartOpeningHandshake(
scoped_ptr<WebSocketHandshakeRequestInfo> request) {
DCHECK(!notification_sender_->handshake_request_info());
Expand Down
9 changes: 9 additions & 0 deletions net/websockets/websocket_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ class NET_EXPORT WebSocketChannel {
// failure to the event interface. May delete |this|.
void OnConnectFailure(const std::string& message);

// SSL certificate error callback from
// WebSocketStream::CreateAndConnectStream(). Forwards the request to the
// event interface.
void OnSSLCertificateError(
scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>
ssl_error_callbacks,
const SSLInfo& ssl_info,
bool fatal);

// Posts a task that sends pending notifications relating WebSocket Opening
// Handshake to the renderer.
void ScheduleOpeningHandshakeNotification();
Expand Down
Loading

0 comments on commit a624495

Please sign in to comment.