Skip to content

Commit

Permalink
feat(batch): abort task (#2757)
Browse files Browse the repository at this point in the history
* feat(batch): abort task

* revision

* make wait_until_task_aborted private and test-only
  • Loading branch information
lmatz authored May 25, 2022
1 parent bd7739d commit 1348deb
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 61 deletions.
12 changes: 11 additions & 1 deletion proto/task_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ message TaskInfo {
CANCELLING = 4;
FINISHED = 5;
FAILED = 6;
ABORTED = 7;
ABORTING = 8;
}
batch_plan.TaskId task_id = 1;
TaskStatus task_status = 2;
Expand All @@ -49,13 +51,20 @@ message CreateTaskResponse {

message AbortTaskRequest {
batch_plan.TaskId task_id = 1;
bool force = 2;
}

message AbortTaskResponse {
common.Status status = 1;
}

message RemoveTaskRequest {
batch_plan.TaskId task_id = 1;
}

message RemoveTaskResponse {
common.Status status = 1;
}

message GetTaskInfoRequest {
batch_plan.TaskId task_id = 1;
}
Expand All @@ -79,6 +88,7 @@ service TaskService {
rpc CreateTask(CreateTaskRequest) returns (CreateTaskResponse);
rpc GetTaskInfo(GetTaskInfoRequest) returns (GetTaskInfoResponse);
rpc AbortTask(AbortTaskRequest) returns (AbortTaskResponse);
rpc RemoveTask(RemoveTaskRequest) returns (RemoveTaskResponse);
}

message GetDataRequest {
Expand Down
34 changes: 31 additions & 3 deletions src/batch/src/rpc/service/task_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;
use risingwave_pb::task_service::task_service_server::TaskService;
use risingwave_pb::task_service::{
AbortTaskRequest, AbortTaskResponse, CreateTaskRequest, CreateTaskResponse, GetTaskInfoRequest,
GetTaskInfoResponse,
GetTaskInfoResponse, RemoveTaskRequest, RemoveTaskResponse,
};
use tonic::{Request, Response, Status};

Expand Down Expand Up @@ -70,8 +70,36 @@ impl TaskService for BatchServiceImpl {
#[cfg_attr(coverage, no_coverage)]
async fn abort_task(
&self,
_: Request<AbortTaskRequest>,
req: Request<AbortTaskRequest>,
) -> Result<Response<AbortTaskResponse>, Status> {
todo!()
let req = req.into_inner();
let res = self
.mgr
.abort_task(req.get_task_id().expect("no task id found"));
match res {
Ok(_) => Ok(Response::new(AbortTaskResponse { status: None })),
Err(e) => {
error!("failed to abort task {}", e);
Err(e.to_grpc_status())
}
}
}

#[cfg_attr(coverage, no_coverage)]
async fn remove_task(
&self,
req: Request<RemoveTaskRequest>,
) -> Result<Response<RemoveTaskResponse>, Status> {
let req = req.into_inner();
let res = self
.mgr
.remove_task(req.get_task_id().expect("no task id found"));
match res {
Ok(_) => Ok(Response::new(RemoveTaskResponse { status: None })),
Err(e) => {
error!("failed to remove task {}", e);
Err(e.to_grpc_status())
}
}
}
}
85 changes: 73 additions & 12 deletions src/batch/src/task/task_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use futures_async_stream::for_await;
use futures::StreamExt;
use parking_lot::Mutex;
use risingwave_common::array::DataChunk;
use risingwave_common::error::{ErrorCode, Result, RwError};
Expand All @@ -24,6 +24,7 @@ use risingwave_pb::batch_plan::{
};
use risingwave_pb::task_service::task_info::TaskStatus;
use risingwave_pb::task_service::GetDataResponse;
use tokio::sync::oneshot::{Receiver, Sender};
use tracing_futures::Instrument;

use crate::executor::ExecutorBuilder;
Expand Down Expand Up @@ -182,6 +183,9 @@ pub struct BatchTaskExecution<C> {
/// The execution failure.
failure: Arc<Mutex<Option<RwError>>>,

/// Shutdown signal sender.
shutdown_tx: Mutex<Option<Sender<u64>>>,

epoch: u64,
}

Expand All @@ -200,6 +204,7 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
context,
failure: Arc::new(Mutex::new(None)),
epoch,
shutdown_tx: Mutex::new(None),
})
}

