From 6cc656c56da14b10ba5f6d226d190321f75bbd19 Mon Sep 17 00:00:00 2001 From: Anamika AggarwaL Date: Fri, 20 Mar 2026 09:04:30 +0530 Subject: [PATCH 1/8] fix joins for missing key columns --- src/DataFrame/Operations/Join.hs | 166 ++++++++++++++++------------- src/DataFrame/Operations/Subset.hs | 8 +- tests/Operations/Join.hs | 45 +++++++- 3 files changed, 141 insertions(+), 78 deletions(-) diff --git a/src/DataFrame/Operations/Join.hs b/src/DataFrame/Operations/Join.hs index 0250a08..c89f5c2 100644 --- a/src/DataFrame/Operations/Join.hs +++ b/src/DataFrame/Operations/Join.hs @@ -10,6 +10,7 @@ module DataFrame.Operations.Join where import Control.Applicative ((<|>)) +import Control.Exception (throw) import Control.Monad (forM_, when) import Control.Monad.ST (ST, runST) import qualified Data.HashMap.Strict as HM @@ -24,6 +25,7 @@ import qualified Data.Vector as VB import qualified Data.Vector.Algorithms.Merge as VA import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM +import DataFrame.Errors (DataFrameException (ColumnNotFoundException)) import DataFrame.Internal.Column as D import DataFrame.Internal.DataFrame as D import DataFrame.Operations.Aggregation as D @@ -145,6 +147,15 @@ fillCrossProduct !leftSI !rightSI !lStart !lEnd !rStart !rEnd !lv !rv !pos = goL keyColIndices :: S.Set T.Text -> DataFrame -> [Int] keyColIndices csSet df = M.elems $ M.restrictKeys (D.columnIndices df) csSet +-- | Validate that all requested join keys exist, then return their indices. +validatedKeyColIndices :: T.Text -> S.Set T.Text -> DataFrame -> [Int] +validatedKeyColIndices callPoint csSet df = + let columnIdxs = D.columnIndices df + missingKeys = S.toAscList (csSet `S.difference` M.keysSet columnIdxs) + in case missingKeys of + [] -> M.elems $ M.restrictKeys columnIdxs csSet + missingKey : _ -> throw (ColumnNotFoundException missingKey callPoint (M.keys columnIdxs)) + -- ============================================================ -- Inner Join -- ============================================================ @@ -170,38 +181,38 @@ ghci> D.innerJoin ["key"] df other @ -} innerJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame -innerJoin cs left right - | D.null right || D.null left = D.empty - | otherwise = - let - csSet = S.fromList cs - leftRows = fst (D.dimensions left) - rightRows = fst (D.dimensions right) - - leftKeyIdxs = keyColIndices csSet left - rightKeyIdxs = keyColIndices csSet right - leftHashes = D.computeRowHashes leftKeyIdxs left - rightHashes = D.computeRowHashes rightKeyIdxs right - - buildRows = min leftRows rightRows - (leftIxs, rightIxs) - | buildRows > joinStrategyThreshold = - sortMergeInnerKernel leftHashes rightHashes - | rightRows <= leftRows = - -- Build on right (smaller or equal), probe with left - hashInnerKernel leftHashes rightHashes - | otherwise = - -- Build on left (smaller), probe with right, swap result - let (!rIxs, !lIxs) = hashInnerKernel rightHashes leftHashes - in (lIxs, rIxs) - in - assembleInner csSet left right leftIxs rightIxs +innerJoin cs left right = + let + csSet = S.fromList cs + leftRows = fst (D.dimensions left) + rightRows = fst (D.dimensions right) + + leftKeyIdxs = validatedKeyColIndices "innerJoin" csSet left + rightKeyIdxs = validatedKeyColIndices "innerJoin" csSet right + leftHashes = D.computeRowHashes leftKeyIdxs left + rightHashes = D.computeRowHashes rightKeyIdxs right + + buildRows = min leftRows rightRows + (leftIxs, rightIxs) + | buildRows > joinStrategyThreshold = + sortMergeInnerKernel leftHashes rightHashes + | rightRows <= leftRows = + -- Build on right (smaller or equal), probe with left + hashInnerKernel leftHashes rightHashes + | otherwise = + -- Build on left (smaller), probe with right, swap result + let (!rIxs, !lIxs) = hashInnerKernel rightHashes leftHashes + in (lIxs, rIxs) + in + if D.null right || D.null left + then D.empty + else assembleInner csSet left right leftIxs rightIxs -- | Compute hashes for the given key column names in a DataFrame. buildHashColumn :: [T.Text] -> DataFrame -> VU.Vector Int buildHashColumn keys df = let csSet = S.fromList keys - keyIdxs = keyColIndices csSet df + keyIdxs = validatedKeyColIndices "buildHashColumn" csSet df in D.computeRowHashes keyIdxs df {- | Probe one batch of rows against a pre-built 'CompactIndex'. @@ -527,28 +538,34 @@ ghci> D.leftJoin ["key"] df other @ -} leftJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame -leftJoin cs left right - | D.null right || D.nRows right == 0 = left - | D.null left || D.nRows left == 0 = D.empty - | otherwise = - let - csSet = S.fromList cs - rightRows = fst (D.dimensions right) - - leftKeyIdxs = keyColIndices csSet left - rightKeyIdxs = keyColIndices csSet right - leftHashes = D.computeRowHashes leftKeyIdxs left - rightHashes = D.computeRowHashes rightKeyIdxs right - - -- Right is always the build side for left join - (leftIxs, rightIxs) - | rightRows > joinStrategyThreshold = - sortMergeLeftKernel leftHashes rightHashes - | otherwise = - hashLeftKernel leftHashes rightHashes - in - -- rightIxs uses -1 as sentinel for "no match" - assembleLeft csSet left right leftIxs rightIxs +leftJoin = leftJoinWithCallPoint "leftJoin" + +leftJoinWithCallPoint :: T.Text -> [T.Text] -> DataFrame -> DataFrame -> DataFrame +leftJoinWithCallPoint callPoint cs left right = + let + csSet = S.fromList cs + rightRows = fst (D.dimensions right) + + leftKeyIdxs = validatedKeyColIndices callPoint csSet left + rightKeyIdxs = validatedKeyColIndices callPoint csSet right + leftHashes = D.computeRowHashes leftKeyIdxs left + rightHashes = D.computeRowHashes rightKeyIdxs right + + -- Right is always the build side for left join + (leftIxs, rightIxs) + | rightRows > joinStrategyThreshold = + sortMergeLeftKernel leftHashes rightHashes + | otherwise = + hashLeftKernel leftHashes rightHashes + in + if D.null right || D.nRows right == 0 + then left + else + if D.null left || D.nRows left == 0 + then D.empty + else + -- rightIxs uses -1 as sentinel for "no match" + assembleLeft csSet left right leftIxs rightIxs {- | Hash-based left join kernel. Returns @(leftExpandedIndices, rightExpandedIndices)@ where @@ -798,33 +815,36 @@ ghci> D.rightJoin ["key"] df other -} rightJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame -rightJoin cs left right = leftJoin cs right left +rightJoin cs left right = leftJoinWithCallPoint "rightJoin" cs right left fullOuterJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame -fullOuterJoin cs left right - | D.null right || D.nRows right == 0 = left - | D.null left || D.nRows left == 0 = right - | otherwise = - let - csSet = S.fromList cs - leftRows = fst (D.dimensions left) - rightRows = fst (D.dimensions right) - - leftKeyIdxs = keyColIndices csSet left - rightKeyIdxs = keyColIndices csSet right - leftHashes = D.computeRowHashes leftKeyIdxs left - rightHashes = D.computeRowHashes rightKeyIdxs right - - -- Both sides can have nulls in full outer - (leftIxs, rightIxs) - | max leftRows rightRows > joinStrategyThreshold = - sortMergeFullOuterKernel leftHashes rightHashes - | otherwise = - hashFullOuterKernel leftHashes rightHashes - in - -- Both index vectors use -1 as sentinel - assembleFullOuter csSet left right leftIxs rightIxs +fullOuterJoin cs left right = + let + csSet = S.fromList cs + leftRows = fst (D.dimensions left) + rightRows = fst (D.dimensions right) + + leftKeyIdxs = validatedKeyColIndices "fullOuterJoin" csSet left + rightKeyIdxs = validatedKeyColIndices "fullOuterJoin" csSet right + leftHashes = D.computeRowHashes leftKeyIdxs left + rightHashes = D.computeRowHashes rightKeyIdxs right + + -- Both sides can have nulls in full outer + (leftIxs, rightIxs) + | max leftRows rightRows > joinStrategyThreshold = + sortMergeFullOuterKernel leftHashes rightHashes + | otherwise = + hashFullOuterKernel leftHashes rightHashes + in + if D.null right || D.nRows right == 0 + then left + else + if D.null left || D.nRows left == 0 + then right + else + -- Both index vectors use -1 as sentinel + assembleFullOuter csSet left right leftIxs rightIxs {- | Hash-based full outer join kernel. Builds compact indices on both sides. diff --git a/src/DataFrame/Operations/Subset.hs b/src/DataFrame/Operations/Subset.hs index 58b6c77..5ad68df 100644 --- a/src/DataFrame/Operations/Subset.hs +++ b/src/DataFrame/Operations/Subset.hs @@ -513,7 +513,7 @@ ghci> D.stratifiedSample (mkStdGen 42) 0.8 "label" df -} stratifiedSample :: forall a g. - (SplitGen g, RandomGen g, Columnable a) => + (RandomGen g, Columnable a) => g -> Double -> Expr a -> DataFrame -> DataFrame stratifiedSample gen p strataCol df = let col = case strataCol of @@ -523,7 +523,7 @@ stratifiedSample gen p strataCol df = go _ [] = mempty go g (ixs : rest) = let stratum = rowsAtIndices ixs df - (g1, g2) = splitGen g + (g1, g2) = split g in sample g1 p stratum <> go g2 rest in go gen groups @@ -537,7 +537,7 @@ ghci> D.stratifiedSplit (mkStdGen 42) 0.8 "label" df -} stratifiedSplit :: forall a g. - (SplitGen g, RandomGen g, Columnable a) => + (RandomGen g, Columnable a) => g -> Double -> Expr a -> DataFrame -> (DataFrame, DataFrame) stratifiedSplit gen p strataCol df = let col = case strataCol of @@ -547,7 +547,7 @@ stratifiedSplit gen p strataCol df = go _ [] = (mempty, mempty) go g (ixs : rest) = let stratum = rowsAtIndices ixs df - (g1, g2) = splitGen g + (g1, g2) = split g (tr, va) = randomSplit g1 p stratum (trAcc, vaAcc) = go g2 rest in (tr <> trAcc, va <> vaAcc) diff --git a/tests/Operations/Join.hs b/tests/Operations/Join.hs index 8818352..379c57f 100644 --- a/tests/Operations/Join.hs +++ b/tests/Operations/Join.hs @@ -4,7 +4,9 @@ module Operations.Join where -import Data.Text (Text) +import Assertions (assertExpectException) +import Control.Exception (evaluate) +import Data.Text (Text, unpack) import Data.These import qualified DataFrame as D import qualified DataFrame.Functions as F @@ -26,6 +28,11 @@ df2 = , ("B", D.fromList ["B0" :: Text, "B1", "B2"]) ] +assertMissingJoinColumn :: String -> Text -> D.DataFrame -> Assertion +assertMissingJoinColumn preface missingKey result = do + assertExpectException preface "Column not found" (evaluate (D.nRows result)) + assertExpectException preface (unpack missingKey) (evaluate (D.nRows result)) + testInnerJoin :: Test testInnerJoin = TestCase @@ -253,6 +260,38 @@ testOuterJoinWithCollisions = (D.sortBy [D.Asc (F.col @Text "key")] (fullOuterJoin ["key"] dfL dfR)) ) +testInnerJoinMissingKey :: Test +testInnerJoinMissingKey = + TestCase $ + assertMissingJoinColumn + "Inner join should fail when the join key is missing" + "Cats" + (innerJoin ["Cats"] df1 df2) + +testLeftJoinMissingKey :: Test +testLeftJoinMissingKey = + TestCase $ + assertMissingJoinColumn + "Left join should fail when the join key is missing" + "Cats" + (leftJoin ["Cats"] df1 df2) + +testRightJoinMissingKey :: Test +testRightJoinMissingKey = + TestCase $ + assertMissingJoinColumn + "Right join should fail when the join key is missing" + "Animals" + (rightJoin ["Animals"] df1 df2) + +testFullOuterJoinMissingKey :: Test +testFullOuterJoinMissingKey = + TestCase $ + assertMissingJoinColumn + "Full outer join should fail when the join key is missing" + "Cats" + (fullOuterJoin ["Cats"] df1 df2) + -- Empty DataFrame fixtures: same schema as df1/df2 but zero rows. emptyDf1 :: D.DataFrame emptyDf1 = @@ -353,6 +392,10 @@ tests = , TestLabel "leftJoinWithCollisions" testLeftJoinWithCollisions , TestLabel "rightJoinWithCollisions" testRightJoinWithCollisions , TestLabel "outerJoinWithCollisions" testOuterJoinWithCollisions + , TestLabel "innerJoinMissingKey" testInnerJoinMissingKey + , TestLabel "leftJoinMissingKey" testLeftJoinMissingKey + , TestLabel "rightJoinMissingKey" testRightJoinMissingKey + , TestLabel "fullOuterJoinMissingKey" testFullOuterJoinMissingKey , TestLabel "innerJoinBothEmpty" testInnerJoinBothEmpty , TestLabel "innerJoinLeftEmpty" testInnerJoinLeftEmpty , TestLabel "innerJoinRightEmpty" testInnerJoinRightEmpty From f6ea70923de309b9fa398ff7d47cf54be9701fe0 Mon Sep 17 00:00:00 2001 From: Anamika AggarwaL Date: Sat, 21 Mar 2026 16:38:32 +0530 Subject: [PATCH 2/8] use splitgen when random supports it --- src/DataFrame/Operations/Join.hs | 3 ++- src/DataFrame/Operations/Subset.hs | 22 ++++++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/DataFrame/Operations/Join.hs b/src/DataFrame/Operations/Join.hs index c89f5c2..d3e85b1 100644 --- a/src/DataFrame/Operations/Join.hs +++ b/src/DataFrame/Operations/Join.hs @@ -540,7 +540,8 @@ ghci> D.leftJoin ["key"] df other leftJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame leftJoin = leftJoinWithCallPoint "leftJoin" -leftJoinWithCallPoint :: T.Text -> [T.Text] -> DataFrame -> DataFrame -> DataFrame +leftJoinWithCallPoint :: + T.Text -> [T.Text] -> DataFrame -> DataFrame -> DataFrame leftJoinWithCallPoint callPoint cs left right = let csSet = S.fromList cs diff --git a/src/DataFrame/Operations/Subset.hs b/src/DataFrame/Operations/Subset.hs index 5ad68df..313714f 100644 --- a/src/DataFrame/Operations/Subset.hs +++ b/src/DataFrame/Operations/Subset.hs @@ -1,4 +1,6 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE ExplicitNamespaces #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} @@ -48,6 +50,18 @@ import System.Random import Type.Reflection import Prelude hiding (filter, take) +#if MIN_VERSION_random(1,3,0) +type SplittableGen g = (SplitGen g, RandomGen g) + +splitForStratified :: SplittableGen g => g -> (g, g) +splitForStratified = splitGen +#else +type SplittableGen g = RandomGen g + +splitForStratified :: SplittableGen g => g -> (g, g) +splitForStratified = split +#endif + -- | O(k * n) Take the first n rows of a DataFrame. take :: Int -> DataFrame -> DataFrame take n d = d{columns = V.map (takeColumn n') (columns d), dataframeDimensions = (n', c)} @@ -513,7 +527,7 @@ ghci> D.stratifiedSample (mkStdGen 42) 0.8 "label" df -} stratifiedSample :: forall a g. - (RandomGen g, Columnable a) => + (SplittableGen g, Columnable a) => g -> Double -> Expr a -> DataFrame -> DataFrame stratifiedSample gen p strataCol df = let col = case strataCol of @@ -523,7 +537,7 @@ stratifiedSample gen p strataCol df = go _ [] = mempty go g (ixs : rest) = let stratum = rowsAtIndices ixs df - (g1, g2) = split g + (g1, g2) = splitForStratified g in sample g1 p stratum <> go g2 rest in go gen groups @@ -537,7 +551,7 @@ ghci> D.stratifiedSplit (mkStdGen 42) 0.8 "label" df -} stratifiedSplit :: forall a g. - (RandomGen g, Columnable a) => + (SplittableGen g, Columnable a) => g -> Double -> Expr a -> DataFrame -> (DataFrame, DataFrame) stratifiedSplit gen p strataCol df = let col = case strataCol of @@ -547,7 +561,7 @@ stratifiedSplit gen p strataCol df = go _ [] = (mempty, mempty) go g (ixs : rest) = let stratum = rowsAtIndices ixs df - (g1, g2) = split g + (g1, g2) = splitForStratified g (tr, va) = randomSplit g1 p stratum (trAcc, vaAcc) = go g2 rest in (tr <> trAcc, va <> vaAcc) From a34d062decfe3aaaca71e0511d9efacd17edf3ca Mon Sep 17 00:00:00 2001 From: Anamika AggarwaL Date: Sat, 21 Mar 2026 16:48:26 +0530 Subject: [PATCH 3/8] report all missing join keys --- src/DataFrame/Errors.hs | 21 +++++++++++++++++++++ src/DataFrame/Operations/Join.hs | 7 +++++-- tests/Operations/Join.hs | 25 ++++++++++++++++++++++--- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/DataFrame/Errors.hs b/src/DataFrame/Errors.hs index ab3318c..e8f88e7 100644 --- a/src/DataFrame/Errors.hs +++ b/src/DataFrame/Errors.hs @@ -30,6 +30,7 @@ data DataFrameException where DataFrameException AggregatedAndNonAggregatedException :: T.Text -> T.Text -> DataFrameException ColumnNotFoundException :: T.Text -> T.Text -> [T.Text] -> DataFrameException + ColumnsNotFoundException :: [T.Text] -> T.Text -> [T.Text] -> DataFrameException EmptyDataSetException :: T.Text -> DataFrameException InternalException :: T.Text -> DataFrameException NonColumnReferenceException :: T.Text -> DataFrameException @@ -52,6 +53,7 @@ instance Show DataFrameException where (callingFunctionName context) errorString show (ColumnNotFoundException columnName callPoint availableColumns) = columnNotFound columnName callPoint availableColumns + show (ColumnsNotFoundException columnNames callPoint availableColumns) = columnsNotFound columnNames callPoint availableColumns show (EmptyDataSetException callPoint) = emptyDataSetError callPoint show (WrongQuantileNumberException q) = wrongQuantileNumberError q show (WrongQuantileIndexException qs q) = wrongQuantileIndexError qs q @@ -75,6 +77,25 @@ columnNotFound name callPoint columns = ++ T.unpack (guessColumnName name columns) ++ "?\n\n" +columnsNotFound :: [T.Text] -> T.Text -> [T.Text] -> String +columnsNotFound names callPoint columns = + red "\n\n[ERROR] " + ++ "Columns not found: " + ++ T.unpack (T.intercalate ", " names) + ++ " for operation " + ++ T.unpack callPoint + ++ concatMap formatSuggestion names + ++ "\n\n" + where + formatSuggestion name = case guessColumnName name columns of + "" -> "" + guessed -> + "\n\tDid you mean " + ++ T.unpack guessed + ++ " for " + ++ T.unpack name + ++ "?" + typeMismatchError :: String -> String -> String typeMismatchError givenType expectedType = red $ diff --git a/src/DataFrame/Operations/Join.hs b/src/DataFrame/Operations/Join.hs index d3e85b1..af5ca1e 100644 --- a/src/DataFrame/Operations/Join.hs +++ b/src/DataFrame/Operations/Join.hs @@ -25,7 +25,9 @@ import qualified Data.Vector as VB import qualified Data.Vector.Algorithms.Merge as VA import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM -import DataFrame.Errors (DataFrameException (ColumnNotFoundException)) +import DataFrame.Errors ( + DataFrameException (ColumnNotFoundException, ColumnsNotFoundException), + ) import DataFrame.Internal.Column as D import DataFrame.Internal.DataFrame as D import DataFrame.Operations.Aggregation as D @@ -154,7 +156,8 @@ validatedKeyColIndices callPoint csSet df = missingKeys = S.toAscList (csSet `S.difference` M.keysSet columnIdxs) in case missingKeys of [] -> M.elems $ M.restrictKeys columnIdxs csSet - missingKey : _ -> throw (ColumnNotFoundException missingKey callPoint (M.keys columnIdxs)) + [missingKey] -> throw (ColumnNotFoundException missingKey callPoint (M.keys columnIdxs)) + _ -> throw (ColumnsNotFoundException missingKeys callPoint (M.keys columnIdxs)) -- ============================================================ -- Inner Join diff --git a/tests/Operations/Join.hs b/tests/Operations/Join.hs index 379c57f..39bb85c 100644 --- a/tests/Operations/Join.hs +++ b/tests/Operations/Join.hs @@ -28,10 +28,20 @@ df2 = , ("B", D.fromList ["B0" :: Text, "B1", "B2"]) ] +assertMissingJoinColumns :: String -> [Text] -> D.DataFrame -> Assertion +assertMissingJoinColumns preface missingKeys result = do + assertExpectException + preface + (if length missingKeys == 1 then "Column not found" else "Columns not found") + (evaluate (D.nRows result)) + mapM_ + ( \missingKey -> + assertExpectException preface (unpack missingKey) (evaluate (D.nRows result)) + ) + missingKeys + assertMissingJoinColumn :: String -> Text -> D.DataFrame -> Assertion -assertMissingJoinColumn preface missingKey result = do - assertExpectException preface "Column not found" (evaluate (D.nRows result)) - assertExpectException preface (unpack missingKey) (evaluate (D.nRows result)) +assertMissingJoinColumn preface missingKey = assertMissingJoinColumns preface [missingKey] testInnerJoin :: Test testInnerJoin = @@ -292,6 +302,14 @@ testFullOuterJoinMissingKey = "Cats" (fullOuterJoin ["Cats"] df1 df2) +testInnerJoinMissingKeys :: Test +testInnerJoinMissingKeys = + TestCase $ + assertMissingJoinColumns + "Inner join should report every missing join key" + ["Animals", "Cats"] + (innerJoin ["Animals", "Cats"] df1 df2) + -- Empty DataFrame fixtures: same schema as df1/df2 but zero rows. emptyDf1 :: D.DataFrame emptyDf1 = @@ -396,6 +414,7 @@ tests = , TestLabel "leftJoinMissingKey" testLeftJoinMissingKey , TestLabel "rightJoinMissingKey" testRightJoinMissingKey , TestLabel "fullOuterJoinMissingKey" testFullOuterJoinMissingKey + , TestLabel "innerJoinMissingKeys" testInnerJoinMissingKeys , TestLabel "innerJoinBothEmpty" testInnerJoinBothEmpty , TestLabel "innerJoinLeftEmpty" testInnerJoinLeftEmpty , TestLabel "innerJoinRightEmpty" testInnerJoinRightEmpty From f8e8dcb6f3c8bb6be034e187ab62b6529f772293 Mon Sep 17 00:00:00 2001 From: Anamika AggarwaL Date: Sat, 21 Mar 2026 16:53:05 +0530 Subject: [PATCH 4/8] use guards for join early exits --- src/DataFrame/Operations/Join.hs | 47 +++++++++++++++++--------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/DataFrame/Operations/Join.hs b/src/DataFrame/Operations/Join.hs index af5ca1e..6214c09 100644 --- a/src/DataFrame/Operations/Join.hs +++ b/src/DataFrame/Operations/Join.hs @@ -184,7 +184,12 @@ ghci> D.innerJoin ["key"] df other @ -} innerJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame -innerJoin cs left right = +innerJoin cs left right + | D.null right || D.null left = D.empty + | otherwise = innerJoinNonEmpty cs left right + +innerJoinNonEmpty :: [T.Text] -> DataFrame -> DataFrame -> DataFrame +innerJoinNonEmpty cs left right = let csSet = S.fromList cs leftRows = fst (D.dimensions left) @@ -207,9 +212,7 @@ innerJoin cs left right = let (!rIxs, !lIxs) = hashInnerKernel rightHashes leftHashes in (lIxs, rIxs) in - if D.null right || D.null left - then D.empty - else assembleInner csSet left right leftIxs rightIxs + assembleInner csSet left right leftIxs rightIxs -- | Compute hashes for the given key column names in a DataFrame. buildHashColumn :: [T.Text] -> DataFrame -> VU.Vector Int @@ -545,7 +548,13 @@ leftJoin = leftJoinWithCallPoint "leftJoin" leftJoinWithCallPoint :: T.Text -> [T.Text] -> DataFrame -> DataFrame -> DataFrame -leftJoinWithCallPoint callPoint cs left right = +leftJoinWithCallPoint callPoint cs left right + | D.null right || D.nRows right == 0 = left + | D.null left || D.nRows left == 0 = D.empty + | otherwise = leftJoinNonEmpty callPoint cs left right + +leftJoinNonEmpty :: T.Text -> [T.Text] -> DataFrame -> DataFrame -> DataFrame +leftJoinNonEmpty callPoint cs left right = let csSet = S.fromList cs rightRows = fst (D.dimensions right) @@ -562,14 +571,8 @@ leftJoinWithCallPoint callPoint cs left right = | otherwise = hashLeftKernel leftHashes rightHashes in - if D.null right || D.nRows right == 0 - then left - else - if D.null left || D.nRows left == 0 - then D.empty - else - -- rightIxs uses -1 as sentinel for "no match" - assembleLeft csSet left right leftIxs rightIxs + -- rightIxs uses -1 as sentinel for "no match" + assembleLeft csSet left right leftIxs rightIxs {- | Hash-based left join kernel. Returns @(leftExpandedIndices, rightExpandedIndices)@ where @@ -823,7 +826,13 @@ rightJoin cs left right = leftJoinWithCallPoint "rightJoin" cs right left fullOuterJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame -fullOuterJoin cs left right = +fullOuterJoin cs left right + | D.null right || D.nRows right == 0 = left + | D.null left || D.nRows left == 0 = right + | otherwise = fullOuterJoinNonEmpty cs left right + +fullOuterJoinNonEmpty :: [T.Text] -> DataFrame -> DataFrame -> DataFrame +fullOuterJoinNonEmpty cs left right = let csSet = S.fromList cs leftRows = fst (D.dimensions left) @@ -841,14 +850,8 @@ fullOuterJoin cs left right = | otherwise = hashFullOuterKernel leftHashes rightHashes in - if D.null right || D.nRows right == 0 - then left - else - if D.null left || D.nRows left == 0 - then right - else - -- Both index vectors use -1 as sentinel - assembleFullOuter csSet left right leftIxs rightIxs + -- Both index vectors use -1 as sentinel + assembleFullOuter csSet left right leftIxs rightIxs {- | Hash-based full outer join kernel. Builds compact indices on both sides. From 5cca1194dec967360ec4581c9ac246cc200b017f Mon Sep 17 00:00:00 2001 From: Anamika AggarwaL Date: Sun, 22 Mar 2026 01:50:06 +0530 Subject: [PATCH 5/8] consolidate missing column suggestions --- src/DataFrame/Errors.hs | 28 +++++++++++++++++++--------- tests/Operations/Join.hs | 14 ++++++++++++++ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/DataFrame/Errors.hs b/src/DataFrame/Errors.hs index e8f88e7..ed06b17 100644 --- a/src/DataFrame/Errors.hs +++ b/src/DataFrame/Errors.hs @@ -11,6 +11,7 @@ import qualified Data.Vector.Unboxed as VU import Control.Exception import Data.Array +import qualified Data.List as L import Data.Typeable (Typeable) import DataFrame.Display.Terminal.Colours import Type.Reflection (TypeRep) @@ -84,17 +85,26 @@ columnsNotFound names callPoint columns = ++ T.unpack (T.intercalate ", " names) ++ " for operation " ++ T.unpack callPoint - ++ concatMap formatSuggestion names + ++ formatSuggestions names columns ++ "\n\n" where - formatSuggestion name = case guessColumnName name columns of - "" -> "" - guessed -> - "\n\tDid you mean " - ++ T.unpack guessed - ++ " for " - ++ T.unpack name - ++ "?" + formatSuggestions missingColumns availableColumns = + case traverse (`suggestColumnName` availableColumns) missingColumns of + Just guessedColumns + | not (null guessedColumns) -> + "\n\tDid you mean " + ++ formatColumnSuggestions guessedColumns + ++ "?" + _ -> "" + + suggestColumnName missingColumn availableColumns = case guessColumnName missingColumn availableColumns of + "" -> Nothing + guessed -> Just guessed + + formatColumnSuggestions guessedColumns = + "[" + ++ L.intercalate ", " (map (show . T.unpack) guessedColumns) + ++ "]" typeMismatchError :: String -> String -> String typeMismatchError givenType expectedType = diff --git a/tests/Operations/Join.hs b/tests/Operations/Join.hs index 39bb85c..1718935 100644 --- a/tests/Operations/Join.hs +++ b/tests/Operations/Join.hs @@ -310,6 +310,19 @@ testInnerJoinMissingKeys = ["Animals", "Cats"] (innerJoin ["Animals", "Cats"] df1 df2) +testInnerJoinMissingKeysSuggestion :: Test +testInnerJoinMissingKeysSuggestion = + TestCase $ + let typoDf = + D.fromNamedColumns + [ ("hello", D.fromList ["H" :: Text]) + , ("world", D.fromList ["W" :: Text]) + ] + in assertExpectException + "Inner join should report consolidated suggestions for missing join keys" + "Did you mean [\"hello\", \"world\"]?" + (evaluate (D.nRows (innerJoin ["helo", "wrld"] typoDf typoDf))) + -- Empty DataFrame fixtures: same schema as df1/df2 but zero rows. emptyDf1 :: D.DataFrame emptyDf1 = @@ -415,6 +428,7 @@ tests = , TestLabel "rightJoinMissingKey" testRightJoinMissingKey , TestLabel "fullOuterJoinMissingKey" testFullOuterJoinMissingKey , TestLabel "innerJoinMissingKeys" testInnerJoinMissingKeys + , TestLabel "innerJoinMissingKeysSuggestion" testInnerJoinMissingKeysSuggestion , TestLabel "innerJoinBothEmpty" testInnerJoinBothEmpty , TestLabel "innerJoinLeftEmpty" testInnerJoinLeftEmpty , TestLabel "innerJoinRightEmpty" testInnerJoinRightEmpty From 1c5f364b4c57684fc556bcdea8c4058becfefafd Mon Sep 17 00:00:00 2001 From: Anamika AggarwaL Date: Sun, 22 Mar 2026 02:11:10 +0530 Subject: [PATCH 6/8] consolidate missing column error formatting --- src/DataFrame/Errors.hs | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/DataFrame/Errors.hs b/src/DataFrame/Errors.hs index ed06b17..4b8e95a 100644 --- a/src/DataFrame/Errors.hs +++ b/src/DataFrame/Errors.hs @@ -68,28 +68,31 @@ instance Show DataFrameException where ++ T.unpack expr2 columnNotFound :: T.Text -> T.Text -> [T.Text] -> String -columnNotFound name callPoint columns = - red "\n\n[ERROR] " - ++ "Column not found: " - ++ T.unpack name - ++ " for operation " - ++ T.unpack callPoint - ++ "\n\tDid you mean " - ++ T.unpack (guessColumnName name columns) - ++ "?\n\n" +columnNotFound name callPoint = columnsNotFound [name] callPoint columnsNotFound :: [T.Text] -> T.Text -> [T.Text] -> String -columnsNotFound names callPoint columns = +columnsNotFound missingColumns callPoint availableColumns = red "\n\n[ERROR] " - ++ "Columns not found: " - ++ T.unpack (T.intercalate ", " names) + ++ missingColumnsLabel missingColumns + ++ ": " + ++ T.unpack (T.intercalate ", " missingColumns) ++ " for operation " ++ T.unpack callPoint - ++ formatSuggestions names columns + ++ formatSuggestions missingColumns availableColumns ++ "\n\n" where - formatSuggestions missingColumns availableColumns = - case traverse (`suggestColumnName` availableColumns) missingColumns of + missingColumnsLabel [_] = "Column not found" + missingColumnsLabel _ = "Columns not found" + + formatSuggestions [missingColumn] columns = + case guessColumnName missingColumn columns of + "" -> "" + guessed -> + "\n\tDid you mean " + ++ T.unpack guessed + ++ "?" + formatSuggestions names columns = + case traverse (`suggestColumnName` columns) names of Just guessedColumns | not (null guessedColumns) -> "\n\tDid you mean " @@ -97,7 +100,7 @@ columnsNotFound names callPoint columns = ++ "?" _ -> "" - suggestColumnName missingColumn availableColumns = case guessColumnName missingColumn availableColumns of + suggestColumnName missingColumn columns = case guessColumnName missingColumn columns of "" -> Nothing guessed -> Just guessed From a8704592f1505fce1d0ce3a852ef643387be38cc Mon Sep 17 00:00:00 2001 From: Anamika AggarwaL Date: Sun, 22 Mar 2026 11:30:45 +0530 Subject: [PATCH 7/8] apply hlint eta reduction --- src/DataFrame/Errors.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DataFrame/Errors.hs b/src/DataFrame/Errors.hs index 4b8e95a..03d7e02 100644 --- a/src/DataFrame/Errors.hs +++ b/src/DataFrame/Errors.hs @@ -68,7 +68,7 @@ instance Show DataFrameException where ++ T.unpack expr2 columnNotFound :: T.Text -> T.Text -> [T.Text] -> String -columnNotFound name callPoint = columnsNotFound [name] callPoint +columnNotFound name = columnsNotFound [name] columnsNotFound :: [T.Text] -> T.Text -> [T.Text] -> String columnsNotFound missingColumns callPoint availableColumns = From be407f8cdf36f5421065057aa7551df2355904e1 Mon Sep 17 00:00:00 2001 From: Anamika AggarwaL Date: Wed, 25 Mar 2026 06:18:58 +0530 Subject: [PATCH 8/8] use one missing columns exception --- src/DataFrame/Errors.hs | 4 +--- src/DataFrame/IO/Parquet.hs | 6 +++--- src/DataFrame/Internal/DataFrame.hs | 6 +++--- src/DataFrame/Internal/Interpreter.hs | 24 ++++++++++----------- src/DataFrame/Internal/Row.hs | 6 +++--- src/DataFrame/Operations/Aggregation.hs | 4 ++-- src/DataFrame/Operations/Core.hs | 24 +++++++++++++++------ src/DataFrame/Operations/Join.hs | 3 +-- src/DataFrame/Operations/Permutation.hs | 4 ++-- src/DataFrame/Operations/Statistics.hs | 4 ++-- src/DataFrame/Operations/Subset.hs | 11 +++++----- src/DataFrame/Operations/Transformations.hs | 12 ++++++----- tests/Operations/GroupBy.hs | 2 +- tests/Operations/Sort.hs | 2 +- tests/Operations/Statistics.hs | 2 +- 15 files changed, 62 insertions(+), 52 deletions(-) diff --git a/src/DataFrame/Errors.hs b/src/DataFrame/Errors.hs index 03d7e02..1d96212 100644 --- a/src/DataFrame/Errors.hs +++ b/src/DataFrame/Errors.hs @@ -30,7 +30,6 @@ data DataFrameException where TypeErrorContext a b -> DataFrameException AggregatedAndNonAggregatedException :: T.Text -> T.Text -> DataFrameException - ColumnNotFoundException :: T.Text -> T.Text -> [T.Text] -> DataFrameException ColumnsNotFoundException :: [T.Text] -> T.Text -> [T.Text] -> DataFrameException EmptyDataSetException :: T.Text -> DataFrameException InternalException :: T.Text -> DataFrameException @@ -53,7 +52,6 @@ instance Show DataFrameException where (errorColumnName context) (callingFunctionName context) errorString - show (ColumnNotFoundException columnName callPoint availableColumns) = columnNotFound columnName callPoint availableColumns show (ColumnsNotFoundException columnNames callPoint availableColumns) = columnsNotFound columnNames callPoint availableColumns show (EmptyDataSetException callPoint) = emptyDataSetError callPoint show (WrongQuantileNumberException q) = wrongQuantileNumberError q @@ -68,7 +66,7 @@ instance Show DataFrameException where ++ T.unpack expr2 columnNotFound :: T.Text -> T.Text -> [T.Text] -> String -columnNotFound name = columnsNotFound [name] +columnNotFound missingColumn = columnsNotFound [missingColumn] columnsNotFound :: [T.Text] -> T.Text -> [T.Text] -> String columnsNotFound missingColumns callPoint availableColumns = diff --git a/src/DataFrame/IO/Parquet.hs b/src/DataFrame/IO/Parquet.hs index 40c5780..43a91d6 100644 --- a/src/DataFrame/IO/Parquet.hs +++ b/src/DataFrame/IO/Parquet.hs @@ -19,7 +19,7 @@ import qualified Data.Text as T import Data.Text.Encoding import Data.Time import Data.Time.Clock.POSIX (posixSecondsToUTCTime) -import DataFrame.Errors (DataFrameException (ColumnNotFoundException)) +import DataFrame.Errors (DataFrameException (ColumnsNotFoundException)) import DataFrame.Internal.Binary (littleEndianWord32) import qualified DataFrame.Internal.Column as DI import DataFrame.Internal.DataFrame (DataFrame) @@ -160,8 +160,8 @@ _readParquetWithOpts extraConfig opts path = withFileBufferedOrSeekable extraCon in unless (L.null missing) ( throw - ( ColumnNotFoundException - (T.pack $ show missing) + ( ColumnsNotFoundException + missing "readParquetWithOpts" availableSelectedColumns ) diff --git a/src/DataFrame/Internal/DataFrame.hs b/src/DataFrame/Internal/DataFrame.hs index 2e3f9a3..954ca83 100644 --- a/src/DataFrame/Internal/DataFrame.hs +++ b/src/DataFrame/Internal/DataFrame.hs @@ -154,16 +154,16 @@ getColumn name df = do {- | Retrieves a column by name from the dataframe, throwing an exception if not found. -This is an unsafe version of 'getColumn' that throws 'ColumnNotFoundException' +This is an unsafe version of 'getColumn' that throws 'ColumnsNotFoundException' if the column does not exist. Use this when you are certain the column exists. ==== __Throws__ -* 'ColumnNotFoundException' - if the column with the given name does not exist +* 'ColumnsNotFoundException' - if the column with the given name does not exist -} unsafeGetColumn :: T.Text -> DataFrame -> Column unsafeGetColumn name df = case getColumn name df of - Nothing -> throw $ ColumnNotFoundException name "" (M.keys $ columnIndices df) + Nothing -> throw $ ColumnsNotFoundException [name] "" (M.keys $ columnIndices df) Just col -> col {- | Checks if the dataframe is empty (has no columns). diff --git a/src/DataFrame/Internal/Interpreter.hs b/src/DataFrame/Internal/Interpreter.hs index 988be75..221a346 100644 --- a/src/DataFrame/Internal/Interpreter.hs +++ b/src/DataFrame/Internal/Interpreter.hs @@ -482,7 +482,7 @@ eval _ (Lit v) = Right (Scalar v) eval (FlatCtx df) (Col name) = case getColumn name df of Nothing -> - Left $ ColumnNotFoundException name "" (M.keys $ columnIndices df) + Left $ ColumnsNotFoundException [name] "" (M.keys $ columnIndices df) Just c | hasElemType @a c -> Right (Flat c) | otherwise -> @@ -500,8 +500,8 @@ eval (GroupCtx gdf) (Col name) = case getColumn name (fullDataframe gdf) of Nothing -> Left $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "" (M.keys $ columnIndices $ fullDataframe gdf) Just c @@ -524,14 +524,14 @@ eval (FlatCtx df) (CastWith name _tag onResult) = case getColumn name df of Nothing -> Left $ - ColumnNotFoundException name "" (M.keys $ columnIndices df) + ColumnsNotFoundException [name] "" (M.keys $ columnIndices df) Just c -> Flat <$> promoteColumnWith onResult c eval (GroupCtx gdf) (CastWith name _tag onResult) = case getColumn name (fullDataframe gdf) of Nothing -> Left $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "" (M.keys $ columnIndices $ fullDataframe gdf) Just c -> do @@ -579,8 +579,8 @@ eval (GroupCtx gdf) expr@(Agg (FoldAgg _ (Just seed) (f :: a -> b -> a)) (Col na case getColumn name (fullDataframe gdf) of Nothing -> Left $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "" (M.keys $ columnIndices $ fullDataframe gdf) Just col -> @@ -599,8 +599,8 @@ eval (GroupCtx gdf) expr@(Agg (FoldAgg _ Nothing (f :: a -> b -> a)) (Col name : case getColumn name (fullDataframe gdf) of Nothing -> Left $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "" (M.keys $ columnIndices $ fullDataframe gdf) Just col -> @@ -618,8 +618,8 @@ eval case getColumn name (fullDataframe gdf) of Nothing -> Left $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "" (M.keys $ columnIndices $ fullDataframe gdf) Just col -> diff --git a/src/DataFrame/Internal/Row.hs b/src/DataFrame/Internal/Row.hs index 7671757..847f9cd 100644 --- a/src/DataFrame/Internal/Row.hs +++ b/src/DataFrame/Internal/Row.hs @@ -173,8 +173,8 @@ mkRowFromArgs names df i = V.map get (V.fromList names) get name = case getColumn name df of Nothing -> throw $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "[INTERNAL] mkRowFromArgs" (M.keys $ columnIndices df) Just (BoxedColumn column) -> toAny (column V.! i) @@ -207,7 +207,7 @@ mkRowRep df names i = V.generate (L.length names) (\index -> get (names' V.! ind Just e -> toAny e Nothing -> throwError name Nothing -> - throw $ ColumnNotFoundException name "mkRowRep" (M.keys $ columnIndices df) + throw $ ColumnsNotFoundException [name] "mkRowRep" (M.keys $ columnIndices df) sortedIndexes' :: [Bool] -> V.Vector Row -> VU.Vector Int sortedIndexes' flipCompare rows = runST $ do diff --git a/src/DataFrame/Operations/Aggregation.hs b/src/DataFrame/Operations/Aggregation.hs index dc07151..a7af6dd 100644 --- a/src/DataFrame/Operations/Aggregation.hs +++ b/src/DataFrame/Operations/Aggregation.hs @@ -49,8 +49,8 @@ groupBy :: groupBy names df | any (`notElem` columnNames df) names = throw $ - ColumnNotFoundException - (T.pack $ show $ names L.\\ columnNames df) + ColumnsNotFoundException + (names L.\\ columnNames df) "groupBy" (columnNames df) | nRows df == 0 = diff --git a/src/DataFrame/Operations/Core.hs b/src/DataFrame/Operations/Core.hs index 23c0d49..8a27b65 100644 --- a/src/DataFrame/Operations/Core.hs +++ b/src/DataFrame/Operations/Core.hs @@ -391,7 +391,7 @@ insertColumn name column d = cloneColumn :: T.Text -> T.Text -> DataFrame -> DataFrame cloneColumn original new df = fromMaybe ( throw $ - ColumnNotFoundException original "cloneColumn" (M.keys $ columnIndices df) + ColumnsNotFoundException [original] "cloneColumn" (M.keys $ columnIndices df) ) $ do column <- getColumn original df @@ -480,7 +480,7 @@ renameMany = fold (uncurry rename) renameSafe :: T.Text -> T.Text -> DataFrame -> Either DataFrameException DataFrame renameSafe orig new df = fromMaybe - (Left $ ColumnNotFoundException orig "rename" (M.keys $ columnIndices df)) + (Left $ ColumnsNotFoundException [orig] "rename" (M.keys $ columnIndices df)) $ do columnIndex <- M.lookup orig (columnIndices df) let origRemoved = M.delete orig (columnIndices df) @@ -856,7 +856,8 @@ columnAsVector :: columnAsVector (Col name) df = case getColumn name df of Just col -> toVector col Nothing -> - Left $ ColumnNotFoundException name "columnAsVector" (M.keys $ columnIndices df) + Left $ + ColumnsNotFoundException [name] "columnAsVector" (M.keys $ columnIndices df) columnAsVector expr df = case interpret df expr of Left e -> throw e Right (TColumn col) -> toVector col @@ -873,7 +874,7 @@ columnAsIntVector (Col name) df = case getColumn name df of Just col -> toIntVector col Nothing -> Left $ - ColumnNotFoundException name "columnAsIntVector" (M.keys $ columnIndices df) + ColumnsNotFoundException [name] "columnAsIntVector" (M.keys $ columnIndices df) columnAsIntVector expr df = case interpret df expr of Left e -> throw e Right (TColumn col) -> toIntVector col @@ -890,7 +891,10 @@ columnAsDoubleVector (Col name) df = case getColumn name df of Just col -> toDoubleVector col Nothing -> Left $ - ColumnNotFoundException name "columnAsDoubleVector" (M.keys $ columnIndices df) + ColumnsNotFoundException + [name] + "columnAsDoubleVector" + (M.keys $ columnIndices df) columnAsDoubleVector expr df = case interpret df expr of Left e -> throw e Right (TColumn col) -> toDoubleVector col @@ -907,7 +911,10 @@ columnAsFloatVector (Col name) df = case getColumn name df of Just col -> toFloatVector col Nothing -> Left $ - ColumnNotFoundException name "columnAsFloatVector" (M.keys $ columnIndices df) + ColumnsNotFoundException + [name] + "columnAsFloatVector" + (M.keys $ columnIndices df) columnAsFloatVector expr df = case interpret df expr of Left e -> throw e Right (TColumn col) -> toFloatVector col @@ -920,7 +927,10 @@ columnAsUnboxedVector (Col name) df = case getColumn name df of Just col -> toUnboxedVector col Nothing -> Left $ - ColumnNotFoundException name "columnAsFloatVector" (M.keys $ columnIndices df) + ColumnsNotFoundException + [name] + "columnAsFloatVector" + (M.keys $ columnIndices df) columnAsUnboxedVector expr df = case interpret df expr of Left e -> throw e Right (TColumn col) -> toUnboxedVector col diff --git a/src/DataFrame/Operations/Join.hs b/src/DataFrame/Operations/Join.hs index 6214c09..9eb0d35 100644 --- a/src/DataFrame/Operations/Join.hs +++ b/src/DataFrame/Operations/Join.hs @@ -26,7 +26,7 @@ import qualified Data.Vector.Algorithms.Merge as VA import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM import DataFrame.Errors ( - DataFrameException (ColumnNotFoundException, ColumnsNotFoundException), + DataFrameException (ColumnsNotFoundException), ) import DataFrame.Internal.Column as D import DataFrame.Internal.DataFrame as D @@ -156,7 +156,6 @@ validatedKeyColIndices callPoint csSet df = missingKeys = S.toAscList (csSet `S.difference` M.keysSet columnIdxs) in case missingKeys of [] -> M.elems $ M.restrictKeys columnIdxs csSet - [missingKey] -> throw (ColumnNotFoundException missingKey callPoint (M.keys columnIdxs)) _ -> throw (ColumnsNotFoundException missingKeys callPoint (M.keys columnIdxs)) -- ============================================================ diff --git a/src/DataFrame/Operations/Permutation.hs b/src/DataFrame/Operations/Permutation.hs index e799bfd..58f095d 100644 --- a/src/DataFrame/Operations/Permutation.hs +++ b/src/DataFrame/Operations/Permutation.hs @@ -53,8 +53,8 @@ sortBy :: sortBy sortOrds df | any (`notElem` columnNames df) names = throw $ - ColumnNotFoundException - (T.pack $ show $ names L.\\ columnNames df) + ColumnsNotFoundException + (names L.\\ columnNames df) "sortBy" (columnNames df) | otherwise = diff --git a/src/DataFrame/Operations/Statistics.hs b/src/DataFrame/Operations/Statistics.hs index b7e5f70..ff4d530 100644 --- a/src/DataFrame/Operations/Statistics.hs +++ b/src/DataFrame/Operations/Statistics.hs @@ -222,7 +222,7 @@ _getColumnAsDouble name df = case getColumn name df of SFalse -> Nothing Nothing -> throw $ - ColumnNotFoundException name "_getColumnAsDouble" (M.keys $ columnIndices df) + ColumnsNotFoundException [name] "_getColumnAsDouble" (M.keys $ columnIndices df) _ -> Nothing -- Return a type mismatch error here. {-# INLINE _getColumnAsDouble #-} @@ -237,7 +237,7 @@ optionalToDoubleVector = sum :: forall a. (Columnable a, Num a) => Expr a -> DataFrame -> a sum (Col name) df = case getColumn name df of - Nothing -> throw $ ColumnNotFoundException name "sum" (M.keys $ columnIndices df) + Nothing -> throw $ ColumnsNotFoundException [name] "sum" (M.keys $ columnIndices df) Just ((UnboxedColumn (column :: VU.Vector a'))) -> case testEquality (typeRep @a') (typeRep @a) of Just Refl -> VG.sum column Nothing -> 0 diff --git a/src/DataFrame/Operations/Subset.hs b/src/DataFrame/Operations/Subset.hs index 313714f..f5873cd 100644 --- a/src/DataFrame/Operations/Subset.hs +++ b/src/DataFrame/Operations/Subset.hs @@ -130,7 +130,7 @@ filter :: filter (Col filterColumnName) condition df = case getColumn filterColumnName df of Nothing -> throw $ - ColumnNotFoundException filterColumnName "filter" (M.keys $ columnIndices df) + ColumnsNotFoundException [filterColumnName] "filter" (M.keys $ columnIndices df) Just (BoxedColumn (column :: V.Vector b)) -> filterByVector filterColumnName column condition df Just (OptionalColumn (column :: V.Vector b)) -> filterByVector filterColumnName column condition df Just (UnboxedColumn (column :: VU.Vector b)) -> filterByVector filterColumnName column condition df @@ -207,7 +207,7 @@ filterWhere expr df = filterJust :: T.Text -> DataFrame -> DataFrame filterJust name df = case getColumn name df of Nothing -> - throw $ ColumnNotFoundException name "filterJust" (M.keys $ columnIndices df) + throw $ ColumnsNotFoundException [name] "filterJust" (M.keys $ columnIndices df) Just column@(OptionalColumn (col :: V.Vector (Maybe a))) -> filter (Col @(Maybe a) name) isJust df & apply @(Maybe a) fromJust name Just column -> df @@ -218,7 +218,8 @@ filterJust name df = case getColumn name df of filterNothing :: T.Text -> DataFrame -> DataFrame filterNothing name df = case getColumn name df of Nothing -> - throw $ ColumnNotFoundException name "filterNothing" (M.keys $ columnIndices df) + throw $ + ColumnsNotFoundException [name] "filterNothing" (M.keys $ columnIndices df) Just (OptionalColumn (col :: V.Vector (Maybe a))) -> filter (Col @(Maybe a) name) isNothing df _ -> df @@ -256,8 +257,8 @@ select cs df | L.null cs = empty | any (`notElem` columnNames df) cs = throw $ - ColumnNotFoundException - (T.pack $ show $ cs L.\\ columnNames df) + ColumnsNotFoundException + (cs L.\\ columnNames df) "select" (columnNames df) | otherwise = diff --git a/src/DataFrame/Operations/Transformations.hs b/src/DataFrame/Operations/Transformations.hs index 41401a1..4eae22a 100644 --- a/src/DataFrame/Operations/Transformations.hs +++ b/src/DataFrame/Operations/Transformations.hs @@ -63,7 +63,8 @@ safeApply :: DataFrame -> Either DataFrameException DataFrame safeApply f columnName d = case getColumn columnName d of - Nothing -> Left $ ColumnNotFoundException columnName "apply" (M.keys $ columnIndices d) + Nothing -> + Left $ ColumnsNotFoundException [columnName] "apply" (M.keys $ columnIndices d) Just column -> do column' <- mapColumn f column pure $ insertColumn columnName column' d @@ -163,8 +164,8 @@ applyWhere :: applyWhere condition filterColumnName f columnName df = case getColumn filterColumnName df of Nothing -> throw $ - ColumnNotFoundException - filterColumnName + ColumnsNotFoundException + [filterColumnName] "applyWhere" (M.keys $ columnIndices df) Just column -> case ifoldrColumn @@ -193,7 +194,7 @@ applyAtIndex :: applyAtIndex i f columnName df = case getColumn columnName df of Nothing -> throw $ - ColumnNotFoundException columnName "applyAtIndex" (M.keys $ columnIndices df) + ColumnsNotFoundException [columnName] "applyAtIndex" (M.keys $ columnIndices df) Just column -> case imapColumn (\index value -> if index == i then f value else value) column of Left e -> throw e Right column' -> insertColumn columnName column' df @@ -208,7 +209,8 @@ imputeCore :: DataFrame imputeCore (Col columnName) value df = case getColumn columnName df of Nothing -> - throw $ ColumnNotFoundException columnName "impute" (M.keys $ columnIndices df) + throw $ + ColumnsNotFoundException [columnName] "impute" (M.keys $ columnIndices df) Just (OptionalColumn _) -> case safeApply (fromMaybe value) columnName df of Left (TypeMismatchException context) -> throw $ TypeMismatchException (context{callingFunctionName = Just "impute"}) Left exception -> throw exception diff --git a/tests/Operations/GroupBy.hs b/tests/Operations/GroupBy.hs index 0ad27ed..47c2b9a 100644 --- a/tests/Operations/GroupBy.hs +++ b/tests/Operations/GroupBy.hs @@ -47,7 +47,7 @@ groupByColumnDoesNotExist = TestCase ( assertExpectException "[Error Case]" - (D.columnNotFound "[\"test0\"]" "groupBy" (D.columnNames testData)) + (D.columnsNotFound ["test0"] "groupBy" (D.columnNames testData)) (print $ D.groupBy ["test0"] testData) ) diff --git a/tests/Operations/Sort.hs b/tests/Operations/Sort.hs index 1c821dc..98d70dc 100644 --- a/tests/Operations/Sort.hs +++ b/tests/Operations/Sort.hs @@ -85,7 +85,7 @@ sortByColumnDoesNotExist = TestCase ( assertExpectException "[Error Case]" - (D.columnNotFound "[\"test0\"]" "sortBy" (D.columnNames testData)) + (D.columnsNotFound ["test0"] "sortBy" (D.columnNames testData)) (print $ D.sortBy [D.Asc (F.col @Int "test0")] testData) ) diff --git a/tests/Operations/Statistics.hs b/tests/Operations/Statistics.hs index 81baba9..db90765 100644 --- a/tests/Operations/Statistics.hs +++ b/tests/Operations/Statistics.hs @@ -244,7 +244,7 @@ correlationSelfIdentity = (abs (r - 1.0) < 1e-10) ) --- Requesting a missing column should throw ColumnNotFoundException +-- Requesting a missing column should throw ColumnsNotFoundException correlationMissingColumn :: Test correlationMissingColumn = TestCase