Skip to content

Commit

Permalink
SocketParams refactor 3: Add ClientSocketPool::CreateConnectJob.
Browse files Browse the repository at this point in the history
This new method takes in all the objects needed to create the socket params
for a ConnectJob, and then returns a newly created one by invoking the
ClientSocketPool::SocketParams callbacks (Which does not actually need
all those parameters, as it has a number of them baked in). These parameters
include the GroupID, the ProxyServer, and whether or not the socket is for
use with websockets.

This also required passing in the ProxyServer and websocket information
to TransportClientSocketPool, which didn't previously need that
information.  In a followup CL, I'll make SocketParams no longer a
callback, and have the method do some of what ClientSocketPoolMaager
does.

Bug: 533571
Change-Id: Ie9113ec941f5db3363255bf27a113841d1b30377
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1562613
Commit-Queue: Matt Menke <mmenke@chromium.org>
Reviewed-by: David Benjamin <davidben@chromium.org>
Cr-Commit-Position: refs/heads/master@{#652971}
  • Loading branch information
Matt Menke authored and Commit Bot committed Apr 22, 2019
1 parent 8cc68ad commit aafff54
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 49 deletions.
3 changes: 3 additions & 0 deletions net/http/http_network_transaction_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,8 @@ class CaptureGroupIdTransportSocketPool : public TransportClientSocketPool {
: TransportClientSocketPool(0,
0,
base::TimeDelta(),
ProxyServer::Direct(),
false /* is_for_websockets */,
common_connect_job_params,
nullptr /* ssl_config_service */) {}

Expand Down Expand Up @@ -14355,6 +14357,7 @@ TEST_F(HttpNetworkTransactionTest, MultiRoundAuth) {
50, // Max sockets for pool
1, // Max sockets per group
base::TimeDelta::FromSeconds(10), // unused_idle_socket_timeout
ProxyServer::Direct(), false, // is_for_websockets
&common_connect_job_params, session_deps_.ssl_config_service.get());
auto mock_pool_manager = std::make_unique<MockClientSocketPoolManager>();
mock_pool_manager->SetSocketPool(ProxyServer::Direct(),
Expand Down
2 changes: 2 additions & 0 deletions net/http/http_stream_factory_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ class CapturePreconnectsTransportSocketPool : public TransportClientSocketPool {
: TransportClientSocketPool(0,
0,
base::TimeDelta(),
ProxyServer::Direct(),
false /* is_for_websockets */,
common_connect_job_params,
nullptr /* ssl_config_service */),
last_num_streams_(-1) {}
Expand Down
15 changes: 14 additions & 1 deletion net/socket/client_socket_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,25 @@ void ClientSocketPool::NetLogTcpClientSocketPoolRequestedSocket(
}

std::unique_ptr<base::Value> ClientSocketPool::NetLogGroupIdCallback(
const ClientSocketPool::GroupId* group_id,
const GroupId* group_id,
NetLogCaptureMode /* capture_mode */) {
std::unique_ptr<base::DictionaryValue> event_params(
new base::DictionaryValue());
event_params->SetString("group_id", group_id->ToString());
return event_params;
}

std::unique_ptr<ConnectJob> ClientSocketPool::CreateConnectJob(
GroupId group_id,
scoped_refptr<SocketParams> socket_params,
const ProxyServer& proxy_server,
bool is_for_websockets,
const CommonConnectJobParams* common_connect_job_params,
RequestPriority request_priority,
SocketTag socket_tag,
ConnectJob::Delegate* delegate) {
return socket_params->create_connect_job_callback().Run(
request_priority, socket_tag, common_connect_job_params, delegate);
}

} // namespace net
14 changes: 13 additions & 1 deletion net/socket/client_socket_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "net/http/http_request_info.h"
#include "net/log/net_log_capture_mode.h"
#include "net/socket/connect_job.h"
#include "net/socket/socket_tag.h"

namespace base {
class DictionaryValue;
Expand All @@ -37,6 +38,7 @@ class HttpAuthController;
class HttpProxySocketParams;
class HttpResponseInfo;
class NetLogWithSource;
class ProxyServer;
class SOCKSSocketParams;
class SSLSocketParams;
class StreamSocket;
Expand Down Expand Up @@ -342,9 +344,19 @@ class NET_EXPORT ClientSocketPool : public LowerLayeredPool {

// Utility method to log a GroupId with a NetLog event.
static std::unique_ptr<base::Value> NetLogGroupIdCallback(
const ClientSocketPool::GroupId* group_id,
const GroupId* group_id,
NetLogCaptureMode capture_mode);

static std::unique_ptr<ConnectJob> CreateConnectJob(
GroupId group_id,
scoped_refptr<SocketParams> socket_params,
const ProxyServer& proxy_server,
bool is_for_websockets,
const CommonConnectJobParams* common_connect_job_params,
RequestPriority request_priority,
SocketTag socket_tag,
ConnectJob::Delegate* delegate);

private:
DISALLOW_COPY_AND_ASSIGN(ClientSocketPool);
};
Expand Down
4 changes: 3 additions & 1 deletion net/socket/client_socket_pool_base_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "net/base/load_timing_info_test_util.h"
#include "net/base/net_errors.h"
#include "net/base/privacy_mode.h"
#include "net/base/proxy_server.h"
#include "net/base/request_priority.h"
#include "net/base/test_completion_callback.h"
#include "net/http/http_response_headers.h"
Expand Down Expand Up @@ -572,9 +573,10 @@ class TestConnectJobFactory
// ConnectJobFactory implementation.

std::unique_ptr<ConnectJob> NewConnectJob(
ClientSocketPool::GroupId group_id,
scoped_refptr<ClientSocketPool::SocketParams> socket_params,
RequestPriority request_priority,
SocketTag socket_tag,
scoped_refptr<ClientSocketPool::SocketParams> socket_params,
ConnectJob::Delegate* delegate) const override {
EXPECT_TRUE(!job_types_ || !job_types_->empty());
TestConnectJob::JobType job_type = job_type_;
Expand Down
7 changes: 4 additions & 3 deletions net/socket/client_socket_pool_manager_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,14 @@ ClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPool(
if (pool_type_ == HttpNetworkSession::WEBSOCKET_SOCKET_POOL &&
proxy_server.is_direct()) {
new_pool = std::make_unique<WebSocketTransportClientSocketPool>(
sockets_per_proxy_server, sockets_per_group,
sockets_per_proxy_server, sockets_per_group, proxy_server,
&websocket_common_connect_job_params_);
} else {
new_pool = std::make_unique<TransportClientSocketPool>(
sockets_per_proxy_server, sockets_per_group,
unused_idle_socket_timeout(pool_type_), &common_connect_job_params_,
ssl_config_service_);
unused_idle_socket_timeout(pool_type_), proxy_server,
pool_type_ == HttpNetworkSession::WEBSOCKET_SOCKET_POOL,
&common_connect_job_params_, ssl_config_service_);
}

std::pair<SocketPoolMap::iterator, bool> ret =
Expand Down
3 changes: 3 additions & 0 deletions net/socket/socket_test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "net/base/hex_utils.h"
#include "net/base/ip_address.h"
#include "net/base/load_timing_info.h"
#include "net/base/proxy_server.h"
#include "net/http/http_network_session.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
Expand Down Expand Up @@ -2096,6 +2097,8 @@ MockTransportClientSocketPool::MockTransportClientSocketPool(
max_sockets,
max_sockets_per_group,
base::TimeDelta::FromSeconds(10) /* unused_idle_socket_timeout */,
ProxyServer::Direct(),
false /* is_for_websockets */,
common_connect_job_params,
nullptr /* ssl_config_service */),
client_socket_factory_(common_connect_job_params->client_socket_factory),
Expand Down
60 changes: 36 additions & 24 deletions net/socket/transport_client_socket_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "base/trace_event/process_memory_dump.h"
#include "base/values.h"
#include "net/base/net_errors.h"
#include "net/base/proxy_server.h"
#include "net/log/net_log.h"
#include "net/log/net_log_event_type.h"
#include "net/log/net_log_source.h"
Expand All @@ -46,40 +47,47 @@ std::unique_ptr<base::Value> NetLogCreateConnectJobCallback(
return std::move(dict);
}

} // namespace

// ConnectJobFactory implementation that creates the standard ConnectJob
// classes, using SocketParams.
class TransportConnectJobFactory
class TransportClientSocketPool::ConnectJobFactoryImpl
: public TransportClientSocketPool::ConnectJobFactory {
public:
explicit TransportConnectJobFactory(
const CommonConnectJobParams* common_connect_job_params)
: common_connect_job_params_(common_connect_job_params) {
ConnectJobFactoryImpl(const ProxyServer& proxy_server,
bool is_for_websockets,
const CommonConnectJobParams* common_connect_job_params)
: proxy_server_(proxy_server),
is_for_websockets_(is_for_websockets),
common_connect_job_params_(common_connect_job_params) {
// This class should not be used with WebSockets. Note that
// |common_connect_job_params| may be nullptr in tests.
DCHECK(!common_connect_job_params ||
!common_connect_job_params->websocket_endpoint_lock_manager);
}

~TransportConnectJobFactory() override = default;
~ConnectJobFactoryImpl() override = default;

// ClientSocketPoolBase::ConnectJobFactory methods.
std::unique_ptr<ConnectJob> NewConnectJob(
ClientSocketPool::GroupId group_id,
scoped_refptr<ClientSocketPool::SocketParams> socket_params,
RequestPriority request_priority,
SocketTag socket_tag,
scoped_refptr<ClientSocketPool::SocketParams> socket_params,
ConnectJob::Delegate* delegate) const override {
return socket_params->create_connect_job_callback().Run(
request_priority, socket_tag, common_connect_job_params_, delegate);
return CreateConnectJob(group_id, socket_params, proxy_server_,
is_for_websockets_, common_connect_job_params_,
request_priority, socket_tag, delegate);
}

private:
const ProxyServer proxy_server_;
const bool is_for_websockets_;
const CommonConnectJobParams* common_connect_job_params_;

DISALLOW_COPY_AND_ASSIGN(TransportConnectJobFactory);
DISALLOW_COPY_AND_ASSIGN(ConnectJobFactoryImpl);
};

} // namespace

TransportClientSocketPool::Request::Request(
ClientSocketHandle* handle,
CompletionOnceCallback callback,
Expand Down Expand Up @@ -131,16 +139,20 @@ TransportClientSocketPool::TransportClientSocketPool(
int max_sockets,
int max_sockets_per_group,
base::TimeDelta unused_idle_socket_timeout,
const ProxyServer& proxy_server,
bool is_for_websockets,
const CommonConnectJobParams* common_connect_job_params,
SSLConfigService* ssl_config_service)
: TransportClientSocketPool(max_sockets,
max_sockets_per_group,
unused_idle_socket_timeout,
ClientSocketPool::used_idle_socket_timeout(),
std::make_unique<TransportConnectJobFactory>(
common_connect_job_params),
ssl_config_service,
true /* connect_backup_jobs_enabled */) {}
: TransportClientSocketPool(
max_sockets,
max_sockets_per_group,
unused_idle_socket_timeout,
ClientSocketPool::used_idle_socket_timeout(),
std::make_unique<ConnectJobFactoryImpl>(proxy_server,
is_for_websockets,
common_connect_job_params),
ssl_config_service,
true /* connect_backup_jobs_enabled */) {}

TransportClientSocketPool::~TransportClientSocketPool() {
// Clean up any idle sockets and pending connect jobs. Assert that we have no
Expand Down Expand Up @@ -400,9 +412,9 @@ int TransportClientSocketPool::RequestSocketInternal(const GroupId& group_id,
group = GetOrCreateGroup(group_id);
connecting_socket_count_++;
std::unique_ptr<ConnectJob> owned_connect_job(
connect_job_factory_->NewConnectJob(request.priority(),
request.socket_tag(),
request.socket_params(), group));
connect_job_factory_->NewConnectJob(group_id, request.socket_params(),
request.priority(),
request.socket_tag(), group));
owned_connect_job->net_log().AddEvent(
NetLogEventType::SOCKET_POOL_CONNECT_JOB_CREATED,
base::BindRepeating(&NetLogCreateConnectJobCallback,
Expand Down Expand Up @@ -1515,8 +1527,8 @@ void TransportClientSocketPool::Group::OnBackupJobTimerFired(
Request* request = unbound_requests_.FirstMax().value().get();
std::unique_ptr<ConnectJob> owned_backup_job =
client_socket_pool_base_helper_->connect_job_factory_->NewConnectJob(
request->priority(), request->socket_tag(), request->socket_params(),
this);
group_id, request->socket_params(), request->priority(),
request->socket_tag(), this);
owned_backup_job->net_log().AddEvent(
NetLogEventType::SOCKET_POOL_CONNECT_JOB_CREATED,
base::BindRepeating(&NetLogCreateConnectJobCallback,
Expand Down
8 changes: 7 additions & 1 deletion net/socket/transport_client_socket_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ namespace net {

struct CommonConnectJobParams;
struct NetLogSource;
class ProxyServer;

// TransportClientSocketPool establishes network connections through using
// ConnectJobs, and maintains a list of idle persistent sockets available for
Expand Down Expand Up @@ -150,9 +151,10 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool
virtual ~ConnectJobFactory() {}

virtual std::unique_ptr<ConnectJob> NewConnectJob(
ClientSocketPool::GroupId group_id,
scoped_refptr<ClientSocketPool::SocketParams> socket_params,
RequestPriority request_priority,
SocketTag socket_tag,
scoped_refptr<SocketParams> socket_params,
ConnectJob::Delegate* delegate) const = 0;

private:
Expand All @@ -163,6 +165,8 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool
int max_sockets,
int max_sockets_per_group,
base::TimeDelta unused_idle_socket_timeout,
const ProxyServer& proxy_server,
bool is_for_websockets,
const CommonConnectJobParams* common_connect_job_params,
SSLConfigService* ssl_config_service);

Expand Down Expand Up @@ -264,6 +268,8 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool
void OnIPAddressChanged() override;

private:
class ConnectJobFactoryImpl;

// Entry for a persistent socket which became idle at time |start_time|.
struct IdleSocket {
IdleSocket() : socket(nullptr) {}
Expand Down
25 changes: 16 additions & 9 deletions net/socket/transport_client_socket_pool_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "net/base/load_timing_info_test_util.h"
#include "net/base/net_errors.h"
#include "net/base/privacy_mode.h"
#include "net/base/proxy_server.h"
#include "net/base/test_completion_callback.h"
#include "net/cert/ct_policy_enforcer.h"
#include "net/cert/mock_cert_verifier.h"
Expand Down Expand Up @@ -126,6 +127,7 @@ class TransportClientSocketPoolTest : public ::testing::Test,
common_connect_job_params_->client_socket_factory = &client_socket_factory_;
pool_ = std::make_unique<TransportClientSocketPool>(
kMaxSockets, kMaxSocketsPerGroup, kUnusedIdleSocketTimeout,
ProxyServer::Direct(), false /* is_for_websockets */,
common_connect_job_params_.get(),
session_deps_.ssl_config_service.get());

Expand All @@ -136,6 +138,7 @@ class TransportClientSocketPoolTest : public ::testing::Test,
&tagging_client_socket_factory_;
tagging_pool_ = std::make_unique<TransportClientSocketPool>(
kMaxSockets, kMaxSocketsPerGroup, kUnusedIdleSocketTimeout,
ProxyServer::Direct(), false /* is_for_websockets */,
tagging_common_connect_job_params_.get(),
session_deps_.ssl_config_service.get());

Expand All @@ -146,6 +149,7 @@ class TransportClientSocketPoolTest : public ::testing::Test,
ClientSocketFactory::GetDefaultFactory();
pool_for_real_sockets_ = std::make_unique<TransportClientSocketPool>(
kMaxSockets, kMaxSocketsPerGroup, kUnusedIdleSocketTimeout,
ProxyServer::Direct(), false /* is_for_websockets */,
common_connect_job_params_for_real_sockets_.get(),
session_deps_.ssl_config_service.get());
}
Expand Down Expand Up @@ -505,9 +509,10 @@ TEST_F(TransportClientSocketPoolTest, ReprioritizeRequests) {
}

TEST_F(TransportClientSocketPoolTest, RequestIgnoringLimitsIsReprioritized) {
TransportClientSocketPool pool(kMaxSockets, 1, kUnusedIdleSocketTimeout,
common_connect_job_params_.get(),
nullptr /* ssl_config_service */);
TransportClientSocketPool pool(
kMaxSockets, 1, kUnusedIdleSocketTimeout, ProxyServer::Direct(),
false /* is_for_websockets */, common_connect_job_params_.get(),
nullptr /* ssl_config_service */);

// Creates a job which ignores limits whose priority is MAXIMUM_PRIORITY.
TestCompletionCallback callback1;
Expand Down Expand Up @@ -1312,9 +1317,10 @@ TEST_F(TransportClientSocketPoolTest, SpdyOneConnectJobTwoRequestsError) {
session_deps_.host_resolver->set_synchronous_mode(true);

// Create a socket pool which only allows a single connection at a time.
TransportClientSocketPool pool(1, 1, kUnusedIdleSocketTimeout,
tagging_common_connect_job_params_.get(),
session_deps_.ssl_config_service.get());
TransportClientSocketPool pool(
1, 1, kUnusedIdleSocketTimeout, ProxyServer::Direct(),
false /* is_for_websockets */, tagging_common_connect_job_params_.get(),
session_deps_.ssl_config_service.get());

// First connection attempt will get an error after creating the SpdyStream.

Expand Down Expand Up @@ -1417,9 +1423,10 @@ TEST_F(TransportClientSocketPoolTest, SpdyAuthOneConnectJobTwoRequests) {
session_deps_.host_resolver->set_synchronous_mode(true);

// Create a socket pool which only allows a single connection at a time.
TransportClientSocketPool pool(1, 1, kUnusedIdleSocketTimeout,
tagging_common_connect_job_params_.get(),
session_deps_.ssl_config_service.get());
TransportClientSocketPool pool(
1, 1, kUnusedIdleSocketTimeout, ProxyServer::Direct(),
false /* is_for_websockets */, tagging_common_connect_job_params_.get(),
session_deps_.ssl_config_service.get());

SpdyTestUtil spdy_util;
spdy::SpdySerializedFrame connect(spdy_util.ConstructSpdyConnect(
Expand Down
Loading

0 comments on commit aafff54

Please sign in to comment.