Skip to content

Commit

Permalink
[FLINK-25817] Let TaskLocalStateStoreImpl persist TaskStateSnapshots
Browse files Browse the repository at this point in the history
This commit lets the TaskLocalStateStoreImpl persist the TaskStateSnapshots into the
directory of the local state checkpoint. This allows to recover the TaskStateSnapshots
in case of a process crash. If the TaskStateSnapshot cannot be read then the whole local
checkpointing directory will be deleted to avoid corrupted files.
  • Loading branch information
tillrohrmann committed Feb 8, 2022
1 parent 19efab6 commit 4b39e4a
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ public TaskExecutorLocalStateStoresManager(
@Nonnull Executor discardExecutor)
throws IOException {

LOG.debug(
"Start {} with local state root directories {}.",
getClass().getSimpleName(),
localStateRootDirectories);

this.taskStateStoresByAllocationID = new HashMap<>();
this.localRecoveryEnabled = localRecoveryEnabled;
this.localStateRootDirectories = localStateRootDirectories;
Expand Down Expand Up @@ -193,7 +198,6 @@ public TaskLocalStateStore localStateStoreForSubtask(
}

public void releaseLocalStateForAllocationId(@Nonnull AllocationID allocationID) {

if (LOG.isDebugEnabled()) {
LOG.debug("Releasing local state under allocation id {}.", allocationID);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FlinkRuntimeException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -35,7 +37,11 @@
import javax.annotation.concurrent.GuardedBy;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
Expand All @@ -60,6 +66,8 @@ public class TaskLocalStateStoreImpl implements OwnedTaskLocalStateStore {
/** Dummy value to use instead of null to satisfy {@link ConcurrentHashMap}. */
@VisibleForTesting static final TaskStateSnapshot NULL_DUMMY = new TaskStateSnapshot(0, false);

public static final String TASK_STATE_SNAPSHOT_FILENAME = "_task_state_snapshot";

/** JobID from the owning subtask. */
@Nonnull private final JobID jobID;

Expand Down Expand Up @@ -165,6 +173,7 @@ public void storeLocalState(
} else {
TaskStateSnapshot previous =
storedTaskStateByCheckpointID.put(checkpointId, localState);
persistLocalStateMetadata(checkpointId, localState);

if (previous != null) {
toDiscard = new AbstractMap.SimpleEntry<>(checkpointId, previous);
Expand All @@ -177,14 +186,53 @@ public void storeLocalState(
}
}

/**
* Writes a task state snapshot file that contains the serialized content of the local state.
*
* @param checkpointId identifying the checkpoint
* @param localState task state snapshot that will be persisted
*/
private void persistLocalStateMetadata(long checkpointId, TaskStateSnapshot localState) {
final File taskStateSnapshotFile = getTaskStateSnapshotFile(checkpointId);
try (ObjectOutputStream oos =
new ObjectOutputStream(new FileOutputStream(taskStateSnapshotFile))) {
oos.writeObject(localState);

LOG.debug(
"Successfully written local task state snapshot file {} for checkpoint {}.",
taskStateSnapshotFile,
checkpointId);
} catch (IOException e) {
ExceptionUtils.rethrow(e, "Could not write the local task state snapshot file.");
}
}

@VisibleForTesting
File getTaskStateSnapshotFile(long checkpointId) {
final File checkpointDirectory =
localRecoveryConfig
.getLocalStateDirectoryProvider()
.orElseThrow(
() -> new IllegalStateException("Local recovery must be enabled."))
.subtaskSpecificCheckpointDirectory(checkpointId);

if (!checkpointDirectory.exists() && !checkpointDirectory.mkdirs()) {
throw new FlinkRuntimeException(
String.format(
"Could not create the checkpoint directory '%s'", checkpointDirectory));
}

return new File(checkpointDirectory, TASK_STATE_SNAPSHOT_FILENAME);
}

@Override
@Nullable
public TaskStateSnapshot retrieveLocalState(long checkpointID) {

TaskStateSnapshot snapshot;

synchronized (lock) {
snapshot = storedTaskStateByCheckpointID.get(checkpointID);
snapshot = loadTaskStateSnapshot(checkpointID);
}

if (snapshot != null) {
Expand Down Expand Up @@ -216,6 +264,42 @@ public TaskStateSnapshot retrieveLocalState(long checkpointID) {
return (snapshot != NULL_DUMMY) ? snapshot : null;
}

@GuardedBy("lock")
@Nullable
private TaskStateSnapshot loadTaskStateSnapshot(long checkpointID) {
return storedTaskStateByCheckpointID.computeIfAbsent(
checkpointID, this::tryLoadTaskStateSnapshotFromDisk);
}

@GuardedBy("lock")
@Nullable
private TaskStateSnapshot tryLoadTaskStateSnapshotFromDisk(long checkpointID) {
final File taskStateSnapshotFile = getTaskStateSnapshotFile(checkpointID);

if (taskStateSnapshotFile.exists()) {
TaskStateSnapshot taskStateSnapshot = null;
try (ObjectInputStream ois =
new ObjectInputStream(new FileInputStream(taskStateSnapshotFile))) {
taskStateSnapshot = (TaskStateSnapshot) ois.readObject();

LOG.debug(
"Loaded task state snapshot for checkpoint {} successfully from disk.",
checkpointID);
} catch (IOException | ClassNotFoundException e) {
LOG.debug(
"Could not read task state snapshot file {} for checkpoint {}. Deleting the corresponding local state.",
taskStateSnapshotFile,
checkpointID);

discardLocalStateForCheckpoint(checkpointID, Optional.empty());
}

return taskStateSnapshot;
}

return null;
}

@Override
@Nonnull
public LocalRecoveryConfig getLocalRecoveryConfig() {
Expand Down Expand Up @@ -307,14 +391,14 @@ private void asyncDiscardLocalStateForCollection(
private void syncDiscardLocalStateForCollection(
Collection<Map.Entry<Long, TaskStateSnapshot>> toDiscard) {
for (Map.Entry<Long, TaskStateSnapshot> entry : toDiscard) {
discardLocalStateForCheckpoint(entry.getKey(), entry.getValue());
discardLocalStateForCheckpoint(entry.getKey(), Optional.of(entry.getValue()));
}
}

/**
* Helper method that discards state objects with an executor and reports exceptions to the log.
*/
private void discardLocalStateForCheckpoint(long checkpointID, TaskStateSnapshot o) {
private void discardLocalStateForCheckpoint(long checkpointID, Optional<TaskStateSnapshot> o) {

if (LOG.isTraceEnabled()) {
LOG.trace(
Expand All @@ -333,17 +417,20 @@ private void discardLocalStateForCheckpoint(long checkpointID, TaskStateSnapshot
subtaskIndex);
}

try {
o.discardState();
} catch (Exception discardEx) {
LOG.warn(
"Exception while discarding local task state snapshot of checkpoint {} in subtask ({} - {} - {}).",
checkpointID,
jobID,
jobVertexID,
subtaskIndex,
discardEx);
}
o.ifPresent(
taskStateSnapshot -> {
try {
taskStateSnapshot.discardState();
} catch (Exception discardEx) {
LOG.warn(
"Exception while discarding local task state snapshot of checkpoint {} in subtask ({} - {} - {}).",
checkpointID,
jobID,
jobVertexID,
subtaskIndex,
discardEx);
}
});

Optional<LocalRecoveryDirectoryProvider> directoryProviderOptional =
localRecoveryConfig.getLocalStateDirectoryProvider();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,53 +33,67 @@
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import javax.annotation.Nonnull;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

/** Test for the {@link TaskLocalStateStoreImpl}. */
public class TaskLocalStateStoreImplTest extends TestLogger {

private SortedMap<Long, TaskStateSnapshot> internalSnapshotMap;
private Object internalLock;
private TemporaryFolder temporaryFolder;
private File[] allocationBaseDirs;
private TaskLocalStateStoreImpl taskLocalStateStore;
private JobID jobID;
private AllocationID allocationID;
private JobVertexID jobVertexID;
private int subtaskIdx;

@Before
public void before() throws Exception {
JobID jobID = new JobID();
AllocationID allocationID = new AllocationID();
JobVertexID jobVertexID = new JobVertexID();
int subtaskIdx = 0;
jobID = new JobID();
allocationID = new AllocationID();
jobVertexID = new JobVertexID();
subtaskIdx = 0;
this.temporaryFolder = new TemporaryFolder();
this.temporaryFolder.create();
this.allocationBaseDirs =
new File[] {temporaryFolder.newFolder(), temporaryFolder.newFolder()};
this.internalSnapshotMap = new TreeMap<>();
this.internalLock = new Object();

this.taskLocalStateStore =
createTaskLocalStateStoreImpl(
allocationBaseDirs, jobID, allocationID, jobVertexID, subtaskIdx);
}

@Nonnull
private TaskLocalStateStoreImpl createTaskLocalStateStoreImpl(
File[] allocationBaseDirs,
JobID jobID,
AllocationID allocationID,
JobVertexID jobVertexID,
int subtaskIdx) {
LocalRecoveryDirectoryProviderImpl directoryProvider =
new LocalRecoveryDirectoryProviderImpl(
allocationBaseDirs, jobID, jobVertexID, subtaskIdx);

LocalRecoveryConfig localRecoveryConfig = new LocalRecoveryConfig(directoryProvider);

this.taskLocalStateStore =
new TaskLocalStateStoreImpl(
jobID,
allocationID,
jobVertexID,
subtaskIdx,
localRecoveryConfig,
Executors.directExecutor(),
internalSnapshotMap,
internalLock);
return new TaskLocalStateStoreImpl(
jobID,
allocationID,
jobVertexID,
subtaskIdx,
localRecoveryConfig,
Executors.directExecutor());
}

@After
Expand Down Expand Up @@ -180,6 +194,56 @@ public void dispose() throws Exception {
checkPrunedAndDiscarded(taskStateSnapshots, 0, chkCount);
}

@Test
public void retrieveNullIfNoPersistedLocalState() {
assertThat(taskLocalStateStore.retrieveLocalState(0)).isNull();
}

@Test
public void retrievePersistedLocalStateFromDisc() {
final TaskStateSnapshot taskStateSnapshot = createTaskStateSnapshot();
final long checkpointId = 0L;
taskLocalStateStore.storeLocalState(checkpointId, taskStateSnapshot);

final TaskLocalStateStoreImpl newTaskLocalStateStore =
createTaskLocalStateStoreImpl(
allocationBaseDirs, jobID, allocationID, jobVertexID, 0);

final TaskStateSnapshot retrievedTaskStateSnapshot =
newTaskLocalStateStore.retrieveLocalState(checkpointId);

assertThat(retrievedTaskStateSnapshot).isEqualTo(taskStateSnapshot);
}

@Nonnull
private TaskStateSnapshot createTaskStateSnapshot() {
final Map<OperatorID, OperatorSubtaskState> operatorSubtaskStates = new HashMap<>();
operatorSubtaskStates.put(new OperatorID(), OperatorSubtaskState.builder().build());
operatorSubtaskStates.put(new OperatorID(), OperatorSubtaskState.builder().build());
final TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(operatorSubtaskStates);
return taskStateSnapshot;
}

@Test
public void deletesLocalStateIfRetrievalFails() throws IOException {
final TaskStateSnapshot taskStateSnapshot = createTaskStateSnapshot();
final long checkpointId = 0L;
taskLocalStateStore.storeLocalState(checkpointId, taskStateSnapshot);

final File taskStateSnapshotFile =
taskLocalStateStore.getTaskStateSnapshotFile(checkpointId);

Files.write(
taskStateSnapshotFile.toPath(), new byte[] {1, 2, 3, 4}, StandardOpenOption.WRITE);

final TaskLocalStateStoreImpl newTaskLocalStateStore =
createTaskLocalStateStoreImpl(
allocationBaseDirs, jobID, allocationID, jobVertexID, subtaskIdx);

assertThat(newTaskLocalStateStore.retrieveLocalState(checkpointId)).isNull();
assertThat(taskStateSnapshotFile.getParentFile()).doesNotExist();
}

private void checkStoredAsExpected(List<TestingTaskStateSnapshot> history, int start, int end) {
for (int i = start; i < end; ++i) {
TestingTaskStateSnapshot expected = history.get(i);
Expand Down

0 comments on commit 4b39e4a

Please sign in to comment.