Skip to content

Commit

Permalink
Fix MessageReader to pass errors to the channel
Browse files Browse the repository at this point in the history
Previously MessageReader was stopping reading after the first error,
but wasn't notifying the client about the problem. This results in some
errors (e.g. from SSL layer) being ignores while they should terminate
connection.

BUG=487451

Review URL: https://codereview.chromium.org/1143443003

Cr-Commit-Position: refs/heads/master@{#329780}
  • Loading branch information
SergeyUlanov authored and Commit bot committed May 14, 2015
1 parent 915be3a commit a04c046
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 53 deletions.
11 changes: 7 additions & 4 deletions remoting/protocol/channel_dispatcher_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,17 @@ void ChannelDispatcherBase::OnChannelReady(

channel_factory_ = nullptr;
channel_ = socket.Pass();
writer_.Init(channel_.get(), base::Bind(&ChannelDispatcherBase::OnWriteFailed,
base::Unretained(this)));
reader_.StartReading(channel_.get());
writer_.Init(channel_.get(),
base::Bind(&ChannelDispatcherBase::OnReadWriteFailed,
base::Unretained(this)));
reader_.StartReading(channel_.get(),
base::Bind(&ChannelDispatcherBase::OnReadWriteFailed,
base::Unretained(this)));

event_handler_->OnChannelInitialized(this);
}

void ChannelDispatcherBase::OnWriteFailed(int error) {
void ChannelDispatcherBase::OnReadWriteFailed(int error) {
event_handler_->OnChannelError(this, CHANNEL_CONNECTION_ERROR);
}

Expand Down
2 changes: 1 addition & 1 deletion remoting/protocol/channel_dispatcher_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ChannelDispatcherBase {

private:
void OnChannelReady(scoped_ptr<net::StreamSocket> socket);
void OnWriteFailed(int error);
void OnReadWriteFailed(int error);

std::string channel_name_;
StreamChannelFactory* channel_factory_;
Expand Down
72 changes: 46 additions & 26 deletions remoting/protocol/channel_multiplexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "base/bind.h"
#include "base/callback.h"
#include "base/callback_helpers.h"
#include "base/location.h"
#include "base/single_thread_task_runner.h"
#include "base/stl_util.h"
Expand Down Expand Up @@ -79,7 +80,7 @@ class ChannelMultiplexer::MuxChannel {
scoped_ptr<net::StreamSocket> CreateSocket();
void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
const base::Closure& done_task);
void OnWriteFailed();
void OnBaseChannelError(int error);

// Called by MuxSocket.
void OnSocketDestroyed();
Expand Down Expand Up @@ -107,7 +108,7 @@ class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
~MuxSocket() override;

void OnWriteComplete();
void OnWriteFailed();
void OnBaseChannelError(int error);
void OnPacketReceived();

// net::StreamSocket interface.
Expand Down Expand Up @@ -168,6 +169,8 @@ class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
private:
MuxChannel* channel_;

int base_channel_error_ = net::OK;

net::CompletionCallback read_callback_;
scoped_refptr<net::IOBuffer> read_buffer_;
int read_buffer_size_;
Expand Down Expand Up @@ -220,9 +223,9 @@ void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
}
}

void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
if (socket_)
socket_->OnWriteFailed();
socket_->OnBaseChannelError(error);
}

