Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
shirayu committed Oct 17, 2022
1 parent 62b6d9a commit 40dd471
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 34 deletions.
39 changes: 24 additions & 15 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,42 @@
#!/usr/bin/env python3


import sys
from unittest.mock import patch

from pydantic import BaseModel

from whispering.cli import get_opts, is_valid_arg
from whispering.cli import Mode, is_valid_arg


class ArgExample(BaseModel):
mode: Mode
cmd: str
ok: bool


def test_options():

exs = [
ArgExample(cmd="--mode server --mic 0", ok=False),
ArgExample(cmd="--mode server --mic 1", ok=False),
ArgExample(cmd="--mode server --beam_size 3", ok=False),
ArgExample(cmd="--mode server --temperature 0", ok=False),
ArgExample(cmd="--mode server --num_block 3", ok=False),
ArgExample(cmd="--mode mic --host 0.0.0.0", ok=False),
ArgExample(cmd="--mode mic --port 8000", ok=False),
ArgExample(mode=Mode.server, cmd="--mic 0", ok=False),
ArgExample(mode=Mode.server, cmd="--mic 1", ok=False),
ArgExample(
mode=Mode.server,
cmd="--host 0.0.0.0 --port 8000",
ok=True,
),
ArgExample(
mode=Mode.server,
cmd="--language en --model tiny --host 0.0.0.0 --port 8000",
ok=True,
),
ArgExample(mode=Mode.server, cmd="--beam_size 3", ok=False),
ArgExample(mode=Mode.server, cmd="--temperature 0", ok=False),
ArgExample(mode=Mode.server, cmd="--num_block 3", ok=False),
ArgExample(mode=Mode.mic, cmd="--host 0.0.0.0", ok=False),
ArgExample(mode=Mode.mic, cmd="--port 8000", ok=False),
]

for ex in exs:
with patch.object(sys, "argv", [""] + ex.cmd.split()):
opts = get_opts()
ok = is_valid_arg(opts)
assert ok is ex.ok, f"{ex.cmd} should be invalid"
ok = is_valid_arg(
mode=ex.mode.value,
args=ex.cmd.split(),
)
assert ok is ex.ok, f"{ex.cmd} should be {ex.ok}"
53 changes: 34 additions & 19 deletions whispering/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum
from logging import DEBUG, INFO, basicConfig, getLogger
from pathlib import Path
from typing import Iterator, Optional, Union
from typing import Iterator, List, Optional, Union

import sounddevice as sd
import torch
Expand Down Expand Up @@ -240,24 +240,36 @@ def show_devices():
print(f"{i}: {device['name']}")


def is_valid_arg(opts) -> bool:
def is_valid_arg(
*,
args: List[str],
mode: str,
) -> bool:
keys = []
if opts.mode == Mode.server.value:
keys = [
"mic",
"beam_size",
"temperature",
]
elif opts.mode == Mode.mic.value:
keys = [
"host",
"port",
]

for key in keys:
_val = vars(opts).get(key)
if _val is not None and _val is not False:
sys.stderr.write(f"{key} is not accepted option for {opts.mode} mode\n")
if mode == Mode.server.value:
keys = {
"--mic",
"--beam_size",
"-b",
"--temperature",
"-t",
"--num_block",
"-n",
"--vad",
"--max_nospeech_skip",
"--output",
"--show-devices",
"--no-progress",
}
elif mode == Mode.mic.value:
keys = {
"--host",
"--port",
}

for arg in args:
if arg in keys:
sys.stderr.write(f"{arg} is not accepted option for {mode} mode\n")
return False
return True

Expand All @@ -280,7 +292,10 @@ def main() -> None:
):
opts.mode = Mode.server.value

if not is_valid_arg(opts):
if not is_valid_arg(
args=sys.argv[1:],
mode=opts.mode,
):
sys.exit(1)

if opts.mode == Mode.client.value:
Expand Down

0 comments on commit 40dd471

Please sign in to comment.