Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 8 additions & 91 deletions nemo_run/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def parse_value(self, value: str) -> Any:
Parse a string value into its corresponding Python object.

This method attempts to evaluate the string as a Python literal or expression.
It handles various types including booleans, constructors, comprehensions,
lambda functions, and ternary expressions.
It handles various types including booleans, constructors, lambda functions,
and ternary expressions.

Args:
value (str): The string value to parse.
Expand All @@ -267,8 +267,6 @@ def parse_value(self, value: str) -> Any:
return value.lower() == "true"
elif value.startswith(("dict(", "list(", "tuple(", "set(")):
return self.parse_constructor(value)
elif "[" in value and "]" in value and "for" in value:
return self.parse_comprehension(value)
elif value.startswith("lambda"):
return self.parse_lambda(value)
elif "if" in value and "else" in value:
Expand Down Expand Up @@ -349,101 +347,20 @@ def parse_constructor_args(self, args: str) -> List[Any]:

def parse_comprehension(self, value: str) -> Any:
"""
Parse a comprehension expression into its corresponding Python object.

This method safely evaluates list, dict, and set comprehensions.
Reject comprehension expressions in CLI values.

Args:
value (str): The comprehension string to parse.

Returns:
Any: The result of the comprehension.

Raises:
ArgumentValueError: If the comprehension is invalid or cannot be safely evaluated.

Example:
>>> parser = PythonicParser()
>>> parser.parse_comprehension("[x for x in range(3)]")
[0, 1, 2]
"""
try:
tree = ast.parse(value, mode="eval")
if isinstance(tree.body, (ast.ListComp, ast.DictComp, ast.SetComp)):
return self.eval_ast(tree.body)
raise ValueError("Not a valid comprehension")
except Exception as e:
raise ArgumentValueError(f"Invalid comprehension: {str(e)}", value, {})

def eval_ast(self, node: ast.AST, context: Dict[str, Any] = None) -> Any:
"""
Safely evaluate an AST node.

This method traverses the AST and evaluates it in a restricted environment,
allowing only safe operations and built-in functions.

Args:
node (ast.AST): The AST node to evaluate.
context (Dict[str, Any], optional): A dictionary of variables for evaluation context.

Returns:
Any: The result of evaluating the AST node.
Any: This method never returns successfully.

Raises:
ValueError: If an unsupported or unsafe operation is encountered.

Note:
This method is recursive and handles various AST node types.
ArgumentValueError: Always, because comprehensions are not supported in CLI values.
"""
if context is None:
context = {}

if isinstance(node, ast.Constant):
return node.value
elif isinstance(node, ast.Name):
return context.get(node.id, getattr(__builtins__, node.id, None))
elif isinstance(node, ast.BinOp):
left = self.eval_ast(node.left, context)
right = self.eval_ast(node.right, context)
op_func = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
}.get(type(node.op))
if op_func:
return op_func(left, right)
elif isinstance(node, ast.Compare):
left = self.eval_ast(node.left, context)
for op, right in zip(node.ops, node.comparators):
right_val = self.eval_ast(right, context)
op_func = {
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge,
}.get(type(op))
if op_func:
if not op_func(left, right_val):
return False
left = right_val
else:
raise ValueError(f"Unsupported comparison operator: {type(op)}")
return True
elif isinstance(node, ast.Call):
func = self.eval_ast(node.func, context)
args = [self.eval_ast(arg, context) for arg in node.args]
return func(*args)
elif isinstance(node, (ast.ListComp, ast.DictComp, ast.SetComp)):
# Implement safe evaluation of comprehensions
# This is a simplified version and may need more robust implementation
return eval(compile(ast.Expression(node), "<string>", "eval"), {}, {})
raise ValueError(f"Unsupported AST node: {type(node)}")
raise ArgumentValueError(
"Comprehensions are not supported in CLI values", value, {}
)

def parse_lambda(self, value: str) -> Callable:
"""
Expand Down
21 changes: 15 additions & 6 deletions test/cli/test_cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,12 +633,21 @@ def test_parse_constructor(self, parser):
assert parser.parse_constructor("set(1, 2, 3)") == {1, 2, 3}

def test_parse_comprehension(self, parser):
assert parser.parse_comprehension("[x for x in range(3)]") == [0, 1, 2]
assert parser.parse_comprehension("{x: x**2 for x in range(3)}") == {
0: 0,
1: 1,
2: 4,
}
with pytest.raises(ArgumentValueError, match="Comprehensions are not supported"):
parser.parse_comprehension("[x for x in range(3)]")

with pytest.raises(ArgumentValueError, match="Comprehensions are not supported"):
parser.parse_comprehension("{x: x**2 for x in range(3)}")

def test_parse_value_does_not_execute_comprehension_payload(self, parser, tmp_path):
marker = tmp_path / "payload_ran"
payload = (
f"[__import__('pathlib').Path({str(marker)!r}).write_text('owned') "
"for _ in [0]]"
)

assert parser.parse_value(payload) == payload
assert not marker.exists()

def test_parse_lambda(self, parser):
# Test safe lambdas
Expand Down