From ac6cd8ede972de401c1a5f1c4a7b9aa6c065953e Mon Sep 17 00:00:00 2001 From: Vishal Gupta Date: Tue, 3 Feb 2026 03:35:45 +0000 Subject: [PATCH 1/2] Ingestion pipeline fixes --- .../ingestion/data/GraphReader.java | 64 +++++++++++++++++-- .../pipeline/ImportGroupPipeline.java | 50 +++++++++++---- .../pipeline/IngestionPipelineOptions.java | 2 +- .../ingestion/spanner/SpannerClient.java | 2 + .../src/main/resources/spanner_schema.sql | 30 --------- pipeline/util/pom.xml | 5 ++ .../pipeline/util/PipelineUtils.java | 60 ++++++++++++++--- 7 files changed, 153 insertions(+), 60 deletions(-) diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/GraphReader.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/GraphReader.java index 2a3a3462..224b3791 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/GraphReader.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/GraphReader.java @@ -2,19 +2,23 @@ import com.google.cloud.ByteArray; import com.google.cloud.spanner.Mutation; +import java.io.IOException; import java.io.Serializable; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.TypeDescriptor; import org.datacommons.Storage.Observations; import org.datacommons.ingestion.spanner.SpannerClient; import org.datacommons.pipeline.util.PipelineUtils; @@ -25,13 +29,16 @@ import org.datacommons.proto.Mcf.McfStatVarObsSeries.StatVarObs; import org.datacommons.proto.Mcf.ValueType; import org.datacommons.util.GraphUtils; +import org.datacommons.util.McfUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class GraphReader implements Serializable { private static final Logger LOGGER = LoggerFactory.getLogger(GraphReader.class); + // Maximum size for a single column value in Spanner (10MB) private static final String DC_AGGREGATE = "dcAggregate/"; private static final String DATCOM_AGGREGATE = "DataCommonsAggregate"; + private static final String IMPORT_METADATA_FILE = "import_metadata_mcf.mcf"; public static List graphToNodes(McfGraph graph, Counter mcfNodesWithoutTypeCounter) { List nodes = new ArrayList<>(); @@ -42,9 +49,12 @@ public static List graphToNodes(McfGraph graph, Counter mcfNodesWithoutTyp // Generate corresponding node Map pv = pvs.getPvsMap(); Node.Builder node = Node.builder(); - node.subjectId(nodeEntry.getKey()); - node.value(nodeEntry.getKey()); + String dcid = GraphUtils.getPropertyValue(pv, "dcid"); + String subjectId = !dcid.isEmpty() ? dcid : McfUtil.stripNamespace(nodeEntry.getKey()); + node.subjectId(subjectId); + node.value(subjectId); node.name(GraphUtils.getPropertyValue(pv, "name")); + List types = GraphUtils.getPropertyValues(pv, "typeOf"); if (types.isEmpty()) { types = List.of(PipelineUtils.TYPE_THING); @@ -55,9 +65,17 @@ public static List graphToNodes(McfGraph graph, Counter mcfNodesWithoutTyp nodes.add(node.build()); // Generate any leaf nodes - for (Map.Entry entry : pv.entrySet()) { // Iterate over properties + for (Map.Entry entry : pv.entrySet()) { for (TypedValue val : entry.getValue().getTypedValuesList()) { if (val.getType() != ValueType.RESOLVED_REF) { + int valSize = val.getValue().getBytes(StandardCharsets.UTF_8).length; + if (valSize > SpannerClient.MAX_SPANNER_COLUMN_SIZE) { + LOGGER.warn( + "Dropping node from {} because value size {} exceeds max size.", + subjectId, + valSize); + continue; + } node = Node.builder(); node.subjectId(PipelineUtils.generateObjectValueKey(val.getValue())); if (PipelineUtils.storeValueAsBytes(entry.getKey())) { @@ -74,22 +92,54 @@ public static List graphToNodes(McfGraph graph, Counter mcfNodesWithoutTyp return nodes; } + public static PCollection getProvenanceMcf( + String bucketName, String importName, String latestVersion, Pipeline p) { + String provenanceFile = "gs://" + bucketName + "/" + "provenance/" + importName + ".mcf"; + String metadataFile = latestVersion + "/" + IMPORT_METADATA_FILE; + LOGGER.info("Reading provenance mcf from {} {}", provenanceFile, metadataFile); + List mcfList = new ArrayList<>(); + String defaultProvenance = + "Node: dcid:dc/base/" + importName + "\n" + "typeOf: dcid:Provenance\n"; + mcfList.add(GraphUtils.convertToGraph(defaultProvenance)); + try { + mcfList.addAll(GraphUtils.readMcfString(PipelineUtils.getGcsFileContent(metadataFile))); + } catch (IOException e) { + LOGGER.warn("Failed to read provenance metadata file: " + e.getMessage()); + } + try { + mcfList.addAll(GraphUtils.readMcfString(PipelineUtils.getGcsFileContent(provenanceFile))); + } catch (IOException e) { + LOGGER.warn("Failed to read provenance metadata file: " + e.getMessage()); + } + return p.apply(Create.of(mcfList).withType(TypeDescriptor.of(McfGraph.class))); + } + public static List graphToEdges(McfGraph graph, String provenance) { List edges = new ArrayList<>(); for (Map.Entry nodeEntry : graph.getNodesMap().entrySet()) { PropertyValues pvs = nodeEntry.getValue(); if (!GraphUtils.isObservation(pvs)) { Map pv = pvs.getPvsMap(); - // String provenance = GraphUtils.getPropertyValue(pv, "provenance"); - String subjectId = nodeEntry.getKey(); // Use the map key as the subjectId - for (Map.Entry entry : pv.entrySet()) { // Iterate over properties + String dcid = GraphUtils.getPropertyValue(pv, "dcid"); + String subjectId = !dcid.isEmpty() ? dcid : McfUtil.stripNamespace(nodeEntry.getKey()); + for (Map.Entry entry : pv.entrySet()) { for (TypedValue val : entry.getValue().getTypedValuesList()) { + if (val.getType() != ValueType.RESOLVED_REF) { + int valSize = val.getValue().getBytes(StandardCharsets.UTF_8).length; + if (valSize > SpannerClient.MAX_SPANNER_COLUMN_SIZE) { + LOGGER.warn( + "Dropping edge from {} because value size {} exceeds max size.", + subjectId, + valSize); + continue; + } + } Edge.Builder edge = Edge.builder(); edge.subjectId(subjectId); edge.predicate(entry.getKey()); edge.provenance(provenance); if (val.getType() == ValueType.RESOLVED_REF) { - edge.objectId(val.getValue()); + edge.objectId(McfUtil.stripNamespace(val.getValue())); } else { edge.objectId(PipelineUtils.generateObjectValueKey(val.getValue())); } diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/ImportGroupPipeline.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/ImportGroupPipeline.java index 4a5be32e..1b97ebeb 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/ImportGroupPipeline.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/ImportGroupPipeline.java @@ -36,6 +36,10 @@ public class ImportGroupPipeline { private static final Counter obsCounter = Metrics.counter(ImportGroupPipeline.class, "graph_observation_count"); + private static boolean isJsonNullOrEmpty(JsonElement element) { + return element == null || element.getAsString().isEmpty(); + } + public static void main(String[] args) { IngestionPipelineOptions options = PipelineOptionsFactory.fromArgs(args).withValidation().as(IngestionPipelineOptions.class); @@ -72,37 +76,55 @@ public static void main(String[] args) { for (JsonElement element : jsonArray) { JsonElement importElement = element.getAsJsonObject().get("importName"); - JsonElement path = element.getAsJsonObject().get("latestVersion"); - if (importElement == null - || path == null - || importElement.getAsString().isEmpty() - || path.getAsString().isEmpty()) { + JsonElement versionElement = element.getAsJsonObject().get("latestVersion"); + JsonElement pathElement = element.getAsJsonObject().get("graphPath"); + + if (isJsonNullOrEmpty(importElement) + || isJsonNullOrEmpty(pathElement) + || isJsonNullOrEmpty(versionElement)) { LOGGER.error("Invalid import input json: {}", element.toString()); continue; } String importName = importElement.getAsString(); + String latestVersion = versionElement.getAsString(); + LOGGER.info("Import: {} Latest version: {}", importName, latestVersion); + + // Populate provenance node/edges. String provenance = "dc/base/" + importName; - LOGGER.info("Import {} graph path {}", importName, path.getAsString()); + PCollection provenanceMcf = + GraphReader.getProvenanceMcf( + options.getStorageBucketId(), importName, latestVersion, pipeline); PCollection deleteMutations = GraphReader.getDeleteMutations(importName, provenance, pipeline, spannerClient); deleteMutationList.add(deleteMutations); // Read schema mcf files and combine MCF nodes, and convert to spanner mutations (Node/Edge). - PCollection nodes = PipelineUtils.readMcfFiles(path.getAsString(), pipeline); + String graphPath = + latestVersion.replaceAll("/+$", "") + + "/" + + pathElement.getAsString().replaceAll("^/+", ""); + PCollection nodes = + pathElement.getAsString().contains("tfrecord") + ? PipelineUtils.readMcfGraph(graphPath, pipeline) + : PipelineUtils.readMcfFiles(graphPath, pipeline); PCollectionTuple graphNodes = PipelineUtils.splitGraph(nodes); PCollection observationNodes = graphNodes.get(PipelineUtils.OBSERVATION_NODES_TAG); PCollection schemaNodes = graphNodes.get(PipelineUtils.SCHEMA_NODES_TAG); - PCollection combinedGraph = PipelineUtils.combineGraphNodes(schemaNodes); - PCollection nodeMutations = - GraphReader.graphToNodes( - combinedGraph, spannerClient, nodeCounter, nodeInvalidTypeCounter) - .apply("ExtractNodeMutations", Values.create()); + PCollection schemaMcf = + PCollectionList.of(schemaNodes) + .and(provenanceMcf) + .apply("FlattenSchema", Flatten.pCollections()); + PCollection edgeMutations = - GraphReader.graphToEdges(combinedGraph, provenance, spannerClient, edgeCounter) + GraphReader.graphToEdges(schemaMcf, provenance, spannerClient, edgeCounter) .apply("ExtractEdgeMutations", Values.create()); + + PCollection nodeMutations = + GraphReader.graphToNodes(schemaMcf, spannerClient, nodeCounter, nodeInvalidTypeCounter) + .apply("ExtractEdgeMutations", Values.create()); + nodeMutationList.add(nodeMutations); edgeMutationList.add(edgeMutations); - // Read observation mcf files, build optimized graph, and convert to spanner mutations // (Observation). PCollection optimizedGraph = diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/IngestionPipelineOptions.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/IngestionPipelineOptions.java index c46b0be1..64baf5a9 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/IngestionPipelineOptions.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/IngestionPipelineOptions.java @@ -25,7 +25,7 @@ public interface IngestionPipelineOptions extends PipelineOptions { void setSpannerDatabaseId(String databaseId); @Description("GCS bucket Id for input data") - @Default.String("datcom-store") + @Default.String("datcom-prod-imports") String getStorageBucketId(); void setStorageBucketId(String bucketId); diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java index 93d03f56..cb11cf1f 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java @@ -45,6 +45,8 @@ public class SpannerClient implements Serializable { // Decrease batch size for observations (bigger rows) private static final int SPANNER_BATCH_SIZE_BYTES = 500 * 1024; + // Maximum size for a single column value in Spanner (10MB) + public static final int MAX_SPANNER_COLUMN_SIZE = 10 * 1024 * 1024; // Increase batch size for Nodes/Edges (smaller rows) private static final int SPANNER_MAX_NUM_ROWS = 2000; // Higher value ensures this limit is not encountered before MaxNumRows diff --git a/pipeline/ingestion/src/main/resources/spanner_schema.sql b/pipeline/ingestion/src/main/resources/spanner_schema.sql index aa67d4de..991d4489 100644 --- a/pipeline/ingestion/src/main/resources/spanner_schema.sql +++ b/pipeline/ingestion/src/main/resources/spanner_schema.sql @@ -32,33 +32,3 @@ CREATE TABLE Observation ( provenance_url STRING(1024), is_dc_aggregate BOOL, ) PRIMARY KEY(observation_about, variable_measured, facet_id) - -CREATE TABLE ImportStatus ( - ImportName STRING(MAX) NOT NULL, - LatestVersion STRING(MAX), - State STRING(1024) NOT NULL, - JobId STRING(1024), - WorkflowId STRING(1024), - ExecutionTime INT64, - DataVolume INT64, - DataImportTimestamp TIMESTAMP OPTIONS ( allow_commit_timestamp = TRUE ), - StatusUpdateTimestamp TIMESTAMP OPTIONS ( allow_commit_timestamp = TRUE ), - NextRefreshTimestamp TIMESTAMP, -) PRIMARY KEY(ImportName) - -CREATE TABLE IngestionHistory ( - CompletionTimestamp TIMESTAMP NOT NULL OPTIONS ( allow_commit_timestamp = TRUE ), - WorkflowExecutionID STRING(1024) NOT NULL, - DataflowJobID STRING(1024), - IngestedImports ARRAY, - ExecutionTime INT64, - NodeCount INT64, - EdgeCount INT64, - ObservationCount INT64, -) PRIMARY KEY(CompletionTimestamp DESC) - -CREATE TABLE IngestionLock ( - LockID STRING(1024) NOT NULL, - LockOwner STRING(1024), - AcquiredTimestamp TIMESTAMP OPTIONS ( allow_commit_timestamp = TRUE ), -) PRIMARY KEY(LockID) diff --git a/pipeline/util/pom.xml b/pipeline/util/pom.xml index 33697cb6..5231f9b7 100644 --- a/pipeline/util/pom.xml +++ b/pipeline/util/pom.xml @@ -36,6 +36,11 @@ beam-sdks-java-core ${beam.version} + + com.google.cloud + google-cloud-storage + 2.0.2 + junit junit diff --git a/pipeline/util/src/main/java/org/datacommons/pipeline/util/PipelineUtils.java b/pipeline/util/src/main/java/org/datacommons/pipeline/util/PipelineUtils.java index 64121550..0cf40b2f 100644 --- a/pipeline/util/src/main/java/org/datacommons/pipeline/util/PipelineUtils.java +++ b/pipeline/util/src/main/java/org/datacommons/pipeline/util/PipelineUtils.java @@ -2,6 +2,10 @@ import static org.apache.beam.sdk.io.Compression.GZIP; +import com.google.cloud.storage.Blob; +import com.google.cloud.storage.BlobId; +import com.google.cloud.storage.Storage; +import com.google.cloud.storage.StorageOptions; import com.google.common.collect.ImmutableSet; import com.google.common.hash.Hashing; import com.google.protobuf.InvalidProtocolBufferException; @@ -20,8 +24,9 @@ import java.util.stream.StreamSupport; import java.util.zip.GZIPOutputStream; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.Compression; +import org.apache.beam.sdk.io.FileIO; import org.apache.beam.sdk.io.TFRecordIO; -import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.io.fs.EmptyMatchTreatment; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.DoFn; @@ -47,6 +52,7 @@ public class PipelineUtils { // Default type for MCF nodes. public static final String TYPE_THING = "Thing"; + public static final String Provisional = "ProvisionalNode"; // Length of prefix of object value to use for key. public static final int OBJECT_VALUE_PREFIX = 16; @@ -65,6 +71,18 @@ public class PipelineUtils { public static final TupleTag OBSERVATION_NODES_TAG = new TupleTag() {}; public static final TupleTag SCHEMA_NODES_TAG = new TupleTag() {}; + public static String getGcsFileContent(String gcsPath) throws IOException { + String[] parts = gcsPath.substring(5).split("/", 2); + String bucketName = parts[0]; + String objectName = parts[1]; + Storage storage = StorageOptions.getDefaultInstance().getService(); + Blob blob = storage.get(BlobId.of(bucketName, objectName)); + if (blob == null) { + throw new IOException("File not found in GCS: " + gcsPath); + } + return new String(blob.getContent(), StandardCharsets.UTF_8); + } + /** * Parses a byte array into an McfOptimizedGraph protocol buffer. * @@ -139,14 +157,24 @@ public void processElement( * @return PCollection of McfGraph proto. */ public static PCollection readMcfFiles(String files, Pipeline p) { - String delimiter = "\n\n"; PCollection nodes = p.apply( - "ReadMcfFiles", - TextIO.read() - .withDelimiter(delimiter.getBytes()) - .from(files) - .withEmptyMatchTreatment(EmptyMatchTreatment.ALLOW)); + "MatchFiles", + FileIO.match() + .filepattern(files) + .withEmptyMatchTreatment(EmptyMatchTreatment.ALLOW)) + .apply("ReadMatches", FileIO.readMatches().withCompression(Compression.AUTO)) + .apply( + "ReadContent", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement( + @Element FileIO.ReadableFile file, OutputReceiver receiver) + throws IOException { + receiver.output(file.readFullyAsUTF8String()); + } + })); PCollection mcf = nodes.apply( @@ -158,7 +186,23 @@ public McfGraph apply(String input) { return GraphUtils.convertToGraph(input); } })); - return mcf; + PCollection graph = + mcf.apply( + "ProcessGraph", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement( + @Element McfGraph element, OutputReceiver receiver) { + for (Map.Entry entry : + element.getNodesMap().entrySet()) { + McfGraph.Builder b = McfGraph.newBuilder(); + b.putNodes(entry.getKey(), entry.getValue()); + receiver.output(b.build()); + } + } + })); + return graph; } public static PCollectionTuple splitGraph(PCollection graph) { From f44043b403c7950145dd6221e9fcf14ce55f78ee Mon Sep 17 00:00:00 2001 From: Vishal Gupta Date: Tue, 3 Feb 2026 03:50:28 +0000 Subject: [PATCH 2/2] Node merge --- .../ingestion/data/CacheReader.java | 2 +- .../org/datacommons/ingestion/data/Edge.java | 3 + .../ingestion/data/GraphReader.java | 166 +++++++++++++++++- .../org/datacommons/ingestion/data/Node.java | 17 +- .../pipeline/ImportGroupPipeline.java | 37 ++-- .../ingestion/spanner/SpannerClient.java | 83 ++++++++- .../ingestion/data/CacheReaderTest.java | 11 +- .../ingestion/data/GraphReaderTest.java | 48 ++++- .../pipeline/util/PipelineUtils.java | 9 +- 9 files changed, 335 insertions(+), 41 deletions(-) diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/CacheReader.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/CacheReader.java index 8ec40a73..edc00adc 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/CacheReader.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/CacheReader.java @@ -141,7 +141,7 @@ public NodesEdges parseArcRow(String row, Counter mcfNodesWithoutTypeCounter) { Node.builder() .subjectId(nodeId) .value(nodeValue) - .bytes(bytes) + .bytes(bytes != null ? bytes.toByteArray() : new byte[0]) .name(entity.getName()) .types(types) .build()); diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/Edge.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/Edge.java index 428cec42..bc9a0acb 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/Edge.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/Edge.java @@ -14,6 +14,9 @@ public class Edge implements Serializable { private String objectId; private String provenance; + @SuppressWarnings("unused") + private Edge() {} + // Private constructor to enforce use of Builder private Edge(Builder builder) { this.subjectId = builder.subjectId; diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/GraphReader.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/GraphReader.java index 224b3791..30eedc54 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/GraphReader.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/GraphReader.java @@ -1,6 +1,5 @@ package org.datacommons.ingestion.data; -import com.google.cloud.ByteArray; import com.google.cloud.spanner.Mutation; import java.io.IOException; import java.io.Serializable; @@ -9,8 +8,10 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; @@ -40,6 +41,118 @@ public class GraphReader implements Serializable { private static final String DATCOM_AGGREGATE = "DataCommonsAggregate"; private static final String IMPORT_METADATA_FILE = "import_metadata_mcf.mcf"; + public static PCollection combineNodes(PCollection nodes) { + return nodes + .apply( + "MapNodesToKV", + ParDo.of( + new DoFn>() { + @ProcessElement + public void processElement( + @Element Node node, OutputReceiver> receiver) { + receiver.output(KV.of(node.getSubjectId(), node)); + } + })) + .apply( + "CombineNodes", + Combine.perKey( + new Combine.CombineFn, Node>() { + @Override + public List createAccumulator() { + return new ArrayList<>(); + } + + @Override + public List addInput(List accumulator, Node input) { + accumulator.add(input); + return accumulator; + } + + @Override + public List mergeAccumulators(Iterable> accumulators) { + List merged = new ArrayList<>(); + for (List acc : accumulators) { + merged.addAll(acc); + } + return merged; + } + + @Override + public Node extractOutput(List accumulator) { + if (accumulator.isEmpty()) return null; + Node first = accumulator.get(0); + Node.Builder builder = + Node.builder() + .subjectId(first.getSubjectId()) + .value(first.getValue()) + .name(first.getName()) + .types(first.getTypes()) + .bytes(first.getBytes()); + + Set types = new java.util.TreeSet<>(); + for (Node n : accumulator) { + types.addAll(n.getTypes()); + if (!n.getValue().isEmpty()) { + builder.value(n.getValue()); + } + if (!n.getName().isEmpty()) { + builder.name(n.getName()); + } + if (n.getBytes().length > 0) { + builder.bytes(n.getBytes()); + } + } + if (types.size() > 1 && types.contains("ProvisionalNode")) { + types.remove("ProvisionalNode"); + } + builder.types(new ArrayList<>(types)); + return builder.build(); + } + })) + .apply( + "ExtractNodes", + ParDo.of( + new DoFn, Node>() { + @ProcessElement + public void processElement( + @Element KV element, OutputReceiver receiver) { + receiver.output(element.getValue()); + } + })); + } + + public static PCollection nodeToMutations( + PCollection nodes, SpannerClient spannerClient) { + return nodes.apply( + "NodesToMutations", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(@Element Node node, OutputReceiver receiver) { + Mutation mutation = spannerClient.toNodeMutation(node); + if (mutation != null) { + receiver.output(mutation); + } + } + })); + } + + public static PCollection edgeToMutations( + PCollection edges, SpannerClient spannerClient) { + return edges.apply( + "EdgesToMutations", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(@Element Edge edge, OutputReceiver receiver) { + Mutation mutation = spannerClient.toEdgeMutation(edge); + if (mutation != null) { + receiver.output(mutation); + } + } + })); + } + public static List graphToNodes(McfGraph graph, Counter mcfNodesWithoutTypeCounter) { List nodes = new ArrayList<>(); for (Map.Entry nodeEntry : graph.getNodesMap().entrySet()) { @@ -79,10 +192,11 @@ public static List graphToNodes(McfGraph graph, Counter mcfNodesWithoutTyp node = Node.builder(); node.subjectId(PipelineUtils.generateObjectValueKey(val.getValue())); if (PipelineUtils.storeValueAsBytes(entry.getKey())) { - node.bytes(ByteArray.copyFrom(PipelineUtils.compressString(val.getValue()))); + node.bytes(PipelineUtils.compressString(val.getValue())); } else { node.value(val.getValue()); } + node.types(List.of(ValueType.TEXT.toString())); nodes.add(node.build()); } } @@ -101,13 +215,13 @@ public static PCollection getProvenanceMcf( String defaultProvenance = "Node: dcid:dc/base/" + importName + "\n" + "typeOf: dcid:Provenance\n"; mcfList.add(GraphUtils.convertToGraph(defaultProvenance)); + // try { + // mcfList.add(GraphUtils.convertToGraph(PipelineUtils.getGCSFileContent(metadataFile))); + // } catch (IOException e) { + // LOGGER.warn("Failed to read provenance metadata file: " + e.getMessage()); + // } try { - mcfList.addAll(GraphUtils.readMcfString(PipelineUtils.getGcsFileContent(metadataFile))); - } catch (IOException e) { - LOGGER.warn("Failed to read provenance metadata file: " + e.getMessage()); - } - try { - mcfList.addAll(GraphUtils.readMcfString(PipelineUtils.getGcsFileContent(provenanceFile))); + mcfList.add(GraphUtils.convertToGraph(PipelineUtils.getGCSFileContent(provenanceFile))); } catch (IOException e) { LOGGER.warn("Failed to read provenance metadata file: " + e.getMessage()); } @@ -214,6 +328,42 @@ public void processElement( })); } + public static PCollection mcfToNodes( + PCollection graph, Counter nodeCounter, Counter mcfNodesWithoutTypeCounter) { + return graph.apply( + "McfToNodes", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(@Element McfGraph element, OutputReceiver receiver) { + List nodes = graphToNodes(element, mcfNodesWithoutTypeCounter); + for (Node node : nodes) { + // LOGGER.info("Node: {}", node.toString()); + receiver.output(node); + } + nodeCounter.inc(nodes.size()); + } + })); + } + + public static PCollection mcfToEdges( + PCollection graph, String provenance, Counter edgeCounter) { + return graph.apply( + "McfToEdges", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(@Element McfGraph element, OutputReceiver receiver) { + List edges = graphToEdges(element, provenance); + for (Edge edge : edges) { + receiver.output(edge); + // LOGGER.info("Edge : {}", edge.toString()); + } + edgeCounter.inc(edges.size()); + } + })); + } + public static PCollection> graphToNodes( PCollection graph, SpannerClient spannerClient, diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/Node.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/Node.java index dc30a375..975d6e36 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/Node.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/data/Node.java @@ -1,7 +1,7 @@ package org.datacommons.ingestion.data; -import com.google.cloud.ByteArray; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import java.util.Objects; import org.apache.beam.sdk.coders.DefaultCoder; @@ -16,10 +16,13 @@ public class Node implements Serializable { private String subjectId; private String value; - private ByteArray bytes; + private byte[] bytes; private String name; private List types; + @SuppressWarnings("unused") + private Node() {} + // Private constructor to enforce use of Builder private Node(Builder builder) { this.subjectId = builder.subjectId; @@ -41,7 +44,7 @@ public String getValue() { return value; } - public ByteArray getBytes() { + public byte[] getBytes() { return bytes; } @@ -60,7 +63,7 @@ public boolean equals(Object o) { Node node = (Node) o; return Objects.equals(subjectId, node.subjectId) && Objects.equals(value, node.value) - && Objects.equals(bytes, node.bytes) + && Arrays.equals(bytes, node.bytes) && Objects.equals(name, node.name) && Objects.equals(types, node.types); } @@ -74,13 +77,13 @@ public int hashCode() { public String toString() { return String.format( "Node{subjectId='%s', value='%s', bytes='%s', name='%s', types=%s}", - subjectId, value, bytes, name, types); + subjectId, value, Arrays.toString(bytes), name, types); } public static class Builder { private String subjectId = ""; private String value = ""; - private ByteArray bytes = null; + private byte[] bytes = new byte[0]; private String name = ""; private List types = List.of(); @@ -96,7 +99,7 @@ public Builder value(String value) { return this; } - public Builder bytes(ByteArray bytes) { + public Builder bytes(byte[] bytes) { this.bytes = bytes; return this; } diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/ImportGroupPipeline.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/ImportGroupPipeline.java index 1b97ebeb..90f09f52 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/ImportGroupPipeline.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/ImportGroupPipeline.java @@ -11,13 +11,16 @@ import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.Wait; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; +import org.datacommons.ingestion.data.Edge; import org.datacommons.ingestion.data.GraphReader; +import org.datacommons.ingestion.data.Node; import org.datacommons.ingestion.spanner.SpannerClient; import org.datacommons.pipeline.util.PipelineUtils; import org.datacommons.proto.Mcf.McfGraph; @@ -95,9 +98,10 @@ public static void main(String[] args) { GraphReader.getProvenanceMcf( options.getStorageBucketId(), importName, latestVersion, pipeline); - PCollection deleteMutations = - GraphReader.getDeleteMutations(importName, provenance, pipeline, spannerClient); - deleteMutationList.add(deleteMutations); + // PCollection deleteMutations = + // GraphReader.getDeleteMutations(importName, provenance, pipeline, spannerClient); + // deleteMutationList.add(deleteMutations); + // Read schema mcf files and combine MCF nodes, and convert to spanner mutations (Node/Edge). String graphPath = latestVersion.replaceAll("/+$", "") @@ -115,13 +119,23 @@ public static void main(String[] args) { .and(provenanceMcf) .apply("FlattenSchema", Flatten.pCollections()); - PCollection edgeMutations = - GraphReader.graphToEdges(schemaMcf, provenance, spannerClient, edgeCounter) - .apply("ExtractEdgeMutations", Values.create()); + PCollection edges = GraphReader.mcfToEdges(schemaMcf, provenance, edgeCounter); + + PCollection edgeMutations = GraphReader.edgeToMutations(edges, spannerClient); + + PCollection newNodes = + GraphReader.mcfToNodes(schemaMcf, nodeCounter, nodeInvalidTypeCounter); + + PCollection existingNodes = spannerClient.readExistingNodes(newNodes); + + PCollection mergedNodes = + PCollectionList.of(newNodes) + .and(existingNodes) + .apply("FlattenNodes", Flatten.pCollections()); + + PCollection finalNodes = GraphReader.combineNodes(mergedNodes); - PCollection nodeMutations = - GraphReader.graphToNodes(schemaMcf, spannerClient, nodeCounter, nodeInvalidTypeCounter) - .apply("ExtractEdgeMutations", Values.create()); + PCollection nodeMutations = GraphReader.nodeToMutations(finalNodes, spannerClient); nodeMutationList.add(nodeMutations); edgeMutationList.add(edgeMutations); @@ -135,8 +149,9 @@ public static void main(String[] args) { obsMutationList.add(observationMutations); } PCollection deleteMutations = - PCollectionList.of(deleteMutationList) - .apply("FlattenDeleteMutations", Flatten.pCollections()); + pipeline.apply(Create.empty(org.apache.beam.sdk.values.TypeDescriptor.of(Mutation.class))); + // PCollectionList.of(deleteMutationList) + // .apply("FlattenDeleteMutations", Flatten.pCollections()); SpannerWriteResult deleted = deleteMutations.apply("DeleteImportsFromSpanner", spannerClient.getWriteTransform()); // Write the mutations to spanner. diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java index cb11cf1f..cddceb30 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java @@ -1,5 +1,6 @@ package org.datacommons.ingestion.spanner; +import com.google.cloud.ByteArray; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.Mutation; import com.google.cloud.spanner.SpannerExceptionFactory; @@ -24,15 +25,23 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.gcp.spanner.ReadOperation; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.Write; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.WriteGrouped; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.transforms.Distinct; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Filter; +import org.apache.beam.sdk.transforms.GroupIntoBatches; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptors; import org.datacommons.ingestion.data.Edge; import org.datacommons.ingestion.data.Node; import org.datacommons.ingestion.data.Observation; @@ -241,7 +250,7 @@ public Mutation toNodeMutation(Node node) { .set("value") .to(node.getValue()) .set("bytes") - .to(node.getBytes()) + .to(ByteArray.copyFrom(node.getBytes())) .set("name") .to(node.getName()) .set("types") @@ -350,6 +359,78 @@ public List> filterGraphKVMutations( return filtered; } + public PCollection readExistingNodes(PCollection inputNodes) { + return inputNodes + .apply("FilterTextNodes", Filter.by((Node node) -> !node.getTypes().contains("TEXT"))) + .apply( + "ExtractSubjectIds", + MapElements.into(TypeDescriptors.strings()).via(Node::getSubjectId)) + .apply("DistinctIds", Distinct.create()) + .apply( + "KeyWithRandomShard", + WithKeys.of( + new SimpleFunction() { + @Override + public Integer apply(String input) { + return Math.abs(input.hashCode()) % 100; + } + })) + .apply("BatchIds", GroupIntoBatches.ofSize(1000)) + .apply( + "CreateReadOperations", + MapElements.via( + new SimpleFunction>, ReadOperation>() { + @Override + public ReadOperation apply(KV> input) { + List ids = new ArrayList<>(); + input.getValue().forEach(ids::add); + return ReadOperation.create() + .withQuery( + Statement.newBuilder( + String.format( + "SELECT subject_id, types, name, value, bytes FROM %s WHERE" + + " subject_id IN UNNEST(@ids)", + nodeTableName)) + .bind("ids") + .toStringArray(ids) + .build()); + } + })) + .apply( + "ReadFromSpanner", + SpannerIO.readAll() + .withProjectId(gcpProjectId) + .withInstanceId(spannerInstanceId) + .withDatabaseId(spannerDatabaseId)) + .apply( + "ParseStructToNode", + MapElements.via( + new SimpleFunction() { + @Override + public Node apply(Struct struct) { + String subjectId = struct.getString("subject_id"); + List types = + struct.isNull("types") ? new ArrayList<>() : struct.getStringList("types"); + String name = struct.isNull("name") ? "" : struct.getString("name"); + String value = struct.isNull("value") ? "" : struct.getString("value"); + ByteArray bytes = + struct.isNull("bytes") ? ByteArray.copyFrom("") : struct.getBytes("bytes"); + Node node = + Node.builder() + .subjectId(subjectId) + .types(types) + .name(name) + .value(value) + .bytes(bytes == null ? new byte[0] : bytes.toByteArray()) + .build(); + return node; + } + })) + .apply( + "FilterProvisionalNodes", + Filter.by((Node node) -> !node.getTypes().contains("ProvisionalNode"))); + } + public static String getSubjectId(Mutation mutation) { return getMutationValue(mutation, "subject_id"); } diff --git a/pipeline/ingestion/src/test/java/org/datacommons/ingestion/data/CacheReaderTest.java b/pipeline/ingestion/src/test/java/org/datacommons/ingestion/data/CacheReaderTest.java index aaba8fa0..c082c82b 100644 --- a/pipeline/ingestion/src/test/java/org/datacommons/ingestion/data/CacheReaderTest.java +++ b/pipeline/ingestion/src/test/java/org/datacommons/ingestion/data/CacheReaderTest.java @@ -2,7 +2,6 @@ import static org.junit.Assert.assertEquals; -import com.google.cloud.ByteArray; import java.util.List; import org.apache.beam.sdk.metrics.Counter; import org.datacommons.Storage.Observations; @@ -55,7 +54,8 @@ public void testParseArcRowForOutArcWithValueNode() { Node.builder() .subjectId("Percentage Work :c6CV18sK/njghkqgkS/mMaTkKP+oWup0pgYkS6iFpvY=") .value( - "Percentage Work Related Physical Activity, Moderate Activity Or Heavy Activity Among Population") + "Percentage Work Related Physical Activity, Moderate Activity Or Heavy" + + " Activity Among Population") .build()) .addEdge( Edge.builder() @@ -85,9 +85,10 @@ public void testParseArcRowForOutArcWithBytesNode() { Node.builder() .subjectId("{ \"type\": \"Pol:G8RZr2tV3+cSSDVRj8Q4KnMpxDhZyZr438T3Fvq1Zkk=") .bytes( - ByteArray.copyFrom( - PipelineUtils.compressString( - "{ \"type\": \"Polygon\", \"coordinates\": [ [ [9, 7], [9, 6.5], [9.5, 6.5], [9.5, 7], [9, 7] ] ] } "))) + PipelineUtils.compressString( + "{ \"type\": \"Polygon\", \"coordinates\": [ [ [9, 7], " + + " [9, 6.5], [9.5, 6.5], [9.5, 7], [9, 7] " + + " ] ] } ")) .build()) .addEdge( Edge.builder() diff --git a/pipeline/ingestion/src/test/java/org/datacommons/ingestion/data/GraphReaderTest.java b/pipeline/ingestion/src/test/java/org/datacommons/ingestion/data/GraphReaderTest.java index 770359f7..ab1fb1b7 100644 --- a/pipeline/ingestion/src/test/java/org/datacommons/ingestion/data/GraphReaderTest.java +++ b/pipeline/ingestion/src/test/java/org/datacommons/ingestion/data/GraphReaderTest.java @@ -3,11 +3,14 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import com.google.cloud.ByteArray; import java.util.Arrays; import java.util.Comparator; import java.util.List; import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; import org.datacommons.Storage.Observations; import org.datacommons.pipeline.util.PipelineUtils; import org.datacommons.proto.Mcf.McfGraph; @@ -18,10 +21,39 @@ import org.datacommons.proto.Mcf.McfStatVarObsSeries.StatVarObs; import org.datacommons.proto.Mcf.McfType; import org.datacommons.proto.Mcf.ValueType; +import org.junit.Rule; import org.junit.Test; import org.mockito.Mockito; public class GraphReaderTest { + @Rule public final transient TestPipeline p = TestPipeline.create(); + + @Test + public void testCombineNodes() { + Node node1 = + Node.builder() + .subjectId("id1") + .types(List.of("Type1", "ProvisionalNode")) + .name("Name1") + .build(); + Node node2 = Node.builder().subjectId("id1").types(List.of("Type2")).value("Value1").build(); + Node node3 = Node.builder().subjectId("id2").types(List.of("ProvisionalNode")).build(); + + PCollection input = p.apply(Create.of(node1, node2, node3)); + PCollection output = GraphReader.combineNodes(input); + + Node expectedNode1 = + Node.builder() + .subjectId("id1") + .types(List.of("Type1", "Type2")) + .name("Name1") + .value("Value1") + .build(); + Node expectedNode2 = Node.builder().subjectId("id2").types(List.of("ProvisionalNode")).build(); + + PAssert.that(output).containsInAnyOrder(expectedNode1, expectedNode2); + p.run(); + } @Test public void testGraphToNodes() { @@ -119,19 +151,21 @@ public void testGraphToNodes() { Node.builder() .subjectId("Node Zero:kUyRupzrJkxe/HIOIctxlJX4woEGeOTtlVwqyXYnfDE=") .value("Node Zero") + .types(List.of("TEXT")) .build(), Node.builder() .subjectId("{ \"type\": \"Pol:G8RZr2tV3+cSSDVRj8Q4KnMpxDhZyZr438T3Fvq1Zkk=") .bytes( - ByteArray.copyFrom( - PipelineUtils.compressString( - "{ \"type\": \"Polygon\", \"coordinates\": [ [ [9, 7], " - + " [9, 6.5], [9.5, 6.5], [9.5, 7], [9, 7] " - + " ] ] } "))) + PipelineUtils.compressString( + "{ \"type\": \"Polygon\", \"coordinates\": [ [ [9, 7], " + + " [9, 6.5], [9.5, 6.5], [9.5, 7], [9, 7] " + + " ] ] } ")) + .types(List.of("TEXT")) .build(), Node.builder() .subjectId("Node One:J7we8EV8ssChRxBgWot6zDSbHl4xGY7I6mQosc89hFk=") .value("Node One") + .types(List.of("TEXT")) .build()); List actualNodes = GraphReader.graphToNodes(graph, mockMcfNodesWithoutTypeCounter); @@ -147,7 +181,7 @@ public void testGraphToNodes() { Node actual = actualNodes.get(i); assertEquals(expected.getSubjectId(), actual.getSubjectId()); assertEquals(expected.getValue(), actual.getValue()); - assertEquals(expected.getBytes(), actual.getBytes()); + assertArrayEquals(expected.getBytes(), actual.getBytes()); assertEquals(expected.getName(), actual.getName()); assertArrayEquals(expected.getTypes().toArray(), actual.getTypes().toArray()); } diff --git a/pipeline/util/src/main/java/org/datacommons/pipeline/util/PipelineUtils.java b/pipeline/util/src/main/java/org/datacommons/pipeline/util/PipelineUtils.java index 0cf40b2f..766e0f96 100644 --- a/pipeline/util/src/main/java/org/datacommons/pipeline/util/PipelineUtils.java +++ b/pipeline/util/src/main/java/org/datacommons/pipeline/util/PipelineUtils.java @@ -71,7 +71,14 @@ public class PipelineUtils { public static final TupleTag OBSERVATION_NODES_TAG = new TupleTag() {}; public static final TupleTag SCHEMA_NODES_TAG = new TupleTag() {}; - public static String getGcsFileContent(String gcsPath) throws IOException { + /** + * Reads the content of a file from GCS. + * + * @param gcsPath The GCS path of the file (e.g., gs://bucket/path/to/file). + * @return The content of the file as a string. + * @throws IOException If the file is not found or cannot be read. + */ + public static String getGCSFileContent(String gcsPath) throws IOException { String[] parts = gcsPath.substring(5).split("/", 2); String bucketName = parts[0]; String objectName = parts[1];