Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update KffFile to use FileChannelFactory instead of RandomAccessFile #527

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions src/main/java/emissary/kff/KffFile.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package emissary.kff;

import emissary.core.channels.FileChannelFactory;
import emissary.core.channels.SeekableByteChannelFactory;

import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.channels.SeekableByteChannel;
import java.nio.file.Files;
import java.nio.file.Paths;
import javax.annotation.Nonnull;
Expand All @@ -31,16 +34,17 @@ public class KffFile implements KffFilter {
private final Logger logger;

/** File containing SHA-1/CRC32 results of known files */
protected RandomAccessFile knownFile;
// protected RandomAccessFile knownFile;
drivenflywheel marked this conversation as resolved.
Show resolved Hide resolved
protected SeekableByteChannelFactory knownFileFactory;

/** Byte buffer that is mapped to the above file */
protected ByteBuffer mappedBuf;

/** Initial value of high index for binary search */
private int bSearchInitHigh;

protected int RECORD_LENGTH = 24;
protected int recordLength = RECORD_LENGTH;
public static final int DEFAULT_RECORD_LENGTH = 24;
protected int recordLength = DEFAULT_RECORD_LENGTH;

/** String logical name for this filter */
protected String filterName = "UNKNOWN";
Expand Down Expand Up @@ -82,10 +86,11 @@ public KffFile(String filename, String filterName, FilterType ftype, int recordL
logger = LoggerFactory.getLogger(this.getClass());

// Open file in read-only mode
knownFile = new RandomAccessFile(filename, "r");
// knownFile = new RandomAccessFile(filename, "r");
knownFileFactory = FileChannelFactory.create(Paths.get(filename));

// Initial high value for binary search is the largest index
bSearchInitHigh = ((int) knownFile.length() / recordLength) - 1;
bSearchInitHigh = ((int) knownFileFactory.create().size() / recordLength) - 1;

logger.debug("KFF File {} has {} records", filename, (bSearchInitHigh + 1));
}
Expand Down Expand Up @@ -147,33 +152,34 @@ private boolean binaryFileSearch(@Nonnull byte[] hash, long crc) {

/* Buffer to hold a record */
byte[] rec = new byte[recordLength];

ByteBuffer byteBuffer = ByteBuffer.wrap(rec);
// Search until the indexes cross
while (low <= high) {
// Calculate the midpoint
int mid = (low + high) >> 1;
try (SeekableByteChannel knownFile = knownFileFactory.create()) {
while (low <= high) {
byteBuffer.clear();

try {
knownFile.seek(rec.length * (long) mid);
int count = knownFile.read(rec);
// Calculate the midpoint
int mid = (low + high) >> 1;

knownFile.position(rec.length * (long) mid);
int count = knownFile.read(byteBuffer);
drivenflywheel marked this conversation as resolved.
Show resolved Hide resolved
if (count != rec.length) {
logger.warn("Short read on KffFile at {} read {} expected {}", (rec.length * mid), count, rec.length);
logger.warn("Short read on KffFile at {} read {} expected {}", (recordLength * mid), count, recordLength);
return false;
}
} catch (IOException x) {
logger.warn("Exception reading KffFile: {}", x.getMessage());
return false;
}

// Compare the record with the target. Adjust the indexes accordingly.
int c = compare(rec, hash, crc);
if (c < 0) {
high = mid - 1;
} else if (c > 0) {
low = mid + 1;
} else {
return true;
// Compare the record with the target. Adjust the indexes accordingly.
int c = compare(rec, hash, crc);
if (c < 0) {
high = mid - 1;
} else if (c > 0) {
low = mid + 1;
} else {
return true;
}
}
} catch (IOException e) {
logger.warn("Exception reading KffFile: {}", e.getMessage());
return false;
}

// not found
Expand Down
169 changes: 166 additions & 3 deletions src/test/java/emissary/kff/KffFileTest.java
Original file line number Diff line number Diff line change
@@ -1,20 +1,44 @@
package emissary.kff;

import emissary.core.channels.FileChannelFactory;
import emissary.core.channels.SeekableByteChannelFactory;
import emissary.test.core.junit5.UnitTest;
import emissary.util.io.ResourceReader;

import org.apache.commons.compress.utils.ByteUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.Validate;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.nio.ByteBuffer;
import java.nio.channels.SeekableByteChannel;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;

import static emissary.kff.KffFile.DEFAULT_RECORD_LENGTH;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

class KffFileTest extends UnitTest {
private static final Logger LOGGER = LoggerFactory.getLogger(KffFileTest.class);

private static final String expectedShaHash = "000000206738748EDD92C4E3D2E823896700F849";
drivenflywheel marked this conversation as resolved.
Show resolved Hide resolved
private static final String ITEM_NAME = "Some_item_name";
private static final byte[] expectedSha1Bytes = {(byte) 0, (byte) 0, (byte) 0, (byte) 32, (byte) 103, (byte) 56, (byte) 116,
(byte) -114, (byte) -35, (byte) -110, (byte) -60, (byte) -29, (byte) -46, (byte) -24, (byte) 35, (byte) -119,
(byte) 103, (byte) 0, (byte) -8, (byte) 73};
Expand All @@ -23,6 +47,8 @@ class KffFileTest extends UnitTest {
private static final String resourcePath = new ResourceReader()
.getResource("emissary/kff/KffFileTest/tmp.bin").getPath();

SeekableByteChannelFactory channelFactory = FileChannelFactory.create(Path.of(resourcePath));

@Override
@BeforeEach
public void setUp() throws Exception {
Expand All @@ -44,7 +70,7 @@ void testKffFileCheck() {
results.setHash("SHA-1", expectedSha1Bytes);
results.setHash("CRC32", expectedCrcBytes);
try {
assertTrue(kffFile.check(expectedShaHash, results));
assertTrue(kffFile.check(ITEM_NAME, results));
} catch (Exception e) {
fail(e);
}
Expand All @@ -54,7 +80,7 @@ void testKffFileCheck() {
results = new ChecksumResults();
results.setHash("SHA-1", incorrectSha1Bytes);
try {
assertFalse(kffFile.check(expectedShaHash, results));
assertFalse(kffFile.check(ITEM_NAME, results));
} catch (Exception e) {
fail(e);
}
Expand All @@ -65,4 +91,141 @@ void testKffFileMain() {
String[] args = {resourcePath, resourcePath};
assertDoesNotThrow(() -> KffFile.main(args));
}

@Test
/**
* Tests concurrent KffFile.check invocations to ensure thread-safety
*/
void testConcurrentKffFileCheckCalls() throws Exception {

final Random RANDOM = new Random();
ExecutorService executorService = null;

// the inputs we'll submit, along wth their expected KffFile.check return values
Map<ChecksumResults, Boolean> kffRecords = new HashMap<>();

// parse "known entries" from the binary input file
try (SeekableByteChannel byteChannel = channelFactory.create()) {
int recordCount = (int) (byteChannel.size() / DEFAULT_RECORD_LENGTH);
LOGGER.debug("test file contains {} known file entries", recordCount);

byte[] recordBytes = new byte[DEFAULT_RECORD_LENGTH];
ByteBuffer buffer = ByteBuffer.wrap(recordBytes);

for (int i = 0; i < recordCount; i++) {
buffer.clear();

// parse the next "known file" entry and add it to our inputs, with an expected value of true
byteChannel.position(i * DEFAULT_RECORD_LENGTH);
// read the value into recordBytes
byteChannel.read(buffer);
ChecksumResults csr = buildChecksumResultsWithSha1AndCRC(recordBytes);
kffRecords.put(csr, true);
}
}

int EXPECTED_FAILURE_COUNT = 500;
byte[] recordBytes = new byte[DEFAULT_RECORD_LENGTH];
for (int j = 0; j < EXPECTED_FAILURE_COUNT; j++) {
// build a ChecksumResults entry with random bytes, and add it to our inputs with an expected value of false
RANDOM.nextBytes(recordBytes);
ChecksumResults csr = buildChecksumResultsWithSha1AndCRC(recordBytes);
kffRecords.put(csr, false);
}

// convert collection of inputs to a list of callable tasks we can execute in parallel
List<KffFileCheckTask> callables = kffRecords.entrySet().stream()
.map(entry -> new KffFileCheckTask(kffFile, entry.getKey(), entry.getValue()))
.collect(Collectors.toList());

// shuffle the callables, so we have expected failures interspersed with expected successes
Collections.shuffle(callables);

try {
executorService = Executors.newFixedThreadPool(10);
// invoke the callable tasks concurrently using the thread pool and get their results
List<Future<Boolean>> results = executorService.invokeAll(callables);
for (Future<Boolean> result : results) {
assertTrue(result.get(), "kffFile.check result didn't match expectations");
}
} finally {
if (executorService != null) {
executorService.shutdown();
}
}
}

/**
* Creates a ChecksumResults instance from the provided bytes. The will have a SHA-1 hash value and CRC value.
*
* @param recordBytes input byte array, with expected length {@link KffFile#DEFAULT_RECORD_LENGTH}
* @return the constructed ChecksumBytes instance
*/
private static ChecksumResults buildChecksumResultsWithSha1AndCRC(byte[] recordBytes) {
Validate.notNull(recordBytes, "recordBytes must not be null");
Validate.isTrue(recordBytes.length == DEFAULT_RECORD_LENGTH, "recordBytes must include 24 elements");
byte[] sha1Bytes = getSha1Bytes(recordBytes);
byte[] crc32Bytes = getCrc32BytesLE(recordBytes);
ChecksumResults csr = new ChecksumResults();
csr.setHash("SHA-1", sha1Bytes);
csr.setCrc(ByteUtils.fromLittleEndian(crc32Bytes));
return csr;
}

/**
* Callable to allow for evaluation of {@link KffFile#check(String, ChecksumResults)} calls in parallel
*/
static class KffFileCheckTask implements Callable<Boolean> {
private final KffFile kffFile;
private final ChecksumResults csum;
private final Boolean expectedResult;

KffFileCheckTask(KffFile kffFile, ChecksumResults csum, boolean expectedResult) {
this.kffFile = kffFile;
this.csum = csum;
this.expectedResult = expectedResult;
}


/**
* Computes a result, or throws an exception if unable to do so.
*
* @return computed result
* @throws Exception if unable to compute a result
*/
@Override
public Boolean call() throws Exception {
boolean actual = kffFile.check("ignored param", csum);
LOGGER.debug("expected {}, got {}", expectedResult, actual);
return expectedResult.equals(actual);
}
}

/**
* Retrieves the SHA-1 bytes from the provided array.
*
* @param recordBytes Bytes to parse
* @return The SHA-1 bytes.
*/
private static byte[] getSha1Bytes(byte[] recordBytes) {
Validate.notNull(recordBytes, "recordBytes must not be null");
Validate.isTrue(recordBytes.length == DEFAULT_RECORD_LENGTH, "recordBytes must include 24 elements");
return Arrays.copyOfRange(recordBytes, 0, DEFAULT_RECORD_LENGTH - 4);
}

/**
* Retrieves the last 4 bytes from the input array and reverses their order from big-endian to little-endian
*
* @param recordBytes Bytes to parse
* @return the CRC32 bytes, in litte-endian order
*/
private static byte[] getCrc32BytesLE(byte[] recordBytes) {
Validate.notNull(recordBytes, "recordBytes must not be null");
Validate.isTrue(recordBytes.length == DEFAULT_RECORD_LENGTH, "recordBytes must include 24 elements");
byte[] result = Arrays.copyOfRange(recordBytes, DEFAULT_RECORD_LENGTH - 4, DEFAULT_RECORD_LENGTH);
ArrayUtils.reverse(result);
return result;
}


}