1 Introduction

The evaluation of high-dimensional expectations under unnormalized distributions is one of the main challenges of statistics and machine learning [4]. In Bayesian decision theory, for instance, one seeks to minimize a risk function that is defined as an expectation of a functional under a often intractable and unnormalized posterior distribution provisioned by Bayes’ rule [21]. Similarly, many model-free reinforcement learning algorithms for searching an optimal policy in a contextual Markov decision process rely on integrating out a high-dimensional context variable of the quality function. Under these circumstances, a common approach consists of substituting the unmanageable integration problem by a manageable optimization problem, which is known as variational inference. For this, one introduces a tractable and large set of parameteric distributions, known as the variational family, and finds the family’s member that most closely matches the unnormalized target by solving a stochastic optimization problem. Then, one uses the solution to this problem as a surrogate to the intractable target and evaluate the desired expectations using Monte Carlo estimators.

Recently, Generative Flow Networks (GFlowNets) [2, 3, 13] were proposed as a flexible variational family parameterized by neural networks to approximate unnormalized distributions with compositional supports [15], with remarkable success in relevant applications including causal inference [5, 24], drug discovery [2], refinement of large language models [9], and combinatorial optimization [28]. To implement a GFlowNet, we first define a flow network over an extension of the target distribution’s support, which thereafter corresponds to the sink nodes, and then search for a consistent flow assignment for which the flow going through each sink node matches its unnormalized probability. During inference, we travel across the defined network by starting at the origin and choosing each transition according to the estimated flow therein. If the flow is correctly estimated, this procedure is guaranteed to yield independent and identically distributed (i.i.d.) samples from the intractable target distribution [2, 13].

To successfully find a consistent and balanced flow assignment, we conventionally parameterize the network’s transition policies with versatile models as, e.g., multilayer perceptrons or graph neural networks. Then, to estimate the model’s parameters, we minimize a loss function consisting of the expectation of log-squared local balance violations according to a prescribed exploratory policy via a variant of SGD. Importantly, this exploratory policy may be different of the learned policy, in which case the training is said to be carried in an off-policy fashion. We remark that, when dealing with high-dimensional and highly sparse target distributions, an off-policy sampling scheme is paramount for avoiding the collapse of the variational distribution [14, 15]. Nonetheless, as we note in Sect. 2, such a possibility of off-policy learning comes at the cost of introducing additional parameters that often entail the estimation of high-dimensional integrals as a subproblem to the successful training of GFlowNets, potentially hindering the sample-efficiency of these models.

Seeking to bypass this problem, recent work [23, 26] developed a contrastive learning objective for GFlowNets. In the spirit of Hinton’s contrastive divergence learning [8], these approaches sidestep the estimation of difficult-to-approximate quantities by, instead of directly optimizing an optimality criteria, minimizing the difference between them. This ensures that the intractable terms cancel out and that the resulting problem is computationally easier to solve. Remarkably, however, these works are constrained to sampling from unnormalized distributions with finite support; a formal extension and an empirical validation of these objectives to the non-discrete setting is lacking in the literature. In this context, we extend both the contrastive balance condition and the contrastive loss function [23] to train GFlowNets defined on measurable topological spaces [13]. In addition, we empirically show that the resulting continuous contrastive loss (CCL) frequently leads to faster training convergence with respect to competing learning objectives for the standard tasks of approximating a sparse mixture of multivariate Gaussian distributions and a banana-shaped distribution.

In summary, our contributions are:

  1. 1.

    We derive a contrastive balance condition for continuous GFlowNets and rigorously show that it is a sufficient for ensuring sampling correctness;

  2. 2.

    We propose the continuous contrastive loss (CCL) as a theoretically sound learning objective for GFlowNets that does not rely on the estimation of high-dimensional integrals via SGD during training;

  3. 3.

    We empirically compare the performance of CCL against alternative loss functions and show that CCL leads to faster convergence in most cases.

The paper is organized as follows. In Sect. 2, we recall the notions of a measurable pointed directed acyclic graph (DAG), transition kernels and flow networks, laying out the theoretical framework upon which GFlowNets are built, and also review popular learning objectives for continuous GFlowNets. Then, in Sect. 3, we rigorously derive the contrastive balance condition and propose CCL as a sound learning objective for GFlowNets, the performance of which is assessed in Sect. 4 for the approximation of a mixture of multivariate Gaussians and of a banana-shaped distributions [16]. Finally, we provide a summary of the relevant literature in Sect. 5 and a discussion of future directions, along with some concluding remarks, in Sect. 6.

2 Background

