Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ repos:
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.254
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
10 changes: 7 additions & 3 deletions hugie/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import srsly
import typer

from hugie.models import InferenceEndpointConfig
from hugie.models import EndpointConfig

app = typer.Typer()

Expand All @@ -26,7 +26,11 @@ def modify(
framework: str = typer.Option("huggingface", help="Framework to use"),
image: str = typer.Option(
None,
help="Image to use when deploying model endppint. Must be string representing a valid JSON, e.g. '{'huggingface': {}}'",
help=(
"Image to use when deploying model endpoint."
"Must be string representing a valid JSON,"
"e.g. '{'huggingface': {}}'"
),
),
repository: str = typer.Option(None, help="Name of the hf model repository"),
revision: str = typer.Option(None, help="Revision of the hf model repository"),
Expand All @@ -41,7 +45,7 @@ def modify(
Modify an existing endpoint config file
"""

config = InferenceEndpointConfig.from_json(path)
config = EndpointConfig.from_json(path)

# Standard configs

Expand Down
14 changes: 8 additions & 6 deletions hugie/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import requests
import typer

from hugie.models import InferenceEndpointConfig
from hugie.models import EndpointConfig
from hugie.settings import Settings
from hugie.utils import format_table, load_json

Expand Down Expand Up @@ -76,7 +76,9 @@ def list(

@app.command()
def create(
data: str = typer.Argument(..., help="Path JSON data to create the endpoint"),
data: str = typer.Argument(
..., help="Path or url of a JSON from which to create the endpoint"
),
json: Optional[bool] = typer.Option(
None, "--json", help="Prints the full output in JSON."
),
Expand All @@ -85,10 +87,10 @@ def create(
Create an endpoint

Args:
data (str): Path to JSON data to create the endpoint
data (str): Path or url of a JSON from which to create the endpoint.
"""

data = InferenceEndpointConfig.from_json(data).dict()
data = EndpointConfig.from_json(data).dict()

try:
response = requests.post(settings.endpoint_url, headers=headers, json=data)
Expand All @@ -105,7 +107,7 @@ def create(
elif response.status_code == 401:
typer.secho("Invalid token", fg=typer.colors.YELLOW)
elif response.status_code == 409:
typer.secho(f"Endpoint {name} already exists", fg=typer.colors.YELLOW)
typer.secho(f"Endpoint {data['name']} already exists", fg=typer.colors.YELLOW)
else:
typer.secho(
f"Endpoint {data['name']} created successfully on {data['provider']['vendor']} using {data['model']['repository']}",
Expand All @@ -126,7 +128,7 @@ def update(
"""
Update an endpoint
"""
data = dict(InferenceEndpointConfig.from_json(data))
data = dict(EndpointConfig.from_json(data))

try:
response = requests.put(
Expand Down
86 changes: 63 additions & 23 deletions hugie/models.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,61 @@
from pydantic import BaseModel, BaseSettings
"""
These models are based on the openapi specification of the Hugging Face
Inference Endpoints API: https://api.endpoints.huggingface.cloud/
"""

from pydantic import BaseModel, BaseSettings, Field

from hugie.utils import load_json


class ScalingModel(BaseModel):
minReplica: int = 1
maxReplica: int = 1
class EndpointScaling(BaseModel):
minReplica: int = Field(..., alias="Minimum number of endpoint replicas")
maxReplica: int = Field(..., alias="Maximum number of endpoint replicas")


class EndpointCompute(BaseModel):
accelerator: str = Field(
..., alias="Accelerator type, one of [cpu, gpu]", regex="^(cpu|gpu)$"
)
instanceSize: str = Field(..., alias="Instance size, e.g. large")
instanceType: str = Field(..., alias="Instance type, e.g. c6i")
scaling: EndpointScaling = EndpointScaling()


class EndpointImageCredentials(BaseModel):
username: str = Field(..., alias="Username for private registry")
password: str = Field(..., alias="Password for private registry")


class EndpointModelImageConfig(BaseModel):
credentials: EndpointImageCredentials = EndpointImageCredentials()
env: dict = Field({}, alias="Environment variables")
health_route: str = Field("/health", alias="Health route")
port: int = Field(80, alias="Port", description="Endpoint API port")
url: str = Field(
..., alias="URL for the container", example="https://host/image:tag"
)

class ComputeModel(BaseModel):

accelerator: str = None
instanceSize: str = None
instanceType: str = None
scaling: ScalingModel = ScalingModel()
class EndpointModelImage(BaseModel):
image: str = Field(
"huggingface",
description="One of ['huggingface', 'custom']",
regex="^(huggingface|custom)$",
)
config: dict = {}

def __call__(self, **kwargs):
return {self.image: self.config}

class ModelModel(BaseModel):

framework: str = None
image: dict = {"huggingface": {}}
repository: str = None
revision: str = None
class EndpointModel(BaseModel):
framework: str = Field(..., alias="Framework, one of [custom, pytorch, tensorflow]")
image: dict = Field({"huggingface": {}})
repository: str = Field(..., alias="Repository name, e.g. gpt2")
revision: str = Field(
..., description="Model commit hash, if not set, the latest commit will be used"
)
task: str = None


Expand All @@ -30,16 +64,22 @@ class ProviderModel(BaseModel):
region: str = None


class InferenceEndpointConfig(BaseSettings):
class EndpointConfig(BaseSettings):
"""
Config for the inference endpoint
"""

accountId: str = None
type: str = None
compute: ComputeModel = ComputeModel()
model: ModelModel = ModelModel()
name: str = None
type: str = Field(
...,
description="Type of the endpoint, must be one of ['public', 'protected', 'private']",
regex="^(public|protected|private)$",
)
compute: EndpointCompute = EndpointCompute()
model: EndpointModel = EndpointModel()
name: str = Field(
..., description="Name of the endpoint", max_length=32, regex="^[a-z0-9-]+$"
)
provider: ProviderModel = ProviderModel()

@classmethod
Expand All @@ -49,20 +89,20 @@ def from_json(self, path: str):
"""
config = load_json(path)

model = ModelModel(
model = EndpointModel(
framework=config["model"]["framework"],
image=config["model"]["image"],
repository=config["model"]["repository"],
revision=config["model"]["revision"],
task=config["model"]["task"],
)

scaling = ScalingModel(
scaling = EndpointScaling(
minReplica=config["compute"]["scaling"]["minReplica"],
maxReplica=config["compute"]["scaling"]["maxReplica"],
)

compute = ComputeModel(
compute = EndpointCompute(
accelerator=config["compute"]["accelerator"],
instanceSize=config["compute"]["instanceSize"],
instanceType=config["compute"]["instanceType"],
Expand All @@ -74,7 +114,7 @@ def from_json(self, path: str):
region=config["provider"]["region"],
)

config = InferenceEndpointConfig(
config = EndpointConfig(
accountId=config["accountId"],
type=config["type"],
compute=compute,
Expand Down