RJIT: Support entry with different PCs

This commit is contained in:
Takashi Kokubun 2023-04-02 15:26:46 -07:00
parent 4fc336127e
commit 6002b12611
9 changed files with 106 additions and 29 deletions

View File

@ -3,6 +3,7 @@ require 'ruby_vm/rjit/block'
require 'ruby_vm/rjit/branch_stub' require 'ruby_vm/rjit/branch_stub'
require 'ruby_vm/rjit/code_block' require 'ruby_vm/rjit/code_block'
require 'ruby_vm/rjit/context' require 'ruby_vm/rjit/context'
require 'ruby_vm/rjit/entry_stub'
require 'ruby_vm/rjit/exit_compiler' require 'ruby_vm/rjit/exit_compiler'
require 'ruby_vm/rjit/insn_compiler' require 'ruby_vm/rjit/insn_compiler'
require 'ruby_vm/rjit/instruction' require 'ruby_vm/rjit/instruction'
@ -64,6 +65,48 @@ module RubyVM::RJIT
exit 1 exit 1
end end
# Compile an entry.
# @param entry [RubyVM::RJIT::EntryStub]
def entry_stub_hit(entry_stub, cfp)
# Compile a new entry guard as a next entry
pc = cfp.pc.to_i
next_entry = Assembler.new.then do |asm|
compile_entry_chain_guard(asm, cfp.iseq, pc)
@cb.write(asm)
end
# Try to find an existing compiled version of this block
ctx = Context.new
block = find_block(cfp.iseq, pc, ctx)
if block
# If an existing block is found, generate a jump to the block.
asm = Assembler.new
asm.jmp(block.start_addr)
@cb.write(asm)
else
# If this block hasn't yet been compiled, generate blocks after the entry guard.
asm = Assembler.new
jit = JITState.new(iseq: cfp.iseq, cfp:)
compile_block(asm, jit:, pc:, ctx:)
@cb.write(asm)
block = jit.block
end
# Regenerate the previous entry
@cb.with_write_addr(entry_stub.start_addr) do
# The last instruction of compile_entry_chain_guard is jne
asm = Assembler.new
asm.jne(next_entry)
@cb.write(asm)
end
return block.start_addr
rescue Exception => e
$stderr.puts e.full_message
exit 1
end
# Compile a branch stub. # Compile a branch stub.
# @param branch_stub [RubyVM::RJIT::BranchStub] # @param branch_stub [RubyVM::RJIT::BranchStub]
# @param cfp `RubyVM::RJIT::CPointer::Struct_rb_control_frame_t` # @param cfp `RubyVM::RJIT::CPointer::Struct_rb_control_frame_t`
@ -210,30 +253,24 @@ module RubyVM::RJIT
# compiled for is the same PC that the interpreter wants us to run with. # compiled for is the same PC that the interpreter wants us to run with.
# If they don't match, then we'll take a side exit. # If they don't match, then we'll take a side exit.
if iseq.body.param.flags.has_opt if iseq.body.param.flags.has_opt
compile_pc_guard(asm, iseq, pc) compile_entry_chain_guard(asm, iseq, pc)
end end
end end
def compile_pc_guard(asm, iseq, pc) def compile_entry_chain_guard(asm, iseq, pc)
entry_stub = EntryStub.new
stub_addr = Assembler.new.then do |ocb_asm|
@exit_compiler.compile_entry_stub(ocb_asm, entry_stub)
@ocb.write(ocb_asm)
end
asm.comment('guard expected PC') asm.comment('guard expected PC')
asm.mov(:rax, pc) asm.mov(:rax, pc)
asm.cmp([CFP, C.rb_control_frame_t.offsetof(:pc)], :rax) asm.cmp([CFP, C.rb_control_frame_t.offsetof(:pc)], :rax)
pc_match = asm.new_label('pc_match') asm.stub(entry_stub) do
asm.je(pc_match) asm.jne(stub_addr)
end
# We're not starting at the first PC, so we need to exit.
asm.incr_counter(:leave_start_pc_non_zero)
asm.pop(SP)
asm.pop(EC)
asm.pop(CFP)
asm.mov(:rax, Qundef)
asm.ret
# PC should match the expected insn_idx
asm.write_label(pc_match)
end end
# @param asm [RubyVM::RJIT::Assembler] # @param asm [RubyVM::RJIT::Assembler]

