diff --git a/transcribe.py b/transcribe.py index 4f8c097..45f373f 100644 --- a/transcribe.py +++ b/transcribe.py @@ -27,7 +27,8 @@ def transcribe_audio(processor, model, audio, language="en"): inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", language=language) inputs.to(model.device, dtype=model.dtype) outputs = model.generate(**inputs, max_new_tokens=256) - return processor.decode(outputs, skip_special_tokens=True) + texts = processor.decode(outputs, skip_special_tokens=True) + return " ".join(texts) if isinstance(texts, list) else texts def record_audio(duration):