try:
from astra.torch.models import ResNetClassifier
print('installed')
except:
print('not installed')
pass
%pip install git+https://github.com/sustainability-lab/ASTRA
installed
Suraj Jaiswal
November 1, 2023
Here I have implemented Active learning on CIFAR10 dataset using ResNet18 as the base model. We currently use the following active learning strategies: - Diversity acquisition - Random acquisition
and compare their performance.
Reference Paper: link, Reference Notebook: link
We want to select the most uncertain samples from the unlabeled pool. However, we also want to ensure that the selected samples are diverse. This is because if we select similar samples, we will not be able to learn much from them.
Here we wish to work on images, so we can use their latent representation to select a pool point which is furthest in the latent dimension space from the train data.
Let’s see how we can do this. Below is an example of random points in 2D space. We want to select the point which is furthest from the train data. We can do this by finding the point which has the maximum distance from the train data. This is the intuition behind diversity acquisition.
Importing the required libraries and modules.
try:
from astra.torch.models import ResNetClassifier
print('installed')
except:
print('not installed')
pass
%pip install git+https://github.com/sustainability-lab/ASTRA
installed
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
# Confusion matrix
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import torchsummary
from tqdm import tqdm
import umap
# ASTRA
from astra.torch.data import load_cifar_10
from astra.torch.utils import train_fn
from astra.torch.models import ResNetClassifier
from astra.torch.al import Furthest, Centroid, DiversityStrategy, UniformRandomAcquisition, RandomStrategy
# Netron, ONNX for model visualization
import netron
import onnx
/home/jaiswalsuraj/miniconda3/envs/torch_gpu/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
/home/jaiswalsuraj/suraj_work/ASTRA/astra/torch/data.py:12: UserWarning: TORCH_HOME not set, setting it to /home/jaiswalsuraj/.cache/torch
warnings.warn(f"TORCH_HOME not set, setting it to {os.environ['TORCH_HOME']}")
device = torch.device("cuda:2")
current_device = device #torch.cuda.current_device()
device_name = torch.cuda.get_device_name(current_device)
print(f"Current GPU assigned: {current_device}, Name: {device_name}, {device}")
Current GPU assigned: cuda:2, Name: NVIDIA A100-SXM4-80GB, cuda:2
# Create dummy data
n_train = 5
n_pool = 10
n_features = 2
torch.manual_seed(0)
# Generate random train and pool features
train_features = torch.rand(n_train, n_features)
pool_features = torch.rand(n_pool, n_features)
plt.scatter(train_features[:, 0], train_features[:, 1], label='Train Data', marker='o')
plt.scatter(pool_features[:, 0], pool_features[:, 1], label='Pool Data', marker='x')
# plt.scatter(centroid_feature[0], centroid_feature[1], label='Centroid', marker='s', c='red')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.title('Train and Pool Data farthest from Centroid')
plt.grid()
plt.show()
n_samples = 2 # number of samples to select from pool
acquisition = Furthest()
selected_indices = acquisition.acquire_scores(
train_features.cpu(), pool_features.cpu(), n=n_samples
)
selected_indices
[8, 9]
print("Index of the n farthest Pool Feature from train features:", selected_indices)
plt.scatter(pool_features[selected_indices][:,0], pool_features[selected_indices][:,1], label='Selected data', marker='s', c='green')
plt.scatter(train_features[:, 0], train_features[:, 1], label='Train Data', marker='o')
plt.scatter(pool_features[:, 0], pool_features[:, 1], label='Pool Data', marker='x')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.title('Train and Pool Data farthest from Centroid')
plt.grid()
plt.show()
Index of the n farthest Pool Feature from train features: [8, 9]
let’s now work on CIFAR10 dataset.
Files already downloaded and verified
Files already downloaded and verified
CIFAR-10 Dataset
length of dataset: 60000
shape of images: torch.Size([3, 32, 32])
len of classes: 10
classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dtype of images: torch.float32
dtype of labels: torch.int64
# Plot some images
plt.figure(figsize=(6, 6))
for i in range(25):
plt.subplot(5, 5, i+1)
plt.imshow(torch.einsum("chw->hwc", dataset.data[i].cpu()))
plt.axis('off')
plt.title(dataset.classes[dataset.targets[i]])
plt.tight_layout()
n_train = 1000
n_test = 20000
X = dataset.data.to(device)
y = dataset.targets.to(device)
print(X.shape)
print(X.shape, X.dtype)
print(X.min(), X.max())
print(y.shape, y.dtype)
print(X.device, y.device)
torch.Size([60000, 3, 32, 32])
torch.Size([60000, 3, 32, 32]) torch.float32
tensor(0., device='cuda:2') tensor(1., device='cuda:2')
torch.Size([60000]) torch.int64
cuda:2 cuda:2
Before we start training the model, let’s see how the model performs without training so we can get lowerbound on the performance of the model.
def get_accuracy(net, X, y):
with torch.no_grad():
logits_pred = net(X)
y_pred = logits_pred.argmax(dim=1)
acc = (y_pred == y).float().mean()
return y_pred, acc
def predict(net, classes, plot_confusion_matrix=False, train_idx= train_idx, pool_idx=pool_idx, test_idx=test_idx):
accuracies = {}
for i, (name, idx) in enumerate(zip(("train", "pool", "test"), [train_idx, pool_idx, test_idx])):
X_dataset = X[idx].to(device)
y_dataset = y[idx].to(device)
y_pred, acc = get_accuracy(net, X_dataset, y_dataset)
accuracies[name] = acc.item()
print(f'{name} set accuracy: {acc*100:.2f}%')
if plot_confusion_matrix:
cm = confusion_matrix(y_dataset.cpu(), y_pred.cpu())
cm_display = ConfusionMatrixDisplay(cm, display_labels=classes).plot(values_format='d'
, cmap='Blues')
# Rotate the labels on x-axis to make them readable
_ = plt.xticks(rotation=90)
plt.show()
return accuracies
accuracy_summary['untrain_acc'] = predict(resnet, dataset.classes, plot_confusion_matrix=False)
accuracy_summary
train set accuracy: 7.70%
pool set accuracy: 8.37%
test set accuracy: 8.58%
{'untrain_acc': {'train': 0.07700000703334808,
'pool': 0.08366666734218597,
'test': 0.08579999953508377}}
def viz_embeddings(net, X, y, device):
reducer = umap.UMAP()
with torch.no_grad():
emb = net.featurizer(X.to(device))
emb = emb.cpu().numpy()
emb = reducer.fit_transform(emb)
plt.figure(figsize=(4, 4))
plt.scatter(emb[:, 0], emb[:, 1], c=y.cpu().numpy(), cmap='tab10')
# Add a colorbar legend to mark color to class mapping
cb = plt.colorbar(boundaries=np.arange(11)-0.5)
cb.set_ticks(np.arange(10))
cb.set_ticklabels(dataset.classes)
plt.title("UMAP embeddings")
plt.tight_layout()
Text(0, 0.5, 'Training loss')
train set accuracy: 100.00%
pool set accuracy: 36.06%
test set accuracy: 36.24%
torch.manual_seed(0)
idx = torch.randperm(len(X))
train_idx = idx[:n_train]
pool_idx = idx[n_train:-n_test]
test_idx = idx[-n_test:]
### Train on train + pool
train_plus_pool_idx = torch.cat([train_idx, pool_idx[:5000]])
print(train_plus_pool_idx.shape)
resnet = ResNetClassifier(models.resnet18, None, n_classes=10, activation=nn.GELU(), dropout=0.1).to(device)
iter_losses, epoch_losses = train_fn(resnet, X[train_plus_pool_idx], y[train_plus_pool_idx], loss_fn=nn.CrossEntropyLoss(),
lr=3e-4,
batch_size=1024, epochs=30, verbose=False)
torch.Size([6000])
Text(0, 0.5, 'Training loss')
torch.manual_seed(0)
idx = torch.randperm(len(X))
train_idx = idx[:n_train]
pool_idx = idx[n_train:-n_test]
test_idx = idx[-n_test:]
### Train on train + pool
train_plus_pool_idx = torch.cat([train_idx, pool_idx])
resnet = ResNetClassifier(models.resnet18, None, n_classes=10, activation=nn.GELU(), dropout=0.1).to(device)
iter_losses, epoch_losses = train_fn(resnet, X[train_plus_pool_idx], y[train_plus_pool_idx], loss_fn=nn.CrossEntropyLoss(),
lr=3e-4,
batch_size=1024, epochs=30, verbose=False)
Text(0, 0.5, 'Training loss')
accuracy_summary[f"train_1000_pool_39000"] = predict(resnet, dataset.classes, plot_confusion_matrix=True)
train set accuracy: 99.60%
pool set accuracy: 99.43%
test set accuracy: 61.61%
{'untrain_acc': {'train': 0.07700000703334808,
'pool': 0.08366666734218597,
'test': 0.08579999953508377},
'train_1000': {'train': 1.0,
'pool': 0.36056411266326904,
'test': 0.3624500036239624},
'train_1000_pool_5000': {'train': 0.9920000433921814,
'pool': 0.4998205304145813,
'test': 0.4273499846458435},
'train_1000_pool_39000': {'train': 0.9960000514984131,
'pool': 0.994282066822052,
'test': 0.616100013256073}}
loading accuracy results
# import json
# file_path = "/home/jaiswalsuraj/suraj_work/ASTRA/notebooks/al/accuracy_summary_without.json"
# with open(file_path, 'w') as json_file:
# json.dump(accuracy_summary, json_file)
# print(f"Accuracy summary has been saved to {file_path}.")
Accuracy summary has been saved to /home/jaiswalsuraj/suraj_work/ASTRA/notebooks/al/accuracy_summary_without.json.
{'untrain_acc': {'train': 0.07700000703334808,
'pool': 0.08366666734218597,
'test': 0.08579999953508377},
'train_1000': {'train': 1.0,
'pool': 0.36056411266326904,
'test': 0.3624500036239624},
'train_1000_pool_5000': {'train': 0.9920000433921814,
'pool': 0.4998205304145813,
'test': 0.4273499846458435},
'train_1000_pool_39000': {'train': 0.9960000514984131,
'pool': 0.994282066822052,
'test': 0.616100013256073}}
Paper link
def AL_loop(train_idx, pool_idx, strategy, acquisition, n_query_samples, num_iter):
train_idx_copy = train_idx.clone().to(device)
pool_idx_copy = pool_idx.clone().to(device)
n_query_samples = n_query_samples # Number of samples to add to training set at each iteration
num_iter = num_iter # Number of iterations to run active learning for
accuracy_list = [] # List to store accuracies
class_count_list = [] # List to store class counts
for i in trange(num_iter):
resnet = ResNetClassifier(models.resnet18, None, n_classes=10, activation=nn.GELU(), dropout=0.1).to(device)
iter_losses, epoch_losses = train_fn(resnet, X[train_idx_copy], y[train_idx_copy], nn.CrossEntropyLoss(), lr=3e-4,
batch_size=128, epochs=30, verbose=False)
if strategy.__class__.__name__ == 'DiversityStrategy':
best_indices = strategy.query(
resnet.featurizer, pool_idx_copy.to(device), train_idx_copy.to(device), n_query_samples=n_query_samples
)
elif strategy.__class__.__name__== 'RandomStrategy':
best_indices = strategy.query(
resnet, pool_idx_copy.to(device), n_query_samples=n_query_samples
)
else:
print('Invalid strategy')
return
accuracy_list.append(predict(resnet, dataset.classes, plot_confusion_matrix=False, train_idx=train_idx_copy, pool_idx=pool_idx_copy, test_idx=test_idx))
class_counts = torch.bincount(y[train_idx_copy], minlength=len(dataset.classes))
class_count_list.append(class_counts)
# print(train_idx_copy.device, pool_idx_copy.device, best_indices[acquisition.__class__.__name__].device)
train_idx_copy = torch.cat([train_idx_copy, best_indices[acquisition.__class__.__name__]])
# pool_idx_copy = torch.cat([pool_idx_copy[i:i + 1] for i in range(len(pool_idx_copy)) if i not in best_indices[acquisition.__class__.__name__]], dim=0)
pool_idx_copy = pool_idx_copy[~torch.isin(pool_idx_copy, best_indices[acquisition.__class__.__name__])]
class_counts = torch.bincount(y[train_idx_copy], minlength=len(dataset.classes))
class_count_list.append(class_counts)
accuracy_list.append(predict(resnet, dataset.classes, plot_confusion_matrix=False, train_idx=train_idx_copy, pool_idx=pool_idx_copy, test_idx=test_idx))
print('train length: ', train_idx_copy.shape,'pool lenght: ', pool_idx_copy.shape)
selected_idx = train_idx_copy
return accuracy_list, accuracy_list[-1], class_count_list # returns accuracy list(over al iterations) and accuracy at end(train, pool, test)
acquisition = Furthest()
strategy = DiversityStrategy(acquisition, X, y)
strategy.to(device)
seeds = [0 ]#, 1, 2, 3, 4]
for seed in seeds:
print(strategy.__class__.__name__, ' For seed: ', seed, ' ----------------------------------------- ')
n_query_samples = 100
num_iter = 50
a, b, c = AL_loop(train_idx, pool_idx, strategy, acquisition, n_query_samples = n_query_samples, num_iter = num_iter)
accuracy_AL_list[f"train_1000_pool_query_{n_query_samples}_iter_{num_iter}_{strategy.__class__.__name__}_seed_{seed}"] = a
accuracy_summary[f"train_1000_pool_query_{n_query_samples}_iter_{num_iter}_{strategy.__class__.__name__}_seed_{seed}"] = b
class_count_dict[f"train_1000_pool_query_{n_query_samples}_iter_{num_iter}_{strategy.__class__.__name__}_seed_{seed}"] = c
DiversityStrategy For seed: 0 -----------------------------------------
0%| | 0/50 [00:00<?, ?it/s]
train set accuracy: 100.00%
2%|▏ | 1/50 [00:04<03:44, 4.58s/it]
pool set accuracy: 36.18%
test set accuracy: 36.43%
train set accuracy: 100.00%
4%|▍ | 2/50 [00:09<03:48, 4.76s/it]
pool set accuracy: 36.20%
test set accuracy: 36.69%
train set accuracy: 100.00%
6%|▌ | 3/50 [00:14<03:55, 5.02s/it]
pool set accuracy: 37.16%
test set accuracy: 37.14%
train set accuracy: 99.54%
8%|▊ | 4/50 [00:20<04:02, 5.27s/it]
pool set accuracy: 36.48%
test set accuracy: 36.71%
train set accuracy: 100.00%
10%|█ | 5/50 [00:26<04:04, 5.43s/it]
pool set accuracy: 38.77%
test set accuracy: 38.72%
train set accuracy: 100.00%
12%|█▏ | 6/50 [00:32<04:08, 5.64s/it]
pool set accuracy: 39.10%
test set accuracy: 39.38%
train set accuracy: 100.00%
14%|█▍ | 7/50 [00:38<04:13, 5.91s/it]
pool set accuracy: 39.21%
test set accuracy: 39.09%
train set accuracy: 99.82%
16%|█▌ | 8/50 [00:45<04:19, 6.18s/it]
pool set accuracy: 38.07%
test set accuracy: 38.21%
train set accuracy: 99.61%
18%|█▊ | 9/50 [00:52<04:27, 6.52s/it]
pool set accuracy: 38.90%
test set accuracy: 39.38%
train set accuracy: 100.00%
20%|██ | 10/50 [00:59<04:29, 6.74s/it]
pool set accuracy: 39.75%
test set accuracy: 39.89%
train set accuracy: 100.00%
22%|██▏ | 11/50 [01:07<04:31, 6.96s/it]
pool set accuracy: 39.51%
test set accuracy: 39.52%
train set accuracy: 99.67%
24%|██▍ | 12/50 [01:15<04:35, 7.26s/it]
pool set accuracy: 39.12%
test set accuracy: 39.17%
train set accuracy: 100.00%
26%|██▌ | 13/50 [01:23<04:39, 7.55s/it]
pool set accuracy: 40.86%
test set accuracy: 40.92%
train set accuracy: 100.00%
28%|██▊ | 14/50 [01:31<04:41, 7.82s/it]
pool set accuracy: 39.84%
test set accuracy: 39.86%
train set accuracy: 96.38%
30%|███ | 15/50 [01:40<04:44, 8.14s/it]
pool set accuracy: 38.85%
test set accuracy: 38.47%
train set accuracy: 99.92%
32%|███▏ | 16/50 [01:49<04:46, 8.44s/it]
pool set accuracy: 40.29%
test set accuracy: 40.52%
train set accuracy: 95.46%
34%|███▍ | 17/50 [01:59<04:50, 8.80s/it]
pool set accuracy: 40.26%
test set accuracy: 40.57%
train set accuracy: 99.30%
36%|███▌ | 18/50 [02:09<04:55, 9.23s/it]
pool set accuracy: 39.75%
test set accuracy: 40.44%
train set accuracy: 99.61%
38%|███▊ | 19/50 [02:20<04:54, 9.51s/it]
pool set accuracy: 39.95%
test set accuracy: 40.60%
train set accuracy: 100.00%
40%|████ | 20/50 [02:30<04:56, 9.89s/it]
pool set accuracy: 41.72%
test set accuracy: 42.24%
train set accuracy: 99.97%
42%|████▏ | 21/50 [02:41<04:56, 10.23s/it]
pool set accuracy: 41.81%
test set accuracy: 42.22%
train set accuracy: 99.68%
44%|████▍ | 22/50 [02:53<04:58, 10.68s/it]
pool set accuracy: 41.07%
test set accuracy: 42.00%
train set accuracy: 99.81%
46%|████▌ | 23/50 [03:05<04:55, 10.93s/it]
pool set accuracy: 42.33%
test set accuracy: 42.37%
train set accuracy: 99.00%
48%|████▊ | 24/50 [03:17<04:55, 11.35s/it]
pool set accuracy: 39.79%
test set accuracy: 41.30%
train set accuracy: 99.47%
50%|█████ | 25/50 [03:30<04:55, 11.80s/it]
pool set accuracy: 42.15%
test set accuracy: 42.94%
train set accuracy: 99.74%
52%|█████▏ | 26/50 [03:43<04:53, 12.24s/it]
pool set accuracy: 40.05%
test set accuracy: 40.45%
train set accuracy: 99.31%
54%|█████▍ | 27/50 [03:57<04:54, 12.81s/it]
pool set accuracy: 43.81%
test set accuracy: 43.76%
train set accuracy: 99.57%
56%|█████▌ | 28/50 [04:11<04:51, 13.24s/it]
pool set accuracy: 41.94%
test set accuracy: 42.24%
train set accuracy: 98.95%
58%|█████▊ | 29/50 [04:26<04:49, 13.79s/it]
pool set accuracy: 41.29%
test set accuracy: 42.16%
train set accuracy: 99.82%
60%|██████ | 30/50 [04:42<04:44, 14.24s/it]
pool set accuracy: 43.28%
test set accuracy: 43.58%
train set accuracy: 99.45%
62%|██████▏ | 31/50 [04:58<04:42, 14.85s/it]
pool set accuracy: 41.29%
test set accuracy: 42.22%
train set accuracy: 99.93%
64%|██████▍ | 32/50 [05:15<04:38, 15.46s/it]
pool set accuracy: 44.22%
test set accuracy: 44.94%
train set accuracy: 99.48%
66%|██████▌ | 33/50 [05:32<04:29, 15.84s/it]
pool set accuracy: 43.57%
test set accuracy: 44.35%
train set accuracy: 98.74%
68%|██████▊ | 34/50 [05:49<04:21, 16.33s/it]
pool set accuracy: 42.53%
test set accuracy: 42.91%
train set accuracy: 99.55%
70%|███████ | 35/50 [06:07<04:12, 16.85s/it]
pool set accuracy: 43.60%
test set accuracy: 44.20%
train set accuracy: 99.13%
72%|███████▏ | 36/50 [06:26<04:02, 17.32s/it]
pool set accuracy: 45.06%
test set accuracy: 46.16%
train set accuracy: 99.91%
74%|███████▍ | 37/50 [06:44<03:50, 17.74s/it]
pool set accuracy: 44.80%
test set accuracy: 45.64%
train set accuracy: 99.23%
76%|███████▌ | 38/50 [07:03<03:37, 18.15s/it]
pool set accuracy: 42.77%
test set accuracy: 43.78%
train set accuracy: 99.85%
78%|███████▊ | 39/50 [07:23<03:24, 18.62s/it]
pool set accuracy: 45.19%
test set accuracy: 45.91%
train set accuracy: 99.78%
80%|████████ | 40/50 [07:43<03:09, 18.99s/it]
pool set accuracy: 46.17%
test set accuracy: 47.22%
train set accuracy: 99.92%
82%|████████▏ | 41/50 [08:04<02:55, 19.52s/it]
pool set accuracy: 46.44%
test set accuracy: 46.54%
train set accuracy: 99.55%
84%|████████▍ | 42/50 [08:24<02:38, 19.87s/it]
pool set accuracy: 45.56%
test set accuracy: 46.45%
train set accuracy: 99.88%
86%|████████▌ | 43/50 [08:46<02:22, 20.34s/it]
pool set accuracy: 45.21%
test set accuracy: 46.16%
train set accuracy: 98.98%
88%|████████▊ | 44/50 [09:08<02:05, 20.86s/it]
pool set accuracy: 43.86%
test set accuracy: 44.58%
train set accuracy: 99.00%
90%|█████████ | 45/50 [09:31<01:47, 21.52s/it]
pool set accuracy: 45.94%
test set accuracy: 46.47%
train set accuracy: 99.35%
92%|█████████▏| 46/50 [09:54<01:27, 21.89s/it]
pool set accuracy: 46.94%
test set accuracy: 47.50%
train set accuracy: 99.91%
94%|█████████▍| 47/50 [10:17<01:06, 22.31s/it]
pool set accuracy: 46.90%
test set accuracy: 48.71%
train set accuracy: 99.84%
96%|█████████▌| 48/50 [10:41<00:45, 22.75s/it]
pool set accuracy: 47.50%
test set accuracy: 48.06%
train set accuracy: 99.67%
98%|█████████▊| 49/50 [11:06<00:23, 23.58s/it]
pool set accuracy: 47.27%
test set accuracy: 48.16%
train set accuracy: 99.85%
100%|██████████| 50/50 [11:31<00:00, 13.84s/it]
pool set accuracy: 47.33%
test set accuracy: 48.04%
train set accuracy: 98.82%
pool set accuracy: 47.44%
test set accuracy: 48.17%
train length: torch.Size([6000]) pool lenght: torch.Size([34000])
acquisition = UniformRandomAcquisition()
strategy = RandomStrategy(acquisition, X, y)
seeds = [0 , 1, 2, 3, 4]
for seed in seeds:
print(strategy.__class__.__name__, ' For seed: ', seed, ' ----------------------------------------- ')
n_query_samples = 100
num_iter = 50
torch.manual_seed(seed)
a, b, c = AL_loop(train_idx, pool_idx, strategy, acquisition, n_query_samples = n_query_samples, num_iter = num_iter)
accuracy_AL_list[f"train_1000_pool_query_{n_query_samples}_iter_{num_iter}_{strategy.__class__.__name__}_seed_{seed}"] = a
accuracy_summary[f"train_1000_pool_query_{n_query_samples}_iter_{num_iter}_{strategy.__class__.__name__}_seed_{seed}"] = b
class_count_dict[f"train_1000_pool_query_{n_query_samples}_iter_{num_iter}_{strategy.__class__.__name__}_seed_{seed}"] = c
RandomStrategy For seed: 0 -----------------------------------------
0%| | 0/50 [00:00<?, ?it/s]
train set accuracy: 99.90%
2%|▏ | 1/50 [00:03<03:07, 3.82s/it]
pool set accuracy: 36.75%
test set accuracy: 37.28%
train set accuracy: 100.00%
4%|▍ | 2/50 [00:07<03:03, 3.82s/it]
pool set accuracy: 37.13%
test set accuracy: 36.95%
train set accuracy: 100.00%
6%|▌ | 3/50 [00:11<03:07, 4.00s/it]
pool set accuracy: 37.36%
test set accuracy: 37.30%
train set accuracy: 99.31%
8%|▊ | 4/50 [00:16<03:14, 4.22s/it]
pool set accuracy: 36.82%
test set accuracy: 36.89%
train set accuracy: 100.00%
10%|█ | 5/50 [00:21<03:16, 4.38s/it]
pool set accuracy: 39.26%
test set accuracy: 39.49%
train set accuracy: 100.00%
12%|█▏ | 6/50 [00:26<03:22, 4.60s/it]
pool set accuracy: 40.21%
test set accuracy: 40.17%
train set accuracy: 99.94%
14%|█▍ | 7/50 [00:31<03:29, 4.88s/it]
pool set accuracy: 40.04%
test set accuracy: 40.26%
train set accuracy: 98.94%
16%|█▌ | 8/50 [00:37<03:38, 5.19s/it]
pool set accuracy: 38.73%
test set accuracy: 38.51%
train set accuracy: 99.11%
18%|█▊ | 9/50 [00:43<03:46, 5.52s/it]
pool set accuracy: 39.41%
test set accuracy: 39.38%
train set accuracy: 100.00%
20%|██ | 10/50 [00:49<03:49, 5.74s/it]
pool set accuracy: 41.66%
test set accuracy: 40.93%
train set accuracy: 100.00%
22%|██▏ | 11/50 [00:56<03:54, 6.02s/it]
pool set accuracy: 40.68%
test set accuracy: 41.06%
train set accuracy: 99.33%
24%|██▍ | 12/50 [01:03<04:00, 6.33s/it]
pool set accuracy: 40.79%
test set accuracy: 40.35%
train set accuracy: 99.41%
26%|██▌ | 13/50 [01:11<04:06, 6.67s/it]
pool set accuracy: 41.95%
test set accuracy: 42.38%
train set accuracy: 100.00%
28%|██▊ | 14/50 [01:18<04:08, 6.91s/it]
pool set accuracy: 43.02%
test set accuracy: 42.95%
train set accuracy: 99.79%
30%|███ | 15/50 [01:26<04:10, 7.15s/it]
pool set accuracy: 42.26%
test set accuracy: 41.33%
train set accuracy: 99.08%
32%|███▏ | 16/50 [01:34<04:11, 7.40s/it]
pool set accuracy: 42.47%
test set accuracy: 42.95%
train set accuracy: 100.00%
34%|███▍ | 17/50 [01:42<04:14, 7.72s/it]
pool set accuracy: 45.05%
test set accuracy: 44.65%
train set accuracy: 99.15%
36%|███▌ | 18/50 [01:51<04:18, 8.07s/it]
pool set accuracy: 43.83%
test set accuracy: 43.52%
train set accuracy: 96.46%
38%|███▊ | 19/50 [02:00<04:17, 8.32s/it]
pool set accuracy: 40.97%
test set accuracy: 40.78%
train set accuracy: 98.07%
40%|████ | 20/50 [02:09<04:18, 8.62s/it]
pool set accuracy: 42.18%
test set accuracy: 42.10%
train set accuracy: 96.00%
42%|████▏ | 21/50 [02:19<04:21, 9.01s/it]
pool set accuracy: 42.37%
test set accuracy: 42.40%
train set accuracy: 100.00%
44%|████▍ | 22/50 [02:29<04:22, 9.39s/it]
pool set accuracy: 45.65%
test set accuracy: 45.89%
train set accuracy: 99.66%
46%|████▌ | 23/50 [02:40<04:20, 9.65s/it]
pool set accuracy: 45.22%
test set accuracy: 45.42%
train set accuracy: 100.00%
48%|████▊ | 24/50 [02:50<04:19, 9.97s/it]
pool set accuracy: 46.08%
test set accuracy: 46.19%
train set accuracy: 100.00%
50%|█████ | 25/50 [03:02<04:19, 10.39s/it]
pool set accuracy: 45.51%
test set accuracy: 45.74%
train set accuracy: 99.69%
52%|█████▏ | 26/50 [03:14<04:18, 10.79s/it]
pool set accuracy: 45.97%
test set accuracy: 45.96%
train set accuracy: 99.97%
54%|█████▍ | 27/50 [03:26<04:18, 11.24s/it]
pool set accuracy: 47.10%
test set accuracy: 46.11%
train set accuracy: 99.68%
56%|█████▌ | 28/50 [03:38<04:16, 11.65s/it]
pool set accuracy: 46.29%
test set accuracy: 45.72%
train set accuracy: 99.87%
58%|█████▊ | 29/50 [03:52<04:14, 12.10s/it]
pool set accuracy: 47.46%
test set accuracy: 47.24%
train set accuracy: 99.69%
60%|██████ | 30/50 [04:05<04:11, 12.60s/it]
pool set accuracy: 46.62%
test set accuracy: 46.11%
train set accuracy: 99.95%
62%|██████▏ | 31/50 [04:20<04:08, 13.08s/it]
pool set accuracy: 47.58%
test set accuracy: 47.62%
train set accuracy: 96.80%
64%|██████▍ | 32/50 [04:34<04:04, 13.59s/it]
pool set accuracy: 47.16%
test set accuracy: 46.98%
train set accuracy: 98.36%
66%|██████▌ | 33/50 [04:49<03:57, 13.97s/it]
pool set accuracy: 46.82%
test set accuracy: 46.95%
train set accuracy: 99.21%
68%|██████▊ | 34/50 [05:04<03:49, 14.35s/it]
pool set accuracy: 47.16%
test set accuracy: 46.65%
train set accuracy: 97.61%
70%|███████ | 35/50 [05:20<03:42, 14.81s/it]
pool set accuracy: 45.40%
test set accuracy: 45.27%
train set accuracy: 99.71%
72%|███████▏ | 36/50 [05:37<03:33, 15.27s/it]
pool set accuracy: 47.11%
test set accuracy: 47.45%
train set accuracy: 99.57%
74%|███████▍ | 37/50 [05:53<03:22, 15.59s/it]
pool set accuracy: 47.18%
test set accuracy: 47.20%
train set accuracy: 99.70%
76%|███████▌ | 38/50 [06:10<03:12, 16.02s/it]
pool set accuracy: 47.92%
test set accuracy: 47.90%
train set accuracy: 99.23%
78%|███████▊ | 39/50 [06:27<03:00, 16.45s/it]
pool set accuracy: 47.92%
test set accuracy: 47.75%
train set accuracy: 99.12%
80%|████████ | 40/50 [06:45<02:49, 16.91s/it]
pool set accuracy: 48.58%
test set accuracy: 48.38%
train set accuracy: 99.32%
82%|████████▏ | 41/50 [07:04<02:35, 17.32s/it]
pool set accuracy: 48.11%
test set accuracy: 48.56%
train set accuracy: 98.16%
84%|████████▍ | 42/50 [07:22<02:20, 17.60s/it]
pool set accuracy: 47.28%
test set accuracy: 47.16%
train set accuracy: 99.60%
86%|████████▌ | 43/50 [07:41<02:05, 17.93s/it]
pool set accuracy: 49.05%
test set accuracy: 48.90%
train set accuracy: 98.83%
88%|████████▊ | 44/50 [08:00<01:49, 18.31s/it]
pool set accuracy: 48.64%
test set accuracy: 48.23%
train set accuracy: 99.31%
90%|█████████ | 45/50 [08:20<01:33, 18.73s/it]
pool set accuracy: 47.98%
test set accuracy: 48.22%
train set accuracy: 98.73%
92%|█████████▏| 46/50 [08:40<01:16, 19.11s/it]
pool set accuracy: 49.87%
test set accuracy: 49.91%
train set accuracy: 99.61%
94%|█████████▍| 47/50 [09:00<00:58, 19.45s/it]
pool set accuracy: 49.80%
test set accuracy: 50.30%
train set accuracy: 98.74%
96%|█████████▌| 48/50 [09:21<00:39, 19.84s/it]
pool set accuracy: 49.02%
test set accuracy: 48.50%
train set accuracy: 99.22%
98%|█████████▊| 49/50 [09:42<00:20, 20.32s/it]
pool set accuracy: 50.22%
test set accuracy: 50.20%
train set accuracy: 99.39%
100%|██████████| 50/50 [10:04<00:00, 12.09s/it]
pool set accuracy: 50.98%
test set accuracy: 50.98%
train set accuracy: 98.63%
pool set accuracy: 50.91%
test set accuracy: 50.92%
train length: torch.Size([6000]) pool lenght: torch.Size([34000])
RandomStrategy For seed: 1 -----------------------------------------
0%| | 0/50 [00:00<?, ?it/s]
train set accuracy: 100.00%
2%|▏ | 1/50 [00:04<03:24, 4.18s/it]
pool set accuracy: 35.72%
test set accuracy: 36.41%
train set accuracy: 100.00%
4%|▍ | 2/50 [00:08<03:34, 4.47s/it]
pool set accuracy: 37.39%
test set accuracy: 37.40%
train set accuracy: 100.00%
6%|▌ | 3/50 [00:14<03:44, 4.78s/it]
pool set accuracy: 37.63%
test set accuracy: 37.96%
train set accuracy: 100.00%
8%|▊ | 4/50 [00:19<03:55, 5.11s/it]
pool set accuracy: 38.02%
test set accuracy: 37.73%
train set accuracy: 100.00%
10%|█ | 5/50 [00:25<03:57, 5.27s/it]
pool set accuracy: 39.50%
test set accuracy: 39.06%
train set accuracy: 100.00%
12%|█▏ | 6/50 [00:31<04:03, 5.53s/it]
pool set accuracy: 39.78%
test set accuracy: 40.08%
train set accuracy: 100.00%
14%|█▍ | 7/50 [00:37<04:11, 5.84s/it]
pool set accuracy: 39.65%
test set accuracy: 39.60%
train set accuracy: 99.88%
16%|█▌ | 8/50 [00:44<04:19, 6.18s/it]
pool set accuracy: 40.05%
test set accuracy: 39.87%
train set accuracy: 99.06%
18%|█▊ | 9/50 [00:52<04:29, 6.57s/it]
pool set accuracy: 39.13%
test set accuracy: 39.77%
train set accuracy: 100.00%
20%|██ | 10/50 [00:59<04:33, 6.84s/it]
pool set accuracy: 41.85%
test set accuracy: 41.67%
train set accuracy: 98.55%
22%|██▏ | 11/50 [01:07<04:39, 7.17s/it]
pool set accuracy: 39.45%
test set accuracy: 39.88%
train set accuracy: 100.00%
24%|██▍ | 12/50 [01:15<04:46, 7.53s/it]
pool set accuracy: 42.06%
test set accuracy: 41.87%
train set accuracy: 98.32%
26%|██▌ | 13/50 [01:24<04:53, 7.93s/it]
pool set accuracy: 40.48%
test set accuracy: 40.22%
train set accuracy: 99.61%
28%|██▊ | 14/50 [01:33<04:54, 8.19s/it]
pool set accuracy: 40.97%
test set accuracy: 41.18%
train set accuracy: 99.79%
30%|███ | 15/50 [01:42<04:57, 8.51s/it]
pool set accuracy: 41.43%
test set accuracy: 42.19%
train set accuracy: 100.00%
32%|███▏ | 16/50 [01:52<05:01, 8.87s/it]
pool set accuracy: 42.32%
test set accuracy: 41.87%
train set accuracy: 99.85%
34%|███▍ | 17/50 [02:02<05:05, 9.26s/it]
pool set accuracy: 42.32%
test set accuracy: 42.17%
train set accuracy: 99.96%
36%|███▌ | 18/50 [02:12<05:07, 9.62s/it]
pool set accuracy: 44.04%
test set accuracy: 44.11%
train set accuracy: 99.96%
38%|███▊ | 19/50 [02:23<05:07, 9.92s/it]
pool set accuracy: 43.93%
test set accuracy: 43.18%
train set accuracy: 100.00%
40%|████ | 20/50 [02:34<05:09, 10.31s/it]
pool set accuracy: 44.38%
test set accuracy: 44.37%
train set accuracy: 99.70%
42%|████▏ | 21/50 [02:46<05:09, 10.68s/it]
pool set accuracy: 43.96%
test set accuracy: 43.90%
train set accuracy: 99.45%
44%|████▍ | 22/50 [02:58<05:11, 11.11s/it]
pool set accuracy: 43.47%
test set accuracy: 42.73%
train set accuracy: 99.19%
46%|████▌ | 23/50 [03:10<05:06, 11.35s/it]
pool set accuracy: 43.23%
test set accuracy: 43.10%
train set accuracy: 99.73%
48%|████▊ | 24/50 [03:22<05:04, 11.70s/it]
pool set accuracy: 43.74%
test set accuracy: 43.37%
train set accuracy: 99.47%
50%|█████ | 25/50 [03:35<05:01, 12.08s/it]
pool set accuracy: 43.39%
test set accuracy: 43.78%
train set accuracy: 99.83%
52%|█████▏ | 26/50 [03:49<04:58, 12.45s/it]
pool set accuracy: 45.58%
test set accuracy: 45.03%
train set accuracy: 96.89%
54%|█████▍ | 27/50 [04:02<04:55, 12.85s/it]
pool set accuracy: 44.22%
test set accuracy: 44.20%
train set accuracy: 99.97%
56%|█████▌ | 28/50 [04:16<04:49, 13.17s/it]
pool set accuracy: 45.64%
test set accuracy: 46.04%
train set accuracy: 100.00%
58%|█████▊ | 29/50 [04:31<04:44, 13.54s/it]
pool set accuracy: 46.92%
test set accuracy: 46.88%
train set accuracy: 97.67%
60%|██████ | 30/50 [04:46<04:38, 13.92s/it]
pool set accuracy: 44.22%
test set accuracy: 45.09%
train set accuracy: 99.93%
62%|██████▏ | 31/50 [05:01<04:31, 14.31s/it]
pool set accuracy: 47.12%
test set accuracy: 47.29%
train set accuracy: 98.61%
64%|██████▍ | 32/50 [05:16<04:24, 14.71s/it]
pool set accuracy: 46.78%
test set accuracy: 46.44%
train set accuracy: 99.88%
66%|██████▌ | 33/50 [05:32<04:13, 14.93s/it]
pool set accuracy: 46.82%
test set accuracy: 46.79%
train set accuracy: 99.14%
68%|██████▊ | 34/50 [05:48<04:04, 15.30s/it]
pool set accuracy: 46.20%
test set accuracy: 46.61%
train set accuracy: 99.68%
70%|███████ | 35/50 [06:05<03:55, 15.69s/it]
pool set accuracy: 47.55%
test set accuracy: 47.23%
train set accuracy: 99.53%
72%|███████▏ | 36/50 [06:22<03:45, 16.09s/it]
pool set accuracy: 46.96%
test set accuracy: 47.06%
train set accuracy: 99.26%
74%|███████▍ | 37/50 [06:39<03:32, 16.36s/it]
pool set accuracy: 46.62%
test set accuracy: 46.71%
train set accuracy: 99.04%
76%|███████▌ | 38/50 [06:56<03:20, 16.69s/it]
pool set accuracy: 46.56%
test set accuracy: 46.51%
train set accuracy: 98.42%
78%|███████▊ | 39/50 [07:14<03:08, 17.10s/it]
pool set accuracy: 48.53%
test set accuracy: 48.99%
train set accuracy: 99.96%
80%|████████ | 40/50 [07:33<02:55, 17.51s/it]
pool set accuracy: 48.66%
test set accuracy: 48.12%
train set accuracy: 99.06%
82%|████████▏ | 41/50 [07:51<02:41, 17.91s/it]
pool set accuracy: 47.45%
test set accuracy: 48.01%
train set accuracy: 99.55%
84%|████████▍ | 42/50 [08:10<02:25, 18.15s/it]
pool set accuracy: 47.75%
test set accuracy: 47.28%
train set accuracy: 98.94%
86%|████████▌ | 43/50 [08:29<02:09, 18.49s/it]
pool set accuracy: 48.38%
test set accuracy: 48.28%
train set accuracy: 99.02%
88%|████████▊ | 44/50 [08:49<01:53, 18.84s/it]
pool set accuracy: 47.15%
test set accuracy: 47.45%
train set accuracy: 99.15%
90%|█████████ | 45/50 [09:09<01:36, 19.27s/it]
pool set accuracy: 48.05%
test set accuracy: 48.33%
train set accuracy: 99.93%
92%|█████████▏| 46/50 [09:30<01:18, 19.56s/it]
pool set accuracy: 49.98%
test set accuracy: 49.54%
train set accuracy: 99.30%
94%|█████████▍| 47/50 [09:50<00:59, 19.87s/it]
pool set accuracy: 48.43%
test set accuracy: 47.73%
train set accuracy: 98.98%
96%|█████████▌| 48/50 [10:11<00:40, 20.21s/it]
pool set accuracy: 48.38%
test set accuracy: 49.05%
train set accuracy: 99.21%
98%|█████████▊| 49/50 [10:33<00:20, 20.59s/it]
pool set accuracy: 48.79%
test set accuracy: 49.01%
train set accuracy: 99.51%
100%|██████████| 50/50 [10:55<00:00, 13.11s/it]
pool set accuracy: 50.57%
test set accuracy: 50.98%
train set accuracy: 98.67%
pool set accuracy: 50.54%
test set accuracy: 50.91%
train length: torch.Size([6000]) pool lenght: torch.Size([34000])
RandomStrategy For seed: 2 -----------------------------------------
0%| | 0/50 [00:00<?, ?it/s]
train set accuracy: 100.00%
2%|▏ | 1/50 [00:04<03:27, 4.23s/it]
pool set accuracy: 37.07%
test set accuracy: 36.98%
train set accuracy: 100.00%
4%|▍ | 2/50 [00:08<03:36, 4.51s/it]
pool set accuracy: 36.33%
test set accuracy: 36.39%
train set accuracy: 100.00%
6%|▌ | 3/50 [00:14<03:45, 4.80s/it]
pool set accuracy: 37.08%
test set accuracy: 37.59%
train set accuracy: 99.62%
8%|▊ | 4/50 [00:19<03:55, 5.12s/it]
pool set accuracy: 36.81%
test set accuracy: 36.99%
train set accuracy: 100.00%
10%|█ | 5/50 [00:25<03:56, 5.26s/it]
pool set accuracy: 38.37%
test set accuracy: 38.68%
train set accuracy: 100.00%
12%|█▏ | 6/50 [00:31<04:03, 5.53s/it]
pool set accuracy: 38.73%
test set accuracy: 39.20%
train set accuracy: 99.44%
14%|█▍ | 7/50 [00:37<04:10, 5.84s/it]
pool set accuracy: 38.01%
test set accuracy: 37.38%
train set accuracy: 99.35%
16%|█▌ | 8/50 [00:44<04:20, 6.20s/it]
pool set accuracy: 38.35%
test set accuracy: 38.17%
train set accuracy: 99.00%
18%|█▊ | 9/50 [00:52<04:28, 6.55s/it]
pool set accuracy: 39.31%
test set accuracy: 39.16%
train set accuracy: 100.00%
20%|██ | 10/50 [00:59<04:32, 6.80s/it]
pool set accuracy: 40.93%
test set accuracy: 41.33%
train set accuracy: 96.60%
22%|██▏ | 11/50 [01:07<04:39, 7.18s/it]
pool set accuracy: 39.05%
test set accuracy: 39.37%
train set accuracy: 99.14%
24%|██▍ | 12/50 [01:15<04:46, 7.54s/it]
pool set accuracy: 40.25%
test set accuracy: 40.12%
train set accuracy: 99.68%
26%|██▌ | 13/50 [01:24<04:54, 7.95s/it]
pool set accuracy: 41.62%
test set accuracy: 41.22%
train set accuracy: 99.09%
28%|██▊ | 14/50 [01:33<04:56, 8.24s/it]
pool set accuracy: 40.56%
test set accuracy: 40.86%
train set accuracy: 99.96%
30%|███ | 15/50 [01:42<04:58, 8.54s/it]
pool set accuracy: 42.18%
test set accuracy: 42.26%
train set accuracy: 99.56%
32%|███▏ | 16/50 [01:52<05:02, 8.89s/it]
pool set accuracy: 43.10%
test set accuracy: 43.32%
train set accuracy: 98.46%
34%|███▍ | 17/50 [02:02<05:06, 9.28s/it]
pool set accuracy: 41.04%
test set accuracy: 41.33%
train set accuracy: 99.81%
36%|███▌ | 18/50 [02:13<05:08, 9.64s/it]
pool set accuracy: 43.68%
test set accuracy: 43.55%
train set accuracy: 99.96%
38%|███▊ | 19/50 [02:23<05:07, 9.92s/it]
pool set accuracy: 43.65%
test set accuracy: 43.53%
train set accuracy: 99.28%
40%|████ | 20/50 [02:34<05:06, 10.22s/it]
pool set accuracy: 42.56%
test set accuracy: 43.53%
train set accuracy: 100.00%
42%|████▏ | 21/50 [02:46<05:08, 10.62s/it]
pool set accuracy: 45.20%
test set accuracy: 45.12%
train set accuracy: 99.52%
44%|████▍ | 22/50 [02:58<05:09, 11.05s/it]
pool set accuracy: 44.52%
test set accuracy: 44.42%
train set accuracy: 99.88%
46%|████▌ | 23/50 [03:10<05:05, 11.32s/it]
pool set accuracy: 44.88%
test set accuracy: 45.14%
train set accuracy: 98.61%
48%|████▊ | 24/50 [03:22<05:02, 11.63s/it]
pool set accuracy: 43.54%
test set accuracy: 43.58%
train set accuracy: 100.00%
50%|█████ | 25/50 [03:35<04:59, 12.00s/it]
pool set accuracy: 46.71%
test set accuracy: 46.47%
train set accuracy: 99.69%
52%|█████▏ | 26/50 [03:48<04:56, 12.37s/it]
pool set accuracy: 45.08%
test set accuracy: 45.17%
train set accuracy: 99.19%
54%|█████▍ | 27/50 [04:02<04:53, 12.75s/it]
pool set accuracy: 46.53%
test set accuracy: 46.11%
train set accuracy: 99.46%
56%|█████▌ | 28/50 [04:16<04:47, 13.06s/it]
pool set accuracy: 46.35%
test set accuracy: 45.52%
train set accuracy: 99.87%
58%|█████▊ | 29/50 [04:30<04:42, 13.44s/it]
pool set accuracy: 46.80%
test set accuracy: 47.13%
train set accuracy: 99.08%
60%|██████ | 30/50 [04:45<04:36, 13.83s/it]
pool set accuracy: 47.27%
test set accuracy: 46.46%
train set accuracy: 99.38%
62%|██████▏ | 31/50 [05:00<04:30, 14.25s/it]
pool set accuracy: 47.46%
test set accuracy: 47.09%
train set accuracy: 95.15%
64%|██████▍ | 32/50 [05:15<04:23, 14.64s/it]
pool set accuracy: 46.45%
test set accuracy: 46.71%
train set accuracy: 99.43%
66%|██████▌ | 33/50 [05:31<04:13, 14.91s/it]
pool set accuracy: 47.03%
test set accuracy: 47.15%
train set accuracy: 98.63%
68%|██████▊ | 34/50 [05:47<04:03, 15.23s/it]
pool set accuracy: 45.75%
test set accuracy: 45.94%
train set accuracy: 98.41%
70%|███████ | 35/50 [06:04<03:54, 15.66s/it]
pool set accuracy: 47.16%
test set accuracy: 47.42%
train set accuracy: 99.53%
72%|███████▏ | 36/50 [06:21<03:45, 16.08s/it]
pool set accuracy: 47.53%
test set accuracy: 47.82%
train set accuracy: 99.13%
74%|███████▍ | 37/50 [06:38<03:33, 16.39s/it]
pool set accuracy: 48.05%
test set accuracy: 47.92%
train set accuracy: 99.26%
76%|███████▌ | 38/50 [06:55<03:20, 16.70s/it]
pool set accuracy: 46.91%
test set accuracy: 47.00%
train set accuracy: 99.52%
78%|███████▊ | 39/50 [07:13<03:07, 17.07s/it]
pool set accuracy: 47.27%
test set accuracy: 46.81%
train set accuracy: 99.00%
80%|████████ | 40/50 [07:31<02:54, 17.44s/it]
pool set accuracy: 48.35%
test set accuracy: 48.15%
train set accuracy: 97.28%
82%|████████▏ | 41/50 [07:50<02:41, 17.90s/it]
pool set accuracy: 47.87%
test set accuracy: 48.11%
train set accuracy: 98.92%
84%|████████▍ | 42/50 [08:10<02:26, 18.30s/it]
pool set accuracy: 49.28%
test set accuracy: 48.49%
train set accuracy: 99.60%
86%|████████▌ | 43/50 [08:29<02:10, 18.62s/it]
pool set accuracy: 49.51%
test set accuracy: 49.32%
train set accuracy: 99.32%
88%|████████▊ | 44/50 [08:49<01:53, 18.95s/it]
pool set accuracy: 48.38%
test set accuracy: 48.70%
train set accuracy: 99.61%
90%|█████████ | 45/50 [09:09<01:36, 19.33s/it]
pool set accuracy: 49.62%
test set accuracy: 49.55%
train set accuracy: 99.64%
92%|█████████▏| 46/50 [09:29<01:18, 19.61s/it]
pool set accuracy: 48.61%
test set accuracy: 48.56%
train set accuracy: 98.88%
94%|█████████▍| 47/50 [09:50<00:59, 19.94s/it]
pool set accuracy: 48.15%
test set accuracy: 48.02%
train set accuracy: 99.18%
96%|█████████▌| 48/50 [10:11<00:40, 20.20s/it]
pool set accuracy: 50.68%
test set accuracy: 50.92%
train set accuracy: 98.97%
98%|█████████▊| 49/50 [10:32<00:20, 20.52s/it]
pool set accuracy: 50.27%
test set accuracy: 51.03%
train set accuracy: 99.80%
100%|██████████| 50/50 [10:54<00:00, 13.09s/it]
pool set accuracy: 50.67%
test set accuracy: 51.59%
train set accuracy: 98.83%
pool set accuracy: 50.75%
test set accuracy: 51.61%
train length: torch.Size([6000]) pool lenght: torch.Size([34000])
RandomStrategy For seed: 3 -----------------------------------------
0%| | 0/50 [00:00<?, ?it/s]
train set accuracy: 99.80%
2%|▏ | 1/50 [00:04<03:29, 4.28s/it]
pool set accuracy: 35.89%
test set accuracy: 35.80%
train set accuracy: 99.82%
4%|▍ | 2/50 [00:09<03:37, 4.54s/it]
pool set accuracy: 37.30%
test set accuracy: 37.31%
train set accuracy: 100.00%
6%|▌ | 3/50 [00:14<03:46, 4.82s/it]
pool set accuracy: 36.98%
test set accuracy: 37.20%
train set accuracy: 99.08%
8%|▊ | 4/50 [00:19<03:55, 5.12s/it]
pool set accuracy: 37.41%
test set accuracy: 37.72%
train set accuracy: 100.00%
10%|█ | 5/50 [00:25<03:58, 5.31s/it]
pool set accuracy: 38.41%
test set accuracy: 37.81%
train set accuracy: 100.00%
12%|█▏ | 6/50 [00:31<04:05, 5.57s/it]
pool set accuracy: 38.99%
test set accuracy: 38.59%
train set accuracy: 99.06%
14%|█▍ | 7/50 [00:37<04:11, 5.86s/it]
pool set accuracy: 38.25%
test set accuracy: 38.35%
train set accuracy: 100.00%
16%|█▌ | 8/50 [00:44<04:20, 6.20s/it]
pool set accuracy: 39.57%
test set accuracy: 39.32%
train set accuracy: 97.83%
18%|█▊ | 9/50 [00:52<04:28, 6.56s/it]
pool set accuracy: 38.43%
test set accuracy: 38.65%
train set accuracy: 96.53%
20%|██ | 10/50 [00:59<04:32, 6.81s/it]
pool set accuracy: 39.05%
test set accuracy: 38.99%
train set accuracy: 100.00%
22%|██▏ | 11/50 [01:07<04:37, 7.13s/it]
pool set accuracy: 40.94%
test set accuracy: 41.22%
train set accuracy: 99.33%
24%|██▍ | 12/50 [01:15<04:43, 7.47s/it]
pool set accuracy: 40.68%
test set accuracy: 40.56%
train set accuracy: 99.95%
26%|██▌ | 13/50 [01:24<04:50, 7.86s/it]
pool set accuracy: 42.06%
test set accuracy: 42.33%
train set accuracy: 100.00%
28%|██▊ | 14/50 [01:33<04:53, 8.14s/it]
pool set accuracy: 41.50%
test set accuracy: 42.03%
train set accuracy: 99.83%
30%|███ | 15/50 [01:42<04:55, 8.45s/it]
pool set accuracy: 41.11%
test set accuracy: 41.32%
train set accuracy: 99.68%
32%|███▏ | 16/50 [01:52<04:59, 8.82s/it]
pool set accuracy: 42.66%
test set accuracy: 42.49%
train set accuracy: 99.27%
34%|███▍ | 17/50 [02:02<05:04, 9.22s/it]
pool set accuracy: 42.14%
test set accuracy: 42.21%
train set accuracy: 97.74%
36%|███▌ | 18/50 [02:12<05:08, 9.63s/it]
pool set accuracy: 41.28%
test set accuracy: 40.80%
train set accuracy: 100.00%
38%|███▊ | 19/50 [02:23<05:07, 9.91s/it]
pool set accuracy: 45.58%
test set accuracy: 46.01%
train set accuracy: 98.24%
40%|████ | 20/50 [02:34<05:06, 10.22s/it]
pool set accuracy: 41.72%
test set accuracy: 42.25%
train set accuracy: 99.67%
42%|████▏ | 21/50 [02:45<05:07, 10.61s/it]
pool set accuracy: 42.78%
test set accuracy: 42.68%
train set accuracy: 99.23%
44%|████▍ | 22/50 [02:57<05:08, 11.00s/it]
pool set accuracy: 43.62%
test set accuracy: 43.41%
train set accuracy: 98.53%
46%|████▌ | 23/50 [03:09<05:05, 11.31s/it]
pool set accuracy: 41.91%
test set accuracy: 41.89%
train set accuracy: 99.94%
48%|████▊ | 24/50 [03:22<05:03, 11.67s/it]
pool set accuracy: 44.48%
test set accuracy: 44.77%
train set accuracy: 99.32%
50%|█████ | 25/50 [03:35<05:00, 12.03s/it]
pool set accuracy: 44.60%
test set accuracy: 44.26%
train set accuracy: 99.91%
52%|█████▏ | 26/50 [03:48<04:56, 12.37s/it]
pool set accuracy: 46.16%
test set accuracy: 45.97%
train set accuracy: 99.61%
54%|█████▍ | 27/50 [04:01<04:52, 12.72s/it]
pool set accuracy: 45.07%
test set accuracy: 44.92%
train set accuracy: 99.35%
56%|█████▌ | 28/50 [04:15<04:46, 13.04s/it]
pool set accuracy: 45.71%
test set accuracy: 45.42%
train set accuracy: 99.79%
58%|█████▊ | 29/50 [04:29<04:40, 13.37s/it]
pool set accuracy: 45.21%
test set accuracy: 45.19%
train set accuracy: 99.28%
60%|██████ | 30/50 [04:44<04:35, 13.78s/it]
pool set accuracy: 45.39%
test set accuracy: 45.01%
train set accuracy: 98.10%
62%|██████▏ | 31/50 [04:59<04:30, 14.22s/it]
pool set accuracy: 45.05%
test set accuracy: 45.19%
train set accuracy: 98.37%
64%|██████▍ | 32/50 [05:15<04:23, 14.66s/it]
pool set accuracy: 47.32%
test set accuracy: 47.18%
train set accuracy: 99.83%
66%|██████▌ | 33/50 [05:31<04:14, 14.96s/it]
pool set accuracy: 47.89%
test set accuracy: 47.47%
train set accuracy: 99.77%
68%|██████▊ | 34/50 [05:46<04:03, 15.21s/it]
pool set accuracy: 46.24%
test set accuracy: 45.42%
train set accuracy: 98.86%
70%|███████ | 35/50 [06:03<03:53, 15.55s/it]
pool set accuracy: 46.10%
test set accuracy: 45.33%
train set accuracy: 99.47%
72%|███████▏ | 36/50 [06:20<03:43, 15.99s/it]
pool set accuracy: 47.35%
test set accuracy: 47.38%
train set accuracy: 99.80%
74%|███████▍ | 37/50 [06:37<03:31, 16.30s/it]
pool set accuracy: 47.29%
test set accuracy: 47.27%
train set accuracy: 99.45%
76%|███████▌ | 38/50 [06:54<03:19, 16.65s/it]
pool set accuracy: 47.19%
test set accuracy: 47.61%
train set accuracy: 98.71%
78%|███████▊ | 39/50 [07:12<03:07, 17.04s/it]
pool set accuracy: 46.17%
test set accuracy: 46.40%
train set accuracy: 99.86%
80%|████████ | 40/50 [07:30<02:53, 17.38s/it]
pool set accuracy: 48.77%
test set accuracy: 49.12%
train set accuracy: 99.98%
82%|████████▏ | 41/50 [07:49<02:40, 17.81s/it]
pool set accuracy: 50.20%
test set accuracy: 50.24%
train set accuracy: 99.53%
84%|████████▍ | 42/50 [08:08<02:24, 18.10s/it]
pool set accuracy: 47.43%
test set accuracy: 47.42%
train set accuracy: 98.56%
86%|████████▌ | 43/50 [08:27<02:08, 18.43s/it]
pool set accuracy: 48.58%
test set accuracy: 47.99%
train set accuracy: 98.85%
88%|████████▊ | 44/50 [08:47<01:52, 18.78s/it]
pool set accuracy: 48.15%
test set accuracy: 47.79%
train set accuracy: 98.65%
90%|█████████ | 45/50 [09:07<01:35, 19.17s/it]
pool set accuracy: 48.90%
test set accuracy: 48.65%
train set accuracy: 99.02%
92%|█████████▏| 46/50 [09:27<01:17, 19.47s/it]
pool set accuracy: 48.53%
test set accuracy: 48.22%
train set accuracy: 99.57%
94%|█████████▍| 47/50 [09:48<00:59, 19.88s/it]
pool set accuracy: 49.10%
test set accuracy: 48.96%
train set accuracy: 98.89%
96%|█████████▌| 48/50 [10:09<00:40, 20.20s/it]
pool set accuracy: 50.22%
test set accuracy: 50.13%
train set accuracy: 99.76%
98%|█████████▊| 49/50 [10:30<00:20, 20.56s/it]
pool set accuracy: 50.31%
test set accuracy: 50.01%
train set accuracy: 98.97%
100%|██████████| 50/50 [10:52<00:00, 13.05s/it]
pool set accuracy: 50.53%
test set accuracy: 49.58%
train set accuracy: 98.03%
pool set accuracy: 50.60%
test set accuracy: 49.55%
train length: torch.Size([6000]) pool lenght: torch.Size([34000])
RandomStrategy For seed: 4 -----------------------------------------
0%| | 0/50 [00:00<?, ?it/s]
train set accuracy: 100.00%
2%|▏ | 1/50 [00:04<03:30, 4.30s/it]
pool set accuracy: 37.15%
test set accuracy: 37.13%
train set accuracy: 98.73%
4%|▍ | 2/50 [00:09<03:37, 4.54s/it]
pool set accuracy: 35.28%
test set accuracy: 35.24%
train set accuracy: 100.00%
6%|▌ | 3/50 [00:14<03:45, 4.80s/it]
pool set accuracy: 39.10%
test set accuracy: 39.35%
train set accuracy: 97.69%
8%|▊ | 4/50 [00:19<03:53, 5.08s/it]
pool set accuracy: 36.62%
test set accuracy: 35.85%
train set accuracy: 100.00%
10%|█ | 5/50 [00:25<03:56, 5.25s/it]
pool set accuracy: 38.76%
test set accuracy: 39.25%
train set accuracy: 100.00%
12%|█▏ | 6/50 [00:31<04:02, 5.52s/it]
pool set accuracy: 39.26%
test set accuracy: 38.87%
train set accuracy: 100.00%
14%|█▍ | 7/50 [00:37<04:11, 5.84s/it]
pool set accuracy: 39.17%
test set accuracy: 39.69%
train set accuracy: 99.59%
16%|█▌ | 8/50 [00:44<04:19, 6.18s/it]
pool set accuracy: 39.60%
test set accuracy: 39.82%
train set accuracy: 96.50%
18%|█▊ | 9/50 [00:51<04:28, 6.55s/it]
pool set accuracy: 38.41%
test set accuracy: 38.51%
train set accuracy: 100.00%
20%|██ | 10/50 [00:59<04:31, 6.79s/it]
pool set accuracy: 40.59%
test set accuracy: 41.12%
train set accuracy: 100.00%
22%|██▏ | 11/50 [01:07<04:38, 7.13s/it]
pool set accuracy: 41.53%
test set accuracy: 41.60%
train set accuracy: 99.67%
24%|██▍ | 12/50 [01:15<04:44, 7.48s/it]
pool set accuracy: 40.49%
test set accuracy: 40.06%
train set accuracy: 99.95%
26%|██▌ | 13/50 [01:24<04:49, 7.84s/it]
pool set accuracy: 41.60%
test set accuracy: 41.17%
train set accuracy: 99.52%
28%|██▊ | 14/50 [01:32<04:51, 8.10s/it]
pool set accuracy: 40.09%
test set accuracy: 40.39%
train set accuracy: 100.00%
30%|███ | 15/50 [01:42<04:55, 8.43s/it]
pool set accuracy: 42.34%
test set accuracy: 42.35%
train set accuracy: 100.00%
32%|███▏ | 16/50 [01:51<05:00, 8.82s/it]
pool set accuracy: 41.83%
test set accuracy: 42.60%
train set accuracy: 99.46%
34%|███▍ | 17/50 [02:01<05:04, 9.23s/it]
pool set accuracy: 42.99%
test set accuracy: 42.93%
train set accuracy: 98.26%
36%|███▌ | 18/50 [02:12<05:08, 9.63s/it]
pool set accuracy: 41.61%
test set accuracy: 41.24%
train set accuracy: 99.96%
38%|███▊ | 19/50 [02:23<05:07, 9.92s/it]
pool set accuracy: 42.74%
test set accuracy: 43.03%
train set accuracy: 100.00%
40%|████ | 20/50 [02:34<05:06, 10.21s/it]
pool set accuracy: 43.67%
test set accuracy: 43.20%
train set accuracy: 100.00%
42%|████▏ | 21/50 [02:45<05:06, 10.56s/it]
pool set accuracy: 44.78%
test set accuracy: 45.42%
train set accuracy: 99.74%
44%|████▍ | 22/50 [02:57<05:05, 10.93s/it]
pool set accuracy: 43.89%
test set accuracy: 43.42%
train set accuracy: 99.72%
46%|████▌ | 23/50 [03:08<05:01, 11.18s/it]
pool set accuracy: 42.16%
test set accuracy: 42.58%
train set accuracy: 99.97%
48%|████▊ | 24/50 [03:21<05:00, 11.55s/it]
pool set accuracy: 44.74%
test set accuracy: 44.31%
train set accuracy: 99.53%
50%|█████ | 25/50 [03:34<04:58, 11.93s/it]
pool set accuracy: 44.83%
test set accuracy: 44.15%
train set accuracy: 99.83%
52%|█████▏ | 26/50 [03:47<04:55, 12.33s/it]
pool set accuracy: 44.77%
test set accuracy: 44.92%
train set accuracy: 99.72%
54%|█████▍ | 27/50 [04:00<04:52, 12.70s/it]
pool set accuracy: 45.96%
test set accuracy: 46.24%
train set accuracy: 98.95%
56%|█████▌ | 28/50 [04:14<04:45, 12.98s/it]
pool set accuracy: 43.96%
test set accuracy: 43.83%
train set accuracy: 99.47%
58%|█████▊ | 29/50 [04:28<04:39, 13.32s/it]
pool set accuracy: 45.95%
test set accuracy: 46.07%
train set accuracy: 99.87%
60%|██████ | 30/50 [04:43<04:34, 13.71s/it]
pool set accuracy: 46.82%
test set accuracy: 46.68%
train set accuracy: 99.18%
62%|██████▏ | 31/50 [04:58<04:29, 14.18s/it]
pool set accuracy: 45.98%
test set accuracy: 45.68%
train set accuracy: 98.71%
64%|██████▍ | 32/50 [05:14<04:22, 14.60s/it]
pool set accuracy: 45.79%
test set accuracy: 45.53%
train set accuracy: 99.33%
66%|██████▌ | 33/50 [05:29<04:13, 14.91s/it]
pool set accuracy: 47.30%
test set accuracy: 46.70%
train set accuracy: 97.79%
68%|██████▊ | 34/50 [05:45<04:03, 15.21s/it]
pool set accuracy: 45.08%
test set accuracy: 44.92%
train set accuracy: 99.11%
70%|███████ | 35/50 [06:01<03:52, 15.52s/it]
pool set accuracy: 45.37%
test set accuracy: 45.58%
train set accuracy: 99.98%
72%|███████▏ | 36/50 [06:18<03:42, 15.87s/it]
pool set accuracy: 47.57%
test set accuracy: 47.87%
train set accuracy: 99.59%
74%|███████▍ | 37/50 [06:35<03:30, 16.20s/it]
pool set accuracy: 48.28%
test set accuracy: 48.73%
train set accuracy: 99.94%
76%|███████▌ | 38/50 [06:53<03:18, 16.58s/it]
pool set accuracy: 47.44%
test set accuracy: 47.51%
train set accuracy: 99.35%
78%|███████▊ | 39/50 [07:10<03:06, 16.97s/it]
pool set accuracy: 48.37%
test set accuracy: 48.11%
train set accuracy: 99.82%
80%|████████ | 40/50 [07:29<02:53, 17.33s/it]
pool set accuracy: 48.07%
test set accuracy: 47.68%
train set accuracy: 98.80%
82%|████████▏ | 41/50 [07:47<02:39, 17.73s/it]
pool set accuracy: 47.57%
test set accuracy: 48.01%
train set accuracy: 99.47%
84%|████████▍ | 42/50 [08:06<02:24, 18.03s/it]
pool set accuracy: 48.10%
test set accuracy: 47.86%
train set accuracy: 99.40%
86%|████████▌ | 43/50 [08:25<02:08, 18.38s/it]
pool set accuracy: 47.59%
test set accuracy: 47.60%
train set accuracy: 99.74%
88%|████████▊ | 44/50 [08:45<01:52, 18.75s/it]
pool set accuracy: 48.26%
test set accuracy: 48.35%
train set accuracy: 98.04%
90%|█████████ | 45/50 [09:05<01:35, 19.15s/it]
pool set accuracy: 48.60%
test set accuracy: 48.44%
train set accuracy: 99.45%
92%|█████████▏| 46/50 [09:25<01:17, 19.36s/it]
pool set accuracy: 48.88%
test set accuracy: 48.99%
train set accuracy: 99.73%
94%|█████████▍| 47/50 [09:45<00:59, 19.68s/it]
pool set accuracy: 48.19%
test set accuracy: 48.47%
train set accuracy: 98.56%
96%|█████████▌| 48/50 [10:06<00:40, 20.07s/it]
pool set accuracy: 49.66%
test set accuracy: 49.45%
train set accuracy: 99.72%
98%|█████████▊| 49/50 [10:28<00:20, 20.44s/it]
pool set accuracy: 49.89%
test set accuracy: 49.45%
train set accuracy: 99.12%
100%|██████████| 50/50 [10:49<00:00, 13.00s/it]
pool set accuracy: 51.40%
test set accuracy: 50.55%
train set accuracy: 98.27%
pool set accuracy: 51.48%
test set accuracy: 50.55%
train length: torch.Size([6000]) pool lenght: torch.Size([34000])
{'untrain_acc': {'train': 0.07700000703334808,
'pool': 0.08366666734218597,
'test': 0.08579999953508377},
'train_1000': {'train': 1.0,
'pool': 0.36056411266326904,
'test': 0.3624500036239624},
'train_1000_pool_5000': {'train': 0.9920000433921814,
'pool': 0.4998205304145813,
'test': 0.4273499846458435},
'train_1000_pool_39000': {'train': 0.9960000514984131,
'pool': 0.994282066822052,
'test': 0.616100013256073},
'train_1000_pool_query_100_iter_50_DiversityStrategy_seed_0': {'train': 0.9881666302680969,
'pool': 0.47435295581817627,
'test': 0.4817499816417694}}
{'untrain_acc': {'train': 0.07700000703334808,
'pool': 0.08366666734218597,
'test': 0.08579999953508377},
'train_1000': {'train': 1.0,
'pool': 0.36056411266326904,
'test': 0.3624500036239624},
'train_1000_pool_5000': {'train': 0.9920000433921814,
'pool': 0.4998205304145813,
'test': 0.4273499846458435},
'train_1000_pool_39000': {'train': 0.9960000514984131,
'pool': 0.994282066822052,
'test': 0.616100013256073},
'train_1000_pool_query_100_iter_50_DiversityStrategy_seed_0': {'train': 0.9881666302680969,
'pool': 0.47435295581817627,
'test': 0.4817499816417694},
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_0': {'train': 0.9863333106040955,
'pool': 0.5091176629066467,
'test': 0.5091999769210815},
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_1': {'train': 0.9866666793823242,
'pool': 0.5054118037223816,
'test': 0.5090500116348267},
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_2': {'train': 0.9883333444595337,
'pool': 0.5074999928474426,
'test': 0.5161499977111816},
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_3': {'train': 0.9803333282470703,
'pool': 0.5059705972671509,
'test': 0.49549999833106995},
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_4': {'train': 0.9826666712760925,
'pool': 0.5147647261619568,
'test': 0.5054500102996826}}
file_path = "/home/jaiswalsuraj/suraj_work/ASTRA/notebooks/al/accuracy_summary_with_AL.json"
# with open(file_path, 'w') as json_file:
# json.dump(accuracy_summary, json_file)
# print(f"Accuracy summary has been saved to {file_path}.")
# Load the accuracy summary from the JSON file
with open(file_path, 'r') as json_file:
accuracy_summary = json.load(json_file)
print(f"Accuracy summary has been loaded from {file_path}.")
Accuracy summary has been loaded from /home/jaiswalsuraj/suraj_work/ASTRA/notebooks/al/accuracy_summary_with_AL.json.
{'train_1000_pool_query_100_iter_50_DiversityStrategy_seed_0': [{'train': 1.0,
'pool': 0.3618461489677429,
'test': 0.36434999108314514},
{'train': 1.0, 'pool': 0.3619794249534607, 'test': 0.3668999969959259},
{'train': 1.0, 'pool': 0.37159794569015503, 'test': 0.3714499771595001},
{'train': 0.9953846335411072,
'pool': 0.364806205034256,
'test': 0.3671000003814697},
{'train': 1.0, 'pool': 0.3876684010028839, 'test': 0.3871999979019165},
{'train': 1.0, 'pool': 0.39098700881004333, 'test': 0.3937999904155731},
{'train': 1.0, 'pool': 0.39205729961395264, 'test': 0.39089998602867126},
{'train': 0.9982352256774902,
'pool': 0.38067886233329773,
'test': 0.382099986076355},
{'train': 0.9961111545562744,
'pool': 0.38900521397590637,
'test': 0.3938499987125397},
{'train': 0.9999999403953552,
'pool': 0.397480309009552,
'test': 0.3989499807357788},
{'train': 1.0, 'pool': 0.3950789272785187, 'test': 0.3951999843120575},
{'train': 0.996666669845581,
'pool': 0.39124009013175964,
'test': 0.3917499780654907},
{'train': 1.0, 'pool': 0.4085714519023895, 'test': 0.40915000438690186},
{'train': 1.0, 'pool': 0.39838194847106934, 'test': 0.3985999822616577},
{'train': 0.9637500047683716,
'pool': 0.38851064443588257,
'test': 0.38464999198913574},
{'train': 0.9991999864578247,
'pool': 0.4029066562652588,
'test': 0.4052499830722809},
{'train': 0.9546154141426086,
'pool': 0.4025668501853943,
'test': 0.4056999981403351},
{'train': 0.9929629564285278,
'pool': 0.39745309948921204,
'test': 0.404449999332428},
{'train': 0.9960713982582092,
'pool': 0.39954301714897156,
'test': 0.4059999883174896},
{'train': 1.0, 'pool': 0.41716980934143066, 'test': 0.42239999771118164},
{'train': 0.9996666312217712,
'pool': 0.41810810565948486,
'test': 0.42225000262260437},
{'train': 0.9967742562294006,
'pool': 0.4106775224208832,
'test': 0.41999998688697815},
{'train': 0.9981249570846558,
'pool': 0.42331522703170776,
'test': 0.4236999750137329},
{'train': 0.9900000095367432,
'pool': 0.3979291319847107,
'test': 0.4129999876022339},
{'train': 0.994705855846405,
'pool': 0.4214754104614258,
'test': 0.42944997549057007},
{'train': 0.9974285364151001,
'pool': 0.40052053332328796,
'test': 0.4045499861240387},
{'train': 0.9930555820465088,
'pool': 0.4380769431591034,
'test': 0.43764999508857727},
{'train': 0.9956756830215454,
'pool': 0.419393926858902,
'test': 0.4224499762058258},
{'train': 0.9894736409187317,
'pool': 0.41287294030189514,
'test': 0.42159998416900635},
{'train': 0.9982050657272339,
'pool': 0.4328254759311676,
'test': 0.43584999442100525},
{'train': 0.9945000410079956,
'pool': 0.4129444360733032,
'test': 0.4221999943256378},
{'train': 0.9992682933807373,
'pool': 0.44217267632484436,
'test': 0.4493499994277954},
{'train': 0.9947618842124939,
'pool': 0.43572625517845154,
'test': 0.44349998235702515},
{'train': 0.987441897392273,
'pool': 0.42532214522361755,
'test': 0.429099977016449},
{'train': 0.9954545497894287,
'pool': 0.4360112249851227,
'test': 0.4420499801635742},
{'train': 0.9913333654403687,
'pool': 0.4505915343761444,
'test': 0.4615999758243561},
{'train': 0.9991304278373718,
'pool': 0.4479943513870239,
'test': 0.4564499855041504},
{'train': 0.9923403859138489,
'pool': 0.4277053773403168,
'test': 0.4378499984741211},
{'train': 0.9985417127609253,
'pool': 0.45193183422088623,
'test': 0.4591499865055084},
{'train': 0.9977551102638245,
'pool': 0.4616524279117584,
'test': 0.4721499979496002},
{'train': 0.9991999864578247,
'pool': 0.4643999934196472,
'test': 0.46539998054504395},
{'train': 0.9954901933670044,
'pool': 0.4556446969509125,
'test': 0.4645499885082245},
{'train': 0.998846173286438,
'pool': 0.4521264433860779,
'test': 0.4615999758243561},
{'train': 0.9898113012313843,
'pool': 0.43855908513069153,
'test': 0.44579997658729553},
{'train': 0.9900000095367432,
'pool': 0.45942196249961853,
'test': 0.46469998359680176},
{'train': 0.9934545159339905,
'pool': 0.46942028403282166,
'test': 0.4749999940395355},
{'train': 0.9991071224212646,
'pool': 0.4689534902572632,
'test': 0.4870999753475189},
{'train': 0.9984210729598999,
'pool': 0.4750145673751831,
'test': 0.4806499779224396},
{'train': 0.9967241883277893,
'pool': 0.4726608097553253,
'test': 0.48159998655319214},
{'train': 0.9984745979309082,
'pool': 0.4733137786388397,
'test': 0.4803999960422516},
{'train': 0.9881666302680969,
'pool': 0.47435295581817627,
'test': 0.4817499816417694}]}
{'train_1000_pool_query_100_iter_50_DiversityStrategy_seed_0': [{'train': 1.0,
'pool': 0.3618461489677429,
'test': 0.36434999108314514},
{'train': 1.0, 'pool': 0.3619794249534607, 'test': 0.3668999969959259},
{'train': 1.0, 'pool': 0.37159794569015503, 'test': 0.3714499771595001},
{'train': 0.9953846335411072,
'pool': 0.364806205034256,
'test': 0.3671000003814697},
{'train': 1.0, 'pool': 0.3876684010028839, 'test': 0.3871999979019165},
{'train': 1.0, 'pool': 0.39098700881004333, 'test': 0.3937999904155731},
{'train': 1.0, 'pool': 0.39205729961395264, 'test': 0.39089998602867126},
{'train': 0.9982352256774902,
'pool': 0.38067886233329773,
'test': 0.382099986076355},
{'train': 0.9961111545562744,
'pool': 0.38900521397590637,
'test': 0.3938499987125397},
{'train': 0.9999999403953552,
'pool': 0.397480309009552,
'test': 0.3989499807357788},
{'train': 1.0, 'pool': 0.3950789272785187, 'test': 0.3951999843120575},
{'train': 0.996666669845581,
'pool': 0.39124009013175964,
'test': 0.3917499780654907},
{'train': 1.0, 'pool': 0.4085714519023895, 'test': 0.40915000438690186},
{'train': 1.0, 'pool': 0.39838194847106934, 'test': 0.3985999822616577},
{'train': 0.9637500047683716,
'pool': 0.38851064443588257,
'test': 0.38464999198913574},
{'train': 0.9991999864578247,
'pool': 0.4029066562652588,
'test': 0.4052499830722809},
{'train': 0.9546154141426086,
'pool': 0.4025668501853943,
'test': 0.4056999981403351},
{'train': 0.9929629564285278,
'pool': 0.39745309948921204,
'test': 0.404449999332428},
{'train': 0.9960713982582092,
'pool': 0.39954301714897156,
'test': 0.4059999883174896},
{'train': 1.0, 'pool': 0.41716980934143066, 'test': 0.42239999771118164},
{'train': 0.9996666312217712,
'pool': 0.41810810565948486,
'test': 0.42225000262260437},
{'train': 0.9967742562294006,
'pool': 0.4106775224208832,
'test': 0.41999998688697815},
{'train': 0.9981249570846558,
'pool': 0.42331522703170776,
'test': 0.4236999750137329},
{'train': 0.9900000095367432,
'pool': 0.3979291319847107,
'test': 0.4129999876022339},
{'train': 0.994705855846405,
'pool': 0.4214754104614258,
'test': 0.42944997549057007},
{'train': 0.9974285364151001,
'pool': 0.40052053332328796,
'test': 0.4045499861240387},
{'train': 0.9930555820465088,
'pool': 0.4380769431591034,
'test': 0.43764999508857727},
{'train': 0.9956756830215454,
'pool': 0.419393926858902,
'test': 0.4224499762058258},
{'train': 0.9894736409187317,
'pool': 0.41287294030189514,
'test': 0.42159998416900635},
{'train': 0.9982050657272339,
'pool': 0.4328254759311676,
'test': 0.43584999442100525},
{'train': 0.9945000410079956,
'pool': 0.4129444360733032,
'test': 0.4221999943256378},
{'train': 0.9992682933807373,
'pool': 0.44217267632484436,
'test': 0.4493499994277954},
{'train': 0.9947618842124939,
'pool': 0.43572625517845154,
'test': 0.44349998235702515},
{'train': 0.987441897392273,
'pool': 0.42532214522361755,
'test': 0.429099977016449},
{'train': 0.9954545497894287,
'pool': 0.4360112249851227,
'test': 0.4420499801635742},
{'train': 0.9913333654403687,
'pool': 0.4505915343761444,
'test': 0.4615999758243561},
{'train': 0.9991304278373718,
'pool': 0.4479943513870239,
'test': 0.4564499855041504},
{'train': 0.9923403859138489,
'pool': 0.4277053773403168,
'test': 0.4378499984741211},
{'train': 0.9985417127609253,
'pool': 0.45193183422088623,
'test': 0.4591499865055084},
{'train': 0.9977551102638245,
'pool': 0.4616524279117584,
'test': 0.4721499979496002},
{'train': 0.9991999864578247,
'pool': 0.4643999934196472,
'test': 0.46539998054504395},
{'train': 0.9954901933670044,
'pool': 0.4556446969509125,
'test': 0.4645499885082245},
{'train': 0.998846173286438,
'pool': 0.4521264433860779,
'test': 0.4615999758243561},
{'train': 0.9898113012313843,
'pool': 0.43855908513069153,
'test': 0.44579997658729553},
{'train': 0.9900000095367432,
'pool': 0.45942196249961853,
'test': 0.46469998359680176},
{'train': 0.9934545159339905,
'pool': 0.46942028403282166,
'test': 0.4749999940395355},
{'train': 0.9991071224212646,
'pool': 0.4689534902572632,
'test': 0.4870999753475189},
{'train': 0.9984210729598999,
'pool': 0.4750145673751831,
'test': 0.4806499779224396},
{'train': 0.9967241883277893,
'pool': 0.4726608097553253,
'test': 0.48159998655319214},
{'train': 0.9984745979309082,
'pool': 0.4733137786388397,
'test': 0.4803999960422516},
{'train': 0.9881666302680969,
'pool': 0.47435295581817627,
'test': 0.4817499816417694}],
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_0': [{'train': 0.999000072479248,
'pool': 0.3675128221511841,
'test': 0.3728500008583069},
{'train': 1.0, 'pool': 0.37133675813674927, 'test': 0.36949998140335083},
{'train': 1.0, 'pool': 0.3735567033290863, 'test': 0.37299999594688416},
{'train': 0.9930769205093384,
'pool': 0.36824291944503784,
'test': 0.36890000104904175},
{'train': 1.0, 'pool': 0.3925906717777252, 'test': 0.39489999413490295},
{'train': 1.0, 'pool': 0.40210390090942383, 'test': 0.40174999833106995},
{'train': 0.9993749856948853,
'pool': 0.40036460757255554,
'test': 0.4025999903678894},
{'train': 0.9894117116928101,
'pool': 0.3873368203639984,
'test': 0.38509997725486755},
{'train': 0.991111159324646,
'pool': 0.39408376812934875,
'test': 0.3938499987125397},
{'train': 0.9999999403953552,
'pool': 0.4166141748428345,
'test': 0.4092999994754791},
{'train': 1.0, 'pool': 0.40684211254119873, 'test': 0.41054999828338623},
{'train': 0.9933333396911621,
'pool': 0.40791556239128113,
'test': 0.4034999907016754},
{'train': 0.9940909147262573,
'pool': 0.41952383518218994,
'test': 0.42374998331069946},
{'train': 1.0, 'pool': 0.4301856756210327, 'test': 0.4295499920845032},
{'train': 0.9979166984558105,
'pool': 0.42260637879371643,
'test': 0.413349986076355},
{'train': 0.9907999634742737,
'pool': 0.42471998929977417,
'test': 0.4295499920845032},
{'train': 1.0, 'pool': 0.45048126578330994, 'test': 0.44644999504089355},
{'train': 0.9914814829826355,
'pool': 0.4382841885089874,
'test': 0.4352499842643738},
{'train': 0.9646428227424622,
'pool': 0.4097042977809906,
'test': 0.4078499972820282},
{'train': 0.9806897044181824,
'pool': 0.42175203561782837,
'test': 0.42100000381469727},
{'train': 0.9599999785423279,
'pool': 0.42367565631866455,
'test': 0.42399999499320984},
{'train': 1.0, 'pool': 0.4565311670303345, 'test': 0.4589499831199646},
{'train': 0.9965624809265137,
'pool': 0.45220109820365906,
'test': 0.4542499780654907},
{'train': 1.0, 'pool': 0.4608174264431, 'test': 0.461899995803833},
{'train': 0.9999999403953552,
'pool': 0.455081969499588,
'test': 0.45739999413490295},
{'train': 0.9968571066856384,
'pool': 0.4596712291240692,
'test': 0.4596499800682068},
{'train': 0.9997222423553467,
'pool': 0.4709615409374237,
'test': 0.46114999055862427},
{'train': 0.9967567920684814,
'pool': 0.46294763684272766,
'test': 0.45719999074935913},
{'train': 0.9986841678619385,
'pool': 0.4745856523513794,
'test': 0.47244998812675476},
{'train': 0.99692302942276,
'pool': 0.46623268723487854,
'test': 0.4610999822616577},
{'train': 0.999500036239624,
'pool': 0.4758055508136749,
'test': 0.4761999845504761},
{'train': 0.9680488109588623,
'pool': 0.47164344787597656,
'test': 0.46984997391700745},
{'train': 0.9835714101791382,
'pool': 0.4681564271450043,
'test': 0.46949997544288635},
{'train': 0.992093026638031,
'pool': 0.47156864404678345,
'test': 0.46654999256134033},
{'train': 0.9761363863945007,
'pool': 0.45396068692207336,
'test': 0.4526999890804291},
{'train': 0.9971111416816711,
'pool': 0.47107040882110596,
'test': 0.47450000047683716},
{'train': 0.9956521987915039,
'pool': 0.47180789709091187,
'test': 0.47200000286102295},
{'train': 0.9970212578773499,
'pool': 0.4791501462459564,
'test': 0.4790499806404114},
{'train': 0.9922916889190674,
'pool': 0.4792329668998718,
'test': 0.47749999165534973},
{'train': 0.991224467754364,
'pool': 0.48575499653816223,
'test': 0.48374998569488525},
{'train': 0.9932000041007996,
'pool': 0.48111429810523987,
'test': 0.485649973154068},
{'train': 0.9815686345100403,
'pool': 0.47280803322792053,
'test': 0.47164997458457947},
{'train': 0.9959615468978882,
'pool': 0.49054598808288574,
'test': 0.4890500009059906},
{'train': 0.9883018732070923,
'pool': 0.4864265024662018,
'test': 0.4822999835014343},
{'train': 0.993148148059845,
'pool': 0.47979769110679626,
'test': 0.4822499752044678},
{'train': 0.9872726798057556,
'pool': 0.49869564175605774,
'test': 0.499099999666214},
{'train': 0.9960713982582092,
'pool': 0.49802327156066895,
'test': 0.5030499696731567},
{'train': 0.9873684644699097,
'pool': 0.4902332127094269,
'test': 0.48499998450279236},
{'train': 0.992241382598877,
'pool': 0.5022221803665161,
'test': 0.5019999742507935},
{'train': 0.993898332118988,
'pool': 0.5097947120666504,
'test': 0.5098000168800354},
{'train': 0.9863333106040955,
'pool': 0.5091176629066467,
'test': 0.5091999769210815}],
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_1': [{'train': 1.0,
'pool': 0.3571794927120209,
'test': 0.36409997940063477},
{'train': 1.0, 'pool': 0.37393316626548767, 'test': 0.3739999830722809},
{'train': 1.0, 'pool': 0.37628865242004395, 'test': 0.37959998846054077},
{'train': 1.0, 'pool': 0.3801550567150116, 'test': 0.3772999942302704},
{'train': 1.0, 'pool': 0.39502590894699097, 'test': 0.3906500041484833},
{'train': 1.0, 'pool': 0.3978441655635834, 'test': 0.40084999799728394},
{'train': 1.0, 'pool': 0.39645835757255554, 'test': 0.3959999978542328},
{'train': 0.9988234639167786,
'pool': 0.4005222022533417,
'test': 0.398749977350235},
{'train': 0.9905555844306946,
'pool': 0.39130887389183044,
'test': 0.39774999022483826},
{'train': 0.9999999403953552,
'pool': 0.41850391030311584,
'test': 0.41669997572898865},
{'train': 0.9855000376701355,
'pool': 0.394473671913147,
'test': 0.39879998564720154},
{'train': 1.0, 'pool': 0.4205540716648102, 'test': 0.4186999797821045},
{'train': 0.9831818342208862,
'pool': 0.4047619104385376,
'test': 0.40219998359680176},
{'train': 0.9960869550704956,
'pool': 0.40973472595214844,
'test': 0.41179999709129333},
{'train': 0.9979166984558105,
'pool': 0.41430848836898804,
'test': 0.4219000041484833},
{'train': 1.0, 'pool': 0.42322665452957153, 'test': 0.4186500012874603},
{'train': 0.9984615445137024,
'pool': 0.4232085347175598,
'test': 0.42170000076293945},
{'train': 0.9996296167373657,
'pool': 0.44037532806396484,
'test': 0.44110000133514404},
{'train': 0.9996428489685059,
'pool': 0.439301073551178,
'test': 0.43184998631477356},
{'train': 1.0, 'pool': 0.44382748007774353, 'test': 0.44369998574256897},
{'train': 0.996999979019165,
'pool': 0.43956756591796875,
'test': 0.4389999806880951},
{'train': 0.9945161938667297,
'pool': 0.4347154498100281,
'test': 0.42729997634887695},
{'train': 0.9918749928474426,
'pool': 0.4323098063468933,
'test': 0.4309999942779541},
{'train': 0.9972727298736572,
'pool': 0.43735694885253906,
'test': 0.43369999527931213},
{'train': 0.994705855846405,
'pool': 0.43393445014953613,
'test': 0.43779999017715454},
{'train': 0.998285710811615,
'pool': 0.4557808041572571,
'test': 0.45034998655319214},
{'train': 0.9688889384269714,
'pool': 0.4421703517436981,
'test': 0.4420499801635742},
{'train': 0.9997297525405884,
'pool': 0.4564187228679657,
'test': 0.4603999853134155},
{'train': 0.9999999403953552,
'pool': 0.46919891238212585,
'test': 0.4688499867916107},
{'train': 0.9766666293144226,
'pool': 0.4421883821487427,
'test': 0.45089998841285706},
{'train': 0.999250054359436,
'pool': 0.4711666703224182,
'test': 0.4728999733924866},
{'train': 0.9860975742340088,
'pool': 0.46779942512512207,
'test': 0.4643999934196472},
{'train': 0.9988095164299011,
'pool': 0.46824023127555847,
'test': 0.46789997816085815},
{'train': 0.9913953542709351,
'pool': 0.46198880672454834,
'test': 0.4661499857902527},
{'train': 0.9968181848526001,
'pool': 0.47550562024116516,
'test': 0.4722999930381775},
{'train': 0.9953333735466003,
'pool': 0.4696337878704071,
'test': 0.470550000667572},
{'train': 0.9926087260246277,
'pool': 0.46621468663215637,
'test': 0.46709999442100525},
{'train': 0.9904255270957947,
'pool': 0.465552419424057,
'test': 0.46514999866485596},
{'train': 0.98416668176651,
'pool': 0.48528409004211426,
'test': 0.4899500012397766},
{'train': 0.9995918273925781,
'pool': 0.4865812063217163,
'test': 0.48124998807907104},
{'train': 0.9905999898910522,
'pool': 0.47451427578926086,
'test': 0.48009997606277466},
{'train': 0.9954901933670044,
'pool': 0.47750717401504517,
'test': 0.47279998660087585},
{'train': 0.9894230961799622,
'pool': 0.4837643802165985,
'test': 0.4827499985694885},
{'train': 0.9901886582374573,
'pool': 0.4715273678302765,
'test': 0.4745499789714813},
{'train': 0.9914814829826355,
'pool': 0.4805491268634796,
'test': 0.4833499789237976},
{'train': 0.9992727041244507,
'pool': 0.4998260736465454,
'test': 0.49539998173713684},
{'train': 0.9930356740951538,
'pool': 0.4843023419380188,
'test': 0.4772999882698059},
{'train': 0.9898245930671692,
'pool': 0.48376092314720154,
'test': 0.49049997329711914},
{'train': 0.9920690059661865,
'pool': 0.48786547780036926,
'test': 0.49014997482299805},
{'train': 0.9950847625732422,
'pool': 0.505659818649292,
'test': 0.5098499655723572},
{'train': 0.9866666793823242,
'pool': 0.5054118037223816,
'test': 0.5090500116348267}],
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_2': [{'train': 1.0,
'pool': 0.37071794271469116,
'test': 0.3698499798774719},
{'train': 1.0, 'pool': 0.36334189772605896, 'test': 0.3639499843120575},
{'train': 1.0, 'pool': 0.3708247244358063, 'test': 0.3759499788284302},
{'train': 0.9961538314819336,
'pool': 0.3681136965751648,
'test': 0.3698999881744385},
{'train': 1.0, 'pool': 0.38370466232299805, 'test': 0.38679999113082886},
{'train': 1.0, 'pool': 0.38727274537086487, 'test': 0.3919999897480011},
{'train': 0.9943749904632568,
'pool': 0.3801041841506958,
'test': 0.37379997968673706},
{'train': 0.9935293793678284,
'pool': 0.3834986984729767,
'test': 0.38169997930526733},
{'train': 0.9900000095367432,
'pool': 0.39311516284942627,
'test': 0.39159998297691345},
{'train': 0.9999999403953552,
'pool': 0.40926507115364075,
'test': 0.41324999928474426},
{'train': 0.9660000205039978,
'pool': 0.3904999792575836,
'test': 0.3937000036239624},
{'train': 0.991428554058075,
'pool': 0.40250658988952637,
'test': 0.40119999647140503},
{'train': 0.9968181848526001,
'pool': 0.41624340415000916,
'test': 0.412200003862381},
{'train': 0.9908695816993713,
'pool': 0.405649870634079,
'test': 0.4086499810218811},
{'train': 0.99958336353302,
'pool': 0.421755313873291,
'test': 0.4225499927997589},
{'train': 0.9955999851226807,
'pool': 0.4310133457183838,
'test': 0.43324998021125793},
{'train': 0.9846153855323792,
'pool': 0.41037431359291077,
'test': 0.4132999777793884},
{'train': 0.9981481432914734,
'pool': 0.43683648109436035,
'test': 0.43549999594688416},
{'train': 0.9996428489685059,
'pool': 0.43647849559783936,
'test': 0.4353500008583069},
{'train': 0.9927586317062378,
'pool': 0.4255795180797577,
'test': 0.4353500008583069},
{'train': 1.0, 'pool': 0.45197296142578125, 'test': 0.4511999785900116},
{'train': 0.9951613545417786,
'pool': 0.4452032744884491,
'test': 0.44415000081062317},
{'train': 0.9987499713897705,
'pool': 0.4488315284252167,
'test': 0.45135000348091125},
{'train': 0.986060619354248,
'pool': 0.43539509177207947,
'test': 0.4357999861240387},
{'train': 0.9999999403953552,
'pool': 0.46707651019096375,
'test': 0.46469998359680176},
{'train': 0.9968571066856384,
'pool': 0.450794517993927,
'test': 0.45170000195503235},
{'train': 0.9919444918632507,
'pool': 0.46527472138404846,
'test': 0.4610999822616577},
{'train': 0.9945946335792542,
'pool': 0.46347105503082275,
'test': 0.4551999866962433},
{'train': 0.9986841678619385,
'pool': 0.4680110514163971,
'test': 0.4713499844074249},
{'train': 0.9907692074775696,
'pool': 0.47271469235420227,
'test': 0.4646499752998352},
{'train': 0.9937500357627869,
'pool': 0.4745555520057678,
'test': 0.4708999991416931},
{'train': 0.9514634013175964,
'pool': 0.46454036235809326,
'test': 0.46709999442100525},
{'train': 0.9942857027053833,
'pool': 0.4702514111995697,
'test': 0.47145000100135803},
{'train': 0.9862790703773499,
'pool': 0.45747900009155273,
'test': 0.45944997668266296},
{'train': 0.9840909242630005,
'pool': 0.47157302498817444,
'test': 0.4742499887943268},
{'train': 0.9953333735466003,
'pool': 0.47526758909225464,
'test': 0.4781999886035919},
{'train': 0.9913043975830078,
'pool': 0.4804519712924957,
'test': 0.4791499972343445},
{'train': 0.992553174495697,
'pool': 0.4690934717655182,
'test': 0.4699999988079071},
{'train': 0.9952083826065063,
'pool': 0.4726988673210144,
'test': 0.46814998984336853},
{'train': 0.9900000095367432,
'pool': 0.48353275656700134,
'test': 0.48144999146461487},
{'train': 0.9727999567985535,
'pool': 0.4786857068538666,
'test': 0.4810999929904938},
{'train': 0.9892156720161438,
'pool': 0.49275073409080505,
'test': 0.48489999771118164},
{'train': 0.9959615468978882,
'pool': 0.49514368176460266,
'test': 0.4931999742984772},
{'train': 0.993207573890686,
'pool': 0.48383286595344543,
'test': 0.4869999885559082},
{'train': 0.9961110949516296,
'pool': 0.4962427616119385,
'test': 0.4955499768257141},
{'train': 0.996363639831543,
'pool': 0.48614493012428284,
'test': 0.4855499863624573},
{'train': 0.9887499809265137,
'pool': 0.4814535081386566,
'test': 0.48019999265670776},
{'train': 0.991754412651062,
'pool': 0.5067930221557617,
'test': 0.5092499852180481},
{'train': 0.9896551966667175,
'pool': 0.5026608109474182,
'test': 0.5103499889373779},
{'train': 0.9979661107063293,
'pool': 0.5066568851470947,
'test': 0.5159499645233154},
{'train': 0.9883333444595337,
'pool': 0.5074999928474426,
'test': 0.5161499977111816}],
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_3': [{'train': 0.9980000257492065,
'pool': 0.358923077583313,
'test': 0.3579999804496765},
{'train': 0.9981818199157715,
'pool': 0.37303340435028076,
'test': 0.3730999827384949},
{'train': 1.0, 'pool': 0.3697938024997711, 'test': 0.3720499873161316},
{'train': 0.9907692074775696,
'pool': 0.37408268451690674,
'test': 0.3771999776363373},
{'train': 1.0, 'pool': 0.3841450810432434, 'test': 0.3780499994754791},
{'train': 1.0, 'pool': 0.38989609479904175, 'test': 0.38589999079704285},
{'train': 0.9906249642372131,
'pool': 0.382500022649765,
'test': 0.38349997997283936},
{'train': 0.9999999403953552,
'pool': 0.3956919014453888,
'test': 0.39319998025894165},
{'train': 0.9783333539962769,
'pool': 0.38426700234413147,
'test': 0.3865499794483185},
{'train': 0.9652631282806396,
'pool': 0.3904724419116974,
'test': 0.3899499773979187},
{'train': 1.0, 'pool': 0.4093684256076813, 'test': 0.412200003862381},
{'train': 0.9933333396911621,
'pool': 0.4068337678909302,
'test': 0.405599981546402},
{'train': 0.9995454549789429,
'pool': 0.420634925365448,
'test': 0.4232499897480011},
{'train': 1.0, 'pool': 0.41496020555496216, 'test': 0.4202499985694885},
{'train': 0.9983333349227905,
'pool': 0.4111170172691345,
'test': 0.4131999909877777},
{'train': 0.9967999458312988,
'pool': 0.42661333084106445,
'test': 0.42489999532699585},
{'train': 0.9926922917366028,
'pool': 0.4214438498020172,
'test': 0.4220999777317047},
{'train': 0.9774073958396912,
'pool': 0.4128418266773224,
'test': 0.40799999237060547},
{'train': 1.0, 'pool': 0.4558333158493042, 'test': 0.46014997363090515},
{'train': 0.9824138283729553,
'pool': 0.41719675064086914,
'test': 0.42249998450279236},
{'train': 0.996666669845581,
'pool': 0.42781081795692444,
'test': 0.4267999827861786},
{'train': 0.9922581315040588,
'pool': 0.43620598316192627,
'test': 0.4341000020503998},
{'train': 0.9853124618530273,
'pool': 0.4191032648086548,
'test': 0.4188999831676483},
{'train': 0.9993939399719238,
'pool': 0.44476836919784546,
'test': 0.44769999384880066},
{'train': 0.9932352304458618,
'pool': 0.44603827595710754,
'test': 0.44259998202323914},
{'train': 0.9991428256034851,
'pool': 0.4615616500377655,
'test': 0.4597499966621399},
{'train': 0.9961111545562744,
'pool': 0.4507417678833008,
'test': 0.44919997453689575},
{'train': 0.9935135245323181,
'pool': 0.45713499188423157,
'test': 0.45419999957084656},
{'train': 0.9978947043418884,
'pool': 0.4520718455314636,
'test': 0.4518499970436096},
{'train': 0.9928205013275146,
'pool': 0.4539058208465576,
'test': 0.45009997487068176},
{'train': 0.9810000658035278,
'pool': 0.45047223567962646,
'test': 0.4518499970436096},
{'train': 0.9836585521697998,
'pool': 0.4731754660606384,
'test': 0.4717999994754791},
{'train': 0.9983333349227905,
'pool': 0.4789385497570038,
'test': 0.47474998235702515},
{'train': 0.9976744651794434,
'pool': 0.462380975484848,
'test': 0.45419999957084656},
{'train': 0.9886363744735718,
'pool': 0.46098315715789795,
'test': 0.45329999923706055},
{'train': 0.9946666955947876,
'pool': 0.4735211133956909,
'test': 0.47384998202323914},
{'train': 0.9980434775352478,
'pool': 0.47293785214424133,
'test': 0.4727499783039093},
{'train': 0.994468092918396,
'pool': 0.47186967730522156,
'test': 0.47609999775886536},
{'train': 0.987083375453949,
'pool': 0.46173295378685,
'test': 0.4640499949455261},
{'train': 0.9985714554786682,
'pool': 0.4876638352870941,
'test': 0.4912000000476837},
{'train': 0.9997999668121338,
'pool': 0.5019999742507935,
'test': 0.5024499893188477},
{'train': 0.9952940940856934,
'pool': 0.4742693603038788,
'test': 0.47419998049736023},
{'train': 0.9855769276618958,
'pool': 0.4858333468437195,
'test': 0.4799000024795532},
{'train': 0.9884905815124512,
'pool': 0.4814697504043579,
'test': 0.47794997692108154},
{'train': 0.9864814877510071,
'pool': 0.48895952105522156,
'test': 0.48649999499320984},
{'train': 0.9901818037033081,
'pool': 0.48533332347869873,
'test': 0.4821999967098236},
{'train': 0.9957142472267151,
'pool': 0.4910465180873871,
'test': 0.4896499812602997},
{'train': 0.9889473915100098,
'pool': 0.5021865963935852,
'test': 0.5012999773025513},
{'train': 0.9975862503051758,
'pool': 0.5030701756477356,
'test': 0.5001000165939331},
{'train': 0.9896610379219055,
'pool': 0.5053079128265381,
'test': 0.49584999680519104},
{'train': 0.9803333282470703,
'pool': 0.5059705972671509,
'test': 0.49549999833106995}],
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_4': [{'train': 1.0,
'pool': 0.37148717045783997,
'test': 0.37129998207092285},
{'train': 0.9872727394104004,
'pool': 0.3528277575969696,
'test': 0.3524499833583832},
{'train': 1.0, 'pool': 0.3909793794155121, 'test': 0.39354997873306274},
{'train': 0.9769230484962463,
'pool': 0.36617571115493774,
'test': 0.35854998230934143},
{'train': 1.0, 'pool': 0.3875647783279419, 'test': 0.39249998331069946},
{'train': 1.0, 'pool': 0.39257141947746277, 'test': 0.3886999785900116},
{'train': 1.0, 'pool': 0.3916666805744171, 'test': 0.3968999981880188},
{'train': 0.9958823323249817,
'pool': 0.3959529995918274,
'test': 0.39819997549057007},
{'train': 0.9650000333786011,
'pool': 0.3841099441051483,
'test': 0.38509997725486755},
{'train': 0.9999999403953552,
'pool': 0.40590548515319824,
'test': 0.41119998693466187},
{'train': 1.0, 'pool': 0.41531577706336975, 'test': 0.41599997878074646},
{'train': 0.996666669845581,
'pool': 0.4048548638820648,
'test': 0.4006499946117401},
{'train': 0.9995454549789429,
'pool': 0.4159788489341736,
'test': 0.4117499887943268},
{'train': 0.9952173829078674,
'pool': 0.4009018540382385,
'test': 0.4039499759674072},
{'train': 1.0, 'pool': 0.42343083024024963, 'test': 0.4235000014305115},
{'train': 1.0, 'pool': 0.4182933270931244, 'test': 0.42604997754096985},
{'train': 0.994615375995636,
'pool': 0.42994651198387146,
'test': 0.4292999804019928},
{'train': 0.9825925827026367,
'pool': 0.416085809469223,
'test': 0.41234999895095825},
{'train': 0.9996428489685059,
'pool': 0.42744624614715576,
'test': 0.4302999973297119},
{'train': 1.0, 'pool': 0.43668463826179504, 'test': 0.4320499897003174},
{'train': 1.0, 'pool': 0.44783782958984375, 'test': 0.4542499780654907},
{'train': 0.9974194169044495,
'pool': 0.438943088054657,
'test': 0.43424999713897705},
{'train': 0.9971874952316284,
'pool': 0.4215760827064514,
'test': 0.4258500039577484},
{'train': 0.9996969699859619,
'pool': 0.44738417863845825,
'test': 0.44314998388290405},
{'train': 0.9952940940856934,
'pool': 0.4482786953449249,
'test': 0.44154998660087585},
{'train': 0.998285710811615,
'pool': 0.4477260112762451,
'test': 0.4491499960422516},
{'train': 0.9972222447395325,
'pool': 0.45956045389175415,
'test': 0.4624499976634979},
{'train': 0.9894594550132751,
'pool': 0.4395592212677002,
'test': 0.43834999203681946},
{'train': 0.9947367906570435,
'pool': 0.45947515964508057,
'test': 0.4607499837875366},
{'train': 0.9987179040908813,
'pool': 0.4682271480560303,
'test': 0.4668499827384949},
{'train': 0.9917500615119934,
'pool': 0.45980554819107056,
'test': 0.45684999227523804},
{'train': 0.9870731830596924,
'pool': 0.4579108655452728,
'test': 0.45534998178482056},
{'train': 0.9933333396911621,
'pool': 0.4729888439178467,
'test': 0.46699997782707214},
{'train': 0.9779070019721985,
'pool': 0.4508403539657593,
'test': 0.44919997453689575},
{'train': 0.991136372089386,
'pool': 0.45373594760894775,
'test': 0.4558499753475189},
{'train': 0.9997777938842773,
'pool': 0.475661963224411,
'test': 0.4786999821662903},
{'train': 0.9958695769309998,
'pool': 0.4827966094017029,
'test': 0.48729997873306274},
{'train': 0.9993616938591003,
'pool': 0.4744192659854889,
'test': 0.4750500023365021},
{'train': 0.9935417175292969,
'pool': 0.48366478085517883,
'test': 0.4810999929904938},
{'train': 0.9981632828712463,
'pool': 0.4806837737560272,
'test': 0.47679999470710754},
{'train': 0.9879999756813049,
'pool': 0.4757428467273712,
'test': 0.4801499843597412},
{'train': 0.994705855846405,
'pool': 0.48103153705596924,
'test': 0.47859999537467957},
{'train': 0.994038462638855,
'pool': 0.47591954469680786,
'test': 0.4760499894618988},
{'train': 0.997358500957489,
'pool': 0.4825936555862427,
'test': 0.4834499955177307},
{'train': 0.9803703427314758,
'pool': 0.4860404431819916,
'test': 0.4843999743461609},
{'train': 0.9945454001426697,
'pool': 0.4887826144695282,
'test': 0.48989999294281006},
{'train': 0.9973214268684387,
'pool': 0.4819476902484894,
'test': 0.48475000262260437},
{'train': 0.9856140613555908,
'pool': 0.4966180622577667,
'test': 0.49449998140335083},
{'train': 0.9972414374351501,
'pool': 0.4989473521709442,
'test': 0.49449998140335083},
{'train': 0.9911864399909973,
'pool': 0.5139882564544678,
'test': 0.5054999589920044},
{'train': 0.9826666712760925,
'pool': 0.5147647261619568,
'test': 0.5054500102996826}]}
file_path = "/home/jaiswalsuraj/suraj_work/ASTRA/notebooks/al/accuracy_AL_list_diversity.json"
# with open(file_path, 'w') as json_file:
# json.dump(accuracy_AL_list, json_file)
# print(f"Accuracy summary has been saved to {file_path}.")
# Load the accuracy summary from the JSON file
with open(file_path, 'r') as json_file:
accuracy_AL_list = json.load(json_file)
print(f"Accuracy summary has been loaded from {file_path}.")
Accuracy summary has been loaded from /home/jaiswalsuraj/suraj_work/ASTRA/notebooks/al/accuracy_AL_list_diversity.json.
{'train_1000_pool_query_100_iter_50_DiversityStrategy_seed_0': [tensor([102, 96, 97, 114, 101, 86, 107, 103, 83, 111], device='cuda:2'),
tensor([116, 111, 101, 128, 103, 101, 110, 120, 88, 122], device='cuda:2'),
tensor([126, 121, 105, 141, 107, 122, 115, 132, 97, 134], device='cuda:2'),
tensor([146, 145, 107, 155, 109, 133, 117, 137, 106, 145], device='cuda:2'),
tensor([156, 164, 115, 169, 109, 145, 122, 146, 111, 163], device='cuda:2'),
tensor([163, 177, 124, 180, 114, 163, 124, 161, 116, 178], device='cuda:2'),
tensor([171, 195, 132, 195, 117, 177, 139, 170, 118, 186], device='cuda:2'),
tensor([182, 208, 140, 207, 121, 189, 144, 185, 121, 203], device='cuda:2'),
tensor([200, 221, 146, 217, 125, 203, 148, 209, 126, 205], device='cuda:2'),
tensor([210, 231, 151, 228, 128, 210, 152, 240, 129, 221], device='cuda:2'),
tensor([220, 243, 155, 241, 131, 223, 156, 257, 132, 242], device='cuda:2'),
tensor([236, 260, 163, 255, 133, 229, 157, 271, 137, 259], device='cuda:2'),
tensor([244, 271, 171, 269, 137, 245, 163, 290, 141, 269], device='cuda:2'),
tensor([257, 284, 179, 278, 141, 258, 169, 304, 146, 284], device='cuda:2'),
tensor([265, 286, 189, 287, 146, 275, 174, 333, 151, 294], device='cuda:2'),
tensor([279, 314, 190, 288, 149, 285, 174, 344, 166, 311], device='cuda:2'),
tensor([285, 339, 191, 302, 153, 296, 185, 355, 171, 323], device='cuda:2'),
tensor([299, 361, 201, 311, 155, 311, 190, 363, 174, 335], device='cuda:2'),
tensor([308, 381, 210, 320, 157, 323, 195, 374, 177, 355], device='cuda:2'),
tensor([334, 393, 220, 329, 164, 329, 199, 383, 185, 364], device='cuda:2'),
tensor([336, 401, 224, 335, 169, 345, 207, 415, 191, 377], device='cuda:2'),
tensor([348, 413, 228, 346, 173, 359, 215, 431, 195, 392], device='cuda:2'),
tensor([364, 422, 241, 362, 175, 371, 223, 444, 200, 398], device='cuda:2'),
tensor([373, 434, 255, 370, 180, 384, 226, 460, 208, 410], device='cuda:2'),
tensor([387, 451, 260, 383, 183, 394, 229, 470, 220, 423], device='cuda:2'),
tensor([394, 467, 271, 393, 184, 409, 233, 482, 230, 437], device='cuda:2'),
tensor([399, 488, 283, 407, 186, 425, 236, 490, 237, 449], device='cuda:2'),
tensor([400, 493, 288, 423, 192, 451, 239, 512, 239, 463], device='cuda:2'),
tensor([413, 515, 295, 430, 198, 454, 245, 519, 247, 484], device='cuda:2'),
tensor([420, 526, 299, 444, 204, 464, 252, 544, 252, 495], device='cuda:2'),
tensor([427, 550, 308, 447, 207, 470, 263, 550, 263, 515], device='cuda:2'),
tensor([435, 559, 316, 454, 218, 484, 278, 565, 264, 527], device='cuda:2'),
tensor([449, 583, 329, 460, 224, 488, 282, 578, 271, 536], device='cuda:2'),
tensor([454, 594, 341, 471, 230, 497, 290, 605, 277, 541], device='cuda:2'),
tensor([463, 601, 346, 482, 233, 510, 309, 622, 281, 553], device='cuda:2'),
tensor([472, 610, 356, 490, 243, 519, 327, 640, 286, 557], device='cuda:2'),
tensor([480, 634, 363, 502, 247, 534, 329, 644, 294, 573], device='cuda:2'),
tensor([489, 653, 369, 507, 249, 541, 331, 656, 314, 591], device='cuda:2'),
tensor([499, 664, 373, 513, 257, 550, 335, 676, 324, 609], device='cuda:2'),
tensor([504, 669, 379, 526, 261, 568, 343, 698, 328, 624], device='cuda:2'),
tensor([518, 686, 389, 538, 266, 580, 348, 708, 338, 629], device='cuda:2'),
tensor([537, 700, 396, 547, 280, 586, 352, 717, 346, 639], device='cuda:2'),
tensor([550, 727, 403, 555, 281, 590, 356, 723, 352, 663], device='cuda:2'),
tensor([567, 745, 407, 564, 285, 597, 361, 735, 363, 676], device='cuda:2'),
tensor([570, 757, 410, 575, 291, 603, 365, 768, 369, 692], device='cuda:2'),
tensor([584, 772, 413, 583, 298, 612, 379, 784, 378, 697], device='cuda:2'),
tensor([594, 788, 421, 596, 305, 619, 390, 794, 386, 707], device='cuda:2'),
tensor([609, 810, 432, 599, 314, 624, 397, 801, 397, 717], device='cuda:2'),
tensor([618, 826, 437, 609, 317, 629, 405, 810, 410, 739], device='cuda:2'),
tensor([625, 843, 446, 622, 324, 639, 412, 824, 415, 750], device='cuda:2'),
tensor([634, 856, 451, 630, 328, 652, 425, 841, 418, 765], device='cuda:2')]}
{'train_1000_pool_query_100_iter_50_DiversityStrategy_seed_0': [[102,
96,
97,
114,
101,
86,
107,
103,
83,
111],
[116, 111, 101, 128, 103, 101, 110, 120, 88, 122],
[126, 121, 105, 141, 107, 122, 115, 132, 97, 134],
[146, 145, 107, 155, 109, 133, 117, 137, 106, 145],
[156, 164, 115, 169, 109, 145, 122, 146, 111, 163],
[163, 177, 124, 180, 114, 163, 124, 161, 116, 178],
[171, 195, 132, 195, 117, 177, 139, 170, 118, 186],
[182, 208, 140, 207, 121, 189, 144, 185, 121, 203],
[200, 221, 146, 217, 125, 203, 148, 209, 126, 205],
[210, 231, 151, 228, 128, 210, 152, 240, 129, 221],
[220, 243, 155, 241, 131, 223, 156, 257, 132, 242],
[236, 260, 163, 255, 133, 229, 157, 271, 137, 259],
[244, 271, 171, 269, 137, 245, 163, 290, 141, 269],
[257, 284, 179, 278, 141, 258, 169, 304, 146, 284],
[265, 286, 189, 287, 146, 275, 174, 333, 151, 294],
[279, 314, 190, 288, 149, 285, 174, 344, 166, 311],
[285, 339, 191, 302, 153, 296, 185, 355, 171, 323],
[299, 361, 201, 311, 155, 311, 190, 363, 174, 335],
[308, 381, 210, 320, 157, 323, 195, 374, 177, 355],
[334, 393, 220, 329, 164, 329, 199, 383, 185, 364],
[336, 401, 224, 335, 169, 345, 207, 415, 191, 377],
[348, 413, 228, 346, 173, 359, 215, 431, 195, 392],
[364, 422, 241, 362, 175, 371, 223, 444, 200, 398],
[373, 434, 255, 370, 180, 384, 226, 460, 208, 410],
[387, 451, 260, 383, 183, 394, 229, 470, 220, 423],
[394, 467, 271, 393, 184, 409, 233, 482, 230, 437],
[399, 488, 283, 407, 186, 425, 236, 490, 237, 449],
[400, 493, 288, 423, 192, 451, 239, 512, 239, 463],
[413, 515, 295, 430, 198, 454, 245, 519, 247, 484],
[420, 526, 299, 444, 204, 464, 252, 544, 252, 495],
[427, 550, 308, 447, 207, 470, 263, 550, 263, 515],
[435, 559, 316, 454, 218, 484, 278, 565, 264, 527],
[449, 583, 329, 460, 224, 488, 282, 578, 271, 536],
[454, 594, 341, 471, 230, 497, 290, 605, 277, 541],
[463, 601, 346, 482, 233, 510, 309, 622, 281, 553],
[472, 610, 356, 490, 243, 519, 327, 640, 286, 557],
[480, 634, 363, 502, 247, 534, 329, 644, 294, 573],
[489, 653, 369, 507, 249, 541, 331, 656, 314, 591],
[499, 664, 373, 513, 257, 550, 335, 676, 324, 609],
[504, 669, 379, 526, 261, 568, 343, 698, 328, 624],
[518, 686, 389, 538, 266, 580, 348, 708, 338, 629],
[537, 700, 396, 547, 280, 586, 352, 717, 346, 639],
[550, 727, 403, 555, 281, 590, 356, 723, 352, 663],
[567, 745, 407, 564, 285, 597, 361, 735, 363, 676],
[570, 757, 410, 575, 291, 603, 365, 768, 369, 692],
[584, 772, 413, 583, 298, 612, 379, 784, 378, 697],
[594, 788, 421, 596, 305, 619, 390, 794, 386, 707],
[609, 810, 432, 599, 314, 624, 397, 801, 397, 717],
[618, 826, 437, 609, 317, 629, 405, 810, 410, 739],
[625, 843, 446, 622, 324, 639, 412, 824, 415, 750],
[634, 856, 451, 630, 328, 652, 425, 841, 418, 765]],
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_0': [tensor([102, 96, 97, 114, 101, 86, 107, 103, 83, 111], device='cuda:2'),
tensor([112, 107, 102, 123, 111, 101, 112, 118, 96, 118], device='cuda:2'),
tensor([118, 115, 114, 132, 121, 108, 126, 132, 102, 132], device='cuda:2'),
tensor([132, 128, 122, 142, 128, 118, 134, 143, 112, 141], device='cuda:2'),
tensor([143, 140, 134, 154, 139, 133, 145, 150, 115, 147], device='cuda:2'),
tensor([155, 151, 139, 161, 149, 143, 153, 163, 128, 158], device='cuda:2'),
tensor([167, 156, 149, 166, 158, 152, 167, 170, 149, 166], device='cuda:2'),
tensor([180, 168, 156, 174, 170, 166, 179, 174, 156, 177], device='cuda:2'),
tensor([189, 178, 167, 184, 179, 177, 185, 185, 167, 189], device='cuda:2'),
tensor([199, 184, 185, 190, 189, 190, 196, 193, 177, 197], device='cuda:2'),
tensor([209, 197, 196, 197, 199, 198, 206, 209, 183, 206], device='cuda:2'),
tensor([223, 205, 201, 212, 211, 204, 213, 216, 195, 220], device='cuda:2'),
tensor([238, 215, 209, 221, 223, 214, 219, 230, 203, 228], device='cuda:2'),
tensor([250, 225, 217, 228, 232, 224, 233, 239, 211, 241], device='cuda:2'),
tensor([262, 237, 227, 237, 245, 234, 240, 246, 222, 250], device='cuda:2'),
tensor([278, 247, 235, 246, 252, 245, 250, 256, 231, 260], device='cuda:2'),
tensor([284, 259, 244, 255, 262, 257, 257, 268, 240, 274], device='cuda:2'),
tensor([292, 273, 250, 268, 271, 269, 266, 280, 249, 282], device='cuda:2'),
tensor([297, 282, 260, 282, 281, 277, 273, 294, 262, 292], device='cuda:2'),
tensor([305, 294, 270, 294, 291, 288, 284, 306, 269, 299], device='cuda:2'),
tensor([316, 299, 279, 300, 301, 299, 297, 319, 280, 310], device='cuda:2'),
tensor([324, 308, 286, 308, 311, 310, 314, 328, 292, 319], device='cuda:2'),
tensor([331, 318, 294, 320, 322, 320, 326, 340, 303, 326], device='cuda:2'),
tensor([341, 329, 305, 331, 329, 339, 337, 349, 308, 332], device='cuda:2'),
tensor([350, 338, 313, 341, 340, 343, 344, 359, 324, 348], device='cuda:2'),
tensor([363, 350, 328, 353, 347, 351, 350, 364, 339, 355], device='cuda:2'),
tensor([370, 358, 338, 362, 356, 361, 356, 380, 350, 369], device='cuda:2'),
tensor([377, 366, 352, 365, 368, 365, 374, 387, 370, 376], device='cuda:2'),
tensor([388, 376, 360, 375, 374, 374, 387, 405, 378, 383], device='cuda:2'),
tensor([397, 387, 371, 387, 388, 379, 398, 418, 384, 391], device='cuda:2'),
tensor([402, 391, 385, 394, 401, 389, 406, 432, 392, 408], device='cuda:2'),
tensor([415, 398, 397, 404, 411, 404, 416, 438, 397, 420], device='cuda:2'),
tensor([429, 413, 404, 412, 426, 415, 430, 442, 404, 425], device='cuda:2'),
tensor([443, 424, 415, 423, 438, 428, 436, 451, 409, 433], device='cuda:2'),
tensor([453, 436, 433, 430, 448, 439, 444, 462, 414, 441], device='cuda:2'),
tensor([466, 441, 446, 438, 460, 449, 455, 472, 421, 452], device='cuda:2'),
tensor([475, 449, 455, 447, 468, 460, 469, 479, 435, 463], device='cuda:2'),
tensor([484, 455, 465, 459, 485, 474, 473, 487, 443, 475], device='cuda:2'),
tensor([492, 464, 476, 468, 492, 488, 478, 504, 453, 485], device='cuda:2'),
tensor([501, 472, 487, 475, 506, 500, 490, 511, 463, 495], device='cuda:2'),
tensor([514, 482, 499, 479, 512, 518, 501, 521, 473, 501], device='cuda:2'),
tensor([522, 495, 510, 493, 519, 532, 507, 534, 480, 508], device='cuda:2'),
tensor([534, 503, 515, 511, 528, 538, 513, 543, 495, 520], device='cuda:2'),
tensor([544, 511, 522, 525, 542, 550, 518, 556, 503, 529], device='cuda:2'),
tensor([551, 521, 533, 537, 557, 560, 526, 565, 510, 540], device='cuda:2'),
tensor([561, 534, 543, 545, 566, 576, 532, 576, 519, 548], device='cuda:2'),
tensor([567, 545, 556, 557, 583, 587, 541, 585, 521, 558], device='cuda:2'),
tensor([575, 559, 568, 565, 595, 595, 550, 592, 529, 572], device='cuda:2'),
tensor([583, 567, 579, 580, 607, 604, 559, 602, 538, 581], device='cuda:2'),
tensor([588, 581, 592, 584, 618, 612, 572, 611, 549, 593], device='cuda:2'),
tensor([596, 592, 601, 595, 623, 621, 587, 622, 556, 607], device='cuda:2')],
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_1': [tensor([102, 96, 97, 114, 101, 86, 107, 103, 83, 111], device='cuda:2'),
tensor([109, 106, 106, 121, 116, 100, 114, 110, 95, 123], device='cuda:2'),
tensor([121, 120, 118, 129, 129, 107, 125, 116, 103, 132], device='cuda:2'),
tensor([133, 131, 125, 137, 134, 116, 134, 132, 116, 142], device='cuda:2'),
tensor([146, 141, 134, 143, 143, 123, 146, 140, 129, 155], device='cuda:2'),
tensor([160, 155, 141, 148, 152, 130, 157, 146, 148, 163], device='cuda:2'),
tensor([172, 161, 150, 160, 164, 137, 164, 156, 157, 179], device='cuda:2'),
tensor([181, 170, 157, 170, 177, 147, 173, 165, 169, 191], device='cuda:2'),
tensor([191, 178, 169, 183, 188, 153, 185, 176, 179, 198], device='cuda:2'),
tensor([199, 194, 179, 189, 196, 167, 193, 184, 191, 208], device='cuda:2'),
tensor([207, 202, 187, 201, 216, 180, 206, 186, 201, 214], device='cuda:2'),
tensor([220, 213, 198, 206, 226, 184, 220, 194, 216, 223], device='cuda:2'),
tensor([225, 224, 213, 216, 232, 195, 231, 206, 224, 234], device='cuda:2'),
tensor([232, 232, 222, 230, 243, 209, 243, 218, 231, 240], device='cuda:2'),
tensor([238, 241, 232, 236, 249, 219, 258, 233, 241, 253], device='cuda:2'),
tensor([249, 249, 238, 249, 260, 233, 268, 238, 252, 264], device='cuda:2'),
tensor([255, 254, 250, 261, 268, 251, 278, 250, 262, 271], device='cuda:2'),
tensor([263, 263, 266, 274, 274, 259, 289, 257, 275, 280], device='cuda:2'),
tensor([272, 273, 276, 279, 284, 272, 295, 265, 288, 296], device='cuda:2'),
tensor([280, 286, 290, 289, 295, 286, 303, 275, 294, 302], device='cuda:2'),
tensor([292, 292, 295, 302, 305, 292, 312, 284, 307, 319], device='cuda:2'),
tensor([300, 303, 306, 310, 311, 303, 327, 293, 321, 326], device='cuda:2'),
tensor([309, 309, 315, 321, 323, 308, 336, 307, 333, 339], device='cuda:2'),
tensor([318, 319, 325, 334, 335, 318, 346, 313, 342, 350], device='cuda:2'),
tensor([328, 328, 331, 351, 351, 330, 359, 317, 346, 359], device='cuda:2'),
tensor([342, 334, 345, 365, 360, 337, 368, 330, 352, 367], device='cuda:2'),
tensor([351, 346, 355, 376, 372, 347, 378, 336, 364, 375], device='cuda:2'),
tensor([362, 359, 365, 385, 379, 360, 387, 345, 375, 383], device='cuda:2'),
tensor([370, 371, 372, 393, 388, 372, 396, 358, 386, 394], device='cuda:2'),
tensor([385, 384, 378, 403, 403, 378, 402, 364, 396, 407], device='cuda:2'),
tensor([394, 393, 388, 412, 413, 385, 411, 376, 411, 417], device='cuda:2'),
tensor([409, 399, 399, 418, 424, 396, 420, 385, 425, 425], device='cuda:2'),
tensor([421, 407, 407, 427, 433, 410, 425, 398, 435, 437], device='cuda:2'),
tensor([429, 417, 416, 435, 446, 417, 441, 409, 445, 445], device='cuda:2'),
tensor([435, 428, 431, 446, 454, 427, 444, 422, 457, 456], device='cuda:2'),
tensor([443, 437, 443, 452, 462, 437, 457, 433, 470, 466], device='cuda:2'),
tensor([450, 453, 453, 463, 474, 445, 465, 446, 474, 477], device='cuda:2'),
tensor([459, 464, 462, 474, 485, 458, 473, 453, 486, 486], device='cuda:2'),
tensor([468, 477, 476, 488, 494, 466, 480, 459, 496, 496], device='cuda:2'),
tensor([477, 486, 486, 492, 504, 480, 487, 470, 507, 511], device='cuda:2'),
tensor([489, 491, 501, 506, 516, 488, 496, 481, 516, 516], device='cuda:2'),
tensor([497, 499, 518, 523, 530, 495, 500, 490, 526, 522], device='cuda:2'),
tensor([507, 515, 528, 532, 540, 506, 508, 499, 535, 530], device='cuda:2'),
tensor([516, 525, 535, 543, 553, 517, 522, 508, 544, 537], device='cuda:2'),
tensor([524, 533, 547, 549, 560, 531, 532, 524, 554, 546], device='cuda:2'),
tensor([530, 544, 559, 558, 572, 540, 538, 536, 565, 558], device='cuda:2'),
tensor([538, 554, 571, 570, 579, 556, 546, 549, 571, 566], device='cuda:2'),
tensor([546, 569, 581, 576, 586, 569, 555, 556, 584, 578], device='cuda:2'),
tensor([560, 580, 590, 585, 601, 578, 561, 561, 595, 589], device='cuda:2'),
tensor([566, 591, 604, 593, 612, 590, 572, 569, 604, 599], device='cuda:2'),
tensor([577, 597, 613, 609, 622, 595, 579, 583, 612, 613], device='cuda:2')],
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_2': [tensor([102, 96, 97, 114, 101, 86, 107, 103, 83, 111], device='cuda:2'),
tensor([110, 108, 110, 124, 108, 91, 116, 119, 92, 122], device='cuda:2'),
tensor([120, 122, 117, 128, 118, 105, 127, 131, 104, 128], device='cuda:2'),
tensor([131, 130, 124, 147, 126, 113, 134, 145, 114, 136], device='cuda:2'),
tensor([139, 140, 136, 155, 133, 123, 144, 157, 124, 149], device='cuda:2'),
tensor([152, 148, 143, 163, 142, 135, 148, 164, 140, 165], device='cuda:2'),
tensor([158, 161, 155, 175, 153, 145, 156, 171, 151, 175], device='cuda:2'),
tensor([165, 171, 161, 191, 163, 162, 163, 184, 158, 182], device='cuda:2'),
tensor([175, 184, 165, 197, 173, 173, 174, 198, 167, 194], device='cuda:2'),
tensor([190, 194, 171, 210, 180, 183, 185, 209, 174, 204], device='cuda:2'),
tensor([201, 204, 180, 221, 189, 193, 191, 221, 187, 213], device='cuda:2'),
tensor([211, 216, 190, 228, 200, 202, 199, 233, 196, 225], device='cuda:2'),
tensor([221, 229, 199, 239, 209, 213, 206, 242, 205, 237], device='cuda:2'),
tensor([227, 240, 208, 251, 218, 220, 213, 255, 216, 252], device='cuda:2'),
tensor([235, 247, 212, 264, 228, 238, 230, 261, 223, 262], device='cuda:2'),
tensor([245, 260, 223, 276, 238, 244, 238, 270, 233, 273], device='cuda:2'),
tensor([254, 265, 239, 284, 250, 256, 249, 281, 243, 279], device='cuda:2'),
tensor([263, 272, 247, 297, 264, 271, 253, 291, 252, 290], device='cuda:2'),
tensor([277, 283, 255, 313, 274, 281, 264, 295, 262, 296], device='cuda:2'),
tensor([286, 293, 261, 324, 282, 292, 273, 307, 273, 309], device='cuda:2'),
tensor([298, 303, 268, 334, 291, 301, 283, 320, 283, 319], device='cuda:2'),
tensor([312, 312, 274, 338, 302, 309, 291, 334, 291, 337], device='cuda:2'),
tensor([324, 321, 279, 345, 307, 319, 307, 354, 298, 346], device='cuda:2'),
tensor([336, 336, 290, 356, 314, 330, 315, 363, 305, 355], device='cuda:2'),
tensor([344, 344, 304, 365, 323, 340, 326, 374, 317, 363], device='cuda:2'),
tensor([358, 352, 317, 372, 332, 354, 337, 382, 325, 371], device='cuda:2'),
tensor([367, 365, 327, 380, 337, 369, 347, 396, 332, 380], device='cuda:2'),
tensor([372, 376, 338, 390, 347, 374, 359, 406, 343, 395], device='cuda:2'),
tensor([383, 389, 345, 404, 359, 382, 372, 415, 354, 397], device='cuda:2'),
tensor([392, 396, 359, 414, 368, 388, 382, 424, 370, 407], device='cuda:2'),
tensor([401, 404, 363, 430, 377, 395, 402, 433, 377, 418], device='cuda:2'),
tensor([406, 415, 376, 441, 385, 405, 416, 442, 383, 431], device='cuda:2'),
tensor([416, 428, 393, 448, 395, 416, 425, 448, 394, 437], device='cuda:2'),
tensor([425, 436, 405, 461, 407, 423, 436, 458, 403, 446], device='cuda:2'),
tensor([436, 447, 418, 475, 416, 426, 450, 469, 411, 452], device='cuda:2'),
tensor([442, 453, 427, 490, 428, 441, 456, 483, 419, 461], device='cuda:2'),
tensor([458, 461, 437, 499, 437, 450, 463, 494, 428, 473], device='cuda:2'),
tensor([475, 476, 446, 507, 444, 460, 471, 508, 436, 477], device='cuda:2'),
tensor([487, 487, 454, 513, 451, 471, 486, 519, 445, 487], device='cuda:2'),
tensor([498, 494, 467, 523, 467, 482, 493, 526, 455, 495], device='cuda:2'),
tensor([506, 502, 477, 535, 473, 490, 507, 537, 467, 506], device='cuda:2'),
tensor([513, 515, 487, 545, 483, 498, 519, 545, 479, 516], device='cuda:2'),
tensor([521, 526, 500, 555, 492, 508, 528, 551, 491, 528], device='cuda:2'),
tensor([531, 532, 508, 568, 503, 516, 539, 562, 504, 537], device='cuda:2'),
tensor([537, 540, 521, 577, 511, 526, 546, 575, 515, 552], device='cuda:2'),
tensor([546, 546, 531, 589, 523, 534, 556, 584, 530, 561], device='cuda:2'),
tensor([551, 563, 540, 600, 533, 544, 570, 588, 542, 569], device='cuda:2'),
tensor([562, 573, 553, 610, 541, 548, 583, 599, 553, 578], device='cuda:2'),
tensor([574, 583, 564, 618, 550, 554, 591, 613, 566, 587], device='cuda:2'),
tensor([586, 597, 577, 622, 560, 563, 598, 619, 581, 597], device='cuda:2'),
tensor([593, 606, 585, 634, 570, 574, 606, 628, 590, 614], device='cuda:2')],
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_3': [tensor([102, 96, 97, 114, 101, 86, 107, 103, 83, 111], device='cuda:2'),
tensor([108, 109, 113, 123, 111, 96, 118, 110, 89, 123], device='cuda:2'),
tensor([113, 119, 122, 137, 120, 106, 124, 122, 101, 136], device='cuda:2'),
tensor([125, 127, 132, 144, 132, 114, 137, 136, 107, 146], device='cuda:2'),
tensor([139, 137, 146, 153, 139, 124, 143, 147, 116, 156], device='cuda:2'),
tensor([151, 147, 158, 161, 146, 139, 148, 151, 128, 171], device='cuda:2'),
tensor([161, 156, 168, 173, 156, 146, 155, 158, 143, 184], device='cuda:2'),
tensor([169, 169, 176, 179, 168, 155, 169, 168, 150, 197], device='cuda:2'),
tensor([182, 182, 186, 191, 175, 165, 181, 178, 154, 206], device='cuda:2'),
tensor([191, 191, 194, 198, 194, 178, 187, 186, 167, 214], device='cuda:2'),
tensor([204, 202, 202, 207, 206, 188, 203, 193, 173, 222], device='cuda:2'),
tensor([205, 211, 216, 223, 218, 193, 215, 202, 180, 237], device='cuda:2'),
tensor([218, 221, 225, 236, 228, 201, 230, 213, 186, 242], device='cuda:2'),
tensor([229, 230, 238, 245, 235, 210, 243, 225, 191, 254], device='cuda:2'),
tensor([241, 242, 249, 256, 245, 214, 252, 236, 200, 265], device='cuda:2'),
tensor([250, 251, 260, 260, 260, 226, 259, 247, 214, 273], device='cuda:2'),
tensor([254, 266, 268, 270, 271, 243, 267, 256, 227, 278], device='cuda:2'),
tensor([260, 280, 281, 277, 284, 248, 276, 263, 240, 291], device='cuda:2'),
tensor([268, 288, 289, 284, 292, 262, 287, 275, 253, 302], device='cuda:2'),
tensor([281, 300, 298, 288, 303, 271, 297, 293, 261, 308], device='cuda:2'),
tensor([291, 306, 308, 296, 314, 283, 309, 302, 272, 319], device='cuda:2'),
tensor([299, 317, 325, 306, 326, 289, 320, 309, 281, 328], device='cuda:2'),
tensor([305, 329, 335, 321, 338, 298, 328, 316, 292, 338], device='cuda:2'),
tensor([314, 340, 344, 334, 343, 305, 341, 329, 303, 347], device='cuda:2'),
tensor([331, 349, 352, 343, 358, 310, 347, 341, 312, 357], device='cuda:2'),
tensor([343, 361, 366, 353, 371, 320, 355, 350, 318, 363], device='cuda:2'),
tensor([354, 378, 375, 361, 380, 326, 363, 358, 330, 375], device='cuda:2'),
tensor([364, 389, 386, 369, 387, 337, 371, 364, 337, 396], device='cuda:2'),
tensor([378, 397, 395, 376, 400, 346, 381, 375, 345, 407], device='cuda:2'),
tensor([390, 407, 407, 383, 410, 353, 390, 386, 354, 420], device='cuda:2'),
tensor([394, 418, 417, 397, 419, 366, 400, 393, 369, 427], device='cuda:2'),
tensor([406, 430, 428, 408, 425, 379, 408, 401, 379, 436], device='cuda:2'),
tensor([416, 440, 440, 418, 434, 392, 416, 408, 390, 446], device='cuda:2'),
tensor([429, 448, 453, 426, 442, 402, 419, 420, 405, 456], device='cuda:2'),
tensor([440, 455, 465, 438, 450, 414, 431, 426, 414, 467], device='cuda:2'),
tensor([445, 465, 473, 448, 462, 425, 444, 434, 424, 480], device='cuda:2'),
tensor([448, 473, 488, 460, 475, 433, 458, 442, 430, 493], device='cuda:2'),
tensor([456, 483, 492, 467, 484, 446, 469, 451, 442, 510], device='cuda:2'),
tensor([474, 487, 505, 476, 492, 455, 477, 464, 452, 518], device='cuda:2'),
tensor([482, 497, 510, 481, 504, 466, 484, 473, 469, 534], device='cuda:2'),
tensor([491, 506, 521, 491, 520, 477, 488, 481, 481, 544], device='cuda:2'),
tensor([501, 520, 530, 501, 535, 484, 496, 488, 492, 553], device='cuda:2'),
tensor([510, 530, 542, 510, 546, 494, 507, 501, 502, 558], device='cuda:2'),
tensor([515, 537, 554, 522, 558, 512, 515, 508, 510, 569], device='cuda:2'),
tensor([526, 547, 561, 535, 565, 524, 527, 516, 520, 579], device='cuda:2'),
tensor([535, 563, 574, 547, 576, 533, 535, 526, 529, 582], device='cuda:2'),
tensor([540, 569, 588, 562, 590, 545, 543, 537, 536, 590], device='cuda:2'),
tensor([547, 581, 594, 572, 599, 550, 556, 553, 547, 601], device='cuda:2'),
tensor([557, 588, 603, 584, 609, 561, 565, 562, 560, 611], device='cuda:2'),
tensor([564, 598, 610, 597, 619, 577, 578, 572, 567, 618], device='cuda:2'),
tensor([572, 607, 621, 608, 629, 587, 587, 579, 577, 633], device='cuda:2')],
'train_1000_pool_query_100_iter_50_RandomStrategy_seed_4': [tensor([102, 96, 97, 114, 101, 86, 107, 103, 83, 111], device='cuda:2'),
tensor([113, 109, 104, 120, 118, 95, 119, 108, 91, 123], device='cuda:2'),
tensor([121, 122, 116, 136, 125, 103, 125, 120, 99, 133], device='cuda:2'),
tensor([131, 132, 121, 152, 137, 112, 138, 129, 110, 138], device='cuda:2'),
tensor([147, 142, 130, 159, 155, 121, 146, 132, 124, 144], device='cuda:2'),
tensor([156, 155, 136, 168, 161, 128, 159, 145, 140, 152], device='cuda:2'),
tensor([168, 164, 145, 176, 171, 142, 167, 155, 149, 163], device='cuda:2'),
tensor([176, 173, 152, 184, 186, 150, 178, 166, 163, 172], device='cuda:2'),
tensor([184, 187, 169, 194, 194, 163, 183, 173, 170, 183], device='cuda:2'),
tensor([193, 191, 181, 204, 201, 174, 194, 184, 186, 192], device='cuda:2'),
tensor([202, 208, 192, 216, 210, 180, 204, 195, 193, 200], device='cuda:2'),
tensor([212, 217, 203, 228, 221, 187, 218, 206, 201, 207], device='cuda:2'),
tensor([224, 226, 218, 238, 229, 195, 231, 215, 209, 215], device='cuda:2'),
tensor([242, 232, 225, 248, 237, 204, 236, 232, 219, 225], device='cuda:2'),
tensor([251, 238, 235, 258, 249, 220, 248, 242, 229, 230], device='cuda:2'),
tensor([263, 249, 253, 266, 258, 226, 255, 255, 236, 239], device='cuda:2'),
tensor([272, 262, 266, 274, 265, 237, 265, 264, 249, 246], device='cuda:2'),
tensor([279, 272, 268, 287, 272, 247, 279, 275, 263, 258], device='cuda:2'),
tensor([294, 278, 281, 291, 288, 255, 289, 283, 272, 269], device='cuda:2'),
tensor([300, 289, 291, 297, 306, 261, 299, 291, 281, 285], device='cuda:2'),
tensor([311, 296, 303, 307, 315, 267, 313, 302, 293, 293], device='cuda:2'),
tensor([324, 308, 310, 320, 323, 279, 321, 312, 301, 302], device='cuda:2'),
tensor([336, 315, 313, 334, 328, 289, 333, 323, 315, 314], device='cuda:2'),
tensor([351, 322, 323, 340, 343, 300, 345, 338, 320, 318], device='cuda:2'),
tensor([359, 332, 329, 353, 354, 310, 353, 353, 329, 328], device='cuda:2'),
tensor([369, 341, 342, 363, 361, 321, 366, 361, 335, 341], device='cuda:2'),
tensor([381, 351, 347, 375, 377, 332, 370, 365, 348, 354], device='cuda:2'),
tensor([397, 365, 362, 382, 388, 339, 378, 371, 354, 364], device='cuda:2'),
tensor([404, 377, 376, 392, 395, 353, 383, 381, 363, 376], device='cuda:2'),
tensor([416, 384, 384, 400, 404, 364, 393, 388, 377, 390], device='cuda:2'),
tensor([431, 393, 389, 409, 415, 379, 405, 393, 386, 400], device='cuda:2'),
tensor([436, 405, 402, 415, 425, 390, 421, 402, 397, 407], device='cuda:2'),
tensor([444, 410, 416, 430, 430, 403, 434, 411, 404, 418], device='cuda:2'),
tensor([454, 421, 422, 437, 444, 406, 444, 421, 421, 430], device='cuda:2'),
tensor([463, 428, 437, 446, 457, 417, 457, 430, 432, 433], device='cuda:2'),
tensor([475, 438, 449, 457, 467, 424, 465, 439, 440, 446], device='cuda:2'),
tensor([485, 447, 462, 466, 477, 435, 470, 447, 456, 455], device='cuda:2'),
tensor([494, 462, 472, 479, 488, 440, 480, 458, 463, 464], device='cuda:2'),
tensor([503, 474, 483, 491, 494, 447, 493, 470, 472, 473], device='cuda:2'),
tensor([514, 482, 491, 498, 507, 455, 502, 478, 485, 488], device='cuda:2'),
tensor([522, 491, 500, 506, 516, 465, 515, 486, 501, 498], device='cuda:2'),
tensor([529, 498, 510, 521, 525, 480, 522, 497, 511, 507], device='cuda:2'),
tensor([540, 510, 518, 530, 535, 492, 528, 509, 520, 518], device='cuda:2'),
tensor([550, 519, 528, 541, 545, 500, 541, 518, 529, 529], device='cuda:2'),
tensor([564, 537, 539, 546, 557, 507, 548, 527, 537, 538], device='cuda:2'),
tensor([573, 545, 550, 559, 567, 516, 561, 536, 547, 546], device='cuda:2'),
tensor([585, 554, 555, 570, 585, 523, 570, 548, 558, 552], device='cuda:2'),
tensor([594, 562, 565, 578, 594, 534, 585, 558, 568, 562], device='cuda:2'),
tensor([607, 574, 575, 588, 601, 542, 600, 568, 574, 571], device='cuda:2'),
tensor([613, 580, 586, 599, 611, 553, 608, 581, 587, 582], device='cuda:2'),
tensor([630, 588, 598, 610, 620, 559, 623, 588, 596, 588], device='cuda:2')]}
file_path = "/home/jaiswalsuraj/suraj_work/ASTRA/notebooks/al/class_count_dict.json"
# with open(file_path, 'w') as json_file:
# json.dump(class_count_dict, json_file, default=lambda x: x.tolist())
# print(f"Accuracy summary has been saved to {file_path}.")
# Load the accuracy summary from the JSON file
with open(file_path, 'r') as json_file:
class_count_dict = json.load(json_file)
print(f"Accuracy summary has been loaded from {file_path}.")
# note we converted each element to list while saving
Accuracy summary has been loaded from /home/jaiswalsuraj/suraj_work/ASTRA/notebooks/al/class_count_dict.json.
diversity_accuracy_AL_list_test = []
for i in range(len(accuracy_AL_list['train_1000_pool_query_100_iter_50_DiversityStrategy_seed_0'])):
diversity_accuracy_AL_list_test.append(accuracy_AL_list['train_1000_pool_query_100_iter_50_DiversityStrategy_seed_0'][i]['test'])
# diversity_accuracy_AL_list_test
# Final accuracy across seeds
random_accuracy_list_AL_list_test = []
for i in accuracy_AL_list.keys():
if i!= 'train_1000_pool_query_100_iter_50_DiversityStrategy_seed_0':
lis = []
for j in range(len(accuracy_AL_list[i])):
lis.append(accuracy_AL_list[i][j]['test'])
random_accuracy_list_AL_list_test.append(lis)
random_accuracy_list_AL_list_test = np.array(random_accuracy_list_AL_list_test)
mean_test_accuracy = np.mean(random_accuracy_list_AL_list_test, axis=0)
std_test_accuracy = np.std(random_accuracy_list_AL_list_test, axis=0)
random_accuracy_list_AL_list_test.shape, mean_test_accuracy.shape, std_test_accuracy.shape
((5, 51), (51,), (51,))
class_counts = class_count_dict['train_1000_pool_query_100_iter_50_DiversityStrategy_seed_0']
# Create a list of tensors from class_count_dict
class_count_tensors = [torch.tensor(count) for count in class_counts]
# Stack the tensors along dimension 1
class_count_tensor = torch.stack(class_count_tensors, dim=1)
class_count_array = class_count_tensor.cpu().numpy()
print('class_count_array shape', class_count_array.shape)
num_iter = 50
# The rest of your code for plotting
plt.figure(figsize=(12, 6))
plt.stackplot(range(num_iter + 1), class_count_array, labels=dataset.classes)
plt.title('Class-wise Count for Diversity Strategy')
plt.xlabel('Active Learning Iteration')
plt.ylabel('Class-wise Count')
plt.legend(loc='upper left')
plt.xlim([1, 3])
plt.xticks(range(0, num_iter + 1))
plt.show()
class_count_array shape (10, 51)
class_counts = class_count_dict['train_1000_pool_query_100_iter_50_RandomStrategy_seed_0']
# Create a list of tensors from class_count_dict
class_count_tensors = [torch.tensor(count) for count in class_counts]
# Stack the tensors along dimension 1
class_count_tensor = torch.stack(class_count_tensors, dim=1)
class_count_array = class_count_tensor.cpu().numpy()
print('class_count_array shape', class_count_array.shape)
num_iter = 50
# The rest of your code for plotting
plt.figure(figsize=(12, 6))
plt.stackplot(range(num_iter + 1), class_count_array, labels=dataset.classes)
plt.title('Class-wise Count for Random Strategy')
plt.xlabel('Active Learning Iteration')
plt.ylabel('Class-wise Count')
plt.legend(loc='upper left')
plt.xlim([1, 3])
plt.xticks(range(0, num_iter + 1))
plt.show()
class_count_array shape (10, 51)
iterations = range(1, len(mean_test_accuracy) + 1)
plt.figure(figsize=(10, 6))
# Fill the region between mean - std and mean + std
plt.fill_between(iterations, np.array(mean_test_accuracy) - np.array(std_test_accuracy), np.array(mean_test_accuracy) + np.array(std_test_accuracy), alpha=0.5, label='Random std')
plt.plot(iterations, mean_test_accuracy, linestyle='-', label='Random mean Accuracy')
plt.plot(iterations, diversity_accuracy_AL_list_test, linestyle='-', label='Diversity Accuracy')
plt.axhline(y=accuracy_summary['untrain_acc']['test'], color='black', linestyle='--', label='untrained')
plt.axhline(y=accuracy_summary['train_1000']['test'], color='red', linestyle='--', label='labeled 1000')
plt.axhline(y=accuracy_summary['train_1000_pool_5000']['test'], color='green', linestyle='--', label='labeled 1000+labeled pool 5000')
# plt.axhline(y=train_1000_pool_query_100_iter_50_AL_acc.mean(), color='orange', linestyle='--', label='labeled 1000+unlabeled pool 5000 with AL')
plt.axhline(y=accuracy_summary['train_1000_pool_39000']['test'], color='blue', linestyle='--', label='labeled 1000+labeled pool(39000)')
# Customize the plot
plt.title('Test Accuracy vs Iterations')
plt.xlabel('Iterations')
plt.ylabel('Test Accuracy')
plt.grid(True)
plt.legend(loc='lower right')
# Show the plot or save it to a file
plt.show()
Experiment | Accuracy on Test Set |
---|---|
Untrained model | 8.58% |
Trained model on train(1000) | 36.24% |
Trained model on train(1000) + pool set(5000) random | 50.54% |
Trained model on train(1000) + pool set(5000) selected by AL | 48.17% |
Trained model on train(1000) + pool set(39000) | 61.61% |
Accuracy on test set - Untrained model : 8.58% - Trained model on train(1000) : 36.24% - Trained model on train(1000) + pool set(5000) random: 50.54% - Trained model on train(1000) + pool set(5000) unlabeled selected by AL 48.17% - Here we used the core-set approach to select the unlabeled data from the pool set - We pick 100 data points for each iteration from the pool set which are farthest from the train set for 50 such AL iterations - Trained model on train(1000) + pool set(39000) labeled: 61.61%
all above are trained on 30 epochs. Here we can see that diveristy acquisition strategy is comparable to random acquisition strategy.