First let's get some imports out of the way:
import tensorflow as tf
import fastestimator as fe
from fastestimator import Network, Pipeline, Estimator
from fastestimator.dataset.data import horse2zebra
from fastestimator.op.numpyop.multivariate import Resize
from fastestimator.op.numpyop.univariate import Normalize, ReadImage
from fastestimator.op.tensorop import LambdaOp
from fastestimator.op.tensorop.gradient import GradientOp
from fastestimator.op.tensorop.model import ModelOp
from fastestimator.trace.io import ImageViewer
from fastestimator.trace.xai import GradCAM
from fastestimator.util import BatchDisplay
Example Data and Pipeline¶
For this tutorial we will use some pictures of zebras with minimal pre-processing. Let's visualize some of the images to see what we're working with.
train_data, eval_data = horse2zebra.load_data(batch_size=5)
test_data = eval_data.split(range(5)) # We will just use the first 5 images for our visualizations
pipeline = Pipeline(test_data=test_data,
ops=[ReadImage(inputs="B", outputs="B"),
Resize(image_in="B", image_out="B", height=224, width=224),
Normalize(inputs="B", outputs="B", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
])
batch = pipeline.get_results(mode='test')
fig = BatchDisplay(image=batch['B'], title="Zebras")
fig.show()
Extracting Intermediate Layer Gradients¶
We will use a pre-trained ResNet151 model for this example, with standard ImageNet weights. We will inspect the model to decide which layer we want to analyze with GradCAM.
model = fe.build(model_fn=lambda: tf.keras.applications.ResNet152V2(weights="imagenet"), optimizer_fn="adam")
print("\n".join([f"{idx}: {x.name}" for idx, x in enumerate(model._flatten_layers(include_self=False, recursive=True))]))
0: input_1 1: conv1_pad 2: conv1_conv 3: pool1_pad 4: pool1_pool 5: conv2_block1_preact_bn 6: conv2_block1_preact_relu 7: conv2_block1_1_conv 8: conv2_block1_1_bn 9: conv2_block1_1_relu 10: conv2_block1_2_pad 11: conv2_block1_2_conv 12: conv2_block1_2_bn 13: conv2_block1_2_relu 14: conv2_block1_0_conv 15: conv2_block1_3_conv 16: conv2_block1_out 17: conv2_block2_preact_bn 18: conv2_block2_preact_relu 19: conv2_block2_1_conv 20: conv2_block2_1_bn 21: conv2_block2_1_relu 22: conv2_block2_2_pad 23: conv2_block2_2_conv 24: conv2_block2_2_bn 25: conv2_block2_2_relu 26: conv2_block2_3_conv 27: conv2_block2_out 28: conv2_block3_preact_bn 29: conv2_block3_preact_relu 30: conv2_block3_1_conv 31: conv2_block3_1_bn 32: conv2_block3_1_relu 33: conv2_block3_2_pad 34: conv2_block3_2_conv 35: conv2_block3_2_bn 36: conv2_block3_2_relu 37: max_pooling2d 38: conv2_block3_3_conv 39: conv2_block3_out 40: conv3_block1_preact_bn 41: conv3_block1_preact_relu 42: conv3_block1_1_conv 43: conv3_block1_1_bn 44: conv3_block1_1_relu 45: conv3_block1_2_pad 46: conv3_block1_2_conv 47: conv3_block1_2_bn 48: conv3_block1_2_relu 49: conv3_block1_0_conv 50: conv3_block1_3_conv 51: conv3_block1_out 52: conv3_block2_preact_bn 53: conv3_block2_preact_relu 54: conv3_block2_1_conv 55: conv3_block2_1_bn 56: conv3_block2_1_relu 57: conv3_block2_2_pad 58: conv3_block2_2_conv 59: conv3_block2_2_bn 60: conv3_block2_2_relu 61: conv3_block2_3_conv 62: conv3_block2_out 63: conv3_block3_preact_bn 64: conv3_block3_preact_relu 65: conv3_block3_1_conv 66: conv3_block3_1_bn 67: conv3_block3_1_relu 68: conv3_block3_2_pad 69: conv3_block3_2_conv 70: conv3_block3_2_bn 71: conv3_block3_2_relu 72: conv3_block3_3_conv 73: conv3_block3_out 74: conv3_block4_preact_bn 75: conv3_block4_preact_relu 76: conv3_block4_1_conv 77: conv3_block4_1_bn 78: conv3_block4_1_relu 79: conv3_block4_2_pad 80: conv3_block4_2_conv 81: conv3_block4_2_bn 82: conv3_block4_2_relu 83: conv3_block4_3_conv 84: conv3_block4_out 85: conv3_block5_preact_bn 86: conv3_block5_preact_relu 87: conv3_block5_1_conv 88: conv3_block5_1_bn 89: conv3_block5_1_relu 90: conv3_block5_2_pad 91: conv3_block5_2_conv 92: conv3_block5_2_bn 93: conv3_block5_2_relu 94: conv3_block5_3_conv 95: conv3_block5_out 96: conv3_block6_preact_bn 97: conv3_block6_preact_relu 98: conv3_block6_1_conv 99: conv3_block6_1_bn 100: conv3_block6_1_relu 101: conv3_block6_2_pad 102: conv3_block6_2_conv 103: conv3_block6_2_bn 104: conv3_block6_2_relu 105: conv3_block6_3_conv 106: conv3_block6_out 107: conv3_block7_preact_bn 108: conv3_block7_preact_relu 109: conv3_block7_1_conv 110: conv3_block7_1_bn 111: conv3_block7_1_relu 112: conv3_block7_2_pad 113: conv3_block7_2_conv 114: conv3_block7_2_bn 115: conv3_block7_2_relu 116: conv3_block7_3_conv 117: conv3_block7_out 118: conv3_block8_preact_bn 119: conv3_block8_preact_relu 120: conv3_block8_1_conv 121: conv3_block8_1_bn 122: conv3_block8_1_relu 123: conv3_block8_2_pad 124: conv3_block8_2_conv 125: conv3_block8_2_bn 126: conv3_block8_2_relu 127: max_pooling2d_1 128: conv3_block8_3_conv 129: conv3_block8_out 130: conv4_block1_preact_bn 131: conv4_block1_preact_relu 132: conv4_block1_1_conv 133: conv4_block1_1_bn 134: conv4_block1_1_relu 135: conv4_block1_2_pad 136: conv4_block1_2_conv 137: conv4_block1_2_bn 138: conv4_block1_2_relu 139: conv4_block1_0_conv 140: conv4_block1_3_conv 141: conv4_block1_out 142: conv4_block2_preact_bn 143: conv4_block2_preact_relu 144: conv4_block2_1_conv 145: conv4_block2_1_bn 146: conv4_block2_1_relu 147: conv4_block2_2_pad 148: conv4_block2_2_conv 149: conv4_block2_2_bn 150: conv4_block2_2_relu 151: conv4_block2_3_conv 152: conv4_block2_out 153: conv4_block3_preact_bn 154: conv4_block3_preact_relu 155: conv4_block3_1_conv 156: conv4_block3_1_bn 157: conv4_block3_1_relu 158: conv4_block3_2_pad 159: conv4_block3_2_conv 160: conv4_block3_2_bn 161: conv4_block3_2_relu 162: conv4_block3_3_conv 163: conv4_block3_out 164: conv4_block4_preact_bn 165: conv4_block4_preact_relu 166: conv4_block4_1_conv 167: conv4_block4_1_bn 168: conv4_block4_1_relu 169: conv4_block4_2_pad 170: conv4_block4_2_conv 171: conv4_block4_2_bn 172: conv4_block4_2_relu 173: conv4_block4_3_conv 174: conv4_block4_out 175: conv4_block5_preact_bn 176: conv4_block5_preact_relu 177: conv4_block5_1_conv 178: conv4_block5_1_bn 179: conv4_block5_1_relu 180: conv4_block5_2_pad 181: conv4_block5_2_conv 182: conv4_block5_2_bn 183: conv4_block5_2_relu 184: conv4_block5_3_conv 185: conv4_block5_out 186: conv4_block6_preact_bn 187: conv4_block6_preact_relu 188: conv4_block6_1_conv 189: conv4_block6_1_bn 190: conv4_block6_1_relu 191: conv4_block6_2_pad 192: conv4_block6_2_conv 193: conv4_block6_2_bn 194: conv4_block6_2_relu 195: conv4_block6_3_conv 196: conv4_block6_out 197: conv4_block7_preact_bn 198: conv4_block7_preact_relu 199: conv4_block7_1_conv 200: conv4_block7_1_bn 201: conv4_block7_1_relu 202: conv4_block7_2_pad 203: conv4_block7_2_conv 204: conv4_block7_2_bn 205: conv4_block7_2_relu 206: conv4_block7_3_conv 207: conv4_block7_out 208: conv4_block8_preact_bn 209: conv4_block8_preact_relu 210: conv4_block8_1_conv 211: conv4_block8_1_bn 212: conv4_block8_1_relu 213: conv4_block8_2_pad 214: conv4_block8_2_conv 215: conv4_block8_2_bn 216: conv4_block8_2_relu 217: conv4_block8_3_conv 218: conv4_block8_out 219: conv4_block9_preact_bn 220: conv4_block9_preact_relu 221: conv4_block9_1_conv 222: conv4_block9_1_bn 223: conv4_block9_1_relu 224: conv4_block9_2_pad 225: conv4_block9_2_conv 226: conv4_block9_2_bn 227: conv4_block9_2_relu 228: conv4_block9_3_conv 229: conv4_block9_out 230: conv4_block10_preact_bn 231: conv4_block10_preact_relu 232: conv4_block10_1_conv 233: conv4_block10_1_bn 234: conv4_block10_1_relu 235: conv4_block10_2_pad 236: conv4_block10_2_conv 237: conv4_block10_2_bn 238: conv4_block10_2_relu 239: conv4_block10_3_conv 240: conv4_block10_out 241: conv4_block11_preact_bn 242: conv4_block11_preact_relu 243: conv4_block11_1_conv 244: conv4_block11_1_bn 245: conv4_block11_1_relu 246: conv4_block11_2_pad 247: conv4_block11_2_conv 248: conv4_block11_2_bn 249: conv4_block11_2_relu 250: conv4_block11_3_conv 251: conv4_block11_out 252: conv4_block12_preact_bn 253: conv4_block12_preact_relu 254: conv4_block12_1_conv 255: conv4_block12_1_bn 256: conv4_block12_1_relu 257: conv4_block12_2_pad 258: conv4_block12_2_conv 259: conv4_block12_2_bn 260: conv4_block12_2_relu 261: conv4_block12_3_conv 262: conv4_block12_out 263: conv4_block13_preact_bn 264: conv4_block13_preact_relu 265: conv4_block13_1_conv 266: conv4_block13_1_bn 267: conv4_block13_1_relu 268: conv4_block13_2_pad 269: conv4_block13_2_conv 270: conv4_block13_2_bn 271: conv4_block13_2_relu 272: conv4_block13_3_conv 273: conv4_block13_out 274: conv4_block14_preact_bn 275: conv4_block14_preact_relu 276: conv4_block14_1_conv 277: conv4_block14_1_bn 278: conv4_block14_1_relu 279: conv4_block14_2_pad 280: conv4_block14_2_conv 281: conv4_block14_2_bn 282: conv4_block14_2_relu 283: conv4_block14_3_conv 284: conv4_block14_out 285: conv4_block15_preact_bn 286: conv4_block15_preact_relu 287: conv4_block15_1_conv 288: conv4_block15_1_bn 289: conv4_block15_1_relu 290: conv4_block15_2_pad 291: conv4_block15_2_conv 292: conv4_block15_2_bn 293: conv4_block15_2_relu 294: conv4_block15_3_conv 295: conv4_block15_out 296: conv4_block16_preact_bn 297: conv4_block16_preact_relu 298: conv4_block16_1_conv 299: conv4_block16_1_bn 300: conv4_block16_1_relu 301: conv4_block16_2_pad 302: conv4_block16_2_conv 303: conv4_block16_2_bn 304: conv4_block16_2_relu 305: conv4_block16_3_conv 306: conv4_block16_out 307: conv4_block17_preact_bn 308: conv4_block17_preact_relu 309: conv4_block17_1_conv 310: conv4_block17_1_bn 311: conv4_block17_1_relu 312: conv4_block17_2_pad 313: conv4_block17_2_conv 314: conv4_block17_2_bn 315: conv4_block17_2_relu 316: conv4_block17_3_conv 317: conv4_block17_out 318: conv4_block18_preact_bn 319: conv4_block18_preact_relu 320: conv4_block18_1_conv 321: conv4_block18_1_bn 322: conv4_block18_1_relu 323: conv4_block18_2_pad 324: conv4_block18_2_conv 325: conv4_block18_2_bn 326: conv4_block18_2_relu 327: conv4_block18_3_conv 328: conv4_block18_out 329: conv4_block19_preact_bn 330: conv4_block19_preact_relu 331: conv4_block19_1_conv 332: conv4_block19_1_bn 333: conv4_block19_1_relu 334: conv4_block19_2_pad 335: conv4_block19_2_conv 336: conv4_block19_2_bn 337: conv4_block19_2_relu 338: conv4_block19_3_conv 339: conv4_block19_out 340: conv4_block20_preact_bn 341: conv4_block20_preact_relu 342: conv4_block20_1_conv 343: conv4_block20_1_bn 344: conv4_block20_1_relu 345: conv4_block20_2_pad 346: conv4_block20_2_conv 347: conv4_block20_2_bn 348: conv4_block20_2_relu 349: conv4_block20_3_conv 350: conv4_block20_out 351: conv4_block21_preact_bn 352: conv4_block21_preact_relu 353: conv4_block21_1_conv 354: conv4_block21_1_bn 355: conv4_block21_1_relu 356: conv4_block21_2_pad 357: conv4_block21_2_conv 358: conv4_block21_2_bn 359: conv4_block21_2_relu 360: conv4_block21_3_conv 361: conv4_block21_out 362: conv4_block22_preact_bn 363: conv4_block22_preact_relu 364: conv4_block22_1_conv 365: conv4_block22_1_bn 366: conv4_block22_1_relu 367: conv4_block22_2_pad 368: conv4_block22_2_conv 369: conv4_block22_2_bn 370: conv4_block22_2_relu 371: conv4_block22_3_conv 372: conv4_block22_out 373: conv4_block23_preact_bn 374: conv4_block23_preact_relu 375: conv4_block23_1_conv 376: conv4_block23_1_bn 377: conv4_block23_1_relu 378: conv4_block23_2_pad 379: conv4_block23_2_conv 380: conv4_block23_2_bn 381: conv4_block23_2_relu 382: conv4_block23_3_conv 383: conv4_block23_out 384: conv4_block24_preact_bn 385: conv4_block24_preact_relu 386: conv4_block24_1_conv 387: conv4_block24_1_bn 388: conv4_block24_1_relu 389: conv4_block24_2_pad 390: conv4_block24_2_conv 391: conv4_block24_2_bn 392: conv4_block24_2_relu 393: conv4_block24_3_conv 394: conv4_block24_out 395: conv4_block25_preact_bn 396: conv4_block25_preact_relu 397: conv4_block25_1_conv 398: conv4_block25_1_bn 399: conv4_block25_1_relu 400: conv4_block25_2_pad 401: conv4_block25_2_conv 402: conv4_block25_2_bn 403: conv4_block25_2_relu 404: conv4_block25_3_conv 405: conv4_block25_out 406: conv4_block26_preact_bn 407: conv4_block26_preact_relu 408: conv4_block26_1_conv 409: conv4_block26_1_bn 410: conv4_block26_1_relu 411: conv4_block26_2_pad 412: conv4_block26_2_conv 413: conv4_block26_2_bn 414: conv4_block26_2_relu 415: conv4_block26_3_conv 416: conv4_block26_out 417: conv4_block27_preact_bn 418: conv4_block27_preact_relu 419: conv4_block27_1_conv 420: conv4_block27_1_bn 421: conv4_block27_1_relu 422: conv4_block27_2_pad 423: conv4_block27_2_conv 424: conv4_block27_2_bn 425: conv4_block27_2_relu 426: conv4_block27_3_conv 427: conv4_block27_out 428: conv4_block28_preact_bn 429: conv4_block28_preact_relu 430: conv4_block28_1_conv 431: conv4_block28_1_bn 432: conv4_block28_1_relu 433: conv4_block28_2_pad 434: conv4_block28_2_conv 435: conv4_block28_2_bn 436: conv4_block28_2_relu 437: conv4_block28_3_conv 438: conv4_block28_out 439: conv4_block29_preact_bn 440: conv4_block29_preact_relu 441: conv4_block29_1_conv 442: conv4_block29_1_bn 443: conv4_block29_1_relu 444: conv4_block29_2_pad 445: conv4_block29_2_conv 446: conv4_block29_2_bn 447: conv4_block29_2_relu 448: conv4_block29_3_conv 449: conv4_block29_out 450: conv4_block30_preact_bn 451: conv4_block30_preact_relu 452: conv4_block30_1_conv 453: conv4_block30_1_bn 454: conv4_block30_1_relu 455: conv4_block30_2_pad 456: conv4_block30_2_conv 457: conv4_block30_2_bn 458: conv4_block30_2_relu 459: conv4_block30_3_conv 460: conv4_block30_out 461: conv4_block31_preact_bn 462: conv4_block31_preact_relu 463: conv4_block31_1_conv 464: conv4_block31_1_bn 465: conv4_block31_1_relu 466: conv4_block31_2_pad 467: conv4_block31_2_conv 468: conv4_block31_2_bn 469: conv4_block31_2_relu 470: conv4_block31_3_conv 471: conv4_block31_out 472: conv4_block32_preact_bn 473: conv4_block32_preact_relu 474: conv4_block32_1_conv 475: conv4_block32_1_bn 476: conv4_block32_1_relu 477: conv4_block32_2_pad 478: conv4_block32_2_conv 479: conv4_block32_2_bn 480: conv4_block32_2_relu 481: conv4_block32_3_conv 482: conv4_block32_out 483: conv4_block33_preact_bn 484: conv4_block33_preact_relu 485: conv4_block33_1_conv 486: conv4_block33_1_bn 487: conv4_block33_1_relu 488: conv4_block33_2_pad 489: conv4_block33_2_conv 490: conv4_block33_2_bn 491: conv4_block33_2_relu 492: conv4_block33_3_conv 493: conv4_block33_out 494: conv4_block34_preact_bn 495: conv4_block34_preact_relu 496: conv4_block34_1_conv 497: conv4_block34_1_bn 498: conv4_block34_1_relu 499: conv4_block34_2_pad 500: conv4_block34_2_conv 501: conv4_block34_2_bn 502: conv4_block34_2_relu 503: conv4_block34_3_conv 504: conv4_block34_out 505: conv4_block35_preact_bn 506: conv4_block35_preact_relu 507: conv4_block35_1_conv 508: conv4_block35_1_bn 509: conv4_block35_1_relu 510: conv4_block35_2_pad 511: conv4_block35_2_conv 512: conv4_block35_2_bn 513: conv4_block35_2_relu 514: conv4_block35_3_conv 515: conv4_block35_out 516: conv4_block36_preact_bn 517: conv4_block36_preact_relu 518: conv4_block36_1_conv 519: conv4_block36_1_bn 520: conv4_block36_1_relu 521: conv4_block36_2_pad 522: conv4_block36_2_conv 523: conv4_block36_2_bn 524: conv4_block36_2_relu 525: max_pooling2d_2 526: conv4_block36_3_conv 527: conv4_block36_out 528: conv5_block1_preact_bn 529: conv5_block1_preact_relu 530: conv5_block1_1_conv 531: conv5_block1_1_bn 532: conv5_block1_1_relu 533: conv5_block1_2_pad 534: conv5_block1_2_conv 535: conv5_block1_2_bn 536: conv5_block1_2_relu 537: conv5_block1_0_conv 538: conv5_block1_3_conv 539: conv5_block1_out 540: conv5_block2_preact_bn 541: conv5_block2_preact_relu 542: conv5_block2_1_conv 543: conv5_block2_1_bn 544: conv5_block2_1_relu 545: conv5_block2_2_pad 546: conv5_block2_2_conv 547: conv5_block2_2_bn 548: conv5_block2_2_relu 549: conv5_block2_3_conv 550: conv5_block2_out 551: conv5_block3_preact_bn 552: conv5_block3_preact_relu 553: conv5_block3_1_conv 554: conv5_block3_1_bn 555: conv5_block3_1_relu 556: conv5_block3_2_pad 557: conv5_block3_2_conv 558: conv5_block3_2_bn 559: conv5_block3_2_relu 560: conv5_block3_3_conv 561: conv5_block3_out 562: post_bn 563: post_relu 564: avg_pool 565: predictions
This model has quite a few layers to choose from. We will try a convolution block near the end of the model. In order to get a GradCAM image we will need the gradients of the model prediction with respect to the intermediate layer outputs. This can be done using a combination of the ModelOp and GradientOp. We will also use a LambdaOp in order to get the gradients of only the model's most confident prediction.
network = Network(ops=[
ModelOp(model=model, inputs="B", outputs=["y_pred", "embedding"], intermediate_layers='conv5_block1_out'),
LambdaOp(inputs="y_pred", outputs="y_pred_max", fn=lambda x: tf.reduce_max(x, axis=-1)),
GradientOp(finals="y_pred_max", inputs="embedding", outputs="grads", mode="!train")
])
traces = [
GradCAM(images="B", grads="grads", preds="y_pred"),
ImageViewer(inputs="gradcam", mode="test")
]
estimator = Estimator(pipeline=pipeline,
network=network,
traces=traces,
epochs=1,
)
estimator.test()
FastEstimator-Warn: the key 'A' is being pruned since it is unused outside of the Pipeline. To prevent this, you can declare the key as an input of a Trace or TensorOp.
FastEstimator-Test: step: None; epoch: 1;
ImageNet class 340 refers to zebras, so the model is correct for all of our data here. We can also see through the GradCAM output that the model sometimes seems to care more about the background of the image than the Zebras themselves. This might indicate that the model is using features that humans would deem non-robust in order to make its decisions. Let's compare this with an untrained model:
model = fe.build(model_fn=lambda: tf.keras.applications.ResNet152V2(weights=None), optimizer_fn="adam")
network = Network(ops=[
ModelOp(model=model, inputs="B", outputs=["y_pred", "embedding"], intermediate_layers='conv5_block1_out'),
LambdaOp(inputs="y_pred", outputs="y_pred_max", fn=lambda x: tf.reduce_max(x, axis=-1)),
GradientOp(finals="y_pred_max", inputs="embedding", outputs="grads", mode="!train")
])
traces = [
GradCAM(images="B", grads="grads", preds="y_pred"),
ImageViewer(inputs="gradcam", mode="test")
]
estimator = Estimator(pipeline=pipeline,
network=network,
traces=traces,
epochs=1,
)
estimator.test()
FastEstimator-Test: step: None; epoch: 1;
As we can see from the images above, the untrained model seems to 'focus' all over the images with no apparent correlation to the zebras whatsoever. The training process clearly helps the network to focus in on more specific parts of the images for its predictions.