Skip to content

Commit

Permalink
use ControlMessage instead of previous TaskMessage hack
Browse files Browse the repository at this point in the history
  • Loading branch information
anfeng committed May 9, 2013
1 parent 6e96a1f commit 22c7a8d
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 90 deletions.
29 changes: 4 additions & 25 deletions storm-core/src/jvm/backtype/storm/messaging/TaskMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import java.nio.ByteBuffer;

public class TaskMessage {
final int SHORT_SIZE = 2;
final int INT_SIZE = 4;
private int _task;
private byte[] _message;

Expand All @@ -22,36 +20,17 @@ public byte[] message() {
}

public ByteBuffer serialize() {
ByteBuffer bb = ByteBuffer.allocate(_message.length+SHORT_SIZE+INT_SIZE);
ByteBuffer bb = ByteBuffer.allocate(_message.length+2);
bb.putShort((short)_task);
if (_message==null)
bb.putInt(0);
else {
bb.putInt(_message.length);
bb.put(_message);
}
bb.put(_message);
return bb;
}

public void deserialize(ByteBuffer packet) {
if (packet==null) return;
_task = packet.getShort();
int len = packet.getInt();
if (len ==0)
_message = null;
else {
_message = new byte[len];
packet.get(_message);
}
_message = new byte[packet.limit()-2];
packet.get(_message);
}

public String toString() {
StringBuffer buf = new StringBuffer();
buf.append("task:");
buf.append(_task);
buf.append(" message size:");
if (_message!=null) buf.append(_message.length);
else buf.append(0);
return buf.toString();
}
}
17 changes: 9 additions & 8 deletions storm-netty/src/jvm/backtype/storm/messaging/netty/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Client implements IConnection {
private final int max_retries;
private final int base_sleep_ms;
private final int max_sleep_ms;
private LinkedBlockingQueue<TaskMessage> message_queue;
private LinkedBlockingQueue<Object> message_queue; //entry should either be TaskMessage or ControlMessage
private AtomicReference<Channel> channelRef;
private final ClientBootstrap bootstrap;
private InetSocketAddress remote_addr;
Expand All @@ -42,7 +42,7 @@ class Client implements IConnection {

@SuppressWarnings("rawtypes")
Client(Map storm_conf, String host, int port) {
message_queue = new LinkedBlockingQueue<TaskMessage>();
message_queue = new LinkedBlockingQueue<Object>();
retries = new AtomicInteger(0);
channelRef = new AtomicReference<Channel>(null);
being_closed = new AtomicBoolean(false);
Expand Down Expand Up @@ -120,16 +120,17 @@ public void send(int task, byte[] message) {
* @return
* @throws InterruptedException
*/
ArrayList<TaskMessage> takeMessages() throws InterruptedException {
ArrayList<Object> takeMessages() throws InterruptedException {
int size = 0;
ArrayList<TaskMessage> requests = new ArrayList<TaskMessage>();
ArrayList<Object> requests = new ArrayList<Object>();
requests.add(message_queue.take());
for (TaskMessage msg = message_queue.poll(); msg!=null; msg = message_queue.poll()) {
for (Object msg = message_queue.poll(); msg!=null; msg = message_queue.poll()) {
requests.add(msg);
//we will discard any message after CLOSE
if (msg==Util.CLOSE_MESSAGE) break;
if (msg==ControlMessage.CLOSE_MESSAGE) break;
//we limit the batch per buffer size
size += (msg.message()!=null? msg.message().length : 0) + 6; //INT + SHORT + payload
TaskMessage taskMsg = (TaskMessage) msg;
size += (taskMsg.message()!=null? taskMsg.message().length : 0) + 6; //INT + SHORT + payload
if (size > buffer_size)
break;
}
Expand All @@ -144,7 +145,7 @@ ArrayList<TaskMessage> takeMessages() throws InterruptedException {
public void close() {
//enqueue a SHUTDOWN message so that shutdown() will be invoked
try {
message_queue.put(Util.CLOSE_MESSAGE);
message_queue.put(ControlMessage.CLOSE_MESSAGE);
being_closed.set(true);
} catch (InterruptedException e) {
close_n_release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import java.util.Map;
import java.util.Vector;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import backtype.storm.messaging.IConnection;
import backtype.storm.messaging.IContext;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package backtype.storm.messaging.netty;

class ControlMessage {
static final short BASE_CODE = -100;
static final short OK = -200; //HTTP status: 200
static final short EOB = -201; //end of a batch
static final short FAILURE = -400; //HTTP status: 400 BAD REQUEST
static final short CLOSE = -410; //HTTP status: 410 GONE

static final ControlMessage CLOSE_MESSAGE = new ControlMessage(CLOSE);
static final ControlMessage EOB_MESSAGE = new ControlMessage(EOB);
static final ControlMessage OK_RESPONSE = new ControlMessage(OK);
static final ControlMessage FAILURE_RESPONSE = new ControlMessage(FAILURE);

private short code;

ControlMessage() {
code = OK;
}

ControlMessage(short code) {
assert(code<BASE_CODE);
this.code = code;
}

short code() {
return code;
}

public String toString() {
switch (code) {
case OK: return "OK";
case EOB: return "END_OF_BATCH";
case FAILURE: return "FAILURE";
case CLOSE: return "CLOSE";
default: return "control message w/ code " + code;
}
}

public boolean equals(Object obj) {
if (obj == null) return false;
if (obj instanceof ControlMessage)
return ((ControlMessage)obj).code == code;
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.handler.codec.frame.FrameDecoder;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import backtype.storm.messaging.TaskMessage;

public class TaskMessageDecoder extends FrameDecoder {
public class MessageDecoder extends FrameDecoder {
private static final Logger LOG = LoggerFactory.getLogger(MessageDecoder.class);

/*
* Each ControlMessage is encoded as:
* code (<0) ... short(2)
* Each TaskMessage is encoded as:
* task ... short(2)
* task (>=0) ... short(2)
* len ... int(4)
* payload ... byte[] *
*/
protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buf) throws Exception {
// Make sure if both task and len were received.
if (buf.readableBytes() < 6) {
// Make sure that we have received at least a short
if (buf.readableBytes() < 2) {
//need more data
return null;
}
Expand All @@ -27,9 +32,26 @@ protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffe
// there's not enough bytes in the buffer.
buf.markReaderIndex();

//read task field
short task = buf.readShort();
//read the short field
short code = buf.readShort();

//case 1: Control message if val<0
if (code<=ControlMessage.BASE_CODE) {
ControlMessage ctrl_msg = new ControlMessage(code);
LOG.debug("Control message:"+ctrl_msg);
return ctrl_msg;
}

//case 2: task Message
short task = code;

// Make sure that we have received at least an integer (length)
if (buf.readableBytes() < 4) {
//need more data
buf.resetReaderIndex();
return null;
}

// Read the length field.
int length = buf.readInt();
if (length==0) {
Expand All @@ -39,13 +61,7 @@ protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffe
// Make sure if there's enough bytes in the buffer.
if (buf.readableBytes() < length) {
// The whole bytes were not received yet - return null.
// This method will be invoked again when more packets are
// received and appended to the buffer.

// Reset to the marked position to read the length field again
// next time.
buf.resetReaderIndex();

return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,29 @@
import backtype.storm.messaging.TaskMessage;
import backtype.storm.utils.Utils;

public class TaskMessageEncoder extends OneToOneEncoder {
public class MessageEncoder extends OneToOneEncoder {
int estimated_buffer_size;

@SuppressWarnings("rawtypes")
TaskMessageEncoder(Map conf) {
MessageEncoder(Map conf) {
estimated_buffer_size = Utils.getInt(conf.get(Config.STORM_MESSAGING_NETTY_BUFFER_SIZE));
}

@SuppressWarnings("unchecked")
@Override
protected Object encode(ChannelHandlerContext ctx, Channel channel, Object obj) throws Exception {

if (obj instanceof ControlMessage) {
ControlMessage message = (ControlMessage)obj;
ChannelBufferOutputStream bout =
new ChannelBufferOutputStream(ChannelBuffers.dynamicBuffer(
estimated_buffer_size, ctx.getChannel().getConfig().getBufferFactory()));
writeControlMessage(bout, message);
bout.close();

return bout.buffer();
}

if (obj instanceof TaskMessage) {
TaskMessage message = (TaskMessage)obj;
ChannelBufferOutputStream bout =
Expand All @@ -43,6 +54,7 @@ protected Object encode(ChannelHandlerContext ctx, Channel channel, Object obj)
ArrayList<TaskMessage> messages = (ArrayList<TaskMessage>) obj;
for (TaskMessage message : messages)
writeTaskMessage(bout, message);
writeControlMessage(bout, ControlMessage.EOB_MESSAGE);
bout.close();

return bout.buffer();
Expand All @@ -69,4 +81,14 @@ private void writeTaskMessage(ChannelBufferOutputStream bout, TaskMessage messag
if (payload_len >0)
bout.write(message.message());
}

/**
* write a ControlMessage into a stream
*
* Each TaskMessage is encoded as:
* code ... short(2)
*/
private void writeControlMessage(ChannelBufferOutputStream bout, ControlMessage message) throws Exception {
bout.writeShort(message.code());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import backtype.storm.messaging.TaskMessage;

class StormClientHandler extends SimpleChannelUpstreamHandler {
private static final Logger LOG = LoggerFactory.getLogger(StormClientHandler.class);
private Client client;
Expand Down Expand Up @@ -49,8 +47,8 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent event) {
LOG.debug("send/recv time (ms):"+(System.currentTimeMillis() - start_time));

//examine the response message from server
TaskMessage msg = (TaskMessage)event.getMessage();
if (msg.task()!=Util.OK)
ControlMessage msg = (ControlMessage)event.getMessage();
if (msg.equals(ControlMessage.FAILURE_RESPONSE))
LOG.info("failure response:"+msg);

//send next request
Expand All @@ -66,12 +64,12 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent event) {
* Retrieve a request from message queue, and send to server
* @param channel
*/
private void sendRequests(Channel channel, final ArrayList<TaskMessage> requests) {
private void sendRequests(Channel channel, final ArrayList<Object> requests) {
if (being_closed.get()) return;

//if task==CLOSE_MESSAGE for our last request, the channel is to be closed
TaskMessage last_msg = requests.get(requests.size()-1);
if (last_msg==Util.CLOSE_MESSAGE) {
Object last_msg = requests.get(requests.size()-1);
if (last_msg==ControlMessage.CLOSE_MESSAGE) {
being_closed.set(true);
requests.remove(last_msg);
}
Expand All @@ -83,8 +81,6 @@ private void sendRequests(Channel channel, final ArrayList<TaskMessage> requests
return;
}

//add an EOB_MESSAGE to the end of our batch
requests.add(Util.EOB_MESSAGE);
//write request into socket channel
ChannelFuture future = channel.write(requests);
future.addListener(new ChannelFutureListener() {
Expand All @@ -94,7 +90,7 @@ public void operationComplete(ChannelFuture future)
LOG.info("failed to send requests:", future.getCause());
future.getChannel().close();
} else {
LOG.debug((requests.size()-1) + " request(s) sent");
LOG.debug(requests.size() + " request(s) sent");
}
if (being_closed.get())
client.close_n_release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public ChannelPipeline getPipeline() throws Exception {
ChannelPipeline pipeline = Channels.pipeline();

// Decoder
pipeline.addLast("decoder", new TaskMessageDecoder());
pipeline.addLast("decoder", new MessageDecoder());
// Encoder
pipeline.addLast("encoder", new TaskMessageEncoder(conf));
pipeline.addLast("encoder", new MessageEncoder(conf));
// business logic.
pipeline.addLast("handler", new StormClientHandler(client));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@ public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) {

@Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) {
TaskMessage message = (TaskMessage)e.getMessage();
if (message == null) return;
Object msg = e.getMessage();
if (msg == null) return;

//end of batch?
if (message.task() == Util.EOB) {
if (ControlMessage.EOB_MESSAGE.equals(msg)) {
Channel channel = ctx.getChannel();
LOG.debug("Sendback response ...");
LOG.debug("Send back response ...");
if (failure_count.get()==0)
channel.write(Util.OK_RESPONSE);
else channel.write(Util.FAILURE_RESPONSE);
channel.write(ControlMessage.OK_RESPONSE);
else channel.write(ControlMessage.FAILURE_RESPONSE);
return;
}

//enqueue the received message for processing
try {
server.enqueue(message);
server.enqueue((TaskMessage)msg);
} catch (InterruptedException e1) {
LOG.info("failed to enqueue a request message", e);
failure_count.incrementAndGet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ public ChannelPipeline getPipeline() throws Exception {
ChannelPipeline pipeline = Channels.pipeline();

// Decoder
pipeline.addLast("decoder", new TaskMessageDecoder());
pipeline.addLast("decoder", new MessageDecoder());
// Encoder
pipeline.addLast("encoder", new TaskMessageEncoder(conf));
pipeline.addLast("encoder", new MessageEncoder(conf));
// business logic.
pipeline.addLast("handler", new StormServerHandler(server));

Expand Down
Loading

0 comments on commit 22c7a8d

Please sign in to comment.