[ruby/prism] Add an IntegerField for parsing integer values

https://github.com/ruby/prism/commit/120d8c0479
This commit is contained in:
Kevin Newton 2024-02-22 10:55:29 -05:00
parent ff6ebba9de
commit af0a6ea1d5
17 changed files with 199 additions and 56 deletions

View File

@ -71,13 +71,6 @@ module Prism
end
end
class IntegerNode < Node
# Returns the value of the node as a Ruby Integer.
def value
Integer(slice)
end
end
class RationalNode < Node
# Returns the value of the node as a Ruby Rational.
def value

View File

@ -1867,6 +1867,9 @@ nodes:
- name: flags
type: flags
kind: IntegerBaseFlags
- name: value
type: integer
comment: The value of the integer literal as a number.
comment: |
Represents an integer number literal.

View File

@ -7,9 +7,9 @@
#include <ruby/encoding.h>
#include "prism.h"
VALUE pm_source_new(pm_parser_t *parser, rb_encoding *encoding);
VALUE pm_token_new(pm_parser_t *parser, pm_token_t *token, rb_encoding *encoding, VALUE source);
VALUE pm_ast_new(pm_parser_t *parser, pm_node_t *node, rb_encoding *encoding, VALUE source);
VALUE pm_source_new(const pm_parser_t *parser, rb_encoding *encoding);
VALUE pm_token_new(const pm_parser_t *parser, const pm_token_t *token, rb_encoding *encoding, VALUE source);
VALUE pm_ast_new(const pm_parser_t *parser, const pm_node_t *node, rb_encoding *encoding, VALUE source);
void Init_prism_api_node(void);
void Init_prism_pack(void);

View File

