Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 74 additions & 80 deletions librubyfmt/src/format_prism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
heredoc_string::HeredocKind,
parser_state::{FormattingContext, HashType, ParserState},
types::SourceOffset,
util::{const_to_str, loc_to_str},
util::loc_to_str,
};

pub fn format_node<'src>(ps: &mut ParserState<'src>, node: prism::Node<'src>) {
Expand Down Expand Up @@ -1619,13 +1619,9 @@ fn format_block_parameter_node<'src>(
ps.emit_soft_indent();
ps.emit_ident(b"&");
if let Some(ident) = block_arg.name() {
let ident_str = const_to_str(ident);
let ident_str = ident.as_slice();
ps.bind_variable(ident_str);
handle_string_at_offset(
ps,
ident_str.as_bytes(),
block_arg.name_loc().unwrap().end_offset(),
);
handle_string_at_offset(ps, ident_str, block_arg.name_loc().unwrap().end_offset());
}
});
}
Expand All @@ -1642,16 +1638,17 @@ fn format_block_argument_node<'src>(
}
}

pub static RSPEC_METHODS: [&str; 2] = ["it", "describe"];
pub static RSPEC_METHODS: [&[u8]; 2] = [b"it", b"describe"];

pub static GEMFILE_METHODS: [&str; 4] = ["gem", "source", "ruby", "group"];
pub static GEMFILE_METHODS: [&[u8]; 4] = [b"gem", b"source", b"ruby", b"group"];

pub static OPTIONALLY_PARENTHESIZED_METHODS: [&str; 3] = ["super", "require", "require_relative"];
pub static OPTIONALLY_PARENTHESIZED_METHODS: [&[u8]; 3] =
[b"super", b"require", b"require_relative"];

fn use_parens_for_call_node<'src>(
ps: &ParserState<'src>,
call_node: &prism::CallNode<'src>,
method_name: &str,
method_name: &[u8],
// Whether we're the final call in a chain, e.g.
// foo.bar.baz
// ^ terminal call
Expand All @@ -1677,14 +1674,14 @@ fn use_parens_for_call_node<'src>(
})
.unwrap_or(false);

if is_terminal_call && method_name.chars().next().is_some_and(|c| c.is_uppercase()) {
if is_terminal_call && method_name.first().is_some_and(|c| c.is_ascii_uppercase()) {
if !has_arguments && call_node.block().is_some() {
return false;
}
return true;
}

if method_name.starts_with("attr_") && context == FormattingContext::ClassOrModule {
if method_name.starts_with(b"attr_") && context == FormattingContext::ClassOrModule {
return original_used_parens;
}

Expand All @@ -1699,7 +1696,7 @@ fn use_parens_for_call_node<'src>(
}
}

