Skip to main content

Zero truncated count distributions and their negative log likelihoods.

Recently, I performed research on single-cell transcriptomic transformers. When porting language model paradigms to this data modality, a pre-training task can be formulated that predicts gene expression levels from their partly masked versions. The gene expression levels typically consist of discrete counts. These are commonly assumed to follow either a (1) Poisson, (2) negative binomial, or (3) zero-inflated negative binomial distribution.

Using single-cell transformers on gene expression profiles poses challenges w.r.t. data dimensionality. In essence, every expression profile contains ± 20 000 genes, each making up one input token. This number of tokens will pose issues for research-institute-scale hardware, even when using memory efficient versions of self-attention. Luckily, it is not unreasonable to assume that all the relevant expression information is contained within the non-zero counts in the profile[1]. By removing all zero counts from the expression profiles, the usual distributional assumptions do not hold anymore. Instead, all prior distributions become zero-truncated. This blogpost describes the Zero-truncated versions of the Poisson and Negative binomial distributions.

The zero-truncated Poisson distribution #

The Poisson distribution is given by:

fPois(k;λ)=Pr(X=k)=λkeλk!

characterizing the probability of observing a (true) count k, given an (estimated) mean λ. Its zero-truncated version is, then:

fZT-Pois(k;λ)=Pr(X=k|X>0)=fPois(k;λ)1fPois(0;λ)=λkeλ/k!1λ0eλ/0!=λkeλ(1eλ)k!=λk(eλ1)k!

The negative log-likelihood of the zero-truncated Poisson distribution is, hence:

NLLZT-Pois=log(λk(eλ1)k!)=log(λk)+log(eλ1)+log(k!)=klog(λ)+log(eλ1)+log(k!)

When estimating λ for an observed count k, the last term log(k!) can be ignored. Further, for large estimated numbers of λ, log(eλ1) becomes unstable. As limλinf(log(eλ1))=λ, we can use this approximation for numerical stability during optimization. Not so incidentally, replacing this term by the other essentially amounts to switching back to the NLL of the “default” Poisson distribution. This makes sense, as at large values of λ, the two distributions are essentially equal (as the probability of obtaining a zero count with the default Poisson will approach zero). The following heatmaps show the NLL values for differing values of true counts k and predicted mean λ, both for the zero-truncated Poisson as well as the default Poisson.

ztpois

One can see that substituting the second loss term log(eλ1) for λ when, say λ>10 does not result in significant approximation errors.

The zero-truncated negative binomial distribution #

The negative binomial distribution is given by:

fNB(k;μ,θ)=Pr(X=k)=Γ(k+θ)k!Γ(θ)(θθ+μ)θ(μθ+μ)k

characterizing the probability of observing a (true) count k, given an (estimated) mean μ and overdispersion θ. Its zero-truncated version is, then:

fZT-NB(k;μ,θ)=Pr(X=k|X>0)=fNB(k;μ,θ)1fNB(0;μ,θ)=Γ(k+θ)k!Γ(θ)(θθ+μ)θ(μθ+μ)k1Γ(0+θ)0!Γ(θ)(θθ+μ)θ(μθ+μ)0=Γ(k+θ)k!Γ(θ)(θθ+μ)θ(μθ+μ)k1(θθ+μ)θ=Γ(k+θ)k!Γ(θ)(θθ+μ)θ1(θθ+μ)θ(μθ+μ)k=Γ(k+θ)k!Γ(θ)θθ(1θθ(θ+μ)θ)(θ+μ)θ(μθ+μ)k=Γ(k+θ)k!Γ(θ)θθ(θ+μ)θθθ(μθ+μ)k

The negative log-likelihood of the zero-truncated Negative Binomial distribution is, hence:

NLLZT-NB=log(Γ(k+θ)k!Γ(θ)θθ(θ+μ)θθθ(μθ+μ)k)=log(Γ(k+θ))+log(k!)+log(Γ(θ))θlog(θ)+log((θ+μ)θθθ)klog(μ)+klog(θ+μ)

Similarly, when estimating μ and θ for an observed count k, log(k!) can be ignored. Further, if μ and θ are both large, the term log((θ+μ)θθθ) becomes numerically unstable. Similarly to the Poisson distribution, we can exchange this term for θlog(θ+μ) to obtain the NLL for the default negative binomial distribution. The following heatmaps show the NLL values for differing values of predicted overdispersion θ and predicted means μ. The observed count k is fixed at different levels (at each row in the plot). The first column and second columns show NLL values for the zero-truncated NB and the “default” NB, respectively.

ztnb

One can see that the “default” NB NLL becomes a good approximation of the zero-truncated NB NLL at high values of μ and (to a lesser extent) large values of θ. In our preprint, we used the approximation when either μ+θ>15 or (μ1)(θ1)>15.

The code for using these as loss functions within PyTorch can be found here.


References and footnotes

  1. Just like we do not need to explicitly mention to a transformer that some words were not present in a sentence, it might not be necessary to communicate to a transformer that some genes were inactive. The communication that this is the case can happen implicitly simply by this token not being present. ↩︎