Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(batch): abort task #2757

Merged
merged 3 commits into from
May 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion proto/task_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ message TaskInfo {
CANCELLING = 4;
FINISHED = 5;
FAILED = 6;
ABORTED = 7;
ABORTING = 8;
}
batch_plan.TaskId task_id = 1;
TaskStatus task_status = 2;
Expand All @@ -49,13 +51,20 @@ message CreateTaskResponse {

message AbortTaskRequest {
batch_plan.TaskId task_id = 1;
bool force = 2;
}

message AbortTaskResponse {
common.Status status = 1;
}

message RemoveTaskRequest {
batch_plan.TaskId task_id = 1;
}

message RemoveTaskResponse {
common.Status status = 1;
}

message GetTaskInfoRequest {
batch_plan.TaskId task_id = 1;
}
Expand All @@ -79,6 +88,7 @@ service TaskService {
rpc CreateTask(CreateTaskRequest) returns (CreateTaskResponse);
rpc GetTaskInfo(GetTaskInfoRequest) returns (GetTaskInfoResponse);
rpc AbortTask(AbortTaskRequest) returns (AbortTaskResponse);
rpc RemoveTask(RemoveTaskRequest) returns (RemoveTaskResponse);
}

message GetDataRequest {
Expand Down
34 changes: 31 additions & 3 deletions src/batch/src/rpc/service/task_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;
use risingwave_pb::task_service::task_service_server::TaskService;
use risingwave_pb::task_service::{
AbortTaskRequest, AbortTaskResponse, CreateTaskRequest, CreateTaskResponse, GetTaskInfoRequest,
GetTaskInfoResponse,
GetTaskInfoResponse, RemoveTaskRequest, RemoveTaskResponse,
};
use tonic::{Request, Response, Status};

Expand Down Expand Up @@ -70,8 +70,36 @@ impl TaskService for BatchServiceImpl {
#[cfg_attr(coverage, no_coverage)]
async fn abort_task(
&self,
_: Request<AbortTaskRequest>,
req: Request<AbortTaskRequest>,
) -> Result<Response<AbortTaskResponse>, Status> {
todo!()
let req = req.into_inner();
let res = self
.mgr
.abort_task(req.get_task_id().expect("no task id found"));
match res {
Ok(_) => Ok(Response::new(AbortTaskResponse { status: None })),
Err(e) => {
error!("failed to abort task {}", e);
Err(e.to_grpc_status())
}
}
}

#[cfg_attr(coverage, no_coverage)]
async fn remove_task(
&self,
req: Request<RemoveTaskRequest>,
) -> Result<Response<RemoveTaskResponse>, Status> {
let req = req.into_inner();
let res = self
.mgr
.remove_task(req.get_task_id().expect("no task id found"));
match res {
Ok(_) => Ok(Response::new(RemoveTaskResponse { status: None })),
Err(e) => {
error!("failed to remove task {}", e);
Err(e.to_grpc_status())
}
}
}
}
85 changes: 73 additions & 12 deletions src/batch/src/task/task_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use futures_async_stream::for_await;
use futures::StreamExt;
use parking_lot::Mutex;
use risingwave_common::array::DataChunk;
use risingwave_common::error::{ErrorCode, Result, RwError};
Expand All @@ -24,6 +24,7 @@ use risingwave_pb::batch_plan::{
};
use risingwave_pb::task_service::task_info::TaskStatus;
use risingwave_pb::task_service::GetDataResponse;
use tokio::sync::oneshot::{Receiver, Sender};
use tracing_futures::Instrument;

use crate::executor::ExecutorBuilder;
Expand Down Expand Up @@ -182,6 +183,9 @@ pub struct BatchTaskExecution<C> {
/// The execution failure.
failure: Arc<Mutex<Option<RwError>>>,

/// Shutdown signal sender.
shutdown_tx: Mutex<Option<Sender<u64>>>,

epoch: u64,
}

Expand All @@ -200,6 +204,7 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
context,
failure: Arc::new(Mutex::new(None)),
epoch,
shutdown_tx: Mutex::new(None),
})
}

