|
98 | 98 |
|
99 | 99 | if idx == len(states_to_plot) - 2: # second to last row, for the legend |
100 | 100 | ax = fig.add_subplot(gs[idx,0]) |
101 | | - ax.plot(uncontrolled_states.index[0:2], np.zeros((2,1)), label = 'Uncontrolled',color='black',alpha=0.6) |
102 | | - ax.plot(level1_states.index[0:2], np.zeros((2,1)), label = 'Level 1',color='blue',alpha=0.6) |
103 | | - ax.plot(level2_states.index[0:2], np.zeros((2,1)), label = 'Level 2',color='green',alpha=0.6) |
104 | | - ax.plot(level3_states.index[0:2], np.zeros((2,1)), label = 'Level 3',color='red',alpha=0.6) |
105 | | - ax.legend(fontsize='x-large') |
| 101 | + # Plot lines for legend (these will be used as legend handles) |
| 102 | + line1, = ax.plot([], [], label='Uncontrolled', color='black', alpha=0.6) |
| 103 | + line2, = ax.plot([], [], label='Level 1', color='blue', alpha=0.6) |
| 104 | + line3, = ax.plot([], [], label='Level 2', color='green', alpha=0.6) |
| 105 | + line4, = ax.plot([], [], label='Level 3', color='red', alpha=0.6) |
| 106 | + # Remove all spines and tick marks |
| 107 | + for spine in ax.spines.values(): |
| 108 | + spine.set_visible(False) |
| 109 | + ax.tick_params(axis='both', which='both', length=0, labelleft=False, labelbottom=False) |
| 110 | + # Draw legend with an opaque background so it hides the underlying plots |
| 111 | + leg = ax.legend(handles=[line1, line2, line3, line4], fontsize='xx-large', frameon=True, loc='center') |
| 112 | + try: |
| 113 | + leg.get_frame().set_facecolor('white') |
| 114 | + leg.get_frame().set_alpha(1.0) |
| 115 | + except Exception: |
| 116 | + pass |
| 117 | + leg.set_zorder(10) |
| 118 | + # Add a white rectangle behind the legend only (figure coords) so nearby plots remain visible |
| 119 | + ax.add_patch(plt.Rectangle((-0.2, -0.2), 1.4, 1.4, transform=ax.transAxes, color='white', zorder=9)) |
| 120 | + ax.set_xlim(0, 1) |
| 121 | + ax.set_ylim(0, 1) |
| 122 | + # Remove any other axes that share this gridspec location |
| 123 | + for other_ax in fig.axes: |
| 124 | + if other_ax is not ax and hasattr(other_ax, 'get_subplotspec'): |
| 125 | + try: |
| 126 | + if other_ax.get_subplotspec() == ax.get_subplotspec(): |
| 127 | + other_ax.remove() |
| 128 | + except Exception: |
| 129 | + pass |
| 130 | + # clear off anything which is on another axis but on this gridspec location |
| 131 | + |
| 132 | + |
| 133 | + |
| 134 | + |
106 | 135 |
|
107 | 136 |
|
108 | 137 | unc_perf = sum(uncontrolled_data_log['performance_measure']) |
|
0 commit comments