From 8881bc78e2243f5e537b6601cab0b6d7af8f29a4 Mon Sep 17 00:00:00 2001 From: LIU Zhi Date: Tue, 24 May 2022 12:41:38 +0800 Subject: [PATCH 1/3] feat(batch): abort task --- proto/task_service.proto | 11 ++- src/batch/src/rpc/service/task_service.rs | 34 ++++++- src/batch/src/task/task_.rs | 88 ++++++++++++++++--- src/batch/src/task/task_manager.rs | 78 +++++++++++++--- src/expr/src/expr/expr_field.rs | 30 +------ src/expr/src/expr/mod.rs | 2 +- src/expr/src/expr/test_utils.rs | 30 ++++++- .../src/executor/managed_state/join/mod.rs | 9 +- 8 files changed, 221 insertions(+), 61 deletions(-) diff --git a/proto/task_service.proto b/proto/task_service.proto index a33ae407b7b6..4333279ab6db 100644 --- a/proto/task_service.proto +++ b/proto/task_service.proto @@ -32,6 +32,7 @@ message TaskInfo { CANCELLING = 4; FINISHED = 5; FAILED = 6; + ABORTED = 7; } batch_plan.TaskId task_id = 1; TaskStatus task_status = 2; @@ -49,13 +50,20 @@ message CreateTaskResponse { message AbortTaskRequest { batch_plan.TaskId task_id = 1; - bool force = 2; } message AbortTaskResponse { common.Status status = 1; } +message RemoveTaskRequest { + batch_plan.TaskId task_id = 1; +} + +message RemoveTaskResponse { + common.Status status = 1; +} + message GetTaskInfoRequest { batch_plan.TaskId task_id = 1; } @@ -79,6 +87,7 @@ service TaskService { rpc CreateTask(CreateTaskRequest) returns (CreateTaskResponse); rpc GetTaskInfo(GetTaskInfoRequest) returns (GetTaskInfoResponse); rpc AbortTask(AbortTaskRequest) returns (AbortTaskResponse); + rpc RemoveTask(RemoveTaskRequest) returns (RemoveTaskResponse); } message GetDataRequest { diff --git a/src/batch/src/rpc/service/task_service.rs b/src/batch/src/rpc/service/task_service.rs index c2b2775640ac..f51e83d95c90 100644 --- a/src/batch/src/rpc/service/task_service.rs +++ b/src/batch/src/rpc/service/task_service.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use risingwave_pb::task_service::task_service_server::TaskService; use risingwave_pb::task_service::{ AbortTaskRequest, AbortTaskResponse, CreateTaskRequest, CreateTaskResponse, GetTaskInfoRequest, - GetTaskInfoResponse, + GetTaskInfoResponse, RemoveTaskRequest, RemoveTaskResponse, }; use tonic::{Request, Response, Status}; @@ -70,8 +70,36 @@ impl TaskService for BatchServiceImpl { #[cfg_attr(coverage, no_coverage)] async fn abort_task( &self, - _: Request, + req: Request, ) -> Result, Status> { - todo!() + let req = req.into_inner(); + let res = self + .mgr + .abort_task(req.get_task_id().expect("no task id found")); + match res { + Ok(_) => Ok(Response::new(AbortTaskResponse { status: None })), + Err(e) => { + error!("failed to abort task {}", e); + Err(e.to_grpc_status()) + } + } + } + + #[cfg_attr(coverage, no_coverage)] + async fn remove_task( + &self, + req: Request, + ) -> Result, Status> { + let req = req.into_inner(); + let res = self + .mgr + .remove_task(req.get_task_id().expect("no task id found")); + match res { + Ok(_) => Ok(Response::new(RemoveTaskResponse { status: None })), + Err(e) => { + error!("failed to remove task {}", e); + Err(e.to_grpc_status()) + } + } } } diff --git a/src/batch/src/task/task_.rs b/src/batch/src/task/task_.rs index 669956777182..22ddca8e8ab5 100644 --- a/src/batch/src/task/task_.rs +++ b/src/batch/src/task/task_.rs @@ -15,7 +15,7 @@ use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use futures_async_stream::for_await; +use futures::{Stream, StreamExt}; use parking_lot::Mutex; use risingwave_common::array::DataChunk; use risingwave_common::error::{ErrorCode, Result, RwError}; @@ -24,6 +24,8 @@ use risingwave_pb::batch_plan::{ }; use risingwave_pb::task_service::task_info::TaskStatus; use risingwave_pb::task_service::GetDataResponse; +use tokio::sync::mpsc::UnboundedSender; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing_futures::Instrument; use crate::executor::ExecutorBuilder; @@ -182,6 +184,9 @@ pub struct BatchTaskExecution { /// The execution failure. failure: Arc>>, + /// Shutdown signal sender. + shutdown_tx: Mutex>>, + epoch: u64, } @@ -200,6 +205,7 @@ impl BatchTaskExecution { context, failure: Arc::new(Mutex::new(None)), epoch, + shutdown_tx: Mutex::new(None), }) } @@ -213,7 +219,7 @@ impl BatchTaskExecution { /// hash partitioned across multiple channels. /// To obtain the result, one must pick one of the channels to consume via [`TaskOutputId`]. As /// such, parallel consumers are able to consume the result idependently. - pub fn async_execute(&self) -> Result<()> { + pub fn async_execute(self: Arc) -> Result<()> { trace!( "Prepare executing plan [{:?}]: {}", self.task_id, @@ -229,6 +235,9 @@ impl BatchTaskExecution { .build2()?; let (sender, receivers) = create_output_channel(self.plan.get_exchange_info()?)?; + let (shutdown_tx, shutdown_rx) = tokio::sync::mpsc::unbounded_channel::(); + *self.shutdown_tx.lock() = Some(shutdown_tx); + let shutdown_rx = UnboundedReceiverStream::new(shutdown_rx); self.receivers .lock() .extend(receivers.into_iter().map(Some)); @@ -243,7 +252,8 @@ impl BatchTaskExecution { let join_handle = tokio::spawn(async move { // We should only pass a reference of sender to execution because we should only // close it after task error has been set. - if let Err(e) = try_execute(exec, &mut sender) + if let Err(e) = self + .try_execute(exec, &mut sender, shutdown_rx) .instrument(tracing::trace_span!( "batch_execute", task_id = ?task_id.task_id, @@ -255,6 +265,7 @@ impl BatchTaskExecution { // Prints the entire backtrace of error. error!("Execution failed [{:?}]: {:?}", &task_id, &e); *failure.lock() = Some(e); + *self.state.lock() = TaskStatus::Failed; } }); @@ -265,6 +276,60 @@ impl BatchTaskExecution { Ok(()) } + pub async fn try_execute>( + &self, + root: BoxedExecutor2, + sender: &mut ChanSenderImpl, + shutdown_rx: S, + ) -> Result<()> { + let mut shutdown_stream = Box::pin(shutdown_rx); + let mut data_chunk_stream = root.execute(); + loop { + tokio::select! { + // We prioritize abort signal over normal data chunks. + biased; + _ = shutdown_stream.next() => { + sender.send(None).await?; + *self.state.lock() = TaskStatus::Aborted; + break; + } + res = data_chunk_stream.next() => { + match res { + Some(data_chunk) => { + sender.send(Some(data_chunk?)).await?; + } + None => { + debug!("data chunk stream shuts down"); + sender.send(None).await?; + break; + } + } + } + } + } + Ok(()) + } + + pub fn abort_task(&self) -> Result<()> { + self.shutdown_tx + .lock() + .as_mut() + .ok_or_else(|| { + ErrorCode::InternalError(format!( + "Task{:?}'s shutdown channel does not exist.", + self.task_id + )) + })? + .send(0) + .map_err(|err| { + ErrorCode::InternalError(format!( + "Task{:?};s shutdown channel send error:{:?}", + self.task_id, err + )) + .into() + }) + } + pub fn get_task_output(&self, output_id: &ProstOutputId) -> Result> { let task_id = TaskId::from(output_id.get_task_id()?); let receiver = self.receivers.lock()[output_id.get_output_id() as usize] @@ -298,18 +363,17 @@ impl BatchTaskExecution { } Ok(()) } -} -pub async fn try_execute(root: BoxedExecutor2, sender: &mut ChanSenderImpl) -> Result<()> { - #[for_await] - for chunk in root.execute() { - let chunk = chunk?; - if chunk.cardinality() > 0 { - sender.send(Some(chunk)).await?; + pub fn check_if_aborted(&self) -> Result<()> { + if *self.state.lock() != TaskStatus::Aborted { + return Err(ErrorCode::InternalError(format!( + "task {:?} has not been aborted", + self.get_task_id() + )) + .into()); } + Ok(()) } - sender.send(None).await?; - Ok(()) } #[cfg(test)] diff --git a/src/batch/src/task/task_manager.rs b/src/batch/src/task/task_manager.rs index 4461f8373e5d..765054c7e221 100644 --- a/src/batch/src/task/task_manager.rs +++ b/src/batch/src/task/task_manager.rs @@ -28,7 +28,7 @@ use crate::task::{BatchTaskExecution, ComputeNodeContext, TaskId, TaskOutput}; #[derive(Clone)] pub struct BatchManager { /// Every task id has a corresponding task execution. - tasks: Arc>>>>, + tasks: Arc>>>>, } impl BatchManager { @@ -48,10 +48,11 @@ impl BatchManager { trace!("Received task id: {:?}, plan: {:?}", tid, plan); let task = BatchTaskExecution::new(tid, plan, context, epoch)?; let task_id = task.get_task_id().clone(); + let task = Arc::new(task); - task.async_execute()?; + task.clone().async_execute()?; if let hash_map::Entry::Vacant(e) = self.tasks.lock().entry(task_id.clone()) { - e.insert(Box::new(task)); + e.insert(task); Ok(()) } else { Err(ErrorCode::InternalError(format!( @@ -72,11 +73,18 @@ impl BatchManager { .get_task_output(output_id) } - #[cfg(test)] + pub fn abort_task(&self, sid: &ProstTaskId) -> Result<()> { + let sid = TaskId::from(sid); + match self.tasks.lock().get(&sid) { + Some(task) => task.abort_task(), + None => Err(TaskNotFound.into()), + } + } + pub fn remove_task( &self, sid: &ProstTaskId, - ) -> Result>>> { + ) -> Result>>> { let task_id = TaskId::from(sid); match self.tasks.lock().remove(&task_id) { Some(t) => Ok(Some(t)), @@ -92,6 +100,13 @@ impl BatchManager { } } + pub fn check_if_task_aborted(&self, task_id: &TaskId) -> Result<()> { + match self.tasks.lock().get(task_id) { + Some(task) => task.check_if_aborted(), + None => Err(TaskNotFound.into()), + } + } + pub fn get_error(&self, task_id: &TaskId) -> Result> { Ok(self .tasks @@ -110,9 +125,15 @@ impl Default for BatchManager { #[cfg(test)] mod tests { + use std::time::Duration; + + use risingwave_expr::expr::make_i32_literal; use risingwave_pb::batch_plan::exchange_info::DistributionMode; use risingwave_pb::batch_plan::plan_node::NodeBody; - use risingwave_pb::batch_plan::TaskOutputId as ProstTaskOutputId; + use risingwave_pb::batch_plan::{ + ExchangeInfo, GenerateSeriesNode, PlanFragment, PlanNode, TaskId as ProstTaskId, + TaskOutputId as ProstTaskOutputId, ValuesNode, + }; use tonic::Code; use crate::task::{BatchManager, ComputeNodeContext, TaskId}; @@ -151,8 +172,6 @@ mod tests { #[tokio::test] async fn test_task_id_conflict() { - use risingwave_pb::batch_plan::*; - let manager = BatchManager::new(); let plan = PlanFragment { root: Some(PlanNode { @@ -169,8 +188,10 @@ mod tests { }), }; let context = ComputeNodeContext::new_for_test(); - let task_id = TaskId { - ..Default::default() + let task_id = ProstTaskId { + query_id: "".to_string(), + stage_id: 0, + task_id: 0, }; manager .fire_task(&task_id, plan.clone(), 0, context.clone()) @@ -180,4 +201,41 @@ mod tests { .to_string() .contains("can not create duplicate task with the same id")); } + + #[tokio::test] + async fn test_task_aborted() { + let manager = BatchManager::new(); + let plan = PlanFragment { + root: Some(PlanNode { + children: vec![], + identity: "".to_string(), + node_body: Some(NodeBody::GenerateSeries(GenerateSeriesNode { + start: Some(make_i32_literal(1)), + // This is a bit hacky as we want to make sure the task lasts long enough + // for us to abort it. + stop: Some(make_i32_literal(i32::MAX)), + step: Some(make_i32_literal(1)), + })), + }), + exchange_info: Some(ExchangeInfo { + mode: DistributionMode::Single as i32, + distribution: None, + }), + }; + let context = ComputeNodeContext::new_for_test(); + let task_id = ProstTaskId { + query_id: "".to_string(), + stage_id: 0, + task_id: 0, + }; + manager + .fire_task(&task_id, plan.clone(), 0, context.clone()) + .unwrap(); + manager.abort_task(&task_id).unwrap(); + let task_id = TaskId::from(&task_id); + while let Err(e) = manager.check_if_task_aborted(&task_id) { + println!("check_if_task_aborted:{}", e); + tokio::time::sleep(Duration::from_millis(50)).await; + } + } } diff --git a/src/expr/src/expr/expr_field.rs b/src/expr/src/expr/expr_field.rs index a1d97786150a..9fbd33a886a2 100644 --- a/src/expr/src/expr/expr_field.rs +++ b/src/expr/src/expr/expr_field.rs @@ -87,39 +87,11 @@ mod tests { use risingwave_common::array::{DataChunk, F32Array, I32Array, StructArray}; use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_pb::data::data_type::TypeName; - use risingwave_pb::data::DataType as ProstDataType; - use risingwave_pb::expr::expr_node::Type::Field; - use risingwave_pb::expr::expr_node::{RexNode, Type}; - use risingwave_pb::expr::{ConstantValue, ExprNode, FunctionCall}; use crate::expr::expr_field::FieldExpression; - use crate::expr::test_utils::make_input_ref; + use crate::expr::test_utils::{make_field_function, make_i32_literal, make_input_ref}; use crate::expr::Expression; - pub fn make_i32_literal(data: i32) -> ExprNode { - ExprNode { - expr_type: Type::ConstantValue as i32, - return_type: Some(ProstDataType { - type_name: TypeName::Int32 as i32, - ..Default::default() - }), - rex_node: Some(RexNode::Constant(ConstantValue { - body: data.to_be_bytes().to_vec(), - })), - } - } - - pub fn make_field_function(children: Vec, ret: TypeName) -> ExprNode { - ExprNode { - expr_type: Field as i32, - return_type: Some(ProstDataType { - type_name: ret as i32, - ..Default::default() - }), - rex_node: Some(RexNode::FuncCall(FunctionCall { children })), - } - } - #[test] fn test_field_expr() { let input_node = make_input_ref(0, TypeName::Struct); diff --git a/src/expr/src/expr/mod.rs b/src/expr/src/expr/mod.rs index 0a9067f8b120..5e623d1a275b 100644 --- a/src/expr/src/expr/mod.rs +++ b/src/expr/src/expr/mod.rs @@ -128,5 +128,5 @@ impl RowExpression { } } -#[cfg(test)] mod test_utils; +pub use test_utils::*; diff --git a/src/expr/src/expr/test_utils.rs b/src/expr/src/expr/test_utils.rs index 3c063b21038f..58d072d98444 100644 --- a/src/expr/src/expr/test_utils.rs +++ b/src/expr/src/expr/test_utils.rs @@ -14,10 +14,10 @@ use itertools::Itertools; use risingwave_pb::data::data_type::TypeName; -use risingwave_pb::data::DataType; -use risingwave_pb::expr::expr_node::Type::InputRef; +use risingwave_pb::data::{DataType as ProstDataType, DataType}; +use risingwave_pb::expr::expr_node::Type::{Field, InputRef}; use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::{ExprNode, FunctionCall, InputRefExpr}; +use risingwave_pb::expr::{ConstantValue, ExprNode, FunctionCall, InputRefExpr}; pub fn make_expression(kind: Type, rets: &[TypeName], indices: &[i32]) -> ExprNode { let mut exprs = Vec::new(); @@ -46,3 +46,27 @@ pub fn make_input_ref(idx: i32, ret: TypeName) -> ExprNode { rex_node: Some(RexNode::InputRef(InputRefExpr { column_idx: idx })), } } + +pub fn make_i32_literal(data: i32) -> ExprNode { + ExprNode { + expr_type: Type::ConstantValue as i32, + return_type: Some(ProstDataType { + type_name: TypeName::Int32 as i32, + ..Default::default() + }), + rex_node: Some(RexNode::Constant(ConstantValue { + body: data.to_be_bytes().to_vec(), + })), + } +} + +pub fn make_field_function(children: Vec, ret: TypeName) -> ExprNode { + ExprNode { + expr_type: Field as i32, + return_type: Some(ProstDataType { + type_name: ret as i32, + ..Default::default() + }), + rex_node: Some(RexNode::FuncCall(FunctionCall { children })), + } +} diff --git a/src/stream/src/executor/managed_state/join/mod.rs b/src/stream/src/executor/managed_state/join/mod.rs index d8cd48c3c20e..ea832d0f1dd5 100644 --- a/src/stream/src/executor/managed_state/join/mod.rs +++ b/src/stream/src/executor/managed_state/join/mod.rs @@ -190,7 +190,9 @@ impl JoinHashMap { /// Returns a mutable reference to the value of the key in the memory, if does not exist, look /// up in remote storage and return, if still not exist, return None. - pub async fn get_mut<'a>(&'a mut self, key: &K) -> Option<&'a mut HashValueType> { + /// FIXME(lmatz): Lifetime 'b is added due to some weird errors, possibly compiler error. May + /// double check after bumping the toolchain. + pub async fn get_mut<'a, 'b>(&'a mut self, key: &'b K) -> Option<&'a mut HashValueType> { let state = self.inner.get(key); // TODO: we should probably implement a entry function for `LruCache` match state { @@ -280,7 +282,10 @@ impl JoinHashMap { /// Get or create a [`JoinEntryState`] without cached state. Should only be called if the key /// does not exist in memory or remote storage. - pub async fn get_or_init_without_cache(&mut self, key: &K) -> RwResult<&mut JoinEntryState> { + pub async fn get_or_init_without_cache<'a>( + &'a mut self, + key: &K, + ) -> RwResult<&'a mut JoinEntryState> { // TODO: we should probably implement a entry function for `LruCache` let contains = self.inner.contains(key); if contains { From 4017943d96122b2c3d7d76672a697e124ab2f707 Mon Sep 17 00:00:00 2001 From: LIU Zhi Date: Wed, 25 May 2022 12:21:03 +0800 Subject: [PATCH 2/3] revision --- proto/task_service.proto | 1 + src/batch/src/task/task_.rs | 63 ++++++++++++++---------------- src/batch/src/task/task_manager.rs | 28 +++++++++---- 3 files changed, 52 insertions(+), 40 deletions(-) diff --git a/proto/task_service.proto b/proto/task_service.proto index 4333279ab6db..fe51408b30b6 100644 --- a/proto/task_service.proto +++ b/proto/task_service.proto @@ -33,6 +33,7 @@ message TaskInfo { FINISHED = 5; FAILED = 6; ABORTED = 7; + ABORTING = 8; } batch_plan.TaskId task_id = 1; TaskStatus task_status = 2; diff --git a/src/batch/src/task/task_.rs b/src/batch/src/task/task_.rs index 22ddca8e8ab5..9336b74bce4c 100644 --- a/src/batch/src/task/task_.rs +++ b/src/batch/src/task/task_.rs @@ -15,7 +15,7 @@ use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use futures::{Stream, StreamExt}; +use futures::StreamExt; use parking_lot::Mutex; use risingwave_common::array::DataChunk; use risingwave_common::error::{ErrorCode, Result, RwError}; @@ -24,8 +24,7 @@ use risingwave_pb::batch_plan::{ }; use risingwave_pb::task_service::task_info::TaskStatus; use risingwave_pb::task_service::GetDataResponse; -use tokio::sync::mpsc::UnboundedSender; -use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio::sync::oneshot::{Receiver, Sender}; use tracing_futures::Instrument; use crate::executor::ExecutorBuilder; @@ -185,7 +184,7 @@ pub struct BatchTaskExecution { failure: Arc>>, /// Shutdown signal sender. - shutdown_tx: Mutex>>, + shutdown_tx: Mutex>>, epoch: u64, } @@ -235,9 +234,8 @@ impl BatchTaskExecution { .build2()?; let (sender, receivers) = create_output_channel(self.plan.get_exchange_info()?)?; - let (shutdown_tx, shutdown_rx) = tokio::sync::mpsc::unbounded_channel::(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::(); *self.shutdown_tx.lock() = Some(shutdown_tx); - let shutdown_rx = UnboundedReceiverStream::new(shutdown_rx); self.receivers .lock() .extend(receivers.into_iter().map(Some)); @@ -276,19 +274,18 @@ impl BatchTaskExecution { Ok(()) } - pub async fn try_execute>( + pub async fn try_execute( &self, root: BoxedExecutor2, sender: &mut ChanSenderImpl, - shutdown_rx: S, + mut shutdown_rx: Receiver, ) -> Result<()> { - let mut shutdown_stream = Box::pin(shutdown_rx); let mut data_chunk_stream = root.execute(); loop { tokio::select! { // We prioritize abort signal over normal data chunks. biased; - _ = shutdown_stream.next() => { + _ = &mut shutdown_rx => { sender.send(None).await?; *self.state.lock() = TaskStatus::Aborted; break; @@ -311,23 +308,22 @@ impl BatchTaskExecution { } pub fn abort_task(&self) -> Result<()> { - self.shutdown_tx - .lock() - .as_mut() - .ok_or_else(|| { - ErrorCode::InternalError(format!( - "Task{:?}'s shutdown channel does not exist.", - self.task_id - )) - })? - .send(0) - .map_err(|err| { - ErrorCode::InternalError(format!( - "Task{:?};s shutdown channel send error:{:?}", - self.task_id, err - )) - .into() - }) + let sender = self.shutdown_tx.lock().take().ok_or_else(|| { + ErrorCode::InternalError(format!( + "Task{:?}'s shutdown channel does not exist. \ + Either the task has been aborted once, \ + or the channel has neven been initialized.", + self.task_id + )) + })?; + *self.state.lock() = TaskStatus::Aborting; + sender.send(0).map_err(|err| { + ErrorCode::InternalError(format!( + "Task{:?};s shutdown channel send error:{:?}", + self.task_id, err + )) + .into() + }) } pub fn get_task_output(&self, output_id: &ProstOutputId) -> Result> { @@ -364,15 +360,16 @@ impl BatchTaskExecution { Ok(()) } - pub fn check_if_aborted(&self) -> Result<()> { - if *self.state.lock() != TaskStatus::Aborted { - return Err(ErrorCode::InternalError(format!( - "task {:?} has not been aborted", + pub fn check_if_aborted(&self) -> Result { + match *self.state.lock() { + TaskStatus::Aborted => Ok(true), + TaskStatus::Finished => Err(ErrorCode::InternalError(format!( + "task {:?} has been finished", self.get_task_id() )) - .into()); + .into()), + _ => Ok(false), } - Ok(()) } } diff --git a/src/batch/src/task/task_manager.rs b/src/batch/src/task/task_manager.rs index 765054c7e221..98fdfd02dd64 100644 --- a/src/batch/src/task/task_manager.rs +++ b/src/batch/src/task/task_manager.rs @@ -14,6 +14,7 @@ use std::collections::{hash_map, HashMap}; use std::sync::Arc; +use std::time::Duration; use parking_lot::Mutex; use risingwave_common::error::ErrorCode::{self, TaskNotFound}; @@ -100,13 +101,30 @@ impl BatchManager { } } - pub fn check_if_task_aborted(&self, task_id: &TaskId) -> Result<()> { + pub fn check_if_task_aborted(&self, task_id: &TaskId) -> Result { match self.tasks.lock().get(task_id) { Some(task) => task.check_if_aborted(), None => Err(TaskNotFound.into()), } } + pub async fn wait_until_task_aborted(&self, task_id: &TaskId) -> Result<()> { + loop { + match self.tasks.lock().get(task_id) { + Some(task) => { + let ret = task.check_if_aborted(); + match ret { + Ok(true) => return Ok(()), + Ok(false) => {} + Err(err) => return Err(err), + } + } + None => return Err(TaskNotFound.into()), + } + tokio::time::sleep(Duration::from_millis(100)).await + } + } + pub fn get_error(&self, task_id: &TaskId) -> Result> { Ok(self .tasks @@ -125,8 +143,6 @@ impl Default for BatchManager { #[cfg(test)] mod tests { - use std::time::Duration; - use risingwave_expr::expr::make_i32_literal; use risingwave_pb::batch_plan::exchange_info::DistributionMode; use risingwave_pb::batch_plan::plan_node::NodeBody; @@ -233,9 +249,7 @@ mod tests { .unwrap(); manager.abort_task(&task_id).unwrap(); let task_id = TaskId::from(&task_id); - while let Err(e) = manager.check_if_task_aborted(&task_id) { - println!("check_if_task_aborted:{}", e); - tokio::time::sleep(Duration::from_millis(50)).await; - } + let res = manager.wait_until_task_aborted(&task_id).await; + assert_eq!(res, Ok(())); } } From c1cbc14cab98bdc86d17d9b27f0635a675ed718c Mon Sep 17 00:00:00 2001 From: LIU Zhi Date: Wed, 25 May 2022 13:17:06 +0800 Subject: [PATCH 3/3] make wait_until_task_aborted private and test-only --- src/batch/src/task/task_manager.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/batch/src/task/task_manager.rs b/src/batch/src/task/task_manager.rs index 98fdfd02dd64..bba178158dea 100644 --- a/src/batch/src/task/task_manager.rs +++ b/src/batch/src/task/task_manager.rs @@ -14,7 +14,6 @@ use std::collections::{hash_map, HashMap}; use std::sync::Arc; -use std::time::Duration; use parking_lot::Mutex; use risingwave_common::error::ErrorCode::{self, TaskNotFound}; @@ -108,7 +107,9 @@ impl BatchManager { } } - pub async fn wait_until_task_aborted(&self, task_id: &TaskId) -> Result<()> { + #[cfg(test)] + async fn wait_until_task_aborted(&self, task_id: &TaskId) -> Result<()> { + use std::time::Duration; loop { match self.tasks.lock().get(task_id) { Some(task) => {