import struct import sys from logging import getLogger from typing import Callable, Dict, List, Optional, Tuple, Union from pjvm.classloader import load_class from pjvm.clazz import Class, Code, Method from . import utils from .opcodes import INSTRUCTIONS from .exceptions import * OPCODE_TYPE = Callable[[int, "Frame"], Tuple[int, "JObj"]] OPCODE_BOUND_TYPE = Callable[["PJVirtualMachine", int, "Frame"], Tuple[int, "JObj"]] LOGGER = getLogger(__name__) class JObj: type = 'java/lang/Object' value = None def instance_of(self, java_type: str): return self.type == java_type # todo inheritance class JInteger(JObj): type = "I" def __init__(self, val: int): self.value = val class JCharacter(JObj): type = 'C' def __init__(self, val: int): self.value = val class JGenericObj(JObj): generic_types: List[str] pass class JString(JObj): type = 'java/lang/String' value: str def __init__(self, value: str): self.value = value class JArray(JGenericObj): type = 'java/lang/Array' generic_types: List[str] values: List def __init__(self, generic_types: List[str], values: List): self.generic_types = generic_types self.values = values class JMockSystemOut(JObj): type = "Ljava/io/PrintStream;" value = None class JMockSystemIn(JObj): type = "Ljava/io/InputStream;" value = None def mock_printstream_println(self: JObj, args: List[JObj]): if len(args) != 1: raise ValueError("Argument length mismatch") if args[0].type != 'java/lang/String': raise ValueError("Argument type mismatch") print(args[0].value) def mock_inputstream_read(self: JObj, args: List[JObj]): character = sys.stdin.read(1).strip() return JInteger(ord(character)) class Frame: ip: int = 0 clazz: Class method: Method code: Code stack: List[JObj] this: JObj local_variables: List def __init__(self, clazz: Class, method: Method): self.clazz = clazz self.method = method self.code = self.method.code self.stack = [] self.local_variables = [None for _ in range(self.code.max_locals)] def get_byte(self, idx: int) -> int: return struct.unpack('>b', self.code.code[idx:idx + 1])[0] def get_short(self, idx: int) -> int: return struct.unpack('>h', self.code.code[idx:idx + 2])[0] def get_int(self, idx: int) -> int: return struct.unpack('>i', self.code.code[idx:idx + 4])[0] def set_local(self, val: JObj, index: int): if index > self.code.max_locals: raise PJVMException("Local out of bound!") self.local_variables[index] = val def get_local(self, index: int) ->JObj: if index > self.code.max_locals: raise PJVMException("Local out of bound!") val = self.local_variables[index] if val is None: raise PJVMException("Local not set") return val class Opcode: opcodes: Dict[int, OPCODE_TYPE] = {} vm: "PJVirtualMachine" func: OPCODE_BOUND_TYPE owner: object @classmethod def register(cls, opcode: Union[int, List[int]], method: OPCODE_TYPE): if isinstance(opcode, int): cls.opcodes[opcode] = method else: for oc in opcode: cls.opcodes[oc] = method @classmethod def get_opcodes(cls, vm: "PJVirtualMachine") -> Dict[int, OPCODE_TYPE]: cls.vm = vm return cls.opcodes def __init__(self, opcode: int, size: int = 0, no_ret: bool = False): self.opcode = opcode self.size = size self.no_ret = no_ret def wrapped(self, opcode: int, frame: Frame): retval = self.func(self.vm, opcode, frame) frame.ip += self.size if self.no_ret and retval is None: retval = 0, None return retval def __call__(self, func: OPCODE_BOUND_TYPE): self.func = func self.register(self.opcode, self.wrapped) @Opcode(0xb2, size=3, no_ret=True) def op_getstatic(vm: "PJVirtualMachine", opcode: int, frame: Frame): """ Get static value from field """ index = frame.get_short(frame.ip + 1) owner, field_name, field_descriptor = frame.clazz.cp_get_fieldref(index) LOGGER.info(f"Fetching {owner} -> {field_name} of type {field_descriptor}") field = vm.get_static_field(owner, field_name) if not field.instance_of(field_descriptor): raise PJVMTypeError(f"Type mismatch {field.type} != {field_descriptor}") frame.stack.append(field) @Opcode(0x12, size=2, no_ret=True) def op_ldc(vm: "PJVirtualMachine", opcode: int, frame: Frame): """ Get constant from constant pool """ index = frame.get_byte(frame.ip + 1) value = frame.clazz.cp_get(index) resulting_value: JObj if value['type'] == 'String': resulting_value = JString(frame.clazz.cp_get_utf8(value['string_index'])) else: raise PJVMNotImplemented("did not feel like implementing the other stuff") frame.stack.append(resulting_value) @Opcode(0xb6, size=3, no_ret=True) def op_invokevirtual(vm: "PJVirtualMachine", opcode: int, frame: Frame): index = frame.get_short(frame.ip + 1) result: JObj class_name, method_name, descriptor = frame.clazz.cp_get_methodref(index) LOGGER.info(f"Calling {class_name} -> {method_name} {descriptor}") args_types = utils.get_argument_count_types_descriptor(descriptor) # todo type check args = [] for _ in args_types: args.append(frame.stack.pop()) object_ref = frame.stack.pop() result = vm.execute_instance_method(class_name, method_name, args, object_ref) if descriptor[-1] != 'V': frame.stack.append(result) @Opcode(0xb1, size=1) def op_return(vm: "PJVirtualMachine", opcode: int, frame: Frame) -> Tuple[int, Optional[JObj]]: return 1, None @Opcode(0x92, size=1, no_ret=True) def op_i2c(vm: "PJVirtualMachine", opcode: int, frame: Frame): top = frame.stack.pop() if not top.instance_of('I'): raise ValueError("Type mismatch!") char = top.value & 0xFFFF frame.stack.append(JInteger(char)) @Opcode([0x03b, 0x3c, 0x3d, 0x3e], size=1, no_ret=True) def op_istore_n(vm: "PJVirtualMachine", opcode: int, frame: Frame): idx = (opcode + 1) & 0b11 item = frame.stack.pop() if not item.instance_of('I'): raise PJVMTypeError("Must be integer") frame.set_local(item, idx) @Opcode([0x1a, 0x1b, 0x1c, 0x1d], size=1, no_ret=True) def op_iload_n(vm: "PJVirtualMachine", opcode: int, frame: Frame): idx = (opcode + 2) & 0b11 item = frame.get_local(idx) if not item.instance_of('I'): raise PJVMTypeError("Must be integer") frame.stack.append(item) @Opcode(0x10, size=2, no_ret=True) def op_bipush(vm: "PJVirtualMachine", opcode: int, frame: Frame): val = frame.get_byte(frame.ip+1) frame.stack.append(JInteger(val)) @Opcode([0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4], no_ret=True) def op_icmp_cond(vm: "PJVirtualMachine", opcode: int, frame: Frame): opcode_map = { 0x9f: int.__eq__, 0xa0: int.__ne__, 0xa1: int.__lt__, 0xa2: int.__ge__, 0xa3: int.__gt__, 0xa4: int.__le__, } left = frame.stack.pop() right = frame.stack.pop() if not left.instance_of("I") or not right.instance_of("I"): raise PJVMTypeError("Must be of type integer") if opcode_map[opcode](left.value, right.value): target = frame.get_short(frame.ip + 1) frame.ip += target else: frame.ip += 3 @Opcode(0xa7, no_ret=True) def op_goto(vm: "PJVirtualMachine", opcode: int, frame: Frame): target = frame.get_short(frame.ip + 1) frame.ip += target class PJVirtualMachine: instructions: Dict[int, OPCODE_TYPE] stack: List[Frame] = [] classes: Dict[str, Class] = {} mock_static_objects = { 'java/lang/System': { 'out': JMockSystemOut(), 'in': JMockSystemIn() } } mock_instance_methods = { 'java/io/PrintStream': { 'println': mock_printstream_println }, 'java/io/InputStream': { 'read': mock_inputstream_read } } def __init__(self, classpath: []): self.classpath = classpath self.instructions = Opcode.get_opcodes(self) self.load_classes() def load_classes(self): LOGGER.info(f"Loading {len(self.classpath)} classes") for class_file in self.classpath: LOGGER.info(f"Loading {class_file}") loader = load_class(class_file) clazz = Class(loader) self.classes[clazz.this_class] = clazz LOGGER.info("Done loading") @staticmethod def not_implemented_instruction(opcode: int, frame: Frame): LOGGER.error(f"OPCODE {opcode} -> {INSTRUCTIONS[opcode]['name']} is not yet implemented!") raise PJVMUnknownOpcode(opcode) def run(self, main_class: str, args: List[str] = None): if args is None: args = [] j_args: JArray = JArray(['java/lang/String'], list(map(JString, args))) if main_class not in self.classes: raise PJVMException("main_class not found!") self.execute_static_method(main_class, 'Main', [j_args]) def execute_static_method(self, class_name: str, method_name: str, args: List[JObj]): clazz = self.classes[class_name] method = clazz.methods[method_name] self.execute(clazz, method, args, None) def execute_instance_method(self, class_name: str, method_name: str, args: List[JObj], this: JObj) -> Optional[ JObj]: mock_class = self.mock_instance_methods.get(class_name) if mock_class: mock_method = mock_class.get(method_name) if mock_method: return mock_method(this, args) raise PJVMNotImplemented(f"Instance methods are not yet implemented {class_name}->{method_name}") def execute(self, clazz: Class, method: Method, args: List[JObj], this: Optional[JObj]): frame = Frame(clazz, method) frame.stack.extend(args) frame.this = this return_value = None self.stack.append(frame) while True: instruction = frame.code.code[frame.ip] try: impl = self.instructions[instruction] except KeyError: self.not_implemented_instruction(instruction, frame) break # unreachable result, meta = impl(instruction, frame) if result == 0: continue if result == 1: return_value = meta break if result == 2: raise PJVMNotImplemented("Exceptions are not implemented") self.stack.pop() if frame.method.return_value != 'V': # add the return value to the stack of the caller if it is not a void call self.stack[-1].stack.append(return_value) def get_static_field(self, class_name: str, field_name: str): mock_type = self.mock_static_objects.get(class_name, None) if mock_type: mock_field = mock_type.get(field_name) if mock_field: return mock_field raise PJVMNotImplemented(f"Static fields are not implemented yet, only as mock {class_name}->{field_name}")