Model Interface

class torchxrayvision.models.Model

The library is composed of core and baseline classifiers. Core classifiers are trained specifically for this library and baseline classifiers come from other papers that have been adapted to provide the same interface and work with the same input pixel scaling as our core models. All models will automatically resize input images (higher or lower using bilinear interpolation) to match the specified size they were trained on. This allows them to be easily swapped out for experiments. Pre-trained models are hosted on GitHub and automatically downloaded to the user’s local ~/.torchxrayvision directory.

Core pre-trained classifiers are provided as PyTorch Modules which are fully differentiable in order to work seamlessly with other PyTorch code.

forward(x: Tensor) Tensor

The model will output a tensor with the shape [batch, pathologies] which is aligned to the order of the list model.pathologies.

preds = model(img)
print(dict(zip(model.targets, preds.tolist()[0])))
# {'Atelectasis': 0.5583771,
#  'Consolidation': 0.5279943,
#  'Infiltration': 0.60061914,
#  ...
targets: List[str]

Each classifier provides a field model.targets which aligns to the list of predictions that the model makes. Depending on the weights loaded this list will change. The predictions can be aligned to pathology names as follows:

features(x: Tensor) Tensor

The pre-trained models can also be used as features extractors for semi-supervised training or transfer learning tasks. A feature vector can be obtained for each image using the model.features function. The resulting size will vary depending on the architecture and the input image size. For some models there is a model.features2 method that will extract features at a different point of the computation graph.

feats = model.features(img)

XRV Pathology Classifiers

class torchxrayvision.models.DenseNet(weights=SPECIFY, op_threshs=None, apply_sigmoid=False)

Based on “Densely Connected Convolutional Networks”

Possible weights for this class include:

## 224x224 models
model = xrv.models.DenseNet(weights="densenet121-res224-all")
model = xrv.models.DenseNet(weights="densenet121-res224-rsna") # RSNA Pneumonia Challenge
model = xrv.models.DenseNet(weights="densenet121-res224-nih") # NIH chest X-ray8
model = xrv.models.DenseNet(weights="densenet121-res224-pc") # PadChest (University of Alicante)
model = xrv.models.DenseNet(weights="densenet121-res224-chex") # CheXpert (Stanford)
model = xrv.models.DenseNet(weights="densenet121-res224-mimic_nb") # MIMIC-CXR (MIT)
model = xrv.models.DenseNet(weights="densenet121-res224-mimic_ch") # MIMIC-CXR (MIT)
  • weights – Specify a weight name to load pre-trained weights

  • op_threshs – Specify a weight name to load pre-trained weights

  • apply_sigmoid – Apply a sigmoid

targets: List[str] = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum']
class torchxrayvision.models.ResNet(weights=SPECIFY, op_threshs=None, apply_sigmoid=False)

Based on “Deep Residual Learning for Image Recognition”

Possible weights for this class include:

# 512x512 models
model = xrv.models.ResNet(weights="resnet50-res512-all")
  • weights – Specify a weight name to load pre-trained weights

  • op_threshs – Specify a weight name to load pre-trained weights

  • apply_sigmoid – Apply a sigmoid

targets: List[str] = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum']

XRV ResNet Autoencoder

class torchxrayvision.autoencoders.ResNetAE(weights=SPECIFY)

A ResNet based autoencoder.

Possible weights for this class include:

ae = xrv.autoencoders.ResNetAE(weights="101-elastic") # trained on PadChest, NIH, CheXpert, and MIMIC
z = ae.encode(image)
image2 = ae.decode(z)

CheXpert Pathology Classifier

class torchxrayvision.baseline_models.chexpert.DenseNet(weights_zip='', num_models=30)

CheXpert: A Large Chest Radiograph Dataset with Uncertainty Labels and Expert Comparison. Irvin, J., et al (2019). AAAI Conference on Artificial Intelligence.

Setting num_models less than 30 will load a subset of the ensemble.

Modified for TorchXRayVision to maintain the pytorch gradient tape and also to provide the features() argument.

Weights can be found:

targets: List[str] = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion']

JF Healthcare Pathology Classifier

class torchxrayvision.baseline_models.jfhealthcare.DenseNet(apply_sigmoid=True)

A model trained on the CheXpert data Apache-2.0 License

ChestX-Det Segmentation

class torchxrayvision.baseline_models.chestx_det.PSPNet

ChestX-Det Segmentation Model

You can load pretrained anatomical segmentation models. Demo Notebook

seg_model = xrv.baseline_models.chestx_det.PSPNet()
output = seg_model(image)
output.shape # [1, 14, 512, 512]
seg_model.targets # ['Left Clavicle', 'Right Clavicle', 'Left Scapula', 'Right Scapula',
                  #  'Left Lung', 'Right Lung', 'Left Hilus Pulmonis', 'Right Hilus Pulmonis',
                  #  'Heart', 'Aorta', 'Facies Diaphragmatica', 'Mediastinum',  'Weasand', 'Spine']

targets: List[str] = ['Left Clavicle', 'Right Clavicle', 'Left Scapula', 'Right Scapula', 'Left Lung', 'Right Lung', 'Left Hilus Pulmonis', 'Right Hilus Pulmonis', 'Heart', 'Aorta', 'Facies Diaphragmatica', 'Mediastinum', 'Weasand', 'Spine']

Emory HITI Race

class torchxrayvision.baseline_models.emory_hiti.RaceModel

This model is from the work below and is trained to predict the patient race from a chest X-ray. Public data from the MIMIC dataset is used to train this model. The native resolution of the model is 320x320. Images are scaled automatically.

Demo notebook

model = xrv.baseline_models.emory_hiti.RaceModel()

image = xrv.utils.load_image('00027426_000.png')
image = torch.from_numpy(image)[None,...]

pred = model(image)

# 'White'
targets: List[str] = ['Asian', 'Black', 'White']

Riken Age Model

class torchxrayvision.baseline_models.riken.AgeModel

This model predicts age. It is trained on the NIH dataset. The publication reports a mean absolute error (MAE) between the estimated age and chronological age of 3.67 years.

The native resolution of the model is 320x320. Images are scaled automatically.

Demo notebook

model = xrv.baseline_models.riken.AgeModel()

image = xrv.utils.load_image('00027426_000.png')
image = torch.from_numpy(image)[None,...]

pred = model(image)
# tensor([[50.4033]], grad_fn=<AddmmBackward0>)


targets: List[str] = ['Age']

Xinario View Model

class torchxrayvision.baseline_models.xinario.ViewModel

The native resolution of the model is 320x320. Images are scaled automatically.

Demo notebook

model = xrv.baseline_models.xinario.ViewModel()

image = xrv.utils.load_image('00027426_000.png')
image = torch.from_numpy(image)[None,...]

pred = model(image)
# tensor([[17.3186, 26.7156]]), grad_fn=<AddmmBackward0>)

# Lateral


targets: List[str] = ['Frontal', 'Lateral']