ZJIT: Add newrange support (#13505)

* Add newrange support to zjit

* Add RangeType enum for Range insn's flag

* Address other feedback
This commit is contained in:
Stan Lo 2025-06-05 02:51:53 +01:00 committed by GitHub
parent 0ca80484ac
commit 111986f8b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
Notes: git 2025-06-05 01:52:07 +00:00
Merged-By: k0kubun <takashikkbn@gmail.com>
8 changed files with 206 additions and 14 deletions

View File

@ -245,6 +245,27 @@ class TestZJIT < Test::Unit::TestCase
}
end
def test_new_range_inclusive
assert_compiles '1..5', %q{
def test(a, b) = a..b
test(1, 5)
}
end
def test_new_range_exclusive
assert_compiles '1...5', %q{
def test(a, b) = a...b
test(1, 5)
}
end
def test_new_range_with_literal
assert_compiles '3..10', %q{
def test(n) = n..10
test(3)
}
end
def test_if
assert_compiles '[0, nil]', %q{
def test(n)

View File

@ -182,6 +182,7 @@ fn main() {
.allowlist_var("rb_cSymbol")
.allowlist_var("rb_cFloat")
.allowlist_var("rb_cNumeric")
.allowlist_var("rb_cRange")
.allowlist_var("rb_cString")
.allowlist_var("rb_cThread")
.allowlist_var("rb_cArray")

View File

@ -7,7 +7,7 @@ use crate::state::ZJITState;
use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr};
use crate::invariants::{iseq_escapes_ep, track_no_ep_escape_assumption};
use crate::backend::lir::{self, asm_comment, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, SP};
use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, CallInfo};
use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, CallInfo, RangeType};
use crate::hir::{Const, FrameState, Function, Insn, InsnId};
use crate::hir_type::{types::Fixnum, Type};
use crate::options::get_option;
@ -251,6 +251,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::PutSelf => gen_putself(),
Insn::Const { val: Const::Value(val) } => gen_const(*val),
Insn::NewArray { elements, state } => gen_new_array(jit, asm, elements, &function.frame_state(*state)),
Insn::NewRange { low, high, flag, state } => gen_new_range(asm, opnd!(low), opnd!(high), *flag, &function.frame_state(*state)),
Insn::ArrayDup { val, state } => gen_array_dup(asm, opnd!(val), &function.frame_state(*state)),
Insn::Param { idx } => unreachable!("block.insns should not have Insn::Param({idx})"),
Insn::Snapshot { .. } => return Some(()), // we don't need to do anything for this instruction at the moment
@ -552,6 +553,28 @@ fn gen_new_array(
new_array
}
/// Compile a new range instruction
fn gen_new_range(
asm: &mut Assembler,
low: lir::Opnd,
high: lir::Opnd,
flag: RangeType,
state: &FrameState,
) -> lir::Opnd {
asm_comment!(asm, "call rb_range_new");
// Save PC
gen_save_pc(asm, state);
// Call rb_range_new(low, high, flag)
let new_range = asm.ccall(
rb_range_new as *const u8,
vec![low, high, lir::Opnd::Imm(flag as i64)],
);
new_range
}
/// Compile code that exits from JIT code with a return value
fn gen_return(asm: &mut Assembler, val: lir::Opnd) -> Option<()> {
// Pop the current frame (ec->cfp++)

View File

@ -771,6 +771,7 @@ unsafe extern "C" {
pub static mut rb_cModule: VALUE;
pub static mut rb_cNilClass: VALUE;
pub static mut rb_cNumeric: VALUE;
pub static mut rb_cRange: VALUE;
pub static mut rb_cString: VALUE;
pub static mut rb_cSymbol: VALUE;
pub static mut rb_cThread: VALUE;

View File

@ -228,6 +228,50 @@ impl Const {
}
}
pub enum RangeType {
Inclusive = 0, // include the end value
Exclusive = 1, // exclude the end value
}
impl std::fmt::Display for RangeType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", match self {
RangeType::Inclusive => "NewRangeInclusive",
RangeType::Exclusive => "NewRangeExclusive",
})
}
}
impl std::fmt::Debug for RangeType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.to_string())
}
}
impl Clone for RangeType {
fn clone(&self) -> Self {
*self
}
}
impl Copy for RangeType {}
impl From<u32> for RangeType {
fn from(flag: u32) -> Self {
match flag {
0 => RangeType::Inclusive,
1 => RangeType::Exclusive,
_ => panic!("Invalid range flag: {}", flag),
}
}
}
impl From<RangeType> for u32 {
fn from(range_type: RangeType) -> Self {
range_type as u32
}
}
/// Print adaptor for [`Const`]. See [`PtrPrintMap`].
struct ConstPrinter<'a> {
inner: &'a Const,
@ -330,6 +374,7 @@ pub enum Insn {
NewArray { elements: Vec<InsnId>, state: InsnId },
/// NewHash contains a vec of (key, value) pairs
NewHash { elements: Vec<(InsnId,InsnId)>, state: InsnId },
NewRange { low: InsnId, high: InsnId, flag: RangeType, state: InsnId },
ArraySet { array: InsnId, idx: usize, val: InsnId },
ArrayDup { val: InsnId, state: InsnId },
ArrayMax { elements: Vec<InsnId>, state: InsnId },
@ -439,6 +484,7 @@ impl Insn {
Insn::StringCopy { .. } => false,
Insn::NewArray { .. } => false,
Insn::NewHash { .. } => false,
Insn::NewRange { .. } => false,
Insn::ArrayDup { .. } => false,
Insn::HashDup { .. } => false,
Insn::Test { .. } => false,
@ -490,6 +536,9 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
}
Ok(())
}
Insn::NewRange { low, high, flag, .. } => {
write!(f, "NewRange {low} {flag} {high}")
}
Insn::ArrayMax { elements, .. } => {
write!(f, "ArrayMax")?;
let mut prefix = " ";
@ -912,6 +961,7 @@ impl Function {
}
NewHash { elements: found_elements, state: find!(state) }
}
&NewRange { low, high, flag, state } => NewRange { low: find!(low), high: find!(high), flag, state: find!(state) },
ArrayMax { elements, state } => ArrayMax { elements: find_vec!(*elements), state: find!(*state) },
&GetIvar { self_val, id, state } => GetIvar { self_val: find!(self_val), id, state },
&SetIvar { self_val, id, val, state } => SetIvar { self_val: find!(self_val), id, val, state },
@ -972,6 +1022,7 @@ impl Function {
Insn::ArrayDup { .. } => types::ArrayExact,
Insn::NewHash { .. } => types::HashExact,
Insn::HashDup { .. } => types::HashExact,
Insn::NewRange { .. } => types::RangeExact,
Insn::CCall { return_type, .. } => *return_type,
Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type),
Insn::GuardBitEquals { val, expected, .. } => self.type_of(*val).intersection(Type::from_value(*expected)),
@ -1486,6 +1537,11 @@ impl Function {
}
worklist.push_back(state);
}
Insn::NewRange { low, high, state, .. } => {
worklist.push_back(low);
worklist.push_back(high);
worklist.push_back(state);
}
Insn::StringCopy { val }
| Insn::StringIntern { val }
| Insn::Return { val }
@ -2342,6 +2398,14 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
let val = state.stack_pop()?;
fun.push_insn(block, Insn::SetIvar { self_val, id, val, state: exit_id });
}
YARVINSN_newrange => {
let flag = RangeType::from(get_arg(pc, 0).as_u32());
let high = state.stack_pop()?;
let low = state.stack_pop()?;
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
let insn_id = fun.push_insn(block, Insn::NewRange { low, high, flag, state: exit_id });
state.stack_push(insn_id);
}
_ => {
// Unknown opcode; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
@ -2735,6 +2799,52 @@ mod tests {
"#]]);
}
#[test]
fn test_new_range_inclusive_with_one_element() {
eval("def test(a) = (a..10)");
assert_method_hir_with_opcode("test", YARVINSN_newrange, expect![[r#"
fn test:
bb0(v0:BasicObject):
v2:Fixnum[10] = Const Value(10)
v4:RangeExact = NewRange v0 NewRangeInclusive v2
Return v4
"#]]);
}
#[test]
fn test_new_range_inclusive_with_two_elements() {
eval("def test(a, b) = (a..b)");
assert_method_hir_with_opcode("test", YARVINSN_newrange, expect![[r#"
fn test:
bb0(v0:BasicObject, v1:BasicObject):
v4:RangeExact = NewRange v0 NewRangeInclusive v1
Return v4
"#]]);
}
#[test]
fn test_new_range_exclusive_with_one_element() {
eval("def test(a) = (a...10)");
assert_method_hir_with_opcode("test", YARVINSN_newrange, expect![[r#"
fn test:
bb0(v0:BasicObject):
v2:Fixnum[10] = Const Value(10)
v4:RangeExact = NewRange v0 NewRangeExclusive v2
Return v4
"#]]);
}
#[test]
fn test_new_range_exclusive_with_two_elements() {
eval("def test(a, b) = (a...b)");
assert_method_hir_with_opcode("test", YARVINSN_newrange, expect![[r#"
fn test:
bb0(v0:BasicObject, v1:BasicObject):
v4:RangeExact = NewRange v0 NewRangeExclusive v1
Return v4
"#]]);
}
#[test]
fn test_array_dup() {
eval("def test = [1, 2, 3]");
@ -4273,6 +4383,23 @@ mod opt_tests {
"#]]);
}
#[test]
fn test_eliminate_new_range() {
eval("
def test()
c = (1..2)
5
end
test; test
");
assert_optimized_method_hir("test", expect![[r#"
fn test:
bb0():
v3:Fixnum[5] = Const Value(5)
Return v3
"#]]);
}
#[test]
fn test_eliminate_new_array_with_elements() {
eval("

View File

@ -71,6 +71,7 @@ end
base_type "String"
base_type "Array"
base_type "Hash"
base_type "Range"
(integer, integer_exact) = base_type "Integer"
# CRuby partitions Integer into immediate and non-immediate variants.

View File

@ -9,7 +9,7 @@ mod bits {
pub const BasicObjectSubclass: u64 = 1u64 << 3;
pub const Bignum: u64 = 1u64 << 4;
pub const BoolExact: u64 = FalseClassExact | TrueClassExact;
pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | ObjectExact | StringExact | SymbolExact | TrueClassExact;
pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | ObjectExact | RangeExact | StringExact | SymbolExact | TrueClassExact;
pub const CBool: u64 = 1u64 << 5;
pub const CDouble: u64 = 1u64 << 6;
pub const CInt: u64 = CSigned | CUnsigned;
@ -48,23 +48,26 @@ mod bits {
pub const NilClass: u64 = NilClassExact | NilClassSubclass;
pub const NilClassExact: u64 = 1u64 << 28;
pub const NilClassSubclass: u64 = 1u64 << 29;
pub const Object: u64 = Array | FalseClass | Float | Hash | Integer | NilClass | ObjectExact | ObjectSubclass | String | Symbol | TrueClass;
pub const Object: u64 = Array | FalseClass | Float | Hash | Integer | NilClass | ObjectExact | ObjectSubclass | Range | String | Symbol | TrueClass;
pub const ObjectExact: u64 = 1u64 << 30;
pub const ObjectSubclass: u64 = 1u64 << 31;
pub const Range: u64 = RangeExact | RangeSubclass;
pub const RangeExact: u64 = 1u64 << 32;
pub const RangeSubclass: u64 = 1u64 << 33;
pub const RubyValue: u64 = BasicObject | CallableMethodEntry | Undef;
pub const StaticSymbol: u64 = 1u64 << 32;
pub const StaticSymbol: u64 = 1u64 << 34;
pub const String: u64 = StringExact | StringSubclass;
pub const StringExact: u64 = 1u64 << 33;
pub const StringSubclass: u64 = 1u64 << 34;
pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | FalseClassSubclass | FloatSubclass | HashSubclass | IntegerSubclass | NilClassSubclass | ObjectSubclass | StringSubclass | SymbolSubclass | TrueClassSubclass;
pub const StringExact: u64 = 1u64 << 35;
pub const StringSubclass: u64 = 1u64 << 36;
pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | FalseClassSubclass | FloatSubclass | HashSubclass | IntegerSubclass | NilClassSubclass | ObjectSubclass | RangeSubclass | StringSubclass | SymbolSubclass | TrueClassSubclass;
pub const Symbol: u64 = SymbolExact | SymbolSubclass;
pub const SymbolExact: u64 = DynamicSymbol | StaticSymbol;
pub const SymbolSubclass: u64 = 1u64 << 35;
pub const SymbolSubclass: u64 = 1u64 << 37;
pub const TrueClass: u64 = TrueClassExact | TrueClassSubclass;
pub const TrueClassExact: u64 = 1u64 << 36;
pub const TrueClassSubclass: u64 = 1u64 << 37;
pub const Undef: u64 = 1u64 << 38;
pub const AllBitPatterns: [(&'static str, u64); 64] = [
pub const TrueClassExact: u64 = 1u64 << 38;
pub const TrueClassSubclass: u64 = 1u64 << 39;
pub const Undef: u64 = 1u64 << 40;
pub const AllBitPatterns: [(&'static str, u64); 67] = [
("Any", Any),
("RubyValue", RubyValue),
("Immediate", Immediate),
@ -84,6 +87,9 @@ mod bits {
("StringExact", StringExact),
("SymbolExact", SymbolExact),
("StaticSymbol", StaticSymbol),
("Range", Range),
("RangeSubclass", RangeSubclass),
("RangeExact", RangeExact),
("ObjectSubclass", ObjectSubclass),
("ObjectExact", ObjectExact),
("NilClass", NilClass),
@ -130,7 +136,7 @@ mod bits {
("ArrayExact", ArrayExact),
("Empty", Empty),
];
pub const NumTypeBits: u64 = 39;
pub const NumTypeBits: u64 = 41;
}
pub mod types {
use super::*;
@ -185,6 +191,9 @@ pub mod types {
pub const Object: Type = Type::from_bits(bits::Object);
pub const ObjectExact: Type = Type::from_bits(bits::ObjectExact);
pub const ObjectSubclass: Type = Type::from_bits(bits::ObjectSubclass);
pub const Range: Type = Type::from_bits(bits::Range);
pub const RangeExact: Type = Type::from_bits(bits::RangeExact);
pub const RangeSubclass: Type = Type::from_bits(bits::RangeSubclass);
pub const RubyValue: Type = Type::from_bits(bits::RubyValue);
pub const StaticSymbol: Type = Type::from_bits(bits::StaticSymbol);
pub const String: Type = Type::from_bits(bits::String);

View File

@ -1,6 +1,6 @@
#![allow(non_upper_case_globals)]
use crate::cruby::{Qfalse, Qnil, Qtrue, VALUE, RUBY_T_ARRAY, RUBY_T_STRING, RUBY_T_HASH};
use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cObject, rb_cTrueClass, rb_cFalseClass, rb_cNilClass};
use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cObject, rb_cTrueClass, rb_cFalseClass, rb_cNilClass, rb_cRange};
use crate::cruby::ClassRelationship;
use crate::cruby::get_class_name;
use crate::hir::PtrPrintMap;
@ -137,6 +137,10 @@ fn is_hash_exact(val: VALUE) -> bool {
val.class_of() == unsafe { rb_cHash } || (val.class_of() == VALUE(0) && val.builtin_type() == RUBY_T_HASH)
}
fn is_range_exact(val: VALUE) -> bool {
val.class_of() == unsafe { rb_cRange }
}
impl Type {
/// Create a `Type` from the given integer.
pub const fn fixnum(val: i64) -> Type {
@ -183,6 +187,9 @@ impl Type {
else if is_hash_exact(val) {
Type { bits: bits::HashExact, spec: Specialization::Object(val) }
}
else if is_range_exact(val) {
Type { bits: bits::RangeExact, spec: Specialization::Object(val) }
}
else if is_string_exact(val) {
Type { bits: bits::StringExact, spec: Specialization::Object(val) }
}
@ -277,6 +284,7 @@ impl Type {
if class == unsafe { rb_cInteger } { return true; }
if class == unsafe { rb_cNilClass } { return true; }
if class == unsafe { rb_cObject } { return true; }
if class == unsafe { rb_cRange } { return true; }
if class == unsafe { rb_cString } { return true; }
if class == unsafe { rb_cSymbol } { return true; }
if class == unsafe { rb_cTrueClass } { return true; }
@ -383,6 +391,7 @@ impl Type {
if self.is_subtype(types::IntegerExact) { return Some(unsafe { rb_cInteger }); }
if self.is_subtype(types::NilClassExact) { return Some(unsafe { rb_cNilClass }); }
if self.is_subtype(types::ObjectExact) { return Some(unsafe { rb_cObject }); }
if self.is_subtype(types::RangeExact) { return Some(unsafe { rb_cRange }); }
if self.is_subtype(types::StringExact) { return Some(unsafe { rb_cString }); }
if self.is_subtype(types::SymbolExact) { return Some(unsafe { rb_cSymbol }); }
if self.is_subtype(types::TrueClassExact) { return Some(unsafe { rb_cTrueClass }); }