Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions apps/models_provider/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def get_model_by_id(_id, workspace_id):
raise Exception(_("Model does not exist"))
return model


def get_model_default_params(model):
def convert_to_int(value):
if isinstance(value, str):
Expand All @@ -127,10 +128,18 @@ def convert_to_int(value):
return {
p.get('field'): convert_to_int(p.get('default_value'))
for p in model.model_params_form
if p.get('default_value') is not None
}


def reset_model_params(default_model_params, **kwargs):
result = {}
for key, value in default_model_params.items():
_value = kwargs.get(key) if kwargs.get(key) is not None else default_model_params.get(key)
if _value is not None:
result[key] = _value
return result


def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
"""
获取模型实例,根据模型相关数据
Expand All @@ -139,5 +148,6 @@ def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
@return: 模型实例
"""
model = get_model_by_id(model_id, workspace_id)
s = get_model_default_params(model)
return ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s, **kwargs}))
default_model_params = get_model_default_params(model)
model_params = reset_model_params(default_model_params, **kwargs)
return ModelManage.get_model(model_id, lambda _id: get_model(model, **model_params))
Loading