def preprocess(img):
    _img = img.resize((256,256))
    w, h = _img.size
    _h = int(h)
    _w = int(w)
    assert _w > 0
    assert _h > 0
    _img = img.resize((_w, _h))
    _img = np.array(_img)
    if len(_img.shape) == 2: 
        _img = np.expand_dims(_img, axis=-1)
    _img = _img.transpose((2, 0, 1))
    if _img.max() > 1:
        _img = _img / 255.
    return _im