import logging from typing import BinaryIO, IO, Union from .unpacker import Unpacker LOGGER = logging.getLogger(__name__) formats = { 'magic': '4s', 'major_version': 'h', 'minor_version': 'h', 'constant_pool_count': 'h', 'access_flags': 'h', 'this_class': 'h', 'super_class': 'h', 'interfaces_count': 'h', 'fields_count': 'h', 'methods_count': 'h', 'attributes_count': 'h', 'byte': 'b', 'short': 'h' } def create_constant_type(name, fmt, *fieldnames): if len(fmt) != len(fieldnames): raise ValueError("Length mismatch format and names") return { 'name': name, 'format': fmt, 'fieldnames': fieldnames } # https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html#jvms-4.4 CONSTANT_POOL_TYPES = { 7: create_constant_type('Class', 'h', 'name_index'), 9: create_constant_type('Fieldref', 'hh', 'class_index', 'name_and_type_index'), 10: create_constant_type('Methodref', 'hh', 'class_index', 'name_and_type_index'), 11: create_constant_type('InterfaceMethodref', 'hh', 'class_index', 'name_and_type_index'), 8: create_constant_type('String', 'h', 'string_index'), 3: create_constant_type('Integer', 'i', 'value'), # divergence, actual value instead of raw bytes 4: create_constant_type('Float', 'f', 'value'), # divergence, actual value instead of raw bytes 5: create_constant_type('Long', 'q', 'value'), # divergence, actual value instead of raw bytes 6: create_constant_type('Double', 'd', 'value'), # divergence, actual value instead of raw bytes 12: create_constant_type('NameAndType', 'hh', 'name_index', 'descriptor_index'), 1: None, # TODO: special format. Pascal coded but with short instead of byte 15: create_constant_type('MethodHandle', 'bh', 'reference_kind', 'reference_index'), 16: create_constant_type('MethodType', 'h', 'descriptor_index'), 18: create_constant_type('InvokeDynamic', 'hh', 'bootstrap_method_attr_index', 'name_ant_type_index') } JAVA_CLASS_MAGIC = b'\xca\xfe\xba\xbe' class ClassLoader: file: IO magic: bytes major_version: int minor_version: int constant_pool_count: int constant_pool: list access_flags: int this_class: int super_class: int interfaces_count: int interfaces: list fields_count: int fields: list methods_count: int methods: list attributes_count: int attributes: list def __init__(self, stream: BinaryIO): self.stream = stream self.unpacker = Unpacker(formats, self.stream) def load(self): self.magic, = self.unpacker.magic if self.magic != JAVA_CLASS_MAGIC: raise ValueError(f"Not a valid java class file! Magic: {self.magic!r}") self.minor_version, = self.unpacker.minor_version self.major_version, = self.unpacker.major_version LOGGER.info(f"Parsing class with version {self.major_version}.{self.minor_version}") self.constant_pool_count, = self.unpacker.constant_pool_count LOGGER.info(f"Parsing {self.constant_pool_count} constant pool items") self.constant_pool = [] for i in range(self.constant_pool_count - 1): tag, = self.unpacker.byte if tag == 1: length, = self.unpacker.short utf8_str, = self.unpacker[f'{length}s'] data = { 'tag': tag, 'type': 'Utf8', 'value': utf8_str.decode() } else: constant_info = CONSTANT_POOL_TYPES.get(tag) data = { 'tag': tag, 'type': constant_info['name'], **dict(zip(constant_info['fieldnames'], self.unpacker[constant_info['format']])) } self.constant_pool.append(data) self.access_flags, = self.unpacker.access_flags # todo enum.IntFlag self.this_class, = self.unpacker.this_class self.super_class, = self.unpacker.super_class self.interfaces_count, = self.unpacker.interfaces_count self.interfaces = [] for i in range(self.interfaces_count): self.interfaces.append(self.unpacker.short[0]) self.fields_count, = self.unpacker.fields_count self.fields = [] for i in range(self.fields_count): access_flags, name_index, descriptor_index, attributes_count = \ self.unpacker['hhhh'] attributes = [self._parse_attribute() for _ in range(attributes_count)] self.fields.append({ 'access_flags': access_flags, 'name_index': name_index, 'descriptor_index': descriptor_index, 'attributes_count': attributes_count, 'attributes': attributes }) self.methods_count, = self.unpacker.methods_count self.methods = [] for i in range(self.methods_count): access_flags, name_index, descriptor_index, attributes_count = \ self.unpacker['hhhh'] attributes = [self._parse_attribute() for _ in range(attributes_count)] self.methods.append({ 'access_flags': access_flags, 'name_index': name_index, 'descriptor_index': descriptor_index, 'attributes_count': attributes_count, 'attributes': attributes }) self.attributes_count, = self.unpacker.attributes_count self.attributes = [self._parse_attribute() for _ in range(self.attributes_count)] # cleanup del self.stream del self.unpacker def _parse_attribute(self): attribute_name_index, attribute_length = self.unpacker['hi'] data, = self.unpacker[f'{attribute_length}s'] return {'attribute_name_index': attribute_name_index, 'attribute_length': attribute_length, 'info': data} def load_class(name_or_stream: Union[str, BinaryIO]) -> ClassLoader: """ Open a class file and parse it using the ClassLoader Args: name_or_stream: A file name or a binary stream Returns: A ClassLoader instance with a loaded class """ stream: BinaryIO if isinstance(name_or_stream, str): stream = open(name_or_stream, 'rb') else: stream = name_or_stream loader = ClassLoader(stream) loader.load() return loader