diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index aba05ddebd..be1a1b4599 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -245,6 +245,27 @@ class TestZJIT < Test::Unit::TestCase } end + def test_new_range_inclusive + assert_compiles '1..5', %q{ + def test(a, b) = a..b + test(1, 5) + } + end + + def test_new_range_exclusive + assert_compiles '1...5', %q{ + def test(a, b) = a...b + test(1, 5) + } + end + + def test_new_range_with_literal + assert_compiles '3..10', %q{ + def test(n) = n..10 + test(3) + } + end + def test_if assert_compiles '[0, nil]', %q{ def test(n) diff --git a/zjit/bindgen/src/main.rs b/zjit/bindgen/src/main.rs index 4c012cd3dc..80ecc4f0e6 100644 --- a/zjit/bindgen/src/main.rs +++ b/zjit/bindgen/src/main.rs @@ -182,6 +182,7 @@ fn main() { .allowlist_var("rb_cSymbol") .allowlist_var("rb_cFloat") .allowlist_var("rb_cNumeric") + .allowlist_var("rb_cRange") .allowlist_var("rb_cString") .allowlist_var("rb_cThread") .allowlist_var("rb_cArray") diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index c1fad915f0..c8713bb612 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -7,7 +7,7 @@ use crate::state::ZJITState; use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr}; use crate::invariants::{iseq_escapes_ep, track_no_ep_escape_assumption}; use crate::backend::lir::{self, asm_comment, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, SP}; -use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, CallInfo}; +use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, CallInfo, RangeType}; use crate::hir::{Const, FrameState, Function, Insn, InsnId}; use crate::hir_type::{types::Fixnum, Type}; use crate::options::get_option; @@ -251,6 +251,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::PutSelf => gen_putself(), Insn::Const { val: Const::Value(val) } => gen_const(*val), Insn::NewArray { elements, state } => gen_new_array(jit, asm, elements, &function.frame_state(*state)), + Insn::NewRange { low, high, flag, state } => gen_new_range(asm, opnd!(low), opnd!(high), *flag, &function.frame_state(*state)), Insn::ArrayDup { val, state } => gen_array_dup(asm, opnd!(val), &function.frame_state(*state)), Insn::Param { idx } => unreachable!("block.insns should not have Insn::Param({idx})"), Insn::Snapshot { .. } => return Some(()), // we don't need to do anything for this instruction at the moment @@ -552,6 +553,28 @@ fn gen_new_array( new_array } +/// Compile a new range instruction +fn gen_new_range( + asm: &mut Assembler, + low: lir::Opnd, + high: lir::Opnd, + flag: RangeType, + state: &FrameState, +) -> lir::Opnd { + asm_comment!(asm, "call rb_range_new"); + + // Save PC + gen_save_pc(asm, state); + + // Call rb_range_new(low, high, flag) + let new_range = asm.ccall( + rb_range_new as *const u8, + vec![low, high, lir::Opnd::Imm(flag as i64)], + ); + + new_range +} + /// Compile code that exits from JIT code with a return value fn gen_return(asm: &mut Assembler, val: lir::Opnd) -> Option<()> { // Pop the current frame (ec->cfp++) diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs index 58e8d40493..a5569a3db0 100644 --- a/zjit/src/cruby_bindings.inc.rs +++ b/zjit/src/cruby_bindings.inc.rs @@ -771,6 +771,7 @@ unsafe extern "C" { pub static mut rb_cModule: VALUE; pub static mut rb_cNilClass: VALUE; pub static mut rb_cNumeric: VALUE; + pub static mut rb_cRange: VALUE; pub static mut rb_cString: VALUE; pub static mut rb_cSymbol: VALUE; pub static mut rb_cThread: VALUE; diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 14a94cdff7..746a3b7e9a 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -228,6 +228,50 @@ impl Const { } } +pub enum RangeType { + Inclusive = 0, // include the end value + Exclusive = 1, // exclude the end value +} + +impl std::fmt::Display for RangeType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", match self { + RangeType::Inclusive => "NewRangeInclusive", + RangeType::Exclusive => "NewRangeExclusive", + }) + } +} + +impl std::fmt::Debug for RangeType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.to_string()) + } +} + +impl Clone for RangeType { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for RangeType {} + +impl From for RangeType { + fn from(flag: u32) -> Self { + match flag { + 0 => RangeType::Inclusive, + 1 => RangeType::Exclusive, + _ => panic!("Invalid range flag: {}", flag), + } + } +} + +impl From for u32 { + fn from(range_type: RangeType) -> Self { + range_type as u32 + } +} + /// Print adaptor for [`Const`]. See [`PtrPrintMap`]. struct ConstPrinter<'a> { inner: &'a Const, @@ -330,6 +374,7 @@ pub enum Insn { NewArray { elements: Vec, state: InsnId }, /// NewHash contains a vec of (key, value) pairs NewHash { elements: Vec<(InsnId,InsnId)>, state: InsnId }, + NewRange { low: InsnId, high: InsnId, flag: RangeType, state: InsnId }, ArraySet { array: InsnId, idx: usize, val: InsnId }, ArrayDup { val: InsnId, state: InsnId }, ArrayMax { elements: Vec, state: InsnId }, @@ -439,6 +484,7 @@ impl Insn { Insn::StringCopy { .. } => false, Insn::NewArray { .. } => false, Insn::NewHash { .. } => false, + Insn::NewRange { .. } => false, Insn::ArrayDup { .. } => false, Insn::HashDup { .. } => false, Insn::Test { .. } => false, @@ -490,6 +536,9 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { } Ok(()) } + Insn::NewRange { low, high, flag, .. } => { + write!(f, "NewRange {low} {flag} {high}") + } Insn::ArrayMax { elements, .. } => { write!(f, "ArrayMax")?; let mut prefix = " "; @@ -912,6 +961,7 @@ impl Function { } NewHash { elements: found_elements, state: find!(state) } } + &NewRange { low, high, flag, state } => NewRange { low: find!(low), high: find!(high), flag, state: find!(state) }, ArrayMax { elements, state } => ArrayMax { elements: find_vec!(*elements), state: find!(*state) }, &GetIvar { self_val, id, state } => GetIvar { self_val: find!(self_val), id, state }, &SetIvar { self_val, id, val, state } => SetIvar { self_val: find!(self_val), id, val, state }, @@ -972,6 +1022,7 @@ impl Function { Insn::ArrayDup { .. } => types::ArrayExact, Insn::NewHash { .. } => types::HashExact, Insn::HashDup { .. } => types::HashExact, + Insn::NewRange { .. } => types::RangeExact, Insn::CCall { return_type, .. } => *return_type, Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type), Insn::GuardBitEquals { val, expected, .. } => self.type_of(*val).intersection(Type::from_value(*expected)), @@ -1486,6 +1537,11 @@ impl Function { } worklist.push_back(state); } + Insn::NewRange { low, high, state, .. } => { + worklist.push_back(low); + worklist.push_back(high); + worklist.push_back(state); + } Insn::StringCopy { val } | Insn::StringIntern { val } | Insn::Return { val } @@ -2342,6 +2398,14 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { let val = state.stack_pop()?; fun.push_insn(block, Insn::SetIvar { self_val, id, val, state: exit_id }); } + YARVINSN_newrange => { + let flag = RangeType::from(get_arg(pc, 0).as_u32()); + let high = state.stack_pop()?; + let low = state.stack_pop()?; + let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); + let insn_id = fun.push_insn(block, Insn::NewRange { low, high, flag, state: exit_id }); + state.stack_push(insn_id); + } _ => { // Unknown opcode; side-exit into the interpreter let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); @@ -2735,6 +2799,52 @@ mod tests { "#]]); } + #[test] + fn test_new_range_inclusive_with_one_element() { + eval("def test(a) = (a..10)"); + assert_method_hir_with_opcode("test", YARVINSN_newrange, expect![[r#" + fn test: + bb0(v0:BasicObject): + v2:Fixnum[10] = Const Value(10) + v4:RangeExact = NewRange v0 NewRangeInclusive v2 + Return v4 + "#]]); + } + + #[test] + fn test_new_range_inclusive_with_two_elements() { + eval("def test(a, b) = (a..b)"); + assert_method_hir_with_opcode("test", YARVINSN_newrange, expect![[r#" + fn test: + bb0(v0:BasicObject, v1:BasicObject): + v4:RangeExact = NewRange v0 NewRangeInclusive v1 + Return v4 + "#]]); + } + + #[test] + fn test_new_range_exclusive_with_one_element() { + eval("def test(a) = (a...10)"); + assert_method_hir_with_opcode("test", YARVINSN_newrange, expect![[r#" + fn test: + bb0(v0:BasicObject): + v2:Fixnum[10] = Const Value(10) + v4:RangeExact = NewRange v0 NewRangeExclusive v2 + Return v4 + "#]]); + } + + #[test] + fn test_new_range_exclusive_with_two_elements() { + eval("def test(a, b) = (a...b)"); + assert_method_hir_with_opcode("test", YARVINSN_newrange, expect![[r#" + fn test: + bb0(v0:BasicObject, v1:BasicObject): + v4:RangeExact = NewRange v0 NewRangeExclusive v1 + Return v4 + "#]]); + } + #[test] fn test_array_dup() { eval("def test = [1, 2, 3]"); @@ -4273,6 +4383,23 @@ mod opt_tests { "#]]); } + #[test] + fn test_eliminate_new_range() { + eval(" + def test() + c = (1..2) + 5 + end + test; test + "); + assert_optimized_method_hir("test", expect![[r#" + fn test: + bb0(): + v3:Fixnum[5] = Const Value(5) + Return v3 + "#]]); + } + #[test] fn test_eliminate_new_array_with_elements() { eval(" diff --git a/zjit/src/hir_type/gen_hir_type.rb b/zjit/src/hir_type/gen_hir_type.rb index ae00a34d87..92351aafa2 100644 --- a/zjit/src/hir_type/gen_hir_type.rb +++ b/zjit/src/hir_type/gen_hir_type.rb @@ -71,6 +71,7 @@ end base_type "String" base_type "Array" base_type "Hash" +base_type "Range" (integer, integer_exact) = base_type "Integer" # CRuby partitions Integer into immediate and non-immediate variants. diff --git a/zjit/src/hir_type/hir_type.inc.rs b/zjit/src/hir_type/hir_type.inc.rs index 1560751933..7d6f92a180 100644 --- a/zjit/src/hir_type/hir_type.inc.rs +++ b/zjit/src/hir_type/hir_type.inc.rs @@ -9,7 +9,7 @@ mod bits { pub const BasicObjectSubclass: u64 = 1u64 << 3; pub const Bignum: u64 = 1u64 << 4; pub const BoolExact: u64 = FalseClassExact | TrueClassExact; - pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | ObjectExact | StringExact | SymbolExact | TrueClassExact; + pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | ObjectExact | RangeExact | StringExact | SymbolExact | TrueClassExact; pub const CBool: u64 = 1u64 << 5; pub const CDouble: u64 = 1u64 << 6; pub const CInt: u64 = CSigned | CUnsigned; @@ -48,23 +48,26 @@ mod bits { pub const NilClass: u64 = NilClassExact | NilClassSubclass; pub const NilClassExact: u64 = 1u64 << 28; pub const NilClassSubclass: u64 = 1u64 << 29; - pub const Object: u64 = Array | FalseClass | Float | Hash | Integer | NilClass | ObjectExact | ObjectSubclass | String | Symbol | TrueClass; + pub const Object: u64 = Array | FalseClass | Float | Hash | Integer | NilClass | ObjectExact | ObjectSubclass | Range | String | Symbol | TrueClass; pub const ObjectExact: u64 = 1u64 << 30; pub const ObjectSubclass: u64 = 1u64 << 31; + pub const Range: u64 = RangeExact | RangeSubclass; + pub const RangeExact: u64 = 1u64 << 32; + pub const RangeSubclass: u64 = 1u64 << 33; pub const RubyValue: u64 = BasicObject | CallableMethodEntry | Undef; - pub const StaticSymbol: u64 = 1u64 << 32; + pub const StaticSymbol: u64 = 1u64 << 34; pub const String: u64 = StringExact | StringSubclass; - pub const StringExact: u64 = 1u64 << 33; - pub const StringSubclass: u64 = 1u64 << 34; - pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | FalseClassSubclass | FloatSubclass | HashSubclass | IntegerSubclass | NilClassSubclass | ObjectSubclass | StringSubclass | SymbolSubclass | TrueClassSubclass; + pub const StringExact: u64 = 1u64 << 35; + pub const StringSubclass: u64 = 1u64 << 36; + pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | FalseClassSubclass | FloatSubclass | HashSubclass | IntegerSubclass | NilClassSubclass | ObjectSubclass | RangeSubclass | StringSubclass | SymbolSubclass | TrueClassSubclass; pub const Symbol: u64 = SymbolExact | SymbolSubclass; pub const SymbolExact: u64 = DynamicSymbol | StaticSymbol; - pub const SymbolSubclass: u64 = 1u64 << 35; + pub const SymbolSubclass: u64 = 1u64 << 37; pub const TrueClass: u64 = TrueClassExact | TrueClassSubclass; - pub const TrueClassExact: u64 = 1u64 << 36; - pub const TrueClassSubclass: u64 = 1u64 << 37; - pub const Undef: u64 = 1u64 << 38; - pub const AllBitPatterns: [(&'static str, u64); 64] = [ + pub const TrueClassExact: u64 = 1u64 << 38; + pub const TrueClassSubclass: u64 = 1u64 << 39; + pub const Undef: u64 = 1u64 << 40; + pub const AllBitPatterns: [(&'static str, u64); 67] = [ ("Any", Any), ("RubyValue", RubyValue), ("Immediate", Immediate), @@ -84,6 +87,9 @@ mod bits { ("StringExact", StringExact), ("SymbolExact", SymbolExact), ("StaticSymbol", StaticSymbol), + ("Range", Range), + ("RangeSubclass", RangeSubclass), + ("RangeExact", RangeExact), ("ObjectSubclass", ObjectSubclass), ("ObjectExact", ObjectExact), ("NilClass", NilClass), @@ -130,7 +136,7 @@ mod bits { ("ArrayExact", ArrayExact), ("Empty", Empty), ]; - pub const NumTypeBits: u64 = 39; + pub const NumTypeBits: u64 = 41; } pub mod types { use super::*; @@ -185,6 +191,9 @@ pub mod types { pub const Object: Type = Type::from_bits(bits::Object); pub const ObjectExact: Type = Type::from_bits(bits::ObjectExact); pub const ObjectSubclass: Type = Type::from_bits(bits::ObjectSubclass); + pub const Range: Type = Type::from_bits(bits::Range); + pub const RangeExact: Type = Type::from_bits(bits::RangeExact); + pub const RangeSubclass: Type = Type::from_bits(bits::RangeSubclass); pub const RubyValue: Type = Type::from_bits(bits::RubyValue); pub const StaticSymbol: Type = Type::from_bits(bits::StaticSymbol); pub const String: Type = Type::from_bits(bits::String); diff --git a/zjit/src/hir_type/mod.rs b/zjit/src/hir_type/mod.rs index 0459482757..dd53fed105 100644 --- a/zjit/src/hir_type/mod.rs +++ b/zjit/src/hir_type/mod.rs @@ -1,6 +1,6 @@ #![allow(non_upper_case_globals)] use crate::cruby::{Qfalse, Qnil, Qtrue, VALUE, RUBY_T_ARRAY, RUBY_T_STRING, RUBY_T_HASH}; -use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cObject, rb_cTrueClass, rb_cFalseClass, rb_cNilClass}; +use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cObject, rb_cTrueClass, rb_cFalseClass, rb_cNilClass, rb_cRange}; use crate::cruby::ClassRelationship; use crate::cruby::get_class_name; use crate::hir::PtrPrintMap; @@ -137,6 +137,10 @@ fn is_hash_exact(val: VALUE) -> bool { val.class_of() == unsafe { rb_cHash } || (val.class_of() == VALUE(0) && val.builtin_type() == RUBY_T_HASH) } +fn is_range_exact(val: VALUE) -> bool { + val.class_of() == unsafe { rb_cRange } +} + impl Type { /// Create a `Type` from the given integer. pub const fn fixnum(val: i64) -> Type { @@ -183,6 +187,9 @@ impl Type { else if is_hash_exact(val) { Type { bits: bits::HashExact, spec: Specialization::Object(val) } } + else if is_range_exact(val) { + Type { bits: bits::RangeExact, spec: Specialization::Object(val) } + } else if is_string_exact(val) { Type { bits: bits::StringExact, spec: Specialization::Object(val) } } @@ -277,6 +284,7 @@ impl Type { if class == unsafe { rb_cInteger } { return true; } if class == unsafe { rb_cNilClass } { return true; } if class == unsafe { rb_cObject } { return true; } + if class == unsafe { rb_cRange } { return true; } if class == unsafe { rb_cString } { return true; } if class == unsafe { rb_cSymbol } { return true; } if class == unsafe { rb_cTrueClass } { return true; } @@ -383,6 +391,7 @@ impl Type { if self.is_subtype(types::IntegerExact) { return Some(unsafe { rb_cInteger }); } if self.is_subtype(types::NilClassExact) { return Some(unsafe { rb_cNilClass }); } if self.is_subtype(types::ObjectExact) { return Some(unsafe { rb_cObject }); } + if self.is_subtype(types::RangeExact) { return Some(unsafe { rb_cRange }); } if self.is_subtype(types::StringExact) { return Some(unsafe { rb_cString }); } if self.is_subtype(types::SymbolExact) { return Some(unsafe { rb_cSymbol }); } if self.is_subtype(types::TrueClassExact) { return Some(unsafe { rb_cTrueClass }); }