Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix early stopping for VAD #155

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -608,14 +608,14 @@
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_NSMicrophoneUsageDescription = "Required to record audio from the microphone for transcription.";
INFOPLIST_KEY_UISupportedInterfaceOrientations = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown";
INFOPLIST_KEY_WKCompanionAppBundleIdentifier = com.argmax.whisperkit.WhisperAX;
INFOPLIST_KEY_WKCompanionAppBundleIdentifier = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}";
INFOPLIST_KEY_WKRunsIndependentlyOfCompanionApp = YES;
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 0.1.2;
PRODUCT_BUNDLE_IDENTIFIER = com.argmax.whisperkit.WhisperAX.watchapp;
PRODUCT_BUNDLE_IDENTIFIER = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}.watchapp";
PRODUCT_NAME = "WhisperAX Watch App";
PROVISIONING_PROFILE_SPECIFIER = "";
SDKROOT = watchos;
Expand Down Expand Up @@ -893,7 +893,7 @@
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.0;
MARKETING_VERSION = 0.3.0;
MARKETING_VERSION = 0.3.1;
PRODUCT_BUNDLE_IDENTIFIER = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}";
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
Expand Down Expand Up @@ -939,7 +939,7 @@
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.0;
MARKETING_VERSION = 0.3.0;
MARKETING_VERSION = 0.3.1;
PRODUCT_BUNDLE_IDENTIFIER = com.argmax.whisperkit.WhisperAX;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
Expand Down
4 changes: 4 additions & 0 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1308,10 +1308,12 @@ struct ContentView: View {
let checkTokens: [Int] = currentTokens.suffix(checkWindow)
let compressionRatio = compressionRatio(of: checkTokens)
if compressionRatio > options.compressionRatioThreshold! {
Logging.debug("Early stopping due to compression threshold")
return false
}
}
if progress.avgLogprob! < options.logProbThreshold! {
Logging.debug("Early stopping due to logprob threshold")
return false
}
return nil
Expand Down Expand Up @@ -1519,10 +1521,12 @@ struct ContentView: View {
let checkTokens: [Int] = currentTokens.suffix(checkWindow)
let compressionRatio = compressionRatio(of: checkTokens)
if compressionRatio > options.compressionRatioThreshold! {
Logging.debug("Early stopping due to compression threshold")
return false
}
}
if progress.avgLogprob! < options.logProbThreshold! {
Logging.debug("Early stopping due to logprob threshold")
return false
}

Expand Down
12 changes: 8 additions & 4 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
public var tokenizer: WhisperTokenizer?
public var prefillData: WhisperMLModel?
public var isModelMultilingual: Bool = false
public var shouldEarlyStop: Bool = false
public var shouldEarlyStop = [UUID: Bool]()
private var languageLogitsFilter: LanguageLogitsFilter?

public var supportsWordTimestamps: Bool {
Expand Down Expand Up @@ -588,7 +588,8 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
Logging.debug("Running main loop for a maximum of \(loopCount) iterations, starting at index \(prefilledIndex)")
var hasAlignment = false
var isFirstTokenLogProbTooLow = false
shouldEarlyStop = false
let windowUUID = UUID()
shouldEarlyStop[windowUUID] = false
for tokenIndex in prefilledIndex..<loopCount {
let loopStart = Date()

Expand Down Expand Up @@ -733,7 +734,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
let shouldContinue = callback(result)
if let shouldContinue = shouldContinue, !shouldContinue, !isPrefill {
Logging.debug("Early stopping")
self?.shouldEarlyStop = true
self?.shouldEarlyStop[windowUUID] = true
}
}
}
Expand All @@ -749,10 +750,13 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
}

// Check if early stopping is triggered
if shouldEarlyStop {
if let shouldStop = shouldEarlyStop[windowUUID], shouldStop {
break
}
}

// Cleanup the early stop flag after loop completion
shouldEarlyStop.removeValue(forKey: windowUUID)

let cache = DecodingCache(
keyCache: decoderInputs.keyCache,
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public extension MLComputeUnits {

#if os(macOS)
// From: https://stackoverflow.com/a/71726663
extension Process {
public extension Process {
static func stringFromTerminal(command: String) -> String {
let task = Process()
let pipe = Pipe()
Expand Down