import sys import argparse import numpy as np import sounddevice as sd from transformers import AutoProcessor, CohereAsrForConditionalGeneration from transformers.audio_utils import load_audio from huggingface_hub import hf_hub_download MODEL_ID = "CohereLabs/cohere-transcribe-03-2026" SAMPLE_RATE = 16000 def load_model(): print("Loading model...") processor = AutoProcessor.from_pretrained(MODEL_ID) model = CohereAsrForConditionalGeneration.from_pretrained( MODEL_ID, device_map="auto" ) return processor, model def transcribe_audio(processor, model, audio, language="en"): inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", language=language) inputs.to(model.device, dtype=model.dtype) outputs = model.generate(**inputs, max_new_tokens=256) return processor.decode(outputs, skip_special_tokens=True) def record_audio(duration): print(f"Recording for {duration} seconds...") audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32") sd.wait() return audio.flatten() def main(): parser = argparse.ArgumentParser(description="Cohere ASR Transcription") group = parser.add_mutually_exclusive_group() group.add_argument("--mic", type=int, nargs="?", const=5, metavar="SECONDS", help="Record from microphone for N seconds (default: 5)") group.add_argument("--stream", action="store_true", help="Live streaming transcription with VAD") parser.add_argument("--lang", default="en", help="Language code (default: en)") args = parser.parse_args() if args.stream: processor, model = load_model() stream_transcribe(processor, model, args.lang) elif args.mic is not None: processor, model = load_model() try: mic_audio = record_audio(args.mic) print("Transcribing...") text = transcribe_audio(processor, model, mic_audio, args.lang) print(f"\nTranscription:\n{text}\n") except OSError as e: print(f"Microphone error: {e}") print("Hint: Run with nix-shell for PortAudio support") else: processor, model = load_model() print("Loading demo audio...") audio_file = hf_hub_download(repo_id=MODEL_ID, filename="demo/voxpopuli_test_en_demo.wav") audio = load_audio(audio_file, sampling_rate=SAMPLE_RATE) print("Transcribing...") text = transcribe_audio(processor, model, audio, args.lang) print(f"\nTranscription:\n{text}\n") def stream_transcribe(processor, model, language): print("TODO: streaming mode") if __name__ == "__main__": main()