View File

@ -0,0 +1,7 @@
module RubyVM::RJIT
class EntryStub < Struct.new(
:start_addr, # @param [Integer] Stub source start address to be re-generated
:end_addr, # @param [Integer] Stub source end address to be re-generated
)
end
end

View File

@ -78,6 +78,18 @@ module RubyVM::RJIT
asm.ret asm.ret
end end
# @param asm [RubyVM::RJIT::Assembler]
# @param entry_stub [RubyVM::RJIT::EntryStub]
def compile_entry_stub(asm, entry_stub)
# Call rb_rjit_entry_stub_hit
asm.comment('entry stub hit')
asm.mov(C_ARGS[0], to_value(entry_stub))
asm.call(C.rb_rjit_entry_stub_hit)
# Jump to the address returned by rb_rjit_entry_stub_hit
asm.jmp(:rax)
end
# @param ctx [RubyVM::RJIT::Context] # @param ctx [RubyVM::RJIT::Context]
# @param asm [RubyVM::RJIT::Assembler] # @param asm [RubyVM::RJIT::Assembler]
# @param branch_stub [RubyVM::RJIT::BranchStub] # @param branch_stub [RubyVM::RJIT::BranchStub]
@ -93,7 +105,7 @@ module RubyVM::RJIT
asm.mov(:edx, target0_p ? 1 : 0) asm.mov(:edx, target0_p ? 1 : 0)
asm.call(C.rb_rjit_branch_stub_hit) asm.call(C.rb_rjit_branch_stub_hit)
# Jump to the address returned by rb_rjit_stub_hit # Jump to the address returned by rb_rjit_branch_stub_hit
asm.jmp(:rax) asm.jmp(:rax)
end end

View File

@ -40,7 +40,6 @@ module RubyVM::RJIT
print_counters(stats, prefix: 'send_', prompt: 'method call exit reasons') print_counters(stats, prefix: 'send_', prompt: 'method call exit reasons')
print_counters(stats, prefix: 'invokeblock_', prompt: 'invokeblock exit reasons') print_counters(stats, prefix: 'invokeblock_', prompt: 'invokeblock exit reasons')
print_counters(stats, prefix: 'invokesuper_', prompt: 'invokesuper exit reasons') print_counters(stats, prefix: 'invokesuper_', prompt: 'invokesuper exit reasons')
print_counters(stats, prefix: 'leave_', prompt: 'leave exit reasons')
print_counters(stats, prefix: 'getblockpp_', prompt: 'getblockparamproxy exit reasons') print_counters(stats, prefix: 'getblockpp_', prompt: 'getblockparamproxy exit reasons')
print_counters(stats, prefix: 'getivar_', prompt: 'getinstancevariable exit reasons') print_counters(stats, prefix: 'getivar_', prompt: 'getinstancevariable exit reasons')
print_counters(stats, prefix: 'setivar_', prompt: 'setinstancevariable exit reasons') print_counters(stats, prefix: 'setivar_', prompt: 'setinstancevariable exit reasons')

20
rjit.c
View File