This section reviews the main components of and the commonly used learning objectives for GFlowNets. To build intuition, we firstly define a GFlowNet for finitely supported distributions in Sect. 2.1. Then, we review the formal framework of Lahlou et al. [13] for extending the notion of a flow networks to arbitrary topological spaces by substituting an adjacency matrix by a transition kernel that describes the continuous network’s connectivity.

2.1 GFlowNets in Discrete Spaces

Discrete GFlowNets. Let R be a (potentially unnormalized) probability distribution on a finite space \(\mathcal {X}\) and \(S \supseteq \mathcal {X}\) be an extension of \(\mathcal {X}\) with two distinguished elements, \(s_{o}\) and \(s_{f}\). Our objective is to randomly sample elements \(x \in \mathcal {X}\) proportionally to R(x). For this, we let \(G = (S, E)\) be a DAG with edges E such that (i) the only nodes connected to \(s_{f}\) are those in \(\mathcal {X}\), (ii) there is a path from \(s_{o}\) to every \(x \in \mathcal {X}\), (iii) there are no outgoing edges from \(s_{f}\) and there are no incoming edges to \(s_{o}\). We call \(\mathcal {G}\) the state graph. In this setting, let \(p_{F} :S \times S \rightarrow \mathbb {R}_{+}\) be a forward policy on \(\mathcal {G}\), i.e., a function such that \(p_{F}(s, \cdot )\) is a probability measure on S supported on the children of s in \(\mathcal {G}\) and \(p_{F}(\tau ) = \prod _{1 \le i \le n} p_{F}(s_{i} | s_{i - 1})\) be the corresponding distribution over trajectories \(\tau = (s_{o}, \dots , s_{n - 1}, s_{n} = s_{f})\). To accomplish our objective, we must find a \(p_{F}\) for which

$$\begin{aligned} \sum _{\tau :\tau \rightsquigarrow x} p_{F}(\tau ) \propto R(x). \end{aligned}$$
(1)

In practice, the left-hand-side of Eq. (1) is computationally intractable and is satisfied by a possibly infinite number of \(p_{F}\)’s. To ensure tractability and uniqueness, we also introduce a backward policy \(p_{B} :S \times S \rightarrow \mathbb {R}_{+}\), which is a forward policy over the transpose of \(\mathcal {G}\), and rewrite Eq. (1) as \(p_{F}(\tau ) \propto p_{B}(\tau | x) R(x)\). Malkin et al. [14] showed that such technique induces a well-posed problem with a unique solution.