Expand All @@ -213,7 +218,7 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
/// hash partitioned across multiple channels.
/// To obtain the result, one must pick one of the channels to consume via [`TaskOutputId`]. As
/// such, parallel consumers are able to consume the result idependently.
pub fn async_execute(&self) -> Result<()> {
pub fn async_execute(self: Arc<Self>) -> Result<()> {
trace!(
"Prepare executing plan [{:?}]: {}",
self.task_id,
Expand All @@ -229,6 +234,8 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
.build2()?;

let (sender, receivers) = create_output_channel(self.plan.get_exchange_info()?)?;
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<u64>();
*self.shutdown_tx.lock() = Some(shutdown_tx);
self.receivers
.lock()
.extend(receivers.into_iter().map(Some));
Expand All @@ -243,7 +250,8 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
let join_handle = tokio::spawn(async move {
// We should only pass a reference of sender to execution because we should only
// close it after task error has been set.
if let Err(e) = try_execute(exec, &mut sender)
if let Err(e) = self
.try_execute(exec, &mut sender, shutdown_rx)
.instrument(tracing::trace_span!(
"batch_execute",
task_id = ?task_id.task_id,
Expand All @@ -255,6 +263,7 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
// Prints the entire backtrace of error.
error!("Execution failed [{:?}]: {:?}", &task_id, &e);
*failure.lock() = Some(e);
*self.state.lock() = TaskStatus::Failed;
}
});

Expand All @@ -265,6 +274,58 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
Ok(())
}

pub async fn try_execute(
&self,
root: BoxedExecutor2,
sender: &mut ChanSenderImpl,
mut shutdown_rx: Receiver<u64>,
) -> Result<()> {
let mut data_chunk_stream = root.execute();
loop {
tokio::select! {
// We prioritize abort signal over normal data chunks.
biased;
_ = &mut shutdown_rx => {
sender.send(None).await?;
*self.state.lock() = TaskStatus::Aborted;
break;
}
res = data_chunk_stream.next() => {
match res {
Some(data_chunk) => {
sender.send(Some(data_chunk?)).await?;
}
None => {
debug!("data chunk stream shuts down");
sender.send(None).await?;
break;
}
}
}
}
}
Ok(())
}

pub fn abort_task(&self) -> Result<()> {
let sender = self.shutdown_tx.lock().take().ok_or_else(|| {
ErrorCode::InternalError(format!(
"Task{:?}'s shutdown channel does not exist. \
Either the task has been aborted once, \
or the channel has neven been initialized.",
self.task_id
))
})?;
*self.state.lock() = TaskStatus::Aborting;
sender.send(0).map_err(|err| {
ErrorCode::InternalError(format!(
"Task{:?};s shutdown channel send error:{:?}",
self.task_id, err
))
.into()
})
}

pub fn get_task_output(&self, output_id: &ProstOutputId) -> Result<TaskOutput<C>> {
let task_id = TaskId::from(output_id.get_task_id()?);
let receiver = self.receivers.lock()[output_id.get_output_id() as usize]
Expand Down Expand Up @@ -298,18 +359,18 @@ impl<C: BatchTaskContext> BatchTaskExecution<C> {
}
Ok(())
}
}

pub async fn try_execute(root: BoxedExecutor2, sender: &mut ChanSenderImpl) -> Result<()> {
#[for_await]
for chunk in root.execute() {
let chunk = chunk?;
if chunk.cardinality() > 0 {
sender.send(Some(chunk)).await?;
pub fn check_if_aborted(&self) -> Result<bool> {
match *self.state.lock() {
TaskStatus::Aborted => Ok(true),
TaskStatus::Finished => Err(ErrorCode::InternalError(format!(
"task {:?} has been finished",
self.get_task_id()
))
.into()),
_ => Ok(false),
}
}
sender.send(None).await?;
Ok(())
}

#[cfg(test)]
Expand Down
93 changes: 83 additions & 10 deletions src/batch/src/task/task_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::task::{BatchTaskExecution, ComputeNodeContext, TaskId, TaskOutput};
#[derive(Clone)]
pub struct BatchManager {
/// Every task id has a corresponding task execution.
tasks: Arc<Mutex<HashMap<TaskId, Box<BatchTaskExecution<ComputeNodeContext>>>>>,
tasks: Arc<Mutex<HashMap<TaskId, Arc<BatchTaskExecution<ComputeNodeContext>>>>>,
}

impl BatchManager {
Expand All @@ -48,10 +48,11 @@ impl BatchManager {
trace!("Received task id: {:?}, plan: {:?}", tid, plan);
let task = BatchTaskExecution::new(tid, plan, context, epoch)?;
let task_id = task.get_task_id().clone();
let task = Arc::new(task);

task.async_execute()?;
task.clone().async_execute()?;
if let hash_map::Entry::Vacant(e) = self.tasks.lock().entry(task_id.clone()) {
e.insert(Box::new(task));
e.insert(task);
Ok(())
} else {
Err(ErrorCode::InternalError(format!(
Expand All @@ -72,11 +73,18 @@ impl BatchManager {
.get_task_output(output_id)
}

#[cfg(test)]
pub fn abort_task(&self, sid: &ProstTaskId) -> Result<()> {
let sid = TaskId::from(sid);
match self.tasks.lock().get(&sid) {
Some(task) => task.abort_task(),
None => Err(TaskNotFound.into()),
}
}

pub fn remove_task(
&self,
sid: &ProstTaskId,
) -> Result<Option<Box<BatchTaskExecution<ComputeNodeContext>>>> {
) -> Result<Option<Arc<BatchTaskExecution<ComputeNodeContext>>>> {
let task_id = TaskId::from(sid);
match self.tasks.lock().remove(&task_id) {
Some(t) => Ok(Some(t)),
Expand All @@ -92,6 +100,32 @@ impl BatchManager {
}
}

pub fn check_if_task_aborted(&self, task_id: &TaskId) -> Result<bool> {
match self.tasks.lock().get(task_id) {
Some(task) => task.check_if_aborted(),
None => Err(TaskNotFound.into()),
}
}

#[cfg(test)]
async fn wait_until_task_aborted(&self, task_id: &TaskId) -> Result<()> {
use std::time::Duration;
loop {
match self.tasks.lock().get(task_id) {
Some(task) => {
let ret = task.check_if_aborted();
match ret {
Ok(true) => return Ok(()),
Ok(false) => {}
Err(err) => return Err(err),
}
}
None => return Err(TaskNotFound.into()),
}
tokio::time::sleep(Duration::from_millis(100)).await
}
}

pub fn get_error(&self, task_id: &TaskId) -> Result<Option<RwError>> {
Ok(self
.tasks
Expand All @@ -110,9 +144,13 @@ impl Default for BatchManager {

#[cfg(test)]
mod tests {
use risingwave_expr::expr::make_i32_literal;
use risingwave_pb::batch_plan::exchange_info::DistributionMode;
use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::TaskOutputId as ProstTaskOutputId;
use risingwave_pb::batch_plan::{
ExchangeInfo, GenerateSeriesNode, PlanFragment, PlanNode, TaskId as ProstTaskId,
TaskOutputId as ProstTaskOutputId, ValuesNode,
};
use tonic::Code;

use crate::task::{BatchManager, ComputeNodeContext, TaskId};
Expand Down Expand Up @@ -151,8 +189,6 @@ mod tests {

#[tokio::test]
async fn test_task_id_conflict() {
use risingwave_pb::batch_plan::*;

let manager = BatchManager::new();
let plan = PlanFragment {
root: Some(PlanNode {
Expand All @@ -169,8 +205,10 @@ mod tests {
}),
};
let context = ComputeNodeContext::new_for_test();
let task_id = TaskId {
..Default::default()
let task_id = ProstTaskId {
query_id: "".to_string(),
stage_id: 0,
task_id: 0,
};
manager
.fire_task(&task_id, plan.clone(), 0, context.clone())
Expand All @@ -180,4 +218,39 @@ mod tests {
.to_string()
.contains("can not create duplicate task with the same id"));
}

#[tokio::test]
async fn test_task_aborted() {
let manager = BatchManager::new();
let plan = PlanFragment {
root: Some(PlanNode {
children: vec![],
identity: "".to_string(),
node_body: Some(NodeBody::GenerateSeries(GenerateSeriesNode {
start: Some(make_i32_literal(1)),
// This is a bit hacky as we want to make sure the task lasts long enough
// for us to abort it.
lmatz marked this conversation as resolved.
Show resolved Hide resolved
stop: Some(make_i32_literal(i32::MAX)),
step: Some(make_i32_literal(1)),
})),
}),
exchange_info: Some(ExchangeInfo {
mode: DistributionMode::Single as i32,
distribution: None,
}),
};
let context = ComputeNodeContext::new_for_test();
let task_id = ProstTaskId {
query_id: "".to_string(),
stage_id: 0,
task_id: 0,
};
manager
.fire_task(&task_id, plan.clone(), 0, context.clone())
.unwrap();
manager.abort_task(&task_id).unwrap();
let task_id = TaskId::from(&task_id);
let res = manager.wait_until_task_aborted(&task_id).await;
assert_eq!(res, Ok(()));
}
}
Loading