torch.utils.data.Dataset vs TensorDataset

 

1. torch.utils.data.TensorDataset

첫 번째 차원(batch size)이 동일한 tensor들을 wrapping하는 dataset

1
2
3
4
5
6
7
8
9
from torch.utils.data import TensorDataset

X_train = torch.FloatTensor([[73, 80, 75],
                             [93, 88, 93],
                             [89, 91, 90],
                             [96, 98, 100],
                             [73, 66, 70]])
y_train = torch.FloatTensor([[152], [185], [180], [196], [142]])
ds = TensorDataset(X_train, y_train)

2. torch.utils.data.Dataset

Custom dataset을 만드는데 사용되는 abstract class

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torch.utils.data import Dataset

class CustomDataset(Dataset):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    ...

  def __getitem__(self, idx):
    return ...

  def __len__(self):
    return ...

ds = CustomDataset(X_train, y_train)

Reference