Training GFlowNets. To find a forward policy abiding by Eq. (1), we parameterize \(p_{F}\) with a neural network with parameters \(\theta \) and fix \(p_{B}(s, \cdot )\) as an uniform policy for each \(s \in S\). Then, by introducing an additional learnable function \(F :S \rightarrow \mathbb {R}_{+}\), which is also parameterized by a (different) neural network \(\gamma \), we concomitantly estimate \(\theta \) and \(\gamma \) by minimizing one of the following objectives. Below, \(p_{e}\) is an exploratory policy, i.e., a forward policy that does not depend on \((\theta , \gamma )\) and has full support over the space of trajectories. Also, we denote by x (resp. \(x'\)) the state immediately preceding \(s_{f}\) in a trajectory \(\tau \) (resp. \(\tau '\)).

  1. 1.

    The detailed balance (DB) loss [3] corresponds to \(\mathcal {L}_{DB}(\theta , \gamma ) =\)

    $$\begin{aligned} \begin{aligned} \mathbb {E}_{\tau \sim p_{e}}\left[ \frac{1}{|\tau |}\sum _{(s_{i - 1}, s_{i}) \in \tau } \left( \log \frac{F_{\gamma }(s_{i - 1})p_{F}(s_{i} | s_{i - 1})}{p_{B}(s_{i} | s_{i - 1}) F_{\gamma }(s_{i - 1})}\right) ^{2} + \left( \log \frac{F_{\gamma }(x)}{R(x)} \right) ^{2} \right] \end{aligned} \end{aligned}$$

    which enforces the DB condition, \(F(s) p_{F}(s | s') = F(s') p_{B}(s|s')\) for \((s, s') \in E\) and \(F(x) = R(x)\) for the states \(x \in S\).

  2. 2.

    The trajectory balance (TB) loss [14] introduces \(\log Z(\gamma ) = \log F_{\gamma }(s_{o})\) as the log-partition function of R and is then defined by

    $$\begin{aligned} \mathcal {L}_{TB}(\theta , \gamma ) = \mathbb {E}_{\tau \sim p_{e}} \left[ \left( \log Z(\gamma ) + \log p_{F}(\tau ; \theta ) - \log p_{B}(\tau | x) - \log R(x) \right) ^{2} \right] , \end{aligned}$$

    enforcing the TB condition \(F(s_{o}) p_{F}(\tau ) = p_{B}(\tau | x) R(x)\).

  3. 3.

    The contrastive balance (CB) loss [23] avoids parameterizing F by sampling a pair \((\tau , \tau ')\) of trajectories from \(p_{e}\), being defined by

    $$\begin{aligned} \mathcal {L}_{CB}(\theta ) = \mathbb {E}_{(\tau , \tau ') \sim p_{e}} \left[ \left( \log \frac{p_{F}(\tau )}{p_{B}(\tau | x) R(x)} - \log \frac{p_{F}(\tau ')}{p_{B}(\tau ' | x') R(x')} \right) ^{2} \right] , \end{aligned}$$
    (2)

    and ensuring the CB condition, \(p_{F}(\tau ) p_{B}(\tau '|x')R(x') = R(x) p_{F}(\tau ') p_{B}(\tau |x)\).

Notably, the learning of discrete GFlowNets is widely studied, and many strategies for enhancing training convergence were proposed [3, 11, 17, 18, 22, 25, 28, 30]. In this work, we focus on extending the CB objective to the continuous setting.

2.2 GFlowNets in Continuous Spaces

Notations. We let \((\mathcal {X}, T_{\mathcal {X}}, \varSigma _{\mathcal {X}}, \mu )\) be a topological measurable space with reference measure \(\mu \), topology \(T_{\mathcal {X}}\), \(\sigma \)-algebra \(\varSigma _{\mathcal {X}}\), and underlying space \(\mathcal {X}\), and consider a non-negative function \(r :\mathcal {X} \rightarrow \mathbb {R}_{+}\) over \(\mathcal {X}\) such that \(\int _{\mathcal {X}} r(x) \mu (\textrm{d}x) < \infty \). In what follows, we will assume that each measure, denoted by a capital letter, has a density relatively to the reference measure \(\mu \), which will be denoted by the corresponding lowercase letter. Under these conditions, our objective is to learn a generative process over \(\mathcal {X}\) that samples each measurable subset A proportionally to \(\int _{A} r(x) \mu (\textrm{d}x)\), and we define the target measure \(R :\varSigma _{\mathcal {X}} \rightarrow \mathbb {R}_{+}\) as \(R(A) = \int _{A} r(x) \mu (\textrm{d}x)\), which is absolutely continuous with respect to \(\mu \). Additionally, for a given set S and a corresponding \(\sigma \)-algebra \(\varSigma _{S}\) (which will be specified later), we define a Markov kernel on S as a function \(k :S \times \varSigma _{S} \rightarrow \mathbb {R}_{+}\) such that \(k(s, \cdot )\) is a measure in \((S, \varSigma _{S})\). Also, we denote by \(S^{\otimes n}\) the nth-order Cartesian product of S and by \(\varSigma _{S}^{\otimes n}\) the product \(\sigma \)-algebra of \(\varSigma _{S}\). Naturally, a transition kernel k on S is naturally extended to a transition kernel \(\kappa ^{\otimes n} :S \times \varSigma _{S}^{\otimes n} \rightarrow \mathbb {R}_{+}\) on \(S^{\otimes n}\) recursively via \(\kappa ^{\otimes n}(s, (A, A_{n})) = \kappa ^{\otimes (n - 1)}(s, A) \kappa (s, A_{n})\), with \(A \in \varSigma _{S}^{\otimes (n - 1)}\), \(A_{n} \in \varSigma _{S}\), \(s \in S\), which exists uniquely due to Carathéodory’s extension theorem. Finally, we denote by \(S^{\otimes \le N} = \bigcup _{1 \le i \le N} S^{\otimes i}\) and \(\varSigma _{S}^{\otimes \le N} = \sigma \left( S^{\otimes \le N}\right) \) the \(\sigma \)-algebra generated by the space of up-to-Nth order product \(\sigma \)-algebras.

Measurable Pointed DAGs (MP-DAG). We first present the definition of an MP-DAG, originally proposed by Lahlou et al. [13], which formally extends the concept of a flow network to a potentially non-countable state space.

Definition 1 (MP-DAG)

Let \((\bar{S}, \varSigma )\) be a topological space and \(\varSigma \) be the Borel \(\sigma \)-algebra on \(\mathcal {T}\). Define the source \(s_{o} \in \bar{S}\) and sink \(s_{f} \in \bar{S}\) states. Also, let \(\kappa \), the reference kernel, and \(\kappa _{b}\), the backward kernel, be \(\sigma \)-finite and continuous Markov kernels in \((\bar{S}, \mathcal {T})\); and let \(\mu \) be a reference measure on \(\varSigma \). Also, let \(S = \bar{S} \setminus \{s_{f}\}\). A measurable pointed DAG is a tuple \(\mathcal {G} = (\bar{S}, \mathcal {T}, \varSigma , \kappa , \kappa _{b}, \mu )\) such that

  1. 1.

    (finality) \(\kappa (s_{f}, \cdot ) = \delta _{s_{f}}\) and if \(\kappa (s, \{s_{f}\}) > 0\) for a \(s \in \bar{S}\), \(\kappa (s, \{s_{f}\}) = 1\);

  2. 2.

    (reachability) \(\forall B \in \mathcal {T} \setminus \{\emptyset \}\), \(\exists n \ge 0\) with \(\kappa ^{\otimes n}(s_{o}, B) > 0\);

  3. 3.

    (initialness) \(\forall B \in \varSigma \), \(\kappa _{b}(s_{o}, B) = 0\); and

  4. 4.

    (consistency) \(\forall B \in \varSigma \times \varSigma \), \(\nu \otimes \kappa (B) = \nu \otimes \kappa _{b}(B)\).

Also, a MP-DAG is said to be finitely absorbing if there is a \(N \in \mathbb {N}\) for which \(\kappa ^{\otimes N}(s_{o}, \{s_{f}\}) > 0\) and, by the finality property, \(\kappa ^{\otimes N}(s_{o}, \{s_{f}\}) = 1\). In other words, any iterative generative process starting at \(s_{o}\) and sampling novel states according to \(\kappa \) will reach \(s_{f}\) in at most N steps. We call N the maximum trajectory length of the MP-DAG, and define \(\mathcal {X} = \{x \in S :\kappa (x, \{s_{f}\}) > 0\}\) as the set of terminal states. Under these circumstances, the set \(\{s_{o}\} \times S^{\otimes \le N - 1} \times \{s_{f}\}\) fully characterizes the trajectories starting at \(s_{o}\), and we will denote it by \(\mathbb {T}\); also, we denote by \(\mathbb {T}' = \{s_{o}\} \times (S \setminus \mathcal {X})^{\otimes \le N - 1}\) the space of trajectories without terminal states. We highlight that, when S is finite and \(\mu \) is set as the counting measure in S, then the definition above recovers the traditional flow network outlined in Sect. 2.1 for the discrete setting, with \(\{\kappa (s, \{s'\})\}_{(s, s') \in S \times S}\) characterizing the network’s adjacency matrix and N denoting its (directed) diameter.

Continuous GFlowNets. Similarly to its discrete counterpart, we characterize a continuous GFlowNet by a MP-DAG \(\mathcal {G} = (\bar{S}, \mathcal {T}, \varSigma , \kappa , \kappa _{b}, \mu )\), a measure R on \(\mathcal {X}\), a forward policy \(P_{F} :\bar{S} \times \varSigma _{\bar{S}} \rightarrow \mathbb {R}_{+}\), defined as a Markov kernel that is absolutely continuous relatively to \(\kappa \), and a backward policy \(P_{B} :\bar{S} \times \varSigma _{\bar{S}} \rightarrow \mathbb {R}_{+}\), defined as a Markov kernel absolutely continuous relatively to \(\kappa _{b}\). Our objective is to find a \(P_{F}\) such that the marginal distribution over \(\mathcal {X}\) of the \(\kappa \)-Markov chain deterministically starting at \(s_{o}\) matches R, that is, informally,

$$\begin{aligned} \int _{\mathbb {T}} \mathbbm {1}_{\{\tau \text { finishes at } A\}} P_{F}^{\otimes N}(s_{o}, \textrm{d}\tau ) \propto R(A) \end{aligned}$$
(3)

for every measurable subset A of \(\mathcal {X}\).

Training GFlowNets. To achieve this objective, we let \(p_{F}\) and \(p_{B}\) be the densities of \(P_{F}\) and \(P_{B}\) relatively to \(\kappa \) and \(\kappa _{b}\), respectively, and parameterize \(p_{F}\) as a neural network \(\theta \). To avoid notational overload, we also denote by \(p_{F}\) and \(p_{B}\) the densities of \(P_{F}^{\otimes i}\) and \(P_{B}^{\otimes i}\) with respect to the corresponding product kernels. Analogously to the discrete GFlowNets, we define an auxiliary function \(u :\bar{S} \rightarrow \mathbb {R}_{+}\) parameterized by a neural network \(\gamma \). Additionally, we introduce an exploratory policy \(P_{E} :\bar{S} \times \varSigma _{\bar{S}} \rightarrow \mathbb {R}_{+}\) absolutely continuous with respect to the forward kernel \(\kappa \). Under these conditions, Lahlou et al. [13] showed that the trajectory and detailed balance losses of Sect. 2.1 are sound objectives for learning continuous GFlowNets when we substitute the transition kernels by the corresponding densities. So, for instance, we let \(Z = u_{\gamma }(s_{o})\) and then

$$\begin{aligned} \mathcal {L}_{TB}(\theta , \gamma ) = \mathbb {E}_{\tau \sim P_{E}(s_{o}, \cdot )} \left[ \left( \log Z(\gamma ) + \log p_{F}(\tau ) - \log p_{B}(\tau | x) - \log r(x) \right) ^{2} \right] \end{aligned}$$
(4)

becomes the continuous equivalent of the TB loss. Remarkably, both the TB and DB losses rely on the estimation of the auxiliary function u, which often corresponds to a high-dimensional integral (such as the partition function) that is difficult to compute and has no inferential purpose.

Illustration. To clarify the above definitions and emphasize their versatility, we show how to instantiate a continuous GFlowNet to approximate an arbitrary distribution in an Euclidean space \(\mathbb {R}^{n}\), an example that will be thoroughly explored in our experimental campaign in Sect. 4. For this setting, we may define \(S = \mathbb {R}^{n}\) and \(\mathcal {X} = \{(x, \top ) :x \in \mathbb {R}^{n}\}\), with \(\top \) artificially distinguishing the elements of \(\mathcal {X}\) from those of S. Then, \(s_{o} = \textbf{0}\) and a transition corresponds to \(x^{t + 1} = x^{t} + \alpha e^{t}\), with \(\alpha \sim Q_{\theta }\) sampled from an appropriate distribution \(Q_{\theta }\) with parameters estimated by a neural network that receives \(x^{t}\) as input and \(e^{t}\) denoting the tth line of the identity matrix. Critically, this iterative generative process induces a finitely absorbing MP-DAG satisfying all the desired properties when we define \(\kappa \) to satisfy \(\kappa (x, \cdot ) = \delta _{s_{f}}(\cdot )\) for every \(x \in \mathcal {X}\) and \(\kappa _{b}(x^{t}, \cdot ) = \delta _{x^{t} - x_{t}^{t} e_{t}}(\cdot )\).

3 A Contrastive Objective

We now present our main theoretical results regarding the soundness of the contrastive balance objective for training GFlowNets in a non-countable state space. Firstly, Sect. 3.1 defines the CB condition as a provably sufficient requirement for ensuring the sampling correctness of GFlowNets. Secondly, Sect. 3.2 derives the contrastive continuous loss (CCL) as a realization of the CB condition that, when minimized, ensures that the resulting model is correct almost everywhere. Finally, we analyze the connection between the CCL and the popular variance loss [20] for carrying out generic VI.

3.1 Contrastive Balance

Contrastive Balance. To get a finer understanding of the contrastive balance condition, we first remark that the TB condition may be equivalently stated as

$$\begin{aligned} \text {Var}\left[ \frac{p_{F}(\tau )}{p_{B}(\tau | x) r(x)} \right] = 0, \end{aligned}$$
(5)

i.e., the quotient \(\beta (\tau ) = \nicefrac {p_{F}(\tau )}{p_{B}(\tau | x) r(x)}\) is constant as a function of \(\tau \). Thus, if \(\beta (\tau ) = \beta (\tau ')\) for any pair of trajectories \((\tau , \tau ')\), the TB condition must be satisfied and the GFlowNet should sample correctly from the target. Importantly, however, the \(\beta \) does not depend on the estimation of intractable quantities and enforcing the equality \(\beta (\tau ) = \beta (\tau ')\), which is the essence of Definition 2 below, avoids the estimation of high-dimensional integrals such as \(\log Z(\gamma )\) in Eq. (4). Nextly, we formally define the contrastive balance objective. Notably, we slightly abuse notation by denoting \(P_{B}(x, \textrm{d}\tau )\) for the product kernel \(P_{B}(x, \textrm{d}s_{n}) \otimes P_{B}(s_{n}, \textrm{d}s_{n - 1}) \otimes \cdots \otimes P_{B}(s_{1}, \textrm{d}s_{o})\) associated to \(\tau = (s_{o}, s_{1}, \dots , s_{n}, x)\).

Definition 2 (Contrastive balance)

Let \((\mathcal {G}, P_{F}, P_{B}, R, \mu )\) be a GFlowNet. The contrastive balance (CB) condition is defined by

$$\begin{aligned} \begin{aligned} \int _{\mathcal {X}} \int _{\mathcal {X}} \int _{\mathbb {T}'} \int _{\mathbb {T}'} f(\tau , x, \tau ', x') P_{F}(s_{o}, \textrm{d}\tau \textrm{d}x) P_{B}(x', \textrm{d}\tau ') R(\textrm{d}x') \\ = \int _{\mathcal {X}} \int _{\mathcal {X}} \int _{\mathbb {T}'} \int _{\mathbb {T}'} f(\tau , x, \tau ', x') R(\textrm{d}x) P_{B}(x, \textrm{d}\tau ) P_{F}(s_{o}, \textrm{d}\tau '\textrm{d}x') \end{aligned} \end{aligned}$$
(6)

for every bounded measurable function \(f :\mathbb {T}' \times \mathcal {X} \times \mathbb {T}' \times \mathcal {X} \rightarrow \mathbb {R}\).

Illustratively, when S is finite, the reference measure is the counting measure, and we let f represent the indicator functions of elements in the space \(\mathbb {T}' \times \mathcal {X} \times \mathbb {T}' \times \mathcal {X}\), the condition above reduces to the discrete contrastive balance condition, namely, \(p_{F}(\tau , x) p_{B}(x', \tau ') r(x') = r(x) p_{B}(x, \tau ) p_{F}(\tau ', x')\) for every pair \((\tau , \tau ')\) of trajectories and \((x, x')\) of terminal states. In the light of this, Definition 2 generalizes the discrete CB condition to a significantly broader setting.

Sufficiency of CB. We rigorously show in the theorem below that the continuous variant of the CB condition outlined in Definition 2 is sufficient for ensuring that the corresponding GFlowNet generates samples distributed according to the target measure R. Note, for this, that we may rewrite the Eq. (6) in term of the densities of the corresponding measures.

Theorem 1

Let \((\mathcal {G}, P_{F}, P_{B}, R, \mu )\) be a GFlowNet abiding by the contrastive balance condition. Then, the marginal distribution over \(\mathcal {X}\) induced by the \(P_{F}\)-Markov chain starting at \(s_{o}\) matches R.

Proof

To start with, we define \(t :\mathbb {T} \rightarrow \mathcal {X}\) as the function taking a complete trajectory and returning its terminal state. By the definition of \(\mathcal {X}\), t is well-defined almost everywhere with respect to the measure \(\kappa (s_{o}, \cdot )\) over trajectories. Then, we show that Eq. (6) is enough to ensure, for all measurable \(A \subseteq \mathcal {X}\), that

$$\begin{aligned} \int _{\mathbb {T}} \mathbbm {1}_{\{t(\tau ) \in A\}} P_{F}(s_{o}, \textrm{d}\tau ) \propto R(A), \end{aligned}$$
(7)

i.e., the marginal distribution of \(P_{F}\) on \(\mathcal {X}\) matches R. Under these conditions, we first note that, when the function f does not depend on the last two parameters, Eq. (6) becomes

$$\begin{aligned} \begin{aligned} \int _{\mathcal {X}} \int _{\mathbb {T}'} f(\tau , x) P_{F}(s_{o}, \textrm{d}\tau \textrm{d}x) \underbrace{\left( \int _{\mathcal {X}} \int _{\mathbb {T}'} R(\textrm{d}x') P_{B}(x', \textrm{d}\tau ') \right) }_{= Z} \\ = \underbrace{\left( \int _{\mathcal {X}} \int _{\mathbb {T}'} P_{F}(s_{o}, \textrm{d}\tau ' \textrm{d}x') \right) }_{= 1} \int _{\mathcal {X}} \int _{\mathbb {T}'} f(\tau , x) R(\textrm{d}x) P_{B}(x, \textrm{d}\tau ), \end{aligned} \end{aligned}$$
(8)

which recovers Lahlou et al.’s TB condition [13, Definition 6]. Consequently, the result follows from [13, Theorem 2, (4)]. For completeness, however, it is easy to see that Eq. (8) ensures that \(P_{F}(s_{o}, \cdot )\) and \(R(\textrm{d}x) P_{B}(x, \cdot )\), when interpreted as distributions over \(S^{\otimes \le N - 1} \times \{s_{f}\}\), are equivalent in terms of expectations of bounded measurable functions and, thus, must be indistinguishable up to a multiplicative constant. Hence, by integrating out the elements of \(\tau \) that are not members of \(\mathcal {X}\), we obtain Eq. (7).

Finally, before we delve into the algorithmic aspects of contrasitve learning of GFlowNets, we again emphasize the result above is general enough to contemplate distributions supported on continuous, discrete, and mixed spaces.

3.2 Contrastive Loss

Contrastive Loss. Realistically, finding a policy \(P_{F}\) satisfying the CB condition is an analytically intractable problem. As a consequence, we parameterize \(P_{F}\) with a neural network with weights \(\theta \) and estimate the optimal \(\theta ^{\star }\) by minimizing a loss function that provably enforces Eq. (6). The corollary below, which is an immediate consequence of Theorem 1, provides such a loss function.

Corollary 1

(Continuous Contrastive Loss). Under the notations of Theorem 1, let \(p_{F}\), \(p_{B}\) and r be the densities of \(P_{F}\), \(P_{B}\), and R relatively to \(\kappa \), \(\kappa _{b}\) and \(\mu \), respectively. Then, we define the continuous contrastive loss (CCL) as

$$\begin{aligned} \mathcal {L}_{CB}(\theta ) = \mathbb {E}_{\tau , \tau ' \sim P_{E}(s_{o}, \cdot ) \otimes P_{E}(s_{o}, \cdot )} \left[ \left( \frac{p_{F}(\tau )}{p_{B}(\tau | x) r(x)} - \frac{p_{F}(\tau ')}{p_{B}(\tau ' | x') r(x')}\right) ^{2} \right] , \end{aligned}$$
(9)

in which \(P_{E}(s_{o}, \cdot )\) is an exploratory policy with full support over the trajectories on \(\varSigma _{\mathcal {S}}^{\otimes \le N}\) starting at \(s_{o}\) and we use x (resp. \(x'\)) to denote the unique terminal state within \(\tau \) (resp. \(\tau '\)). Then, if \(\mathcal {L}_{CB}(\theta ^{\star }) = 0\), the GFlowNet parameterized by \(\theta ^{\star }\) abides by Eq. (8) and samples correctly from the target.

Intuitively, the condition \(\mathcal {L}_{CB}(\theta ) = 0\) ensures that the quotients \(\beta (\tau )\) defined in Sect. 3.1 are constant \(\kappa (s_{o}, \cdot )\)-almost everywhere, which then implies Theorem 1 and ultimately guarantees GFlowNet’s distributional correctness.

Algorithm 1
figure a

Contrastive learning of GFlowNets

Algorithm. Algorithm 1 illustrates the training of GFlowNets through the minimization of the CCL. In practice, both the sampling and neural network forward passes can be massively parallelized. Also, for the optimization step, we use Adam [12] to estimate the stochastic gradients. Importantly, the architecture of the parameterizing neural network with weights \(\theta \) and the nature of exploratory policy \(P_{E}\) are very problem-dependent and are abstracted away from Algorithm 1; we provide some examples in the next section.

Connection to Other Losses. To conclude, we remark that the CB loss is tightly connected to the variance loss for variational inference [20, 26]; indeed, it is well known that the expectation of the squared-difference between two independently sampled variables equal their variance, i.e., \(\mathbb {E}_{\tau , \tau '}[(\beta (\tau ) - \beta (\tau '))^{2}] = \text {Var}_{\tau }[\beta (\tau )]\).

4 Experiments

This section shows that the CCL is a sound and effective learning objective for continuous GFlowNets in varied experimental setups, which are described below.

Tasks. We underline the effectiveness of the CB loss in the tasks of sampling from a mixture of Gaussian distributions (GM) [28] and of a banana-shaped distribution (BS) [16], which are commonly implemented benchmark models respectively for continuous GFlowNets and variational inference algorithms in general. For the GMs, we evaluate GFlowNet’s performance for a sparse 2-dimensional model and a 15-dimensional model, both with isotropic variances. For the BS distribution, we consider the target distribution

$$\begin{aligned} \mathbb {R}^{2} \ni \textbf{x} \sim \mathcal {N}\left( \begin{bmatrix} x_{1} \\ x_{2} + x_{1}^{2} + 1 \end{bmatrix} \bigg | \begin{bmatrix} 0 \\ 0 \end{bmatrix}, \begin{bmatrix} 1 & 0.9 \\ 0.9 & 1 \end{bmatrix}\right) . \end{aligned}$$
(10)

Experimental Setup. For each experiment, we consider the iterative generative process described at the end of Sect. 2.2. To parameterize \(Q_{\theta }\), which is here defined as a Gaussian mixture distribution, we employ an MLP with 2 64-dimensional layers that receives the current state as an input and return the log-weights, averages, and log-variances of the mixture. In Algorithm 1, we fix a batch size of 128 and use Adam with a linearly decaying learning rate of \(10^{-3}\) for optimization. However, when minimizing the TB loss, we adopt Malkin et al.’s [14] suggestion to implement a fixed learning rate of \(10^{-1}\) for \(\log Z(\gamma )\). Code for reproducing the results will be released upon acceptance.

Fig. 1.
figure 1

Minimizing the CB loss (bottom row) drastically improves the training convergence when compared to TB minimization (top row) when the target is a 2-dimensional sparse Gaussian Mixture (see Fig. 3). This corroborates our hypothesis that avoiding parameterizing \(\log Z(\gamma )\) facilitates the learning of GFlowNets. Columns represent results of three randomly initialized runs.

Fig. 2.
figure 2

CB and TB minimization perform similarly for the (relatively easy-to-approximate) banana-shaped target. We consider HMC (right plot) as a gold-standard reference for sample generation due to its accuracy and reliability.

Fig. 3.
figure 3

A sparse Gaussian mixture.

Results. Figure 1 highlights that a GFlowNet trained by minimizing \(\mathcal {L}_{CB}\) leads to a substantial improvement in terms of approximation quality when compared to the minimization of \(\mathcal {L}_{TB}\) given a fixed computational budget and a sparse target, which is illustrated in Fig. 3. This validates our hypothesis that the training of \(\mathcal {L}_{TB}\) is significantly constrained by the obligatory estimation of the intractable \(\log Z(\gamma )\). For the experiments based on the high-dimensional Gaussian mixture, which cannot be simply visualized, we show in Fig. 4 the Jensen-Shannon divergence between the learned distribution q and the target p,

$$\begin{aligned} \mathcal {D}_{JS}[p || q] = \frac{1}{2} \left( \mathcal {D}_{KL}[p || m] + \mathcal {D}_{KL}[q || m] \right) , \end{aligned}$$
(11)

with \(m :==\nicefrac {p + q}{2}\) and \(\mathcal {D}_{KL}[p || m] = \mathbb {E}_{x \sim p} \left[ \log \nicefrac {p}{m}\right] \) representing the Kullback-Leibler divergence between p and m. Importantly, \(\mathcal {D}_{JS}\) is a frequently implemented metric for assessing GFlowNets [6, 13]. Finally, for the low-dimensional and non-sparse targets, the training of GFlowNets based on both the minimization of CB and TB yield similar approximations given an equivalent computation resources, as Fig. 2 attests, which is potentially a consequence of the relative simplicity of the target distribution.

5 Related Works

Discrete GFlowNets. GFlowNets were originally conceived as a reinforcement learning algorithm tailored to the search of a diversity of high-valued states [2, 3]. Since then, this family of models was successfully applied to Bayesian structure learning and causal discovery [1, 5, 6], natural language processing [9], robust scheduling [26], phylogenetic inference [30], and probabilistic modelling in general [10, 15, 27, 29]. Correspondingly, a large effort was devoted to improving the training convergence of discrete GFlowNets [11, 14, 17, 19, 25].

Fig. 4.
figure 4

Learning curve for / .

Continuous GFlowNets. In contrast, the literature on continuous GFlowNet, introduced by Lahlou et al. [13], is relatively thin. To the best of our knowledge, the only fruitful application of these models occurred recently in the context of Bayesian structure learning [6]. The reason for this, we believe, is the difficulty of efficiently training these models. Indeed, Zhou et al. [30] discretized the continuous component of the state space to avoid training a model to sample from a continuous distribution. Under these conditions, our work is an important step towards the improvement the usability of continuous GFlowNets.

6 Conclusions

We rigorously introduced a continuous variant of the contrastive balance condition and an accompanying continuous contrastive loss for efficiently training GFlowNets over distributions supported on discrete, continuous, and mixed spaces. All in all, our theoretical analysis assured the reliability of the proposed learning objective and our empirical results suggested that minimizing the CCL often significantly speeds up the model’s training convergence.

Yet, the training of continuous GFlowNet remains challenging, and extending the recent advancements in the training of discrete GFlowNets to the continuous setting, along the lines of our work, is a promising endeavor that may greatly diminish these difficulties. For instance, one may build upon the framework of Pan et al.’s [19] Generative Augmented Flow Networks to incorporate intermediate reward-related signals within a trajectory for enhanced credit assignment [22]. Similarly, the theoretical developments of Tiapkin et al. [7, 25] showing the relationship between RL and GFlowNets paves the road for the design of learning objectives motivated by techniques in the theory of control for continuous states.