[ruby/prism] Add node ids to nodes

https://github.com/ruby/prism/commit/bf16ade7f9
This commit is contained in:
Kevin Newton 2024-07-03 10:42:56 -04:00
parent 39dcfe26ee
commit 2bf9ae3fa1
10 changed files with 449 additions and 194 deletions

View File

@ -72,6 +72,7 @@ module Prism
def to_interpolated def to_interpolated
InterpolatedStringNode.new( InterpolatedStringNode.new(
source, source,
-1,
location, location,
frozen? ? InterpolatedStringNodeFlags::FROZEN : 0, frozen? ? InterpolatedStringNodeFlags::FROZEN : 0,
opening_loc, opening_loc,
@ -89,10 +90,11 @@ module Prism
def to_interpolated def to_interpolated
InterpolatedXStringNode.new( InterpolatedXStringNode.new(
source, source,
-1,
location, location,
flags, flags,
opening_loc, opening_loc,
[StringNode.new(source, content_loc, 0, nil, content_loc, nil, unescaped)], [StringNode.new(source, node_id, content_loc, 0, nil, content_loc, nil, unescaped)],
closing_loc closing_loc
) )
end end
@ -119,9 +121,9 @@ module Prism
deprecated("value", "numerator", "denominator") deprecated("value", "numerator", "denominator")
if denominator == 1 if denominator == 1
IntegerNode.new(source, location.chop, flags, numerator) IntegerNode.new(source, -1, location.chop, flags, numerator)
else else
FloatNode.new(source, location.chop, 0, numerator.to_f / denominator) FloatNode.new(source, -1, location.chop, 0, numerator.to_f / denominator)
end end
end end
end end
@ -199,7 +201,12 @@ module Prism
# continue to supply that API. # continue to supply that API.
def child def child
deprecated("name", "name_loc") deprecated("name", "name_loc")
name ? ConstantReadNode.new(source, name_loc, 0, name) : MissingNode.new(source, location, 0)
if name
ConstantReadNode.new(source, -1, name_loc, 0, name)
else
MissingNode.new(source, -1, location, 0)
end
end end
end end
@ -235,7 +242,12 @@ module Prism
# continue to supply that API. # continue to supply that API.
def child def child
deprecated("name", "name_loc") deprecated("name", "name_loc")
name ? ConstantReadNode.new(source, name_loc, 0, name) : MissingNode.new(source, location, 0)
if name
ConstantReadNode.new(source, -1, name_loc, 0, name)
else
MissingNode.new(source, -1, location, 0)
end
end end
end end

View File

@ -625,6 +625,13 @@ typedef uint32_t pm_state_stack_t;
* it's considering. * it's considering.
*/ */
struct pm_parser { struct pm_parser {
/**
* The next node identifier that will be assigned. This is a unique
* identifier used to track nodes such that the syntax tree can be dropped
* but the node can be found through another parse.
*/
uint32_t node_id;
/** The current state of the lexer. */ /** The current state of the lexer. */
pm_lex_state_t lex_state; pm_lex_state_t lex_state;

File diff suppressed because it is too large Load Diff

View File

@ -173,17 +173,20 @@ pm_ast_new(const pm_parser_t *parser, const pm_node_t *node, rb_encoding *encodi
<%- if node.fields.any? { |field| ![Prism::Template::NodeField, Prism::Template::OptionalNodeField].include?(field.class) } -%> <%- if node.fields.any? { |field| ![Prism::Template::NodeField, Prism::Template::OptionalNodeField].include?(field.class) } -%>
pm_<%= node.human %>_t *cast = (pm_<%= node.human %>_t *) node; pm_<%= node.human %>_t *cast = (pm_<%= node.human %>_t *) node;
<%- end -%> <%- end -%>
VALUE argv[<%= node.fields.length + 3 %>]; VALUE argv[<%= node.fields.length + 4 %>];
// source // source
argv[0] = source; argv[0] = source;
// node_id
argv[1] = ULONG2NUM(node->node_id);
// location // location
argv[1] = pm_location_new(parser, node->location.start, node->location.end); argv[2] = pm_location_new(parser, node->location.start, node->location.end);
// flags // flags
argv[2] = ULONG2NUM(node->flags); argv[3] = ULONG2NUM(node->flags);
<%- node.fields.each.with_index(3) do |field, index| -%> <%- node.fields.each.with_index(4) do |field, index| -%>
// <%= field.name %> // <%= field.name %>
<%- case field -%> <%- case field -%>
@ -235,7 +238,7 @@ pm_ast_new(const pm_parser_t *parser, const pm_node_t *node, rb_encoding *encodi
<%- end -%> <%- end -%>
<%- end -%> <%- end -%>
rb_ary_push(value_stack, rb_class_new_instance(<%= node.fields.length + 3 %>, argv, rb_cPrism<%= node.name %>)); rb_ary_push(value_stack, rb_class_new_instance(<%= node.fields.length + 4 %>, argv, rb_cPrism<%= node.name %>));
break; break;
} }
<%- end -%> <%- end -%>

View File

@ -136,6 +136,12 @@ typedef struct pm_node {
*/ */
pm_node_flags_t flags; pm_node_flags_t flags;
/**
* The unique identifier for this node, which is deterministic based on the
* source. It is used to identify unique nodes across parses.
*/
uint32_t node_id;
/** /**
* This is the location of the node in the source. It's a range of bytes * This is the location of the node in the source. It's a range of bytes
* containing a start and an end. * containing a start and an end.

View File

@ -6,11 +6,13 @@ module Prism
# #
# Prism::ArrayNode.new( # Prism::ArrayNode.new(
# source, # source,
# 0,
# Prism::Location.new(source, 0, 3), # Prism::Location.new(source, 0, 3),
# 0, # 0,
# [ # [
# Prism::IntegerNode.new( # Prism::IntegerNode.new(
# source, # source,
# 0,
# Prism::Location.new(source, 1, 1), # Prism::Location.new(source, 1, 1),
# Prism::IntegerBaseFlags::DECIMAL, # Prism::IntegerBaseFlags::DECIMAL,
# 1 # 1
@ -65,7 +67,7 @@ module Prism
<%- nodes.each do |node| -%> <%- nodes.each do |node| -%>
# Create a new <%= node.name %> node. # Create a new <%= node.name %> node.
def <%= node.human %>(<%= ["source: default_source", "location: default_location", "flags: 0", *node.fields.map { |field| def <%= node.human %>(<%= ["source: default_source", "node_id: 0", "location: default_location", "flags: 0", *node.fields.map { |field|
case field case field
when Prism::Template::NodeField, Prism::Template::ConstantField when Prism::Template::NodeField, Prism::Template::ConstantField
"#{field.name}: default_node(source, location)" "#{field.name}: default_node(source, location)"
@ -83,7 +85,7 @@ module Prism
raise raise
end end
}].join(", ") %>) }].join(", ") %>)
<%= node.name %>.new(<%= ["source", "location", "flags", *node.fields.map(&:name)].join(", ") %>) <%= node.name %>.new(<%= ["source", "node_id", "location", "flags", *node.fields.map(&:name)].join(", ") %>)
end end
<%- end -%> <%- end -%>
<%- flags.each do |flag| -%> <%- flags.each do |flag| -%>

View File

@ -6,6 +6,12 @@ module Prism
attr_reader :source attr_reader :source
private :source private :source
# A unique identifier for this node. This is used in a very specific
# use case where you want to keep around a reference to a node without
# having to keep around the syntax tree in memory. This unique identifier
# will be consistent across multiple parses of the same source code.
attr_reader :node_id
# A Location instance that represents the location of this node in the # A Location instance that represents the location of this node in the
# source. # source.
def location def location
@ -212,8 +218,9 @@ module Prism
<%- end -%> <%- end -%>
class <%= node.name -%> < Node class <%= node.name -%> < Node
# Initialize a new <%= node.name %> node. # Initialize a new <%= node.name %> node.
def initialize(<%= ["source", "location", "flags", *node.fields.map(&:name)].join(", ") %>) def initialize(<%= ["source", "node_id", "location", "flags", *node.fields.map(&:name)].join(", ") %>)
@source = source @source = source
@node_id = node_id
@location = location @location = location
@flags = flags @flags = flags
<%- node.fields.each do |field| -%> <%- node.fields.each do |field| -%>
@ -274,17 +281,17 @@ module Prism
}.compact.join(", ") %>] #: Array[Prism::node | Location] }.compact.join(", ") %>] #: Array[Prism::node | Location]
end end
# def copy: (<%= (["?location: Location", "?flags: Integer"] + node.fields.map { |field| "?#{field.name}: #{field.rbs_class}" }).join(", ") %>) -> <%= node.name %> # def copy: (<%= (["?node_id: Integer", "?location: Location", "?flags: Integer"] + node.fields.map { |field| "?#{field.name}: #{field.rbs_class}" }).join(", ") %>) -> <%= node.name %>
def copy(<%= (["location", "flags"] + node.fields.map(&:name)).map { |field| "#{field}: self.#{field}" }.join(", ") %>) def copy(<%= (["node_id", "location", "flags"] + node.fields.map(&:name)).map { |field| "#{field}: self.#{field}" }.join(", ") %>)
<%= node.name %>.new(<%= ["source", "location", "flags", *node.fields.map(&:name)].join(", ") %>) <%= node.name %>.new(<%= ["source", "node_id", "location", "flags", *node.fields.map(&:name)].join(", ") %>)
end end
# def deconstruct: () -> Array[nil | Node] # def deconstruct: () -> Array[nil | Node]
alias deconstruct child_nodes alias deconstruct child_nodes
# def deconstruct_keys: (Array[Symbol] keys) -> { <%= (node.fields.map { |field| "#{field.name}: #{field.rbs_class}" } + ["location: Location"]).join(", ") %> } # def deconstruct_keys: (Array[Symbol] keys) -> { <%= (["node_id: Integer", "location: Location"] + node.fields.map { |field| "#{field.name}: #{field.rbs_class}" }).join(", ") %> }
def deconstruct_keys(keys) def deconstruct_keys(keys)
{ <%= (node.fields.map { |field| "#{field.name}: #{field.name}" } + ["location: location"]).join(", ") %> } { <%= (["node_id: node_id", "location: location"] + node.fields.map { |field| "#{field.name}: #{field.name}" }).join(", ") %> }
end end
<%- if (node_flags = node.flags) -%> <%- if (node_flags = node.flags) -%>
<%- node_flags.values.each do |value| -%> <%- node_flags.values.each do |value| -%>

View File

@ -322,6 +322,7 @@ module Prism
if RUBY_ENGINE == "ruby" if RUBY_ENGINE == "ruby"
def load_node def load_node
type = io.getbyte type = io.getbyte
node_id = load_varuint
location = load_location location = load_location
case type case type
@ -330,7 +331,7 @@ module Prism
<%- if node.needs_serialized_length? -%> <%- if node.needs_serialized_length? -%>
load_uint32 load_uint32
<%- end -%> <%- end -%>
<%= node.name %>.new(<%= ["source", "location", "load_varuint", *node.fields.map { |field| <%= node.name %>.new(<%= ["source", "node_id", "location", "load_varuint", *node.fields.map { |field|
case field case field
when Prism::Template::NodeField then "load_node" when Prism::Template::NodeField then "load_node"
when Prism::Template::OptionalNodeField then "load_optional_node" when Prism::Template::OptionalNodeField then "load_optional_node"
@ -362,11 +363,12 @@ module Prism
nil, nil,
<%- nodes.each do |node| -%> <%- nodes.each do |node| -%>
-> { -> {
node_id = load_varuint
location = load_location location = load_location
<%- if node.needs_serialized_length? -%> <%- if node.needs_serialized_length? -%>
load_uint32 load_uint32
<%- end -%> <%- end -%>
<%= node.name %>.new(<%= ["source", "location", "load_varuint", *node.fields.map { |field| <%= node.name %>.new(<%= ["source", "node_id", "location", "load_varuint", *node.fields.map { |field|
case field case field
when Prism::Template::NodeField then "load_node" when Prism::Template::NodeField then "load_node"
when Prism::Template::OptionalNodeField then "load_optional_node" when Prism::Template::OptionalNodeField then "load_optional_node"

View File

@ -74,6 +74,7 @@ pm_serialize_node(pm_parser_t *parser, pm_node_t *node, pm_buffer_t *buffer) {
size_t offset = buffer->length; size_t offset = buffer->length;
pm_buffer_append_varuint(buffer, node->node_id);
pm_serialize_location(parser, &node->location, buffer); pm_serialize_location(parser, &node->location, buffer);
switch (PM_NODE_TYPE(node)) { switch (PM_NODE_TYPE(node)) {

View File

@ -0,0 +1,27 @@
# frozen_string_literal: true
require_relative "../test_helper"
module Prism
class NodeIdTest < TestCase
Fixture.each do |fixture|
define_method(fixture.test_name) { assert_node_ids(fixture.read) }
end
private
def assert_node_ids(source)
queue = [Prism.parse(source).value]
node_ids = []
while (node = queue.shift)
node_ids << node.node_id
queue.concat(node.compact_child_nodes)
end
node_ids.sort!
refute_includes node_ids, 0
assert_equal node_ids, node_ids.uniq
end
end
end