diff --git a/base/containers/linked_list.h b/base/containers/linked_list.h index 25bbe762cb759b..41461ff365e66f 100644 --- a/base/containers/linked_list.h +++ b/base/containers/linked_list.h @@ -5,6 +5,8 @@ #ifndef BASE_CONTAINERS_LINKED_LIST_H_ #define BASE_CONTAINERS_LINKED_LIST_H_ +#include "base/macros.h" + // Simple LinkedList type. (See the Q&A section to understand how this // differs from std::list). // @@ -82,7 +84,7 @@ namespace base { template class LinkNode { public: - LinkNode() : previous_(0), next_(0) {} + LinkNode() : previous_(NULL), next_(NULL) {} LinkNode(LinkNode* previous, LinkNode* next) : previous_(previous), next_(next) {} @@ -106,6 +108,10 @@ class LinkNode { void RemoveFromList() { this->previous_->next_ = this->next_; this->next_->previous_ = this->previous_; + // next() and previous() return non-NULL if and only this node is not in any + // list. + this->next_ = NULL; + this->previous_ = NULL; } LinkNode* previous() const { @@ -128,6 +134,8 @@ class LinkNode { private: LinkNode* previous_; LinkNode* next_; + + DISALLOW_COPY_AND_ASSIGN(LinkNode); }; template @@ -155,8 +163,12 @@ class LinkedList { return &root_; } + bool empty() const { return head() == end(); } + private: LinkNode root_; + + DISALLOW_COPY_AND_ASSIGN(LinkedList); }; } // namespace base diff --git a/base/containers/linked_list_unittest.cc b/base/containers/linked_list_unittest.cc index 801e3028a07459..93a9f385084c0e 100644 --- a/base/containers/linked_list_unittest.cc +++ b/base/containers/linked_list_unittest.cc @@ -257,5 +257,52 @@ TEST(LinkedList, MultipleInheritanceNode) { EXPECT_EQ(&node, node.value()); } +TEST(LinkedList, EmptyListIsEmpty) { + LinkedList list; + EXPECT_TRUE(list.empty()); +} + +TEST(LinkedList, NonEmptyListIsNotEmpty) { + LinkedList list; + + Node n(1); + list.Append(&n); + + EXPECT_FALSE(list.empty()); +} + +TEST(LinkedList, EmptiedListIsEmptyAgain) { + LinkedList list; + + Node n(1); + list.Append(&n); + n.RemoveFromList(); + + EXPECT_TRUE(list.empty()); +} + +TEST(LinkedList, NodesCanBeReused) { + LinkedList list1; + LinkedList list2; + + Node n(1); + list1.Append(&n); + n.RemoveFromList(); + list2.Append(&n); + + EXPECT_EQ(list2.head()->value(), &n); +} + +TEST(LinkedList, RemovedNodeHasNullNextPrevious) { + LinkedList list; + + Node n(1); + list.Append(&n); + n.RemoveFromList(); + + EXPECT_EQ(NULL, n.next()); + EXPECT_EQ(NULL, n.previous()); +} + } // namespace } // namespace base diff --git a/net/http/http_stream_factory_impl_unittest.cc b/net/http/http_stream_factory_impl_unittest.cc index 924e1c3997624d..8a407e5015d76d 100644 --- a/net/http/http_stream_factory_impl_unittest.cc +++ b/net/http/http_stream_factory_impl_unittest.cc @@ -945,9 +945,6 @@ TEST_P(HttpStreamFactoryTest, RequestWebSocketBasicHandshakeStream) { session->GetTransportSocketPool(HttpNetworkSession::NORMAL_SOCKET_POOL))); EXPECT_EQ(0, GetSocketPoolGroupCount( session->GetSSLSocketPool(HttpNetworkSession::NORMAL_SOCKET_POOL))); - EXPECT_EQ(1, GetSocketPoolGroupCount( - session->GetTransportSocketPool( - HttpNetworkSession::WEBSOCKET_SOCKET_POOL))); EXPECT_EQ(0, GetSocketPoolGroupCount( session->GetSSLSocketPool(HttpNetworkSession::WEBSOCKET_SOCKET_POOL))); EXPECT_TRUE(waiter.used_proxy_info().is_direct()); @@ -996,9 +993,6 @@ TEST_P(HttpStreamFactoryTest, RequestWebSocketBasicHandshakeStreamOverSSL) { session->GetTransportSocketPool(HttpNetworkSession::NORMAL_SOCKET_POOL))); EXPECT_EQ(0, GetSocketPoolGroupCount( session->GetSSLSocketPool(HttpNetworkSession::NORMAL_SOCKET_POOL))); - EXPECT_EQ(1, GetSocketPoolGroupCount( - session->GetTransportSocketPool( - HttpNetworkSession::WEBSOCKET_SOCKET_POOL))); EXPECT_EQ(1, GetSocketPoolGroupCount( session->GetSSLSocketPool(HttpNetworkSession::WEBSOCKET_SOCKET_POOL))); EXPECT_TRUE(waiter.used_proxy_info().is_direct()); @@ -1161,9 +1155,6 @@ TEST_P(HttpStreamFactoryTest, RequestWebSocketSpdyHandshakeStreamButGetSSL) { session->GetTransportSocketPool(HttpNetworkSession::NORMAL_SOCKET_POOL))); EXPECT_EQ(0, GetSocketPoolGroupCount( session->GetSSLSocketPool(HttpNetworkSession::NORMAL_SOCKET_POOL))); - EXPECT_EQ(1, GetSocketPoolGroupCount( - session->GetTransportSocketPool( - HttpNetworkSession::WEBSOCKET_SOCKET_POOL))); EXPECT_EQ(1, GetSocketPoolGroupCount( session->GetSSLSocketPool(HttpNetworkSession::WEBSOCKET_SOCKET_POOL))); EXPECT_TRUE(waiter1.used_proxy_info().is_direct()); diff --git a/net/net.gypi b/net/net.gypi index 704d19b96f3cee..508c94684ef45a 100644 --- a/net/net.gypi +++ b/net/net.gypi @@ -979,6 +979,12 @@ 'socket/transport_client_socket_pool.h', 'socket/unix_domain_socket_posix.cc', 'socket/unix_domain_socket_posix.h', + 'socket/websocket_endpoint_lock_manager.cc', + 'socket/websocket_endpoint_lock_manager.h', + 'socket/websocket_transport_client_socket_pool.cc', + 'socket/websocket_transport_client_socket_pool.h', + 'socket/websocket_transport_connect_sub_job.cc', + 'socket/websocket_transport_connect_sub_job.h', 'socket_stream/socket_stream.cc', 'socket_stream/socket_stream.h', 'socket_stream/socket_stream_job.cc', @@ -1557,9 +1563,13 @@ 'socket/tcp_listen_socket_unittest.h', 'socket/tcp_server_socket_unittest.cc', 'socket/tcp_socket_unittest.cc', + 'socket/transport_client_socket_pool_test_util.cc', + 'socket/transport_client_socket_pool_test_util.h', 'socket/transport_client_socket_pool_unittest.cc', 'socket/transport_client_socket_unittest.cc', 'socket/unix_domain_socket_posix_unittest.cc', + 'socket/websocket_endpoint_lock_manager_unittest.cc', + 'socket/websocket_transport_client_socket_pool_unittest.cc', 'socket_stream/socket_stream_metrics_unittest.cc', 'socket_stream/socket_stream_unittest.cc', 'spdy/buffered_spdy_framer_unittest.cc', diff --git a/net/socket/client_socket_pool_manager_impl.cc b/net/socket/client_socket_pool_manager_impl.cc index 991278d7341c55..da73122e53fc44 100644 --- a/net/socket/client_socket_pool_manager_impl.cc +++ b/net/socket/client_socket_pool_manager_impl.cc @@ -11,6 +11,7 @@ #include "net/socket/socks_client_socket_pool.h" #include "net/socket/ssl_client_socket_pool.h" #include "net/socket/transport_client_socket_pool.h" +#include "net/socket/websocket_transport_client_socket_pool.h" #include "net/ssl/ssl_config_service.h" namespace net { @@ -57,12 +58,21 @@ ClientSocketPoolManagerImpl::ClientSocketPoolManagerImpl( ssl_config_service_(ssl_config_service), pool_type_(pool_type), transport_pool_histograms_("TCP"), - transport_socket_pool_(new TransportClientSocketPool( - max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type), - &transport_pool_histograms_, - host_resolver, - socket_factory_, - net_log)), + transport_socket_pool_( + pool_type == HttpNetworkSession::WEBSOCKET_SOCKET_POOL + ? new WebSocketTransportClientSocketPool( + max_sockets_per_pool(pool_type), + max_sockets_per_group(pool_type), + &transport_pool_histograms_, + host_resolver, + socket_factory_, + net_log) + : new TransportClientSocketPool(max_sockets_per_pool(pool_type), + max_sockets_per_group(pool_type), + &transport_pool_histograms_, + host_resolver, + socket_factory_, + net_log)), ssl_pool_histograms_("SSL2"), ssl_socket_pool_(new SSLClientSocketPool( max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type), diff --git a/net/socket/transport_client_socket_pool.cc b/net/socket/transport_client_socket_pool.cc index dc481fa4b6d600..cf247c03270679 100644 --- a/net/socket/transport_client_socket_pool.cc +++ b/net/socket/transport_client_socket_pool.cc @@ -31,7 +31,7 @@ namespace net { // TODO(willchan): Base this off RTT instead of statically setting it. Note we // choose a timeout that is different from the backup connect job timer so they // don't synchronize. -const int TransportConnectJob::kIPv6FallbackTimerInMs = 300; +const int TransportConnectJobHelper::kIPv6FallbackTimerInMs = 300; namespace { @@ -81,6 +81,107 @@ TransportSocketParams::~TransportSocketParams() {} // See comment #12 at http://crbug.com/23364 for specifics. static const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes. +TransportConnectJobHelper::TransportConnectJobHelper( + const scoped_refptr& params, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + LoadTimingInfo::ConnectTiming* connect_timing) + : params_(params), + client_socket_factory_(client_socket_factory), + resolver_(host_resolver), + next_state_(STATE_NONE), + connect_timing_(connect_timing) {} + +TransportConnectJobHelper::~TransportConnectJobHelper() {} + +int TransportConnectJobHelper::DoResolveHost(RequestPriority priority, + const BoundNetLog& net_log) { + next_state_ = STATE_RESOLVE_HOST_COMPLETE; + connect_timing_->dns_start = base::TimeTicks::Now(); + + return resolver_.Resolve( + params_->destination(), priority, &addresses_, on_io_complete_, net_log); +} + +int TransportConnectJobHelper::DoResolveHostComplete( + int result, + const BoundNetLog& net_log) { + connect_timing_->dns_end = base::TimeTicks::Now(); + // Overwrite connection start time, since for connections that do not go + // through proxies, |connect_start| should not include dns lookup time. + connect_timing_->connect_start = connect_timing_->dns_end; + + if (result == OK) { + // Invoke callback, and abort if it fails. + if (!params_->host_resolution_callback().is_null()) + result = params_->host_resolution_callback().Run(addresses_, net_log); + + if (result == OK) + next_state_ = STATE_TRANSPORT_CONNECT; + } + return result; +} + +base::TimeDelta TransportConnectJobHelper::HistogramDuration( + ConnectionLatencyHistogram race_result) { + DCHECK(!connect_timing_->connect_start.is_null()); + DCHECK(!connect_timing_->dns_start.is_null()); + base::TimeTicks now = base::TimeTicks::Now(); + base::TimeDelta total_duration = now - connect_timing_->dns_start; + UMA_HISTOGRAM_CUSTOM_TIMES("Net.DNS_Resolution_And_TCP_Connection_Latency2", + total_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + + base::TimeDelta connect_duration = now - connect_timing_->connect_start; + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + + switch (race_result) { + case CONNECTION_LATENCY_IPV4_WINS_RACE: + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_Wins_Race", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + + case CONNECTION_LATENCY_IPV4_NO_RACE: + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_No_Race", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + + case CONNECTION_LATENCY_IPV6_RACEABLE: + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Raceable", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + + case CONNECTION_LATENCY_IPV6_SOLO: + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Solo", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + + default: + NOTREACHED(); + break; + } + + return connect_duration; +} + TransportConnectJob::TransportConnectJob( const std::string& group_name, RequestPriority priority, @@ -92,11 +193,9 @@ TransportConnectJob::TransportConnectJob( NetLog* net_log) : ConnectJob(group_name, timeout_duration, priority, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), - params_(params), - client_socket_factory_(client_socket_factory), - resolver_(host_resolver), - next_state_(STATE_NONE), + helper_(params, client_socket_factory, host_resolver, &connect_timing_), interval_between_connects_(CONNECT_INTERVAL_GT_20MS) { + helper_.SetOnIOComplete(this); } TransportConnectJob::~TransportConnectJob() { @@ -105,14 +204,14 @@ TransportConnectJob::~TransportConnectJob() { } LoadState TransportConnectJob::GetLoadState() const { - switch (next_state_) { - case STATE_RESOLVE_HOST: - case STATE_RESOLVE_HOST_COMPLETE: + switch (helper_.next_state()) { + case TransportConnectJobHelper::STATE_RESOLVE_HOST: + case TransportConnectJobHelper::STATE_RESOLVE_HOST_COMPLETE: return LOAD_STATE_RESOLVING_HOST; - case STATE_TRANSPORT_CONNECT: - case STATE_TRANSPORT_CONNECT_COMPLETE: + case TransportConnectJobHelper::STATE_TRANSPORT_CONNECT: + case TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE: return LOAD_STATE_CONNECTING; - case STATE_NONE: + case TransportConnectJobHelper::STATE_NONE: return LOAD_STATE_IDLE; } NOTREACHED(); @@ -129,71 +228,12 @@ void TransportConnectJob::MakeAddressListStartWithIPv4(AddressList* list) { } } -void TransportConnectJob::OnIOComplete(int result) { - int rv = DoLoop(result); - if (rv != ERR_IO_PENDING) - NotifyDelegateOfCompletion(rv); // Deletes |this| -} - -int TransportConnectJob::DoLoop(int result) { - DCHECK_NE(next_state_, STATE_NONE); - - int rv = result; - do { - State state = next_state_; - next_state_ = STATE_NONE; - switch (state) { - case STATE_RESOLVE_HOST: - DCHECK_EQ(OK, rv); - rv = DoResolveHost(); - break; - case STATE_RESOLVE_HOST_COMPLETE: - rv = DoResolveHostComplete(rv); - break; - case STATE_TRANSPORT_CONNECT: - DCHECK_EQ(OK, rv); - rv = DoTransportConnect(); - break; - case STATE_TRANSPORT_CONNECT_COMPLETE: - rv = DoTransportConnectComplete(rv); - break; - default: - NOTREACHED(); - rv = ERR_FAILED; - break; - } - } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); - - return rv; -} - int TransportConnectJob::DoResolveHost() { - next_state_ = STATE_RESOLVE_HOST_COMPLETE; - connect_timing_.dns_start = base::TimeTicks::Now(); - - return resolver_.Resolve( - params_->destination(), - priority(), - &addresses_, - base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)), - net_log()); + return helper_.DoResolveHost(priority(), net_log()); } int TransportConnectJob::DoResolveHostComplete(int result) { - connect_timing_.dns_end = base::TimeTicks::Now(); - // Overwrite connection start time, since for connections that do not go - // through proxies, |connect_start| should not include dns lookup time. - connect_timing_.connect_start = connect_timing_.dns_end; - - if (result == OK) { - // Invoke callback, and abort if it fails. - if (!params_->host_resolution_callback().is_null()) - result = params_->host_resolution_callback().Run(addresses_, net_log()); - - if (result == OK) - next_state_ = STATE_TRANSPORT_CONNECT; - } - return result; + return helper_.DoResolveHostComplete(result, net_log()); } int TransportConnectJob::DoTransportConnect() { @@ -216,42 +256,42 @@ int TransportConnectJob::DoTransportConnect() { interval_between_connects_ = CONNECT_INTERVAL_GT_20MS; } - next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; - transport_socket_ = client_socket_factory_->CreateTransportClientSocket( - addresses_, net_log().net_log(), net_log().source()); - int rv = transport_socket_->Connect( - base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this))); + helper_.set_next_state( + TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE); + transport_socket_ = + helper_.client_socket_factory()->CreateTransportClientSocket( + helper_.addresses(), net_log().net_log(), net_log().source()); + int rv = transport_socket_->Connect(helper_.on_io_complete()); if (rv == ERR_IO_PENDING && - addresses_.front().GetFamily() == ADDRESS_FAMILY_IPV6 && - !AddressListOnlyContainsIPv6(addresses_)) { - fallback_timer_.Start(FROM_HERE, - base::TimeDelta::FromMilliseconds(kIPv6FallbackTimerInMs), - this, &TransportConnectJob::DoIPv6FallbackTransportConnect); + helper_.addresses().front().GetFamily() == ADDRESS_FAMILY_IPV6 && + !AddressListOnlyContainsIPv6(helper_.addresses())) { + fallback_timer_.Start( + FROM_HERE, + base::TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs), + this, + &TransportConnectJob::DoIPv6FallbackTransportConnect); } return rv; } int TransportConnectJob::DoTransportConnectComplete(int result) { if (result == OK) { - bool is_ipv4 = addresses_.front().GetFamily() == ADDRESS_FAMILY_IPV4; - DCHECK(!connect_timing_.connect_start.is_null()); - DCHECK(!connect_timing_.dns_start.is_null()); - base::TimeTicks now = base::TimeTicks::Now(); - base::TimeDelta total_duration = now - connect_timing_.dns_start; - UMA_HISTOGRAM_CUSTOM_TIMES( - "Net.DNS_Resolution_And_TCP_Connection_Latency2", - total_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - - base::TimeDelta connect_duration = now - connect_timing_.connect_start; - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - + bool is_ipv4 = + helper_.addresses().front().GetFamily() == ADDRESS_FAMILY_IPV4; + TransportConnectJobHelper::ConnectionLatencyHistogram race_result = + TransportConnectJobHelper::CONNECTION_LATENCY_UNKNOWN; + if (is_ipv4) { + race_result = TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_NO_RACE; + } else { + if (AddressListOnlyContainsIPv6(helper_.addresses())) { + race_result = TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_SOLO; + } else { + race_result = + TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_RACEABLE; + } + } + base::TimeDelta connect_duration = helper_.HistogramDuration(race_result); switch (interval_between_connects_) { case CONNECT_INTERVAL_LE_10MS: UMA_HISTOGRAM_CUSTOM_TIMES( @@ -282,27 +322,6 @@ int TransportConnectJob::DoTransportConnectComplete(int result) { break; } - if (is_ipv4) { - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_No_Race", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - } else { - if (AddressListOnlyContainsIPv6(addresses_)) { - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Solo", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - } else { - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Raceable", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - } - } SetSocket(transport_socket_.Pass()); fallback_timer_.Stop(); } else { @@ -317,7 +336,8 @@ int TransportConnectJob::DoTransportConnectComplete(int result) { void TransportConnectJob::DoIPv6FallbackTransportConnect() { // The timer should only fire while we're waiting for the main connect to // succeed. - if (next_state_ != STATE_TRANSPORT_CONNECT_COMPLETE) { + if (helper_.next_state() != + TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE) { NOTREACHED(); return; } @@ -325,10 +345,10 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() { DCHECK(!fallback_transport_socket_.get()); DCHECK(!fallback_addresses_.get()); - fallback_addresses_.reset(new AddressList(addresses_)); + fallback_addresses_.reset(new AddressList(helper_.addresses())); MakeAddressListStartWithIPv4(fallback_addresses_.get()); fallback_transport_socket_ = - client_socket_factory_->CreateTransportClientSocket( + helper_.client_socket_factory()->CreateTransportClientSocket( *fallback_addresses_, net_log().net_log(), net_log().source()); fallback_connect_start_time_ = base::TimeTicks::Now(); int rv = fallback_transport_socket_->Connect( @@ -341,7 +361,8 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() { void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { // This should only happen when we're waiting for the main connect to succeed. - if (next_state_ != STATE_TRANSPORT_CONNECT_COMPLETE) { + if (helper_.next_state() != + TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE) { NOTREACHED(); return; } @@ -352,30 +373,11 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { if (result == OK) { DCHECK(!fallback_connect_start_time_.is_null()); - DCHECK(!connect_timing_.dns_start.is_null()); - base::TimeTicks now = base::TimeTicks::Now(); - base::TimeDelta total_duration = now - connect_timing_.dns_start; - UMA_HISTOGRAM_CUSTOM_TIMES( - "Net.DNS_Resolution_And_TCP_Connection_Latency2", - total_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - - base::TimeDelta connect_duration = now - fallback_connect_start_time_; - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_Wins_Race", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); + connect_timing_.connect_start = fallback_connect_start_time_; + helper_.HistogramDuration( + TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_WINS_RACE); SetSocket(fallback_transport_socket_.Pass()); - next_state_ = STATE_NONE; + helper_.set_next_state(TransportConnectJobHelper::STATE_NONE); transport_socket_.reset(); } else { // Be a bit paranoid and kill off the fallback members to prevent reuse. @@ -386,8 +388,7 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { } int TransportConnectJob::ConnectInternal() { - next_state_ = STATE_RESOLVE_HOST; - return DoLoop(OK); + return helper_.DoConnectInternal(this); } scoped_ptr @@ -439,6 +440,15 @@ int TransportClientSocketPool::RequestSocket( const scoped_refptr* casted_params = static_cast*>(params); + NetLogTcpClientSocketPoolRequestedSocket(net_log, casted_params); + + return base_.RequestSocket(group_name, *casted_params, priority, handle, + callback, net_log); +} + +void TransportClientSocketPool::NetLogTcpClientSocketPoolRequestedSocket( + const BoundNetLog& net_log, + const scoped_refptr* casted_params) { if (net_log.IsLogging()) { // TODO(eroman): Split out the host and port parameters. net_log.AddEvent( @@ -446,9 +456,6 @@ int TransportClientSocketPool::RequestSocket( CreateNetLogHostPortPairCallback( &casted_params->get()->destination().host_port_pair())); } - - return base_.RequestSocket(group_name, *casted_params, priority, handle, - callback, net_log); } void TransportClientSocketPool::RequestSockets( diff --git a/net/socket/transport_client_socket_pool.h b/net/socket/transport_client_socket_pool.h index 1c22bf29ec3776..003008a372ddfa 100644 --- a/net/socket/transport_client_socket_pool.h +++ b/net/socket/transport_client_socket_pool.h @@ -55,6 +55,75 @@ class NET_EXPORT_PRIVATE TransportSocketParams DISALLOW_COPY_AND_ASSIGN(TransportSocketParams); }; +// Common data and logic shared between TransportConnectJob and +// WebSocketTransportConnectJob. +class NET_EXPORT_PRIVATE TransportConnectJobHelper { + public: + enum State { + STATE_RESOLVE_HOST, + STATE_RESOLVE_HOST_COMPLETE, + STATE_TRANSPORT_CONNECT, + STATE_TRANSPORT_CONNECT_COMPLETE, + STATE_NONE, + }; + + // For recording the connection time in the appropriate bucket. + enum ConnectionLatencyHistogram { + CONNECTION_LATENCY_UNKNOWN, + CONNECTION_LATENCY_IPV4_WINS_RACE, + CONNECTION_LATENCY_IPV4_NO_RACE, + CONNECTION_LATENCY_IPV6_RACEABLE, + CONNECTION_LATENCY_IPV6_SOLO, + }; + + TransportConnectJobHelper(const scoped_refptr& params, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + LoadTimingInfo::ConnectTiming* connect_timing); + ~TransportConnectJobHelper(); + + ClientSocketFactory* client_socket_factory() { + return client_socket_factory_; + } + + const AddressList& addresses() const { return addresses_; } + State next_state() const { return next_state_; } + void set_next_state(State next_state) { next_state_ = next_state; } + CompletionCallback on_io_complete() const { return on_io_complete_; } + + int DoResolveHost(RequestPriority priority, const BoundNetLog& net_log); + int DoResolveHostComplete(int result, const BoundNetLog& net_log); + + template + int DoConnectInternal(T* job); + + template + void SetOnIOComplete(T* job); + + template + void OnIOComplete(T* job, int result); + + // Record the histograms Net.DNS_Resolution_And_TCP_Connection_Latency2 and + // Net.TCP_Connection_Latency and return the connect duration. + base::TimeDelta HistogramDuration(ConnectionLatencyHistogram race_result); + + static const int kIPv6FallbackTimerInMs; + + private: + template + int DoLoop(T* job, int result); + + scoped_refptr params_; + ClientSocketFactory* const client_socket_factory_; + SingleRequestHostResolver resolver_; + AddressList addresses_; + State next_state_; + CompletionCallback on_io_complete_; + LoadTimingInfo::ConnectTiming* connect_timing_; + + DISALLOW_COPY_AND_ASSIGN(TransportConnectJobHelper); +}; + // TransportConnectJob handles the host resolution necessary for socket creation // and the transport (likely TCP) connect. TransportConnectJob also has fallback // logic for IPv6 connect() timeouts (which may happen due to networks / routers @@ -82,27 +151,14 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { // WARNING: this method should only be used to implement the prefer-IPv4 hack. static void MakeAddressListStartWithIPv4(AddressList* addrlist); - static const int kIPv6FallbackTimerInMs; - private: - enum State { - STATE_RESOLVE_HOST, - STATE_RESOLVE_HOST_COMPLETE, - STATE_TRANSPORT_CONNECT, - STATE_TRANSPORT_CONNECT_COMPLETE, - STATE_NONE, - }; - enum ConnectInterval { CONNECT_INTERVAL_LE_10MS, CONNECT_INTERVAL_LE_20MS, CONNECT_INTERVAL_GT_20MS, }; - void OnIOComplete(int result); - - // Runs the state transition loop. - int DoLoop(int result); + friend class TransportConnectJobHelper; int DoResolveHost(); int DoResolveHostComplete(int result); @@ -118,11 +174,7 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { // Otherwise, it returns a net error code. virtual int ConnectInternal() OVERRIDE; - scoped_refptr params_; - ClientSocketFactory* const client_socket_factory_; - SingleRequestHostResolver resolver_; - AddressList addresses_; - State next_state_; + TransportConnectJobHelper helper_; scoped_ptr transport_socket_; @@ -187,6 +239,12 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + protected: + // Methods shared with WebSocketTransportClientSocketPool + void NetLogTcpClientSocketPoolRequestedSocket( + const BoundNetLog& net_log, + const scoped_refptr* casted_params); + private: typedef ClientSocketPoolBase PoolBase; @@ -224,6 +282,61 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool); }; +template +int TransportConnectJobHelper::DoConnectInternal(T* job) { + next_state_ = STATE_RESOLVE_HOST; + return this->DoLoop(job, OK); +} + +template +void TransportConnectJobHelper::SetOnIOComplete(T* job) { + // These usages of base::Unretained() are safe because IO callbacks are + // guaranteed not to be called after the object is destroyed. + on_io_complete_ = base::Bind(&TransportConnectJobHelper::OnIOComplete, + base::Unretained(this), + base::Unretained(job)); +} + +template +void TransportConnectJobHelper::OnIOComplete(T* job, int result) { + result = this->DoLoop(job, result); + if (result != ERR_IO_PENDING) + job->NotifyDelegateOfCompletion(result); // Deletes |job| and |this| +} + +template +int TransportConnectJobHelper::DoLoop(T* job, int result) { + DCHECK_NE(next_state_, STATE_NONE); + + int rv = result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_RESOLVE_HOST: + DCHECK_EQ(OK, rv); + rv = job->DoResolveHost(); + break; + case STATE_RESOLVE_HOST_COMPLETE: + rv = job->DoResolveHostComplete(rv); + break; + case STATE_TRANSPORT_CONNECT: + DCHECK_EQ(OK, rv); + rv = job->DoTransportConnect(); + break; + case STATE_TRANSPORT_CONNECT_COMPLETE: + rv = job->DoTransportConnectComplete(rv); + break; + default: + NOTREACHED(); + rv = ERR_FAILED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + + return rv; +} + } // namespace net #endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ diff --git a/net/socket/transport_client_socket_pool_test_util.cc b/net/socket/transport_client_socket_pool_test_util.cc new file mode 100644 index 00000000000000..a94167c7954495 --- /dev/null +++ b/net/socket/transport_client_socket_pool_test_util.cc @@ -0,0 +1,421 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/transport_client_socket_pool_test_util.h" + +#include + +#include "base/logging.h" +#include "base/memory/weak_ptr.h" +#include "base/run_loop.h" +#include "net/base/ip_endpoint.h" +#include "net/base/load_timing_info.h" +#include "net/base/load_timing_info_test_util.h" +#include "net/base/net_util.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/ssl_client_socket.h" +#include "net/udp/datagram_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +IPAddressNumber ParseIP(const std::string& ip) { + IPAddressNumber number; + CHECK(ParseIPLiteralToNumber(ip, &number)); + return number; +} + +// A StreamSocket which connects synchronously and successfully. +class MockConnectClientSocket : public StreamSocket { + public: + MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log) + : connected_(false), + addrlist_(addrlist), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {} + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + connected_ = true; + return OK; + } + virtual void Disconnect() OVERRIDE { connected_ = false; } + virtual bool IsConnected() const OVERRIDE { return connected_; } + virtual bool IsConnectedAndIdle() const OVERRIDE { return connected_; } + + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + *address = addrlist_.front(); + return OK; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + if (!connected_) + return ERR_SOCKET_NOT_CONNECTED; + if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) + SetIPv4Address(address); + else + SetIPv6Address(address); + return OK; + } + virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; } + + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + virtual bool WasEverUsed() const OVERRIDE { return false; } + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } + virtual bool WasNpnNegotiated() const OVERRIDE { return false; } + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; } + + // Socket implementation. + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } + virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } + + private: + bool connected_; + const AddressList addrlist_; + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket); +}; + +class MockFailingClientSocket : public StreamSocket { + public: + MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log) + : addrlist_(addrlist), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {} + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + return ERR_CONNECTION_FAILED; + } + + virtual void Disconnect() OVERRIDE {} + + virtual bool IsConnected() const OVERRIDE { return false; } + virtual bool IsConnectedAndIdle() const OVERRIDE { return false; } + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + return ERR_UNEXPECTED; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + return ERR_UNEXPECTED; + } + virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; } + + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + virtual bool WasEverUsed() const OVERRIDE { return false; } + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } + virtual bool WasNpnNegotiated() const OVERRIDE { return false; } + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; } + + // Socket implementation. + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } + virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } + + private: + const AddressList addrlist_; + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket); +}; + +class MockTriggerableClientSocket : public StreamSocket { + public: + // |should_connect| indicates whether the socket should successfully complete + // or fail. + MockTriggerableClientSocket(const AddressList& addrlist, + bool should_connect, + net::NetLog* net_log) + : should_connect_(should_connect), + is_connected_(false), + addrlist_(addrlist), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), + weak_factory_(this) {} + + // Call this method to get a closure which will trigger the connect callback + // when called. The closure can be called even after the socket is deleted; it + // will safely do nothing. + base::Closure GetConnectCallback() { + return base::Bind(&MockTriggerableClientSocket::DoCallback, + weak_factory_.GetWeakPtr()); + } + + static scoped_ptr MakeMockPendingClientSocket( + const AddressList& addrlist, + bool should_connect, + net::NetLog* net_log) { + scoped_ptr socket( + new MockTriggerableClientSocket(addrlist, should_connect, net_log)); + base::MessageLoop::current()->PostTask(FROM_HERE, + socket->GetConnectCallback()); + return socket.PassAs(); + } + + static scoped_ptr MakeMockDelayedClientSocket( + const AddressList& addrlist, + bool should_connect, + const base::TimeDelta& delay, + net::NetLog* net_log) { + scoped_ptr socket( + new MockTriggerableClientSocket(addrlist, should_connect, net_log)); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, socket->GetConnectCallback(), delay); + return socket.PassAs(); + } + + static scoped_ptr MakeMockStalledClientSocket( + const AddressList& addrlist, + net::NetLog* net_log) { + scoped_ptr socket( + new MockTriggerableClientSocket(addrlist, true, net_log)); + return socket.PassAs(); + } + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + DCHECK(callback_.is_null()); + callback_ = callback; + return ERR_IO_PENDING; + } + + virtual void Disconnect() OVERRIDE {} + + virtual bool IsConnected() const OVERRIDE { return is_connected_; } + virtual bool IsConnectedAndIdle() const OVERRIDE { return is_connected_; } + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + *address = addrlist_.front(); + return OK; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + if (!is_connected_) + return ERR_SOCKET_NOT_CONNECTED; + if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) + SetIPv4Address(address); + else + SetIPv6Address(address); + return OK; + } + virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; } + + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + virtual bool WasEverUsed() const OVERRIDE { return false; } + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } + virtual bool WasNpnNegotiated() const OVERRIDE { return false; } + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; } + + // Socket implementation. + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } + virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } + + private: + void DoCallback() { + is_connected_ = should_connect_; + callback_.Run(is_connected_ ? OK : ERR_CONNECTION_FAILED); + } + + bool should_connect_; + bool is_connected_; + const AddressList addrlist_; + BoundNetLog net_log_; + CompletionCallback callback_; + + base::WeakPtrFactory weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket); +}; + +} // namespace + +void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) { + LoadTimingInfo load_timing_info; + // Only pass true in as |is_reused|, as in general, HttpStream types should + // have stricter concepts of reuse than socket pools. + EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info)); + + EXPECT_TRUE(load_timing_info.socket_reused); + EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); +} + +void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) { + EXPECT_FALSE(handle.is_reused()); + + LoadTimingInfo load_timing_info; + EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); + + EXPECT_FALSE(load_timing_info.socket_reused); + EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + ExpectConnectTimingHasTimes(load_timing_info.connect_timing, + CONNECT_TIMING_HAS_DNS_TIMES); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); + + TestLoadTimingInfoConnectedReused(handle); +} + +void SetIPv4Address(IPEndPoint* address) { + *address = IPEndPoint(ParseIP("1.1.1.1"), 80); +} + +void SetIPv6Address(IPEndPoint* address) { + *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80); +} + +MockTransportClientSocketFactory::MockTransportClientSocketFactory( + NetLog* net_log) + : net_log_(net_log), + allocation_count_(0), + client_socket_type_(MOCK_CLIENT_SOCKET), + client_socket_types_(NULL), + client_socket_index_(0), + client_socket_index_max_(0), + delay_(base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs)) {} + +MockTransportClientSocketFactory::~MockTransportClientSocketFactory() {} + +scoped_ptr +MockTransportClientSocketFactory::CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + NetLog* net_log, + const NetLog::Source& source) { + NOTREACHED(); + return scoped_ptr(); +} + +scoped_ptr +MockTransportClientSocketFactory::CreateTransportClientSocket( + const AddressList& addresses, + NetLog* /* net_log */, + const NetLog::Source& /* source */) { + allocation_count_++; + + ClientSocketType type = client_socket_type_; + if (client_socket_types_ && client_socket_index_ < client_socket_index_max_) { + type = client_socket_types_[client_socket_index_++]; + } + + switch (type) { + case MOCK_CLIENT_SOCKET: + return scoped_ptr( + new MockConnectClientSocket(addresses, net_log_)); + case MOCK_FAILING_CLIENT_SOCKET: + return scoped_ptr( + new MockFailingClientSocket(addresses, net_log_)); + case MOCK_PENDING_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockPendingClientSocket( + addresses, true, net_log_); + case MOCK_PENDING_FAILING_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockPendingClientSocket( + addresses, false, net_log_); + case MOCK_DELAYED_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockDelayedClientSocket( + addresses, true, delay_, net_log_); + case MOCK_DELAYED_FAILING_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockDelayedClientSocket( + addresses, false, delay_, net_log_); + case MOCK_STALLED_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses, + net_log_); + case MOCK_TRIGGERABLE_CLIENT_SOCKET: { + scoped_ptr rv( + new MockTriggerableClientSocket(addresses, true, net_log_)); + triggerable_sockets_.push(rv->GetConnectCallback()); + // run_loop_quit_closure_ behaves like a condition variable. It will + // wake up WaitForTriggerableSocketCreation() if it is sleeping. We + // don't need to worry about atomicity because this code is + // single-threaded. + if (!run_loop_quit_closure_.is_null()) + run_loop_quit_closure_.Run(); + return rv.PassAs(); + } + default: + NOTREACHED(); + return scoped_ptr( + new MockConnectClientSocket(addresses, net_log_)); + } +} + +scoped_ptr +MockTransportClientSocketFactory::CreateSSLClientSocket( + scoped_ptr transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) { + NOTIMPLEMENTED(); + return scoped_ptr(); +} + +void MockTransportClientSocketFactory::ClearSSLSessionCache() { + NOTIMPLEMENTED(); +} + +void MockTransportClientSocketFactory::set_client_socket_types( + ClientSocketType* type_list, + int num_types) { + DCHECK_GT(num_types, 0); + client_socket_types_ = type_list; + client_socket_index_ = 0; + client_socket_index_max_ = num_types; +} + +base::Closure +MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() { + while (triggerable_sockets_.empty()) { + base::RunLoop run_loop; + run_loop_quit_closure_ = run_loop.QuitClosure(); + run_loop.Run(); + run_loop_quit_closure_.Reset(); + } + base::Closure trigger = triggerable_sockets_.front(); + triggerable_sockets_.pop(); + return trigger; +} + +} // namespace net diff --git a/net/socket/transport_client_socket_pool_test_util.h b/net/socket/transport_client_socket_pool_test_util.h new file mode 100644 index 00000000000000..d48193c688e8d5 --- /dev/null +++ b/net/socket/transport_client_socket_pool_test_util.h @@ -0,0 +1,127 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Test methods and classes common to transport_client_socket_pool_unittest.cc +// and websocket_transport_client_socket_pool_unittest.cc. If you find you need +// to use these for another purpose, consider moving them to socket_test_util.h. + +#ifndef NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_TEST_UTIL_H_ +#define NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_TEST_UTIL_H_ + +#include + +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/macros.h" +#include "base/memory/scoped_ptr.h" +#include "base/time/time.h" +#include "net/base/address_list.h" +#include "net/base/net_log.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/stream_socket.h" + +namespace net { + +class ClientSocketHandle; +class IPEndPoint; + +// Make sure |handle| sets load times correctly when it has been assigned a +// reused socket. Uses gtest expectations. +void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle); + +// Make sure |handle| sets load times correctly when it has been assigned a +// fresh socket. Also runs TestLoadTimingInfoConnectedReused, since the owner +// of a connection where |is_reused| is false may consider the connection +// reused. Uses gtest expectations. +void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle); + +// Set |address| to 1.1.1.1:80 +void SetIPv4Address(IPEndPoint* address); + +// Set |address| to [1:abcd::3:4:ff]:80 +void SetIPv6Address(IPEndPoint* address); + +// A ClientSocketFactory that produces sockets with the specified connection +// behaviours. +class MockTransportClientSocketFactory : public ClientSocketFactory { + public: + enum ClientSocketType { + // Connects successfully, synchronously. + MOCK_CLIENT_SOCKET, + // Fails to connect, synchronously. + MOCK_FAILING_CLIENT_SOCKET, + // Connects successfully, asynchronously. + MOCK_PENDING_CLIENT_SOCKET, + // Fails to connect, asynchronously. + MOCK_PENDING_FAILING_CLIENT_SOCKET, + // A delayed socket will pause before connecting through the message loop. + MOCK_DELAYED_CLIENT_SOCKET, + // A delayed socket that fails. + MOCK_DELAYED_FAILING_CLIENT_SOCKET, + // A stalled socket that never connects at all. + MOCK_STALLED_CLIENT_SOCKET, + // A socket that can be triggered to connect explicitly, asynchronously. + MOCK_TRIGGERABLE_CLIENT_SOCKET, + }; + + explicit MockTransportClientSocketFactory(NetLog* net_log); + virtual ~MockTransportClientSocketFactory(); + + virtual scoped_ptr CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + NetLog* net_log, + const NetLog::Source& source) OVERRIDE; + + virtual scoped_ptr CreateTransportClientSocket( + const AddressList& addresses, + NetLog* /* net_log */, + const NetLog::Source& /* source */) OVERRIDE; + + virtual scoped_ptr CreateSSLClientSocket( + scoped_ptr transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) OVERRIDE; + + virtual void ClearSSLSessionCache() OVERRIDE; + + int allocation_count() const { return allocation_count_; } + + // Set the default ClientSocketType. + void set_client_socket_type(ClientSocketType type) { + client_socket_type_ = type; + } + + // Set a list of ClientSocketTypes to be used. + void set_client_socket_types(ClientSocketType* type_list, int num_types); + + void set_delay(base::TimeDelta delay) { delay_ = delay; } + + // If one or more MOCK_TRIGGERABLE_CLIENT_SOCKETs has already been created, + // then returns a Closure that can be called to cause the first + // not-yet-connected one to connect. If no MOCK_TRIGGERABLE_CLIENT_SOCKETs + // have been created yet, wait for one to be created before returning the + // Closure. This method should be called the same number of times as + // MOCK_TRIGGERABLE_CLIENT_SOCKETs are created in the test. + base::Closure WaitForTriggerableSocketCreation(); + + private: + NetLog* net_log_; + int allocation_count_; + ClientSocketType client_socket_type_; + ClientSocketType* client_socket_types_; + int client_socket_index_; + int client_socket_index_max_; + base::TimeDelta delay_; + std::queue triggerable_sockets_; + base::Closure run_loop_quit_closure_; + + DISALLOW_COPY_AND_ASSIGN(MockTransportClientSocketFactory); +}; + +} // namespace net + +#endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_TEST_UTIL_H_ diff --git a/net/socket/transport_client_socket_pool_unittest.cc b/net/socket/transport_client_socket_pool_unittest.cc index 425bb8cc421ac4..146fc490f88e63 100644 --- a/net/socket/transport_client_socket_pool_unittest.cc +++ b/net/socket/transport_client_socket_pool_unittest.cc @@ -7,8 +7,6 @@ #include "base/bind.h" #include "base/bind_helpers.h" #include "base/callback.h" -#include "base/compiler_specific.h" -#include "base/logging.h" #include "base/message_loop/message_loop.h" #include "base/threading/platform_thread.h" #include "net/base/capturing_net_log.h" @@ -19,12 +17,11 @@ #include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/dns/mock_host_resolver.h" -#include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" -#include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" +#include "net/socket/transport_client_socket_pool_test_util.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { @@ -37,410 +34,6 @@ const int kMaxSockets = 32; const int kMaxSocketsPerGroup = 6; const net::RequestPriority kDefaultPriority = LOW; -// Make sure |handle| sets load times correctly when it has been assigned a -// reused socket. -void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) { - LoadTimingInfo load_timing_info; - // Only pass true in as |is_reused|, as in general, HttpStream types should - // have stricter concepts of reuse than socket pools. - EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info)); - - EXPECT_TRUE(load_timing_info.socket_reused); - EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); - - ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing); - ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); -} - -// Make sure |handle| sets load times correctly when it has been assigned a -// fresh socket. Also runs TestLoadTimingInfoConnectedReused, since the owner -// of a connection where |is_reused| is false may consider the connection -// reused. -void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) { - EXPECT_FALSE(handle.is_reused()); - - LoadTimingInfo load_timing_info; - EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); - - EXPECT_FALSE(load_timing_info.socket_reused); - EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); - - ExpectConnectTimingHasTimes(load_timing_info.connect_timing, - CONNECT_TIMING_HAS_DNS_TIMES); - ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); - - TestLoadTimingInfoConnectedReused(handle); -} - -void SetIPv4Address(IPEndPoint* address) { - IPAddressNumber number; - CHECK(ParseIPLiteralToNumber("1.1.1.1", &number)); - *address = IPEndPoint(number, 80); -} - -void SetIPv6Address(IPEndPoint* address) { - IPAddressNumber number; - CHECK(ParseIPLiteralToNumber("1:abcd::3:4:ff", &number)); - *address = IPEndPoint(number, 80); -} - -class MockClientSocket : public StreamSocket { - public: - MockClientSocket(const AddressList& addrlist, net::NetLog* net_log) - : connected_(false), - addrlist_(addrlist), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { - } - - // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE { - connected_ = true; - return OK; - } - virtual void Disconnect() OVERRIDE { - connected_ = false; - } - virtual bool IsConnected() const OVERRIDE { - return connected_; - } - virtual bool IsConnectedAndIdle() const OVERRIDE { - return connected_; - } - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { - return ERR_UNEXPECTED; - } - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { - if (!connected_) - return ERR_SOCKET_NOT_CONNECTED; - if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) - SetIPv4Address(address); - else - SetIPv6Address(address); - return OK; - } - virtual const BoundNetLog& NetLog() const OVERRIDE { - return net_log_; - } - - virtual void SetSubresourceSpeculation() OVERRIDE {} - virtual void SetOmniboxSpeculation() OVERRIDE {} - virtual bool WasEverUsed() const OVERRIDE { return false; } - virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } - virtual bool WasNpnNegotiated() const OVERRIDE { - return false; - } - virtual NextProto GetNegotiatedProtocol() const OVERRIDE { - return kProtoUnknown; - } - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { - return false; - } - - // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } - virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } - - private: - bool connected_; - const AddressList addrlist_; - BoundNetLog net_log_; - - DISALLOW_COPY_AND_ASSIGN(MockClientSocket); -}; - -class MockFailingClientSocket : public StreamSocket { - public: - MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log) - : addrlist_(addrlist), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { - } - - // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE { - return ERR_CONNECTION_FAILED; - } - - virtual void Disconnect() OVERRIDE {} - - virtual bool IsConnected() const OVERRIDE { - return false; - } - virtual bool IsConnectedAndIdle() const OVERRIDE { - return false; - } - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { - return ERR_UNEXPECTED; - } - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { - return ERR_UNEXPECTED; - } - virtual const BoundNetLog& NetLog() const OVERRIDE { - return net_log_; - } - - virtual void SetSubresourceSpeculation() OVERRIDE {} - virtual void SetOmniboxSpeculation() OVERRIDE {} - virtual bool WasEverUsed() const OVERRIDE { return false; } - virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } - virtual bool WasNpnNegotiated() const OVERRIDE { - return false; - } - virtual NextProto GetNegotiatedProtocol() const OVERRIDE { - return kProtoUnknown; - } - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { - return false; - } - - // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } - virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } - - private: - const AddressList addrlist_; - BoundNetLog net_log_; - - DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket); -}; - -class MockPendingClientSocket : public StreamSocket { - public: - // |should_connect| indicates whether the socket should successfully complete - // or fail. - // |should_stall| indicates that this socket should never connect. - // |delay_ms| is the delay, in milliseconds, before simulating a connect. - MockPendingClientSocket( - const AddressList& addrlist, - bool should_connect, - bool should_stall, - base::TimeDelta delay, - net::NetLog* net_log) - : should_connect_(should_connect), - should_stall_(should_stall), - delay_(delay), - is_connected_(false), - addrlist_(addrlist), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), - weak_factory_(this) { - } - - // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE { - base::MessageLoop::current()->PostDelayedTask( - FROM_HERE, - base::Bind(&MockPendingClientSocket::DoCallback, - weak_factory_.GetWeakPtr(), callback), - delay_); - return ERR_IO_PENDING; - } - - virtual void Disconnect() OVERRIDE {} - - virtual bool IsConnected() const OVERRIDE { - return is_connected_; - } - virtual bool IsConnectedAndIdle() const OVERRIDE { - return is_connected_; - } - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { - return ERR_UNEXPECTED; - } - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { - if (!is_connected_) - return ERR_SOCKET_NOT_CONNECTED; - if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) - SetIPv4Address(address); - else - SetIPv6Address(address); - return OK; - } - virtual const BoundNetLog& NetLog() const OVERRIDE { - return net_log_; - } - - virtual void SetSubresourceSpeculation() OVERRIDE {} - virtual void SetOmniboxSpeculation() OVERRIDE {} - virtual bool WasEverUsed() const OVERRIDE { return false; } - virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } - virtual bool WasNpnNegotiated() const OVERRIDE { - return false; - } - virtual NextProto GetNegotiatedProtocol() const OVERRIDE { - return kProtoUnknown; - } - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { - return false; - } - - // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } - virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } - - private: - void DoCallback(const CompletionCallback& callback) { - if (should_stall_) - return; - - if (should_connect_) { - is_connected_ = true; - callback.Run(OK); - } else { - is_connected_ = false; - callback.Run(ERR_CONNECTION_FAILED); - } - } - - bool should_connect_; - bool should_stall_; - base::TimeDelta delay_; - bool is_connected_; - const AddressList addrlist_; - BoundNetLog net_log_; - - base::WeakPtrFactory weak_factory_; - - DISALLOW_COPY_AND_ASSIGN(MockPendingClientSocket); -}; - -class MockClientSocketFactory : public ClientSocketFactory { - public: - enum ClientSocketType { - MOCK_CLIENT_SOCKET, - MOCK_FAILING_CLIENT_SOCKET, - MOCK_PENDING_CLIENT_SOCKET, - MOCK_PENDING_FAILING_CLIENT_SOCKET, - // A delayed socket will pause before connecting through the message loop. - MOCK_DELAYED_CLIENT_SOCKET, - // A stalled socket that never connects at all. - MOCK_STALLED_CLIENT_SOCKET, - }; - - explicit MockClientSocketFactory(NetLog* net_log) - : net_log_(net_log), allocation_count_(0), - client_socket_type_(MOCK_CLIENT_SOCKET), client_socket_types_(NULL), - client_socket_index_(0), client_socket_index_max_(0), - delay_(base::TimeDelta::FromMilliseconds( - ClientSocketPool::kMaxConnectRetryIntervalMs)) {} - - virtual scoped_ptr CreateDatagramClientSocket( - DatagramSocket::BindType bind_type, - const RandIntCallback& rand_int_cb, - NetLog* net_log, - const NetLog::Source& source) OVERRIDE { - NOTREACHED(); - return scoped_ptr(); - } - - virtual scoped_ptr CreateTransportClientSocket( - const AddressList& addresses, - NetLog* /* net_log */, - const NetLog::Source& /* source */) OVERRIDE { - allocation_count_++; - - ClientSocketType type = client_socket_type_; - if (client_socket_types_ && - client_socket_index_ < client_socket_index_max_) { - type = client_socket_types_[client_socket_index_++]; - } - - switch (type) { - case MOCK_CLIENT_SOCKET: - return scoped_ptr( - new MockClientSocket(addresses, net_log_)); - case MOCK_FAILING_CLIENT_SOCKET: - return scoped_ptr( - new MockFailingClientSocket(addresses, net_log_)); - case MOCK_PENDING_CLIENT_SOCKET: - return scoped_ptr( - new MockPendingClientSocket( - addresses, true, false, base::TimeDelta(), net_log_)); - case MOCK_PENDING_FAILING_CLIENT_SOCKET: - return scoped_ptr( - new MockPendingClientSocket( - addresses, false, false, base::TimeDelta(), net_log_)); - case MOCK_DELAYED_CLIENT_SOCKET: - return scoped_ptr( - new MockPendingClientSocket( - addresses, true, false, delay_, net_log_)); - case MOCK_STALLED_CLIENT_SOCKET: - return scoped_ptr( - new MockPendingClientSocket( - addresses, true, true, base::TimeDelta(), net_log_)); - default: - NOTREACHED(); - return scoped_ptr( - new MockClientSocket(addresses, net_log_)); - } - } - - virtual scoped_ptr CreateSSLClientSocket( - scoped_ptr transport_socket, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config, - const SSLClientSocketContext& context) OVERRIDE { - NOTIMPLEMENTED(); - return scoped_ptr(); - } - - virtual void ClearSSLSessionCache() OVERRIDE { - NOTIMPLEMENTED(); - } - - int allocation_count() const { return allocation_count_; } - - // Set the default ClientSocketType. - void set_client_socket_type(ClientSocketType type) { - client_socket_type_ = type; - } - - // Set a list of ClientSocketTypes to be used. - void set_client_socket_types(ClientSocketType* type_list, int num_types) { - DCHECK_GT(num_types, 0); - client_socket_types_ = type_list; - client_socket_index_ = 0; - client_socket_index_max_ = num_types; - } - - void set_delay(base::TimeDelta delay) { delay_ = delay; } - - private: - NetLog* net_log_; - int allocation_count_; - ClientSocketType client_socket_type_; - ClientSocketType* client_socket_types_; - int client_socket_index_; - int client_socket_index_max_; - base::TimeDelta delay_; - - DISALLOW_COPY_AND_ASSIGN(MockClientSocketFactory); -}; - class TransportClientSocketPoolTest : public testing::Test { protected: TransportClientSocketPoolTest() @@ -494,10 +87,11 @@ class TransportClientSocketPoolTest : public testing::Test { scoped_refptr params_; scoped_ptr histograms_; scoped_ptr host_resolver_; - MockClientSocketFactory client_socket_factory_; + MockTransportClientSocketFactory client_socket_factory_; TransportClientSocketPool pool_; ClientSocketPoolTest test_base_; + private: DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPoolTest); }; @@ -617,7 +211,7 @@ TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) { TEST_F(TransportClientSocketPoolTest, InitConnectionFailure) { client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET); + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET); TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(ERR_IO_PENDING, @@ -761,7 +355,7 @@ TEST_F(TransportClientSocketPoolTest, TwoRequestsCancelOne) { TEST_F(TransportClientSocketPoolTest, ConnectCancelConnect) { client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); ClientSocketHandle handle; TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, @@ -921,7 +515,7 @@ TEST_F(TransportClientSocketPoolTest, RequestTwice) { // cancelled. TEST_F(TransportClientSocketPoolTest, CancelActiveRequestWithPendingRequests) { client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); // Queue up all the requests EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); @@ -951,7 +545,7 @@ TEST_F(TransportClientSocketPoolTest, CancelActiveRequestWithPendingRequests) { // Make sure that pending requests get serviced after active requests fail. TEST_F(TransportClientSocketPoolTest, FailingActiveRequestWithPendingRequests) { client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET); + MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET); const int kNumRequests = 2 * kMaxSocketsPerGroup + 1; ASSERT_LE(kNumRequests, kMaxSockets); // Otherwise the test will hang. @@ -1022,24 +616,24 @@ TEST_F(TransportClientSocketPoolTest, ResetIdleSocketsOnIPAddressChange) { TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) { // Case 1 tests the first socket stalling, and the backup connecting. - MockClientSocketFactory::ClientSocketType case1_types[] = { + MockTransportClientSocketFactory::ClientSocketType case1_types[] = { // The first socket will not connect. - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, // The second socket will connect more quickly. - MockClientSocketFactory::MOCK_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET }; // Case 2 tests the first socket being slow, so that we start the // second connect, but the second connect stalls, and we still // complete the first. - MockClientSocketFactory::ClientSocketType case2_types[] = { + MockTransportClientSocketFactory::ClientSocketType case2_types[] = { // The first socket will connect, although delayed. - MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, // The second socket will not connect. - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET }; - MockClientSocketFactory::ClientSocketType* cases[2] = { + MockTransportClientSocketFactory::ClientSocketType* cases[2] = { case1_types, case2_types }; @@ -1084,7 +678,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) { // of the backup socket, but then we cancelled the request after that. TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) { client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET); + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET); enum { CANCEL_BEFORE_WAIT, CANCEL_AFTER_WAIT }; @@ -1126,11 +720,11 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) { // of the backup socket and never completes, and then the backup // connection fails. TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) { - MockClientSocketFactory::ClientSocketType case_types[] = { + MockTransportClientSocketFactory::ClientSocketType case_types[] = { // The first socket will not connect. - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, // The second socket will fail immediately. - MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET }; client_socket_factory_.set_client_socket_types(case_types, 2); @@ -1173,11 +767,11 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) { // of the backup socket and eventually completes, but the backup socket // fails. TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterDelay) { - MockClientSocketFactory::ClientSocketType case_types[] = { + MockTransportClientSocketFactory::ClientSocketType case_types[] = { // The first socket will connect, although delayed. - MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, // The second socket will not connect. - MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET }; client_socket_factory_.set_client_socket_types(case_types, 2); @@ -1228,11 +822,11 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) { &client_socket_factory_, NULL); - MockClientSocketFactory::ClientSocketType case_types[] = { + MockTransportClientSocketFactory::ClientSocketType case_types[] = { // This is the IPv6 socket. - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, // This is the IPv4 socket. - MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET }; client_socket_factory_.set_client_socket_types(case_types, 2); @@ -1271,16 +865,16 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) { &client_socket_factory_, NULL); - MockClientSocketFactory::ClientSocketType case_types[] = { + MockTransportClientSocketFactory::ClientSocketType case_types[] = { // This is the IPv6 socket. - MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, // This is the IPv4 socket. - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET }; client_socket_factory_.set_client_socket_types(case_types, 2); client_socket_factory_.set_delay(base::TimeDelta::FromMilliseconds( - TransportConnectJob::kIPv6FallbackTimerInMs + 50)); + TransportConnectJobHelper::kIPv6FallbackTimerInMs + 50)); // Resolve an AddressList with a IPv6 address first and then a IPv4 address. host_resolver_->rules() @@ -1314,7 +908,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) { NULL); client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); // Resolve an AddressList with only IPv6 addresses. host_resolver_->rules() @@ -1348,7 +942,7 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) { NULL); client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); // Resolve an AddressList with only IPv4 addresses. host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string()); diff --git a/net/socket/websocket_endpoint_lock_manager.cc b/net/socket/websocket_endpoint_lock_manager.cc new file mode 100644 index 00000000000000..c8d9a3b0a4602a --- /dev/null +++ b/net/socket/websocket_endpoint_lock_manager.cc @@ -0,0 +1,95 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/websocket_endpoint_lock_manager.h" + +#include + +#include "base/logging.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" + +namespace net { + +WebSocketEndpointLockManager::Waiter::~Waiter() { + if (next()) { + DCHECK(previous()); + RemoveFromList(); + } +} + +WebSocketEndpointLockManager* WebSocketEndpointLockManager::GetInstance() { + return Singleton::get(); +} + +int WebSocketEndpointLockManager::LockEndpoint(const IPEndPoint& endpoint, + Waiter* waiter) { + EndPointWaiterMap::value_type insert_value(endpoint, NULL); + std::pair rv = + endpoint_waiter_map_.insert(insert_value); + if (rv.second) { + DVLOG(3) << "Locking endpoint " << endpoint.ToString(); + rv.first->second = new ConnectJobQueue; + return OK; + } + DVLOG(3) << "Waiting for endpoint " << endpoint.ToString(); + rv.first->second->Append(waiter); + return ERR_IO_PENDING; +} + +void WebSocketEndpointLockManager::RememberSocket(StreamSocket* socket, + const IPEndPoint& endpoint) { + bool inserted = socket_endpoint_map_.insert(SocketEndPointMap::value_type( + socket, endpoint)).second; + DCHECK(inserted); + DCHECK(endpoint_waiter_map_.find(endpoint) != endpoint_waiter_map_.end()); + DVLOG(3) << "Remembered (StreamSocket*)" << socket << " for " + << endpoint.ToString() << " (" << socket_endpoint_map_.size() + << " sockets remembered)"; +} + +void WebSocketEndpointLockManager::UnlockSocket(StreamSocket* socket) { + SocketEndPointMap::iterator socket_it = socket_endpoint_map_.find(socket); + if (socket_it == socket_endpoint_map_.end()) { + DVLOG(3) << "Ignoring request to unlock already-unlocked socket" + "(StreamSocket*)" << socket; + return; + } + const IPEndPoint& endpoint = socket_it->second; + DVLOG(3) << "Unlocking (StreamSocket*)" << socket << " for " + << endpoint.ToString() << " (" << socket_endpoint_map_.size() + << " sockets left)"; + UnlockEndpoint(endpoint); + socket_endpoint_map_.erase(socket_it); +} + +void WebSocketEndpointLockManager::UnlockEndpoint(const IPEndPoint& endpoint) { + EndPointWaiterMap::iterator found_it = endpoint_waiter_map_.find(endpoint); + CHECK(found_it != endpoint_waiter_map_.end()); // Security critical + ConnectJobQueue* queue = found_it->second; + if (queue->empty()) { + DVLOG(3) << "Unlocking endpoint " << endpoint.ToString(); + delete queue; + endpoint_waiter_map_.erase(found_it); + } else { + DVLOG(3) << "Unlocking endpoint " << endpoint.ToString() + << " and activating next waiter"; + Waiter* next_job = queue->head()->value(); + next_job->RemoveFromList(); + next_job->GotEndpointLock(); + } +} + +bool WebSocketEndpointLockManager::IsEmpty() const { + return endpoint_waiter_map_.empty() && socket_endpoint_map_.empty(); +} + +WebSocketEndpointLockManager::WebSocketEndpointLockManager() {} + +WebSocketEndpointLockManager::~WebSocketEndpointLockManager() { + DCHECK(endpoint_waiter_map_.empty()); + DCHECK(socket_endpoint_map_.empty()); +} + +} // namespace net diff --git a/net/socket/websocket_endpoint_lock_manager.h b/net/socket/websocket_endpoint_lock_manager.h new file mode 100644 index 00000000000000..7ab25c6744ef89 --- /dev/null +++ b/net/socket/websocket_endpoint_lock_manager.h @@ -0,0 +1,85 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_WEBSOCKET_ENDPOINT_LOCK_MANAGER_H_ +#define NET_SOCKET_WEBSOCKET_ENDPOINT_LOCK_MANAGER_H_ + +#include + +#include "base/containers/linked_list.h" +#include "base/logging.h" +#include "base/macros.h" +#include "base/memory/singleton.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_export.h" +#include "net/socket/websocket_transport_client_socket_pool.h" + +namespace net { + +class StreamSocket; + +class NET_EXPORT_PRIVATE WebSocketEndpointLockManager { + public: + class NET_EXPORT_PRIVATE Waiter : public base::LinkNode { + public: + // If the node is in a list, removes it. + virtual ~Waiter(); + + virtual void GotEndpointLock() = 0; + }; + + static WebSocketEndpointLockManager* GetInstance(); + + // Returns OK if lock was acquired immediately, ERR_IO_PENDING if not. If the + // lock was not acquired, then |waiter->GotEndpointLock()| will be called when + // it is. A Waiter automatically removes itself from the list of waiters when + // its destructor is called. + int LockEndpoint(const IPEndPoint& endpoint, Waiter* waiter); + + // Records the IPEndPoint associated with a particular socket. This is + // necessary because TCPClientSocket refuses to return the PeerAddress after + // the connection is disconnected. The association will be forgotten when + // UnlockSocket() is called. The |socket| pointer must not be deleted between + // the call to RememberSocket() and the call to UnlockSocket(). + void RememberSocket(StreamSocket* socket, const IPEndPoint& endpoint); + + // Releases the lock on an endpoint, and, if appropriate, triggers the next + // socket connection. For a successful WebSocket connection, this method will + // be called once when the handshake completes, and again when the connection + // is closed. Calls after the first are safely ignored. + void UnlockSocket(StreamSocket* socket); + + // Releases the lock on |endpoint|. If RememberSocket() has been called for + // this endpoint, then UnlockSocket() must be used instead of this method. + void UnlockEndpoint(const IPEndPoint& endpoint); + + // Checks that |endpoint_waiter_map_| and |socket_endpoint_map_| are + // empty. For tests. + bool IsEmpty() const; + + private: + typedef base::LinkedList ConnectJobQueue; + typedef std::map EndPointWaiterMap; + typedef std::map SocketEndPointMap; + + WebSocketEndpointLockManager(); + ~WebSocketEndpointLockManager(); + + // If an entry is present in the map for a particular endpoint, then that + // endpoint is locked. If the list is non-empty, then one or more Waiters are + // waiting for the lock. + EndPointWaiterMap endpoint_waiter_map_; + + // Store sockets remembered by RememberSocket() and not yet unlocked by + // UnlockSocket(). + SocketEndPointMap socket_endpoint_map_; + + friend struct DefaultSingletonTraits; + + DISALLOW_COPY_AND_ASSIGN(WebSocketEndpointLockManager); +}; + +} // namespace net + +#endif // NET_SOCKET_WEBSOCKET_ENDPOINT_LOCK_MANAGER_H_ diff --git a/net/socket/websocket_endpoint_lock_manager_unittest.cc b/net/socket/websocket_endpoint_lock_manager_unittest.cc new file mode 100644 index 00000000000000..03d4dbc8d62e6b --- /dev/null +++ b/net/socket/websocket_endpoint_lock_manager_unittest.cc @@ -0,0 +1,211 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/websocket_endpoint_lock_manager.h" + +#include "net/base/net_errors.h" +#include "net/socket/next_proto.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/stream_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +// A StreamSocket implementation with no functionality at all. +// TODO(ricea): If you need to use this in another file, please move it to +// socket_test_util.h. +class FakeStreamSocket : public StreamSocket { + public: + FakeStreamSocket() {} + + // StreamSocket implementation + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + + virtual void Disconnect() OVERRIDE { return; } + + virtual bool IsConnected() const OVERRIDE { return false; } + + virtual bool IsConnectedAndIdle() const OVERRIDE { return false; } + + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + return ERR_FAILED; + } + + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + return ERR_FAILED; + } + + virtual const BoundNetLog& NetLog() const OVERRIDE { return bound_net_log_; } + + virtual void SetSubresourceSpeculation() OVERRIDE { return; } + virtual void SetOmniboxSpeculation() OVERRIDE { return; } + + virtual bool WasEverUsed() const OVERRIDE { return false; } + + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } + + virtual bool WasNpnNegotiated() const OVERRIDE { return false; } + + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; } + + // Socket implementation + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + + virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return ERR_FAILED; } + + virtual int SetSendBufferSize(int32 size) OVERRIDE { return ERR_FAILED; } + + private: + BoundNetLog bound_net_log_; + + DISALLOW_COPY_AND_ASSIGN(FakeStreamSocket); +}; + +class FakeWaiter : public WebSocketEndpointLockManager::Waiter { + public: + FakeWaiter() : called_(false) {} + + virtual void GotEndpointLock() OVERRIDE { + CHECK(!called_); + called_ = true; + } + + bool called() const { return called_; } + + private: + bool called_; +}; + +class WebSocketEndpointLockManagerTest : public ::testing::Test { + protected: + WebSocketEndpointLockManagerTest() + : instance_(WebSocketEndpointLockManager::GetInstance()) {} + virtual ~WebSocketEndpointLockManagerTest() { + // If this check fails then subsequent tests may fail. + CHECK(instance_->IsEmpty()); + } + + WebSocketEndpointLockManager* instance() const { return instance_; } + + IPEndPoint DummyEndpoint() { + IPAddressNumber ip_address_number; + CHECK(ParseIPLiteralToNumber("127.0.0.1", &ip_address_number)); + return IPEndPoint(ip_address_number, 80); + } + + void UnlockDummyEndpoint(int times) { + for (int i = 0; i < times; ++i) { + instance()->UnlockEndpoint(DummyEndpoint()); + } + } + + WebSocketEndpointLockManager* const instance_; +}; + +TEST_F(WebSocketEndpointLockManagerTest, GetInstanceWorks) { + // All the work is done by the test framework. +} + +TEST_F(WebSocketEndpointLockManagerTest, LockEndpointReturnsOkOnce) { + FakeWaiter waiters[2]; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + + UnlockDummyEndpoint(2); +} + +TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockNotCalledOnOk) { + FakeWaiter waiter; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiter)); + EXPECT_FALSE(waiter.called()); + + UnlockDummyEndpoint(1); +} + +TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockNotCalledImmediately) { + FakeWaiter waiters[2]; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + EXPECT_FALSE(waiters[1].called()); + + UnlockDummyEndpoint(2); +} + +TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockCalledWhenUnlocked) { + FakeWaiter waiters[2]; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + instance()->UnlockEndpoint(DummyEndpoint()); + EXPECT_TRUE(waiters[1].called()); + + UnlockDummyEndpoint(1); +} + +TEST_F(WebSocketEndpointLockManagerTest, + EndpointUnlockedIfWaiterAlreadyDeleted) { + FakeWaiter first_lock_holder; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &first_lock_holder)); + + { + FakeWaiter short_lived_waiter; + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &short_lived_waiter)); + } + + instance()->UnlockEndpoint(DummyEndpoint()); + + FakeWaiter second_lock_holder; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &second_lock_holder)); + + UnlockDummyEndpoint(1); +} + +TEST_F(WebSocketEndpointLockManagerTest, RememberSocketWorks) { + FakeWaiter waiters[2]; + FakeStreamSocket dummy_socket; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + + instance()->RememberSocket(&dummy_socket, DummyEndpoint()); + instance()->UnlockSocket(&dummy_socket); + EXPECT_TRUE(waiters[1].called()); + + UnlockDummyEndpoint(1); +} + +// Calling UnlockSocket() on the same socket a second time should be harmless. +TEST_F(WebSocketEndpointLockManagerTest, UnlockSocketTwice) { + FakeWaiter waiter; + FakeStreamSocket dummy_socket; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiter)); + instance()->RememberSocket(&dummy_socket, DummyEndpoint()); + instance()->UnlockSocket(&dummy_socket); + instance()->UnlockSocket(&dummy_socket); +} + +} // namespace + +} // namespace net diff --git a/net/socket/websocket_transport_client_socket_pool.cc b/net/socket/websocket_transport_client_socket_pool.cc new file mode 100644 index 00000000000000..39a2771cb2b9f7 --- /dev/null +++ b/net/socket/websocket_transport_client_socket_pool.cc @@ -0,0 +1,645 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/websocket_transport_client_socket_pool.h" + +#include + +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "base/numerics/safe_conversions.h" +#include "base/strings/string_util.h" +#include "base/time/time.h" +#include "base/values.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/websocket_endpoint_lock_manager.h" +#include "net/socket/websocket_transport_connect_sub_job.h" + +namespace net { + +namespace { + +using base::TimeDelta; + +// TODO(ricea): For now, we implement a global timeout for compatability with +// TransportConnectJob. Since WebSocketTransportConnectJob controls the address +// selection process more tightly, it could do something smarter here. +const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes. + +} // namespace + +WebSocketTransportConnectJob::WebSocketTransportConnectJob( + const std::string& group_name, + RequestPriority priority, + const scoped_refptr& params, + TimeDelta timeout_duration, + const CompletionCallback& callback, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + ClientSocketHandle* handle, + Delegate* delegate, + NetLog* pool_net_log, + const BoundNetLog& request_net_log) + : ConnectJob(group_name, + timeout_duration, + priority, + delegate, + BoundNetLog::Make(pool_net_log, NetLog::SOURCE_CONNECT_JOB)), + helper_(params, client_socket_factory, host_resolver, &connect_timing_), + race_result_(TransportConnectJobHelper::CONNECTION_LATENCY_UNKNOWN), + handle_(handle), + callback_(callback), + request_net_log_(request_net_log), + had_ipv4_(false), + had_ipv6_(false) { + helper_.SetOnIOComplete(this); +} + +WebSocketTransportConnectJob::~WebSocketTransportConnectJob() {} + +LoadState WebSocketTransportConnectJob::GetLoadState() const { + LoadState load_state = LOAD_STATE_RESOLVING_HOST; + if (ipv6_job_) + load_state = ipv6_job_->GetLoadState(); + // This method should return LOAD_STATE_CONNECTING in preference to + // LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET when possible because "waiting for + // available socket" implies that nothing is happening. + if (ipv4_job_ && load_state != LOAD_STATE_CONNECTING) + load_state = ipv4_job_->GetLoadState(); + return load_state; +} + +int WebSocketTransportConnectJob::DoResolveHost() { + return helper_.DoResolveHost(priority(), net_log()); +} + +int WebSocketTransportConnectJob::DoResolveHostComplete(int result) { + return helper_.DoResolveHostComplete(result, net_log()); +} + +int WebSocketTransportConnectJob::DoTransportConnect() { + AddressList ipv4_addresses; + AddressList ipv6_addresses; + int result = ERR_UNEXPECTED; + helper_.set_next_state( + TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE); + + for (AddressList::const_iterator it = helper_.addresses().begin(); + it != helper_.addresses().end(); + ++it) { + switch (it->GetFamily()) { + case ADDRESS_FAMILY_IPV4: + ipv4_addresses.push_back(*it); + break; + + case ADDRESS_FAMILY_IPV6: + ipv6_addresses.push_back(*it); + break; + + default: + DVLOG(1) << "Unexpected ADDRESS_FAMILY: " << it->GetFamily(); + break; + } + } + + if (!ipv4_addresses.empty()) { + had_ipv4_ = true; + ipv4_job_.reset(new WebSocketTransportConnectSubJob( + ipv4_addresses, this, SUB_JOB_IPV4)); + } + + if (!ipv6_addresses.empty()) { + had_ipv6_ = true; + ipv6_job_.reset(new WebSocketTransportConnectSubJob( + ipv6_addresses, this, SUB_JOB_IPV6)); + result = ipv6_job_->Start(); + switch (result) { + case OK: + SetSocket(ipv6_job_->PassSocket()); + race_result_ = + had_ipv4_ + ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_RACEABLE + : TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_SOLO; + return result; + + case ERR_IO_PENDING: + if (ipv4_job_) { + // This use of base::Unretained is safe because |fallback_timer_| is + // owned by this object. + fallback_timer_.Start( + FROM_HERE, + TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs), + base::Bind(&WebSocketTransportConnectJob::StartIPv4JobAsync, + base::Unretained(this))); + } + return result; + + default: + ipv6_job_.reset(); + } + } + + DCHECK(!ipv6_job_); + if (ipv4_job_) { + result = ipv4_job_->Start(); + if (result == OK) { + SetSocket(ipv4_job_->PassSocket()); + race_result_ = + had_ipv6_ + ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_WINS_RACE + : TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_NO_RACE; + } + } + + return result; +} + +int WebSocketTransportConnectJob::DoTransportConnectComplete(int result) { + if (result == OK) + helper_.HistogramDuration(race_result_); + return result; +} + +void WebSocketTransportConnectJob::OnSubJobComplete( + int result, + WebSocketTransportConnectSubJob* job) { + if (result == OK) { + switch (job->type()) { + case SUB_JOB_IPV4: + race_result_ = + had_ipv6_ + ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_WINS_RACE + : TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_NO_RACE; + break; + + case SUB_JOB_IPV6: + race_result_ = + had_ipv4_ + ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_RACEABLE + : TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_SOLO; + break; + } + SetSocket(job->PassSocket()); + + // Make sure all connections are cancelled even if this object fails to be + // deleted. + ipv4_job_.reset(); + ipv6_job_.reset(); + } else { + switch (job->type()) { + case SUB_JOB_IPV4: + ipv4_job_.reset(); + break; + + case SUB_JOB_IPV6: + ipv6_job_.reset(); + if (ipv4_job_ && !ipv4_job_->started()) { + fallback_timer_.Stop(); + result = ipv4_job_->Start(); + if (result != ERR_IO_PENDING) { + OnSubJobComplete(result, ipv4_job_.get()); + return; + } + } + break; + } + if (ipv4_job_ || ipv6_job_) + return; + } + helper_.OnIOComplete(this, result); +} + +void WebSocketTransportConnectJob::StartIPv4JobAsync() { + DCHECK(ipv4_job_); + int result = ipv4_job_->Start(); + if (result != ERR_IO_PENDING) + OnSubJobComplete(result, ipv4_job_.get()); +} + +int WebSocketTransportConnectJob::ConnectInternal() { + return helper_.DoConnectInternal(this); +} + +WebSocketTransportClientSocketPool::WebSocketTransportClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + ClientSocketFactory* client_socket_factory, + NetLog* net_log) + : TransportClientSocketPool(max_sockets, + max_sockets_per_group, + histograms, + host_resolver, + client_socket_factory, + net_log), + connect_job_delegate_(this), + histograms_(histograms), + pool_net_log_(net_log), + client_socket_factory_(client_socket_factory), + host_resolver_(host_resolver), + max_sockets_(max_sockets), + handed_out_socket_count_(0), + flushing_(false), + weak_factory_(this) {} + +WebSocketTransportClientSocketPool::~WebSocketTransportClientSocketPool() { + // Clean up any pending connect jobs. + FlushWithError(ERR_ABORTED); + DCHECK(pending_connects_.empty()); + DCHECK_EQ(0, handed_out_socket_count_); + DCHECK(stalled_request_queue_.empty()); + DCHECK(stalled_request_map_.empty()); +} + +// static +void WebSocketTransportClientSocketPool::UnlockEndpoint( + ClientSocketHandle* handle) { + DCHECK(handle->is_initialized()); + WebSocketEndpointLockManager::GetInstance()->UnlockSocket(handle->socket()); +} + +int WebSocketTransportClientSocketPool::RequestSocket( + const std::string& group_name, + const void* params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& request_net_log) { + DCHECK(params); + const scoped_refptr& casted_params = + *static_cast*>(params); + + NetLogTcpClientSocketPoolRequestedSocket(request_net_log, &casted_params); + + CHECK(!callback.is_null()); + CHECK(handle); + + request_net_log.BeginEvent(NetLog::TYPE_SOCKET_POOL); + + if (ReachedMaxSocketsLimit() && !casted_params->ignore_limits()) { + request_net_log.AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS); + // TODO(ricea): Use emplace_back when C++11 becomes allowed. + StalledRequest request( + casted_params, priority, handle, callback, request_net_log); + stalled_request_queue_.push_back(request); + StalledRequestQueue::iterator iterator = stalled_request_queue_.end(); + --iterator; + DCHECK_EQ(handle, iterator->handle); + // Because StalledRequestQueue is a std::list, its iterators are guaranteed + // to remain valid as long as the elements are not removed. As long as + // stalled_request_queue_ and stalled_request_map_ are updated in sync, it + // is safe to dereference an iterator in stalled_request_map_ to find the + // corresponding list element. + stalled_request_map_.insert( + StalledRequestMap::value_type(handle, iterator)); + return ERR_IO_PENDING; + } + + scoped_ptr connect_job( + new WebSocketTransportConnectJob(group_name, + priority, + casted_params, + ConnectionTimeout(), + callback, + client_socket_factory_, + host_resolver_, + handle, + &connect_job_delegate_, + pool_net_log_, + request_net_log)); + + int rv = connect_job->Connect(); + // Regardless of the outcome of |connect_job|, it will always be bound to + // |handle|, since this pool uses early-binding. So the binding is logged + // here, without waiting for the result. + request_net_log.AddEvent( + NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, + connect_job->net_log().source().ToEventParametersCallback()); + if (rv == OK) { + HandOutSocket(connect_job->PassSocket(), + connect_job->connect_timing(), + handle, + request_net_log); + request_net_log.EndEvent(NetLog::TYPE_SOCKET_POOL); + } else if (rv == ERR_IO_PENDING) { + // TODO(ricea): Implement backup job timer? + AddJob(handle, connect_job.Pass()); + } else { + scoped_ptr error_socket; + connect_job->GetAdditionalErrorState(handle); + error_socket = connect_job->PassSocket(); + if (error_socket) { + HandOutSocket(error_socket.Pass(), + connect_job->connect_timing(), + handle, + request_net_log); + } + } + + if (rv != ERR_IO_PENDING) { + request_net_log.EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, rv); + } + + return rv; +} + +void WebSocketTransportClientSocketPool::RequestSockets( + const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) { + NOTIMPLEMENTED(); +} + +void WebSocketTransportClientSocketPool::CancelRequest( + const std::string& group_name, + ClientSocketHandle* handle) { + if (DeleteStalledRequest(handle)) + return; + if (!DeleteJob(handle)) + pending_callbacks_.erase(handle); + if (!ReachedMaxSocketsLimit() && !stalled_request_queue_.empty()) + ActivateStalledRequest(); +} + +void WebSocketTransportClientSocketPool::ReleaseSocket( + const std::string& group_name, + scoped_ptr socket, + int id) { + WebSocketEndpointLockManager::GetInstance()->UnlockSocket(socket.get()); + CHECK_GT(handed_out_socket_count_, 0); + --handed_out_socket_count_; + if (!ReachedMaxSocketsLimit() && !stalled_request_queue_.empty()) + ActivateStalledRequest(); +} + +void WebSocketTransportClientSocketPool::FlushWithError(int error) { + // Sockets which are in LOAD_STATE_CONNECTING are in danger of unlocking + // sockets waiting for the endpoint lock. If they connected synchronously, + // then OnConnectJobComplete(). The |flushing_| flag tells this object to + // ignore spurious calls to OnConnectJobComplete(). It is safe to ignore those + // calls because this method will delete the jobs and call their callbacks + // anyway. + flushing_ = true; + for (PendingConnectsMap::iterator it = pending_connects_.begin(); + it != pending_connects_.end(); + ++it) { + InvokeUserCallbackLater( + it->second->handle(), it->second->callback(), error); + delete it->second, it->second = NULL; + } + pending_connects_.clear(); + for (StalledRequestQueue::iterator it = stalled_request_queue_.begin(); + it != stalled_request_queue_.end(); + ++it) { + InvokeUserCallbackLater(it->handle, it->callback, error); + } + stalled_request_map_.clear(); + stalled_request_queue_.clear(); + handed_out_socket_count_ = 0; + flushing_ = false; +} + +void WebSocketTransportClientSocketPool::CloseIdleSockets() { + // We have no idle sockets. +} + +int WebSocketTransportClientSocketPool::IdleSocketCount() const { + return 0; +} + +int WebSocketTransportClientSocketPool::IdleSocketCountInGroup( + const std::string& group_name) const { + return 0; +} + +LoadState WebSocketTransportClientSocketPool::GetLoadState( + const std::string& group_name, + const ClientSocketHandle* handle) const { + if (stalled_request_map_.find(handle) != stalled_request_map_.end()) + return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET; + if (pending_callbacks_.count(handle)) + return LOAD_STATE_CONNECTING; + return LookupConnectJob(handle)->GetLoadState(); +} + +base::DictionaryValue* WebSocketTransportClientSocketPool::GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetString("name", name); + dict->SetString("type", type); + dict->SetInteger("handed_out_socket_count", handed_out_socket_count_); + dict->SetInteger("connecting_socket_count", pending_connects_.size()); + dict->SetInteger("idle_socket_count", 0); + dict->SetInteger("max_socket_count", max_sockets_); + dict->SetInteger("max_sockets_per_group", max_sockets_); + dict->SetInteger("pool_generation_number", 0); + return dict; +} + +TimeDelta WebSocketTransportClientSocketPool::ConnectionTimeout() const { + return TimeDelta::FromSeconds(kTransportConnectJobTimeoutInSeconds); +} + +ClientSocketPoolHistograms* WebSocketTransportClientSocketPool::histograms() + const { + return histograms_; +} + +bool WebSocketTransportClientSocketPool::IsStalled() const { + return !stalled_request_queue_.empty(); +} + +void WebSocketTransportClientSocketPool::OnConnectJobComplete( + int result, + WebSocketTransportConnectJob* job) { + DCHECK_NE(ERR_IO_PENDING, result); + + scoped_ptr socket = job->PassSocket(); + + // See comment in FlushWithError. + if (flushing_) { + WebSocketEndpointLockManager::GetInstance()->UnlockSocket(socket.get()); + return; + } + + BoundNetLog request_net_log = job->request_net_log(); + CompletionCallback callback = job->callback(); + LoadTimingInfo::ConnectTiming connect_timing = job->connect_timing(); + + ClientSocketHandle* const handle = job->handle(); + bool handed_out_socket = false; + + if (result == OK) { + DCHECK(socket.get()); + handed_out_socket = true; + HandOutSocket(socket.Pass(), connect_timing, handle, request_net_log); + request_net_log.EndEvent(NetLog::TYPE_SOCKET_POOL); + } else { + // If we got a socket, it must contain error information so pass that + // up so that the caller can retrieve it. + job->GetAdditionalErrorState(handle); + if (socket.get()) { + handed_out_socket = true; + HandOutSocket(socket.Pass(), connect_timing, handle, request_net_log); + } + request_net_log.EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, result); + } + bool delete_succeeded = DeleteJob(handle); + DCHECK(delete_succeeded); + if (!handed_out_socket && !stalled_request_queue_.empty() && + !ReachedMaxSocketsLimit()) + ActivateStalledRequest(); + InvokeUserCallbackLater(handle, callback, result); +} + +void WebSocketTransportClientSocketPool::InvokeUserCallbackLater( + ClientSocketHandle* handle, + const CompletionCallback& callback, + int rv) { + DCHECK(!pending_callbacks_.count(handle)); + pending_callbacks_.insert(handle); + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&WebSocketTransportClientSocketPool::InvokeUserCallback, + weak_factory_.GetWeakPtr(), + handle, + callback, + rv)); +} + +void WebSocketTransportClientSocketPool::InvokeUserCallback( + ClientSocketHandle* handle, + const CompletionCallback& callback, + int rv) { + if (pending_callbacks_.erase(handle)) + callback.Run(rv); +} + +bool WebSocketTransportClientSocketPool::ReachedMaxSocketsLimit() const { + return handed_out_socket_count_ >= max_sockets_ || + base::checked_cast(pending_connects_.size()) >= + max_sockets_ - handed_out_socket_count_; +} + +void WebSocketTransportClientSocketPool::HandOutSocket( + scoped_ptr socket, + const LoadTimingInfo::ConnectTiming& connect_timing, + ClientSocketHandle* handle, + const BoundNetLog& net_log) { + DCHECK(socket); + handle->SetSocket(socket.Pass()); + DCHECK_EQ(ClientSocketHandle::UNUSED, handle->reuse_type()); + DCHECK_EQ(0, handle->idle_time().InMicroseconds()); + handle->set_pool_id(0); + handle->set_connect_timing(connect_timing); + + net_log.AddEvent( + NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, + handle->socket()->NetLog().source().ToEventParametersCallback()); + + ++handed_out_socket_count_; +} + +void WebSocketTransportClientSocketPool::AddJob( + ClientSocketHandle* handle, + scoped_ptr connect_job) { + bool inserted = + pending_connects_.insert(PendingConnectsMap::value_type( + handle, connect_job.release())).second; + DCHECK(inserted); +} + +bool WebSocketTransportClientSocketPool::DeleteJob(ClientSocketHandle* handle) { + PendingConnectsMap::iterator it = pending_connects_.find(handle); + if (it == pending_connects_.end()) + return false; + // Deleting a ConnectJob which holds an endpoint lock can lead to a different + // ConnectJob proceeding to connect. If the connect proceeds synchronously + // (usually because of a failure) then it can trigger that job to be + // deleted. |it| remains valid because std::map guarantees that erase() does + // not invalid iterators to other entries. + delete it->second, it->second = NULL; + DCHECK(pending_connects_.find(handle) == it); + pending_connects_.erase(it); + return true; +} + +const WebSocketTransportConnectJob* +WebSocketTransportClientSocketPool::LookupConnectJob( + const ClientSocketHandle* handle) const { + PendingConnectsMap::const_iterator it = pending_connects_.find(handle); + CHECK(it != pending_connects_.end()); + return it->second; +} + +void WebSocketTransportClientSocketPool::ActivateStalledRequest() { + DCHECK(!stalled_request_queue_.empty()); + DCHECK(!ReachedMaxSocketsLimit()); + // Usually we will only be able to activate one stalled request at a time, + // however if all the connects fail synchronously for some reason, we may be + // able to clear the whole queue at once. + while (!stalled_request_queue_.empty() && !ReachedMaxSocketsLimit()) { + StalledRequest request(stalled_request_queue_.front()); + stalled_request_queue_.pop_front(); + stalled_request_map_.erase(request.handle); + int rv = RequestSocket("ignored", + &request.params, + request.priority, + request.handle, + request.callback, + request.net_log); + // ActivateStalledRequest() never returns synchronously, so it is never + // called re-entrantly. + if (rv != ERR_IO_PENDING) + InvokeUserCallbackLater(request.handle, request.callback, rv); + } +} + +bool WebSocketTransportClientSocketPool::DeleteStalledRequest( + ClientSocketHandle* handle) { + StalledRequestMap::iterator it = stalled_request_map_.find(handle); + if (it == stalled_request_map_.end()) + return false; + stalled_request_queue_.erase(it->second); + stalled_request_map_.erase(it); + return true; +} + +WebSocketTransportClientSocketPool::ConnectJobDelegate::ConnectJobDelegate( + WebSocketTransportClientSocketPool* owner) + : owner_(owner) {} + +WebSocketTransportClientSocketPool::ConnectJobDelegate::~ConnectJobDelegate() {} + +void +WebSocketTransportClientSocketPool::ConnectJobDelegate::OnConnectJobComplete( + int result, + ConnectJob* job) { + owner_->OnConnectJobComplete(result, + static_cast(job)); +} + +WebSocketTransportClientSocketPool::StalledRequest::StalledRequest( + const scoped_refptr& params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) + : params(params), + priority(priority), + handle(handle), + callback(callback), + net_log(net_log) {} + +WebSocketTransportClientSocketPool::StalledRequest::~StalledRequest() {} + +} // namespace net diff --git a/net/socket/websocket_transport_client_socket_pool.h b/net/socket/websocket_transport_client_socket_pool.h new file mode 100644 index 00000000000000..d0d2d9df22637a --- /dev/null +++ b/net/socket/websocket_transport_client_socket_pool.h @@ -0,0 +1,245 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ +#define NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ + +#include +#include +#include +#include + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" +#include "base/time/time.h" +#include "base/timer/timer.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/transport_client_socket_pool.h" + +namespace net { + +class ClientSocketFactory; +class ClientSocketPoolHistograms; +class HostResolver; +class NetLog; +class WebSocketEndpointLockManager; +class WebSocketTransportConnectSubJob; + +// WebSocketTransportConnectJob handles the host resolution necessary for socket +// creation and the TCP connect. WebSocketTransportConnectJob also has fallback +// logic for IPv6 connect() timeouts (which may happen due to networks / routers +// with broken IPv6 support). Those timeouts take 20s, so rather than make the +// user wait 20s for the timeout to fire, we use a fallback timer +// (kIPv6FallbackTimerInMs) and start a connect() to an IPv4 address if the +// timer fires. Then we race the IPv4 connect(s) against the IPv6 connect(s) and +// use the socket that completes successfully first or fails last. +class NET_EXPORT_PRIVATE WebSocketTransportConnectJob : public ConnectJob { + public: + WebSocketTransportConnectJob( + const std::string& group_name, + RequestPriority priority, + const scoped_refptr& params, + base::TimeDelta timeout_duration, + const CompletionCallback& callback, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + ClientSocketHandle* handle, + Delegate* delegate, + NetLog* pool_net_log, + const BoundNetLog& request_net_log); + virtual ~WebSocketTransportConnectJob(); + + // Unlike normal socket pools, the WebSocketTransportClientPool uses + // early-binding of sockets. + ClientSocketHandle* handle() const { return handle_; } + + // Stash the callback from RequestSocket() here for convenience. + const CompletionCallback& callback() const { return callback_; } + + const BoundNetLog& request_net_log() const { return request_net_log_; } + + // ConnectJob methods. + virtual LoadState GetLoadState() const OVERRIDE; + + private: + friend class WebSocketTransportConnectSubJob; + friend class TransportConnectJobHelper; + friend class WebSocketEndpointLockManager; + + // Although it is not strictly necessary, it makes the code simpler if each + // subjob knows what type it is. + enum SubJobType { SUB_JOB_IPV4, SUB_JOB_IPV6 }; + + int DoResolveHost(); + int DoResolveHostComplete(int result); + int DoTransportConnect(); + int DoTransportConnectComplete(int result); + + // Called back from a SubJob when it completes. + void OnSubJobComplete(int result, WebSocketTransportConnectSubJob* job); + + // Called from |fallback_timer_|. + void StartIPv4JobAsync(); + + // Begins the host resolution and the TCP connect. Returns OK on success + // and ERR_IO_PENDING if it cannot immediately service the request. + // Otherwise, it returns a net error code. + virtual int ConnectInternal() OVERRIDE; + + TransportConnectJobHelper helper_; + + // The addresses are divided into IPv4 and IPv6, which are performed partially + // in parallel. If the list of IPv6 addresses is non-empty, then the IPv6 jobs + // go first, followed after |kIPv6FallbackTimerInMs| by the IPv4 + // addresses. First sub-job to establish a connection wins. + scoped_ptr ipv4_job_; + scoped_ptr ipv6_job_; + + base::OneShotTimer fallback_timer_; + TransportConnectJobHelper::ConnectionLatencyHistogram race_result_; + ClientSocketHandle* const handle_; + CompletionCallback callback_; + BoundNetLog request_net_log_; + + bool had_ipv4_; + bool had_ipv6_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketTransportConnectJob); +}; + +class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool + : public TransportClientSocketPool { + public: + WebSocketTransportClientSocketPool(int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + ClientSocketFactory* client_socket_factory, + NetLog* net_log); + + virtual ~WebSocketTransportClientSocketPool(); + + // Allow another connection to be started to the IPEndPoint that this |handle| + // is connected to. Used when the WebSocket handshake completes successfully. + static void UnlockEndpoint(ClientSocketHandle* handle); + + // ClientSocketPool implementation. + virtual int RequestSocket(const std::string& group_name, + const void* resolve_info, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) OVERRIDE; + virtual void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) OVERRIDE; + virtual void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) OVERRIDE; + virtual void ReleaseSocket(const std::string& group_name, + scoped_ptr socket, + int id) OVERRIDE; + virtual void FlushWithError(int error) OVERRIDE; + virtual void CloseIdleSockets() OVERRIDE; + virtual int IdleSocketCount() const OVERRIDE; + virtual int IdleSocketCountInGroup( + const std::string& group_name) const OVERRIDE; + virtual LoadState GetLoadState( + const std::string& group_name, + const ClientSocketHandle* handle) const OVERRIDE; + virtual base::DictionaryValue* GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const OVERRIDE; + virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + + // HigherLayeredPool implementation. + virtual bool IsStalled() const OVERRIDE; + + private: + class ConnectJobDelegate : public ConnectJob::Delegate { + public: + explicit ConnectJobDelegate(WebSocketTransportClientSocketPool* owner); + virtual ~ConnectJobDelegate(); + + virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE; + + private: + WebSocketTransportClientSocketPool* owner_; + + DISALLOW_COPY_AND_ASSIGN(ConnectJobDelegate); + }; + + // Store the arguments from a call to RequestSocket() that has stalled so we + // can replay it when there are available socket slots. + struct StalledRequest { + StalledRequest(const scoped_refptr& params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log); + ~StalledRequest(); + const scoped_refptr params; + const RequestPriority priority; + ClientSocketHandle* const handle; + const CompletionCallback callback; + const BoundNetLog net_log; + }; + friend class ConnectJobDelegate; + typedef std::map + PendingConnectsMap; + // This is a list so that we can remove requests from the middle, and also + // so that iterators are not invalidated unless the corresponding request is + // removed. + typedef std::list StalledRequestQueue; + typedef std::map + StalledRequestMap; + + void OnConnectJobComplete(int result, WebSocketTransportConnectJob* job); + void InvokeUserCallbackLater(ClientSocketHandle* handle, + const CompletionCallback& callback, + int rv); + void InvokeUserCallback(ClientSocketHandle* handle, + const CompletionCallback& callback, + int rv); + bool ReachedMaxSocketsLimit() const; + void HandOutSocket(scoped_ptr socket, + const LoadTimingInfo::ConnectTiming& connect_timing, + ClientSocketHandle* handle, + const BoundNetLog& net_log); + void AddJob(ClientSocketHandle* handle, + scoped_ptr connect_job); + bool DeleteJob(ClientSocketHandle* handle); + const WebSocketTransportConnectJob* LookupConnectJob( + const ClientSocketHandle* handle) const; + void ActivateStalledRequest(); + bool DeleteStalledRequest(ClientSocketHandle* handle); + + ConnectJobDelegate connect_job_delegate_; + std::set pending_callbacks_; + PendingConnectsMap pending_connects_; + StalledRequestQueue stalled_request_queue_; + StalledRequestMap stalled_request_map_; + ClientSocketPoolHistograms* const histograms_; + NetLog* const pool_net_log_; + ClientSocketFactory* const client_socket_factory_; + HostResolver* const host_resolver_; + const int max_sockets_; + int handed_out_socket_count_; + bool flushing_; + + base::WeakPtrFactory weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketTransportClientSocketPool); +}; + +} // namespace net + +#endif // NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ diff --git a/net/socket/websocket_transport_client_socket_pool_unittest.cc b/net/socket/websocket_transport_client_socket_pool_unittest.cc new file mode 100644 index 00000000000000..c122502f129087 --- /dev/null +++ b/net/socket/websocket_transport_client_socket_pool_unittest.cc @@ -0,0 +1,1066 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/websocket_transport_client_socket_pool.h" + +#include +#include + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/callback.h" +#include "base/macros.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/strings/stringprintf.h" +#include "base/time/time.h" +#include "net/base/capturing_net_log.h" +#include "net/base/ip_endpoint.h" +#include "net/base/load_timing_info.h" +#include "net/base/load_timing_info_test_util.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "net/dns/mock_host_resolver.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/stream_socket.h" +#include "net/socket/transport_client_socket_pool_test_util.h" +#include "net/socket/websocket_endpoint_lock_manager.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +const int kMaxSockets = 32; +const int kMaxSocketsPerGroup = 6; +const net::RequestPriority kDefaultPriority = LOW; + +// RunLoop doesn't support this natively but it is easy to emulate. +void RunLoopForTimePeriod(base::TimeDelta period) { + base::RunLoop run_loop; + base::Closure quit_closure(run_loop.QuitClosure()); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, quit_closure, period); + run_loop.Run(); +} + +class WebSocketTransportClientSocketPoolTest : public testing::Test { + protected: + WebSocketTransportClientSocketPoolTest() + : params_(new TransportSocketParams(HostPortPair("www.google.com", 80), + false, + false, + OnHostResolutionCallback())), + histograms_(new ClientSocketPoolHistograms("TCPUnitTest")), + host_resolver_(new MockHostResolver), + client_socket_factory_(&net_log_), + pool_(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL) {} + + virtual ~WebSocketTransportClientSocketPoolTest() { + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + EXPECT_TRUE(WebSocketEndpointLockManager::GetInstance()->IsEmpty()); + } + + int StartRequest(const std::string& group_name, RequestPriority priority) { + scoped_refptr params( + new TransportSocketParams(HostPortPair("www.google.com", 80), + false, + false, + OnHostResolutionCallback())); + return test_base_.StartRequestUsingPool( + &pool_, group_name, priority, params); + } + + int GetOrderOfRequest(size_t index) { + return test_base_.GetOrderOfRequest(index); + } + + bool ReleaseOneConnection(ClientSocketPoolTest::KeepAlive keep_alive) { + return test_base_.ReleaseOneConnection(keep_alive); + } + + void ReleaseAllConnections(ClientSocketPoolTest::KeepAlive keep_alive) { + test_base_.ReleaseAllConnections(keep_alive); + } + + TestSocketRequest* request(int i) { return test_base_.request(i); } + + ScopedVector* requests() { return test_base_.requests(); } + size_t completion_count() const { return test_base_.completion_count(); } + + CapturingNetLog net_log_; + scoped_refptr params_; + scoped_ptr histograms_; + scoped_ptr host_resolver_; + MockTransportClientSocketFactory client_socket_factory_; + WebSocketTransportClientSocketPool pool_; + ClientSocketPoolTest test_base_; + + private: + DISALLOW_COPY_AND_ASSIGN(WebSocketTransportClientSocketPoolTest); +}; + +TEST_F(WebSocketTransportClientSocketPoolTest, Basic) { + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init( + "a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfoConnectedNotReused(handle); +} + +// Make sure that WebSocketTransportConnectJob passes on its priority to its +// HostResolver request on Init. +TEST_F(WebSocketTransportClientSocketPoolTest, SetResolvePriorityOnInit) { + for (int i = MINIMUM_PRIORITY; i <= MAXIMUM_PRIORITY; ++i) { + RequestPriority priority = static_cast(i); + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + priority, + callback.callback(), + &pool_, + BoundNetLog())); + EXPECT_EQ(priority, host_resolver_->last_request_priority()); + } +} + +TEST_F(WebSocketTransportClientSocketPoolTest, InitHostResolutionFailure) { + host_resolver_->rules()->AddSimulatedFailure("unresolvable.host.name"); + TestCompletionCallback callback; + ClientSocketHandle handle; + HostPortPair host_port_pair("unresolvable.host.name", 80); + scoped_refptr dest(new TransportSocketParams( + host_port_pair, false, false, OnHostResolutionCallback())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + dest, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, callback.WaitForResult()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, InitConnectionFailure) { + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET); + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + + // Make the host resolutions complete synchronously this time. + host_resolver_->set_synchronous_mode(true); + EXPECT_EQ(ERR_CONNECTION_FAILED, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, PendingRequestsFinishFifo) { + // First request finishes asynchronously. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, request(0)->WaitForResult()); + + // Make all subsequent host resolutions complete synchronously. + host_resolver_->set_synchronous_mode(true); + + // Rest of them wait for the first socket to be released. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE); + + EXPECT_EQ(6, client_socket_factory_.allocation_count()); + + // One initial asynchronous request and then 5 pending requests. + EXPECT_EQ(6U, completion_count()); + + // The requests finish in FIFO order. + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(3, GetOrderOfRequest(3)); + EXPECT_EQ(4, GetOrderOfRequest(4)); + EXPECT_EQ(5, GetOrderOfRequest(5)); + EXPECT_EQ(6, GetOrderOfRequest(6)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(7)); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, PendingRequests_NoKeepAlive) { + // First request finishes asynchronously. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, request(0)->WaitForResult()); + + // Make all subsequent host resolutions complete synchronously. + host_resolver_->set_synchronous_mode(true); + + // Rest of them wait for the first socket to be released. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + + // The pending requests should finish successfully. + EXPECT_EQ(OK, request(1)->WaitForResult()); + EXPECT_EQ(OK, request(2)->WaitForResult()); + EXPECT_EQ(OK, request(3)->WaitForResult()); + EXPECT_EQ(OK, request(4)->WaitForResult()); + EXPECT_EQ(OK, request(5)->WaitForResult()); + + EXPECT_EQ(static_cast(requests()->size()), + client_socket_factory_.allocation_count()); + + // First asynchronous request, and then last 5 pending requests. + EXPECT_EQ(6U, completion_count()); +} + +// This test will start up a RequestSocket() and then immediately Cancel() it. +// The pending host resolution will eventually complete, and destroy the +// ClientSocketPool which will crash if the group was not cleared properly. +TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequestClearGroup) { + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); + handle.Reset(); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, TwoRequestsCancelOne) { + ClientSocketHandle handle; + TestCompletionCallback callback; + ClientSocketHandle handle2; + TestCompletionCallback callback2; + + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + &pool_, + BoundNetLog())); + + handle.Reset(); + + EXPECT_EQ(OK, callback2.WaitForResult()); + handle2.Reset(); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, ConnectCancelConnect) { + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); + + handle.Reset(); + + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + &pool_, + BoundNetLog())); + + host_resolver_->set_synchronous_mode(true); + // At this point, handle has two ConnectingSockets out for it. Due to the + // setting the mock resolver into synchronous mode, the host resolution for + // both will return in the same loop of the MessageLoop. The client socket + // is a pending socket, so the Connect() will asynchronously complete on the + // next loop of the MessageLoop. That means that the first + // ConnectingSocket will enter OnIOComplete, and then the second one will. + // If the first one is not cancelled, it will advance the load state, and + // then the second one will crash. + + EXPECT_EQ(OK, callback2.WaitForResult()); + EXPECT_FALSE(callback.have_result()); + + handle.Reset(); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequest) { + // First request finishes asynchronously. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, request(0)->WaitForResult()); + + // Make all subsequent host resolutions complete synchronously. + host_resolver_->set_synchronous_mode(true); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + // Cancel a request. + const size_t index_to_cancel = 2; + EXPECT_FALSE(request(index_to_cancel)->handle()->is_initialized()); + request(index_to_cancel)->handle()->Reset(); + + ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE); + + EXPECT_EQ(5, client_socket_factory_.allocation_count()); + + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, + GetOrderOfRequest(3)); // Canceled request. + EXPECT_EQ(3, GetOrderOfRequest(4)); + EXPECT_EQ(4, GetOrderOfRequest(5)); + EXPECT_EQ(5, GetOrderOfRequest(6)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(7)); +} + +class RequestSocketCallback : public TestCompletionCallbackBase { + public: + RequestSocketCallback(ClientSocketHandle* handle, + WebSocketTransportClientSocketPool* pool) + : handle_(handle), + pool_(pool), + within_callback_(false), + callback_(base::Bind(&RequestSocketCallback::OnComplete, + base::Unretained(this))) {} + + virtual ~RequestSocketCallback() {} + + const CompletionCallback& callback() const { return callback_; } + + private: + void OnComplete(int result) { + SetResult(result); + ASSERT_EQ(OK, result); + + if (!within_callback_) { + // Don't allow reuse of the socket. Disconnect it and then release it and + // run through the MessageLoop once to get it completely released. + handle_->socket()->Disconnect(); + handle_->Reset(); + { + base::MessageLoop::ScopedNestableTaskAllower allow( + base::MessageLoop::current()); + base::MessageLoop::current()->RunUntilIdle(); + } + within_callback_ = true; + scoped_refptr dest( + new TransportSocketParams(HostPortPair("www.google.com", 80), + false, + false, + OnHostResolutionCallback())); + int rv = + handle_->Init("a", dest, LOWEST, callback(), pool_, BoundNetLog()); + EXPECT_EQ(OK, rv); + } + } + + ClientSocketHandle* const handle_; + WebSocketTransportClientSocketPool* const pool_; + bool within_callback_; + CompletionCallback callback_; + + DISALLOW_COPY_AND_ASSIGN(RequestSocketCallback); +}; + +TEST_F(WebSocketTransportClientSocketPoolTest, RequestTwice) { + ClientSocketHandle handle; + RequestSocketCallback callback(&handle, &pool_); + scoped_refptr dest( + new TransportSocketParams(HostPortPair("www.google.com", 80), + false, + false, + OnHostResolutionCallback())); + int rv = handle.Init( + "a", dest, LOWEST, callback.callback(), &pool_, BoundNetLog()); + ASSERT_EQ(ERR_IO_PENDING, rv); + + // The callback is going to request "www.google.com". We want it to complete + // synchronously this time. + host_resolver_->set_synchronous_mode(true); + + EXPECT_EQ(OK, callback.WaitForResult()); + + handle.Reset(); +} + +// Make sure that pending requests get serviced after active requests get +// cancelled. +TEST_F(WebSocketTransportClientSocketPoolTest, + CancelActiveRequestWithPendingRequests) { + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); + + // Queue up all the requests + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + // Now, kMaxSocketsPerGroup requests should be active. Let's cancel them. + ASSERT_LE(kMaxSocketsPerGroup, static_cast(requests()->size())); + for (int i = 0; i < kMaxSocketsPerGroup; i++) + request(i)->handle()->Reset(); + + // Let's wait for the rest to complete now. + for (size_t i = kMaxSocketsPerGroup; i < requests()->size(); ++i) { + EXPECT_EQ(OK, request(i)->WaitForResult()); + request(i)->handle()->Reset(); + } + + EXPECT_EQ(requests()->size() - kMaxSocketsPerGroup, completion_count()); +} + +// Make sure that pending requests get serviced after active requests fail. +TEST_F(WebSocketTransportClientSocketPoolTest, + FailingActiveRequestWithPendingRequests) { + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET); + + const int kNumRequests = 2 * kMaxSocketsPerGroup + 1; + ASSERT_LE(kNumRequests, kMaxSockets); // Otherwise the test will hang. + + // Queue up all the requests + for (int i = 0; i < kNumRequests; i++) + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + for (int i = 0; i < kNumRequests; i++) + EXPECT_EQ(ERR_CONNECTION_FAILED, request(i)->WaitForResult()); +} + +// The lock on the endpoint is released when a ClientSocketHandle is reset. +TEST_F(WebSocketTransportClientSocketPoolTest, LockReleasedOnHandleReset) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, request(0)->WaitForResult()); + EXPECT_FALSE(request(1)->handle()->is_initialized()); + request(0)->handle()->Reset(); + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(request(1)->handle()->is_initialized()); +} + +// The lock on the endpoint is released when a ClientSocketHandle is deleted. +TEST_F(WebSocketTransportClientSocketPoolTest, LockReleasedOnHandleDelete) { + TestCompletionCallback callback; + scoped_ptr handle(new ClientSocketHandle); + int rv = handle->Init( + "a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_FALSE(request(0)->handle()->is_initialized()); + handle.reset(); + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(request(0)->handle()->is_initialized()); +} + +// A new connection is performed when the lock on the previous connection is +// explicitly released. +TEST_F(WebSocketTransportClientSocketPoolTest, + ConnectionProceedsOnExplicitRelease) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, request(0)->WaitForResult()); + EXPECT_FALSE(request(1)->handle()->is_initialized()); + WebSocketTransportClientSocketPool::UnlockEndpoint(request(0)->handle()); + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(request(1)->handle()->is_initialized()); +} + +// A connection which is cancelled before completion does not block subsequent +// connections. +TEST_F(WebSocketTransportClientSocketPoolTest, + CancelDuringConnectionReleasesLock) { + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(case_types, + arraysize(case_types)); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + base::RunLoop().RunUntilIdle(); + pool_.CancelRequest("a", request(0)->handle()); + EXPECT_EQ(OK, request(1)->WaitForResult()); +} + +// Test the case of the IPv6 address stalling, and falling back to the IPv4 +// socket which finishes first. +TEST_F(WebSocketTransportClientSocketPoolTest, + IPv6FallbackSocketIPv4FinishesFirst) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + // This is the IPv6 socket. + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(case_types, 2); + + // Resolve an AddressList with an IPv6 address first and then an IPv4 address. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv4AddressSize, endpoint.address().size()); + EXPECT_EQ(2, client_socket_factory_.allocation_count()); +} + +// Test the case of the IPv6 address being slow, thus falling back to trying to +// connect to the IPv4 address, but having the connect to the IPv6 address +// finish first. +TEST_F(WebSocketTransportClientSocketPoolTest, + IPv6FallbackSocketIPv6FinishesFirst) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + // This is the IPv6 socket. + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(case_types, 2); + client_socket_factory_.set_delay(base::TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs + 50)); + + // Resolve an AddressList with an IPv6 address first and then an IPv4 address. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv6AddressSize, endpoint.address().size()); + EXPECT_EQ(2, client_socket_factory_.allocation_count()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + IPv6NoIPv4AddressesToFallbackTo) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + + // Resolve an AddressList with only IPv6 addresses. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,3:abcd::3:4:ff", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv6AddressSize, endpoint.address().size()); + EXPECT_EQ(1, client_socket_factory_.allocation_count()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, IPv4HasNoFallback) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + + // Resolve an AddressList with only IPv4 addresses. + host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv4AddressSize, endpoint.address().size()); + EXPECT_EQ(1, client_socket_factory_.allocation_count()); +} + +// If all IPv6 addresses fail to connect synchronously, then IPv4 connections +// proceeed immediately. +TEST_F(WebSocketTransportClientSocketPoolTest, IPv6InstantFail) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + // First IPv6 socket. + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET, + // Second IPv6 socket. + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(case_types, + arraysize(case_types)); + + // Resolve an AddressList with two IPv6 addresses and then an IPv4 address. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,2:abcd::3:5:ff,2.2.2.2", std::string()); + host_resolver_->set_synchronous_mode(true); + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(OK, rv); + ASSERT_TRUE(handle.socket()); + + IPEndPoint endpoint; + handle.socket()->GetPeerAddress(&endpoint); + EXPECT_EQ("2.2.2.2", endpoint.ToStringWithoutPort()); +} + +// If all IPv6 addresses fail before the IPv4 fallback timeout, then the IPv4 +// connections proceed immediately. +TEST_F(WebSocketTransportClientSocketPoolTest, IPv6RapidFail) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + // First IPv6 socket. + MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET, + // Second IPv6 socket. + MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(case_types, + arraysize(case_types)); + + // Resolve an AddressList with two IPv6 addresses and then an IPv4 address. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,2:abcd::3:5:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.socket()); + + base::Time start(base::Time::NowFromSystemTime()); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_LT(base::Time::NowFromSystemTime() - start, + base::TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs)); + ASSERT_TRUE(handle.socket()); + + IPEndPoint endpoint; + handle.socket()->GetPeerAddress(&endpoint); + EXPECT_EQ("2.2.2.2", endpoint.ToStringWithoutPort()); +} + +// If two sockets connect successfully, the one which connected first wins (this +// can only happen if the sockets are different types, since sockets of the same +// type do not race). +TEST_F(WebSocketTransportClientSocketPoolTest, FirstSuccessWins) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_TRIGGERABLE_CLIENT_SOCKET); + + // Resolve an AddressList with an IPv6 addresses and an IPv4 address. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + ASSERT_FALSE(handle.socket()); + + base::Closure ipv6_connect_trigger = + client_socket_factory_.WaitForTriggerableSocketCreation(); + base::Closure ipv4_connect_trigger = + client_socket_factory_.WaitForTriggerableSocketCreation(); + + ipv4_connect_trigger.Run(); + ipv6_connect_trigger.Run(); + + EXPECT_EQ(OK, callback.WaitForResult()); + ASSERT_TRUE(handle.socket()); + + IPEndPoint endpoint; + handle.socket()->GetPeerAddress(&endpoint); + EXPECT_EQ("2.2.2.2", endpoint.ToStringWithoutPort()); +} + +// We should not report failure until all connections have failed. +TEST_F(WebSocketTransportClientSocketPoolTest, LastFailureWins) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_FAILING_CLIENT_SOCKET); + base::TimeDelta delay = base::TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs / 3); + client_socket_factory_.set_delay(delay); + + // Resolve an AddressList with 4 IPv6 addresses and 2 IPv4 addresses. + host_resolver_->rules()->AddIPLiteralRule("*", + "1:abcd::3:4:ff,2:abcd::3:4:ff," + "3:abcd::3:4:ff,4:abcd::3:4:ff," + "1.1.1.1,2.2.2.2", + std::string()); + + // Expected order of events: + // After 100ms: Connect to 1:abcd::3:4:ff times out + // After 200ms: Connect to 2:abcd::3:4:ff times out + // After 300ms: Connect to 3:abcd::3:4:ff times out, IPv4 fallback starts + // After 400ms: Connect to 4:abcd::3:4:ff and 1.1.1.1 time out + // After 500ms: Connect to 2.2.2.2 times out + + TestCompletionCallback callback; + ClientSocketHandle handle; + base::Time start(base::Time::NowFromSystemTime()); + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + + EXPECT_GE(base::Time::NowFromSystemTime() - start, delay * 5); +} + +// Global timeout for all connects applies. This test is disabled by default +// because it takes 4 minutes. Run with --gtest_also_run_disabled_tests if you +// want to run it. +TEST_F(WebSocketTransportClientSocketPoolTest, DISABLED_OverallTimeoutApplies) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + const base::TimeDelta connect_job_timeout = pool.ConnectionTimeout(); + + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_FAILING_CLIENT_SOCKET); + client_socket_factory_.set_delay(base::TimeDelta::FromSeconds(1) + + connect_job_timeout / 6); + + // Resolve an AddressList with 6 IPv6 addresses and 6 IPv4 addresses. + host_resolver_->rules()->AddIPLiteralRule("*", + "1:abcd::3:4:ff,2:abcd::3:4:ff," + "3:abcd::3:4:ff,4:abcd::3:4:ff," + "5:abcd::3:4:ff,6:abcd::3:4:ff," + "1.1.1.1,2.2.2.2,3.3.3.3," + "4.4.4.4,5.5.5.5,6.6.6.6", + std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + EXPECT_EQ(ERR_TIMED_OUT, callback.WaitForResult()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, MaxSocketsEnforced) { + host_resolver_->set_synchronous_mode(true); + for (int i = 0; i < kMaxSockets; ++i) { + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle()); + } + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, MaxSocketsEnforcedWhenPending) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + // Now there are 32 sockets waiting to connect, and one stalled. + for (int i = 0; i < kMaxSockets; ++i) { + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(request(i)->handle()->is_initialized()); + EXPECT_TRUE(request(i)->handle()->socket()); + WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle()); + } + // Now there are 32 sockets connected, and one stalled. + base::RunLoop().RunUntilIdle(); + EXPECT_FALSE(request(kMaxSockets)->handle()->is_initialized()); + EXPECT_FALSE(request(kMaxSockets)->handle()->socket()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, StalledSocketReleased) { + host_resolver_->set_synchronous_mode(true); + for (int i = 0; i < kMaxSockets; ++i) { + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle()); + } + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + ReleaseOneConnection(ClientSocketPoolTest::NO_KEEP_ALIVE); + EXPECT_TRUE(request(kMaxSockets)->handle()->is_initialized()); + EXPECT_TRUE(request(kMaxSockets)->handle()->socket()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, IsStalledTrueWhenStalled) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + EXPECT_EQ(OK, request(0)->WaitForResult()); + EXPECT_TRUE(pool_.IsStalled()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + CancellingPendingSocketUnstallsStalledSocket) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + EXPECT_EQ(OK, request(0)->WaitForResult()); + request(1)->handle()->Reset(); + base::RunLoop().RunUntilIdle(); + EXPECT_FALSE(pool_.IsStalled()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + LoadStateOfStalledSocketIsWaitingForAvailableSocket) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET, + pool_.GetLoadState("a", request(kMaxSockets)->handle())); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + CancellingStalledSocketUnstallsPool) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + request(kMaxSockets)->handle()->Reset(); + EXPECT_FALSE(pool_.IsStalled()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + FlushWithErrorFlushesPendingConnections) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + pool_.FlushWithError(ERR_FAILED); + EXPECT_EQ(ERR_FAILED, request(0)->WaitForResult()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + FlushWithErrorFlushesStalledConnections) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + pool_.FlushWithError(ERR_FAILED); + EXPECT_EQ(ERR_FAILED, request(kMaxSockets)->WaitForResult()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + AfterFlushWithErrorCanMakeNewConnections) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + pool_.FlushWithError(ERR_FAILED); + host_resolver_->set_synchronous_mode(true); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); +} + +// Deleting pending connections can release the lock on the endpoint, which can +// in principle lead to other pending connections succeeding. However, when we +// call FlushWithError(), everything should fail. +TEST_F(WebSocketTransportClientSocketPoolTest, + FlushWithErrorDoesNotCauseSuccessfulConnections) { + host_resolver_->set_synchronous_mode(true); + MockTransportClientSocketFactory::ClientSocketType first_type[] = { + // First socket + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET + }; + client_socket_factory_.set_client_socket_types(first_type, + arraysize(first_type)); + // The rest of the sockets will connect synchronously. + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET); + for (int i = 0; i < kMaxSockets; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + // Now we have one socket in STATE_TRANSPORT_CONNECT and the rest in + // STATE_OBTAIN_LOCK. If any of the sockets in STATE_OBTAIN_LOCK is given the + // lock, they will synchronously connect. + pool_.FlushWithError(ERR_FAILED); + for (int i = 0; i < kMaxSockets; ++i) { + EXPECT_EQ(ERR_FAILED, request(i)->WaitForResult()); + } +} + +// This is a regression test for the first attempted fix for +// FlushWithErrorDoesNotCauseSuccessfulConnections. Because a ConnectJob can +// have both IPv4 and IPv6 subjobs, it can be both connecting and waiting for +// the lock at the same time. +TEST_F(WebSocketTransportClientSocketPoolTest, + FlushWithErrorDoesNotCauseSuccessfulConnectionsMultipleAddressTypes) { + host_resolver_->set_synchronous_mode(true); + // The first |kMaxSockets| sockets to connect will be IPv6. Then we will have + // one IPv4. + std::vector socket_types( + kMaxSockets + 1, + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET); + client_socket_factory_.set_client_socket_types(&socket_types[0], + socket_types.size()); + // The rest of the sockets will connect synchronously. + client_socket_factory_.set_client_socket_type( + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET); + for (int i = 0; i < kMaxSockets; ++i) { + host_resolver_->rules()->ClearRules(); + // Each connect job has a different IPv6 address but the same IPv4 address. + // So the IPv6 connections happen in parallel but the IPv4 ones are + // serialised. + host_resolver_->rules()->AddIPLiteralRule("*", + base::StringPrintf( + "%x:abcd::3:4:ff," + "1.1.1.1", + i + 1), + std::string()); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + // Now we have |kMaxSockets| IPv6 sockets stalled in connect. No IPv4 sockets + // are started yet. + RunLoopForTimePeriod(base::TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs)); + // Now we have |kMaxSockets| IPv6 sockets and one IPv4 socket stalled in + // connect, and |kMaxSockets - 1| IPv4 sockets waiting for the endpoint lock. + pool_.FlushWithError(ERR_FAILED); + for (int i = 0; i < kMaxSockets; ++i) { + EXPECT_EQ(ERR_FAILED, request(i)->WaitForResult()); + } +} + +} // namespace + +} // namespace net diff --git a/net/socket/websocket_transport_connect_sub_job.cc b/net/socket/websocket_transport_connect_sub_job.cc new file mode 100644 index 00000000000000..fbe8bbcc82c92c --- /dev/null +++ b/net/socket/websocket_transport_connect_sub_job.cc @@ -0,0 +1,170 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/websocket_transport_connect_sub_job.h" + +#include "base/logging.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/websocket_endpoint_lock_manager.h" + +namespace net { + +WebSocketTransportConnectSubJob::WebSocketTransportConnectSubJob( + const AddressList& addresses, + WebSocketTransportConnectJob* parent_job, + SubJobType type) + : parent_job_(parent_job), + addresses_(addresses), + current_address_index_(0), + next_state_(STATE_NONE), + type_(type) {} + +WebSocketTransportConnectSubJob::~WebSocketTransportConnectSubJob() { + // We don't worry about cancelling the TCP connect, since ~StreamSocket will + // take care of it. + if (next()) { + DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_); + // The ~Waiter destructor will remove this object from the waiting list. + } else if (next_state_ == STATE_TRANSPORT_CONNECT_COMPLETE) { + WebSocketEndpointLockManager::GetInstance()->UnlockEndpoint( + CurrentAddress()); + } +} + +// Start connecting. +int WebSocketTransportConnectSubJob::Start() { + DCHECK_EQ(STATE_NONE, next_state_); + next_state_ = STATE_OBTAIN_LOCK; + return DoLoop(OK); +} + +// Called by WebSocketEndpointLockManager when the lock becomes available. +void WebSocketTransportConnectSubJob::GotEndpointLock() { + DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_); + OnIOComplete(OK); +} + +LoadState WebSocketTransportConnectSubJob::GetLoadState() const { + switch (next_state_) { + case STATE_OBTAIN_LOCK: + case STATE_OBTAIN_LOCK_COMPLETE: + // TODO(ricea): Add a WebSocket-specific LOAD_STATE ? + return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET; + case STATE_TRANSPORT_CONNECT: + case STATE_TRANSPORT_CONNECT_COMPLETE: + case STATE_DONE: + return LOAD_STATE_CONNECTING; + case STATE_NONE: + return LOAD_STATE_IDLE; + } + NOTREACHED(); + return LOAD_STATE_IDLE; +} + +ClientSocketFactory* WebSocketTransportConnectSubJob::client_socket_factory() + const { + return parent_job_->helper_.client_socket_factory(); +} + +const BoundNetLog& WebSocketTransportConnectSubJob::net_log() const { + return parent_job_->net_log(); +} + +const IPEndPoint& WebSocketTransportConnectSubJob::CurrentAddress() const { + DCHECK_LT(current_address_index_, addresses_.size()); + return addresses_[current_address_index_]; +} + +void WebSocketTransportConnectSubJob::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + parent_job_->OnSubJobComplete(rv, this); // |this| deleted +} + +int WebSocketTransportConnectSubJob::DoLoop(int result) { + DCHECK_NE(next_state_, STATE_NONE); + + int rv = result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_OBTAIN_LOCK: + DCHECK_EQ(OK, rv); + rv = DoEndpointLock(); + break; + case STATE_OBTAIN_LOCK_COMPLETE: + DCHECK_EQ(OK, rv); + rv = DoEndpointLockComplete(); + break; + case STATE_TRANSPORT_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoTransportConnect(); + break; + case STATE_TRANSPORT_CONNECT_COMPLETE: + rv = DoTransportConnectComplete(rv); + break; + default: + NOTREACHED(); + rv = ERR_FAILED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE && + next_state_ != STATE_DONE); + + return rv; +} + +int WebSocketTransportConnectSubJob::DoEndpointLock() { + int rv = WebSocketEndpointLockManager::GetInstance()->LockEndpoint( + CurrentAddress(), this); + next_state_ = STATE_OBTAIN_LOCK_COMPLETE; + return rv; +} + +int WebSocketTransportConnectSubJob::DoEndpointLockComplete() { + next_state_ = STATE_TRANSPORT_CONNECT; + return OK; +} + +int WebSocketTransportConnectSubJob::DoTransportConnect() { + // TODO(ricea): Update global g_last_connect_time and report + // ConnectInterval. + next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; + AddressList one_address(CurrentAddress()); + transport_socket_ = client_socket_factory()->CreateTransportClientSocket( + one_address, net_log().net_log(), net_log().source()); + // This use of base::Unretained() is safe because transport_socket_ is + // destroyed in the destructor. + return transport_socket_->Connect(base::Bind( + &WebSocketTransportConnectSubJob::OnIOComplete, base::Unretained(this))); +} + +int WebSocketTransportConnectSubJob::DoTransportConnectComplete(int result) { + next_state_ = STATE_DONE; + WebSocketEndpointLockManager* endpoint_lock_manager = + WebSocketEndpointLockManager::GetInstance(); + if (result != OK) { + endpoint_lock_manager->UnlockEndpoint(CurrentAddress()); + + if (current_address_index_ + 1 < addresses_.size()) { + // Try falling back to the next address in the list. + next_state_ = STATE_OBTAIN_LOCK; + ++current_address_index_; + result = OK; + } + + return result; + } + + endpoint_lock_manager->RememberSocket(transport_socket_.get(), + CurrentAddress()); + + return result; +} + +} // namespace net diff --git a/net/socket/websocket_transport_connect_sub_job.h b/net/socket/websocket_transport_connect_sub_job.h new file mode 100644 index 00000000000000..79980d295f952f --- /dev/null +++ b/net/socket/websocket_transport_connect_sub_job.h @@ -0,0 +1,90 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_WEBSOCKET_TRANSPORT_CONNECT_SUB_JOB_H_ +#define NET_SOCKET_WEBSOCKET_TRANSPORT_CONNECT_SUB_JOB_H_ + +#include "base/compiler_specific.h" +#include "base/macros.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/load_states.h" +#include "net/socket/websocket_endpoint_lock_manager.h" +#include "net/socket/websocket_transport_client_socket_pool.h" + +namespace net { + +class BoundNetLog; +class ClientSocketFactory; +class IPEndPoint; +class StreamSocket; + +// Attempts to connect to a subset of the addresses required by a +// WebSocketTransportConnectJob, specifically either the IPv4 or IPv6 +// addresses. Each address is tried in turn, and parent_job->OnSubJobComplete() +// is called when the first address succeeds or the last address fails. +class WebSocketTransportConnectSubJob + : public WebSocketEndpointLockManager::Waiter { + public: + typedef WebSocketTransportConnectJob::SubJobType SubJobType; + + WebSocketTransportConnectSubJob(const AddressList& addresses, + WebSocketTransportConnectJob* parent_job, + SubJobType type); + + virtual ~WebSocketTransportConnectSubJob(); + + // Start connecting. + int Start(); + + bool started() { return next_state_ != STATE_NONE; } + + LoadState GetLoadState() const; + + SubJobType type() const { return type_; } + + scoped_ptr PassSocket() { return transport_socket_.Pass(); } + + // Implementation of WebSocketEndpointLockManager::EndpointWaiter. + virtual void GotEndpointLock() OVERRIDE; + + private: + enum State { + STATE_NONE, + STATE_OBTAIN_LOCK, + STATE_OBTAIN_LOCK_COMPLETE, + STATE_TRANSPORT_CONNECT, + STATE_TRANSPORT_CONNECT_COMPLETE, + STATE_DONE, + }; + + ClientSocketFactory* client_socket_factory() const; + + const BoundNetLog& net_log() const; + + const IPEndPoint& CurrentAddress() const; + + void OnIOComplete(int result); + int DoLoop(int result); + int DoEndpointLock(); + int DoEndpointLockComplete(); + int DoTransportConnect(); + int DoTransportConnectComplete(int result); + + WebSocketTransportConnectJob* const parent_job_; + + const AddressList addresses_; + size_t current_address_index_; + + State next_state_; + const SubJobType type_; + + scoped_ptr transport_socket_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketTransportConnectSubJob); +}; + +} // namespace net + +#endif // NET_SOCKET_WEBSOCKET_TRANSPORT_CONNECT_SUB_JOB_H_ diff --git a/net/websockets/websocket_basic_handshake_stream.cc b/net/websockets/websocket_basic_handshake_stream.cc index a51fee2a45c773..a19de55f6c07f3 100644 --- a/net/websockets/websocket_basic_handshake_stream.cc +++ b/net/websockets/websocket_basic_handshake_stream.cc @@ -30,6 +30,7 @@ #include "net/http/http_status_code.h" #include "net/http/http_stream_parser.h" #include "net/socket/client_socket_handle.h" +#include "net/socket/websocket_transport_client_socket_pool.h" #include "net/websockets/websocket_basic_stream.h" #include "net/websockets/websocket_deflate_predictor.h" #include "net/websockets/websocket_deflate_predictor_impl.h" @@ -496,6 +497,7 @@ scoped_ptr WebSocketBasicHandshakeStream::Upgrade() { // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make // sure it does not touch it again before it is destroyed. state_.DeleteParser(); + WebSocketTransportClientSocketPool::UnlockEndpoint(state_.connection()); scoped_ptr basic_stream( new WebSocketBasicStream(state_.ReleaseConnection(), state_.read_buf(), diff --git a/net/websockets/websocket_test_util.cc b/net/websockets/websocket_test_util.cc index bfa8980344775b..60593f61fa5e65 100644 --- a/net/websockets/websocket_test_util.cc +++ b/net/websockets/websocket_test_util.cc @@ -136,7 +136,7 @@ WebSocketDeterministicMockClientSocketFactoryMaker::AddSSLSocketDataProvider( } WebSocketTestURLRequestContextHost::WebSocketTestURLRequestContextHost() - : url_request_context_(true) { + : url_request_context_(true), url_request_context_initialized_(false) { url_request_context_.set_client_socket_factory(maker_.factory()); } @@ -154,9 +154,12 @@ void WebSocketTestURLRequestContextHost::AddSSLSocketDataProvider( TestURLRequestContext* WebSocketTestURLRequestContextHost::GetURLRequestContext() { - url_request_context_.Init(); - // A Network Delegate is required to make the URLRequest::Delegate work. - url_request_context_.set_network_delegate(&network_delegate_); + if (!url_request_context_initialized_) { + url_request_context_.Init(); + // A Network Delegate is required to make the URLRequest::Delegate work. + url_request_context_.set_network_delegate(&network_delegate_); + url_request_context_initialized_ = true; + } return &url_request_context_; } diff --git a/net/websockets/websocket_test_util.h b/net/websockets/websocket_test_util.h index 2ad86c08fe0268..e95db1a4040b90 100644 --- a/net/websockets/websocket_test_util.h +++ b/net/websockets/websocket_test_util.h @@ -117,14 +117,14 @@ struct WebSocketTestURLRequestContextHost { scoped_ptr ssl_socket_data); // Call after calling one of SetExpections() or AddRawExpectations(). The - // returned pointer remains owned by this object. This should only be called - // once. + // returned pointer remains owned by this object. TestURLRequestContext* GetURLRequestContext(); private: WebSocketDeterministicMockClientSocketFactoryMaker maker_; TestURLRequestContext url_request_context_; TestNetworkDelegate network_delegate_; + bool url_request_context_initialized_; DISALLOW_COPY_AND_ASSIGN(WebSocketTestURLRequestContextHost); };