Initial commit
This commit is contained in:
391
pjvm/vm.py
Normal file
391
pjvm/vm.py
Normal file
@@ -0,0 +1,391 @@
|
||||
import struct
|
||||
import sys
|
||||
from logging import getLogger
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pjvm.classloader import load_class
|
||||
from pjvm.clazz import Class, Code, Method
|
||||
from . import utils
|
||||
from .opcodes import INSTRUCTIONS
|
||||
from .exceptions import *
|
||||
|
||||
OPCODE_TYPE = Callable[[int, "Frame"], Tuple[int, "JObj"]]
|
||||
OPCODE_BOUND_TYPE = Callable[["PJVirtualMachine", int, "Frame"], Tuple[int, "JObj"]]
|
||||
|
||||
LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class JObj:
|
||||
type = 'java/lang/Object'
|
||||
value = None
|
||||
|
||||
def instance_of(self, java_type: str):
|
||||
return self.type == java_type # todo inheritance
|
||||
|
||||
|
||||
class JInteger(JObj):
|
||||
type = "I"
|
||||
|
||||
def __init__(self, val: int):
|
||||
self.value = val
|
||||
|
||||
|
||||
class JCharacter(JObj):
|
||||
type = 'C'
|
||||
|
||||
def __init__(self, val: int):
|
||||
self.value = val
|
||||
|
||||
|
||||
class JGenericObj(JObj):
|
||||
generic_types: List[str]
|
||||
pass
|
||||
|
||||
|
||||
class JString(JObj):
|
||||
type = 'java/lang/String'
|
||||
value: str
|
||||
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
|
||||
class JArray(JGenericObj):
|
||||
type = 'java/lang/Array'
|
||||
generic_types: List[str]
|
||||
values: List
|
||||
|
||||
def __init__(self, generic_types: List[str], values: List):
|
||||
self.generic_types = generic_types
|
||||
self.values = values
|
||||
|
||||
|
||||
class JMockSystemOut(JObj):
|
||||
type = "Ljava/io/PrintStream;"
|
||||
value = None
|
||||
|
||||
|
||||
class JMockSystemIn(JObj):
|
||||
type = "Ljava/io/InputStream;"
|
||||
value = None
|
||||
|
||||
|
||||
def mock_printstream_println(self: JObj, args: List[JObj]):
|
||||
if len(args) != 1:
|
||||
raise ValueError("Argument length mismatch")
|
||||
if args[0].type != 'java/lang/String':
|
||||
raise ValueError("Argument type mismatch")
|
||||
|
||||
print(args[0].value)
|
||||
|
||||
|
||||
def mock_inputstream_read(self: JObj, args: List[JObj]):
|
||||
character = sys.stdin.read(1).strip()
|
||||
return JInteger(ord(character))
|
||||
|
||||
|
||||
class Frame:
|
||||
ip: int = 0
|
||||
clazz: Class
|
||||
method: Method
|
||||
code: Code
|
||||
stack: List[JObj]
|
||||
this: JObj
|
||||
local_variables: List
|
||||
|
||||
def __init__(self, clazz: Class, method: Method):
|
||||
self.clazz = clazz
|
||||
self.method = method
|
||||
self.code = self.method.code
|
||||
self.stack = []
|
||||
self.local_variables = [None for _ in range(self.code.max_locals)]
|
||||
|
||||
def get_byte(self, idx: int) -> int:
|
||||
return struct.unpack('>b', self.code.code[idx:idx + 1])[0]
|
||||
|
||||
def get_short(self, idx: int) -> int:
|
||||
return struct.unpack('>h', self.code.code[idx:idx + 2])[0]
|
||||
|
||||
def get_int(self, idx: int) -> int:
|
||||
return struct.unpack('>i', self.code.code[idx:idx + 4])[0]
|
||||
|
||||
def set_local(self, val: JObj, index: int):
|
||||
if index > self.code.max_locals:
|
||||
raise PJVMException("Local out of bound!")
|
||||
self.local_variables[index] = val
|
||||
|
||||
def get_local(self, index: int) ->JObj:
|
||||
if index > self.code.max_locals:
|
||||
raise PJVMException("Local out of bound!")
|
||||
val = self.local_variables[index]
|
||||
if val is None:
|
||||
raise PJVMException("Local not set")
|
||||
return val
|
||||
|
||||
|
||||
class Opcode:
|
||||
opcodes: Dict[int, OPCODE_TYPE] = {}
|
||||
vm: "PJVirtualMachine"
|
||||
func: OPCODE_BOUND_TYPE
|
||||
owner: object
|
||||
|
||||
@classmethod
|
||||
def register(cls, opcode: Union[int, List[int]], method: OPCODE_TYPE):
|
||||
if isinstance(opcode, int):
|
||||
cls.opcodes[opcode] = method
|
||||
else:
|
||||
for oc in opcode:
|
||||
cls.opcodes[oc] = method
|
||||
|
||||
@classmethod
|
||||
def get_opcodes(cls, vm: "PJVirtualMachine") -> Dict[int, OPCODE_TYPE]:
|
||||
cls.vm = vm
|
||||
return cls.opcodes
|
||||
|
||||
def __init__(self, opcode: int, size: int = 0, no_ret: bool = False):
|
||||
self.opcode = opcode
|
||||
self.size = size
|
||||
self.no_ret = no_ret
|
||||
|
||||
def wrapped(self, opcode: int, frame: Frame):
|
||||
retval = self.func(self.vm, opcode, frame)
|
||||
frame.ip += self.size
|
||||
if self.no_ret and retval is None:
|
||||
retval = 0, None
|
||||
return retval
|
||||
|
||||
def __call__(self, func: OPCODE_BOUND_TYPE):
|
||||
self.func = func
|
||||
self.register(self.opcode, self.wrapped)
|
||||
|
||||
|
||||
@Opcode(0xb2, size=3, no_ret=True)
|
||||
def op_getstatic(vm: "PJVirtualMachine", opcode: int, frame: Frame):
|
||||
"""
|
||||
Get static value from field
|
||||
"""
|
||||
index = frame.get_short(frame.ip + 1)
|
||||
owner, field_name, field_descriptor = frame.clazz.cp_get_fieldref(index)
|
||||
LOGGER.info(f"Fetching {owner} -> {field_name} of type {field_descriptor}")
|
||||
|
||||
field = vm.get_static_field(owner, field_name)
|
||||
if not field.instance_of(field_descriptor):
|
||||
raise PJVMTypeError(f"Type mismatch {field.type} != {field_descriptor}")
|
||||
|
||||
frame.stack.append(field)
|
||||
|
||||
|
||||
@Opcode(0x12, size=2, no_ret=True)
|
||||
def op_ldc(vm: "PJVirtualMachine", opcode: int, frame: Frame):
|
||||
"""
|
||||
Get constant from constant pool
|
||||
"""
|
||||
index = frame.get_byte(frame.ip + 1)
|
||||
|
||||
value = frame.clazz.cp_get(index)
|
||||
resulting_value: JObj
|
||||
|
||||
if value['type'] == 'String':
|
||||
resulting_value = JString(frame.clazz.cp_get_utf8(value['string_index']))
|
||||
else:
|
||||
raise PJVMNotImplemented("did not feel like implementing the other stuff")
|
||||
|
||||
frame.stack.append(resulting_value)
|
||||
|
||||
|
||||
@Opcode(0xb6, size=3, no_ret=True)
|
||||
def op_invokevirtual(vm: "PJVirtualMachine", opcode: int, frame: Frame):
|
||||
index = frame.get_short(frame.ip + 1)
|
||||
result: JObj
|
||||
|
||||
class_name, method_name, descriptor = frame.clazz.cp_get_methodref(index)
|
||||
LOGGER.info(f"Calling {class_name} -> {method_name} {descriptor}")
|
||||
args_types = utils.get_argument_count_types_descriptor(descriptor)
|
||||
# todo type check
|
||||
args = []
|
||||
for _ in args_types:
|
||||
args.append(frame.stack.pop())
|
||||
object_ref = frame.stack.pop()
|
||||
|
||||
result = vm.execute_instance_method(class_name, method_name, args, object_ref)
|
||||
|
||||
if descriptor[-1] != 'V':
|
||||
frame.stack.append(result)
|
||||
|
||||
|
||||
@Opcode(0xb1, size=1)
|
||||
def op_return(vm: "PJVirtualMachine", opcode: int, frame: Frame) -> Tuple[int, Optional[JObj]]:
|
||||
return 1, None
|
||||
|
||||
|
||||
@Opcode(0x92, size=1, no_ret=True)
|
||||
def op_i2c(vm: "PJVirtualMachine", opcode: int, frame: Frame):
|
||||
top = frame.stack.pop()
|
||||
if not top.instance_of('I'):
|
||||
raise ValueError("Type mismatch!")
|
||||
|
||||
char = top.value & 0xFFFF
|
||||
frame.stack.append(JInteger(char))
|
||||
|
||||
|
||||
@Opcode([0x03b, 0x3c, 0x3d, 0x3e], size=1, no_ret=True)
|
||||
def op_istore_n(vm: "PJVirtualMachine", opcode: int, frame: Frame):
|
||||
idx = (opcode + 1) & 0b11
|
||||
item = frame.stack.pop()
|
||||
if not item.instance_of('I'):
|
||||
raise PJVMTypeError("Must be integer")
|
||||
|
||||
frame.set_local(item, idx)
|
||||
|
||||
|
||||
@Opcode([0x1a, 0x1b, 0x1c, 0x1d], size=1, no_ret=True)
|
||||
def op_iload_n(vm: "PJVirtualMachine", opcode: int, frame: Frame):
|
||||
idx = (opcode + 2) & 0b11
|
||||
item = frame.get_local(idx)
|
||||
if not item.instance_of('I'):
|
||||
raise PJVMTypeError("Must be integer")
|
||||
frame.stack.append(item)
|
||||
|
||||
|
||||
@Opcode(0x10, size=2, no_ret=True)
|
||||
def op_bipush(vm: "PJVirtualMachine", opcode: int, frame: Frame):
|
||||
val = frame.get_byte(frame.ip+1)
|
||||
frame.stack.append(JInteger(val))
|
||||
|
||||
|
||||
@Opcode([0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4], no_ret=True)
|
||||
def op_icmp_cond(vm: "PJVirtualMachine", opcode: int, frame: Frame):
|
||||
opcode_map = {
|
||||
0x9f: int.__eq__,
|
||||
0xa0: int.__ne__,
|
||||
0xa1: int.__lt__,
|
||||
0xa2: int.__ge__,
|
||||
0xa3: int.__gt__,
|
||||
0xa4: int.__le__,
|
||||
}
|
||||
left = frame.stack.pop()
|
||||
right = frame.stack.pop()
|
||||
if not left.instance_of("I") or not right.instance_of("I"):
|
||||
raise PJVMTypeError("Must be of type integer")
|
||||
if opcode_map[opcode](left.value, right.value):
|
||||
target = frame.get_short(frame.ip + 1)
|
||||
frame.ip += target
|
||||
else:
|
||||
frame.ip += 3
|
||||
|
||||
|
||||
@Opcode(0xa7, no_ret=True)
|
||||
def op_goto(vm: "PJVirtualMachine", opcode: int, frame: Frame):
|
||||
target = frame.get_short(frame.ip + 1)
|
||||
frame.ip += target
|
||||
|
||||
|
||||
class PJVirtualMachine:
|
||||
instructions: Dict[int, OPCODE_TYPE]
|
||||
|
||||
stack: List[Frame] = []
|
||||
|
||||
classes: Dict[str, Class] = {}
|
||||
|
||||
mock_static_objects = {
|
||||
'java/lang/System': {
|
||||
'out': JMockSystemOut(),
|
||||
'in': JMockSystemIn()
|
||||
}
|
||||
}
|
||||
|
||||
mock_instance_methods = {
|
||||
'java/io/PrintStream': {
|
||||
'println': mock_printstream_println
|
||||
},
|
||||
'java/io/InputStream': {
|
||||
'read': mock_inputstream_read
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, classpath: []):
|
||||
self.classpath = classpath
|
||||
self.instructions = Opcode.get_opcodes(self)
|
||||
self.load_classes()
|
||||
|
||||
def load_classes(self):
|
||||
LOGGER.info(f"Loading {len(self.classpath)} classes")
|
||||
for class_file in self.classpath:
|
||||
LOGGER.info(f"Loading {class_file}")
|
||||
loader = load_class(class_file)
|
||||
clazz = Class(loader)
|
||||
self.classes[clazz.this_class] = clazz
|
||||
|
||||
LOGGER.info("Done loading")
|
||||
|
||||
@staticmethod
|
||||
def not_implemented_instruction(opcode: int, frame: Frame):
|
||||
LOGGER.error(f"OPCODE {opcode} -> {INSTRUCTIONS[opcode]['name']} is not yet implemented!")
|
||||
raise PJVMUnknownOpcode(opcode)
|
||||
|
||||
def run(self, main_class: str, args: List[str] = None):
|
||||
if args is None:
|
||||
args = []
|
||||
|
||||
j_args: JArray = JArray(['java/lang/String'], list(map(JString, args)))
|
||||
|
||||
if main_class not in self.classes:
|
||||
raise PJVMException("main_class not found!")
|
||||
|
||||
self.execute_static_method(main_class, 'Main', [j_args])
|
||||
|
||||
def execute_static_method(self, class_name: str, method_name: str, args: List[JObj]):
|
||||
clazz = self.classes[class_name]
|
||||
method = clazz.methods[method_name]
|
||||
self.execute(clazz, method, args, None)
|
||||
|
||||
def execute_instance_method(self, class_name: str, method_name: str, args: List[JObj], this: JObj) -> Optional[
|
||||
JObj]:
|
||||
mock_class = self.mock_instance_methods.get(class_name)
|
||||
if mock_class:
|
||||
mock_method = mock_class.get(method_name)
|
||||
if mock_method:
|
||||
return mock_method(this, args)
|
||||
|
||||
raise PJVMNotImplemented(f"Instance methods are not yet implemented {class_name}->{method_name}")
|
||||
|
||||
def execute(self, clazz: Class, method: Method, args: List[JObj], this: Optional[JObj]):
|
||||
frame = Frame(clazz, method)
|
||||
frame.stack.extend(args)
|
||||
frame.this = this
|
||||
return_value = None
|
||||
|
||||
self.stack.append(frame)
|
||||
|
||||
while True:
|
||||
instruction = frame.code.code[frame.ip]
|
||||
try:
|
||||
impl = self.instructions[instruction]
|
||||
except KeyError:
|
||||
self.not_implemented_instruction(instruction, frame)
|
||||
break # unreachable
|
||||
|
||||
result, meta = impl(instruction, frame)
|
||||
if result == 0:
|
||||
continue
|
||||
|
||||
if result == 1:
|
||||
return_value = meta
|
||||
break
|
||||
if result == 2:
|
||||
raise PJVMNotImplemented("Exceptions are not implemented")
|
||||
|
||||
self.stack.pop()
|
||||
|
||||
if frame.method.return_value != 'V':
|
||||
# add the return value to the stack of the caller if it is not a void call
|
||||
self.stack[-1].stack.append(return_value)
|
||||
|
||||
def get_static_field(self, class_name: str, field_name: str):
|
||||
mock_type = self.mock_static_objects.get(class_name, None)
|
||||
if mock_type:
|
||||
mock_field = mock_type.get(field_name)
|
||||
if mock_field:
|
||||
return mock_field
|
||||
|
||||
raise PJVMNotImplemented(f"Static fields are not implemented yet, only as mock {class_name}->{field_name}")
|
||||
Reference in New Issue
Block a user