diff --git a/spec/std/channel_spec.cr b/spec/std/channel_spec.cr index c13571aec30e..c68cf19dbdff 100644 --- a/spec/std/channel_spec.cr +++ b/spec/std/channel_spec.cr @@ -137,6 +137,29 @@ describe Channel::Unbuffered do spawn { ch.send 123 } ch.receive?.should eq(123) end + + it "wakes up the sender fiber when channel is closed" do + ch = Channel::Unbuffered(Nil).new + sender_closed = false + spawn do + ch.send nil + ch.send nil + rescue Channel::ClosedError + sender_closed = true + end + receiver_closed = false + spawn do + Fiber.yield + ch.receive + rescue Channel::ClosedError + receiver_closed = true + end + Fiber.yield + ch.close + Fiber.yield + sender_closed.should be_true + receiver_closed.should be_true + end end describe Channel::Buffered do diff --git a/src/concurrent/channel.cr b/src/concurrent/channel.cr index c3eb1ccfb347..7e5c20857586 100644 --- a/src/concurrent/channel.cr +++ b/src/concurrent/channel.cr @@ -30,6 +30,8 @@ abstract class Channel(T) def close @closed = true + Scheduler.enqueue @senders + @senders.clear Scheduler.enqueue @receivers @receivers.clear nil @@ -259,6 +261,7 @@ class Channel::Unbuffered(T) < Channel(T) @value.tap do @has_value = false Scheduler.enqueue @sender.not_nil! + @sender = nil end end @@ -269,4 +272,12 @@ class Channel::Unbuffered(T) < Channel(T) def full? @has_value || @receivers.empty? end + + def close + super + if sender = @sender + Scheduler.enqueue sender + @sender = nil + end + end end