refactor: switch to argparse, add --stream and --lang flags

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-29 02:43:47 +08:00
parent 6bff2875c5
commit 4605be5bc9
+58 -35
View File
@@ -1,51 +1,74 @@
import sys import sys
import argparse
import numpy as np import numpy as np
import sounddevice as sd import sounddevice as sd
from transformers import AutoProcessor, CohereAsrForConditionalGeneration from transformers import AutoProcessor, CohereAsrForConditionalGeneration
from transformers.audio_utils import load_audio from transformers.audio_utils import load_audio
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
# Load model MODEL_ID = "CohereLabs/cohere-transcribe-03-2026"
print("Loading model...") SAMPLE_RATE = 16000
processor = AutoProcessor.from_pretrained("CohereLabs/cohere-transcribe-03-2026")
model = CohereAsrForConditionalGeneration.from_pretrained(
"CohereLabs/cohere-transcribe-03-2026",
device_map="auto"
)
def transcribe_audio(audio, language="en"):
inputs = processor(audio, sampling_rate=16000, return_tensors="pt", language=language) def load_model():
print("Loading model...")
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = CohereAsrForConditionalGeneration.from_pretrained(
MODEL_ID, device_map="auto"
)
return processor, model
def transcribe_audio(processor, model, audio, language="en"):
inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", language=language)
inputs.to(model.device, dtype=model.dtype) inputs.to(model.device, dtype=model.dtype)
outputs = model.generate(**inputs, max_new_tokens=256) outputs = model.generate(**inputs, max_new_tokens=256)
text = processor.decode(outputs, skip_special_tokens=True) return processor.decode(outputs, skip_special_tokens=True)
return text
def record_audio(duration, samplerate=16000):
def record_audio(duration):
print(f"Recording for {duration} seconds...") print(f"Recording for {duration} seconds...")
audio = sd.rec(int(duration * samplerate), samplerate=samplerate, channels=1, dtype='float32') audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32")
sd.wait() sd.wait()
return audio.flatten() return audio.flatten()
# Parse arguments
if len(sys.argv) > 1 and sys.argv[1] == "--mic":
duration = int(sys.argv[2]) if len(sys.argv) > 2 else 5
try:
mic_audio = record_audio(duration)
print("Transcribing...")
text = transcribe_audio(mic_audio)
print(f"\nTranscription:\n{text}\n")
except OSError as e:
print(f"Microphone error: {e}")
print("Hint: Run with nix-shell for PortAudio support")
else:
print("Loading demo audio...")
audio_file = hf_hub_download(
repo_id="CohereLabs/cohere-transcribe-03-2026",
filename="demo/voxpopuli_test_en_demo.wav",
)
audio = load_audio(audio_file, sampling_rate=16000)
print("Transcribing...") def main():
text = transcribe_audio(audio) parser = argparse.ArgumentParser(description="Cohere ASR Transcription")
print(f"\nTranscription:\n{text}\n") group = parser.add_mutually_exclusive_group()
group.add_argument("--mic", type=int, nargs="?", const=5, metavar="SECONDS",
help="Record from microphone for N seconds (default: 5)")
group.add_argument("--stream", action="store_true",
help="Live streaming transcription with VAD")
parser.add_argument("--lang", default="en", help="Language code (default: en)")
args = parser.parse_args()
if args.stream:
processor, model = load_model()
stream_transcribe(processor, model, args.lang)
elif args.mic is not None:
processor, model = load_model()
try:
mic_audio = record_audio(args.mic)
print("Transcribing...")
text = transcribe_audio(processor, model, mic_audio, args.lang)
print(f"\nTranscription:\n{text}\n")
except OSError as e:
print(f"Microphone error: {e}")
print("Hint: Run with nix-shell for PortAudio support")
else:
processor, model = load_model()
print("Loading demo audio...")
audio_file = hf_hub_download(repo_id=MODEL_ID, filename="demo/voxpopuli_test_en_demo.wav")
audio = load_audio(audio_file, sampling_rate=SAMPLE_RATE)
print("Transcribing...")
text = transcribe_audio(processor, model, audio, args.lang)
print(f"\nTranscription:\n{text}\n")
def stream_transcribe(processor, model, language):
print("TODO: streaming mode")
if __name__ == "__main__":
main()