import sys import argparse import queue import threading import time import numpy as np import sounddevice as sd from transformers import AutoProcessor, CohereAsrForConditionalGeneration from transformers.audio_utils import load_audio from huggingface_hub import hf_hub_download MODEL_ID = "CohereLabs/cohere-transcribe-03-2026" SAMPLE_RATE = 16000 def load_model(): print("Loading model...") processor = AutoProcessor.from_pretrained(MODEL_ID) model = CohereAsrForConditionalGeneration.from_pretrained( MODEL_ID, device_map="auto" ) return processor, model def transcribe_audio(processor, model, audio, language="en"): inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", language=language) inputs.to(model.device, dtype=model.dtype) outputs = model.generate(**inputs, max_new_tokens=256) return processor.decode(outputs, skip_special_tokens=True) def record_audio(duration): print(f"Recording for {duration} seconds...") audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32") sd.wait() return audio.flatten() def main(): parser = argparse.ArgumentParser(description="Cohere ASR Transcription") group = parser.add_mutually_exclusive_group() group.add_argument("--mic", type=int, nargs="?", const=5, metavar="SECONDS", help="Record from microphone for N seconds (default: 5)") group.add_argument("--stream", action="store_true", help="Live streaming transcription with VAD") parser.add_argument("--lang", default="en", help="Language code (default: en)") args = parser.parse_args() if args.stream: processor, model = load_model() stream_transcribe(processor, model, args.lang) elif args.mic is not None: processor, model = load_model() try: mic_audio = record_audio(args.mic) print("Transcribing...") text = transcribe_audio(processor, model, mic_audio, args.lang) print(f"\nTranscription:\n{text}\n") except OSError as e: print(f"Microphone error: {e}") print("Hint: Run with nix-shell for PortAudio support") else: processor, model = load_model() print("Loading demo audio...") audio_file = hf_hub_download(repo_id=MODEL_ID, filename="demo/voxpopuli_test_en_demo.wav") audio = load_audio(audio_file, sampling_rate=SAMPLE_RATE) print("Transcribing...") text = transcribe_audio(processor, model, audio, args.lang) print(f"\nTranscription:\n{text}\n") def calibrate_silence(duration=0.5): print("Calibrating silence threshold...") audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32") sd.wait() rms = np.sqrt(np.mean(audio ** 2)) threshold = max(rms * 3, 0.01) print(f" Ambient RMS: {rms:.4f}, threshold: {threshold:.4f}") return threshold FRAME_SIZE = 800 # 50ms at 16kHz PRE_ROLL_FRAMES = 6 # ~0.3s of audio before speech onset SILENCE_FRAMES = 16 # ~0.8s of silence to end a segment SPEECH_ONSET_FRAMES = 3 # ~150ms of speech to trigger MAX_SPEECH_SECONDS = 30 # force chunk boundary class VADStateMachine: def __init__(self, threshold): self.threshold = threshold self.speaking = False self.speech_frames = 0 self.silence_frames = 0 self.pre_roll = [] self.segment = [] self.segment_start_time = 0.0 def process_frame(self, frame, elapsed_time): """Process one 50ms frame. Returns a (start_time, audio_array) tuple when a complete speech segment is detected, otherwise None.""" rms = np.sqrt(np.mean(frame ** 2)) is_loud = rms > self.threshold if not self.speaking: self.pre_roll.append(frame) if len(self.pre_roll) > PRE_ROLL_FRAMES: self.pre_roll.pop(0) if is_loud: self.speech_frames += 1 if self.speech_frames >= SPEECH_ONSET_FRAMES: self.speaking = True self.silence_frames = 0 self.segment = list(self.pre_roll) self.segment_start_time = max(0.0, elapsed_time - len(self.pre_roll) * FRAME_SIZE / SAMPLE_RATE) self.pre_roll = [] else: self.speech_frames = 0 return None # Currently speaking self.segment.append(frame) if is_loud: self.silence_frames = 0 else: self.silence_frames += 1 segment_duration = len(self.segment) * FRAME_SIZE / SAMPLE_RATE if self.silence_frames >= SILENCE_FRAMES or segment_duration >= MAX_SPEECH_SECONDS: result = (self.segment_start_time, np.concatenate(self.segment)) self.speaking = False self.speech_frames = 0 self.silence_frames = 0 self.segment = [] self.pre_roll = [] return result return None def stream_transcribe(processor, model, language): threshold = calibrate_silence() vad = VADStateMachine(threshold) seg_queue = queue.Queue() stop_event = threading.Event() start_time = time.monotonic() def transcription_worker(): while not stop_event.is_set() or not seg_queue.empty(): try: seg_start, audio = seg_queue.get(timeout=0.5) except queue.Empty: continue minutes = int(seg_start) // 60 seconds = int(seg_start) % 60 text = transcribe_audio(processor, model, audio, language) if text.strip(): print(f"[{minutes:02d}:{seconds:02d}] {text.strip()}") worker = threading.Thread(target=transcription_worker, daemon=True) worker.start() frame_buf = np.empty(0, dtype="float32") def audio_callback(indata, frames, time_info, status): nonlocal frame_buf if stop_event.is_set(): return frame_buf = np.append(frame_buf, indata[:, 0]) while len(frame_buf) >= FRAME_SIZE: frame = frame_buf[:FRAME_SIZE] frame_buf = frame_buf[FRAME_SIZE:] elapsed = time.monotonic() - start_time result = vad.process_frame(frame, elapsed) if result is not None: seg_queue.put(result) print("Listening... (Ctrl+C to stop)") stream = sd.InputStream( samplerate=SAMPLE_RATE, channels=1, dtype="float32", callback=audio_callback, blocksize=FRAME_SIZE, ) try: with stream: while True: time.sleep(0.1) except KeyboardInterrupt: pass stop_event.set() # Flush any remaining speech segment if vad.speaking and vad.segment: elapsed = time.monotonic() - start_time seg_queue.put((vad.segment_start_time, np.concatenate(vad.segment))) worker.join(timeout=30) print("\nDone.") if __name__ == "__main__": main()