diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb9463d..e90abaf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,3 +9,8 @@ repos: rev: 22.3.0 hooks: - id: black +- repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.254 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] diff --git a/hugie/config.py b/hugie/config.py index 6d0a296..2677e84 100644 --- a/hugie/config.py +++ b/hugie/config.py @@ -1,7 +1,7 @@ import srsly import typer -from hugie.models import InferenceEndpointConfig +from hugie.models import EndpointConfig app = typer.Typer() @@ -26,7 +26,11 @@ def modify( framework: str = typer.Option("huggingface", help="Framework to use"), image: str = typer.Option( None, - help="Image to use when deploying model endppint. Must be string representing a valid JSON, e.g. '{'huggingface': {}}'", + help=( + "Image to use when deploying model endpoint." + "Must be string representing a valid JSON," + "e.g. '{'huggingface': {}}'" + ), ), repository: str = typer.Option(None, help="Name of the hf model repository"), revision: str = typer.Option(None, help="Revision of the hf model repository"), @@ -41,7 +45,7 @@ def modify( Modify an existing endpoint config file """ - config = InferenceEndpointConfig.from_json(path) + config = EndpointConfig.from_json(path) # Standard configs diff --git a/hugie/endpoint.py b/hugie/endpoint.py index ca7cbb4..8f16067 100644 --- a/hugie/endpoint.py +++ b/hugie/endpoint.py @@ -3,7 +3,7 @@ import requests import typer -from hugie.models import InferenceEndpointConfig +from hugie.models import EndpointConfig from hugie.settings import Settings from hugie.utils import format_table, load_json @@ -76,7 +76,9 @@ def list( @app.command() def create( - data: str = typer.Argument(..., help="Path JSON data to create the endpoint"), + data: str = typer.Argument( + ..., help="Path or url of a JSON from which to create the endpoint" + ), json: Optional[bool] = typer.Option( None, "--json", help="Prints the full output in JSON." ), @@ -85,10 +87,10 @@ def create( Create an endpoint Args: - data (str): Path to JSON data to create the endpoint + data (str): Path or url of a JSON from which to create the endpoint. """ - data = InferenceEndpointConfig.from_json(data).dict() + data = EndpointConfig.from_json(data).dict() try: response = requests.post(settings.endpoint_url, headers=headers, json=data) @@ -105,7 +107,7 @@ def create( elif response.status_code == 401: typer.secho("Invalid token", fg=typer.colors.YELLOW) elif response.status_code == 409: - typer.secho(f"Endpoint {name} already exists", fg=typer.colors.YELLOW) + typer.secho(f"Endpoint {data['name']} already exists", fg=typer.colors.YELLOW) else: typer.secho( f"Endpoint {data['name']} created successfully on {data['provider']['vendor']} using {data['model']['repository']}", @@ -126,7 +128,7 @@ def update( """ Update an endpoint """ - data = dict(InferenceEndpointConfig.from_json(data)) + data = dict(EndpointConfig.from_json(data)) try: response = requests.put( diff --git a/hugie/models.py b/hugie/models.py index eb41aaf..185e34a 100644 --- a/hugie/models.py +++ b/hugie/models.py @@ -1,27 +1,61 @@ -from pydantic import BaseModel, BaseSettings +""" +These models are based on the openapi specification of the Hugging Face +Inference Endpoints API: https://api.endpoints.huggingface.cloud/ +""" + +from pydantic import BaseModel, BaseSettings, Field from hugie.utils import load_json -class ScalingModel(BaseModel): - minReplica: int = 1 - maxReplica: int = 1 +class EndpointScaling(BaseModel): + minReplica: int = Field(..., alias="Minimum number of endpoint replicas") + maxReplica: int = Field(..., alias="Maximum number of endpoint replicas") + + +class EndpointCompute(BaseModel): + accelerator: str = Field( + ..., alias="Accelerator type, one of [cpu, gpu]", regex="^(cpu|gpu)$" + ) + instanceSize: str = Field(..., alias="Instance size, e.g. large") + instanceType: str = Field(..., alias="Instance type, e.g. c6i") + scaling: EndpointScaling = EndpointScaling() + + +class EndpointImageCredentials(BaseModel): + username: str = Field(..., alias="Username for private registry") + password: str = Field(..., alias="Password for private registry") + +class EndpointModelImageConfig(BaseModel): + credentials: EndpointImageCredentials = EndpointImageCredentials() + env: dict = Field({}, alias="Environment variables") + health_route: str = Field("/health", alias="Health route") + port: int = Field(80, alias="Port", description="Endpoint API port") + url: str = Field( + ..., alias="URL for the container", example="https://host/image:tag" + ) -class ComputeModel(BaseModel): - accelerator: str = None - instanceSize: str = None - instanceType: str = None - scaling: ScalingModel = ScalingModel() +class EndpointModelImage(BaseModel): + image: str = Field( + "huggingface", + description="One of ['huggingface', 'custom']", + regex="^(huggingface|custom)$", + ) + config: dict = {} + def __call__(self, **kwargs): + return {self.image: self.config} -class ModelModel(BaseModel): - framework: str = None - image: dict = {"huggingface": {}} - repository: str = None - revision: str = None +class EndpointModel(BaseModel): + framework: str = Field(..., alias="Framework, one of [custom, pytorch, tensorflow]") + image: dict = Field({"huggingface": {}}) + repository: str = Field(..., alias="Repository name, e.g. gpt2") + revision: str = Field( + ..., description="Model commit hash, if not set, the latest commit will be used" + ) task: str = None @@ -30,16 +64,22 @@ class ProviderModel(BaseModel): region: str = None -class InferenceEndpointConfig(BaseSettings): +class EndpointConfig(BaseSettings): """ Config for the inference endpoint """ accountId: str = None - type: str = None - compute: ComputeModel = ComputeModel() - model: ModelModel = ModelModel() - name: str = None + type: str = Field( + ..., + description="Type of the endpoint, must be one of ['public', 'protected', 'private']", + regex="^(public|protected|private)$", + ) + compute: EndpointCompute = EndpointCompute() + model: EndpointModel = EndpointModel() + name: str = Field( + ..., description="Name of the endpoint", max_length=32, regex="^[a-z0-9-]+$" + ) provider: ProviderModel = ProviderModel() @classmethod @@ -49,7 +89,7 @@ def from_json(self, path: str): """ config = load_json(path) - model = ModelModel( + model = EndpointModel( framework=config["model"]["framework"], image=config["model"]["image"], repository=config["model"]["repository"], @@ -57,12 +97,12 @@ def from_json(self, path: str): task=config["model"]["task"], ) - scaling = ScalingModel( + scaling = EndpointScaling( minReplica=config["compute"]["scaling"]["minReplica"], maxReplica=config["compute"]["scaling"]["maxReplica"], ) - compute = ComputeModel( + compute = EndpointCompute( accelerator=config["compute"]["accelerator"], instanceSize=config["compute"]["instanceSize"], instanceType=config["compute"]["instanceType"], @@ -74,7 +114,7 @@ def from_json(self, path: str): region=config["provider"]["region"], ) - config = InferenceEndpointConfig( + config = EndpointConfig( accountId=config["accountId"], type=config["type"], compute=compute,