ZJIT: Side-exit into the interpreter on unknown call types

This commit is contained in:
Max Bernstein 2025-05-23 13:45:38 -04:00 committed by Takashi Kokubun
parent a0df4cf6f1
commit d23fe287b6
Notes: git 2025-05-23 20:33:03 +00:00

View File

@ -1797,7 +1797,6 @@ pub enum CallType {
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum ParseError { pub enum ParseError {
StackUnderflow(FrameState), StackUnderflow(FrameState),
UnhandledCallType(CallType),
} }
/// Return the number of locals in the current ISEQ (includes parameters) /// Return the number of locals in the current ISEQ (includes parameters)
@ -1806,19 +1805,19 @@ fn num_locals(iseq: *const rb_iseq_t) -> usize {
} }
/// If we can't handle the type of send (yet), bail out. /// If we can't handle the type of send (yet), bail out.
fn filter_translatable_calls(flag: u32) -> Result<(), ParseError> { fn unknown_call_type(flag: u32) -> bool {
if (flag & VM_CALL_KW_SPLAT_MUT) != 0 { return Err(ParseError::UnhandledCallType(CallType::KwSplatMut)); } if (flag & VM_CALL_KW_SPLAT_MUT) != 0 { return true; }
if (flag & VM_CALL_ARGS_SPLAT_MUT) != 0 { return Err(ParseError::UnhandledCallType(CallType::SplatMut)); } if (flag & VM_CALL_ARGS_SPLAT_MUT) != 0 { return true; }
if (flag & VM_CALL_ARGS_SPLAT) != 0 { return Err(ParseError::UnhandledCallType(CallType::Splat)); } if (flag & VM_CALL_ARGS_SPLAT) != 0 { return true; }
if (flag & VM_CALL_KW_SPLAT) != 0 { return Err(ParseError::UnhandledCallType(CallType::KwSplat)); } if (flag & VM_CALL_KW_SPLAT) != 0 { return true; }
if (flag & VM_CALL_ARGS_BLOCKARG) != 0 { return Err(ParseError::UnhandledCallType(CallType::BlockArg)); } if (flag & VM_CALL_ARGS_BLOCKARG) != 0 { return true; }
if (flag & VM_CALL_KWARG) != 0 { return Err(ParseError::UnhandledCallType(CallType::Kwarg)); } if (flag & VM_CALL_KWARG) != 0 { return true; }
if (flag & VM_CALL_TAILCALL) != 0 { return Err(ParseError::UnhandledCallType(CallType::Tailcall)); } if (flag & VM_CALL_TAILCALL) != 0 { return true; }
if (flag & VM_CALL_SUPER) != 0 { return Err(ParseError::UnhandledCallType(CallType::Super)); } if (flag & VM_CALL_SUPER) != 0 { return true; }
if (flag & VM_CALL_ZSUPER) != 0 { return Err(ParseError::UnhandledCallType(CallType::Zsuper)); } if (flag & VM_CALL_ZSUPER) != 0 { return true; }
if (flag & VM_CALL_OPT_SEND) != 0 { return Err(ParseError::UnhandledCallType(CallType::OptSend)); } if (flag & VM_CALL_OPT_SEND) != 0 { return true; }
if (flag & VM_CALL_FORWARDING) != 0 { return Err(ParseError::UnhandledCallType(CallType::Forwarding)); } if (flag & VM_CALL_FORWARDING) != 0 { return true; }
Ok(()) false
} }
/// We have IseqPayload, which keeps track of HIR Types in the interpreter, but this is not useful /// We have IseqPayload, which keeps track of HIR Types in the interpreter, but this is not useful
@ -2147,7 +2146,12 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
// NB: opt_neq has two cd; get_arg(0) is for eq and get_arg(1) is for neq // NB: opt_neq has two cd; get_arg(0) is for eq and get_arg(1) is for neq
let cd: *const rb_call_data = get_arg(pc, 1).as_ptr(); let cd: *const rb_call_data = get_arg(pc, 1).as_ptr();
let call_info = unsafe { rb_get_call_data_ci(cd) }; let call_info = unsafe { rb_get_call_data_ci(cd) };
filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?; if unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
// Unknown call type; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
fun.push_insn(block, Insn::SideExit { state: exit_id });
break; // End the block
}
let argc = unsafe { vm_ci_argc((*cd).ci) }; let argc = unsafe { vm_ci_argc((*cd).ci) };
@ -2190,7 +2194,12 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
YARVINSN_opt_send_without_block => { YARVINSN_opt_send_without_block => {
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
let call_info = unsafe { rb_get_call_data_ci(cd) }; let call_info = unsafe { rb_get_call_data_ci(cd) };
filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?; if unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
// Unknown call type; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
fun.push_insn(block, Insn::SideExit { state: exit_id });
break; // End the block
}
let argc = unsafe { vm_ci_argc((*cd).ci) }; let argc = unsafe { vm_ci_argc((*cd).ci) };
@ -2213,7 +2222,12 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
let blockiseq: IseqPtr = get_arg(pc, 1).as_iseq(); let blockiseq: IseqPtr = get_arg(pc, 1).as_iseq();
let call_info = unsafe { rb_get_call_data_ci(cd) }; let call_info = unsafe { rb_get_call_data_ci(cd) };
filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?; if unknown_call_type(unsafe { rb_vm_ci_flag(call_info) }) {
// Unknown call type; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
fun.push_insn(block, Insn::SideExit { state: exit_id });
break; // End the block
}
let argc = unsafe { vm_ci_argc((*cd).ci) }; let argc = unsafe { vm_ci_argc((*cd).ci) };
let method_name = unsafe { let method_name = unsafe {
@ -3077,7 +3091,13 @@ mod tests {
eval(" eval("
def test(a) = foo(*a) def test(a) = foo(*a)
"); ");
assert_compile_fails("test", ParseError::UnhandledCallType(CallType::Splat)) assert_method_hir("test", expect![[r#"
fn test:
bb0(v0:BasicObject):
v2:BasicObject = PutSelf
v4:ArrayExact = ToArray v0
SideExit
"#]]);
} }
#[test] #[test]
@ -3085,7 +3105,12 @@ mod tests {
eval(" eval("
def test(a) = foo(&a) def test(a) = foo(&a)
"); ");
assert_compile_fails("test", ParseError::UnhandledCallType(CallType::BlockArg)) assert_method_hir("test", expect![[r#"
fn test:
bb0(v0:BasicObject):
v2:BasicObject = PutSelf
SideExit
"#]]);
} }
#[test] #[test]
@ -3093,7 +3118,13 @@ mod tests {
eval(" eval("
def test(a) = foo(a: 1) def test(a) = foo(a: 1)
"); ");
assert_compile_fails("test", ParseError::UnhandledCallType(CallType::Kwarg)) assert_method_hir("test", expect![[r#"
fn test:
bb0(v0:BasicObject):
v2:BasicObject = PutSelf
v3:Fixnum[1] = Const Value(1)
SideExit
"#]]);
} }
#[test] #[test]
@ -3101,7 +3132,12 @@ mod tests {
eval(" eval("
def test(a) = foo(**a) def test(a) = foo(**a)
"); ");
assert_compile_fails("test", ParseError::UnhandledCallType(CallType::KwSplat)) assert_method_hir("test", expect![[r#"
fn test:
bb0(v0:BasicObject):
v2:BasicObject = PutSelf
SideExit
"#]]);
} }
// TODO(max): Figure out how to generate a call with TAILCALL flag // TODO(max): Figure out how to generate a call with TAILCALL flag
@ -3165,7 +3201,15 @@ mod tests {
eval(" eval("
def test(*) = foo *, 1 def test(*) = foo *, 1
"); ");
assert_compile_fails("test", ParseError::UnhandledCallType(CallType::SplatMut)) assert_method_hir("test", expect![[r#"
fn test:
bb0(v0:ArrayExact):
v2:BasicObject = PutSelf
v4:ArrayExact = ToNewArray v0
v5:Fixnum[1] = Const Value(1)
ArrayPush v4, v5
SideExit
"#]]);
} }
#[test] #[test]