Skip to content

Memory Leak Issue with ForwardIs Method in gotch Library #133

@yinziyang

Description

@yinziyang

Description:

When using the gotch library with a JIT-compiled BERT model, calling the ForwardIs method repeatedly causes a memory leak. The memory usage continuously increases with each call to ForwardIs, which may lead to system instability after prolonged operation.

Steps to Reproduce:

  1. Train and save a BERT model using the following Python code:

    import torch
    from transformers import BertTokenizer, BertModel, BertForSequenceClassification
    import torch.nn as nn
    
    class ScriptableBertForSequenceClassification(BertForSequenceClassification):
        def __init__(self, config):
            super().__init__(config)
            self.bert = BertModel(config)
    
        def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
            if input_ids is not None:
                input_shape = input_ids.size()
            else:
                input_shape = inputs_embeds.size()[:-1]
    
            device = input_ids.device if input_ids is not None else inputs_embeds.device
    
            if attention_mask is None:
                attention_mask = torch.ones(input_shape, device=device)
            if token_type_ids is None:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    
            extended_attention_mask = attention_mask[:, None, None, :]
    
            for param in self.bert.parameters():
                if param is not None:
                    extended_attention_mask = extended_attention_mask.to(dtype=param.dtype)
                    break
    
            extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
    
            embedding_output = self.bert.embeddings(
                input_ids=input_ids,
                position_ids=position_ids,
                inputs_embeds=inputs_embeds,
            )
            encoder_outputs = self.bert.encoder(
                embedding_output,
                attention_mask=extended_attention_mask,
                head_mask=head_mask,
            )
            sequence_output = encoder_outputs[0]
            pooled_output = self.bert.pooler(sequence_output)
    
            logits = self.classifier(pooled_output)
    
            loss = None
            if labels is not None:
                if self.num_labels == 1:
                    loss_fct = nn.MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = nn.CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    
            return (logits, pooled_output, sequence_output) if loss is None else (loss, logits, pooled_output, sequence_output)
    
    model = ScriptableBertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=3)
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    input_text = "Hello, this is a test."
    inputs = tokenizer(input_text, return_tensors='pt')
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    token_type_ids = inputs["token_type_ids"]
    
    jit_model = torch.jit.trace(model, (input_ids, attention_mask, token_type_ids))
    torch.jit.save(jit_model, "model.pt")
  2. Load and call the model in Go using the gotch library:

    package main
    
    import (
    	"log"
    
    	"github.com/sugarme/gotch/ts"
    )
    
    func main() {
    	modelFile := "model.pt"
    	model, err := ts.ModuleLoad(modelFile)
    	if err != nil {
    		panic(err)
    	}
    
    	var inputIds = []int32{
    		101, 12865, 11639, 56011, 10908, 10473, 47798, 11424, 83438, 13663, 80017, 74661, 47464, 79326, 10271, 10114, 17734, 3378, 7104, 121, 2075, 2102, 7323, 2534, 3642, 8831, 4151, 7069, 3661, 5605, 3197, 3459, 29653, 6088, 2188, 4380, 2072, 4780, 2435, 7498, 5396, 5718, 73784, 5611, 2452, 2763, 2084, 2090, 3001, 8192, 3701, 2735, 4009, 7740, 5755, 2568, 4792, 73784, 3592, 6336, 6779, 3775, 3378, 2448, 7300, 7321, 7356, 3197, 3408, 4284, 3792, 2275, 5605, 7323, 2534, 2730, 7323, 7315, 5142, 2534, 121, 2075, 2102, 7323, 2534, 8332, 3626, 2080, 8098, 5718, 3661, 5605, 4163, 6748, 10900, 68897, 7700, 2102, 3661, 5605, 5769, 5718, 3199, 3240, 5605, 2146, 3661, 5605, 4982, 5619, 2080, 2688, 8422, 5618, 8335, 5061, 4409, 4252, 2259, 2299, 4142, 5484, 4941, 2105, 3419, 3191, 4577, 2773, 2149, 7838, 3031, 5484, 7700, 2102, 73784, 4462, 2731, 2206, 2756, 2204, 7333, 7333, 5765, 5769, 5718, 4380, 8332, 3626, 5718, 2080, 8098, 3731, 5293, 4163, 6748, 4163, 6748, 4163, 6748, 4333, 2597, 6546, 5396, 4476, 2762, 121, 2316, 3848, 2286, 3661, 2890, 3197, 3459, 121, 11517, 11274, 2465, 3410, 3824, 25986, 10929, 3978, 6457, 4449, 8595, 5396, 4476, 5718, 3378, 2678, 4004, 4482, 7349, 7168, 6088, 2251, 2104, 4313, 3642, 8831, 2534, 6309, 8215, 121, 3507, 2204, 2078, 2211, 4580, 2435, 7498, 2435, 7478, 3173, 4368, 4476, 6397, 6036, 2184, 4181, 2286, 2079, 7651, 4012, 6098, 73784, 2468, 3410, 5760, 2457, 2149, 8417, 5611, 3701, 2735, 6457, 7475, 7478, 2081, 4476, 6397, 6036, 4346, 2457, 102,
    	}
    	var attentionMask = []int32{
    		1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
    	}
    	var tokenTypeIds = []int32{
    		0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    	}
    
    	l := int64(len(inputIds))
    
    	inputParams := make([]*ts.IValue, 3)
    
    	s1 := ts.MustOfSlice(inputIds)
    	ts1 := s1.MustView([]int64{1, l}, true)
    	defer s1.Drop()
    	log.Printf("%+v", ts1)
    	defer ts1.Drop()
    	inputParams[0] = ts.NewIValue(ts1)
    
    	s2 := ts.MustOfSlice(attentionMask)
    	ts2 := s2.MustView([]int64{1, l}, true)
    	defer s2.Drop()
    	log.Printf("%+v", ts2)
    	defer ts2.Drop()
    	inputParams[1] = ts.NewIValue(ts2)
    
    	s3 := ts.MustOfSlice(tokenTypeIds)
    	ts3 := s3.MustView([]int64{1, l}, true)
    	defer s3.Drop()
    	log.Printf("%+v", ts3)
    	defer ts3.Drop()
    	inputParams[2] = ts.NewIValue(ts3)
    
        
    	for {
    		ivs, err := model.ForwardIs(inputParams)   // Calling this line multiple times will lead to increasing memory usage
    		if err != nil {
    			panic(err)
    		}
    		xs := ivs.Value().([]*ts.Tensor)
    		for _, x := range xs {
    			log.Println(x)
    			x.MustDrop()
    		}
    	}
    }

Expected Behavior:

Memory usage should remain stable regardless of the number of ForwardIs calls.

Actual Behavior:

Memory usage continues to increase with each call to ForwardIs, leading to gradual memory consumption over prolonged operation.

Additional Information:

  • Python Version: 3.12.4
  • PyTorch Version: 2.3.1
  • Gotch Version: 0.9.1
  • Transformers Version: 4.43.3
  • Operating System: Ubuntu 22.04
  • Other Libraries: [Specify any other relevant libraries]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions