Abstract
Generative Flow Networks (GFlowNets) are a novel class of flexible amortized samplers for distributions supported on complex objects (e.g., graphs and sequences), achieving significant success in problems such as combinatorial optimization, drug discovery and natural language processing. Nonetheless, training of GFlowNets is challenging—partly because it relies on estimating high-dimensional integrals, including the log-partition function, via stochastic gradient descent (SGD). In particular, for distributions supported on non-discrete spaces, which have received far less attention from the recent literature, every previously proposed learning objective either depends on estimating a log-partition function or is restricted to on-policy training, which is susceptible to mode-collapse. In this context, inspired by the success of contrastive learning for variational inference, we propose the continuous contrastive loss (CCL) as the first objective function natively enabling off-policy training of continuous GFlowNets without reliance on the approximation of high-dimensional integrals via SGD, extending previous work based on discrete distributions. Additionally, we show that minimizing the CCL objective is empirically effective and often leads to faster training convergence than alternatives.
Access provided by University of Notre Dame Hesburgh Library. Download conference paper PDF
Similar content being viewed by others
1 Introduction
The evaluation of high-dimensional expectations under unnormalized distributions is one of the main challenges of statistics and machine learning [4]. In Bayesian decision theory, for instance, one seeks to minimize a risk function that is defined as an expectation of a functional under a often intractable and unnormalized posterior distribution provisioned by Bayes’ rule [21]. Similarly, many model-free reinforcement learning algorithms for searching an optimal policy in a contextual Markov decision process rely on integrating out a high-dimensional context variable of the quality function. Under these circumstances, a common approach consists of substituting the unmanageable integration problem by a manageable optimization problem, which is known as variational inference. For this, one introduces a tractable and large set of parameteric distributions, known as the variational family, and finds the family’s member that most closely matches the unnormalized target by solving a stochastic optimization problem. Then, one uses the solution to this problem as a surrogate to the intractable target and evaluate the desired expectations using Monte Carlo estimators.
Recently, Generative Flow Networks (GFlowNets) [2, 3, 13] were proposed as a flexible variational family parameterized by neural networks to approximate unnormalized distributions with compositional supports [15], with remarkable success in relevant applications including causal inference [5, 24], drug discovery [2], refinement of large language models [9], and combinatorial optimization [28]. To implement a GFlowNet, we first define a flow network over an extension of the target distribution’s support, which thereafter corresponds to the sink nodes, and then search for a consistent flow assignment for which the flow going through each sink node matches its unnormalized probability. During inference, we travel across the defined network by starting at the origin and choosing each transition according to the estimated flow therein. If the flow is correctly estimated, this procedure is guaranteed to yield independent and identically distributed (i.i.d.) samples from the intractable target distribution [2, 13].
To successfully find a consistent and balanced flow assignment, we conventionally parameterize the network’s transition policies with versatile models as, e.g., multilayer perceptrons or graph neural networks. Then, to estimate the model’s parameters, we minimize a loss function consisting of the expectation of log-squared local balance violations according to a prescribed exploratory policy via a variant of SGD. Importantly, this exploratory policy may be different of the learned policy, in which case the training is said to be carried in an off-policy fashion. We remark that, when dealing with high-dimensional and highly sparse target distributions, an off-policy sampling scheme is paramount for avoiding the collapse of the variational distribution [14, 15]. Nonetheless, as we note in Sect. 2, such a possibility of off-policy learning comes at the cost of introducing additional parameters that often entail the estimation of high-dimensional integrals as a subproblem to the successful training of GFlowNets, potentially hindering the sample-efficiency of these models.
Seeking to bypass this problem, recent work [23, 26] developed a contrastive learning objective for GFlowNets. In the spirit of Hinton’s contrastive divergence learning [8], these approaches sidestep the estimation of difficult-to-approximate quantities by, instead of directly optimizing an optimality criteria, minimizing the difference between them. This ensures that the intractable terms cancel out and that the resulting problem is computationally easier to solve. Remarkably, however, these works are constrained to sampling from unnormalized distributions with finite support; a formal extension and an empirical validation of these objectives to the non-discrete setting is lacking in the literature. In this context, we extend both the contrastive balance condition and the contrastive loss function [23] to train GFlowNets defined on measurable topological spaces [13]. In addition, we empirically show that the resulting continuous contrastive loss (CCL) frequently leads to faster training convergence with respect to competing learning objectives for the standard tasks of approximating a sparse mixture of multivariate Gaussian distributions and a banana-shaped distribution.
In summary, our contributions are:
-
1.
We derive a contrastive balance condition for continuous GFlowNets and rigorously show that it is a sufficient for ensuring sampling correctness;
-
2.
We propose the continuous contrastive loss (CCL) as a theoretically sound learning objective for GFlowNets that does not rely on the estimation of high-dimensional integrals via SGD during training;
-
3.
We empirically compare the performance of CCL against alternative loss functions and show that CCL leads to faster convergence in most cases.
The paper is organized as follows. In Sect. 2, we recall the notions of a measurable pointed directed acyclic graph (DAG), transition kernels and flow networks, laying out the theoretical framework upon which GFlowNets are built, and also review popular learning objectives for continuous GFlowNets. Then, in Sect. 3, we rigorously derive the contrastive balance condition and propose CCL as a sound learning objective for GFlowNets, the performance of which is assessed in Sect. 4 for the approximation of a mixture of multivariate Gaussians and of a banana-shaped distributions [16]. Finally, we provide a summary of the relevant literature in Sect. 5 and a discussion of future directions, along with some concluding remarks, in Sect. 6.
2 Background
This section reviews the main components of and the commonly used learning objectives for GFlowNets. To build intuition, we firstly define a GFlowNet for finitely supported distributions in Sect. 2.1. Then, we review the formal framework of Lahlou et al. [13] for extending the notion of a flow networks to arbitrary topological spaces by substituting an adjacency matrix by a transition kernel that describes the continuous network’s connectivity.
2.1 GFlowNets in Discrete Spaces
Discrete GFlowNets. Let R be a (potentially unnormalized) probability distribution on a finite space \(\mathcal {X}\) and \(S \supseteq \mathcal {X}\) be an extension of \(\mathcal {X}\) with two distinguished elements, \(s_{o}\) and \(s_{f}\). Our objective is to randomly sample elements \(x \in \mathcal {X}\) proportionally to R(x). For this, we let \(G = (S, E)\) be a DAG with edges E such that (i) the only nodes connected to \(s_{f}\) are those in \(\mathcal {X}\), (ii) there is a path from \(s_{o}\) to every \(x \in \mathcal {X}\), (iii) there are no outgoing edges from \(s_{f}\) and there are no incoming edges to \(s_{o}\). We call \(\mathcal {G}\) the state graph. In this setting, let \(p_{F} :S \times S \rightarrow \mathbb {R}_{+}\) be a forward policy on \(\mathcal {G}\), i.e., a function such that \(p_{F}(s, \cdot )\) is a probability measure on S supported on the children of s in \(\mathcal {G}\) and \(p_{F}(\tau ) = \prod _{1 \le i \le n} p_{F}(s_{i} | s_{i - 1})\) be the corresponding distribution over trajectories \(\tau = (s_{o}, \dots , s_{n - 1}, s_{n} = s_{f})\). To accomplish our objective, we must find a \(p_{F}\) for which
In practice, the left-hand-side of Eq. (1) is computationally intractable and is satisfied by a possibly infinite number of \(p_{F}\)’s. To ensure tractability and uniqueness, we also introduce a backward policy \(p_{B} :S \times S \rightarrow \mathbb {R}_{+}\), which is a forward policy over the transpose of \(\mathcal {G}\), and rewrite Eq. (1) as \(p_{F}(\tau ) \propto p_{B}(\tau | x) R(x)\). Malkin et al. [14] showed that such technique induces a well-posed problem with a unique solution.
Training GFlowNets. To find a forward policy abiding by Eq. (1), we parameterize \(p_{F}\) with a neural network with parameters \(\theta \) and fix \(p_{B}(s, \cdot )\) as an uniform policy for each \(s \in S\). Then, by introducing an additional learnable function \(F :S \rightarrow \mathbb {R}_{+}\), which is also parameterized by a (different) neural network \(\gamma \), we concomitantly estimate \(\theta \) and \(\gamma \) by minimizing one of the following objectives. Below, \(p_{e}\) is an exploratory policy, i.e., a forward policy that does not depend on \((\theta , \gamma )\) and has full support over the space of trajectories. Also, we denote by x (resp. \(x'\)) the state immediately preceding \(s_{f}\) in a trajectory \(\tau \) (resp. \(\tau '\)).
-
1.
The detailed balance (DB) loss [3] corresponds to \(\mathcal {L}_{DB}(\theta , \gamma ) =\)
$$\begin{aligned} \begin{aligned} \mathbb {E}_{\tau \sim p_{e}}\left[ \frac{1}{|\tau |}\sum _{(s_{i - 1}, s_{i}) \in \tau } \left( \log \frac{F_{\gamma }(s_{i - 1})p_{F}(s_{i} | s_{i - 1})}{p_{B}(s_{i} | s_{i - 1}) F_{\gamma }(s_{i - 1})}\right) ^{2} + \left( \log \frac{F_{\gamma }(x)}{R(x)} \right) ^{2} \right] \end{aligned} \end{aligned}$$which enforces the DB condition, \(F(s) p_{F}(s | s') = F(s') p_{B}(s|s')\) for \((s, s') \in E\) and \(F(x) = R(x)\) for the states \(x \in S\).
-
2.
The trajectory balance (TB) loss [14] introduces \(\log Z(\gamma ) = \log F_{\gamma }(s_{o})\) as the log-partition function of R and is then defined by
$$\begin{aligned} \mathcal {L}_{TB}(\theta , \gamma ) = \mathbb {E}_{\tau \sim p_{e}} \left[ \left( \log Z(\gamma ) + \log p_{F}(\tau ; \theta ) - \log p_{B}(\tau | x) - \log R(x) \right) ^{2} \right] , \end{aligned}$$enforcing the TB condition \(F(s_{o}) p_{F}(\tau ) = p_{B}(\tau | x) R(x)\).
-
3.
The contrastive balance (CB) loss [23] avoids parameterizing F by sampling a pair \((\tau , \tau ')\) of trajectories from \(p_{e}\), being defined by
$$\begin{aligned} \mathcal {L}_{CB}(\theta ) = \mathbb {E}_{(\tau , \tau ') \sim p_{e}} \left[ \left( \log \frac{p_{F}(\tau )}{p_{B}(\tau | x) R(x)} - \log \frac{p_{F}(\tau ')}{p_{B}(\tau ' | x') R(x')} \right) ^{2} \right] , \end{aligned}$$(2)and ensuring the CB condition, \(p_{F}(\tau ) p_{B}(\tau '|x')R(x') = R(x) p_{F}(\tau ') p_{B}(\tau |x)\).
Notably, the learning of discrete GFlowNets is widely studied, and many strategies for enhancing training convergence were proposed [3, 11, 17, 18, 22, 25, 28, 30]. In this work, we focus on extending the CB objective to the continuous setting.
2.2 GFlowNets in Continuous Spaces
Notations. We let \((\mathcal {X}, T_{\mathcal {X}}, \varSigma _{\mathcal {X}}, \mu )\) be a topological measurable space with reference measure \(\mu \), topology \(T_{\mathcal {X}}\), \(\sigma \)-algebra \(\varSigma _{\mathcal {X}}\), and underlying space \(\mathcal {X}\), and consider a non-negative function \(r :\mathcal {X} \rightarrow \mathbb {R}_{+}\) over \(\mathcal {X}\) such that \(\int _{\mathcal {X}} r(x) \mu (\textrm{d}x) < \infty \). In what follows, we will assume that each measure, denoted by a capital letter, has a density relatively to the reference measure \(\mu \), which will be denoted by the corresponding lowercase letter. Under these conditions, our objective is to learn a generative process over \(\mathcal {X}\) that samples each measurable subset A proportionally to \(\int _{A} r(x) \mu (\textrm{d}x)\), and we define the target measure \(R :\varSigma _{\mathcal {X}} \rightarrow \mathbb {R}_{+}\) as \(R(A) = \int _{A} r(x) \mu (\textrm{d}x)\), which is absolutely continuous with respect to \(\mu \). Additionally, for a given set S and a corresponding \(\sigma \)-algebra \(\varSigma _{S}\) (which will be specified later), we define a Markov kernel on S as a function \(k :S \times \varSigma _{S} \rightarrow \mathbb {R}_{+}\) such that \(k(s, \cdot )\) is a measure in \((S, \varSigma _{S})\). Also, we denote by \(S^{\otimes n}\) the nth-order Cartesian product of S and by \(\varSigma _{S}^{\otimes n}\) the product \(\sigma \)-algebra of \(\varSigma _{S}\). Naturally, a transition kernel k on S is naturally extended to a transition kernel \(\kappa ^{\otimes n} :S \times \varSigma _{S}^{\otimes n} \rightarrow \mathbb {R}_{+}\) on \(S^{\otimes n}\) recursively via \(\kappa ^{\otimes n}(s, (A, A_{n})) = \kappa ^{\otimes (n - 1)}(s, A) \kappa (s, A_{n})\), with \(A \in \varSigma _{S}^{\otimes (n - 1)}\), \(A_{n} \in \varSigma _{S}\), \(s \in S\), which exists uniquely due to Carathéodory’s extension theorem. Finally, we denote by \(S^{\otimes \le N} = \bigcup _{1 \le i \le N} S^{\otimes i}\) and \(\varSigma _{S}^{\otimes \le N} = \sigma \left( S^{\otimes \le N}\right) \) the \(\sigma \)-algebra generated by the space of up-to-Nth order product \(\sigma \)-algebras.
Measurable Pointed DAGs (MP-DAG). We first present the definition of an MP-DAG, originally proposed by Lahlou et al. [13], which formally extends the concept of a flow network to a potentially non-countable state space.
Definition 1 (MP-DAG)
Let \((\bar{S}, \varSigma )\) be a topological space and \(\varSigma \) be the Borel \(\sigma \)-algebra on \(\mathcal {T}\). Define the source \(s_{o} \in \bar{S}\) and sink \(s_{f} \in \bar{S}\) states. Also, let \(\kappa \), the reference kernel, and \(\kappa _{b}\), the backward kernel, be \(\sigma \)-finite and continuous Markov kernels in \((\bar{S}, \mathcal {T})\); and let \(\mu \) be a reference measure on \(\varSigma \). Also, let \(S = \bar{S} \setminus \{s_{f}\}\). A measurable pointed DAG is a tuple \(\mathcal {G} = (\bar{S}, \mathcal {T}, \varSigma , \kappa , \kappa _{b}, \mu )\) such that
-
1.
(finality) \(\kappa (s_{f}, \cdot ) = \delta _{s_{f}}\) and if \(\kappa (s, \{s_{f}\}) > 0\) for a \(s \in \bar{S}\), \(\kappa (s, \{s_{f}\}) = 1\);
-
2.
(reachability) \(\forall B \in \mathcal {T} \setminus \{\emptyset \}\), \(\exists n \ge 0\) with \(\kappa ^{\otimes n}(s_{o}, B) > 0\);
-
3.
(initialness) \(\forall B \in \varSigma \), \(\kappa _{b}(s_{o}, B) = 0\); and
-
4.
(consistency) \(\forall B \in \varSigma \times \varSigma \), \(\nu \otimes \kappa (B) = \nu \otimes \kappa _{b}(B)\).
Also, a MP-DAG is said to be finitely absorbing if there is a \(N \in \mathbb {N}\) for which \(\kappa ^{\otimes N}(s_{o}, \{s_{f}\}) > 0\) and, by the finality property, \(\kappa ^{\otimes N}(s_{o}, \{s_{f}\}) = 1\). In other words, any iterative generative process starting at \(s_{o}\) and sampling novel states according to \(\kappa \) will reach \(s_{f}\) in at most N steps. We call N the maximum trajectory length of the MP-DAG, and define \(\mathcal {X} = \{x \in S :\kappa (x, \{s_{f}\}) > 0\}\) as the set of terminal states. Under these circumstances, the set \(\{s_{o}\} \times S^{\otimes \le N - 1} \times \{s_{f}\}\) fully characterizes the trajectories starting at \(s_{o}\), and we will denote it by \(\mathbb {T}\); also, we denote by \(\mathbb {T}' = \{s_{o}\} \times (S \setminus \mathcal {X})^{\otimes \le N - 1}\) the space of trajectories without terminal states. We highlight that, when S is finite and \(\mu \) is set as the counting measure in S, then the definition above recovers the traditional flow network outlined in Sect. 2.1 for the discrete setting, with \(\{\kappa (s, \{s'\})\}_{(s, s') \in S \times S}\) characterizing the network’s adjacency matrix and N denoting its (directed) diameter.
Continuous GFlowNets. Similarly to its discrete counterpart, we characterize a continuous GFlowNet by a MP-DAG \(\mathcal {G} = (\bar{S}, \mathcal {T}, \varSigma , \kappa , \kappa _{b}, \mu )\), a measure R on \(\mathcal {X}\), a forward policy \(P_{F} :\bar{S} \times \varSigma _{\bar{S}} \rightarrow \mathbb {R}_{+}\), defined as a Markov kernel that is absolutely continuous relatively to \(\kappa \), and a backward policy \(P_{B} :\bar{S} \times \varSigma _{\bar{S}} \rightarrow \mathbb {R}_{+}\), defined as a Markov kernel absolutely continuous relatively to \(\kappa _{b}\). Our objective is to find a \(P_{F}\) such that the marginal distribution over \(\mathcal {X}\) of the \(\kappa \)-Markov chain deterministically starting at \(s_{o}\) matches R, that is, informally,
for every measurable subset A of \(\mathcal {X}\).
Training GFlowNets. To achieve this objective, we let \(p_{F}\) and \(p_{B}\) be the densities of \(P_{F}\) and \(P_{B}\) relatively to \(\kappa \) and \(\kappa _{b}\), respectively, and parameterize \(p_{F}\) as a neural network \(\theta \). To avoid notational overload, we also denote by \(p_{F}\) and \(p_{B}\) the densities of \(P_{F}^{\otimes i}\) and \(P_{B}^{\otimes i}\) with respect to the corresponding product kernels. Analogously to the discrete GFlowNets, we define an auxiliary function \(u :\bar{S} \rightarrow \mathbb {R}_{+}\) parameterized by a neural network \(\gamma \). Additionally, we introduce an exploratory policy \(P_{E} :\bar{S} \times \varSigma _{\bar{S}} \rightarrow \mathbb {R}_{+}\) absolutely continuous with respect to the forward kernel \(\kappa \). Under these conditions, Lahlou et al. [13] showed that the trajectory and detailed balance losses of Sect. 2.1 are sound objectives for learning continuous GFlowNets when we substitute the transition kernels by the corresponding densities. So, for instance, we let \(Z = u_{\gamma }(s_{o})\) and then
becomes the continuous equivalent of the TB loss. Remarkably, both the TB and DB losses rely on the estimation of the auxiliary function u, which often corresponds to a high-dimensional integral (such as the partition function) that is difficult to compute and has no inferential purpose.
Illustration. To clarify the above definitions and emphasize their versatility, we show how to instantiate a continuous GFlowNet to approximate an arbitrary distribution in an Euclidean space \(\mathbb {R}^{n}\), an example that will be thoroughly explored in our experimental campaign in Sect. 4. For this setting, we may define \(S = \mathbb {R}^{n}\) and \(\mathcal {X} = \{(x, \top ) :x \in \mathbb {R}^{n}\}\), with \(\top \) artificially distinguishing the elements of \(\mathcal {X}\) from those of S. Then, \(s_{o} = \textbf{0}\) and a transition corresponds to \(x^{t + 1} = x^{t} + \alpha e^{t}\), with \(\alpha \sim Q_{\theta }\) sampled from an appropriate distribution \(Q_{\theta }\) with parameters estimated by a neural network that receives \(x^{t}\) as input and \(e^{t}\) denoting the tth line of the identity matrix. Critically, this iterative generative process induces a finitely absorbing MP-DAG satisfying all the desired properties when we define \(\kappa \) to satisfy \(\kappa (x, \cdot ) = \delta _{s_{f}}(\cdot )\) for every \(x \in \mathcal {X}\) and \(\kappa _{b}(x^{t}, \cdot ) = \delta _{x^{t} - x_{t}^{t} e_{t}}(\cdot )\).
3 A Contrastive Objective
We now present our main theoretical results regarding the soundness of the contrastive balance objective for training GFlowNets in a non-countable state space. Firstly, Sect. 3.1 defines the CB condition as a provably sufficient requirement for ensuring the sampling correctness of GFlowNets. Secondly, Sect. 3.2 derives the contrastive continuous loss (CCL) as a realization of the CB condition that, when minimized, ensures that the resulting model is correct almost everywhere. Finally, we analyze the connection between the CCL and the popular variance loss [20] for carrying out generic VI.
3.1 Contrastive Balance
Contrastive Balance. To get a finer understanding of the contrastive balance condition, we first remark that the TB condition may be equivalently stated as
i.e., the quotient \(\beta (\tau ) = \nicefrac {p_{F}(\tau )}{p_{B}(\tau | x) r(x)}\) is constant as a function of \(\tau \). Thus, if \(\beta (\tau ) = \beta (\tau ')\) for any pair of trajectories \((\tau , \tau ')\), the TB condition must be satisfied and the GFlowNet should sample correctly from the target. Importantly, however, the \(\beta \) does not depend on the estimation of intractable quantities and enforcing the equality \(\beta (\tau ) = \beta (\tau ')\), which is the essence of Definition 2 below, avoids the estimation of high-dimensional integrals such as \(\log Z(\gamma )\) in Eq. (4). Nextly, we formally define the contrastive balance objective. Notably, we slightly abuse notation by denoting \(P_{B}(x, \textrm{d}\tau )\) for the product kernel \(P_{B}(x, \textrm{d}s_{n}) \otimes P_{B}(s_{n}, \textrm{d}s_{n - 1}) \otimes \cdots \otimes P_{B}(s_{1}, \textrm{d}s_{o})\) associated to \(\tau = (s_{o}, s_{1}, \dots , s_{n}, x)\).
Definition 2 (Contrastive balance)
Let \((\mathcal {G}, P_{F}, P_{B}, R, \mu )\) be a GFlowNet. The contrastive balance (CB) condition is defined by
for every bounded measurable function \(f :\mathbb {T}' \times \mathcal {X} \times \mathbb {T}' \times \mathcal {X} \rightarrow \mathbb {R}\).
Illustratively, when S is finite, the reference measure is the counting measure, and we let f represent the indicator functions of elements in the space \(\mathbb {T}' \times \mathcal {X} \times \mathbb {T}' \times \mathcal {X}\), the condition above reduces to the discrete contrastive balance condition, namely, \(p_{F}(\tau , x) p_{B}(x', \tau ') r(x') = r(x) p_{B}(x, \tau ) p_{F}(\tau ', x')\) for every pair \((\tau , \tau ')\) of trajectories and \((x, x')\) of terminal states. In the light of this, Definition 2 generalizes the discrete CB condition to a significantly broader setting.
Sufficiency of CB. We rigorously show in the theorem below that the continuous variant of the CB condition outlined in Definition 2 is sufficient for ensuring that the corresponding GFlowNet generates samples distributed according to the target measure R. Note, for this, that we may rewrite the Eq. (6) in term of the densities of the corresponding measures.
Theorem 1
Let \((\mathcal {G}, P_{F}, P_{B}, R, \mu )\) be a GFlowNet abiding by the contrastive balance condition. Then, the marginal distribution over \(\mathcal {X}\) induced by the \(P_{F}\)-Markov chain starting at \(s_{o}\) matches R.
Proof
To start with, we define \(t :\mathbb {T} \rightarrow \mathcal {X}\) as the function taking a complete trajectory and returning its terminal state. By the definition of \(\mathcal {X}\), t is well-defined almost everywhere with respect to the measure \(\kappa (s_{o}, \cdot )\) over trajectories. Then, we show that Eq. (6) is enough to ensure, for all measurable \(A \subseteq \mathcal {X}\), that
i.e., the marginal distribution of \(P_{F}\) on \(\mathcal {X}\) matches R. Under these conditions, we first note that, when the function f does not depend on the last two parameters, Eq. (6) becomes
which recovers Lahlou et al.’s TB condition [13, Definition 6]. Consequently, the result follows from [13, Theorem 2, (4)]. For completeness, however, it is easy to see that Eq. (8) ensures that \(P_{F}(s_{o}, \cdot )\) and \(R(\textrm{d}x) P_{B}(x, \cdot )\), when interpreted as distributions over \(S^{\otimes \le N - 1} \times \{s_{f}\}\), are equivalent in terms of expectations of bounded measurable functions and, thus, must be indistinguishable up to a multiplicative constant. Hence, by integrating out the elements of \(\tau \) that are not members of \(\mathcal {X}\), we obtain Eq. (7).
Finally, before we delve into the algorithmic aspects of contrasitve learning of GFlowNets, we again emphasize the result above is general enough to contemplate distributions supported on continuous, discrete, and mixed spaces.
3.2 Contrastive Loss
Contrastive Loss. Realistically, finding a policy \(P_{F}\) satisfying the CB condition is an analytically intractable problem. As a consequence, we parameterize \(P_{F}\) with a neural network with weights \(\theta \) and estimate the optimal \(\theta ^{\star }\) by minimizing a loss function that provably enforces Eq. (6). The corollary below, which is an immediate consequence of Theorem 1, provides such a loss function.
Corollary 1
(Continuous Contrastive Loss). Under the notations of Theorem 1, let \(p_{F}\), \(p_{B}\) and r be the densities of \(P_{F}\), \(P_{B}\), and R relatively to \(\kappa \), \(\kappa _{b}\) and \(\mu \), respectively. Then, we define the continuous contrastive loss (CCL) as
in which \(P_{E}(s_{o}, \cdot )\) is an exploratory policy with full support over the trajectories on \(\varSigma _{\mathcal {S}}^{\otimes \le N}\) starting at \(s_{o}\) and we use x (resp. \(x'\)) to denote the unique terminal state within \(\tau \) (resp. \(\tau '\)). Then, if \(\mathcal {L}_{CB}(\theta ^{\star }) = 0\), the GFlowNet parameterized by \(\theta ^{\star }\) abides by Eq. (8) and samples correctly from the target.
Intuitively, the condition \(\mathcal {L}_{CB}(\theta ) = 0\) ensures that the quotients \(\beta (\tau )\) defined in Sect. 3.1 are constant \(\kappa (s_{o}, \cdot )\)-almost everywhere, which then implies Theorem 1 and ultimately guarantees GFlowNet’s distributional correctness.
Algorithm. Algorithm 1 illustrates the training of GFlowNets through the minimization of the CCL. In practice, both the sampling and neural network forward passes can be massively parallelized. Also, for the optimization step, we use Adam [12] to estimate the stochastic gradients. Importantly, the architecture of the parameterizing neural network with weights \(\theta \) and the nature of exploratory policy \(P_{E}\) are very problem-dependent and are abstracted away from Algorithm 1; we provide some examples in the next section.
Connection to Other Losses. To conclude, we remark that the CB loss is tightly connected to the variance loss for variational inference [20, 26]; indeed, it is well known that the expectation of the squared-difference between two independently sampled variables equal their variance, i.e., \(\mathbb {E}_{\tau , \tau '}[(\beta (\tau ) - \beta (\tau '))^{2}] = \text {Var}_{\tau }[\beta (\tau )]\).
4 Experiments
This section shows that the CCL is a sound and effective learning objective for continuous GFlowNets in varied experimental setups, which are described below.
Tasks. We underline the effectiveness of the CB loss in the tasks of sampling from a mixture of Gaussian distributions (GM) [28] and of a banana-shaped distribution (BS) [16], which are commonly implemented benchmark models respectively for continuous GFlowNets and variational inference algorithms in general. For the GMs, we evaluate GFlowNet’s performance for a sparse 2-dimensional model and a 15-dimensional model, both with isotropic variances. For the BS distribution, we consider the target distribution
Experimental Setup. For each experiment, we consider the iterative generative process described at the end of Sect. 2.2. To parameterize \(Q_{\theta }\), which is here defined as a Gaussian mixture distribution, we employ an MLP with 2 64-dimensional layers that receives the current state as an input and return the log-weights, averages, and log-variances of the mixture. In Algorithm 1, we fix a batch size of 128 and use Adam with a linearly decaying learning rate of \(10^{-3}\) for optimization. However, when minimizing the TB loss, we adopt Malkin et al.’s [14] suggestion to implement a fixed learning rate of \(10^{-1}\) for \(\log Z(\gamma )\). Code for reproducing the results will be released upon acceptance.
Minimizing the CB loss (bottom row) drastically improves the training convergence when compared to TB minimization (top row) when the target is a 2-dimensional sparse Gaussian Mixture (see Fig. 3). This corroborates our hypothesis that avoiding parameterizing \(\log Z(\gamma )\) facilitates the learning of GFlowNets. Columns represent results of three randomly initialized runs.
Results. Figure 1 highlights that a GFlowNet trained by minimizing \(\mathcal {L}_{CB}\) leads to a substantial improvement in terms of approximation quality when compared to the minimization of \(\mathcal {L}_{TB}\) given a fixed computational budget and a sparse target, which is illustrated in Fig. 3. This validates our hypothesis that the training of \(\mathcal {L}_{TB}\) is significantly constrained by the obligatory estimation of the intractable \(\log Z(\gamma )\). For the experiments based on the high-dimensional Gaussian mixture, which cannot be simply visualized, we show in Fig. 4 the Jensen-Shannon divergence between the learned distribution q and the target p,
with \(m :==\nicefrac {p + q}{2}\) and \(\mathcal {D}_{KL}[p || m] = \mathbb {E}_{x \sim p} \left[ \log \nicefrac {p}{m}\right] \) representing the Kullback-Leibler divergence between p and m. Importantly, \(\mathcal {D}_{JS}\) is a frequently implemented metric for assessing GFlowNets [6, 13]. Finally, for the low-dimensional and non-sparse targets, the training of GFlowNets based on both the minimization of CB and TB yield similar approximations given an equivalent computation resources, as Fig. 2 attests, which is potentially a consequence of the relative simplicity of the target distribution.
5 Related Works
Discrete GFlowNets. GFlowNets were originally conceived as a reinforcement learning algorithm tailored to the search of a diversity of high-valued states [2, 3]. Since then, this family of models was successfully applied to Bayesian structure learning and causal discovery [1, 5, 6], natural language processing [9], robust scheduling [26], phylogenetic inference [30], and probabilistic modelling in general [10, 15, 27, 29]. Correspondingly, a large effort was devoted to improving the training convergence of discrete GFlowNets [11, 14, 17, 19, 25].
Continuous GFlowNets. In contrast, the literature on continuous GFlowNet, introduced by Lahlou et al. [13], is relatively thin. To the best of our knowledge, the only fruitful application of these models occurred recently in the context of Bayesian structure learning [6]. The reason for this, we believe, is the difficulty of efficiently training these models. Indeed, Zhou et al. [30] discretized the continuous component of the state space to avoid training a model to sample from a continuous distribution. Under these conditions, our work is an important step towards the improvement the usability of continuous GFlowNets.
6 Conclusions
We rigorously introduced a continuous variant of the contrastive balance condition and an accompanying continuous contrastive loss for efficiently training GFlowNets over distributions supported on discrete, continuous, and mixed spaces. All in all, our theoretical analysis assured the reliability of the proposed learning objective and our empirical results suggested that minimizing the CCL often significantly speeds up the model’s training convergence.
Yet, the training of continuous GFlowNet remains challenging, and extending the recent advancements in the training of discrete GFlowNets to the continuous setting, along the lines of our work, is a promising endeavor that may greatly diminish these difficulties. For instance, one may build upon the framework of Pan et al.’s [19] Generative Augmented Flow Networks to incorporate intermediate reward-related signals within a trajectory for enhanced credit assignment [22]. Similarly, the theoretical developments of Tiapkin et al. [7, 25] showing the relationship between RL and GFlowNets paves the road for the design of learning objectives motivated by techniques in the theory of control for continuous states.
References
Atanackovic, L., Tong, A., Hartford, J., Lee, L.J., Wang, B., Bengio, Y.: DynGFN: towards Bayesian inference of gene regulatory networks with GFlowNets. In: Advances in Neural Processing Systems (NeurIPS) (2023)
Bengio, E., Jain, M., Korablyov, M., Precup, D., Bengio, Y.: Flow network based generative models for non-iterative diverse candidate generation. In: NeurIPS (2021)
Bengio, Y., Lahlou, S., Deleu, T., Hu, E.J., Tiwari, M., Bengio, E.: GFlowNet foundations. J. Mach. Learn. Res. (JMLR) 24, 1–55 (2023)
Blei, D.M., et al.: Variational inference: a review for statisticians. J. Am. Stat. Assoc. 112, 859–877 (2017)
Deleu, T., et al.: Bayesian structure learning with generative flow networks. In: UAI (2022)
Deleu, T., Nishikawa-Toomey, M., Subramanian, J., Malkin, N., Charlin, L., Bengio, Y.: Joint Bayesian inference of graphical structure and parameters with a single generative flow network. In: Advances in Neural Processing Systems (NeurIPS) (2023)
Deleu, T., Nouri, P., Malkin, N., Precup, D., Bengio, Y.: Discrete probabilistic inference as control in multi-path environments (2024)
Hinton, G.E., Osindero, S., Teh, Y.W.: A fast learning algorithm for deep belief nets. Neural Comput. 18, 1527–1554 (2006)
Hu, E.J., Jain, M., Elmoznino, E., Kaddar, Y., et al.: Amortizing intractable inference in large language models (2023)
Hu, E.J., Malkin, N., Jain, M., Everett, K.E., Graikos, A., Bengio, Y.: GFlowNet-EM for learning compositional latent variable models. In: International Conference on Machine Learning (ICLR) (2023)
Jang, H., Kim, M., Ahn, S.: Learning energy decompositions for partial inference in GFlownets. In: The Twelfth International Conference on Learning Representations (2024)
Kingma, D.P., Ba, J.: Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014)
Lahlou, S., et al.: A theory of continuous generative flow networks. In: ICML Proceedings of Machine Learning Research, vol. 202, pp. 18269–18300. PMLR (2023)
Malkin, N., Jain, M., Bengio, E., Sun, C., Bengio, Y.: Trajectory balance: improved credit assignment in GFlowNets. In: NeurIPS (2022)
Malkin, N., et al.: GFlowNets and variational inference. In: International Conference on Learning Representations (ICLR) (2023)
Mesquita, D., Blomstedt, P., Kaski, S.: Embarrassingly parallel MCMC using deep invertible transformations. In: UAI (2019)
Pan, L., Malkin, N., Zhang, D., Bengio, Y.: Better training of GFlowNets with local credit and incomplete trajectories. In: International Conference on Machine Learning (ICML) (2023)
Pan, L., Malkin, N., Zhang, D., Bengio, Y.: Better training of GFlowNets with local credit and incomplete trajectories. arXiv preprint arXiv:2302.01687 (2023)
Pan, L., Zhang, D., Courville, A., Huang, L., Bengio, Y.: Generative augmented flow networks. In: International Conference on Learning Representations (ICLR) (2023)
Richter, L., Boustati, A., Nüsken, N., Ruiz, F.J.R., Akyildiz, Ö.D.: VarGrad: a low-variance gradient estimator for variational inference (2020)
Robert, C.P., et al.: The Bayesian Choice: From Decision-Theoretic Foundations to Computational Implementation, vol. 2. Springer, Heidelberg (2007). https://doi.org/10.1007/0-387-71599-1
Shen, M.W., Bengio, E., Hajiramezanali, E., Loukas, A., Cho, K., Biancalani, T.: Towards understanding and improving GFlowNet training. In: International Conference on Machine Learning (2023)
da Silva, T., Carvalho, L.M., Souza, A., Kaski, S., Mesquita, D.: Embarrassingly parallel GFlowNets (2024)
da Silva, T., et al.: Human-in-the-loop causal discovery under latent confounding using ancestral GFlowNets. arXiv preprint arXiv:2309.12032 (2023)
Tiapkin, D., Morozov, N., Naumov, A., Vetrov, D.: Generative flow networks as entropy-regularized RL (2024)
Zhang, D.W., Rainone, C., Peschl, M., Bondesan, R.: Robust scheduling with GFlowNets. In: International Conference on Learning Representations (ICLR) (2023)
Zhang, D., Chen, R.T., Malkin, N., Bengio, Y.: Unifying generative models with GFlowNets and beyond. In: ICML Beyond Bayes Workshop (2022)
Zhang, D., Dai, H., Malkin, N., Courville, A., Bengio, Y., Pan, L.: Let the flows tell: solving graph combinatorial optimization problems with GFlowNets. In: NeurIPS (2023)
Zhang, D., Malkin, N., Liu, Z., Volokhova, A., Courville, A., Bengio, Y.: Generative flow networks for discrete probabilistic modeling. In: International Conference on Machine Learning (ICML) (2022)
Zhou, M.Y., et al.: PhyloGFN: phylogenetic inference with generative flow networks. In: The Twelfth International Conference on Learning Representations (2024)
Acknowledgements
This work was supported by Fundação Carlos Chagas Filho de Amparo à Pesquisa do Estado do Rio de Janeiro FAPERJ (SEI-260003/000709/2023), São Paulo Research Foundation FAPESP (2023/00815-6), and Conselho Nacional de Desenvolvimento Científico e Tecnológico CNPq (404336/2023-0).
Author information
Authors and Affiliations
Corresponding author
Editor information
Editors and Affiliations
Ethics declarations
Disclosure of Interests
The authors have no competing interests to declare that are relevant to the content of this article.
Rights and permissions
Copyright information
© 2025 The Author(s), under exclusive license to Springer Nature Switzerland AG
About this paper
Cite this paper
da Silva, T., Mesquita, D. (2025). A Contrastive Objective for Training Continuous Generative Flow Networks. In: Paes, A., Verri, F.A.N. (eds) Intelligent Systems. BRACIS 2024. Lecture Notes in Computer Science(), vol 15412. Springer, Cham. https://doi.org/10.1007/978-3-031-79029-4_1
Download citation
DOI: https://doi.org/10.1007/978-3-031-79029-4_1
Published:
Publisher Name: Springer, Cham
Print ISBN: 978-3-031-79028-7
Online ISBN: 978-3-031-79029-4
eBook Packages: Computer ScienceComputer Science (R0)








/
.