void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
Expand Down Expand Up @@ -276,6 +279,9 @@ int ChannelMultiplexer::MuxSocket::Read(
DCHECK(CalledOnValidThread());
DCHECK(read_callback_.is_null());

if (base_channel_error_ != net::OK)
return base_channel_error_;

int result = channel_->DoRead(buffer, buffer_len);
if (result == 0) {
read_buffer_ = buffer;
Expand All @@ -290,6 +296,10 @@ int ChannelMultiplexer::MuxSocket::Write(
net::IOBuffer* buffer, int buffer_len,
const net::CompletionCallback& callback) {
DCHECK(CalledOnValidThread());
DCHECK(write_callback_.is_null());

if (base_channel_error_ != net::OK)
return base_channel_error_;

scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
size_t size = std::min(kMaxPacketSize, buffer_len);
Expand Down Expand Up @@ -317,29 +327,36 @@ int ChannelMultiplexer::MuxSocket::Write(

void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
write_pending_ = false;
if (!write_callback_.is_null()) {
net::CompletionCallback cb;
std::swap(cb, write_callback_);
cb.Run(write_result_);
}
if (!write_callback_.is_null())
base::ResetAndReturn(&write_callback_).Run(write_result_);

}

void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
if (!write_callback_.is_null()) {
net::CompletionCallback cb;
std::swap(cb, write_callback_);
cb.Run(net::ERR_FAILED);
void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
base_channel_error_ = error;

// Here only one of the read and write callbacks is called if both of them are
// pending. Ideally both of them should be called in that case, but that would
// require the second one to be called asynchronously which would complicate
// this code. Channels handle read and write errors the same way (see
// ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
// callbacks is enough.

if (!read_callback_.is_null()) {
base::ResetAndReturn(&read_callback_).Run(error);
return;
}

if (!write_callback_.is_null())
base::ResetAndReturn(&write_callback_).Run(error);
}

void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
if (!read_callback_.is_null()) {
int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
read_buffer_ = nullptr;
DCHECK_GT(result, 0);
net::CompletionCallback cb;
std::swap(cb, read_callback_);
cb.Run(result);
base::ResetAndReturn(&read_callback_).Run(result);
}
}

Expand Down Expand Up @@ -403,9 +420,11 @@ void ChannelMultiplexer::OnBaseChannelReady(

if (base_channel_.get()) {
// Initialize reader and writer.
reader_.StartReading(base_channel_.get());
reader_.StartReading(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnBaseChannelError,
base::Unretained(this)));
writer_.Init(base_channel_.get(),
base::Bind(&ChannelMultiplexer::OnWriteFailed,
base::Bind(&ChannelMultiplexer::OnBaseChannelError,
base::Unretained(this)));
}

Expand Down Expand Up @@ -447,20 +466,21 @@ ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
}


void ChannelMultiplexer::OnWriteFailed(int error) {
void ChannelMultiplexer::OnBaseChannelError(int error) {
for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
it != channels_.end(); ++it) {
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed,
weak_factory_.GetWeakPtr(), it->second->name()));
FROM_HERE,
base::Bind(&ChannelMultiplexer::NotifyBaseChannelError,
weak_factory_.GetWeakPtr(), it->second->name(), error));
}
}

void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) {
void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
int error) {
std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
if (it != channels_.end()) {
it->second->OnWriteFailed();
}
if (it != channels_.end())
it->second->OnBaseChannelError(error);
}

void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
Expand Down
9 changes: 5 additions & 4 deletions remoting/protocol/channel_multiplexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ class ChannelMultiplexer : public StreamChannelFactory {
// Helper method used to create channels.
MuxChannel* GetOrCreateChannel(const std::string& name);

// Error handling callback for |writer_|.
void OnWriteFailed(int error);
// Error handling callback for |reader_| and |writer_|.
void OnBaseChannelError(int error);

// Failed write notifier, queued asynchronously by OnWriteFailed().
void NotifyWriteFailed(const std::string& name);
// Propagates base channel error to channel |name|, queued asynchronously by
// OnBaseChannelError().
void NotifyBaseChannelError(const std::string& name, int error);

// Callback for |reader_;
void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
Expand Down
9 changes: 8 additions & 1 deletion remoting/protocol/client_video_dispatcher_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ClientVideoDispatcherTest : public testing::Test,

protected:
void OnVideoAck(scoped_ptr<VideoAck> ack, const base::Closure& done);
void OnReadError(int error);

base::MessageLoop message_loop_;

Expand Down Expand Up @@ -72,7 +73,9 @@ ClientVideoDispatcherTest::ClientVideoDispatcherTest()
DCHECK(initialized_);
host_socket_.PairWith(
session_.fake_channel_factory().GetFakeChannel(kVideoChannelName));
reader_.StartReading(&host_socket_);
reader_.StartReading(&host_socket_,
base::Bind(&ClientVideoDispatcherTest::OnReadError,
base::Unretained(this)));
writer_.Init(&host_socket_, BufferedSocketWriter::WriteFailedCallback());
}

Expand Down Expand Up @@ -101,6 +104,10 @@ void ClientVideoDispatcherTest::OnVideoAck(scoped_ptr<VideoAck> ack,
done.Run();
}

void ClientVideoDispatcherTest::OnReadError(int error) {
LOG(FATAL) << "Unexpected read error: " << error;
}

