diff --git a/src/main/scala/io/github/metarank/ltrlib/booster/Booster.scala b/src/main/scala/io/github/metarank/ltrlib/booster/Booster.scala index 637626b..8afe04b 100644 --- a/src/main/scala/io/github/metarank/ltrlib/booster/Booster.scala +++ b/src/main/scala/io/github/metarank/ltrlib/booster/Booster.scala @@ -5,11 +5,14 @@ import io.github.metarank.ltrlib.model.{Dataset, Model} import org.apache.commons.math3.linear.{Array2DRowRealMatrix, ArrayRealVector, RealMatrix, RealVector} trait Booster[D] extends Model { + protected var isClosed = false def save(): Array[Byte] def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] def weights(): Array[Double] def close(): Unit + def whenNotClosed[T](f: => T): T = if (!isClosed) f else throw new Exception("booster is already closed") + override def predict(values: RealMatrix): ArrayRealVector = { val rows = values.getRowDimension val cols = values.getColumnDimension diff --git a/src/main/scala/io/github/metarank/ltrlib/booster/CatboostBooster.scala b/src/main/scala/io/github/metarank/ltrlib/booster/CatboostBooster.scala index 5a73505..060a5cd 100644 --- a/src/main/scala/io/github/metarank/ltrlib/booster/CatboostBooster.scala +++ b/src/main/scala/io/github/metarank/ltrlib/booster/CatboostBooster.scala @@ -11,11 +11,11 @@ import java.io.ByteArrayInputStream case class CatboostBooster(booster: CatBoostModel, bytes: Array[Byte]) extends Booster[String] { override def save(): Array[Byte] = bytes - override def close(): Unit = booster.close() + override def close(): Unit = whenNotClosed { booster.close() } override def weights(): Array[Double] = Array.emptyDoubleArray - override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = { + override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = whenNotClosed { val split = new Array[Array[Float]](rows) var i = 0 while (i < rows) { diff --git a/src/main/scala/io/github/metarank/ltrlib/booster/LightGBMBooster.scala b/src/main/scala/io/github/metarank/ltrlib/booster/LightGBMBooster.scala index 73d45ca..a367aae 100644 --- a/src/main/scala/io/github/metarank/ltrlib/booster/LightGBMBooster.scala +++ b/src/main/scala/io/github/metarank/ltrlib/booster/LightGBMBooster.scala @@ -11,19 +11,24 @@ import scala.collection.mutable case class LightGBMBooster(model: LGBMBooster) extends Booster[LGBMDataset] with Logging { - override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = { + override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = whenNotClosed { model.predictForMat(values, rows, cols, true, PredictionType.C_API_PREDICT_NORMAL) } - override def close(): Unit = model.close() + override def close(): Unit = whenNotClosed { + isClosed = true + model.close() + } - override def save(): Array[Byte] = + override def save(): Array[Byte] = whenNotClosed { model.saveModelToString(0, 0, FeatureImportanceType.SPLIT).getBytes(StandardCharsets.UTF_8) + } - override def weights(): Array[Double] = + override def weights(): Array[Double] = whenNotClosed { // numIteration=0 means "use all of them" // we use split there to match xgboost, which can only do split model.featureImportance(0, FeatureImportanceType.SPLIT) + } } object LightGBMBooster extends BoosterFactory[LGBMDataset, LightGBMBooster, LightGBMOptions] with Logging { diff --git a/src/main/scala/io/github/metarank/ltrlib/booster/XGBoostBooster.scala b/src/main/scala/io/github/metarank/ltrlib/booster/XGBoostBooster.scala index a65c294..4fb1989 100644 --- a/src/main/scala/io/github/metarank/ltrlib/booster/XGBoostBooster.scala +++ b/src/main/scala/io/github/metarank/ltrlib/booster/XGBoostBooster.scala @@ -16,7 +16,7 @@ case class XGBoostBooster( ) extends Booster[DMatrix] with Logging { - override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = { + override def predictMat(values: Array[Double], rows: Int, cols: Int): Array[Double] = whenNotClosed { val mat = new DMatrix(values.map(_.toFloat), rows, cols, Float.NaN) mat.setGroup(Array(rows)) mat.setFeatureTypes(featureTypes) @@ -30,9 +30,9 @@ case class XGBoostBooster( out } - override def close(): Unit = model.dispose() + override def close(): Unit = whenNotClosed { model.dispose() } - override def save(): Array[Byte] = { + override def save(): Array[Byte] = whenNotClosed { val bytes = new ByteArrayOutputStream() val data = new DataOutputStream(bytes) data.writeByte(BITSTREAM_VERSION) @@ -47,7 +47,7 @@ case class XGBoostBooster( bytes.toByteArray } - override def weights(): Array[Double] = { + override def weights(): Array[Double] = whenNotClosed { val names = (0 until model.getNumFeature.toInt).map(i => s"feature$i").toArray val weights = model.getFeatureScore(names).asScala val result = for {