-
Notifications
You must be signed in to change notification settings - Fork 20
Updating Dimod Sampler #73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
anahitamansouri
wants to merge
3
commits into
dwavesystems:main
Choose a base branch
from
anahitamansouri:feature/dimod-conditional-sampling
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,19 +13,20 @@ | |
| # limitations under the License. | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Any | ||
| from typing import TYPE_CHECKING, Any, Optional | ||
|
|
||
| import torch | ||
| from dimod import Sampler | ||
| from hybrid.composers import AggregatedSamples | ||
| import dimod | ||
| import warnings | ||
| from hybrid.composers import AggregatedSamples | ||
|
|
||
| from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine | ||
| from dwave.plugins.torch.samplers.base import TorchSampler | ||
| from dwave.plugins.torch.utils import sampleset_to_tensor | ||
|
|
||
| if TYPE_CHECKING: | ||
| import dimod | ||
|
|
||
| from dimod import SampleSet | ||
| from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine | ||
|
|
||
|
|
||
|
|
@@ -86,24 +87,128 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: | |
| """Sample from the dimod sampler and return the corresponding tensor. | ||
|
|
||
| The sample set returned from the latest sample call is stored in :func:`DimodSampler.sample_set` | ||
| which is overwritten by subsequent calls. | ||
| which is overwritten by subsequent calls. When ``x`` is provided (conditional sampling), the method | ||
| expects the underlying sampler to return a SampleSet containing exactly one sample per row of ``x``; | ||
| otherwise, a ValueError is raised. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A tensor of shape (``batch_size``, ``dim``) or (``batch_size``, ``n_nodes``) | ||
| interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will | ||
| be sampled; entries with +/-1 values will remain constant. | ||
| Raises: | ||
| ValueError: If ``x`` has an invalid shape or contains values other than ±1 or NaN or if the | ||
| sampler returns more than one sample per input row. | ||
|
|
||
| Returns: | ||
| torch.Tensor: A tensor of shape (``batch_size``, ``n_nodes``) containing | ||
| sampled spin configurations with values in ``{-1, +1}``. | ||
| """ | ||
| if x is not None: | ||
| raise NotImplementedError("Support for conditional sampling has not been implemented.") | ||
| device = self._grbm._linear.device | ||
| nodes = self._grbm.nodes | ||
| n_nodes = self._grbm.n_nodes | ||
|
|
||
| h, J = self._grbm.to_ising(self._prefactor, self._linear_range, self._quadratic_range) | ||
| self._sample_set = AggregatedSamples.spread( | ||
| self._sampler.sample_ising(h, J, **self._sampler_params) | ||
| ) | ||
|
|
||
| # Unconditional sampling | ||
| if x is None: | ||
| self._sample_set = AggregatedSamples.spread( | ||
| self._sampler.sample_ising(h, J, **self._sampler_params) | ||
| ) | ||
| return self._sampleset_to_tensor(self._sample_set, device) | ||
|
|
||
| # Conditional sampling | ||
| if x.shape[1] != n_nodes: | ||
| raise ValueError(f"x must have shape (batch_size, {n_nodes})") | ||
|
|
||
| mask = ~torch.isnan(x) | ||
| if not torch.all(torch.isin(x[mask], torch.tensor([-1, 1], device=x.device))): | ||
| raise ValueError("x must contain only ±1 or NaN") | ||
|
|
||
| results = [] | ||
| for row, row_mask in zip(x, mask): | ||
| # Fresh BQM | ||
| bqm = dimod.BinaryQuadraticModel.from_ising(h, J) | ||
|
|
||
| # Build conditioning dict | ||
| conditioned = {node: int(val.item()) | ||
| for node, val, m in zip(nodes, row, row_mask) if m | ||
| } | ||
|
|
||
| # Apply conditioning | ||
| if conditioned: | ||
| bqm.fix_variables(conditioned) | ||
|
|
||
| # Handle fully clamped case | ||
| if bqm.num_variables == 0: | ||
| full = torch.tensor([conditioned[node] for node in nodes], | ||
| device=device, dtype=torch.float) | ||
| results.append(full) | ||
| continue | ||
|
|
||
| # Clip linear biases for remaining free variables | ||
| if self._linear_range is not None: | ||
| lb, ub = self._linear_range | ||
| for v in bqm.linear: | ||
| if bqm.linear[v] > ub: | ||
| bqm.set_linear(v, ub) | ||
| elif bqm.linear[v] < lb: | ||
| bqm.set_linear(v, lb) | ||
|
|
||
| # Clip quadratic biases | ||
| if self._quadratic_range is not None: | ||
| lb, ub = self._quadratic_range | ||
| for u, v, bias in bqm.iter_quadratic(): | ||
| if bias > ub: | ||
| bqm.set_quadratic(u, v, ub) | ||
| elif bias < lb: | ||
| bqm.set_quadratic(u, v, lb) | ||
|
|
||
| # Sample one configuration per input | ||
| sample_kwargs = dict(self._sampler_params) | ||
|
|
||
| # Storing the latest samples | ||
| self._sample_set = self._sampler.sample(bqm, **sample_kwargs) | ||
|
|
||
| if len(self._sample_set) > 1: | ||
| raise ValueError(f"Expected exactly one sample per input row, but got {len(self._sample_set)}") | ||
|
|
||
| # Extract sampled values | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Raise an error here if sample set has sample size > 1 |
||
| sample = self._sample_set.first.sample | ||
|
|
||
| # Reconstruct full sample | ||
| full = torch.empty(n_nodes, device=device) | ||
| for node, idx in self._grbm._node_to_idx.items(): | ||
| full[idx] = conditioned[node] if node in conditioned else float(sample[node]) | ||
| results.append(full) | ||
|
|
||
| # Stack to get (batch_size, n_nodes) | ||
| samples = torch.stack(results, dim=0) | ||
| return samples | ||
|
|
||
| def _sampleset_to_tensor(self, sample_set: SampleSet, device: Optional[torch.device] = None | ||
| ) -> torch.Tensor: | ||
| """Converts a ``dimod.SampleSet`` to a ``torch.Tensor`` using GRBM node order. | ||
|
|
||
| # use same device as modules linear | ||
| device = self._grbm._linear.device | ||
| return sampleset_to_tensor(self._grbm.nodes, self._sample_set, device) | ||
| Args: | ||
| sample_set (dimod.SampleSet): A sample set. | ||
| device (torch.device, optional): The device of the constructed tensor. | ||
| If ``None`` and data is a tensor then the device of data is used. | ||
| If ``None`` and data is not a tensor then the result tensor is constructed | ||
| on the current device. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The sample set as a ``torch.Tensor``. | ||
| """ | ||
| var_to_sample_i = {v: i for i, v in enumerate(sample_set.variables)} | ||
|
|
||
| # Convert dict -> ordered list by index | ||
| ordered_vars = [v for v, _ in sorted(self._grbm.node_to_idx.items(), key=lambda x: x[1])] | ||
|
|
||
| permutation = [var_to_sample_i[v] for v in ordered_vars] | ||
|
|
||
| sample = sample_set.record.sample[:, permutation] | ||
|
|
||
| return torch.from_numpy(sample).to(device=device, dtype=torch.float32) | ||
|
|
||
| @property | ||
| def sample_set(self) -> dimod.SampleSet: | ||
|
|
||
5 changes: 5 additions & 0 deletions
5
releasenotes/notes/dimod-conditional-sampling-3ce98d4847eedb83.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| --- | ||
| upgrade: | ||
| - | | ||
| Add conditional sampling functionality for the ``DimodSampler``. The sampler | ||
| enforces one sample per input row during conditional sampling. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,14 +15,13 @@ | |
| import unittest | ||
|
|
||
| import torch | ||
| from dimod import SPIN, BinaryQuadraticModel, IdentitySampler, SampleSet, TrackingComposite | ||
| from parameterized import parameterized | ||
| from dimod import IdentitySampler, SampleSet, TrackingComposite | ||
|
|
||
| from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM | ||
| from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| from dwave.samplers import SteepestDescentSampler | ||
| from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple | ||
|
|
||
| from dwave.samplers import SimulatedAnnealingSampler | ||
| from dwave.plugins.torch.utils import sampleset_to_tensor | ||
|
|
||
| class TestDimodSampler(unittest.TestCase): | ||
| def setUp(self) -> None: | ||
|
|
@@ -125,6 +124,51 @@ def test_sample(self): | |
| torch.tensor(list(tracker.input['J'].values())), | ||
| torch.tensor([0, 0, 0, 0.0]) | ||
| ) | ||
|
|
||
| with self.subTest("Conditional sampling preserves clamped variables"): | ||
| sampler = DimodSampler( | ||
| self.bm, | ||
| SimulatedAnnealingSampler(), | ||
| prefactor=1, | ||
| sample_kwargs=dict(num_reads=1) | ||
| ) | ||
|
|
||
| x = torch.tensor([ | ||
| [1.0, float("nan"), -1.0, float("nan")], | ||
| [float("nan"), -1.0, float("nan"), 1.0], | ||
| ]) | ||
|
|
||
| samples = sampler.sample(x) | ||
|
|
||
| # Shape check | ||
| self.assertTupleEqual(samples.shape, x.shape) | ||
|
|
||
| # Check clamped values unchanged | ||
| mask = ~torch.isnan(x) | ||
| self.assertTrue(torch.all(samples[mask] == x[mask])) | ||
|
|
||
| # Check free variables are ±1 | ||
| free_mask = torch.isnan(x) | ||
| free_values = samples[free_mask] | ||
| self.assertTrue(torch.all(torch.isin(free_values, torch.tensor([-1.0, 1.0]))), | ||
| "Free variables should be sampled as ±1") | ||
|
|
||
| with self.subTest("Conditional sampling with all variables clamped returns input unchanged."): | ||
| sampler = DimodSampler( | ||
| self.bm, | ||
| SimulatedAnnealingSampler(), | ||
| prefactor=1, | ||
| sample_kwargs=dict(num_reads=1) | ||
| ) | ||
|
|
||
| x = torch.tensor([ | ||
| [+1.0, -1.0, -1.0, +1.0], | ||
| [-1.0, +1.0, -1.0, -1.0], | ||
| ]) | ||
|
|
||
| samples = sampler.sample(x) | ||
| # All spins clamped, should return identical tensor | ||
| torch.testing.assert_close(samples, x) | ||
|
|
||
| def test_sample_set(self): | ||
| grbm = GRBM(list("abcd"), [("a", "b")]) | ||
|
|
@@ -142,6 +186,5 @@ def test_sample_set(self): | |
| with self.subTest("The `sample_set` attribute should be of type `dimod.SampleSet`."): | ||
| self.assertTrue(isinstance(sampler.sample_set, SampleSet)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be documented in the docstring. It should also raise a warning if
num_readsis supplied and overwritten.