Skip to content

Commit

Permalink
rewrite propagator on inheritance tree
Browse files Browse the repository at this point in the history
  • Loading branch information
wagyourtail committed Sep 19, 2024
1 parent d545736 commit e20205f
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 324 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Main: CliktCommand(printHelpOnEmptyArgs = true) {
if (prop.isNotEmpty() || cp.isNotEmpty()) {
LOGGER.info { "Propagating..." }
val t = measureTime {
Propagator(Namespace(propagationNs!!), mappings, prop + cp)
Propagator(mappings, Namespace(propagationNs!!), prop + cp)
.propagate(mappings.namespaces.toSet() - Namespace(propagationNs!!))
}
LOGGER.info { "Propagated in ${t.inWholeMilliseconds}ms" }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package xyz.wagyourtail.unimined.mapping.propagator

import xyz.wagyourtail.commonskt.reader.CharReader
import xyz.wagyourtail.unimined.mapping.Namespace
import xyz.wagyourtail.unimined.mapping.jvms.ext.FieldOrMethodDescriptor
import xyz.wagyourtail.unimined.mapping.jvms.four.AccessFlag
import xyz.wagyourtail.unimined.mapping.jvms.four.ElementType
import xyz.wagyourtail.unimined.mapping.jvms.four.two.one.InternalName
import xyz.wagyourtail.unimined.mapping.tree.AbstractMappingTree

class CachedInheritanceTree(tree: AbstractMappingTree, fns: Namespace, data: CharReader<*>): InheritanceTree(tree, fns) {

companion object {

fun write(tree: InheritanceTree, append: (String) -> Unit) {
for (cls in tree.classes.values) {
append(cls.name.toString())
append("\t")
append(cls.superType.toString())
append("\t")
append(cls.interfaces.joinToString("\t") { it.toString() })
append("\n")

for (method in cls.methods) {
append("\t")
append(AccessFlag.of(ElementType.METHOD, method.access).joinToString("|") { it.toString() })
append("\t")
append(method.name)
append("\t")
append(method.descriptor.toString())
append("\n")
}
}
}

}

override val classes by lazy {
val classes = mutableMapOf<InternalName, ClassInfo>()
var ci: ClassInfo? = null
while (!data.exhausted()) {
if (data.peek() == '\n') {
data.take()
continue
}
var col = data.takeNext()!!
var indent = 0
while (col.isEmpty()) {
indent++
col = data.takeNext()!!
}
if (indent > 1) {
throw IllegalArgumentException("expected method, found double indent")
}
if (indent == 0) {
val cls = col
val sup = data.takeNext()!!.ifEmpty { null }
val intf = data.takeRemainingOnLine().map { InternalName.read(it) }
ci = ClassInfo(InternalName.read(cls), sup?.let { InternalName.read(it) }, intf)
classes[ci!!.name] = ci!!
} else {
val acc = col.split("|").map { AccessFlag.valueOf(it.uppercase()) }
val name = data.takeNext()!!
val desc = FieldOrMethodDescriptor.read(data.takeNext()!!)

if (desc.isMethodDescriptor()) {
ci!!.methods.add(
MethodInfo(
name,
desc.getMethodDescriptor(),
AccessFlag.toInt(acc.toSet())
)
)
} else {
ci!!.fields.add(
FieldInfo(
name,
desc.getFieldDescriptor(),
AccessFlag.toInt(acc.toSet())
)
)
}
}
}
classes
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,97 +4,23 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import xyz.wagyourtail.commonskt.utils.coroutines.parallelMap
import xyz.wagyourtail.unimined.mapping.Namespace
import xyz.wagyourtail.unimined.mapping.jvms.ext.FieldOrMethodDescriptor
import xyz.wagyourtail.unimined.mapping.jvms.four.AccessFlag
import xyz.wagyourtail.unimined.mapping.jvms.four.ElementType
import xyz.wagyourtail.unimined.mapping.jvms.four.three.three.MethodDescriptor
import xyz.wagyourtail.unimined.mapping.jvms.four.three.two.FieldDescriptor
import xyz.wagyourtail.unimined.mapping.jvms.four.two.one.InternalName
import xyz.wagyourtail.unimined.mapping.tree.AbstractMappingTree
import xyz.wagyourtail.commonskt.reader.CharReader
import xyz.wagyourtail.commonskt.utils.coroutines.parallelMap
import xyz.wagyourtail.unimined.mapping.tree.node._class.ClassNode

open class InheritanceTree(val tree: AbstractMappingTree, val fns: Namespace, val targets: Set<Namespace>) {
abstract class InheritanceTree(val tree: AbstractMappingTree, val fns: Namespace) {
val LOGGER = KotlinLogging.logger { }

private val _classes = mutableMapOf<InternalName, ClassInfo>()
val classes: Map<InternalName, ClassInfo> get() = _classes
abstract val classes: Map<InternalName, ClassInfo>

suspend fun propagate() = coroutineScope {
suspend fun propagate(targets: Set<Namespace>) = coroutineScope {
classes.values.parallelMap {
it.propagate()
}
}

fun read(data: CharReader<*>) {
var ci: ClassInfo? = null
while (!data.exhausted()) {
if (data.peek() == '\n') {
data.take()
continue
}
var col = data.takeNext()!!
var indent = 0
while (col.isEmpty()) {
indent++
col = data.takeNext()!!
}
if (indent > 1) {
throw IllegalArgumentException("expected method, found double indent")
}
if (indent == 0) {
val cls = col
val sup = data.takeNext()!!.ifEmpty { null }
val intf = data.takeRemainingOnLine().map { InternalName.read(it) }
ci = ClassInfo(InternalName.read(cls), sup?.let { InternalName.read(it) }, intf)
_classes[ci.name] = ci
} else {
val acc = col.split("|").map { AccessFlag.valueOf(it.uppercase()) }
val name = data.takeNext()!!
val desc = FieldOrMethodDescriptor.read(data.takeNext()!!)

if (desc.isMethodDescriptor()) {
ci!!.methods.add(
MethodInfo(
name,
desc.getMethodDescriptor(),
AccessFlag.toInt(acc.toSet())
)
)
} else {
ci!!.fields.add(
FieldInfo(
name,
desc.getFieldDescriptor(),
AccessFlag.toInt(acc.toSet())
)
)
}
}
}
}

fun write(append: (String) -> Unit) {
for (cls in _classes.values) {
append(cls.name.toString())
append("\t")
append(cls.superType.toString())
append("\t")
append(cls.interfaces.joinToString("\t") { it.toString() })
append("\n")

for (method in cls.methods) {
append("\t")
append(AccessFlag.of(ElementType.METHOD, method.access).joinToString("|") { it.toString() })
append("\t")
append(method.name)
append("\t")
append(method.descriptor.toString())
append("\n")
}

it.propagate(targets)
}
}

Expand Down Expand Up @@ -123,16 +49,16 @@ open class InheritanceTree(val tree: AbstractMappingTree, val fns: Namespace, va

lateinit var methodData: MutableMap<MethodInfo, MutableMap<Namespace, String>>

suspend fun propagate(): Unit = coroutineScope {
suspend fun propagate(targets: Set<Namespace>): Unit = coroutineScope {
if (::methodData.isInitialized) methodData
propagateLock.withLock {
if (::methodData.isInitialized) methodData

superClass?.propagate()
interfaceClasses.parallelMap { it.propagate() }
superClass?.propagate(targets)
interfaceClasses.parallelMap { it.propagate(targets) }

for (method in methods) {
clsNode?.visitMethod(mapOf(fns to (method.name to method.descriptor))).visitEnd()
clsNode?.visitMethod(mapOf(fns to (method.name to method.descriptor)))?.visitEnd()
}

val methods = methods.filter { md ->
Expand All @@ -144,28 +70,41 @@ open class InheritanceTree(val tree: AbstractMappingTree, val fns: Namespace, va
}
AccessFlag.isInheritable(acc)
}.parallelMap { md ->
val names = (clsNode?.getMethods(fns, md.name, md.descriptor)?.firstOrNull()?.names?.filterKeys { it in targets } ?: emptyMap()).toMutableMap()
md to (clsNode?.getMethods(fns, md.name, md.descriptor)?.firstOrNull()?.names?.filterKeys { it in targets } ?: emptyMap()).toMutableMap()
// }.parallelMap { (md, names) ->
}.parallelMap { (md, names) ->
// traverse parents, retrieve matching mappings
val superNames = superClass?.methodData?.get(md)
val interfaces = interfaceClasses.map { it to it.methodData[md] }
for (ns in targets) {
if (superNames != null) {
val needsOverwrite = mutableListOf<Namespace>()
if (superNames != null) {
val needsOverwrite = mutableListOf<Namespace>()
for (ns in targets) {
if (names[ns] != superNames[ns]) {
if (superNames[ns] == null) {
superClass!!.overwriteMethodName(md, ns, names[ns]!!)
needsOverwrite += ns
} else {
names[ns] = superNames.getValue(ns)
}
}
}
for ((intf, intfNames) in interfaces) {
if (intfNames != null) {
if (names[ns] != intfNames[ns] && names[ns] != null) {
intf.overwriteMethodName(md, ns, names[ns]!!)
}
for ((intf, intfNames) in interfaces) {
if (intfNames != null) {
for (ns in targets) {
if (names[ns] != intfNames[ns]) {
if (intfNames[ns] == null) {
needsOverwrite += ns
} else {
names[ns] = intfNames.getValue(ns)
}
}
}
}
}
if (needsOverwrite.isNotEmpty()) {
overwriteParentMethodNames(md, names.filterKeys { it in needsOverwrite })
}
clsNode?.visitMethod(
mapOf(fns to (md.name to md.descriptor)) +
names.mapValues { it.value to null }
Expand All @@ -192,20 +131,28 @@ open class InheritanceTree(val tree: AbstractMappingTree, val fns: Namespace, va
}
}

private fun overwriteMethodName(md: MethodInfo, namespace: Namespace, newName: String) {
private fun overwriteMethodNames(md: MethodInfo, names: Map<Namespace, String>) {
if (md in methodData) {
methodData[md]!![namespace] = newName
methodData[md]!!.putAll(names)
clsNode?.visitMethod(mapOf(
fns to (md.name to md.descriptor),
namespace to (newName to null)
*names.mapValues { it.value to null }.entries.map { it.key to it.value }.toTypedArray()
))?.visitEnd()
superClass?.overwriteMethodName(md, namespace, newName)
for (interfaceClass in interfaceClasses) {
interfaceClass.overwriteMethodName(md, namespace, newName)
}
overwriteParentMethodNames(md, names)
}
}

private fun overwriteParentMethodNames(md: MethodInfo, names: Map<Namespace, String>) {
superClass?.overwriteMethodNames(md, names)
for (interfaceClass in interfaceClasses) {
interfaceClass.overwriteMethodNames(md, names)
}
}

override fun toString(): String {
return "$name extends $superClass implements $interfaces"
}

}

class MethodInfo(
Expand All @@ -214,6 +161,10 @@ open class InheritanceTree(val tree: AbstractMappingTree, val fns: Namespace, va
var access: Int
) {

override fun toString(): String {
return "${AccessFlag.of(ElementType.METHOD, access)} $name;$descriptor"
}

override fun equals(other: Any?): Boolean {
return other is MethodInfo && name == other.name && descriptor == other.descriptor
}
Expand All @@ -230,6 +181,10 @@ open class InheritanceTree(val tree: AbstractMappingTree, val fns: Namespace, va
var access: Int
) {

override fun toString(): String {
return "${AccessFlag.of(ElementType.FIELD, access)} $name;$descriptor"
}

override fun equals(other: Any?): Boolean {
return other is FieldInfo && name == other.name && descriptor == other.descriptor
}
Expand Down
Loading

0 comments on commit e20205f

Please sign in to comment.