diff --git a/.gitignore b/.gitignore index 505a3b1..72a0047 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,7 @@ wheels/ # Virtual environments .venv + +# Nix +.direnv/ +result diff --git a/main.py b/main.py deleted file mode 100644 index 83d671f..0000000 --- a/main.py +++ /dev/null @@ -1,6 +0,0 @@ -def main(): - print("Hello from cohere!") - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index 7cd40d0..bebb5c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "cohere" +name = "cohere-transcribe" version = "0.1.0" -description = "Add your description here" +description = "Live speech transcription using Cohere ASR" readme = "README.md" requires-python = ">=3.14" dependencies = [ @@ -15,3 +15,13 @@ dependencies = [ "torch>=2.12.0", "transformers>=5.9.0", ] + +[project.scripts] +cohere-transcribe = "cohere_transcribe.cli:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.backends" + +[tool.hatch.build.targets.wheel] +packages = ["src/cohere_transcribe"] diff --git a/shell.nix b/shell.nix deleted file mode 100644 index 28ed63a..0000000 --- a/shell.nix +++ /dev/null @@ -1,15 +0,0 @@ -{ pkgs ? import { config.allowUnfree = true; } }: - -pkgs.mkShell { - buildInputs = with pkgs; [ - portaudio - cudaPackages.cudatoolkit - uv - python314 - ]; - - shellHook = '' - export LD_LIBRARY_PATH="${pkgs.cudaPackages.cudatoolkit}/lib:$LD_LIBRARY_PATH" - echo "Dev shell ready - microphone input enabled" - ''; -} diff --git a/src/cohere_transcribe/__init__.py b/src/cohere_transcribe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/cohere_transcribe/cli.py b/src/cohere_transcribe/cli.py new file mode 100644 index 0000000..1c4d581 --- /dev/null +++ b/src/cohere_transcribe/cli.py @@ -0,0 +1,40 @@ +import argparse + +from huggingface_hub import hf_hub_download +from transformers.audio_utils import load_audio + +from .model import MODEL_ID, SAMPLE_RATE, load_model, record_audio, transcribe_audio +from .stream import stream_transcribe + + +def main(): + parser = argparse.ArgumentParser(description="Cohere ASR Transcription") + group = parser.add_mutually_exclusive_group() + group.add_argument("--mic", type=int, nargs="?", const=5, metavar="SECONDS", + help="Record from microphone for N seconds (default: 5)") + group.add_argument("--stream", action="store_true", + help="Live streaming transcription with VAD") + parser.add_argument("--lang", default="en", help="Language code (default: en)") + args = parser.parse_args() + + if args.stream: + processor, model = load_model() + stream_transcribe(processor, model, args.lang) + elif args.mic is not None: + processor, model = load_model() + try: + mic_audio = record_audio(args.mic) + print("Transcribing...") + text = transcribe_audio(processor, model, mic_audio, args.lang) + print(f"\nTranscription:\n{text}\n") + except OSError as e: + print(f"Microphone error: {e}") + print("Hint: Run with nix-shell for PortAudio support") + else: + processor, model = load_model() + print("Loading demo audio...") + audio_file = hf_hub_download(repo_id=MODEL_ID, filename="demo/voxpopuli_test_en_demo.wav") + audio = load_audio(audio_file, sampling_rate=SAMPLE_RATE) + print("Transcribing...") + text = transcribe_audio(processor, model, audio, args.lang) + print(f"\nTranscription:\n{text}\n") diff --git a/src/cohere_transcribe/model.py b/src/cohere_transcribe/model.py new file mode 100644 index 0000000..98938f8 --- /dev/null +++ b/src/cohere_transcribe/model.py @@ -0,0 +1,32 @@ +import numpy as np +from transformers import AutoProcessor, CohereAsrForConditionalGeneration +from transformers.audio_utils import load_audio + +MODEL_ID = "CohereLabs/cohere-transcribe-03-2026" +SAMPLE_RATE = 16000 + + +def load_model(): + print("Loading model...") + processor = AutoProcessor.from_pretrained(MODEL_ID) + model = CohereAsrForConditionalGeneration.from_pretrained( + MODEL_ID, device_map="auto" + ) + return processor, model + + +def transcribe_audio(processor, model, audio, language="en"): + inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", language=language) + inputs.to(model.device, dtype=model.dtype) + outputs = model.generate(**inputs, max_new_tokens=256) + texts = processor.decode(outputs, skip_special_tokens=True) + return " ".join(texts) if isinstance(texts, list) else texts + + +def record_audio(duration): + import sounddevice as sd + + print(f"Recording for {duration} seconds...") + audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32") + sd.wait() + return audio.flatten() diff --git a/src/cohere_transcribe/stream.py b/src/cohere_transcribe/stream.py new file mode 100644 index 0000000..5f06412 --- /dev/null +++ b/src/cohere_transcribe/stream.py @@ -0,0 +1,64 @@ +import sys +import queue +import threading +import time + +import numpy as np +import sounddevice as sd + +from .model import SAMPLE_RATE, transcribe_audio +from .vad import FRAME_SIZE, VADStateMachine, calibrate_silence + + +def stream_transcribe(processor, model, language): + threshold = calibrate_silence() + vad = VADStateMachine(threshold) + seg_queue = queue.Queue() + stop_event = threading.Event() + start_time = time.monotonic() + + def transcription_worker(): + while not stop_event.is_set() or not seg_queue.empty(): + try: + seg_start, audio = seg_queue.get(timeout=0.5) + except queue.Empty: + continue + minutes = int(seg_start) // 60 + seconds = int(seg_start) % 60 + text = transcribe_audio(processor, model, audio, language) + if text.strip(): + print(f"[{minutes:02d}:{seconds:02d}] {text.strip()}") + + worker = threading.Thread(target=transcription_worker, daemon=True) + worker.start() + + def audio_callback(indata, frames, time_info, status): + if stop_event.is_set(): + return + elapsed = time.monotonic() - start_time + result = vad.process_frame(indata[:, 0].copy(), elapsed) + if result is not None: + seg_queue.put(result) + + print("Listening... (Ctrl+C to stop)") + stream = sd.InputStream( + samplerate=SAMPLE_RATE, channels=1, dtype="float32", + callback=audio_callback, blocksize=FRAME_SIZE, + ) + + try: + with stream: + while True: + time.sleep(0.1) + except KeyboardInterrupt: + pass + + stop_event.set() + + if vad.speaking and vad.segment: + seg_queue.put((vad.segment_start_time, np.concatenate(vad.segment))) + + worker.join(timeout=30) + if worker.is_alive(): + print("Warning: transcription worker did not finish in time.", file=sys.stderr) + print("\nDone.") diff --git a/src/cohere_transcribe/vad.py b/src/cohere_transcribe/vad.py new file mode 100644 index 0000000..e292c22 --- /dev/null +++ b/src/cohere_transcribe/vad.py @@ -0,0 +1,73 @@ +import collections + +import numpy as np +import sounddevice as sd + +from .model import SAMPLE_RATE + +FRAME_SIZE = 800 # 50ms at 16kHz +PRE_ROLL_FRAMES = 6 # ~0.3s of audio before speech onset +SILENCE_FRAMES = 16 # ~0.8s of silence to end a segment +SPEECH_ONSET_FRAMES = 3 # ~150ms of speech to trigger +MAX_SPEECH_SECONDS = 30 # force chunk boundary + + +def calibrate_silence(duration=0.5): + print("Calibrating silence threshold...") + audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32") + sd.wait() + rms = np.sqrt(np.mean(audio ** 2)) + threshold = max(rms * 3, 0.01) + print(f" Ambient RMS: {rms:.4f}, threshold: {threshold:.4f}") + return threshold + + +class VADStateMachine: + def __init__(self, threshold): + self.threshold = threshold + self.speaking = False + self.speech_frames = 0 + self.silence_frames = 0 + self.pre_roll = collections.deque(maxlen=PRE_ROLL_FRAMES) + self.segment = [] + self.segment_start_time = 0.0 + + def process_frame(self, frame, elapsed_time): + """Process one 50ms frame. Returns a (start_time, audio_array) tuple when a + complete speech segment is detected, otherwise None.""" + rms = np.sqrt(np.mean(frame ** 2)) + is_loud = rms > self.threshold + + if not self.speaking: + self.pre_roll.append(frame) + + if is_loud: + self.speech_frames += 1 + if self.speech_frames >= SPEECH_ONSET_FRAMES: + self.speaking = True + self.silence_frames = 0 + self.segment = list(self.pre_roll) + self.segment_start_time = max(0.0, elapsed_time - len(self.pre_roll) * FRAME_SIZE / SAMPLE_RATE) + self.pre_roll = collections.deque(maxlen=PRE_ROLL_FRAMES) + else: + self.speech_frames = 0 + return None + + self.segment.append(frame) + + if is_loud: + self.silence_frames = 0 + else: + self.silence_frames += 1 + + segment_duration = len(self.segment) * FRAME_SIZE / SAMPLE_RATE + if self.silence_frames >= SILENCE_FRAMES or segment_duration >= MAX_SPEECH_SECONDS: + result = (self.segment_start_time, np.concatenate(self.segment)) + self.speaking = False + self.speech_frames = 0 + self.silence_frames = 0 + self.segment = [] + self.pre_roll = collections.deque(maxlen=PRE_ROLL_FRAMES) + return result + + return None diff --git a/test_mic.py b/tests/test_mic.py similarity index 100% rename from test_mic.py rename to tests/test_mic.py diff --git a/transcribe.py b/transcribe.py deleted file mode 100644 index 45f373f..0000000 --- a/transcribe.py +++ /dev/null @@ -1,200 +0,0 @@ -import sys -import argparse -import collections -import queue -import threading -import time -import numpy as np -import sounddevice as sd -from transformers import AutoProcessor, CohereAsrForConditionalGeneration -from transformers.audio_utils import load_audio -from huggingface_hub import hf_hub_download - -MODEL_ID = "CohereLabs/cohere-transcribe-03-2026" -SAMPLE_RATE = 16000 - - -def load_model(): - print("Loading model...") - processor = AutoProcessor.from_pretrained(MODEL_ID) - model = CohereAsrForConditionalGeneration.from_pretrained( - MODEL_ID, device_map="auto" - ) - return processor, model - - -def transcribe_audio(processor, model, audio, language="en"): - inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", language=language) - inputs.to(model.device, dtype=model.dtype) - outputs = model.generate(**inputs, max_new_tokens=256) - texts = processor.decode(outputs, skip_special_tokens=True) - return " ".join(texts) if isinstance(texts, list) else texts - - -def record_audio(duration): - print(f"Recording for {duration} seconds...") - audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32") - sd.wait() - return audio.flatten() - - -def main(): - parser = argparse.ArgumentParser(description="Cohere ASR Transcription") - group = parser.add_mutually_exclusive_group() - group.add_argument("--mic", type=int, nargs="?", const=5, metavar="SECONDS", - help="Record from microphone for N seconds (default: 5)") - group.add_argument("--stream", action="store_true", - help="Live streaming transcription with VAD") - parser.add_argument("--lang", default="en", help="Language code (default: en)") - args = parser.parse_args() - - if args.stream: - processor, model = load_model() - stream_transcribe(processor, model, args.lang) - elif args.mic is not None: - processor, model = load_model() - try: - mic_audio = record_audio(args.mic) - print("Transcribing...") - text = transcribe_audio(processor, model, mic_audio, args.lang) - print(f"\nTranscription:\n{text}\n") - except OSError as e: - print(f"Microphone error: {e}") - print("Hint: Run with nix-shell for PortAudio support") - else: - processor, model = load_model() - print("Loading demo audio...") - audio_file = hf_hub_download(repo_id=MODEL_ID, filename="demo/voxpopuli_test_en_demo.wav") - audio = load_audio(audio_file, sampling_rate=SAMPLE_RATE) - print("Transcribing...") - text = transcribe_audio(processor, model, audio, args.lang) - print(f"\nTranscription:\n{text}\n") - - -def calibrate_silence(duration=0.5): - print("Calibrating silence threshold...") - audio = sd.rec(int(duration * SAMPLE_RATE), samplerate=SAMPLE_RATE, channels=1, dtype="float32") - sd.wait() - rms = np.sqrt(np.mean(audio ** 2)) - threshold = max(rms * 3, 0.01) - print(f" Ambient RMS: {rms:.4f}, threshold: {threshold:.4f}") - return threshold - - -FRAME_SIZE = 800 # 50ms at 16kHz -PRE_ROLL_FRAMES = 6 # ~0.3s of audio before speech onset -SILENCE_FRAMES = 16 # ~0.8s of silence to end a segment -SPEECH_ONSET_FRAMES = 3 # ~150ms of speech to trigger -MAX_SPEECH_SECONDS = 30 # force chunk boundary - - -class VADStateMachine: - def __init__(self, threshold): - self.threshold = threshold - self.speaking = False - self.speech_frames = 0 - self.silence_frames = 0 - self.pre_roll = collections.deque(maxlen=PRE_ROLL_FRAMES) - self.segment = [] - self.segment_start_time = 0.0 - - def process_frame(self, frame, elapsed_time): - """Process one 50ms frame. Returns a (start_time, audio_array) tuple when a - complete speech segment is detected, otherwise None.""" - rms = np.sqrt(np.mean(frame ** 2)) - is_loud = rms > self.threshold - - if not self.speaking: - self.pre_roll.append(frame) - - if is_loud: - self.speech_frames += 1 - if self.speech_frames >= SPEECH_ONSET_FRAMES: - self.speaking = True - self.silence_frames = 0 - self.segment = list(self.pre_roll) - self.segment_start_time = max(0.0, elapsed_time - len(self.pre_roll) * FRAME_SIZE / SAMPLE_RATE) - self.pre_roll = collections.deque(maxlen=PRE_ROLL_FRAMES) - else: - self.speech_frames = 0 - return None - - # Currently speaking - self.segment.append(frame) - - if is_loud: - self.silence_frames = 0 - else: - self.silence_frames += 1 - - segment_duration = len(self.segment) * FRAME_SIZE / SAMPLE_RATE - if self.silence_frames >= SILENCE_FRAMES or segment_duration >= MAX_SPEECH_SECONDS: - result = (self.segment_start_time, np.concatenate(self.segment)) - self.speaking = False - self.speech_frames = 0 - self.silence_frames = 0 - self.segment = [] - self.pre_roll = collections.deque(maxlen=PRE_ROLL_FRAMES) - return result - - return None - - -def stream_transcribe(processor, model, language): - threshold = calibrate_silence() - vad = VADStateMachine(threshold) - seg_queue = queue.Queue() - stop_event = threading.Event() - start_time = time.monotonic() - - def transcription_worker(): - while not stop_event.is_set() or not seg_queue.empty(): - try: - seg_start, audio = seg_queue.get(timeout=0.5) - except queue.Empty: - continue - minutes = int(seg_start) // 60 - seconds = int(seg_start) % 60 - text = transcribe_audio(processor, model, audio, language) - if text.strip(): - print(f"[{minutes:02d}:{seconds:02d}] {text.strip()}") - - worker = threading.Thread(target=transcription_worker, daemon=True) - worker.start() - - def audio_callback(indata, frames, time_info, status): - if stop_event.is_set(): - return - elapsed = time.monotonic() - start_time - result = vad.process_frame(indata[:, 0].copy(), elapsed) - if result is not None: - seg_queue.put(result) - - print("Listening... (Ctrl+C to stop)") - stream = sd.InputStream( - samplerate=SAMPLE_RATE, channels=1, dtype="float32", - callback=audio_callback, blocksize=FRAME_SIZE, - ) - - try: - with stream: - while True: - time.sleep(0.1) - except KeyboardInterrupt: - pass - - stop_event.set() - - # Flush any remaining speech segment - if vad.speaking and vad.segment: - elapsed = time.monotonic() - start_time - seg_queue.put((vad.segment_start_time, np.concatenate(vad.segment))) - - worker.join(timeout=30) - if worker.is_alive(): - print("Warning: transcription worker did not finish in time.", file=sys.stderr) - print("\nDone.") - - -if __name__ == "__main__": - main()