diff --git a/transcribe.py b/transcribe.py index 6d7a539..efa7f0c 100644 --- a/transcribe.py +++ b/transcribe.py @@ -1,51 +1,74 @@ import sys +import argparse import numpy as np import sounddevice as sd from transformers import AutoProcessor, CohereAsrForConditionalGeneration from transformers.audio_utils import load_audio from huggingface_hub import hf_hub_download -# Load model -print("Loading model...") -processor = AutoProcessor.from_pretrained("CohereLabs/cohere-transcribe-03-2026") -model = CohereAsrForConditionalGeneration.from_pretrained( - "CohereLabs/cohere-transcribe-03-2026", - device_map="auto" -) +MODEL_ID = "CohereLabs/cohere-transcribe-03-2026" +SAMPLE_RATE = 16000 -def transcribe_audio(audio, language="en"): - inputs = processor(audio, sampling_rate=16000, return_tensors="pt", language=language) + +def load_model(): + print("Loading model...") + processor = AutoProcessor.from_pretrained(MODEL_ID) + model = CohereAsrForConditionalGeneration.from_pretrained( + MODEL_ID, device_map="auto" + ) + return processor, model + + +def transcribe_audio(processor, model, audio, language="en"): + inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", language=language) inputs.to(model.device, dtype=model.dtype) - outputs = model.generate(**inputs, max_new_tokens=256) - text = processor.decode(outputs, skip_special_tokens=True) - return text + return processor.decode(outputs, skip_special_tokens=True) -def record_audio(duration, samplerate=16000): + +def record_audio(duration): print(f"Recording for {duration} seconds...") - audio = sd.rec(int(duration * samplerate), samplerate=samplerate, channels=1, dtype='float32') + audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32") sd.wait() return audio.flatten() -# Parse arguments -if len(sys.argv) > 1 and sys.argv[1] == "--mic": - duration = int(sys.argv[2]) if len(sys.argv) > 2 else 5 - try: - mic_audio = record_audio(duration) - print("Transcribing...") - text = transcribe_audio(mic_audio) - print(f"\nTranscription:\n{text}\n") - except OSError as e: - print(f"Microphone error: {e}") - print("Hint: Run with nix-shell for PortAudio support") -else: - print("Loading demo audio...") - audio_file = hf_hub_download( - repo_id="CohereLabs/cohere-transcribe-03-2026", - filename="demo/voxpopuli_test_en_demo.wav", - ) - audio = load_audio(audio_file, sampling_rate=16000) - print("Transcribing...") - text = transcribe_audio(audio) - print(f"\nTranscription:\n{text}\n") +def main(): + parser = argparse.ArgumentParser(description="Cohere ASR Transcription") + group = parser.add_mutually_exclusive_group() + group.add_argument("--mic", type=int, nargs="?", const=5, metavar="SECONDS", + help="Record from microphone for N seconds (default: 5)") + group.add_argument("--stream", action="store_true", + help="Live streaming transcription with VAD") + parser.add_argument("--lang", default="en", help="Language code (default: en)") + args = parser.parse_args() + + if args.stream: + processor, model = load_model() + stream_transcribe(processor, model, args.lang) + elif args.mic is not None: + processor, model = load_model() + try: + mic_audio = record_audio(args.mic) + print("Transcribing...") + text = transcribe_audio(processor, model, mic_audio, args.lang) + print(f"\nTranscription:\n{text}\n") + except OSError as e: + print(f"Microphone error: {e}") + print("Hint: Run with nix-shell for PortAudio support") + else: + processor, model = load_model() + print("Loading demo audio...") + audio_file = hf_hub_download(repo_id=MODEL_ID, filename="demo/voxpopuli_test_en_demo.wav") + audio = load_audio(audio_file, sampling_rate=SAMPLE_RATE) + print("Transcribing...") + text = transcribe_audio(processor, model, audio, args.lang) + print(f"\nTranscription:\n{text}\n") + + +def stream_transcribe(processor, model, language): + print("TODO: streaming mode") + + +if __name__ == "__main__": + main()