Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configure pylint and fix linter violations #179

Merged
merged 22 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7896847
Add workaround for setuptools bug
mattt Oct 29, 2023
83aa9ce
Extract setup step into script
mattt Oct 29, 2023
bdc79c9
Extract lint step into script
mattt Oct 29, 2023
0ff1258
Add script/format
mattt Oct 29, 2023
ee51121
Add pylint as optional dev dependency
mattt Oct 29, 2023
2372d43
Add pylint configuration to pyproject.toml
mattt Oct 29, 2023
8fc2dcf
Fix W4905: Using deprecated decorator abc.abstractproperty()
mattt Oct 29, 2023
63762f5
Fix W0719: Raising too general exception: Exception
mattt Oct 29, 2023
d049225
Fix W0621: Redefining name 'version' from outer scope (line 3)
mattt Oct 29, 2023
c503629
Fix E1101: Instance of 'ModelPrivateAttr' has no member
mattt Oct 29, 2023
c334577
Fix C0103: Name doesn't conform to snake_case naming style
mattt Oct 29, 2023
a922c27
Ignore C0103: Variable name doesn't conform to snake_case naming styl…
mattt Oct 29, 2023
bb6c526
Fix R1705: Unnecessary elif after return, remove the leading el from …
mattt Oct 29, 2023
d6ebe9f
Fix C0103: Constant name doesn't conform to UPPER_CASE naming style (…
mattt Oct 29, 2023
e92fa74
Fix C0103: Variable name doesn't conform to snake_case naming style (…
mattt Oct 29, 2023
eaf4747
Ignore R0911: Too many return statements (too-many-return-statements)
mattt Oct 29, 2023
3344a2a
Fix C0116: Missing function or method docstring (missing-function-doc…
mattt Oct 29, 2023
aa21664
Ignore C0116: Missing function or method docstring (missing-function-…
mattt Oct 29, 2023
18b7ad3
Fix C0115: Missing class docstring (missing-class-docstring)
mattt Oct 29, 2023
2a7e235
Run pylint in script/lint
mattt Oct 29, 2023
0266320
Extract test step into script
mattt Oct 29, 2023
b80a9d4
Run lint after test step
mattt Oct 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
- name: Install dependencies
run: |
python -m pip install -r requirements.txt -r requirements-dev.txt .
yes | python -m mypy --install-types replicate || true

- name: Lint
run: |
python -m mypy replicate
python -m ruff .
python -m ruff format --check .
- name: Setup
run: ./script/setup

- name: Test
run: python -m pytest
run: ./script/test

- name: Lint
run: ./script/lint

15 changes: 15 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ requires-python = ">=3.8"
dependencies = ["packaging", "pydantic>1", "httpx>=0.21.0,<1"]
optional-dependencies = { dev = [
"mypy",
"pylint",
"pytest",
"pytest-asyncio",
"pytest-recording",
Expand All @@ -27,13 +28,27 @@ repository = "https://github.com/replicate/replicate-python"
[tool.pytest.ini_options]
testpaths = "tests/"

[tool.setuptools]
# See https://github.com/pypa/setuptools/issues/3197#issuecomment-1078770109
py-modules = []

[tool.setuptools.package-data]
"replicate" = ["py.typed"]

[tool.mypy]
plugins = "pydantic.mypy"
exclude = ["tests/"]

[tool.pylint.main]
disable = [
"C0301", # Line too long
"C0413", # Import should be placed at the top of the module
"C0114", # Missing module docstring
"R0801", # Similar lines in N files
"W0212", # Access to a protected member
"W0622", # Redefining built-in
]

[tool.ruff]
select = [
"E", # pycodestyle error
Expand Down
5 changes: 3 additions & 2 deletions replicate/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def reload(self) -> None:
"""
Load this object from the server again.
"""
new_model = self._collection.get(self.id)
for k, v in new_model.dict().items():

new_model = self._collection.get(self.id) # pylint: disable=no-member
for k, v in new_model.dict().items(): # pylint: disable=invalid-name
setattr(self, k, v)
12 changes: 12 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,30 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response:

@property
def models(self) -> ModelCollection:
"""
Namespace for operations related to models.
"""
return ModelCollection(client=self)

@property
def predictions(self) -> PredictionCollection:
"""
Namespace for operations related to predictions.
"""
return PredictionCollection(client=self)

@property
def trainings(self) -> TrainingCollection:
"""
Namespace for operations related to trainings.
"""
return TrainingCollection(client=self)

@property
def deployments(self) -> DeploymentCollection:
"""
Namespace for operations related to deployments.
"""
return DeploymentCollection(client=self)

def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401
Expand Down
23 changes: 12 additions & 11 deletions replicate/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from replicate.client import Client

from replicate.base_model import BaseModel
from replicate.exceptions import ReplicateException

Model = TypeVar("Model", bound=BaseModel)

Expand All @@ -17,20 +18,21 @@ class Collection(abc.ABC, Generic[Model]):
def __init__(self, client: "Client") -> None:
self._client = client

@abc.abstractproperty
def model(self) -> Model:
@property
@abc.abstractmethod
def model(self) -> Model: # pylint: disable=missing-function-docstring
pass

@abc.abstractmethod
def list(self) -> List[Model]:
def list(self) -> List[Model]: # pylint: disable=missing-function-docstring
pass

@abc.abstractmethod
def get(self, key: str) -> Model:
def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring
pass

@abc.abstractmethod
def create(self, **kwargs) -> Model:
def create(self, **kwargs) -> Model: # pylint: disable=missing-function-docstring
pass

def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
Expand All @@ -41,13 +43,12 @@ def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
attrs._client = self._client
attrs._collection = self
return cast(Model, attrs)
elif (
isinstance(attrs, dict) and self.model is not None and callable(self.model)
):

if isinstance(attrs, dict) and self.model is not None and callable(self.model):
model = self.model(**attrs)
model._client = self._client
model._collection = self
return model
else:
name = self.model.__name__ if hasattr(self.model, "__name__") else "model"
raise Exception(f"Can't create {name} from {attrs}")

name = self.model.__name__ if hasattr(self.model, "__name__") else "model"
raise ReplicateException(f"Can't create {name} from {attrs}")
22 changes: 22 additions & 0 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,19 @@ def predictions(self) -> "DeploymentPredictionCollection":


class DeploymentCollection(Collection):
"""
Namespace for operations related to deployments.
"""

model = Deployment

def list(self) -> List[Deployment]:
"""
List deployments.

Raises:
NotImplementedError: This method is not implemented.
"""
raise NotImplementedError()

def get(self, name: str) -> Deployment:
Expand All @@ -56,6 +66,12 @@ def get(self, name: str) -> Deployment:
return self.prepare_model({"username": username, "name": name})

def create(self, **kwargs) -> Deployment:
"""
Create a deployment.

Raises:
NotImplementedError: This method is not implemented.
"""
raise NotImplementedError()

def prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
Expand All @@ -74,6 +90,12 @@ def __init__(self, client: "Client", deployment: Deployment) -> None:
self._deployment = deployment

def list(self) -> List[Prediction]:
"""
List predictions in a deployment.

Raises:
NotImplementedError: This method is not implemented.
"""
raise NotImplementedError()

def get(self, id: str) -> Prediction:
Expand Down
32 changes: 15 additions & 17 deletions replicate/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,34 @@
import httpx


def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
"""
Upload a file to the server.

Args:
fh: A file handle to upload.
file: A file handle to upload.
output_file_prefix: A string to prepend to the output file name.
Returns:
str: A URL to the uploaded file.
"""
# Lifted straight from cog.files

fh.seek(0)
file.seek(0)

if output_file_prefix is not None:
name = getattr(fh, "name", "output")
name = getattr(file, "name", "output")
url = output_file_prefix + os.path.basename(name)
resp = httpx.put(url, files={"file": fh}, timeout=None) # type: ignore
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
resp.raise_for_status()

return url

b = fh.read()
# The file handle is strings, not bytes
if isinstance(b, str):
b = b.encode("utf-8")
encoded_body = base64.b64encode(b)
if getattr(fh, "name", None):
# despite doing a getattr check here, mypy complains that io.IOBase has no attribute name
mime_type = mimetypes.guess_type(fh.name)[0] # type: ignore
else:
mime_type = "application/octet-stream"
s = encoded_body.decode("utf-8")
return f"data:{mime_type};base64,{s}"
body = file.read()
# Ensure the file handle is in bytes
body = body.encode("utf-8") if isinstance(body, str) else body
encoded_body = base64.b64encode(body).decode("utf-8")
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name
mime_type = (
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
)
return f"data:{mime_type};base64,{encoded_body}"
11 changes: 6 additions & 5 deletions replicate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
try:
import numpy as np # type: ignore

has_numpy = True
HAS_NUMPY = True
except ImportError:
has_numpy = False
HAS_NUMPY = False


# pylint: disable=too-many-return-statements
def encode_json(
obj: Any, # noqa: ANN401
upload_file: Callable[[io.IOBase], str],
Expand All @@ -25,11 +26,11 @@ def encode_json(
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
return [encode_json(value, upload_file) for value in obj]
if isinstance(obj, Path):
with obj.open("rb") as f:
return upload_file(f)
with obj.open("rb") as file:
return upload_file(file)
if isinstance(obj, io.IOBase):
return upload_file(obj)
if has_numpy:
if HAS_NUMPY:
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
Expand Down
10 changes: 10 additions & 0 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def versions(self) -> VersionCollection:


class ModelCollection(Collection):
"""
Namespace for operations related to models.
"""

model = Model

def list(self) -> List[Model]:
Expand Down Expand Up @@ -136,6 +140,12 @@ def get(self, key: str) -> Model:
return self.prepare_model(resp.json())

def create(self, **kwargs) -> Model:
"""
Create a model.

Raises:
NotImplementedError: This method is not implemented.
"""
raise NotImplementedError()

def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
Expand Down
10 changes: 7 additions & 3 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def wait(self) -> None:
Wait for prediction to finish.
"""
while self.status not in ["succeeded", "failed", "canceled"]:
time.sleep(self._client.poll_interval)
time.sleep(self._client.poll_interval) # pylint: disable=no-member
self.reload()

def output_iterator(self) -> Iterator[Any]:
Expand All @@ -114,7 +114,7 @@ def output_iterator(self) -> Iterator[Any]:
new_output = output[len(previous_output) :]
yield from new_output
previous_output = output
time.sleep(self._client.poll_interval)
time.sleep(self._client.poll_interval) # pylint: disable=no-member
self.reload()

if self.status == "failed":
Expand All @@ -129,10 +129,14 @@ def cancel(self) -> None:
"""
Cancels a running prediction.
"""
self._client._request("POST", f"/v1/predictions/{self.id}/cancel")
self._client._request("POST", f"/v1/predictions/{self.id}/cancel") # pylint: disable=no-member


class PredictionCollection(Collection):
"""
Namespace for operations related to predictions.
"""

model = Prediction

def list(self) -> List[Prediction]:
Expand Down
4 changes: 2 additions & 2 deletions replicate/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def version_has_no_array_type(cog_version: str) -> Optional[bool]:

def make_schema_backwards_compatible(
schema: dict,
version: str,
cog_version: str,
) -> dict:
"""A place to add backwards compatibility logic for our openapi schema"""

# If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type
if version_has_no_array_type(version):
if version_has_no_array_type(cog_version):
output = schema["components"]["schemas"]["Output"]
if output.get("type") == "array":
output["x-cog-array-type"] = "iterator"
Expand Down
6 changes: 5 additions & 1 deletion replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ class Training(BaseModel):

def cancel(self) -> None:
"""Cancel a running training"""
self._client._request("POST", f"/v1/trainings/{self.id}/cancel")
self._client._request("POST", f"/v1/trainings/{self.id}/cancel") # pylint: disable=no-member


class TrainingCollection(Collection):
"""
Namespace for operations related to trainings.
"""

model = Training

def list(self) -> List[Training]:
Expand Down
Loading