Skip to content

Commit

Permalink
Merge pull request #181 from metarank/fix/double-release-guard
Browse files Browse the repository at this point in the history
add guards for native resources
  • Loading branch information
shuttie authored Mar 11, 2024
2 parents 22ad17d + be590f6 commit 967f3b5
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ import io.github.metarank.ltrlib.model.{Dataset, Model}
import org.apache.commons.math3.linear.{Array2DRowRealMatrix, ArrayRealVector, RealMatrix, RealVector}

trait Booster[D] extends Model {
protected var isClosed = false
def save(): Array[Byte]
def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double]
def weights(): Array[Double]
def close(): Unit

def whenNotClosed[T](f: => T): T = if (!isClosed) f else throw new Exception("booster is already closed")

override def predict(values: RealMatrix): ArrayRealVector = {
val rows = values.getRowDimension
val cols = values.getColumnDimension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import java.io.ByteArrayInputStream
case class CatboostBooster(booster: CatBoostModel, bytes: Array[Byte]) extends Booster[String] {
override def save(): Array[Byte] = bytes

override def close(): Unit = booster.close()
override def close(): Unit = whenNotClosed { booster.close() }

override def weights(): Array[Double] = Array.emptyDoubleArray

override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = {
override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = whenNotClosed {
val split = new Array[Array[Float]](rows)
var i = 0
while (i < rows) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,24 @@ import scala.collection.mutable

case class LightGBMBooster(model: LGBMBooster) extends Booster[LGBMDataset] with Logging {

override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = {
override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = whenNotClosed {
model.predictForMat(values, rows, cols, true, PredictionType.C_API_PREDICT_NORMAL)
}

override def close(): Unit = model.close()
override def close(): Unit = whenNotClosed {
isClosed = true
model.close()
}

override def save(): Array[Byte] =
override def save(): Array[Byte] = whenNotClosed {
model.saveModelToString(0, 0, FeatureImportanceType.SPLIT).getBytes(StandardCharsets.UTF_8)
}

override def weights(): Array[Double] =
override def weights(): Array[Double] = whenNotClosed {
// numIteration=0 means "use all of them"
// we use split there to match xgboost, which can only do split
model.featureImportance(0, FeatureImportanceType.SPLIT)
}
}

object LightGBMBooster extends BoosterFactory[LGBMDataset, LightGBMBooster, LightGBMOptions] with Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ case class XGBoostBooster(
) extends Booster[DMatrix]
with Logging {

override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = {
override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = whenNotClosed {
val mat = new DMatrix(values.map(_.toFloat), rows, cols, Float.NaN)
mat.setGroup(Array(rows))
mat.setFeatureTypes(featureTypes)
Expand All @@ -30,9 +30,9 @@ case class XGBoostBooster(
out
}

override def close(): Unit = model.dispose()
override def close(): Unit = whenNotClosed { model.dispose() }

override def save(): Array[Byte] = {
override def save(): Array[Byte] = whenNotClosed {
val bytes = new ByteArrayOutputStream()
val data = new DataOutputStream(bytes)
data.writeByte(BITSTREAM_VERSION)
Expand All @@ -47,7 +47,7 @@ case class XGBoostBooster(
bytes.toByteArray
}

override def weights(): Array[Double] = {
override def weights(): Array[Double] = whenNotClosed {
val names = (0 until model.getNumFeature.toInt).map(i => s"feature$i").toArray
val weights = model.getFeatureScore(names).asScala
val result = for {
Expand Down

0 comments on commit 967f3b5

Please sign in to comment.