@ -360,6 +360,26 @@ rb_rjit_compile(const rb_iseq_t *iseq)
RB_VM_LOCK_LEAVE(); RB_VM_LOCK_LEAVE();
} }
void *
rb_rjit_entry_stub_hit(VALUE branch_stub)
{
VALUE result;
RB_VM_LOCK_ENTER();
rb_vm_barrier();
rb_control_frame_t *cfp = GET_EC()->cfp;
WITH_RJIT_ISOLATED({
VALUE cfp_ptr = rb_funcall(rb_cRJITCfpPtr, rb_intern("new"), 1, SIZET2NUM((size_t)cfp));
result = rb_funcall(rb_RJITCompiler, rb_intern("entry_stub_hit"), 2, branch_stub, cfp_ptr);
});
RB_VM_LOCK_LEAVE();
return (void *)NUM2SIZET(result);
}
void * void *
rb_rjit_branch_stub_hit(VALUE branch_stub, int sp_offset, int target0_p) rb_rjit_branch_stub_hit(VALUE branch_stub, int sp_offset, int target0_p)
{ {

View File

@ -530,6 +530,8 @@ extern const rb_callable_method_entry_t *rb_callable_method_entry_or_negative(VA
extern VALUE rb_vm_yield_with_cfunc(rb_execution_context_t *ec, const struct rb_captured_block *captured, int argc, const VALUE *argv); extern VALUE rb_vm_yield_with_cfunc(rb_execution_context_t *ec, const struct rb_captured_block *captured, int argc, const VALUE *argv);
extern VALUE rb_vm_set_ivar_id(VALUE obj, ID id, VALUE val); extern VALUE rb_vm_set_ivar_id(VALUE obj, ID id, VALUE val);
extern VALUE rb_ary_unshift_m(int argc, VALUE *argv, VALUE ary); extern VALUE rb_ary_unshift_m(int argc, VALUE *argv, VALUE ary);
extern void* rb_rjit_entry_stub_hit(VALUE branch_stub, int sp_offset, int target0_p);
extern void* rb_rjit_branch_stub_hit(VALUE branch_stub, int sp_offset, int target0_p);
#include "rjit_c.rbinc" #include "rjit_c.rbinc"

View File

@ -157,8 +157,6 @@ RJIT_RUNTIME_COUNTERS(
getblockpp_not_gc_guarded, getblockpp_not_gc_guarded,
getblockpp_not_iseq_block, getblockpp_not_iseq_block,
leave_start_pc_non_zero,
compiled_block_count compiled_block_count
) )
#undef RJIT_RUNTIME_COUNTERS #undef RJIT_RUNTIME_COUNTERS

View File

@ -25,13 +25,6 @@ module RubyVM::RJIT # :nodoc: all
CType::Immediate.parse("size_t").new(addr) CType::Immediate.parse("size_t").new(addr)
end end
def rb_rjit_branch_stub_hit
Primitive.cstmt! %{
extern void *rb_rjit_branch_stub_hit(VALUE branch_stub, int sp_offset, int target0_p);
return SIZET2NUM((size_t)rb_rjit_branch_stub_hit);
}
end
def rb_rjit_counters def rb_rjit_counters
addr = Primitive.cexpr! 'SIZET2NUM((size_t)&rb_rjit_counters)' addr = Primitive.cexpr! 'SIZET2NUM((size_t)&rb_rjit_counters)'
rb_rjit_runtime_counters.new(addr) rb_rjit_runtime_counters.new(addr)
@ -659,6 +652,14 @@ module RubyVM::RJIT # :nodoc: all
Primitive.cexpr! %q{ SIZET2NUM((size_t)rb_reg_nth_match) } Primitive.cexpr! %q{ SIZET2NUM((size_t)rb_reg_nth_match) }
end end
def C.rb_rjit_branch_stub_hit
Primitive.cexpr! %q{ SIZET2NUM((size_t)rb_rjit_branch_stub_hit) }
end
def C.rb_rjit_entry_stub_hit
Primitive.cexpr! %q{ SIZET2NUM((size_t)rb_rjit_entry_stub_hit) }
end
def C.rb_str_buf_append def C.rb_str_buf_append
Primitive.cexpr! %q{ SIZET2NUM((size_t)rb_str_buf_append) } Primitive.cexpr! %q{ SIZET2NUM((size_t)rb_str_buf_append) }
end end
@ -1436,7 +1437,6 @@ module RubyVM::RJIT # :nodoc: all
getblockpp_block_handler_none: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), getblockpp_block_handler_none)")], getblockpp_block_handler_none: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), getblockpp_block_handler_none)")],
getblockpp_not_gc_guarded: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), getblockpp_not_gc_guarded)")], getblockpp_not_gc_guarded: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), getblockpp_not_gc_guarded)")],
getblockpp_not_iseq_block: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), getblockpp_not_iseq_block)")], getblockpp_not_iseq_block: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), getblockpp_not_iseq_block)")],
leave_start_pc_non_zero: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), leave_start_pc_non_zero)")],
compiled_block_count: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), compiled_block_count)")], compiled_block_count: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_rjit_runtime_counters *)NULL)), compiled_block_count)")],
) )
end end

View File

@ -568,6 +568,8 @@ generator = BindingGenerator.new(
rjit_rb_ary_subseq_length rjit_rb_ary_subseq_length
rb_ary_unshift_m rb_ary_unshift_m
rjit_build_kwhash rjit_build_kwhash
rb_rjit_entry_stub_hit
rb_rjit_branch_stub_hit
], ],
types: %w[ types: %w[
CALL_DATA CALL_DATA