from logging import getLogger from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Tuple, Union from . import utils from .exceptions import PJVMNotImplemented, PJVMTypeError from .old_jtypes import JInteger, JObj, JString from .stackframe import Frame if TYPE_CHECKING: from .vm import PJVirtualMachine LOGGER = getLogger(__name__) OPCODE_TYPE = Callable[[int, "Frame"], Tuple[int, "JObj"]] OPCODE_BOUND_TYPE = Callable[["PJVirtualMachine", int, "Frame"], Tuple[int, "JObj"]] class Opcode: opcodes: Dict[int, OPCODE_TYPE] = {} vm: "PJVirtualMachine" func: OPCODE_BOUND_TYPE owner: object @classmethod def register(cls, opcode: Union[int, List[int]], method: OPCODE_TYPE): if isinstance(opcode, int): cls.opcodes[opcode] = method else: for oc in opcode: cls.opcodes[oc] = method @classmethod def get_opcodes(cls, vm: "PJVirtualMachine") -> Dict[int, OPCODE_TYPE]: cls.vm = vm return cls.opcodes def __init__(self, opcode: int, size: int = 0, no_ret: bool = False): self.opcode = opcode self.size = size self.no_ret = no_ret def wrapped(self, opcode: int, frame: Frame): retval = self.func(self.vm, opcode, frame) frame.ip += self.size if self.no_ret and retval is None: retval = 0, None return retval def __call__(self, func: OPCODE_BOUND_TYPE): self.func = func self.register(self.opcode, self.wrapped) @Opcode(0xb2, size=3, no_ret=True) def op_getstatic(vm: "PJVirtualMachine", opcode: int, frame: Frame): """ Get static value from field """ index = frame.get_short(frame.ip + 1) owner, field_name, field_descriptor = frame.clazz.cp_get_fieldref(index) LOGGER.info(f"Fetching {owner} -> {field_name} of type {field_descriptor}") field = vm.get_static_field(owner, field_name) if not field.instance_of(field_descriptor): raise PJVMTypeError(f"Type mismatch {field.type} != {field_descriptor}") frame.stack.append(field) @Opcode(0x12, size=2, no_ret=True) def op_ldc(vm: "PJVirtualMachine", opcode: int, frame: Frame): """ Get constant from constant pool """ index = frame.get_byte(frame.ip + 1) value = frame.clazz.cp_get(index) resulting_value: JObj if value['type'] == 'String': resulting_value = JString(frame.clazz.cp_get_utf8(value['string_index'])) else: raise PJVMNotImplemented("did not feel like implementing the other stuff") frame.stack.append(resulting_value) @Opcode(0xb8, size=3, no_ret=True) def op_invokestatic(vm: "PJVirtualMachine", opcode: int, frame: Frame): index = frame.get_short(frame.ip + 1) class_name, method_name, descriptor = frame.clazz.cp_get_methodref(index) LOGGER.info(f"Calling static {class_name} -> {method_name} {descriptor}") args_types = utils.get_argument_count_types_descriptor(descriptor) # todo type check args = [] for _ in args_types: args.append(frame.stack.pop()) result = vm.execute_static_method(class_name, method_name, args) if descriptor[-1] != 'V': frame.stack.append(result) @Opcode(0xb6, size=3, no_ret=True) def op_invokevirtual(vm: "PJVirtualMachine", opcode: int, frame: Frame): index = frame.get_short(frame.ip + 1) result: JObj class_name, method_name, descriptor = frame.clazz.cp_get_methodref(index) LOGGER.info(f"Calling virtual {class_name} -> {method_name} {descriptor}") args_types = utils.get_argument_count_types_descriptor(descriptor) # todo type check args = [] for _ in args_types: args.append(frame.stack.pop()) object_ref = frame.stack.pop() result = vm.execute_instance_method(class_name, method_name, args, object_ref) if descriptor[-1] != 'V': frame.stack.append(result) @Opcode(0xb1, size=1) def op_return(vm: "PJVirtualMachine", opcode: int, frame: Frame) -> Tuple[int, Optional[JObj]]: return 1, None @Opcode(0x92, size=1, no_ret=True) def op_i2c(vm: "PJVirtualMachine", opcode: int, frame: Frame): top = frame.stack.pop() if not top.instance_of('I'): raise ValueError("Type mismatch!") char = top.value & 0xFFFF frame.stack.append(JInteger(char)) @Opcode([0x03b, 0x3c, 0x3d, 0x3e], size=1, no_ret=True) def op_istore_n(vm: "PJVirtualMachine", opcode: int, frame: Frame): idx = (opcode + 1) & 0b11 item = frame.stack.pop() if not item.instance_of('I'): raise PJVMTypeError("Must be integer") frame.set_local(item, idx) @Opcode([0x1a, 0x1b, 0x1c, 0x1d], size=1, no_ret=True) def op_iload_n(vm: "PJVirtualMachine", opcode: int, frame: Frame): idx = (opcode + 2) & 0b11 item = frame.get_local(idx) if not item.instance_of('I'): raise PJVMTypeError("Must be integer") frame.stack.append(item) @Opcode(0x10, size=2, no_ret=True) def op_bipush(vm: "PJVirtualMachine", opcode: int, frame: Frame): val = frame.get_byte(frame.ip + 1) frame.stack.append(JInteger(val)) @Opcode([0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4], no_ret=True) def op_icmp_cond(vm: "PJVirtualMachine", opcode: int, frame: Frame): opcode_map = { 0x9f: int.__eq__, 0xa0: int.__ne__, 0xa1: int.__lt__, 0xa2: int.__ge__, 0xa3: int.__gt__, 0xa4: int.__le__, } left = frame.stack.pop() right = frame.stack.pop() if not left.instance_of("I") or not right.instance_of("I"): raise PJVMTypeError("Must be of type integer") if opcode_map[opcode](left.value, right.value): target = frame.get_short(frame.ip + 1) frame.ip += target else: frame.ip += 3 @Opcode(0xa7, no_ret=True) def op_goto(vm: "PJVirtualMachine", opcode: int, frame: Frame): target = frame.get_short(frame.ip + 1) frame.ip += target