Expand All @@ -213,7 +218,7 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
/// hash partitioned across multiple channels.
/// To obtain the result, one must pick one of the channels to consume via [`TaskOutputId`]. As
/// such, parallel consumers are able to consume the result idependently.
pub fn async_execute(&self) -> Result<()> {
pub fn async_execute(self: Arc<Self>) -> Result<()> {
trace!(
"Prepare executing plan [{:?}]: {}",
self.task_id,
Expand All @@ -229,6 +234,8 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
.build2()?;

let (sender, receivers) = create_output_channel(self.plan.get_exchange_info()?)?;
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<u64>();
*self.shutdown_tx.lock() = Some(shutdown_tx);
self.receivers
.lock()
.extend(receivers.into_iter().map(Some));
Expand All @@ -243,7 +250,8 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
let join_handle = tokio::spawn(async move {
// We should only pass a reference of sender to execution because we should only
// close it after task error has been set.
if let Err(e) = try_execute(exec, &mut sender)
if let Err(e) = self
.try_execute(exec, &mut sender, shutdown_rx)
.instrument(tracing::trace_span!(
"batch_execute",
task_id = ?task_id.task_id,
Expand All @@ -255,6 +263,7 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
// Prints the entire backtrace of error.
error!("Execution failed [{:?}]: {:?}", &task_id, &e);
*failure.lock() = Some(e);
*self.state.lock() = TaskStatus::Failed;
}
});

Expand All @@ -265,6 +274,58 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
Ok(())
}

pub async fn try_execute(
&self,
root: BoxedExecutor2,
sender: &mut ChanSenderImpl,
mut shutdown_rx: Receiver<u64>,
) -> Result<()> {
let mut data_chunk_stream = root.execute();
loop {
tokio::select! {
// We prioritize abort signal over normal data chunks.
biased;
_ = &mut shutdown_rx => {
sender.send(None).await?;
*self.state.lock() = TaskStatus::Aborted;
break;
}
res = data_chunk_stream.next() => {
match res {
Some(data_chunk) => {
sender.send(Some(data_chunk?)).await?;
}
None => {
debug!("data chunk stream shuts down");
sender.send(None).await?;
break;
}
}
}
}
}
Ok(())
}

pub fn abort_task(&self) -> Result<()> {
let sender = self.shutdown_tx.lock().take().ok_or_else(|| {
ErrorCode::InternalError(format!(
"Task{:?}'s shutdown channel does not exist. \
Either the task has been aborted once, \
or the channel has neven been initialized.",
self.task_id
))
})?;
*self.state.lock() = TaskStatus::Aborting;
sender.send(0).map_err(|err| {
ErrorCode::InternalError(format!(
"Task{:?};s shutdown channel send error:{:?}",
self.task_id, err
))
.into()
})
}

pub fn get_task_output(&self, output_id: &ProstOutputId) -> Result<TaskOutput<C>> {
let task_id = TaskId::from(output_id.get_task_id()?);
let receiver = self.receivers.lock()[output_id.get_output_id() as usize]
Expand Down Expand Up @@ -298,18 +359,18 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
}
Ok(())
}
}

pub async fn try_execute(root: BoxedExecutor2, sender: &mut ChanSenderImpl) -> Result<()> {
#[for_await]
for chunk in root.execute() {
let chunk = chunk?;
if chunk.cardinality() > 0 {
sender.send(Some(chunk)).await?;
pub fn check_if_aborted(&self) -> Result<bool> {
match *self.state.lock() {
TaskStatus::Aborted => Ok(true),
TaskStatus::Finished => Err(ErrorCode::InternalError(format!(
"task {:?} has been finished",
self.get_task_id()
))
.into()),
_ => Ok(false),
}
}
sender.send(None).await?;
Ok(())
}

#[cfg(test)]
Expand Down
93 changes: 83 additions & 10 deletions src/batch/src/task/task_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::task::{BatchTaskExecution, ComputeNodeContext, TaskId, TaskOutput};
#[derive(Clone)]
pub struct BatchManager {
/// Every task id has a corresponding task execution.
tasks: Arc<Mutex<HashMap<TaskId, Box<BatchTaskExecution<ComputeNodeContext>>>>>,
tasks: Arc<Mutex<HashMap<TaskId, Arc<BatchTaskExecution<ComputeNodeContext>>>>>,
}

impl BatchManager {
Expand All @@ -48,10 +48,11 @@ impl BatchManager {
trace!("Received task id: {:?}, plan: {:?}", tid, plan);
let task = BatchTaskExecution::new(tid, plan, context, epoch)?;
let task_id = task.get_task_id().clone();
let task = Arc::new(task);

task.async_execute()?;
task.clone().async_execute()?;
if let hash_map::Entry::Vacant(e) = self.tasks.lock().entry(task_id.clone()) {
e.insert(Box::new(task));
e.insert(task);
Ok(())
} else {
Err(ErrorCode::InternalError(format!(
Expand All @@ -72,11 +73,18 @@ impl BatchManager {
.get_task_output(output_id)
}

#[cfg(test)]
pub fn abort_task(&self, sid: &ProstTaskId) -> Result<()> {
let sid = TaskId::from(sid);
match self.tasks.lock().get(&sid) {
Some(task) => task.abort_task(),
None => Err(TaskNotFound.into()),
}
}

pub fn remove_task(
&self,
sid: &ProstTaskId,
) -> Result<Option<Box<BatchTaskExecution<ComputeNodeContext>>>> {
) -> Result<Option<Arc<BatchTaskExecution<ComputeNodeContext>>>> {
let task_id = TaskId::from(sid);
match self.tasks.lock().remove(&task_id) {
Some(t) => Ok(Some(t)),
Expand All @@ -92,6 +100,32 @@ impl BatchManager {
}
}

pub fn check_if_task_aborted(&self, task_id: &TaskId) -> Result<bool> {
match self.tasks.lock().get(task_id) {
Some(task) => task.check_if_aborted(),
None => Err(TaskNotFound.into()),
}
}

#[cfg(test)]
async fn wait_until_task_aborted(&self, task_id: &TaskId) -> Result<()> {
use std::time::Duration;
loop {
match self.tasks.lock().get(task_id) {
Some(task) => {
let ret = task.check_if_aborted();
match ret {
Ok(true) => return Ok(()),
Ok(false) => {}
Err(err) => return Err(err),
}
}
None => return Err(TaskNotFound.into()),
}
tokio::time::sleep(Duration::from_millis(100)).await
}
}

pub fn get_error(&self, task_id: &TaskId) -> Result<Option<RwError>> {
Ok(self
.tasks
Expand All @@ -110,9 +144,13 @@ impl Default for BatchManager {

#[cfg(test)]
mod tests {
use risingwave_expr::expr::make_i32_literal;
use risingwave_pb::batch_plan::exchange_info::DistributionMode;
use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::TaskOutputId as ProstTaskOutputId;
use risingwave_pb::batch_plan::{
ExchangeInfo, GenerateSeriesNode, PlanFragment, PlanNode, TaskId as ProstTaskId,
TaskOutputId as ProstTaskOutputId, ValuesNode,
};
use tonic::Code;

use crate::task::{BatchManager, ComputeNodeContext, TaskId};
Expand Down Expand Up @@ -151,8 +189,6 @@ mod tests {

#[tokio::test]
async fn test_task_id_conflict() {
use risingwave_pb::batch_plan::*;

let manager = BatchManager::new();
let plan = PlanFragment {
root: Some(PlanNode {
Expand All @@ -169,8 +205,10 @@ mod tests {
}),
};
let context = ComputeNodeContext::new_for_test();
let task_id = TaskId {
..Default::default()
let task_id = ProstTaskId {
query_id: "".to_string(),
stage_id: 0,
task_id: 0,
};
manager
.fire_task(&task_id, plan.clone(), 0, context.clone())
Expand All @@ -180,4 +218,39 @@ mod tests {
.to_string()
.contains("can not create duplicate task with the same id"));
}

#[tokio::test]
async fn test_task_aborted() {
let manager = BatchManager::new();
let plan = PlanFragment {
root: Some(PlanNode {
children: vec![],
identity: "".to_string(),
node_body: Some(NodeBody::GenerateSeries(GenerateSeriesNode {
start: Some(make_i32_literal(1)),
// This is a bit hacky as we want to make sure the task lasts long enough
// for us to abort it.
stop: Some(make_i32_literal(i32::MAX)),
step: Some(make_i32_literal(1)),
})),
}),
exchange_info: Some(ExchangeInfo {
mode: DistributionMode::Single as i32,
distribution: None,
}),
};
let context = ComputeNodeContext::new_for_test();
let task_id = ProstTaskId {
query_id: "".to_string(),
stage_id: 0,
task_id: 0,
};
manager
.fire_task(&task_id, plan.clone(), 0, context.clone())
.unwrap();
manager.abort_task(&task_id).unwrap();
let task_id = TaskId::from(&task_id);
let res = manager.wait_until_task_aborted(&task_id).await;
assert_eq!(res, Ok(()));
}
}
Loading

0 comments on commit 1348deb

Please sign in to comment.