@ -3803,12 +3803,25 @@ pm_integer_node_create(pm_parser_t *parser, pm_node_flags_t base, const pm_token
assert(token->type == PM_TOKEN_INTEGER);
pm_integer_node_t *node = PM_ALLOC_NODE(parser, pm_integer_node_t);
*node = (pm_integer_node_t) {{
.type = PM_INTEGER_NODE,
.flags = base | PM_NODE_FLAG_STATIC_LITERAL,
.location = PM_LOCATION_TOKEN_VALUE(token)
}};
*node = (pm_integer_node_t) {
{
.type = PM_INTEGER_NODE,
.flags = base | PM_NODE_FLAG_STATIC_LITERAL,
.location = PM_LOCATION_TOKEN_VALUE(token)
},
.value = { 0 }
};
pm_number_base_t number_base;
switch (base) {
case PM_INTEGER_BASE_FLAGS_BINARY: number_base = PM_NUMBER_BASE_BINARY; break;
case PM_INTEGER_BASE_FLAGS_OCTAL: number_base = PM_NUMBER_BASE_OCTAL; break;
case PM_INTEGER_BASE_FLAGS_DECIMAL: number_base = PM_NUMBER_BASE_DECIMAL; break;
case PM_INTEGER_BASE_FLAGS_HEXADECIMAL: number_base = PM_NUMBER_BASE_HEXADECIMAL; break;
default: assert(false && "unreachable");
}
pm_number_parse(&node->value, number_base, token->start, token->end);
return node;
}
@ -14281,6 +14294,9 @@ static inline void
parse_negative_numeric(pm_node_t *node) {
switch (PM_NODE_TYPE(node)) {
case PM_INTEGER_NODE:
node->location.start--;
((pm_integer_node_t *) node)->value.negative = true;
break;
case PM_FLOAT_NODE:
node->location.start--;
break;

View File

@ -12,13 +12,13 @@ static VALUE rb_cPrism<%= node.name %>;
<%- end -%>
static VALUE
pm_location_new(pm_parser_t *parser, const uint8_t *start, const uint8_t *end) {
pm_location_new(const pm_parser_t *parser, const uint8_t *start, const uint8_t *end) {
uint64_t value = ((((uint64_t) (start - parser->start)) << 32) | ((uint32_t) (end - start)));
return ULL2NUM(value);
}
VALUE
pm_token_new(pm_parser_t *parser, pm_token_t *token, rb_encoding *encoding, VALUE source) {
pm_token_new(const pm_parser_t *parser, const pm_token_t *token, rb_encoding *encoding, VALUE source) {
ID type = rb_intern(pm_token_type_name(token->type));
VALUE location = pm_location_new(parser, token->start, token->end);
@ -33,13 +33,30 @@ pm_token_new(pm_parser_t *parser, pm_token_t *token, rb_encoding *encoding, VALU
}
static VALUE
pm_string_new(pm_string_t *string, rb_encoding *encoding) {
pm_string_new(const pm_string_t *string, rb_encoding *encoding) {
return rb_enc_str_new((const char *) pm_string_source(string), pm_string_length(string), encoding);
}
static VALUE
pm_integer_new(const pm_number_t *number) {
VALUE result = UINT2NUM(number->head.value);
size_t shift = 0;
for (const pm_number_node_t *node = number->head.next; node != NULL; node = node->next) {
VALUE receiver = rb_funcall(UINT2NUM(node->value), rb_intern("<<"), 1, ULONG2NUM(++shift * 32));
result = rb_funcall(receiver, rb_intern("|"), 1, result);
}
if (number->negative) {
result = rb_funcall(result, rb_intern("-@"), 0);
}
return result;
}
// Create a Prism::Source object from the given parser, after pm_parse() was called.
VALUE
pm_source_new(pm_parser_t *parser, rb_encoding *encoding) {
pm_source_new(const pm_parser_t *parser, rb_encoding *encoding) {
VALUE source_string = rb_enc_str_new((const char *) parser->start, parser->end - parser->start, encoding);
VALUE offsets = rb_ary_new_capa(parser->newline_list.size);
@ -53,12 +70,12 @@ pm_source_new(pm_parser_t *parser, rb_encoding *encoding) {
typedef struct pm_node_stack_node {
struct pm_node_stack_node *prev;
pm_node_t *visit;
const pm_node_t *visit;
bool visited;
} pm_node_stack_node_t;
static void
pm_node_stack_push(pm_node_stack_node_t **stack, pm_node_t *visit) {
pm_node_stack_push(pm_node_stack_node_t **stack, const pm_node_t *visit) {
pm_node_stack_node_t *node = malloc(sizeof(pm_node_stack_node_t));
node->prev = *stack;
node->visit = visit;
@ -66,10 +83,10 @@ pm_node_stack_push(pm_node_stack_node_t **stack, pm_node_t *visit) {
*stack = node;
}
static pm_node_t *
static const pm_node_t *
pm_node_stack_pop(pm_node_stack_node_t **stack) {
pm_node_stack_node_t *current = *stack;
pm_node_t *visit = current->visit;
const pm_node_t *visit = current->visit;
*stack = current->prev;
free(current);
@ -78,7 +95,7 @@ pm_node_stack_pop(pm_node_stack_node_t **stack) {
}
VALUE
pm_ast_new(pm_parser_t *parser, pm_node_t *node, rb_encoding *encoding, VALUE source) {
pm_ast_new(const pm_parser_t *parser, const pm_node_t *node, rb_encoding *encoding, VALUE source) {
ID *constants = calloc(parser->constant_pool.size, sizeof(ID));
for (uint32_t index = 0; index < parser->constant_pool.size; index++) {
@ -108,7 +125,7 @@ pm_ast_new(pm_parser_t *parser, pm_node_t *node, rb_encoding *encoding, VALUE so
continue;
}
pm_node_t *node = node_stack->visit;
const pm_node_t *node = node_stack->visit;
node_stack->visited = true;
switch (PM_NODE_TYPE(node)) {
@ -136,7 +153,7 @@ pm_ast_new(pm_parser_t *parser, pm_node_t *node, rb_encoding *encoding, VALUE so
}
#line <%= __LINE__ + 1 %> "<%= File.basename(__FILE__) %>"
} else {
pm_node_t *node = pm_node_stack_pop(&node_stack);
const pm_node_t *node = pm_node_stack_pop(&node_stack);
switch (PM_NODE_TYPE(node)) {
<%- nodes.each do |node| -%>
@ -193,6 +210,9 @@ pm_ast_new(pm_parser_t *parser, pm_node_t *node, rb_encoding *encoding, VALUE so
<%- when Prism::FlagsField -%>
#line <%= __LINE__ + 1 %> "<%= File.basename(__FILE__) %>"
argv[<%= index %>] = ULONG2NUM(node->flags & ~PM_NODE_FLAG_COMMON_MASK);
<%- when Prism::IntegerField -%>
#line <%= __LINE__ + 1 %> "<%= File.basename(__FILE__) %>"
argv[<%= index %>] = pm_integer_new(&cast-><%= field.name %>);
<%- else -%>
<%- raise -%>
<%- end -%>

View File

@ -183,6 +183,7 @@ typedef struct pm_<%= node.human %> {
when Prism::LocationField, Prism::OptionalLocationField then "pm_location_t #{field.name}"
when Prism::UInt8Field then "uint8_t #{field.name}"
when Prism::UInt32Field then "uint32_t #{field.name}"
when Prism::IntegerField then "pm_number_t #{field.name}"
else raise field.class.name
end
%>;

View File

@ -133,7 +133,7 @@ module Prism
else
table.field("<%= field.name %>", "[]")
end
<%- when Prism::StringField, Prism::ConstantField, Prism::OptionalConstantField, Prism::UInt8Field, Prism::UInt32Field, Prism::ConstantListField -%>
<%- when Prism::StringField, Prism::ConstantField, Prism::OptionalConstantField, Prism::UInt8Field, Prism::UInt32Field, Prism::ConstantListField, Prism::IntegerField -%>
table.field("<%= field.name %>", node.<%= field.name %>.inspect)
<%- when Prism::LocationField -%>
table.field("<%= field.name %>", location_inspect(node.<%= field.name %>))

View File

@ -8,6 +8,7 @@ module Prism
# [
# Prism::IntegerNode.new(
# Prism::IntegerBaseFlags::DECIMAL,
# 1,
# Prism::Location.new(source, 1, 1),
# source
# )
@ -22,7 +23,7 @@ module Prism
# source = Prism::Source.new("[1]")
#
# ArrayNode(
# IntegerNode(Prism::IntegerBaseFlags::DECIMAL, Location(source, 1, 1)), source),
# IntegerNode(Prism::IntegerBaseFlags::DECIMAL, 1, Location(source, 1, 1)), source),
# Location(source, 0, 1),
# Location(source, 2, 1),
# source

View File

@ -281,7 +281,7 @@ module Prism
inspector << "<%= pointer %><%= field.name %>:\n"
inspector << <%= field.name %>.inspect(inspector.child_inspector("<%= preadd %>")).delete_prefix(inspector.prefix)
end
<%- when Prism::ConstantField, Prism::StringField, Prism::UInt8Field, Prism::UInt32Field -%>
<%- when Prism::ConstantField, Prism::StringField, Prism::UInt8Field, Prism::UInt32Field, Prism::IntegerField -%>
inspector << "<%= pointer %><%= field.name %>: #{<%= field.name %>.inspect}\n"
<%- when Prism::OptionalConstantField -%>
if (<%= field.name %> = self.<%= field.name %>).nil?

View File

@ -172,6 +172,17 @@ module Prism
(n >> 1) ^ (-(n & 1))
end
def load_integer
negative = io.getbyte != 0
length = load_varuint
value = 0
length.times { |index| value |= (load_varuint << (index * 32)) }
value = -value if negative
value
end
def load_serialized_length
io.read(4).unpack1("L")
end
@ -288,6 +299,7 @@ module Prism
when Prism::OptionalLocationField then "load_optional_location"
when Prism::UInt8Field then "io.getbyte"
when Prism::UInt32Field, Prism::FlagsField then "load_varuint"
when Prism::IntegerField then "load_integer"
else raise
end
} + ["location"]).join(", ") -%>)
@ -323,6 +335,7 @@ module Prism
when Prism::OptionalLocationField then "load_optional_location"
when Prism::UInt8Field then "io.getbyte"
when Prism::UInt32Field, Prism::FlagsField then "load_varuint"
when Prism::IntegerField then "load_integer"
else raise
end
} + ["location"]).join(", ") -%>)

View File

@ -74,6 +74,8 @@ pm_node_destroy(pm_parser_t *parser, pm_node_t *node) {
pm_node_list_free(parser, &cast-><%= field.name %>);
<%- when Prism::ConstantListField -%>
pm_constant_id_list_free(&cast-><%= field.name %>);
<%- when Prism::IntegerField -%>
pm_number_free(&cast-><%= field.name %>);
<%- else -%>
<%- raise -%>
<%- end -%>
@ -103,14 +105,6 @@ pm_node_memsize_node(pm_node_t *node, pm_memsize_t *memsize) {
case <%= node.type %>: {
pm_<%= node.human %>_t *cast = (pm_<%= node.human %>_t *) node;
memsize->memsize += sizeof(*cast);
<%- if node.fields.any? { |f| f.is_a?(Prism::NodeListField) } -%>
// Node lists will add in their own sizes below.
memsize->memsize -= sizeof(pm_node_list_t) * <%= node.fields.count { |f| f.is_a?(Prism::NodeListField) } %>;
<%- end -%>
<%- if node.fields.any? { |f| f.is_a?(Prism::ConstantListField) } -%>
// Constant id lists will add in their own sizes below.
memsize->memsize -= sizeof(pm_constant_id_list_t) * <%= node.fields.count { |f| f.is_a?(Prism::ConstantListField) } %>;
<%- end -%>
<%- node.fields.each do |field| -%>
<%- case field -%>
<%- when Prism::ConstantField, Prism::OptionalConstantField, Prism::UInt8Field, Prism::UInt32Field, Prism::FlagsField, Prism::LocationField, Prism::OptionalLocationField -%>
@ -121,11 +115,13 @@ pm_node_memsize_node(pm_node_t *node, pm_memsize_t *memsize) {
pm_node_memsize_node((pm_node_t *)cast-><%= field.name %>, memsize);
}
<%- when Prism::StringField -%>
memsize->memsize += pm_string_memsize(&cast-><%= field.name %>);
memsize->memsize += (pm_string_memsize(&cast-><%= field.name %>) - sizeof(pm_string_t));
<%- when Prism::NodeListField -%>
memsize->memsize += pm_node_list_memsize(&cast-><%= field.name %>, memsize);
memsize->memsize += (pm_node_list_memsize(&cast-><%= field.name %>, memsize) - sizeof(pm_node_list_t));
<%- when Prism::ConstantListField -%>
memsize->memsize += pm_constant_id_list_memsize(&cast-><%= field.name %>);
memsize->memsize += (pm_constant_id_list_memsize(&cast-><%= field.name %>) - sizeof(pm_constant_id_list_t));
<%- when Prism::IntegerField -%>
memsize->memsize += (pm_number_memsize(&cast-><%= field.name %>) - sizeof(pm_number_t));
<%- else -%>
<%- raise -%>
<%- end -%>
@ -257,6 +253,29 @@ pm_dump_json(pm_buffer_t *buffer, const pm_parser_t *parser, const pm_node_t *no
}
<%- end -%>
pm_buffer_append_byte(buffer, ']');
<%- when Prism::IntegerField -%>
{
const pm_number_t *number = &cast-><%= field.name %>;
if (number->length == 0) {
if (number->negative) pm_buffer_append_byte(buffer, '-');
pm_buffer_append_string(buffer, "%" PRIu32, number->head.value);
} else if (number->length == 1) {
if (number->negative) pm_buffer_append_byte(buffer, '-');
pm_buffer_append_format(buffer, "%" PRIu64, ((uint64_t) number->head.value) | (((uint64_t) number->head.next->value) << 32));
} else {
pm_buffer_append_byte(buffer, '{');
pm_buffer_append_format(buffer, "\"negative\": %s", number->negative ? "true" : "false");
pm_buffer_append_string(buffer, ",\"values\":[", 11);
const pm_number_node_t *node = &number->head;
while (node != NULL) {
pm_buffer_append_format(buffer, "%" PRIu32, node->value);
node = node->next;
if (node != NULL) pm_buffer_append_byte(buffer, ',');
}
pm_buffer_append_string(buffer, "]}", 2);
}
}
<%- else -%>
<%- raise %>
<%- end -%>

View File

@ -126,6 +126,36 @@ prettyprint_node(pm_buffer_t *output_buffer, const pm_parser_t *parser, const pm
<%- end -%>
if (!found) pm_buffer_append_string(output_buffer, " nil", 4);
pm_buffer_append_byte(output_buffer, '\n');
<%- when Prism::IntegerField -%>
const pm_number_t *number = &cast-><%= field.name %>;
if (number->length == 0) {
pm_buffer_append_byte(output_buffer, ' ');
if (number->negative) pm_buffer_append_byte(output_buffer, '-');
pm_buffer_append_string(output_buffer, "%" PRIu32 "\n", number->head.value);
} else if (number->length == 1) {
pm_buffer_append_byte(output_buffer, ' ');
if (number->negative) pm_buffer_append_byte(output_buffer, '-');
pm_buffer_append_string(output_buffer, "%" PRIu64 "\n", ((uint64_t) number->head.value) | (((uint64_t) number->head.next->value) << 32));
} else {
pm_buffer_append_byte(output_buffer, ' ');
const pm_number_node_t *node = &number->head;
uint32_t index = 0;
while (node != NULL) {
if (index != 0) pm_buffer_append_string(output_buffer, " | ", 3);
pm_buffer_append_format(output_buffer, "%" PRIu32, node->value);
if (index != 0) {
pm_buffer_append_string(output_buffer, " << ", 4);
pm_buffer_append_format(output_buffer, "%" PRIu32, index * 32);
}
node = node->next;
index++;
}
pm_buffer_append_string(output_buffer, "]\n", 2);
}
<%- else -%>
<%- raise -%>
<%- end -%>

View File

@ -25,7 +25,7 @@ pm_serialize_location(const pm_parser_t *parser, const pm_location_t *location,
}
static void
pm_serialize_string(pm_parser_t *parser, pm_string_t *string, pm_buffer_t *buffer) {
pm_serialize_string(const pm_parser_t *parser, const pm_string_t *string, pm_buffer_t *buffer) {
switch (string->type) {
case PM_STRING_SHARED: {
pm_buffer_append_byte(buffer, 1);
@ -49,6 +49,16 @@ pm_serialize_string(pm_parser_t *parser, pm_string_t *string, pm_buffer_t *buffe
}
}
static void
pm_serialize_integer(const pm_number_t *number, pm_buffer_t *buffer) {
pm_buffer_append_byte(buffer, number->negative ? 1 : 0);
pm_buffer_append_varuint(buffer, pm_sizet_to_u32(number->length + 1));
for (const pm_number_node_t *node = &number->head; node != NULL; node = node->next) {
pm_buffer_append_varuint(buffer, node->value);
}
}
static void
pm_serialize_node(pm_parser_t *parser, pm_node_t *node, pm_buffer_t *buffer) {
pm_buffer_append_byte(buffer, (uint8_t) PM_NODE_TYPE(node));
@ -115,6 +125,8 @@ pm_serialize_node(pm_parser_t *parser, pm_node_t *node, pm_buffer_t *buffer) {
pm_buffer_append_varuint(buffer, ((pm_<%= node.human %>_t *)node)-><%= field.name %>);
<%- when Prism::FlagsField -%>
pm_buffer_append_varuint(buffer, (uint32_t)(node->flags & ~PM_NODE_FLAG_COMMON_MASK));
<%- when Prism::IntegerField -%>
pm_serialize_integer(&((pm_<%= node.human %>_t *)node)-><%= field.name %>, buffer);
<%- else -%>
<%- raise -%>
<%- end -%>

View File

@ -254,6 +254,22 @@ module Prism
end
end
# This represents an arbitrarily-sized integer. When it gets to Ruby it will
# be an Integer.
class IntegerField < Field
def rbs_class
"Integer"
end
def rbi_class
"Integer"
end
def java_type
"VariableInteger"
end
end
# This class represents a node in the tree, configured by the config.yml file
# in YAML format. It contains information about the name of the node and the
# various child nodes it contains.
@ -315,6 +331,7 @@ module Prism
when "uint8" then UInt8Field
when "uint32" then UInt32Field
when "flags" then FlagsField
when "integer" then IntegerField
else raise("Unknown field type: #{name.inspect}")
end
end

View File

@ -87,17 +87,11 @@ pm_number_parse_digit(const uint8_t character) {
*/
PRISM_EXPORTED_FUNCTION void
pm_number_parse(pm_number_t *number, pm_number_base_t base, const uint8_t *start, const uint8_t *end) {
switch (*start) {
case '-':
number->negative = true;
/* fallthrough */
case '+':
start++;
break;
default:
break;
}
// Ignore unary +. Unary + is parsed differently and will not end up here.
// Instead, it will modify the parsed number later.
if (*start == '+') start++;
// Determine the multiplier from the base, and skip past any prefixes.
uint32_t multiplier;
switch (base) {
case PM_NUMBER_BASE_BINARY:
@ -133,13 +127,29 @@ pm_number_parse(pm_number_t *number, pm_number_base_t base, const uint8_t *start
break;
}
for (pm_number_add(number, pm_number_parse_digit(*start++)); start < end; start++) {
// It's possible that we've consumed everything at this point if there is an
// invalid number. If this is the case, we'll just return 0.
if (start >= end) return;
// Add the first digit to the number.
pm_number_add(number, pm_number_parse_digit(*start++));
// Add the subsequent digits to the number.
for (; start < end; start++) {
if (*start == '_') continue;
pm_number_multiply(number, multiplier);
pm_number_add(number, pm_number_parse_digit(*start));
}
}
/**
* Return the memory size of the number.
*/
size_t
pm_number_memsize(const pm_number_t *number) {
return sizeof(pm_number_t) + number->length * sizeof(pm_number_node_t);
}
/**
* Recursively destroy the linked list of a number.
*/

View File

@ -84,6 +84,14 @@ typedef enum {
*/
PRISM_EXPORTED_FUNCTION void pm_number_parse(pm_number_t *number, pm_number_base_t base, const uint8_t *start, const uint8_t *end);
/**
* Return the memory size of the number.
*
* @param number The number to get the memory size of.
* @return The size of the memory associated with the number.
*/
size_t pm_number_memsize(const pm_number_t *number);
/**
* Free the internal memory of a number. This memory will only be allocated if
* the number exceeds the size of a single node in the linked list.

View File

@ -966,8 +966,8 @@ module Prism
ParametersNode(
[RequiredParameterNode(0, :a)],
[
OptionalParameterNode(0, :b, Location(), Location(), IntegerNode(IntegerBaseFlags::DECIMAL)),
OptionalParameterNode(0, :d, Location(), Location(), IntegerNode(IntegerBaseFlags::DECIMAL))
OptionalParameterNode(0, :b, Location(), Location(), IntegerNode(IntegerBaseFlags::DECIMAL, 1)),
OptionalParameterNode(0, :d, Location(), Location(), IntegerNode(IntegerBaseFlags::DECIMAL, 2))
],
nil,
[RequiredParameterNode(0, :c), RequiredParameterNode(0, :e)],
@ -1024,7 +1024,7 @@ module Prism
Location(),
nil,
nil,
StatementsNode([IntegerNode(IntegerBaseFlags::DECIMAL)]),
StatementsNode([IntegerNode(IntegerBaseFlags::DECIMAL, 42)]),
[],
Location(),
nil,
@ -1214,7 +1214,7 @@ module Prism
:foo,
Location(),
nil,
ParametersNode([], [OptionalParameterNode(0, :a, Location(), Location(), IntegerNode(IntegerBaseFlags::DECIMAL))], RestParameterNode(0, :c, Location(), Location()), [RequiredParameterNode(0, :b)], [], nil, nil),
ParametersNode([], [OptionalParameterNode(0, :a, Location(), Location(), IntegerNode(IntegerBaseFlags::DECIMAL, 1))], RestParameterNode(0, :c, Location(), Location()), [RequiredParameterNode(0, :b)], [], nil, nil),
nil,
[:a, :b, :c],
Location(),