package main
import (
"log"
"github.com/sugarme/gotch/ts"
)
func main() {
modelFile := "model.pt"
model, err := ts.ModuleLoad(modelFile)
if err != nil {
panic(err)
}
var inputIds = []int32{
101, 12865, 11639, 56011, 10908, 10473, 47798, 11424, 83438, 13663, 80017, 74661, 47464, 79326, 10271, 10114, 17734, 3378, 7104, 121, 2075, 2102, 7323, 2534, 3642, 8831, 4151, 7069, 3661, 5605, 3197, 3459, 29653, 6088, 2188, 4380, 2072, 4780, 2435, 7498, 5396, 5718, 73784, 5611, 2452, 2763, 2084, 2090, 3001, 8192, 3701, 2735, 4009, 7740, 5755, 2568, 4792, 73784, 3592, 6336, 6779, 3775, 3378, 2448, 7300, 7321, 7356, 3197, 3408, 4284, 3792, 2275, 5605, 7323, 2534, 2730, 7323, 7315, 5142, 2534, 121, 2075, 2102, 7323, 2534, 8332, 3626, 2080, 8098, 5718, 3661, 5605, 4163, 6748, 10900, 68897, 7700, 2102, 3661, 5605, 5769, 5718, 3199, 3240, 5605, 2146, 3661, 5605, 4982, 5619, 2080, 2688, 8422, 5618, 8335, 5061, 4409, 4252, 2259, 2299, 4142, 5484, 4941, 2105, 3419, 3191, 4577, 2773, 2149, 7838, 3031, 5484, 7700, 2102, 73784, 4462, 2731, 2206, 2756, 2204, 7333, 7333, 5765, 5769, 5718, 4380, 8332, 3626, 5718, 2080, 8098, 3731, 5293, 4163, 6748, 4163, 6748, 4163, 6748, 4333, 2597, 6546, 5396, 4476, 2762, 121, 2316, 3848, 2286, 3661, 2890, 3197, 3459, 121, 11517, 11274, 2465, 3410, 3824, 25986, 10929, 3978, 6457, 4449, 8595, 5396, 4476, 5718, 3378, 2678, 4004, 4482, 7349, 7168, 6088, 2251, 2104, 4313, 3642, 8831, 2534, 6309, 8215, 121, 3507, 2204, 2078, 2211, 4580, 2435, 7498, 2435, 7478, 3173, 4368, 4476, 6397, 6036, 2184, 4181, 2286, 2079, 7651, 4012, 6098, 73784, 2468, 3410, 5760, 2457, 2149, 8417, 5611, 3701, 2735, 6457, 7475, 7478, 2081, 4476, 6397, 6036, 4346, 2457, 102,
}
var attentionMask = []int32{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
}
var tokenTypeIds = []int32{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}
l := int64(len(inputIds))
inputParams := make([]*ts.IValue, 3)
s1 := ts.MustOfSlice(inputIds)
ts1 := s1.MustView([]int64{1, l}, true)
defer s1.Drop()
log.Printf("%+v", ts1)
defer ts1.Drop()
inputParams[0] = ts.NewIValue(ts1)
s2 := ts.MustOfSlice(attentionMask)
ts2 := s2.MustView([]int64{1, l}, true)
defer s2.Drop()
log.Printf("%+v", ts2)
defer ts2.Drop()
inputParams[1] = ts.NewIValue(ts2)
s3 := ts.MustOfSlice(tokenTypeIds)
ts3 := s3.MustView([]int64{1, l}, true)
defer s3.Drop()
log.Printf("%+v", ts3)
defer ts3.Drop()
inputParams[2] = ts.NewIValue(ts3)
for {
ivs, err := model.ForwardIs(inputParams) // Calling this line multiple times will lead to increasing memory usage
if err != nil {
panic(err)
}
xs := ivs.Value().([]*ts.Tensor)
for _, x := range xs {
log.Println(x)
x.MustDrop()
}
}
}
Description:
When using the
gotchlibrary with a JIT-compiled BERT model, calling theForwardIsmethod repeatedly causes a memory leak. The memory usage continuously increases with each call toForwardIs, which may lead to system instability after prolonged operation.Steps to Reproduce:
Train and save a BERT model using the following Python code:
Load and call the model in Go using the
gotchlibrary:Expected Behavior:
Memory usage should remain stable regardless of the number of
ForwardIscalls.Actual Behavior:
Memory usage continues to increase with each call to
ForwardIs, leading to gradual memory consumption over prolonged operation.Additional Information: