• matrix heat map
import matplotlib.pyplot as plt
import numpy as np
def plot_matrix(color_type):
    # Define the data distribution for each party
    party_data = {
        'Party A': np.random.randint(10, 100, size=10),  # 5 classes
        'Party B': np.random.randint(10, 100, size=10),
        'Party C': np.random.randint(10, 100, size=10),
        'Party D': np.random.randint(10, 100, size=10),
        'Party E': np.random.randint(10, 100, size=10),
        'Party F': np.random.randint(10, 100, size=10),
        'Party G': np.random.randint(10, 100, size=10),
        'Party H': np.random.randint(10, 100, size=10),
        'Party I': np.random.randint(10, 100, size=10),
        'Party J': np.random.randint(10, 100, size=10),
    }

    # Extract the data as a matrix
    data_matrix = np.array([party_data[party] for party in party_data])

    # Define class labels
    class_labels = [f'Class {i}' for i in range(1, 11)]

    # Plotting the heatmap matrix
    fig, ax = plt.subplots(figsize=(10, 6))
    heatmap = ax.imshow(data_matrix.T, cmap=color_type, aspect='auto') # viridis

    # Set labels and title
    ax.set_xticks(np.arange(len(party_data)))
    ax.set_yticks(np.arange(len(class_labels)))
    ax.set_xticklabels(list(party_data.keys()))
    ax.set_yticklabels(class_labels)
    ax.set_xlabel('Party ID')
    ax.set_ylabel('Class')
    ax.set_title('Non-IID Data Distribution Heatmap')

    # Display the colorbar
    cbar = plt.colorbar(heatmap, orientation='vertical')
    cbar.set_label('Number of Data Samples')

    # Show plot
    plt.show()

if __name__=="__main__":
    """
    good to use: Accent_r, Blues, GnBu, OrRd, RdYlGn_r
    """
    plot_matrix("OrRd")
  • learning curve
rounds = range(1,21)
# Plotting
plt.figure(figsize=(10, 6))

# plt.plot(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100, marker='v', label="NumVirtual_1")
# plt.plot(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_5)*100, marker='^', label="NumVirtual_5")
# plt.plot(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_10)*100, marker='s', label="NumVirtual_10")
# plt.plot(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_20)*100, marker='o', label="NumVirtual_20")
# plt.plot(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_30)*100, marker='d', label="NumVirtual_30")

train_loss_std = 0.01*100
# plt.fill_between(epochs, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100 - train_loss_std, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100 + train_loss_std, color='blue', alpha=0.2)
# plt.fill_between(epochs, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100 - train_loss_std, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100 + train_loss_std, color='blue', alpha=0.2)
# plt.fill_between(epochs, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100 - train_loss_std, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100 + train_loss_std, color='blue', alpha=0.2)
# plt.fill_between(epochs, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100 - train_loss_std, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100 + train_loss_std, color='blue', alpha=0.2)
# plt.fill_between(epochs, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100 - train_loss_std, np.array(fedvc_seed0_beta_05_cifar10_virtual_1)*100 + train_loss_std, color='blue', alpha=0.2)

# Plot lines with matching fill colors
plt.plot(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_1) * 100, marker='v', label="NumVirtual_1", color='blue')
plt.fill_between(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_1) * 100 - train_loss_std,
                 np.array(fedvc_seed0_beta_05_cifar10_virtual_1) * 100 + train_loss_std, color='blue', alpha=0.2)

plt.plot(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_5) * 100, marker='^', label="NumVirtual_5", color='orange')
plt.fill_between(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_5) * 100 - train_loss_std,
                 np.array(fedvc_seed0_beta_05_cifar10_virtual_5) * 100 + train_loss_std, color='orange', alpha=0.2)

plt.plot(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_10) * 100, marker='s', label="NumVirtual_10", color='green')
plt.fill_between(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_10) * 100 - train_loss_std,
                 np.array(fedvc_seed0_beta_05_cifar10_virtual_10) * 100 + train_loss_std, color='green', alpha=0.2)

plt.plot(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_20) * 100, marker='o', label="NumVirtual_20", color='red')
plt.fill_between(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_20) * 100 - train_loss_std,
                 np.array(fedvc_seed0_beta_05_cifar10_virtual_20) * 100 + train_loss_std, color='red', alpha=0.2)

plt.plot(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_30) * 100, marker='d', label="NumVirtual_30", color='purple')
plt.fill_between(rounds, np.array(fedvc_seed0_beta_05_cifar10_virtual_30) * 100 - train_loss_std,
                 np.array(fedvc_seed0_beta_05_cifar10_virtual_30) * 100 + train_loss_std, color='purple', alpha=0.2)

# Title and Labels
plt.title("Comparison of Virtual Client Num - CIFAR-10 (beta=0.5, NumReal=20)")
plt.xlabel("Rounds")
plt.ylabel("Test Accuracy (%)")

# Set x-axis to integer values
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
# Legend
plt.legend(loc='lower right', prop={'size': 16})
# Grid and Background
plt.grid(True, linestyle='--', alpha=0.7)
# plt.gca().set_facecolor('#f0f0f0')
# Save and Show Plot
plt.savefig("plot/HyperParam.png", dpi=330)
plt.show()
打赏作者

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

CAPTCHA