From d1fa1ffb15366d4d994eb27a14af113694de223c Mon Sep 17 00:00:00 2001 From: Adolfo Reyna Date: Thu, 26 Feb 2026 20:57:31 -0500 Subject: [PATCH] Optimize for Apple Silicon using MLX --- transcribe.py | 40 +++++++++++----------------------------- 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/transcribe.py b/transcribe.py index e872ef5..b29581b 100644 --- a/transcribe.py +++ b/transcribe.py @@ -1,4 +1,4 @@ -import whisper +import mlx_whisper import numpy as np import sounddevice as sd import queue @@ -7,13 +7,13 @@ import torch from silero_vad import load_silero_vad, get_speech_timestamps # Parameters -MODEL_TYPE = "tiny.en" +MODEL_PATH = "mlx-community/whisper-tiny.en-mlx" # MLX optimized model CHANNELS = 1 SAMPLERATE = 16000 -BLOCK_SIZE = 512 # Silero VAD prefers specific block sizes (512, 1024, 1536) -VAD_THRESHOLD = 0.5 # Confidence threshold for speech -BUFFER_LIMIT = SAMPLERATE * 30 # Max 30 seconds of audio buffer -MIN_SILENCE_DURATION_MS = 500 # Silence duration to trigger transcription +BLOCK_SIZE = 512 # Silero VAD prefers 512, 1024, or 1536 +VAD_THRESHOLD = 0.5 +BUFFER_LIMIT = SAMPLERATE * 30 +MIN_SILENCE_DURATION_MS = 500 audio_queue = queue.Queue() @@ -23,22 +23,15 @@ def callback(indata, frames, time, status): audio_queue.put(indata.copy()) def main(): - print(f"Loading Whisper model '{MODEL_TYPE}'...") - whisper_model = whisper.load_model(MODEL_TYPE) + print(f"Loading MLX-optimized Whisper model '{MODEL_PATH}'...") + # mlx-whisper uses the same model names or Hugging Face paths print("Loading Silero VAD model...") vad_model = load_silero_vad() print("Models loaded.") - print("\nAvailable Audio Devices:") - devices = sd.query_devices() - print(devices) - - default_device = sd.default.device[0] - print(f"\nUsing default input device index: {default_device}") - - print("\nStarting live transcription with VAD... (Press Ctrl+C to stop)") + print("\nStarting live transcription (MLX + VAD)... (Press Ctrl+C to stop)") audio_buffer = [] speech_started = False @@ -51,13 +44,9 @@ def main(): audio_buffer.append(data.flatten()) if len(audio_buffer) > 0: - # Concatenate buffer to check for speech current_audio = np.concatenate(audio_buffer) - - # Convert to torch tensor for Silero audio_tensor = torch.from_numpy(current_audio) - # Get speech timestamps speech_timestamps = get_speech_timestamps( audio_tensor, vad_model, @@ -66,31 +55,24 @@ def main(): min_silence_duration_ms=MIN_SILENCE_DURATION_MS ) - # If we have speech and then silence, or buffer is getting too long if len(speech_timestamps) > 0: speech_started = True - - # Check if the last speech segment has "ended" (i.e., we have enough silence after it) - # or if we've reached a significant buffer size last_end = speech_timestamps[-1]['end'] buffer_len_samples = len(current_audio) - # If the speech ended more than MIN_SILENCE_DURATION_MS ago if (buffer_len_samples - last_end) > (SAMPLERATE * MIN_SILENCE_DURATION_MS / 1000) or buffer_len_samples > BUFFER_LIMIT: - # Transcribe the valid speech segment - result = whisper_model.transcribe(current_audio, fp16=False, language="en") + # Transcribe with MLX + result = mlx_whisper.transcribe(current_audio, path_or_hf_repo=MODEL_PATH) text = result['text'].strip() if text: print(f"Transcription: {text}") - # Reset buffer audio_buffer = [] speech_started = False elif not speech_started and len(current_audio) > SAMPLERATE * 2: - # Clear buffer if it's just silence for more than 2 seconds audio_buffer = [] except KeyboardInterrupt: