Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions timm/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
input_img_mode='RGB',
transform=None,
target_transform=None,
additional_features=None,
**kwargs,
):
if reader is None or isinstance(reader, str):
Expand All @@ -38,17 +39,19 @@ def __init__(
root=root,
split=split,
class_map=class_map,
additional_features=additional_features,
**kwargs,
)
self.reader = reader
self.load_bytes = load_bytes
self.input_img_mode = input_img_mode
self.transform = transform
self.target_transform = target_transform
self.additional_features = additional_features
self._consecutive_errors = 0

def __getitem__(self, index):
img, target = self.reader[index]
img, target, *features = self.reader[index]

try:
img = img.read() if self.load_bytes else Image.open(img)
Expand All @@ -71,7 +74,10 @@ def __getitem__(self, index):
elif self.target_transform is not None:
target = self.target_transform(target)

return img, target
if self.additional_features is None:
return img, target
else:
return img, target, *features

def __len__(self):
return len(self.reader)
Expand Down
5 changes: 4 additions & 1 deletion timm/data/readers/reader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ def create_reader(
prefix = name[0]
name = name[-1]

# FIXME the additional features are only supported by ReaderHfds for now.
additional_features = kwargs.pop("additional_features", None)

# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'hfds':
from .reader_hfds import ReaderHfds # defer Hf datasets import
reader = ReaderHfds(name=name, root=root, split=split, **kwargs)
reader = ReaderHfds(name=name, root=root, split=split, additional_features=additional_features, **kwargs)
elif prefix == 'hfids':
from .reader_hfids import ReaderHfids # defer HF datasets import
reader = ReaderHfids(name=name, root=root, split=split, **kwargs)
Expand Down
13 changes: 12 additions & 1 deletion timm/data/readers/reader_hfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
class_map: dict = None,
input_key: str = 'image',
target_key: str = 'label',
additional_features: Optional[list[str]] = None,
download: bool = False,
trust_remote_code: bool = False
):
Expand Down Expand Up @@ -65,9 +66,18 @@ def __init__(
self.split_info = self.dataset.info.splits[split]
self.num_samples = self.split_info.num_examples

if isinstance(additional_features, str):
self.additional_features = [additional_features]
elif isinstance(additional_features, list):
self.additional_features = additional_features
else:
self.additional_features = []

def __getitem__(self, index):
item = self.dataset[index]
image = item[self.image_key]
features = [item[feat] for feat in self.additional_features]

if 'bytes' in image and image['bytes']:
image = io.BytesIO(image['bytes'])
else:
Expand All @@ -76,7 +86,8 @@ def __getitem__(self, index):
label = item[self.label_key]
if self.remap_class:
label = self.class_to_idx[label]
return image, label

return image, label, *features

def __len__(self):
return len(self.dataset)
Expand Down