diff --git a/apps/models_provider/tools.py b/apps/models_provider/tools.py index 9b94d7800c5..f503e4f1f78 100644 --- a/apps/models_provider/tools.py +++ b/apps/models_provider/tools.py @@ -115,6 +115,7 @@ def get_model_by_id(_id, workspace_id): raise Exception(_("Model does not exist")) return model + def get_model_default_params(model): def convert_to_int(value): if isinstance(value, str): @@ -127,10 +128,18 @@ def convert_to_int(value): return { p.get('field'): convert_to_int(p.get('default_value')) for p in model.model_params_form - if p.get('default_value') is not None } +def reset_model_params(default_model_params, **kwargs): + result = {} + for key, value in default_model_params.items(): + _value = kwargs.get(key) if kwargs.get(key) is not None else default_model_params.get(key) + if _value is not None: + result[key] = _value + return result + + def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs): """ 获取模型实例,根据模型相关数据 @@ -139,5 +148,6 @@ def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs): @return: 模型实例 """ model = get_model_by_id(model_id, workspace_id) - s = get_model_default_params(model) - return ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s, **kwargs})) + default_model_params = get_model_default_params(model) + model_params = reset_model_params(default_model_params, **kwargs) + return ModelManage.get_model(model_id, lambda _id: get_model(model, **model_params))