from logging import getLogger from typing import Callable, Dict, List, Optional, Tuple, Union from pjvm.classloader import load_class from pjvm.clazz import Class, Code, Method from . import utils from .mockimpl import mock_instance_methods, mock_static_objects from .old_jtypes import JObj, JInteger, JString, JArray from .opcodes_names import INSTRUCTIONS from .exceptions import * from .stackframe import Frame OPCODE_TYPE = Callable[[int, "Frame"], Tuple[int, "JObj"]] OPCODE_BOUND_TYPE = Callable[["PJVirtualMachine", int, "Frame"], Tuple[int, "JObj"]] LOGGER = getLogger(__name__) class Opcode: opcodes: Dict[int, OPCODE_TYPE] = {} vm: "PJVirtualMachine" func: OPCODE_BOUND_TYPE owner: object @classmethod def register(cls, opcode: Union[int, List[int]], method: OPCODE_TYPE): if isinstance(opcode, int): cls.opcodes[opcode] = method else: for oc in opcode: cls.opcodes[oc] = method @classmethod def get_opcodes(cls, vm: "PJVirtualMachine") -> Dict[int, OPCODE_TYPE]: cls.vm = vm return cls.opcodes def __init__(self, opcode: int, size: int = 0, no_ret: bool = False): self.opcode = opcode self.size = size self.no_ret = no_ret def wrapped(self, opcode: int, frame: Frame): retval = self.func(self.vm, opcode, frame) frame.ip += self.size if self.no_ret and retval is None: retval = 0, None return retval def __call__(self, func: OPCODE_BOUND_TYPE): self.func = func self.register(self.opcode, self.wrapped) @Opcode(0xb2, size=3, no_ret=True) def op_getstatic(vm: "PJVirtualMachine", opcode: int, frame: Frame): """ Get static value from field """ index = frame.get_short(frame.ip + 1) owner, field_name, field_descriptor = frame.clazz.cp_get_fieldref(index) LOGGER.info(f"Fetching {owner} -> {field_name} of type {field_descriptor}") field = vm.get_static_field(owner, field_name) if not field.instance_of(field_descriptor): raise PJVMTypeError(f"Type mismatch {field.type} != {field_descriptor}") frame.stack.append(field) @Opcode(0x12, size=2, no_ret=True) def op_ldc(vm: "PJVirtualMachine", opcode: int, frame: Frame): """ Get constant from constant pool """ index = frame.get_byte(frame.ip + 1) value = frame.clazz.cp_get(index) resulting_value: JObj if value['type'] == 'String': resulting_value = JString(frame.clazz.cp_get_utf8(value['string_index'])) else: raise PJVMNotImplemented("did not feel like implementing the other stuff") frame.stack.append(resulting_value) @Opcode(0xb6, size=3, no_ret=True) def op_invokevirtual(vm: "PJVirtualMachine", opcode: int, frame: Frame): index = frame.get_short(frame.ip + 1) result: JObj class_name, method_name, descriptor = frame.clazz.cp_get_methodref(index) LOGGER.info(f"Calling {class_name} -> {method_name} {descriptor}") args_types = utils.get_argument_count_types_descriptor(descriptor) # todo type check args = [] for _ in args_types: args.append(frame.stack.pop()) object_ref = frame.stack.pop() result = vm.execute_instance_method(class_name, method_name, args, object_ref) if descriptor[-1] != 'V': frame.stack.append(result) @Opcode(0xb1, size=1) def op_return(vm: "PJVirtualMachine", opcode: int, frame: Frame) -> Tuple[int, Optional[JObj]]: return 1, None @Opcode(0x92, size=1, no_ret=True) def op_i2c(vm: "PJVirtualMachine", opcode: int, frame: Frame): top = frame.stack.pop() if not top.instance_of('I'): raise ValueError("Type mismatch!") char = top.value & 0xFFFF frame.stack.append(JInteger(char)) @Opcode([0x03b, 0x3c, 0x3d, 0x3e], size=1, no_ret=True) def op_istore_n(vm: "PJVirtualMachine", opcode: int, frame: Frame): idx = (opcode + 1) & 0b11 item = frame.stack.pop() if not item.instance_of('I'): raise PJVMTypeError("Must be integer") frame.set_local(item, idx) @Opcode([0x1a, 0x1b, 0x1c, 0x1d], size=1, no_ret=True) def op_iload_n(vm: "PJVirtualMachine", opcode: int, frame: Frame): idx = (opcode + 2) & 0b11 item = frame.get_local(idx) if not item.instance_of('I'): raise PJVMTypeError("Must be integer") frame.stack.append(item) @Opcode(0x10, size=2, no_ret=True) def op_bipush(vm: "PJVirtualMachine", opcode: int, frame: Frame): val = frame.get_byte(frame.ip+1) frame.stack.append(JInteger(val)) @Opcode([0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4], no_ret=True) def op_icmp_cond(vm: "PJVirtualMachine", opcode: int, frame: Frame): opcode_map = { 0x9f: int.__eq__, 0xa0: int.__ne__, 0xa1: int.__lt__, 0xa2: int.__ge__, 0xa3: int.__gt__, 0xa4: int.__le__, } left = frame.stack.pop() right = frame.stack.pop() if not left.instance_of("I") or not right.instance_of("I"): raise PJVMTypeError("Must be of type integer") if opcode_map[opcode](left.value, right.value): target = frame.get_short(frame.ip + 1) frame.ip += target else: frame.ip += 3 @Opcode(0xa7, no_ret=True) def op_goto(vm: "PJVirtualMachine", opcode: int, frame: Frame): target = frame.get_short(frame.ip + 1) frame.ip += target class PJVirtualMachine: instructions: Dict[int, OPCODE_TYPE] stack: List[Frame] = [] classes: Dict[str, Class] = {} def __init__(self, classpath: []): self.classpath = classpath self.instructions = Opcode.get_opcodes(self) self.load_classes() def load_classes(self): LOGGER.info(f"Loading {len(self.classpath)} classes") for class_file in self.classpath: LOGGER.info(f"Loading {class_file}") loader = load_class(class_file) clazz = Class(loader) self.classes[clazz.this_class] = clazz LOGGER.info("Done loading") @staticmethod def not_implemented_instruction(opcode: int, frame: Frame): LOGGER.error(f"OPCODE {opcode} -> {INSTRUCTIONS[opcode]['name']} is not yet implemented!") raise PJVMUnknownOpcode(opcode) def run(self, main_class: str, args: List[str] = None): if args is None: args = [] j_args: JArray = JArray(['java/lang/String'], list(map(JString, args))) if main_class not in self.classes: raise PJVMException("main_class not found!") self.execute_static_method(main_class, 'Main', [j_args]) def execute_static_method(self, class_name: str, method_name: str, args: List[JObj]): clazz = self.classes[class_name] method = clazz.methods[method_name] self.execute(clazz, method, args, None) def execute_instance_method(self, class_name: str, method_name: str, args: List[JObj], this: JObj) -> Optional[ JObj]: mock_class = mock_instance_methods.get(class_name) if mock_class: mock_method = mock_class.get(method_name) if mock_method: return mock_method(this, args) raise PJVMNotImplemented(f"Instance methods are not yet implemented {class_name}->{method_name}") def execute(self, clazz: Class, method: Method, args: List[JObj], this: Optional[JObj]): frame = Frame(clazz, method) frame.stack.extend(args) frame.this = this return_value = None self.stack.append(frame) while True: instruction = frame.code.code[frame.ip] try: impl = self.instructions[instruction] except KeyError: self.not_implemented_instruction(instruction, frame) break # unreachable result, meta = impl(instruction, frame) if result == 0: continue if result == 1: return_value = meta break if result == 2: raise PJVMNotImplemented("Exceptions are not implemented") self.stack.pop() if frame.method.return_value != 'V': # add the return value to the stack of the caller if it is not a void call self.stack[-1].stack.append(return_value) def get_static_field(self, class_name: str, field_name: str): mock_type = mock_static_objects.get(class_name, None) if mock_type: mock_field = mock_type.get(field_name) if mock_field: return mock_field raise PJVMNotImplemented(f"Static fields are not implemented yet, only as mock {class_name}->{field_name}")