# 使用的包
pip install datasets |
# 使用的函数 datasets.load_dataset()
# 函数原型
datasets.load_dataset( | |
path: str, | |
name: Optional[str] = None, | |
data_dir: Optional[str] = None, | |
data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None, | |
split: Optional[Union[str, Split]] = None, | |
cache_dir: Optional[str] = None, | |
features: Optional[Features] = None, | |
download_config: Optional[DownloadConfig] = None, | |
download_mode: Optional[DownloadMode] = None, | |
ignore_verifications: bool = False, | |
keep_in_memory: Optional[bool] = None, | |
save_infos: bool = False, | |
revision: Optional[Union[str, Version]] = None, | |
use_auth_token: Optional[Union[bool, str]] = None, | |
task: Optional[Union[str, TaskTemplate]] = None, | |
streaming: bool = False, | |
**config_kwargs | |
) |
# 参数说明
load_dataset
函数从 Hugging Face Hub
或者本地数据集文件中加载一个数据集。可以通过 https://huggingface.co/datasets 或者 datasets.list_datasets()
函数来获取所有可用的数据集。
参数
path
表示数据集的名字或者路径。可以是一个数据集的名字,比如imdb
、glue
;也可以是通用的产生数据集文件的脚本,比如 "json"、“csv”、“parquet”、“text”;或者是在数据集目录中的脚本(.py
) 文件,比如 “glue/glue.py
”。参数
name
表示数据集中的子数据集,当一个数据集包含多个数据集时,就需要这个参数。比如 "glue
"数据集下就包含"sst2
"、“cola
”、"qqp
" 等多个子数据集,此时就需要指定name
来表示加载哪一个子数据集。参数
data_dir
表示数据集所在的目录,参数data_files
表示本地数据集文件。参数
split
如果为None
,则返回一个DataDict
对象,包含多个DataSet
数据集对象;如果给定的话,则返回单个DataSet
对象。参数
cache_dir
表示缓存数据的目录,默认为 "~/.cache/huggingface/datasets
"。参数keep_in_memory
表示是否将数据集缓存在内存中,加载一次后,再次加载可以提高加载速度。参数
revision
表示加载数据集的脚本的版本。
# 函数使用
- 加载 imdb 数据集
>>> from datasets import load_dataset | |
>>> dataset = datasets.load_dataset("imdb") | |
>>> dataset | |
DatasetDict({ | |
train: Dataset({ | |
features: ['text', 'label'], | |
num_rows: 25000 | |
}) | |
test: Dataset({ | |
features: ['text', 'label'], | |
num_rows: 25000 | |
}) | |
unsupervised: Dataset({ | |
features: ['text', 'label'], | |
num_rows: 50000 | |
}) | |
}) | |
>>> dataset['train'] | |
Dataset({ | |
features: ['text', 'label'], | |
num_rows: 25000 | |
}) |
使用时会自行下载
- 加载 glue 下的 cola 子数据集
>>> dataset = datasets.load_dataset("glue", name="cola") | |
>>> dataset | |
DatasetDict({ | |
train: Dataset({ | |
features: ['sentence', 'label', 'idx'], | |
num_rows: 8551 | |
}) | |
validation: Dataset({ | |
features: ['sentence', 'label', 'idx'], | |
num_rows: 1043 | |
}) | |
test: Dataset({ | |
features: ['sentence', 'label', 'idx'], | |
num_rows: 1063 | |
}) | |
}) |
- 通过 csv 脚本加载本地的 test.tsv 文件中的数据集
>>> dataset = datasets.load_dataset("csv", data_dir="E:\Python\\transfomers\\test", data_files="test.tsv") | |
>>> dataset | |
DatasetDict({ | |
train: Dataset({ | |
features: ['14'], | |
num_rows: 4 | |
}) | |
}) |
- 通过 glue.py 脚本文件加载 cola 数据集
>>> dataset_1 = datasets.load_dataset("../dataset/glue/glue.py", name="cola") | |
# 与上一个等价 | |
>>> dataset_2 = datasets.load_dataset("../dataset/glue", name="cola") |