-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b993ecc
commit 0ff4cf1
Showing
2 changed files
with
231 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from tqdm import tqdm | ||
from typing import Callable, Optional | ||
import torch | ||
import pandas as pd | ||
|
||
from datasets import load_dataset | ||
|
||
import dataset_info | ||
|
||
|
||
# Load and prepare dataset | ||
def load_huggingface_dataset(dataset_name: str) -> tuple[pd.DataFrame, pd.DataFrame]: | ||
if dataset_name == "bias_in_bios": | ||
dataset = load_dataset("LabHC/bias_in_bios") | ||
train_df = pd.DataFrame(dataset["train"]) | ||
test_df = pd.DataFrame(dataset["test"]) | ||
elif dataset_name == "amazon_reviews_all_ratings": | ||
dataset = load_dataset( | ||
"canrager/amazon_reviews_mcauley", | ||
config_name="dataset_all_categories_and_ratings_train1000_test250", | ||
) | ||
elif dataset_name == "amazon_reviews_1and5": | ||
dataset = load_dataset( | ||
"canrager/amazon_reviews_mcauley_1and5", | ||
) | ||
train_df = pd.DataFrame(dataset["train"]) | ||
test_df = pd.DataFrame(dataset["test"]) | ||
else: | ||
raise ValueError(f"Unknown dataset name: {dataset_name}") | ||
return train_df, test_df | ||
|
||
|
||
def get_balanced_dataset( | ||
df: pd.DataFrame, | ||
dataset_name: str, | ||
min_samples_per_quadrant: int, | ||
random_seed: int, | ||
) -> dict[str, list[str]]: | ||
"""Returns a dataset of, in the case of bias_in_bios, a key of profession idx, | ||
and a value of a list of bios (strs) of len min_samples_per_quadrant * 2.""" | ||
|
||
text_column_name = dataset_info.dataset_metadata[dataset_name]["text_column_name"] | ||
column1_name = dataset_info.dataset_metadata[dataset_name]["column1_name"] | ||
column2_name = dataset_info.dataset_metadata[dataset_name]["column2_name"] | ||
|
||
balanced_df_list = [] | ||
|
||
for profession in tqdm(df[column1_name].unique()): | ||
prof_df = df[df[column1_name] == profession] | ||
min_count = prof_df[column2_name].value_counts().min() | ||
|
||
if min_count < min_samples_per_quadrant: | ||
continue | ||
|
||
balanced_prof_df = pd.concat( | ||
[ | ||
group.sample(n=min_samples_per_quadrant, random_state=random_seed) | ||
for _, group in prof_df.groupby(column2_name) | ||
] | ||
).reset_index(drop=True) | ||
balanced_df_list.append(balanced_prof_df) | ||
|
||
balanced_df = pd.concat(balanced_df_list).reset_index(drop=True) | ||
grouped = balanced_df.groupby(column1_name)[text_column_name].apply(list) | ||
|
||
str_data = {str(key): texts for key, texts in grouped.items()} | ||
|
||
balanced_data = {label: texts for label, texts in str_data.items()} | ||
|
||
for key in balanced_data.keys(): | ||
balanced_data[key] = balanced_data[key][: min_samples_per_quadrant * 2] | ||
assert len(balanced_data[key]) == min_samples_per_quadrant * 2 | ||
|
||
return balanced_data | ||
|
||
|
||
def ensure_shared_keys(train_data: dict, test_data: dict) -> tuple[dict, dict]: | ||
# Find keys that are in test but not in train | ||
test_only_keys = set(test_data.keys()) - set(train_data.keys()) | ||
|
||
# Find keys that are in train but not in test | ||
train_only_keys = set(train_data.keys()) - set(test_data.keys()) | ||
|
||
# Remove keys from test that are not in train | ||
for key in test_only_keys: | ||
print(f"Removing {key} from test set") | ||
del test_data[key] | ||
|
||
# Remove keys from train that are not in test | ||
for key in train_only_keys: | ||
print(f"Removing {key} from train set") | ||
del train_data[key] | ||
|
||
return train_data, test_data | ||
|
||
|
||
def get_multi_label_train_test_data( | ||
train_df: pd.DataFrame, | ||
test_df: pd.DataFrame, | ||
dataset_name: str, | ||
train_set_size: int, | ||
test_set_size: int, | ||
random_seed: int, | ||
) -> tuple[dict[str, list[str]], dict[str, list[str]]]: | ||
"""Returns a dict of [class_name, list[str]]""" | ||
# 4 is because male / gender for each profession | ||
minimum_train_samples_per_quadrant = train_set_size // 4 | ||
minimum_test_samples_per_quadrant = test_set_size // 4 | ||
|
||
train_data = get_balanced_dataset( | ||
train_df, | ||
dataset_name, | ||
minimum_train_samples_per_quadrant, | ||
random_seed=random_seed, | ||
) | ||
test_data = get_balanced_dataset( | ||
test_df, | ||
dataset_name, | ||
minimum_test_samples_per_quadrant, | ||
random_seed=random_seed, | ||
) | ||
|
||
train_data, test_data = ensure_shared_keys(train_data, test_data) | ||
|
||
return train_data, test_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# TODO: Consolidate all bias in bios utility stuff | ||
# TODO: Only use strings for keys, only use ints when initializing the dictionary datasets | ||
|
||
POSITIVE_CLASS_LABEL = 1 | ||
NEGATIVE_CLASS_LABEL = 0 | ||
|
||
profession_dict = { | ||
"accountant": 0, | ||
"architect": 1, | ||
"attorney": 2, | ||
"chiropractor": 3, | ||
"comedian": 4, | ||
"composer": 5, | ||
"dentist": 6, | ||
"dietitian": 7, | ||
"dj": 8, | ||
"filmmaker": 9, | ||
"interior_designer": 10, | ||
"journalist": 11, | ||
"model": 12, | ||
"nurse": 13, | ||
"painter": 14, | ||
"paralegal": 15, | ||
"pastor": 16, | ||
"personal_trainer": 17, | ||
"photographer": 18, | ||
"physician": 19, | ||
"poet": 20, | ||
"professor": 21, | ||
"psychologist": 22, | ||
"rapper": 23, | ||
"software_engineer": 24, | ||
"surgeon": 25, | ||
"teacher": 26, | ||
"yoga_teacher": 27, | ||
} | ||
profession_int_to_str = {v: k for k, v in profession_dict.items()} | ||
|
||
gender_dict = { | ||
"male": 0, | ||
"female": 1, | ||
} | ||
|
||
# From the original dataset | ||
amazon_category_dict = { | ||
"All_Beauty": 0, | ||
"Toys_and_Games": 1, | ||
"Cell_Phones_and_Accessories": 2, | ||
"Industrial_and_Scientific": 3, | ||
"Gift_Cards": 4, | ||
"Musical_Instruments": 5, | ||
"Electronics": 6, | ||
"Handmade_Products": 7, | ||
"Arts_Crafts_and_Sewing": 8, | ||
"Baby_Products": 9, | ||
"Health_and_Household": 10, | ||
"Office_Products": 11, | ||
"Digital_Music": 12, | ||
"Grocery_and_Gourmet_Food": 13, | ||
"Sports_and_Outdoors": 14, | ||
"Home_and_Kitchen": 15, | ||
"Subscription_Boxes": 16, | ||
"Tools_and_Home_Improvement": 17, | ||
"Pet_Supplies": 18, | ||
"Video_Games": 19, | ||
"Kindle_Store": 20, | ||
"Clothing_Shoes_and_Jewelry": 21, | ||
"Patio_Lawn_and_Garden": 22, | ||
"Unknown": 23, | ||
"Books": 24, | ||
"Automotive": 25, | ||
"CDs_and_Vinyl": 26, | ||
"Beauty_and_Personal_Care": 27, | ||
"Amazon_Fashion": 28, | ||
"Magazine_Subscriptions": 29, | ||
"Software": 30, | ||
"Health_and_Personal_Care": 31, | ||
"Appliances": 32, | ||
"Movies_and_TV": 33, | ||
} | ||
amazon_int_to_str = {v: k for k, v in amazon_category_dict.items()} | ||
|
||
|
||
amazon_rating_dict = { | ||
1.0: 1.0, | ||
5.0: 5.0, | ||
} | ||
|
||
dataset_metadata = { | ||
"bias_in_bios": { | ||
"text_column_name": "hard_text", | ||
"column1_name": "profession", | ||
"column2_name": "gender", | ||
"column2_autointerp_name": "gender", | ||
"column1_mapping": profession_dict, | ||
"column2_mapping": gender_dict, | ||
}, | ||
"amazon_reviews_1and5": { | ||
"text_column_name": "text", | ||
"column1_name": "category", | ||
"column2_name": "rating", | ||
"column2_autointerp_name": "Amazon Review Sentiment", | ||
"column1_mapping": amazon_category_dict, | ||
"column2_mapping": amazon_rating_dict, | ||
}, | ||
} |