Skip to content

Commit

Permalink
Merge pull request typelevel#2458 from vasilmkd/piped-bug
Browse files Browse the repository at this point in the history
PipedStreamBuffer circular copying fix
  • Loading branch information
mpilquist authored Jul 2, 2021
2 parents f51e5bd + bfdd2a0 commit 2e3a96a
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 3 deletions.
62 changes: 60 additions & 2 deletions io/src/main/scala/fs2/io/internal/PipedStreamBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
// or just a part of it.
val toRead = math.min(available, length)
// Transfer the bytes to the provided byte array.
System.arraycopy(buffer, head % capacity, b, offset, toRead)
circularRead(buffer, head, capacity, b, offset, toRead)
// The bytes are marked as read by advancing the head of the
// circular buffer.
head += toRead
Expand Down Expand Up @@ -192,6 +192,35 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
override def available(): Int = self.synchronized {
if (closed) 0 else tail - head
}

/** Reads bytes from a circular buffer by copying them into a regular
* buffer.
*
* @param src the source circular buffer
* @param srcPos the offset into the source circular buffer
* @param srcCap the capacity of the source circular buffer
* @param dst the destination buffer
* @param dstPos the offset into the destination buffer
* @param length the number of bytes to be transferred
*/
private[this] def circularRead(
src: Array[Byte],
srcPos: Int,
srcCap: Int,
dst: Array[Byte],
dstPos: Int,
length: Int
): Unit = {
val srcOffset = srcPos % srcCap
if (srcOffset + length >= srcCap) {
val batch1 = srcCap - srcOffset
val batch2 = length - batch1
System.arraycopy(src, srcOffset, dst, dstPos, batch1)
System.arraycopy(src, 0, dst, dstPos + batch1, batch2)
} else {
System.arraycopy(src, srcOffset, dst, dstPos, length)
}
}
}

val outputStream: OutputStream = new OutputStream {
Expand Down Expand Up @@ -270,7 +299,7 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
// or just a part of it.
val toWrite = math.min(available, length)
// Transfer the bytes to the provided byte array.
System.arraycopy(b, offset, buffer, tail % capacity, toWrite)
circularWrite(b, offset, buffer, tail, capacity, toWrite)
// The bytes are marked as written by advancing the tail of the
// circular buffer.
tail += toWrite
Expand Down Expand Up @@ -316,5 +345,34 @@ private[io] final class PipedStreamBuffer(private[this] val capacity: Int) { sel
readerPermit.release()
}
}

/** Writes bytes into a circular buffer by copying them from a regular
* buffer.
*
* @param src the source buffer
* @param srcPos the offset into the source buffer
* @param dst the destination circular buffer
* @param dstPos the offset into the destination circular buffer
* @param dstCap the capacity of the destination circular buffer
* @param length the number of bytes to be transferred
*/
private[this] def circularWrite(
src: Array[Byte],
srcPos: Int,
dst: Array[Byte],
dstPos: Int,
dstCap: Int,
length: Int
): Unit = {
val dstOffset = dstPos % dstCap
if (dstOffset + length >= dstCap) {
val batch1 = dstCap - dstOffset
val batch2 = length - batch1
System.arraycopy(src, srcPos, dst, dstOffset, batch1)
System.arraycopy(src, srcPos + batch1, dst, 0, batch2)
} else {
System.arraycopy(src, srcPos, dst, dstOffset, length)
}
}
}
}
39 changes: 38 additions & 1 deletion io/src/test/scala/fs2/io/IoSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

package fs2.io
package fs2
package io

import java.io.{ByteArrayInputStream, InputStream, OutputStream}
import java.util.concurrent.Executors
Expand Down Expand Up @@ -168,6 +169,42 @@ class IoSuite extends Fs2Suite {
}.compile.drain.map(_ => assert(true))
}
}

test("different chunk sizes function correctly") {

def test(blocker: Blocker, chunkSize: Int): Pipe[IO, Byte, Byte] = source => {
readOutputStream(blocker, chunkSize) { os =>
source.through(writeOutputStream(IO.delay(os), blocker, true)).compile.drain
}
}

def source(chunkSize: Int, bufferSize: Int): Stream[Pure, Byte] =
Stream.range(65, 75).map(_.toByte).repeat.take(chunkSize.toLong * 2).buffer(bufferSize)

forAllF { (chunkSize0: Int, bufferSize0: Int) =>
val chunkSize = (chunkSize0 % 512).abs + 1
val bufferSize = (bufferSize0 % 511).abs + 1

val src = source(chunkSize, bufferSize)

Blocker[IO].use { blocker =>
src
.through(text.utf8Decode)
.foldMonoid
.flatMap { expected =>
src
.through(test(blocker, chunkSize))
.through(text.utf8Decode)
.foldMonoid
.evalMap { actual =>
IO(assertEquals(actual, expected))
}
}
.compile
.drain
}
}
}
}

group("unsafeReadInputStream") {
Expand Down

0 comments on commit 2e3a96a

Please sign in to comment.