// Verify that the client can receive video packets and acks are not sent for
// VideoPackets that don't have frame_id field set.
TEST_F(ClientVideoDispatcherTest, WithoutAcks) {
Expand Down
34 changes: 25 additions & 9 deletions remoting/protocol/message_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,32 @@ void MessageReader::SetMessageReceivedCallback(
message_received_callback_ = callback;
}

void MessageReader::StartReading(net::Socket* socket) {
void MessageReader::StartReading(
net::Socket* socket,
const ReadFailedCallback& read_failed_callback) {
DCHECK(CalledOnValidThread());
DCHECK(socket);
DCHECK(!read_failed_callback.is_null());

socket_ = socket;
read_failed_callback_ = read_failed_callback;
DoRead();
}

void MessageReader::DoRead() {
DCHECK(CalledOnValidThread());
// Don't try to read again if there is another read pending or we
// have messages that we haven't finished processing yet.
while (!closed_ && !read_pending_ && pending_messages_ == 0) {
bool read_succeeded = true;
while (read_succeeded && !closed_ && !read_pending_ &&
pending_messages_ == 0) {
read_buffer_ = new net::IOBuffer(kReadBufferSize);
int result = socket_->Read(
read_buffer_.get(),
kReadBufferSize,
base::Bind(&MessageReader::OnRead, weak_factory_.GetWeakPtr()));
HandleReadResult(result);

HandleReadResult(result, &read_succeeded);
}
}

Expand All @@ -65,26 +73,34 @@ void MessageReader::OnRead(int result) {
read_pending_ = false;

if (!closed_) {
HandleReadResult(result);
DoRead();
bool read_succeeded;
HandleReadResult(result, &read_succeeded);
if (read_succeeded)
DoRead();
}
}

void MessageReader::HandleReadResult(int result) {
void MessageReader::HandleReadResult(int result, bool* read_succeeded) {
DCHECK(CalledOnValidThread());
if (closed_)
return;

*read_succeeded = true;

if (result > 0) {
OnDataReceived(read_buffer_.get(), result);
*read_succeeded = true;
} else if (result == net::ERR_IO_PENDING) {
read_pending_ = true;
} else {
if (result != net::ERR_CONNECTION_CLOSED) {
LOG(ERROR) << "Read() returned error " << result;
}
DCHECK_LT(result, 0);

// Stop reading after any error.
closed_ = true;
*read_succeeded = false;

LOG(ERROR) << "Read() returned error " << result;
read_failed_callback_.Run(result);
}
}

Expand Down
8 changes: 6 additions & 2 deletions remoting/protocol/message_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class MessageReader : public base::NonThreadSafe {
public:
typedef base::Callback<void(scoped_ptr<CompoundBuffer>, const base::Closure&)>
MessageReceivedCallback;
typedef base::Callback<void(int)> ReadFailedCallback;

MessageReader();
virtual ~MessageReader();
Expand All @@ -43,16 +44,19 @@ class MessageReader : public base::NonThreadSafe {
void SetMessageReceivedCallback(const MessageReceivedCallback& callback);

// Starts reading from |socket|.
void StartReading(net::Socket* socket);
void StartReading(net::Socket* socket,
const ReadFailedCallback& read_failed_callback);

private:
void DoRead();
void OnRead(int result);
void HandleReadResult(int result);
void HandleReadResult(int result, bool* read_succeeded);
void OnDataReceived(net::IOBuffer* data, int data_size);
void RunCallback(scoped_ptr<CompoundBuffer> message);
void OnMessageDone();

ReadFailedCallback read_failed_callback_;

net::Socket* socket_;

// Set to true, when we have a socket read pending, and expecting
Expand Down
18 changes: 12 additions & 6 deletions remoting/protocol/message_reader_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class MessageReaderTest : public testing::Test {
void InitReader() {
reader_->SetMessageReceivedCallback(
base::Bind(&MessageReaderTest::OnMessage, base::Unretained(this)));
reader_->StartReading(&socket_);
reader_->StartReading(&socket_, base::Bind(&MessageReaderTest::OnReadError,
base::Unretained(this)));
}

void AddMessage(const std::string& message) {
Expand All @@ -92,6 +93,11 @@ class MessageReaderTest : public testing::Test {
return result == expected;
}

void OnReadError(int error) {
read_error_ = error;
reader_.reset();
}

void OnMessage(scoped_ptr<CompoundBuffer> buffer,
const base::Closure& done_callback) {
messages_.push_back(buffer.release());
Expand All @@ -102,6 +108,7 @@ class MessageReaderTest : public testing::Test {
scoped_ptr<MessageReader> reader_;
FakeStreamSocket socket_;
MockMessageReceivedCallback callback_;
int read_error_ = 0;
std::vector<CompoundBuffer*> messages_;
bool in_callback_;
};
Expand Down Expand Up @@ -281,13 +288,12 @@ TEST_F(MessageReaderTest, TwoMessages_Separately) {
TEST_F(MessageReaderTest, ReadError) {
socket_.AppendReadError(net::ERR_FAILED);

// Add a message. It should never be read after the error above.
AddMessage(kTestMessage1);

EXPECT_CALL(callback_, OnMessage(_))
.Times(0);
EXPECT_CALL(callback_, OnMessage(_)).Times(0);

InitReader();

EXPECT_EQ(net::ERR_FAILED, read_error_);
EXPECT_FALSE(reader_);
}

// Verify that we the OnMessage callback is not reentered.
Expand Down

0 comments on commit a04c046

Please sign in to comment.