From 853b5523e52b7cb6be9c3f3adc622027d8ef0fa8 Mon Sep 17 00:00:00 2001 From: Wong Ding Feng Date: Sat, 6 Jun 2026 22:51:06 +0800 Subject: [PATCH] feat: add --device flag and devices command for mic selection Lets the user pick an input device by index or name substring. Adds `cohere devices` for listing. For devices that don't support 16kHz natively (e.g. Sipeed MicArray hw at 48kHz), captures at the device's native rate and resamples to 16kHz via scipy.signal.resample_poly. Co-Authored-By: Claude Opus 4.7 --- src/cohere_transcribe/cli/cli.py | 42 ++++++++++++++++++++-- src/cohere_transcribe/daemon.py | 25 +++++++++---- src/cohere_transcribe/daemon_main.py | 13 ++++++- src/cohere_transcribe/model.py | 8 +++-- src/cohere_transcribe/stream.py | 16 +++++---- src/cohere_transcribe/vad.py | 52 +++++++++++++++++++++++++--- 6 files changed, 132 insertions(+), 24 deletions(-) diff --git a/src/cohere_transcribe/cli/cli.py b/src/cohere_transcribe/cli/cli.py index d4561a1..154a6b5 100644 --- a/src/cohere_transcribe/cli/cli.py +++ b/src/cohere_transcribe/cli/cli.py @@ -12,10 +12,20 @@ app = typer.Typer(help="Cohere live transcription — speaks into your keyboard. console = Console() +def _parse_device(value: str | None): + if value is None: + return None + try: + return int(value) + except ValueError: + return value + + @app.command() def on( language: str = typer.Option("en", "--lang", "-l", help="Language code"), pause: float = typer.Option(0.3, "--pause", "-p", help="Seconds of silence before sending text"), + device: str = typer.Option(None, "--device", "-d", help="Input device index or name substring (see `cohere devices`)"), foreground: bool = typer.Option(False, "--fg", help="Run in foreground (don't daemonize)"), ): """Start transcribing and typing into your focused window.""" @@ -26,7 +36,7 @@ def on( if foreground: from ..daemon import run_daemon console.print("[green]Starting cohere (foreground)...[/green]") - run_daemon(language, pause=pause) + run_daemon(language, pause=pause, device=_parse_device(device)) return console.print("[green]Starting cohere daemon...[/green]") @@ -34,6 +44,8 @@ def on( cmd = [sys.executable, "-m", "cohere_transcribe.daemon_main", "--lang", language] if pause != 0.3: cmd += ["--pause", str(pause)] + if device is not None: + cmd += ["--device", device] subprocess.Popen( cmd, start_new_session=True, @@ -90,20 +102,23 @@ def transcribe( stream: bool = typer.Option(False, "--stream", "-s", help="Live streaming mode (prints to terminal)"), language: str = typer.Option("en", "--lang", "-l", help="Language code"), pause: float = typer.Option(0.3, "--pause", "-p", help="Seconds of silence before sending text"), + device: str = typer.Option(None, "--device", "-d", help="Input device index or name substring (see `cohere devices`)"), ): """One-shot transcription (file, mic, or stream to terminal).""" from ..model import load_model, transcribe_audio from ..vad import pause_seconds_to_frames + dev = _parse_device(device) + if stream: from ..stream import stream_transcribe processor, model = load_model() - stream_transcribe(processor, model, language, silence_frames=pause_seconds_to_frames(pause)) + stream_transcribe(processor, model, language, silence_frames=pause_seconds_to_frames(pause), device=dev) elif mic is not None: from ..model import record_audio processor, model = load_model() try: - audio = record_audio(mic) + audio = record_audio(mic, device=dev) console.print("Transcribing...") text = transcribe_audio(processor, model, audio, language) console.print(f"\n{text}\n") @@ -122,5 +137,26 @@ def transcribe( raise typer.Exit(1) +@app.command() +def devices(): + """List available audio input devices.""" + import sounddevice as sd + + default_in = sd.default.device[0] + for idx, dev in enumerate(sd.query_devices()): + if dev["max_input_channels"] <= 0: + continue + marker = "[green]*[/green]" if idx == default_in else " " + hostapi = sd.query_hostapis(dev["hostapi"])["name"] + console.print( + f"{marker} [bold]{idx:>2}[/bold] {dev['name']} " + f"[dim]({dev['max_input_channels']}ch, {int(dev['default_samplerate'])}Hz, {hostapi})[/dim]" + ) + console.print( + "\n[dim]Tip: indices can shift between runs on PipeWire. " + "Prefer [bold]-d pipewire[/bold] (uses PipeWire's default source) or pass a name substring like [bold]-d Sipeed[/bold].[/dim]" + ) + + def main(): app() diff --git a/src/cohere_transcribe/daemon.py b/src/cohere_transcribe/daemon.py index 131e7c2..82aff5c 100644 --- a/src/cohere_transcribe/daemon.py +++ b/src/cohere_transcribe/daemon.py @@ -11,7 +11,16 @@ import sounddevice as sd from .backend import WtypeBackend from .commands import process_and_output from .model import SAMPLE_RATE, load_model, transcribe_audio -from .vad import DEFAULT_SILENCE_FRAMES, FRAME_SIZE, VADStateMachine, calibrate_silence, pause_seconds_to_frames +from .vad import ( + DEFAULT_SILENCE_FRAMES, + FRAME_SIZE, + VADStateMachine, + calibrate_silence, + describe_input_device, + pause_seconds_to_frames, + resample_to_target, + resolve_input_rate, +) STATE_DIR = os.path.expanduser("~/.local/state/cohere") STATE_FILE = os.path.join(STATE_DIR, "state.json") @@ -71,7 +80,7 @@ def stop_daemon() -> bool: return False -def run_daemon(language: str = "en", pause: float | None = None): +def run_daemon(language: str = "en", pause: float | None = None, device=None): pid = os.getpid() _write_state(pid, "starting") @@ -82,7 +91,10 @@ def run_daemon(language: str = "en", pause: float | None = None): silence_frames = pause_seconds_to_frames(pause) if pause else DEFAULT_SILENCE_FRAMES processor, model = load_model() - threshold = calibrate_silence() + print(f"Using input device: {describe_input_device(device)}") + threshold = calibrate_silence(device=device) + capture_rate = resolve_input_rate(device) + capture_blocksize = FRAME_SIZE * capture_rate // SAMPLE_RATE vad = VADStateMachine(threshold, silence_frames=silence_frames) seg_queue: queue.Queue = queue.Queue() stop_event = threading.Event() @@ -108,13 +120,14 @@ def run_daemon(language: str = "en", pause: float | None = None): if stop_event.is_set(): return elapsed = time.monotonic() - start_time - result = vad.process_frame(indata[:, 0].copy(), elapsed) + frame = resample_to_target(indata[:, 0].copy(), capture_rate) + result = vad.process_frame(frame, elapsed) if result is not None: seg_queue.put(result) stream = sd.InputStream( - samplerate=SAMPLE_RATE, channels=1, dtype="float32", - callback=audio_callback, blocksize=FRAME_SIZE, + samplerate=capture_rate, channels=1, dtype="float32", + callback=audio_callback, blocksize=capture_blocksize, device=device, ) try: diff --git a/src/cohere_transcribe/daemon_main.py b/src/cohere_transcribe/daemon_main.py index 0d21305..7fca2c3 100644 --- a/src/cohere_transcribe/daemon_main.py +++ b/src/cohere_transcribe/daemon_main.py @@ -2,8 +2,19 @@ import argparse from .daemon import run_daemon + +def _parse_device(value): + if value is None: + return None + try: + return int(value) + except ValueError: + return value + + parser = argparse.ArgumentParser() parser.add_argument("--lang", default="en") parser.add_argument("--pause", type=float, default=None) +parser.add_argument("--device", default=None) args = parser.parse_args() -run_daemon(args.lang, pause=args.pause) +run_daemon(args.lang, pause=args.pause, device=_parse_device(args.device)) diff --git a/src/cohere_transcribe/model.py b/src/cohere_transcribe/model.py index 98938f8..128c290 100644 --- a/src/cohere_transcribe/model.py +++ b/src/cohere_transcribe/model.py @@ -23,10 +23,12 @@ def transcribe_audio(processor, model, audio, language="en"): return " ".join(texts) if isinstance(texts, list) else texts -def record_audio(duration): +def record_audio(duration, device=None): import sounddevice as sd + from .vad import resolve_input_rate, resample_to_target print(f"Recording for {duration} seconds...") - audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32") + rate = resolve_input_rate(device) + audio = sd.rec(int(duration * rate), samplerate=rate, channels=1, dtype="float32", device=device) sd.wait() - return audio.flatten() + return resample_to_target(audio.flatten(), rate) diff --git a/src/cohere_transcribe/stream.py b/src/cohere_transcribe/stream.py index 7040941..a75fee4 100644 --- a/src/cohere_transcribe/stream.py +++ b/src/cohere_transcribe/stream.py @@ -7,11 +7,14 @@ import numpy as np import sounddevice as sd from .model import SAMPLE_RATE, transcribe_audio -from .vad import DEFAULT_SILENCE_FRAMES, FRAME_SIZE, VADStateMachine, calibrate_silence +from .vad import DEFAULT_SILENCE_FRAMES, FRAME_SIZE, VADStateMachine, calibrate_silence, describe_input_device, resample_to_target, resolve_input_rate -def stream_transcribe(processor, model, language, silence_frames=DEFAULT_SILENCE_FRAMES): - threshold = calibrate_silence() +def stream_transcribe(processor, model, language, silence_frames=DEFAULT_SILENCE_FRAMES, device=None): + print(f"Using input device: {describe_input_device(device)}") + threshold = calibrate_silence(device=device) + capture_rate = resolve_input_rate(device) + capture_blocksize = FRAME_SIZE * capture_rate // SAMPLE_RATE vad = VADStateMachine(threshold, silence_frames=silence_frames) seg_queue = queue.Queue() stop_event = threading.Event() @@ -36,14 +39,15 @@ def stream_transcribe(processor, model, language, silence_frames=DEFAULT_SILENCE if stop_event.is_set(): return elapsed = time.monotonic() - start_time - result = vad.process_frame(indata[:, 0].copy(), elapsed) + frame = resample_to_target(indata[:, 0].copy(), capture_rate) + result = vad.process_frame(frame, elapsed) if result is not None: seg_queue.put(result) print("Listening... (Ctrl+C to stop)") stream = sd.InputStream( - samplerate=SAMPLE_RATE, channels=1, dtype="float32", - callback=audio_callback, blocksize=FRAME_SIZE, + samplerate=capture_rate, channels=1, dtype="float32", + callback=audio_callback, blocksize=capture_blocksize, device=device, ) try: diff --git a/src/cohere_transcribe/vad.py b/src/cohere_transcribe/vad.py index e4168ec..0306c31 100644 --- a/src/cohere_transcribe/vad.py +++ b/src/cohere_transcribe/vad.py @@ -1,4 +1,5 @@ import collections +from math import gcd import numpy as np import sounddevice as sd @@ -10,17 +11,56 @@ PRE_ROLL_FRAMES = 6 # ~0.3s of audio before speech onset DEFAULT_SILENCE_FRAMES = 16 # ~0.8s of silence to end a segment SPEECH_ONSET_FRAMES = 3 # ~150ms of speech to trigger MAX_SPEECH_SECONDS = 30 # force chunk boundary -MIN_SPEECH_SECONDS = 0.3 # discard segments shorter than this (mic bumps, clicks) +MIN_LOUD_FRAMES = 8 # need at least ~400ms of loud frames to count as speech def pause_seconds_to_frames(seconds: float) -> int: return max(1, round(seconds / (FRAME_SIZE / SAMPLE_RATE))) -def calibrate_silence(duration=0.5): +def _query_input(device): + resolved = device if device is not None else sd.default.device[0] + info = sd.query_devices(resolved) + if info["max_input_channels"] < 1: + raise ValueError( + f"Device {device!r} ({info['name']}) is not an input device. " + f"Run `cohere devices` to see current input indices — they can shift between runs on PipeWire." + ) + return info + + +def describe_input_device(device) -> str: + info = _query_input(device) + return info["name"] + + +def resolve_input_rate(device) -> int: + """Pick a samplerate the device will accept. Prefer SAMPLE_RATE; if the device + refuses (e.g. raw ALSA hw: that doesn't resample), fall back to its native rate.""" + info = _query_input(device) + try: + sd.check_input_settings(device=device, samplerate=SAMPLE_RATE, channels=1, dtype="float32") + return SAMPLE_RATE + except sd.PortAudioError: + rate = int(info["default_samplerate"]) + print(f" Device doesn't support {SAMPLE_RATE}Hz; capturing at {rate}Hz and resampling.") + return rate + + +def resample_to_target(audio: np.ndarray, src_rate: int) -> np.ndarray: + if src_rate == SAMPLE_RATE: + return audio + from scipy.signal import resample_poly + g = gcd(SAMPLE_RATE, src_rate) + return resample_poly(audio, SAMPLE_RATE // g, src_rate // g).astype(np.float32) + + +def calibrate_silence(duration=0.5, device=None): print("Calibrating silence threshold...") - audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32") + rate = resolve_input_rate(device) + audio = sd.rec(int(duration * rate), samplerate=rate, channels=1, dtype="float32", device=device) sd.wait() + audio = resample_to_target(audio.flatten(), rate) rms = np.sqrt(np.mean(audio ** 2)) threshold = max(rms * 3, 0.01) print(f" Ambient RMS: {rms:.4f}, threshold: {threshold:.4f}") @@ -34,6 +74,7 @@ class VADStateMachine: self.speaking = False self.speech_frames = 0 self.silence_frames = 0 + self.loud_frames = 0 self.pre_roll = collections.deque(maxlen=PRE_ROLL_FRAMES) self.segment = [] self.segment_start_time = 0.0 @@ -63,18 +104,19 @@ class VADStateMachine: if is_loud: self.silence_frames = 0 + self.loud_frames += 1 else: self.silence_frames += 1 segment_duration = len(self.segment) * FRAME_SIZE / SAMPLE_RATE if self.silence_frames >= self.silence_limit or segment_duration >= MAX_SPEECH_SECONDS: - speech_duration = segment_duration - self.silence_frames * FRAME_SIZE / SAMPLE_RATE result = None - if speech_duration >= MIN_SPEECH_SECONDS: + if self.loud_frames >= MIN_LOUD_FRAMES: result = (self.segment_start_time, np.concatenate(self.segment)) self.speaking = False self.speech_frames = 0 self.silence_frames = 0 + self.loud_frames = 0 self.segment = [] self.pre_roll = collections.deque(maxlen=PRE_ROLL_FRAMES) return result