diff --git a/nemo_run/cli/cli_parser.py b/nemo_run/cli/cli_parser.py index 76d8cca1..e7bc3543 100644 --- a/nemo_run/cli/cli_parser.py +++ b/nemo_run/cli/cli_parser.py @@ -244,8 +244,8 @@ def parse_value(self, value: str) -> Any: Parse a string value into its corresponding Python object. This method attempts to evaluate the string as a Python literal or expression. - It handles various types including booleans, constructors, comprehensions, - lambda functions, and ternary expressions. + It handles various types including booleans, constructors, lambda functions, + and ternary expressions. Args: value (str): The string value to parse. @@ -267,8 +267,6 @@ def parse_value(self, value: str) -> Any: return value.lower() == "true" elif value.startswith(("dict(", "list(", "tuple(", "set(")): return self.parse_constructor(value) - elif "[" in value and "]" in value and "for" in value: - return self.parse_comprehension(value) elif value.startswith("lambda"): return self.parse_lambda(value) elif "if" in value and "else" in value: @@ -349,101 +347,20 @@ def parse_constructor_args(self, args: str) -> List[Any]: def parse_comprehension(self, value: str) -> Any: """ - Parse a comprehension expression into its corresponding Python object. - - This method safely evaluates list, dict, and set comprehensions. + Reject comprehension expressions in CLI values. Args: value (str): The comprehension string to parse. Returns: - Any: The result of the comprehension. - - Raises: - ArgumentValueError: If the comprehension is invalid or cannot be safely evaluated. - - Example: - >>> parser = PythonicParser() - >>> parser.parse_comprehension("[x for x in range(3)]") - [0, 1, 2] - """ - try: - tree = ast.parse(value, mode="eval") - if isinstance(tree.body, (ast.ListComp, ast.DictComp, ast.SetComp)): - return self.eval_ast(tree.body) - raise ValueError("Not a valid comprehension") - except Exception as e: - raise ArgumentValueError(f"Invalid comprehension: {str(e)}", value, {}) - - def eval_ast(self, node: ast.AST, context: Dict[str, Any] = None) -> Any: - """ - Safely evaluate an AST node. - - This method traverses the AST and evaluates it in a restricted environment, - allowing only safe operations and built-in functions. - - Args: - node (ast.AST): The AST node to evaluate. - context (Dict[str, Any], optional): A dictionary of variables for evaluation context. - - Returns: - Any: The result of evaluating the AST node. + Any: This method never returns successfully. Raises: - ValueError: If an unsupported or unsafe operation is encountered. - - Note: - This method is recursive and handles various AST node types. + ArgumentValueError: Always, because comprehensions are not supported in CLI values. """ - if context is None: - context = {} - - if isinstance(node, ast.Constant): - return node.value - elif isinstance(node, ast.Name): - return context.get(node.id, getattr(__builtins__, node.id, None)) - elif isinstance(node, ast.BinOp): - left = self.eval_ast(node.left, context) - right = self.eval_ast(node.right, context) - op_func = { - ast.Add: operator.add, - ast.Sub: operator.sub, - ast.Mult: operator.mul, - ast.Div: operator.truediv, - ast.FloorDiv: operator.floordiv, - ast.Mod: operator.mod, - ast.Pow: operator.pow, - }.get(type(node.op)) - if op_func: - return op_func(left, right) - elif isinstance(node, ast.Compare): - left = self.eval_ast(node.left, context) - for op, right in zip(node.ops, node.comparators): - right_val = self.eval_ast(right, context) - op_func = { - ast.Eq: operator.eq, - ast.NotEq: operator.ne, - ast.Lt: operator.lt, - ast.LtE: operator.le, - ast.Gt: operator.gt, - ast.GtE: operator.ge, - }.get(type(op)) - if op_func: - if not op_func(left, right_val): - return False - left = right_val - else: - raise ValueError(f"Unsupported comparison operator: {type(op)}") - return True - elif isinstance(node, ast.Call): - func = self.eval_ast(node.func, context) - args = [self.eval_ast(arg, context) for arg in node.args] - return func(*args) - elif isinstance(node, (ast.ListComp, ast.DictComp, ast.SetComp)): - # Implement safe evaluation of comprehensions - # This is a simplified version and may need more robust implementation - return eval(compile(ast.Expression(node), "", "eval"), {}, {}) - raise ValueError(f"Unsupported AST node: {type(node)}") + raise ArgumentValueError( + "Comprehensions are not supported in CLI values", value, {} + ) def parse_lambda(self, value: str) -> Callable: """ diff --git a/test/cli/test_cli_parser.py b/test/cli/test_cli_parser.py index f8b8afd2..cbb66af6 100644 --- a/test/cli/test_cli_parser.py +++ b/test/cli/test_cli_parser.py @@ -633,12 +633,21 @@ def test_parse_constructor(self, parser): assert parser.parse_constructor("set(1, 2, 3)") == {1, 2, 3} def test_parse_comprehension(self, parser): - assert parser.parse_comprehension("[x for x in range(3)]") == [0, 1, 2] - assert parser.parse_comprehension("{x: x**2 for x in range(3)}") == { - 0: 0, - 1: 1, - 2: 4, - } + with pytest.raises(ArgumentValueError, match="Comprehensions are not supported"): + parser.parse_comprehension("[x for x in range(3)]") + + with pytest.raises(ArgumentValueError, match="Comprehensions are not supported"): + parser.parse_comprehension("{x: x**2 for x in range(3)}") + + def test_parse_value_does_not_execute_comprehension_payload(self, parser, tmp_path): + marker = tmp_path / "payload_ran" + payload = ( + f"[__import__('pathlib').Path({str(marker)!r}).write_text('owned') " + "for _ in [0]]" + ) + + assert parser.parse_value(payload) == payload + assert not marker.exists() def test_parse_lambda(self, parser): # Test safe lambdas