if method_name == "raise" {
if method_name == b"raise" {
if ps.current_formatting_context_requires_parens() {
return true;
}
Expand Down Expand Up @@ -1751,8 +1748,8 @@ fn use_parens_for_call_node<'src>(
if let Some(receiver) = call_node.receiver()
&& let Some(const_read) = receiver.as_constant_read_node()
{
let const_name = const_to_str(const_read.name());
if const_name == "RSpec" && method_name == "describe" {
let const_name = const_read.name().as_slice();
if const_name == b"RSpec" && method_name == b"describe" {
return false;
}
}
Expand All @@ -1771,7 +1768,7 @@ fn format_call_node<'src>(
is_final_call_in_chain: bool,
skip_attr_write_value: bool,
) {
let method_name = const_to_str(call_node.name());
let method_name = call_node.name().as_slice();
// When we skip the attr_write value (because it will be formatted separately),
// we should only wind to the end of the method name, not the full call.
// Otherwise we'd extract comments from inside the value prematurely.
Expand All @@ -1783,23 +1780,22 @@ fn format_call_node<'src>(
} else {
call_node.location().end_offset()
};
let is_dot_call = method_name == "call" && call_node.message_loc().is_none(); // e.g. `a.()`
let is_dot_call = method_name == b"call" && call_node.message_loc().is_none(); // e.g. `a.()`

// Only treat [] and []= as aref syntax when there's no explicit call operator.
let has_call_operator = call_node.call_operator_loc().is_some();
let is_aref = method_name == "[]" && !has_call_operator;
let is_aref_write = method_name == "[]=" && !has_call_operator;
let is_aref = method_name == b"[]" && !has_call_operator;
let is_aref_write = method_name == b"[]=" && !has_call_operator;

if skip_receiver || call_node.receiver().is_none() {
if !is_aref && !is_aref_write {
let method_ident = if call_node.is_attribute_write() {
loc_to_str(
call_node
.message_loc()
.expect("Attribute writes must have a message"),
)
call_node
.message_loc()
.expect("Attribute writes must have a message")
.as_slice()
} else if is_dot_call {
""
b""
} else {
method_name
};
Expand Down Expand Up @@ -2005,7 +2001,7 @@ fn format_call_node<'src>(
} else {
let is_unary_operator = call_node.arguments().is_none()
&& call_node.call_operator_loc().is_none()
&& matches!(method_name, "-@" | "+@" | "!" | "~");
&& [b"-@" as &[u8], b"+@", b"!", b"~"].contains(&method_name);

if is_unary_operator {
format_unary_operator(ps, call_node, method_name);
Expand All @@ -2020,7 +2016,7 @@ fn format_call_node<'src>(
format_infix_operator(
ps,
call_node.receiver().unwrap(),
method_name.as_bytes(),
method_name,
// For infix operators, we still get an ArgumentsNode, but it will
// always be an argument list of a single node.
call_node.arguments().unwrap().arguments().first().unwrap(),
Expand All @@ -2040,50 +2036,52 @@ fn format_call_node<'src>(
fn format_unary_operator<'src>(
ps: &mut ParserState<'src>,
call_node: prism::CallNode<'src>,
method_name: &'src str,
method_name: &'src [u8],
) {
// We need to preserve parens for `not`, they can be semantically meaningful
let is_not_with_parens = method_name == "!"
let is_not_with_parens = method_name == b"!"
&& call_node
.message_loc()
.map(|loc| loc_to_str(loc) == "not")
.map(|loc| loc.as_slice() == b"not")
.unwrap_or(false)
&& call_node.opening_loc().is_some();

let operator_symbol = match method_name {
"!" => {
let operator_symbol: &[u8] = match method_name {
b"!" => {
// `not` and `!` both have a `name` of `!` but different messages
if let Some(message_loc) = call_node.message_loc() {
let message_text = loc_to_str(message_loc);
if message_text == "not" {
if is_not_with_parens { "not" } else { "not " }
let message_text = message_loc.as_slice();
if message_text == b"not" {
if is_not_with_parens { b"not" } else { b"not " }
} else {
"!"
b"!"
}
} else {
"!"
b"!"
}
}
"-@" => "-",
"+@" => "+",
"~" => "~",
b"-@" => b"-",
b"+@" => b"+",
b"~" => b"~",
_ => {
if cfg!(debug_assertions) {
unreachable!("Received unexpected unary operator: {}", method_name);
unreachable!(
"Received unexpected unary operator: {}",
String::from_utf8_lossy(method_name)
);
}

// Try to render the message loc as a fallback in unexpected cases, but
// panic if we don't find one, otherwise we're rendering a total guess.
loc_to_str(
call_node
.message_loc()
.expect("Expected unary operator to have a message loc"),
)
call_node
.message_loc()
.expect("Expected unary operator to have a message loc")
.as_slice()
}
};

ps.with_start_of_line(false, |ps| {
ps.emit_ident(operator_symbol.as_bytes());
ps.emit_ident(operator_symbol);
let receiver = call_node
.receiver()
.expect("Unary operators must have a receiver");
Expand Down Expand Up @@ -2210,9 +2208,9 @@ fn format_call_chain_segments<'src>(
// break independently of each other
let trailing_attr_write_value = chain_elements.last().and_then(|last| {
last.as_call_node().and_then(|call| {
let method_name = const_to_str(call.name());
let method_name = call.name().as_slice();
// arefs are handled in format_call_node
if call.is_attribute_write() && method_name != "[]=" {
if call.is_attribute_write() && method_name != b"[]=" {
call.arguments().and_then(|args| {
debug_assert!(
args.arguments().len() == 1,
Expand Down Expand Up @@ -3134,10 +3132,10 @@ fn split_node_into_call_chains<'src>(node: prism::Node<'src>) -> Vec<Vec<prism::
while let Some(receiver) = maybe_receiver {
maybe_receiver = receiver.as_call_node().and_then(|call_node| {
// Don't traverse into unary operators, they should not be treated as part of a call chain.
let method_name = const_to_str(call_node.name());
let method_name = call_node.name().as_slice();
let is_unary_operator = call_node.arguments().is_none()
&& call_node.call_operator_loc().is_none()
&& matches!(method_name, "-@" | "+@" | "!" | "~");
&& [b"-@" as &[u8], b"+@", b"!", b"~"].contains(&method_name);
if is_unary_operator {
return None;
}
Expand Down Expand Up @@ -3218,13 +3216,9 @@ fn format_rest_parameter_node<'src>(
ps.emit_ident(b"*");
ps.with_start_of_line(false, |ps| {
if let Some(name) = rest_param.name() {
let name_str = const_to_str(name);
let name_str = name.as_slice();
ps.bind_variable(name_str);
handle_string_at_offset(
ps,
name_str.as_bytes(),
rest_param.name_loc().unwrap().end_offset(),
);
handle_string_at_offset(ps, name_str, rest_param.name_loc().unwrap().end_offset());
}
});
});
Expand Down Expand Up @@ -3279,40 +3273,40 @@ fn format_keyword_rest_parameter_node<'src>(
ps.emit_soft_indent();
ps.emit_ident(b"**");
if let Some(constant_id) = keyword_rest_parameter_node.name() {
let name = const_to_str(constant_id);
let name = constant_id.as_slice();
ps.bind_variable(name);
ps.emit_ident(name.as_bytes());
ps.emit_ident(name);
}
}

fn format_required_keyword_parameter_node<'src>(
ps: &mut ParserState<'src>,
required_keyword_parameter_node: prism::RequiredKeywordParameterNode<'src>,
) {
let name = const_to_str(required_keyword_parameter_node.name());
let name = required_keyword_parameter_node.name().as_slice();
ps.bind_variable(name);
ps.emit_ident(name.as_bytes());
ps.emit_ident(name);
ps.emit_ident(b":");
}

fn format_required_parameter_node<'src>(
ps: &mut ParserState<'src>,
required_parameter_node: prism::RequiredParameterNode<'src>,
) {
let name = const_to_str(required_parameter_node.name());
let name = required_parameter_node.name().as_slice();
ps.bind_variable(name);
ps.emit_ident(name.as_bytes());
ps.emit_ident(name);
}

fn format_local_variable_and_write_node<'src>(
ps: &mut ParserState<'src>,
local_variable_and_write_node: prism::LocalVariableAndWriteNode<'src>,
) {
let variable_name = const_to_str(local_variable_and_write_node.name());
let variable_name = local_variable_and_write_node.name().as_slice();
ps.bind_variable(variable_name);
format_write_node(
ps,
variable_name.as_bytes(),
variable_name,
b"&&=",
local_variable_and_write_node.value(),
);
Expand All @@ -3322,11 +3316,11 @@ fn format_local_variable_operator_write_node<'src>(
ps: &mut ParserState<'src>,
local_variable_operator_write_node: prism::LocalVariableOperatorWriteNode<'src>,
) {
let variable_name = const_to_str(local_variable_operator_write_node.name());
let variable_name = local_variable_operator_write_node.name().as_slice();
ps.bind_variable(variable_name);
format_write_node(
ps,
variable_name.as_bytes(),
variable_name,
local_variable_operator_write_node
.binary_operator_loc()
.as_slice(),
Expand All @@ -3338,11 +3332,11 @@ fn format_local_variable_or_write_node<'src>(
ps: &mut ParserState<'src>,
local_variable_or_write_node: prism::LocalVariableOrWriteNode<'src>,
) {
let variable_name = const_to_str(local_variable_or_write_node.name());
let variable_name = local_variable_or_write_node.name().as_slice();
ps.bind_variable(variable_name);
format_write_node(
ps,
variable_name.as_bytes(),
variable_name,
b"||=",
local_variable_or_write_node.value(),
);
Expand All @@ -3352,26 +3346,26 @@ fn format_local_variable_target_node<'src>(
ps: &mut ParserState<'src>,
local_variable_target_node: prism::LocalVariableTargetNode<'src>,
) {
let variable_name = const_to_str(local_variable_target_node.name());
let variable_name = local_variable_target_node.name().as_slice();
ps.bind_variable(variable_name);
ps.emit_ident(variable_name.as_bytes());
ps.emit_ident(variable_name);
}

fn format_local_variable_read_node<'src>(
ps: &mut ParserState<'src>,
local_variable_read_node: prism::LocalVariableReadNode<'src>,
) {
let name = const_to_str(local_variable_read_node.name());
ps.emit_ident(name.as_bytes());
let name = local_variable_read_node.name().as_slice();
ps.emit_ident(name);
}

fn format_local_variable_write_node<'src>(
ps: &mut ParserState<'src>,
local_variable_write_node: prism::LocalVariableWriteNode<'src>,
) {
let name = const_to_str(local_variable_write_node.name());
let name = local_variable_write_node.name().as_slice();
ps.bind_variable(name);
format_write_node(ps, name.as_bytes(), b"=", local_variable_write_node.value());
format_write_node(ps, name, b"=", local_variable_write_node.value());
}

fn format_splat_node<'src>(ps: &mut ParserState<'src>, splat_node: prism::SplatNode<'src>) {
Expand Down Expand Up @@ -4493,9 +4487,9 @@ fn format_optional_keyword_parameter_node<'src>(
ps: &mut ParserState<'src>,
optional_keyword_parameter_node: prism::OptionalKeywordParameterNode<'src>,
) {
let name = const_to_str(optional_keyword_parameter_node.name());
let name = optional_keyword_parameter_node.name().as_slice();
ps.bind_variable(name);
ps.emit_ident(name.as_bytes());
ps.emit_ident(name);
ps.emit_op(b":");
ps.emit_space();
ps.with_start_of_line(false, |ps| {
Expand All @@ -4507,9 +4501,9 @@ fn format_optional_parameter_node<'src>(
ps: &mut ParserState<'src>,
optional_parameter_node: prism::OptionalParameterNode<'src>,
) {
let name = const_to_str(optional_parameter_node.name());
let name = optional_parameter_node.name().as_slice();
ps.bind_variable(name);
ps.emit_ident(name.as_bytes());
ps.emit_ident(name);
ps.emit_space();
ps.emit_op(b"=");
ps.emit_space();
Expand Down
2 changes: 1 addition & 1 deletion librubyfmt/src/intermediary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl<'src> Intermediary<'src> {
}
}
ConcreteLineToken::MethodName { name } => {
if *name == "require" && self.tokens.last().map(|t| t.is_indent()).unwrap_or(false)
if *name == b"require" && self.tokens.last().map(|t| t.is_indent()).unwrap_or(false)
{
self.current_line_metadata.set_has_require();
}
Expand Down
Loading