-
Notifications
You must be signed in to change notification settings - Fork 30.1k
Open
Labels
Description
System Info
transformers
version: 4.55.3- Platform: Linux-6.8.0-1029-aws-x86_64-with-glibc2.39
- Python version: 3.10.16
- Huggingface_hub version: 0.34.4
- Safetensors version: 0.6.2
- Accelerate version: 1.10.0
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cu128 (NA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script? no
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
from concurrent.futures import ThreadPoolExecutor
import torch
from transformers import AutoTokenizer, AutoModel
from typing import Any
HFCacheSettings = {
"REVISIONS": {
"princeton-nlp/unsup-simcse-bert-base-uncased": "6504ae026e02a1464538d443b15e36afc318e034",
},
"PRETRAINED_CACHE_DIR": "/tmp/hf",
}
class SimcseGenerator:
def __init__(
self, tokenizer: AutoTokenizer, model: AutoModel, device: torch.device
) -> None:
self.tokenizer = tokenizer
self.model = model
self.device = device
@classmethod
def from_config(
cls,
model_name: str = "princeton-nlp/unsup-simcse-bert-base-uncased",
**kwargs: Any,
) -> "SimcseGenerator":
model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
revision=HFCacheSettings["REVISIONS"][model_name],
cache_dir=HFCacheSettings["PRETRAINED_CACHE_DIR"],
)
model = AutoModel.from_pretrained(
model_name,
revision=HFCacheSettings["REVISIONS"][model_name],
cache_dir=HFCacheSettings["PRETRAINED_CACHE_DIR"],
).to(model_device)
return SimcseGenerator(tokenizer, model, model_device)
def test_simcse_generator():
def execute():
SimcseGenerator.from_config()
num_jobs = 20
with ThreadPoolExecutor(max_workers=num_jobs) as executor:
futures = [executor.submit(execute) for _ in range(num_jobs)]
[future.result() for future in futures]
uv venv --python 3.10
source .venv/bin/activate
uv pip install transformers[torch] "torch @ https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=0c96999d15cf1f13dd7c913e0b21a9a355538e6cfc10861a17158320292f5954" pytest
pytest test_simcse.py
(tmp.rdAd75nvYG) ubuntu@ip-10-1-27-118:/tmp/tmp.rdAd75nvYG$ pytest test_simcse.py
====================================================================================================================================================== test session starts ======================================================================================================================================================
platform linux -- Python 3.10.16, pytest-8.4.1, pluggy-1.6.0
rootdir: /tmp/tmp.rdAd75nvYG
collected 1 item
test_simcse.py F [100%]
=========================================================================================================================================================== FAILURES ============================================================================================================================================================
_____________________________________________________________________________________________________________________________________________________ test_simcse_generator _____________________________________________________________________________________________________________________________________________________
def test_simcse_generator():
def execute():
SimcseGenerator.from_config()
num_jobs = 20
with ThreadPoolExecutor(max_workers=num_jobs) as executor:
futures = [executor.submit(execute) for _ in range(num_jobs)]
> [future.result() for future in futures]
test_simcse.py:48:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test_simcse.py:48: in <listcomp>
[future.result() for future in futures]
/home/ubuntu/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/concurrent/futures/_base.py:458: in result
return self.__get_result()
/home/ubuntu/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/concurrent/futures/_base.py:403: in __get_result
raise self._exception
/home/ubuntu/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/concurrent/futures/thread.py:58: in run
result = self.fn(*self.args, **self.kwargs)
test_simcse.py:43: in execute
SimcseGenerator.from_config()
test_simcse.py:37: in from_config
).to(model_device)
.venv/lib/python3.10/site-packages/transformers/modeling_utils.py:4346: in to
return super().to(*args, **kwargs)
.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1369: in to
return self._apply(convert)
.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:928: in _apply
module._apply(fn)
.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:928: in _apply
module._apply(fn)
.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:955: in _apply
param_applied = fn(param)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
t = Parameter containing:
tensor(..., device='meta', size=(30522, 768), requires_grad=True)
def convert(t):
try:
if convert_to_format is not None and t.dim() in (4, 5):
return t.to(
device,
dtype if t.is_floating_point() or t.is_complex() else None,
non_blocking,
memory_format=convert_to_format,
)
return t.to(
device,
dtype if t.is_floating_point() or t.is_complex() else None,
non_blocking,
)
except NotImplementedError as e:
if str(e) == "Cannot copy out of meta tensor; no data!":
> raise NotImplementedError(
f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
f"when moving module from meta to a different device."
) from None
E NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.
.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1362: NotImplementedError
==================================================================================================================================================== short test summary info ====================================================================================================================================================
FAILED test_simcse.py::test_simcse_generator - NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.
====================================================================================================================================================== 1 failed in 12.04s =======================================================================================================================================================
Expected behavior
Model is loaded into "cpu" or "cuda" device at runtime every time, not randomly fail with NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.