diff --git a/src/DataFrame/Errors.hs b/src/DataFrame/Errors.hs index ab3318c..1d96212 100644 --- a/src/DataFrame/Errors.hs +++ b/src/DataFrame/Errors.hs @@ -11,6 +11,7 @@ import qualified Data.Vector.Unboxed as VU import Control.Exception import Data.Array +import qualified Data.List as L import Data.Typeable (Typeable) import DataFrame.Display.Terminal.Colours import Type.Reflection (TypeRep) @@ -29,7 +30,7 @@ data DataFrameException where TypeErrorContext a b -> DataFrameException AggregatedAndNonAggregatedException :: T.Text -> T.Text -> DataFrameException - ColumnNotFoundException :: T.Text -> T.Text -> [T.Text] -> DataFrameException + ColumnsNotFoundException :: [T.Text] -> T.Text -> [T.Text] -> DataFrameException EmptyDataSetException :: T.Text -> DataFrameException InternalException :: T.Text -> DataFrameException NonColumnReferenceException :: T.Text -> DataFrameException @@ -51,7 +52,7 @@ instance Show DataFrameException where (errorColumnName context) (callingFunctionName context) errorString - show (ColumnNotFoundException columnName callPoint availableColumns) = columnNotFound columnName callPoint availableColumns + show (ColumnsNotFoundException columnNames callPoint availableColumns) = columnsNotFound columnNames callPoint availableColumns show (EmptyDataSetException callPoint) = emptyDataSetError callPoint show (WrongQuantileNumberException q) = wrongQuantileNumberError q show (WrongQuantileIndexException qs q) = wrongQuantileIndexError qs q @@ -65,15 +66,46 @@ instance Show DataFrameException where ++ T.unpack expr2 columnNotFound :: T.Text -> T.Text -> [T.Text] -> String -columnNotFound name callPoint columns = +columnNotFound missingColumn = columnsNotFound [missingColumn] + +columnsNotFound :: [T.Text] -> T.Text -> [T.Text] -> String +columnsNotFound missingColumns callPoint availableColumns = red "\n\n[ERROR] " - ++ "Column not found: " - ++ T.unpack name + ++ missingColumnsLabel missingColumns + ++ ": " + ++ T.unpack (T.intercalate ", " missingColumns) ++ " for operation " ++ T.unpack callPoint - ++ "\n\tDid you mean " - ++ T.unpack (guessColumnName name columns) - ++ "?\n\n" + ++ formatSuggestions missingColumns availableColumns + ++ "\n\n" + where + missingColumnsLabel [_] = "Column not found" + missingColumnsLabel _ = "Columns not found" + + formatSuggestions [missingColumn] columns = + case guessColumnName missingColumn columns of + "" -> "" + guessed -> + "\n\tDid you mean " + ++ T.unpack guessed + ++ "?" + formatSuggestions names columns = + case traverse (`suggestColumnName` columns) names of + Just guessedColumns + | not (null guessedColumns) -> + "\n\tDid you mean " + ++ formatColumnSuggestions guessedColumns + ++ "?" + _ -> "" + + suggestColumnName missingColumn columns = case guessColumnName missingColumn columns of + "" -> Nothing + guessed -> Just guessed + + formatColumnSuggestions guessedColumns = + "[" + ++ L.intercalate ", " (map (show . T.unpack) guessedColumns) + ++ "]" typeMismatchError :: String -> String -> String typeMismatchError givenType expectedType = diff --git a/src/DataFrame/IO/Parquet.hs b/src/DataFrame/IO/Parquet.hs index 61550e5..084f941 100644 --- a/src/DataFrame/IO/Parquet.hs +++ b/src/DataFrame/IO/Parquet.hs @@ -19,7 +19,7 @@ import qualified Data.Text as T import Data.Text.Encoding import Data.Time import Data.Time.Clock.POSIX (posixSecondsToUTCTime) -import DataFrame.Errors (DataFrameException (ColumnNotFoundException)) +import DataFrame.Errors (DataFrameException (ColumnsNotFoundException)) import DataFrame.Internal.Binary (littleEndianWord32) import qualified DataFrame.Internal.Column as DI import DataFrame.Internal.DataFrame (DataFrame) @@ -179,8 +179,8 @@ _readParquetWithOpts extraConfig opts path = withFileBufferedOrSeekable extraCon in unless (L.null missing) ( throw - ( ColumnNotFoundException - (T.pack $ show missing) + ( ColumnsNotFoundException + missing "readParquetWithOpts" availableSelectedColumns ) diff --git a/src/DataFrame/Internal/DataFrame.hs b/src/DataFrame/Internal/DataFrame.hs index d2474d9..44f4685 100644 --- a/src/DataFrame/Internal/DataFrame.hs +++ b/src/DataFrame/Internal/DataFrame.hs @@ -161,16 +161,16 @@ getColumn name df {- | Retrieves a column by name from the dataframe, throwing an exception if not found. -This is an unsafe version of 'getColumn' that throws 'ColumnNotFoundException' +This is an unsafe version of 'getColumn' that throws 'ColumnsNotFoundException' if the column does not exist. Use this when you are certain the column exists. ==== __Throws__ -* 'ColumnNotFoundException' - if the column with the given name does not exist +* 'ColumnsNotFoundException' - if the column with the given name does not exist -} unsafeGetColumn :: T.Text -> DataFrame -> Column unsafeGetColumn name df = case getColumn name df of - Nothing -> throw $ ColumnNotFoundException name "" (M.keys $ columnIndices df) + Nothing -> throw $ ColumnsNotFoundException [name] "" (M.keys $ columnIndices df) Just col -> col {- | Checks if the dataframe is empty (has no columns). diff --git a/src/DataFrame/Internal/Interpreter.hs b/src/DataFrame/Internal/Interpreter.hs index f0372c8..9bfdeeb 100644 --- a/src/DataFrame/Internal/Interpreter.hs +++ b/src/DataFrame/Internal/Interpreter.hs @@ -482,7 +482,7 @@ eval _ (Lit v) = Right (Scalar v) eval (FlatCtx df) (Col name) = case getColumn name df of Nothing -> - Left $ ColumnNotFoundException name "" (M.keys $ columnIndices df) + Left $ ColumnsNotFoundException [name] "" (M.keys $ columnIndices df) Just c | hasElemType @a c -> Right (Flat c) | otherwise -> @@ -500,8 +500,8 @@ eval (GroupCtx gdf) (Col name) = case getColumn name (fullDataframe gdf) of Nothing -> Left $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "" (M.keys $ columnIndices $ fullDataframe gdf) Just c @@ -524,14 +524,14 @@ eval (FlatCtx df) (CastWith name _tag onResult) = case getColumn name df of Nothing -> Left $ - ColumnNotFoundException name "" (M.keys $ columnIndices df) + ColumnsNotFoundException [name] "" (M.keys $ columnIndices df) Just c -> Flat <$> promoteColumnWith onResult c eval (GroupCtx gdf) (CastWith name _tag onResult) = case getColumn name (fullDataframe gdf) of Nothing -> Left $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "" (M.keys $ columnIndices $ fullDataframe gdf) Just c -> do @@ -579,8 +579,8 @@ eval (GroupCtx gdf) expr@(Agg (FoldAgg _ (Just seed) (f :: a -> b -> a)) (Col na case getColumn name (fullDataframe gdf) of Nothing -> Left $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "" (M.keys $ columnIndices $ fullDataframe gdf) Just col -> @@ -599,8 +599,8 @@ eval (GroupCtx gdf) expr@(Agg (FoldAgg _ Nothing (f :: a -> b -> a)) (Col name : case getColumn name (fullDataframe gdf) of Nothing -> Left $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "" (M.keys $ columnIndices $ fullDataframe gdf) Just col -> @@ -617,8 +617,8 @@ eval case getColumn name (fullDataframe gdf) of Nothing -> Left $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "" (M.keys $ columnIndices $ fullDataframe gdf) Just col -> diff --git a/src/DataFrame/Internal/Row.hs b/src/DataFrame/Internal/Row.hs index f7ed391..f512e50 100644 --- a/src/DataFrame/Internal/Row.hs +++ b/src/DataFrame/Internal/Row.hs @@ -169,8 +169,8 @@ mkRowFromArgs names df i = V.map get (V.fromList names) get name = case getColumn name df of Nothing -> throw $ - ColumnNotFoundException - name + ColumnsNotFoundException + [name] "[INTERNAL] mkRowFromArgs" (M.keys $ columnIndices df) Just (BoxedColumn column) -> toAny (column V.! i) @@ -203,7 +203,7 @@ mkRowRep df names i = V.generate (L.length names) (\index -> get (names' V.! ind Just e -> toAny e Nothing -> throwError name Nothing -> - throw $ ColumnNotFoundException name "mkRowRep" (M.keys $ columnIndices df) + throw $ ColumnsNotFoundException [name] "mkRowRep" (M.keys $ columnIndices df) sortedIndexes' :: [Bool] -> V.Vector Row -> VU.Vector Int sortedIndexes' flipCompare rows = runST $ do diff --git a/src/DataFrame/Operations/Aggregation.hs b/src/DataFrame/Operations/Aggregation.hs index dc07151..a7af6dd 100644 --- a/src/DataFrame/Operations/Aggregation.hs +++ b/src/DataFrame/Operations/Aggregation.hs @@ -49,8 +49,8 @@ groupBy :: groupBy names df | any (`notElem` columnNames df) names = throw $ - ColumnNotFoundException - (T.pack $ show $ names L.\\ columnNames df) + ColumnsNotFoundException + (names L.\\ columnNames df) "groupBy" (columnNames df) | nRows df == 0 = diff --git a/src/DataFrame/Operations/Core.hs b/src/DataFrame/Operations/Core.hs index cec2315..ab0861e 100644 --- a/src/DataFrame/Operations/Core.hs +++ b/src/DataFrame/Operations/Core.hs @@ -394,7 +394,7 @@ cloneColumn original new df | null df = throw (EmptyDataSetException "cloneColumn") | otherwise = fromMaybe ( throw $ - ColumnNotFoundException original "cloneColumn" (M.keys $ columnIndices df) + ColumnsNotFoundException [original] "cloneColumn" (M.keys $ columnIndices df) ) $ do column <- getColumn original df @@ -485,7 +485,7 @@ renameSafe :: renameSafe orig new df | null df = throw (EmptyDataSetException "rename") | otherwise = fromMaybe - (Left $ ColumnNotFoundException orig "rename" (M.keys $ columnIndices df)) + (Left $ ColumnsNotFoundException [orig] "rename" (M.keys $ columnIndices df)) $ do columnIndex <- M.lookup orig (columnIndices df) let origRemoved = M.delete orig (columnIndices df) @@ -859,7 +859,8 @@ columnAsVector expr df (Col name) -> case getColumn name df of Just col -> toVector col Nothing -> - Left $ ColumnNotFoundException name "columnAsVector" (M.keys $ columnIndices df) + Left $ + ColumnsNotFoundException [name] "columnAsVector" (M.keys $ columnIndices df) _ -> case interpret df expr of Left e -> throw e Right (TColumn col) -> toVector col @@ -876,7 +877,7 @@ columnAsIntVector (Col name) df = case getColumn name df of Just col -> toIntVector col Nothing -> Left $ - ColumnNotFoundException name "columnAsIntVector" (M.keys $ columnIndices df) + ColumnsNotFoundException [name] "columnAsIntVector" (M.keys $ columnIndices df) columnAsIntVector expr df = case interpret df expr of Left e -> throw e Right (TColumn col) -> toIntVector col @@ -893,7 +894,10 @@ columnAsDoubleVector (Col name) df = case getColumn name df of Just col -> toDoubleVector col Nothing -> Left $ - ColumnNotFoundException name "columnAsDoubleVector" (M.keys $ columnIndices df) + ColumnsNotFoundException + [name] + "columnAsDoubleVector" + (M.keys $ columnIndices df) columnAsDoubleVector expr df = case interpret df expr of Left e -> throw e Right (TColumn col) -> toDoubleVector col @@ -910,7 +914,10 @@ columnAsFloatVector (Col name) df = case getColumn name df of Just col -> toFloatVector col Nothing -> Left $ - ColumnNotFoundException name "columnAsFloatVector" (M.keys $ columnIndices df) + ColumnsNotFoundException + [name] + "columnAsFloatVector" + (M.keys $ columnIndices df) columnAsFloatVector expr df = case interpret df expr of Left e -> throw e Right (TColumn col) -> toFloatVector col @@ -923,7 +930,10 @@ columnAsUnboxedVector (Col name) df = case getColumn name df of Just col -> toUnboxedVector col Nothing -> Left $ - ColumnNotFoundException name "columnAsFloatVector" (M.keys $ columnIndices df) + ColumnsNotFoundException + [name] + "columnAsFloatVector" + (M.keys $ columnIndices df) columnAsUnboxedVector expr df = case interpret df expr of Left e -> throw e Right (TColumn col) -> toUnboxedVector col diff --git a/src/DataFrame/Operations/Join.hs b/src/DataFrame/Operations/Join.hs index a6db3bd..5938661 100644 --- a/src/DataFrame/Operations/Join.hs +++ b/src/DataFrame/Operations/Join.hs @@ -11,6 +11,7 @@ module DataFrame.Operations.Join where import Control.Applicative ((<|>)) +import Control.Exception (throw) import Control.Monad (forM_, when) import Control.Monad.ST (ST, runST) import qualified Data.HashMap.Strict as HM @@ -27,6 +28,9 @@ import qualified Data.Vector as VB import qualified Data.Vector.Algorithms.Merge as VA import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM +import DataFrame.Errors ( + DataFrameException (ColumnsNotFoundException), + ) import DataFrame.Internal.Column as D import DataFrame.Internal.DataFrame as D import DataFrame.Operations.Aggregation as D @@ -148,6 +152,15 @@ fillCrossProduct !leftSI !rightSI !lStart !lEnd !rStart !rEnd !lv !rv !pos = goL keyColIndices :: S.Set T.Text -> DataFrame -> [Int] keyColIndices csSet df = M.elems $ M.restrictKeys (D.columnIndices df) csSet +-- | Validate that all requested join keys exist, then return their indices. +validatedKeyColIndices :: T.Text -> S.Set T.Text -> DataFrame -> [Int] +validatedKeyColIndices callPoint csSet df = + let columnIdxs = D.columnIndices df + missingKeys = S.toAscList (csSet `S.difference` M.keysSet columnIdxs) + in case missingKeys of + [] -> M.elems $ M.restrictKeys columnIdxs csSet + _ -> throw (ColumnsNotFoundException missingKeys callPoint (M.keys columnIdxs)) + -- ============================================================ -- Inner Join -- ============================================================ @@ -175,36 +188,39 @@ ghci> D.innerJoin ["key"] df other innerJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame innerJoin cs left right | D.null right || D.null left = D.empty - | otherwise = - let - csSet = S.fromList cs - leftRows = fst (D.dimensions left) - rightRows = fst (D.dimensions right) - - leftKeyIdxs = keyColIndices csSet left - rightKeyIdxs = keyColIndices csSet right - leftHashes = D.computeRowHashes leftKeyIdxs left - rightHashes = D.computeRowHashes rightKeyIdxs right - - buildRows = min leftRows rightRows - (leftIxs, rightIxs) - | buildRows > joinStrategyThreshold = - sortMergeInnerKernel leftHashes rightHashes - | rightRows <= leftRows = - -- Build on right (smaller or equal), probe with left - hashInnerKernel leftHashes rightHashes - | otherwise = - -- Build on left (smaller), probe with right, swap result - let (!rIxs, !lIxs) = hashInnerKernel rightHashes leftHashes - in (lIxs, rIxs) - in - assembleInner csSet left right leftIxs rightIxs + | otherwise = innerJoinNonEmpty cs left right + +innerJoinNonEmpty :: [T.Text] -> DataFrame -> DataFrame -> DataFrame +innerJoinNonEmpty cs left right = + let + csSet = S.fromList cs + leftRows = fst (D.dimensions left) + rightRows = fst (D.dimensions right) + + leftKeyIdxs = validatedKeyColIndices "innerJoin" csSet left + rightKeyIdxs = validatedKeyColIndices "innerJoin" csSet right + leftHashes = D.computeRowHashes leftKeyIdxs left + rightHashes = D.computeRowHashes rightKeyIdxs right + + buildRows = min leftRows rightRows + (leftIxs, rightIxs) + | buildRows > joinStrategyThreshold = + sortMergeInnerKernel leftHashes rightHashes + | rightRows <= leftRows = + -- Build on right (smaller or equal), probe with left + hashInnerKernel leftHashes rightHashes + | otherwise = + -- Build on left (smaller), probe with right, swap result + let (!rIxs, !lIxs) = hashInnerKernel rightHashes leftHashes + in (lIxs, rIxs) + in + assembleInner csSet left right leftIxs rightIxs -- | Compute hashes for the given key column names in a DataFrame. buildHashColumn :: [T.Text] -> DataFrame -> VU.Vector Int buildHashColumn keys df = let csSet = S.fromList keys - keyIdxs = keyColIndices csSet df + keyIdxs = validatedKeyColIndices "buildHashColumn" csSet df in D.computeRowHashes keyIdxs df {- | Probe one batch of rows against a pre-built 'CompactIndex'. @@ -530,28 +546,35 @@ ghci> D.leftJoin ["key"] df other @ -} leftJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame -leftJoin cs left right +leftJoin = leftJoinWithCallPoint "leftJoin" + +leftJoinWithCallPoint :: + T.Text -> [T.Text] -> DataFrame -> DataFrame -> DataFrame +leftJoinWithCallPoint callPoint cs left right | D.null right || D.nRows right == 0 = left | D.null left || D.nRows left == 0 = D.empty - | otherwise = - let - csSet = S.fromList cs - rightRows = fst (D.dimensions right) - - leftKeyIdxs = keyColIndices csSet left - rightKeyIdxs = keyColIndices csSet right - leftHashes = D.computeRowHashes leftKeyIdxs left - rightHashes = D.computeRowHashes rightKeyIdxs right - - -- Right is always the build side for left join - (leftIxs, rightIxs) - | rightRows > joinStrategyThreshold = - sortMergeLeftKernel leftHashes rightHashes - | otherwise = - hashLeftKernel leftHashes rightHashes - in - -- rightIxs uses -1 as sentinel for "no match" - assembleLeft csSet left right leftIxs rightIxs + | otherwise = leftJoinNonEmpty callPoint cs left right + +leftJoinNonEmpty :: T.Text -> [T.Text] -> DataFrame -> DataFrame -> DataFrame +leftJoinNonEmpty callPoint cs left right = + let + csSet = S.fromList cs + rightRows = fst (D.dimensions right) + + leftKeyIdxs = validatedKeyColIndices callPoint csSet left + rightKeyIdxs = validatedKeyColIndices callPoint csSet right + leftHashes = D.computeRowHashes leftKeyIdxs left + rightHashes = D.computeRowHashes rightKeyIdxs right + + -- Right is always the build side for left join + (leftIxs, rightIxs) + | rightRows > joinStrategyThreshold = + sortMergeLeftKernel leftHashes rightHashes + | otherwise = + hashLeftKernel leftHashes rightHashes + in + -- rightIxs uses -1 as sentinel for "no match" + assembleLeft csSet left right leftIxs rightIxs {- | Hash-based left join kernel. Returns @(leftExpandedIndices, rightExpandedIndices)@ where @@ -801,33 +824,36 @@ ghci> D.rightJoin ["key"] df other -} rightJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame -rightJoin cs left right = leftJoin cs right left +rightJoin cs left right = leftJoinWithCallPoint "rightJoin" cs right left fullOuterJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame fullOuterJoin cs left right | D.null right || D.nRows right == 0 = left | D.null left || D.nRows left == 0 = right - | otherwise = - let - csSet = S.fromList cs - leftRows = fst (D.dimensions left) - rightRows = fst (D.dimensions right) - - leftKeyIdxs = keyColIndices csSet left - rightKeyIdxs = keyColIndices csSet right - leftHashes = D.computeRowHashes leftKeyIdxs left - rightHashes = D.computeRowHashes rightKeyIdxs right - - -- Both sides can have nulls in full outer - (leftIxs, rightIxs) - | max leftRows rightRows > joinStrategyThreshold = - sortMergeFullOuterKernel leftHashes rightHashes - | otherwise = - hashFullOuterKernel leftHashes rightHashes - in - -- Both index vectors use -1 as sentinel - assembleFullOuter csSet left right leftIxs rightIxs + | otherwise = fullOuterJoinNonEmpty cs left right + +fullOuterJoinNonEmpty :: [T.Text] -> DataFrame -> DataFrame -> DataFrame +fullOuterJoinNonEmpty cs left right = + let + csSet = S.fromList cs + leftRows = fst (D.dimensions left) + rightRows = fst (D.dimensions right) + + leftKeyIdxs = validatedKeyColIndices "fullOuterJoin" csSet left + rightKeyIdxs = validatedKeyColIndices "fullOuterJoin" csSet right + leftHashes = D.computeRowHashes leftKeyIdxs left + rightHashes = D.computeRowHashes rightKeyIdxs right + + -- Both sides can have nulls in full outer + (leftIxs, rightIxs) + | max leftRows rightRows > joinStrategyThreshold = + sortMergeFullOuterKernel leftHashes rightHashes + | otherwise = + hashFullOuterKernel leftHashes rightHashes + in + -- Both index vectors use -1 as sentinel + assembleFullOuter csSet left right leftIxs rightIxs {- | Hash-based full outer join kernel. Builds compact indices on both sides. diff --git a/src/DataFrame/Operations/Permutation.hs b/src/DataFrame/Operations/Permutation.hs index e799bfd..58f095d 100644 --- a/src/DataFrame/Operations/Permutation.hs +++ b/src/DataFrame/Operations/Permutation.hs @@ -53,8 +53,8 @@ sortBy :: sortBy sortOrds df | any (`notElem` columnNames df) names = throw $ - ColumnNotFoundException - (T.pack $ show $ names L.\\ columnNames df) + ColumnsNotFoundException + (names L.\\ columnNames df) "sortBy" (columnNames df) | otherwise = diff --git a/src/DataFrame/Operations/Statistics.hs b/src/DataFrame/Operations/Statistics.hs index b7e5f70..ff4d530 100644 --- a/src/DataFrame/Operations/Statistics.hs +++ b/src/DataFrame/Operations/Statistics.hs @@ -222,7 +222,7 @@ _getColumnAsDouble name df = case getColumn name df of SFalse -> Nothing Nothing -> throw $ - ColumnNotFoundException name "_getColumnAsDouble" (M.keys $ columnIndices df) + ColumnsNotFoundException [name] "_getColumnAsDouble" (M.keys $ columnIndices df) _ -> Nothing -- Return a type mismatch error here. {-# INLINE _getColumnAsDouble #-} @@ -237,7 +237,7 @@ optionalToDoubleVector = sum :: forall a. (Columnable a, Num a) => Expr a -> DataFrame -> a sum (Col name) df = case getColumn name df of - Nothing -> throw $ ColumnNotFoundException name "sum" (M.keys $ columnIndices df) + Nothing -> throw $ ColumnsNotFoundException [name] "sum" (M.keys $ columnIndices df) Just ((UnboxedColumn (column :: VU.Vector a'))) -> case testEquality (typeRep @a') (typeRep @a) of Just Refl -> VG.sum column Nothing -> 0 diff --git a/src/DataFrame/Operations/Subset.hs b/src/DataFrame/Operations/Subset.hs index eacb5dc..0d2a1ed 100644 --- a/src/DataFrame/Operations/Subset.hs +++ b/src/DataFrame/Operations/Subset.hs @@ -1,4 +1,6 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE ExplicitNamespaces #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} @@ -49,6 +51,18 @@ import System.Random import Type.Reflection import Prelude hiding (filter, take) +#if MIN_VERSION_random(1,3,0) +type SplittableGen g = (SplitGen g, RandomGen g) + +splitForStratified :: SplittableGen g => g -> (g, g) +splitForStratified = splitGen +#else +type SplittableGen g = RandomGen g + +splitForStratified :: SplittableGen g => g -> (g, g) +splitForStratified = split +#endif + -- | O(k * n) Take the first n rows of a DataFrame. take :: Int -> DataFrame -> DataFrame take n d = d{columns = V.map (takeColumn n') (columns d), dataframeDimensions = (n', c)} @@ -117,7 +131,7 @@ filter :: filter (Col filterColumnName) condition df = case getColumn filterColumnName df of Nothing -> throw $ - ColumnNotFoundException filterColumnName "filter" (M.keys $ columnIndices df) + ColumnsNotFoundException [filterColumnName] "filter" (M.keys $ columnIndices df) Just (BoxedColumn (column :: V.Vector b)) -> filterByVector filterColumnName column condition df Just (OptionalColumn (column :: V.Vector b)) -> filterByVector filterColumnName column condition df Just (UnboxedColumn (column :: VU.Vector b)) -> filterByVector filterColumnName column condition df @@ -194,7 +208,7 @@ filterWhere expr df = filterJust :: T.Text -> DataFrame -> DataFrame filterJust name df = case getColumn name df of Nothing -> - throw $ ColumnNotFoundException name "filterJust" (M.keys $ columnIndices df) + throw $ ColumnsNotFoundException [name] "filterJust" (M.keys $ columnIndices df) Just column@(OptionalColumn (col :: V.Vector (Maybe a))) -> filter (Col @(Maybe a) name) isJust df & apply @(Maybe a) fromJust name Just column -> df @@ -205,7 +219,8 @@ filterJust name df = case getColumn name df of filterNothing :: T.Text -> DataFrame -> DataFrame filterNothing name df = case getColumn name df of Nothing -> - throw $ ColumnNotFoundException name "filterNothing" (M.keys $ columnIndices df) + throw $ + ColumnsNotFoundException [name] "filterNothing" (M.keys $ columnIndices df) Just (OptionalColumn (col :: V.Vector (Maybe a))) -> filter (Col @(Maybe a) name) isNothing df _ -> df @@ -243,8 +258,8 @@ select cs df | L.null cs = empty | any (`notElem` columnNames df) cs = throw $ - ColumnNotFoundException - (T.pack $ show $ cs L.\\ columnNames df) + ColumnsNotFoundException + (cs L.\\ columnNames df) "select" (columnNames df) | otherwise = @@ -455,7 +470,7 @@ ghci> D.stratifiedSample (mkStdGen 42) 0.8 "label" df -} stratifiedSample :: forall a g. - (SplitGen g, RandomGen g, Columnable a) => + (SplittableGen g, Columnable a) => g -> Double -> Expr a -> DataFrame -> DataFrame stratifiedSample gen p strataCol df = let col = case strataCol of @@ -465,7 +480,7 @@ stratifiedSample gen p strataCol df = go _ [] = mempty go g (ixs : rest) = let stratum = rowsAtIndices ixs df - (g1, g2) = splitGen g + (g1, g2) = splitForStratified g in sample g1 p stratum <> go g2 rest in go gen groups @@ -479,7 +494,7 @@ ghci> D.stratifiedSplit (mkStdGen 42) 0.8 "label" df -} stratifiedSplit :: forall a g. - (SplitGen g, RandomGen g, Columnable a) => + (SplittableGen g, Columnable a) => g -> Double -> Expr a -> DataFrame -> (DataFrame, DataFrame) stratifiedSplit gen p strataCol df = let col = case strataCol of @@ -489,7 +504,7 @@ stratifiedSplit gen p strataCol df = go _ [] = (mempty, mempty) go g (ixs : rest) = let stratum = rowsAtIndices ixs df - (g1, g2) = splitGen g + (g1, g2) = splitForStratified g (tr, va) = randomSplit g1 p stratum (trAcc, vaAcc) = go g2 rest in (tr <> trAcc, va <> vaAcc) diff --git a/src/DataFrame/Operations/Transformations.hs b/src/DataFrame/Operations/Transformations.hs index 41401a1..4eae22a 100644 --- a/src/DataFrame/Operations/Transformations.hs +++ b/src/DataFrame/Operations/Transformations.hs @@ -63,7 +63,8 @@ safeApply :: DataFrame -> Either DataFrameException DataFrame safeApply f columnName d = case getColumn columnName d of - Nothing -> Left $ ColumnNotFoundException columnName "apply" (M.keys $ columnIndices d) + Nothing -> + Left $ ColumnsNotFoundException [columnName] "apply" (M.keys $ columnIndices d) Just column -> do column' <- mapColumn f column pure $ insertColumn columnName column' d @@ -163,8 +164,8 @@ applyWhere :: applyWhere condition filterColumnName f columnName df = case getColumn filterColumnName df of Nothing -> throw $ - ColumnNotFoundException - filterColumnName + ColumnsNotFoundException + [filterColumnName] "applyWhere" (M.keys $ columnIndices df) Just column -> case ifoldrColumn @@ -193,7 +194,7 @@ applyAtIndex :: applyAtIndex i f columnName df = case getColumn columnName df of Nothing -> throw $ - ColumnNotFoundException columnName "applyAtIndex" (M.keys $ columnIndices df) + ColumnsNotFoundException [columnName] "applyAtIndex" (M.keys $ columnIndices df) Just column -> case imapColumn (\index value -> if index == i then f value else value) column of Left e -> throw e Right column' -> insertColumn columnName column' df @@ -208,7 +209,8 @@ imputeCore :: DataFrame imputeCore (Col columnName) value df = case getColumn columnName df of Nothing -> - throw $ ColumnNotFoundException columnName "impute" (M.keys $ columnIndices df) + throw $ + ColumnsNotFoundException [columnName] "impute" (M.keys $ columnIndices df) Just (OptionalColumn _) -> case safeApply (fromMaybe value) columnName df of Left (TypeMismatchException context) -> throw $ TypeMismatchException (context{callingFunctionName = Just "impute"}) Left exception -> throw exception diff --git a/tests/Operations/GroupBy.hs b/tests/Operations/GroupBy.hs index 0ad27ed..47c2b9a 100644 --- a/tests/Operations/GroupBy.hs +++ b/tests/Operations/GroupBy.hs @@ -47,7 +47,7 @@ groupByColumnDoesNotExist = TestCase ( assertExpectException "[Error Case]" - (D.columnNotFound "[\"test0\"]" "groupBy" (D.columnNames testData)) + (D.columnsNotFound ["test0"] "groupBy" (D.columnNames testData)) (print $ D.groupBy ["test0"] testData) ) diff --git a/tests/Operations/Join.hs b/tests/Operations/Join.hs index 8818352..1718935 100644 --- a/tests/Operations/Join.hs +++ b/tests/Operations/Join.hs @@ -4,7 +4,9 @@ module Operations.Join where -import Data.Text (Text) +import Assertions (assertExpectException) +import Control.Exception (evaluate) +import Data.Text (Text, unpack) import Data.These import qualified DataFrame as D import qualified DataFrame.Functions as F @@ -26,6 +28,21 @@ df2 = , ("B", D.fromList ["B0" :: Text, "B1", "B2"]) ] +assertMissingJoinColumns :: String -> [Text] -> D.DataFrame -> Assertion +assertMissingJoinColumns preface missingKeys result = do + assertExpectException + preface + (if length missingKeys == 1 then "Column not found" else "Columns not found") + (evaluate (D.nRows result)) + mapM_ + ( \missingKey -> + assertExpectException preface (unpack missingKey) (evaluate (D.nRows result)) + ) + missingKeys + +assertMissingJoinColumn :: String -> Text -> D.DataFrame -> Assertion +assertMissingJoinColumn preface missingKey = assertMissingJoinColumns preface [missingKey] + testInnerJoin :: Test testInnerJoin = TestCase @@ -253,6 +270,59 @@ testOuterJoinWithCollisions = (D.sortBy [D.Asc (F.col @Text "key")] (fullOuterJoin ["key"] dfL dfR)) ) +testInnerJoinMissingKey :: Test +testInnerJoinMissingKey = + TestCase $ + assertMissingJoinColumn + "Inner join should fail when the join key is missing" + "Cats" + (innerJoin ["Cats"] df1 df2) + +testLeftJoinMissingKey :: Test +testLeftJoinMissingKey = + TestCase $ + assertMissingJoinColumn + "Left join should fail when the join key is missing" + "Cats" + (leftJoin ["Cats"] df1 df2) + +testRightJoinMissingKey :: Test +testRightJoinMissingKey = + TestCase $ + assertMissingJoinColumn + "Right join should fail when the join key is missing" + "Animals" + (rightJoin ["Animals"] df1 df2) + +testFullOuterJoinMissingKey :: Test +testFullOuterJoinMissingKey = + TestCase $ + assertMissingJoinColumn + "Full outer join should fail when the join key is missing" + "Cats" + (fullOuterJoin ["Cats"] df1 df2) + +testInnerJoinMissingKeys :: Test +testInnerJoinMissingKeys = + TestCase $ + assertMissingJoinColumns + "Inner join should report every missing join key" + ["Animals", "Cats"] + (innerJoin ["Animals", "Cats"] df1 df2) + +testInnerJoinMissingKeysSuggestion :: Test +testInnerJoinMissingKeysSuggestion = + TestCase $ + let typoDf = + D.fromNamedColumns + [ ("hello", D.fromList ["H" :: Text]) + , ("world", D.fromList ["W" :: Text]) + ] + in assertExpectException + "Inner join should report consolidated suggestions for missing join keys" + "Did you mean [\"hello\", \"world\"]?" + (evaluate (D.nRows (innerJoin ["helo", "wrld"] typoDf typoDf))) + -- Empty DataFrame fixtures: same schema as df1/df2 but zero rows. emptyDf1 :: D.DataFrame emptyDf1 = @@ -353,6 +423,12 @@ tests = , TestLabel "leftJoinWithCollisions" testLeftJoinWithCollisions , TestLabel "rightJoinWithCollisions" testRightJoinWithCollisions , TestLabel "outerJoinWithCollisions" testOuterJoinWithCollisions + , TestLabel "innerJoinMissingKey" testInnerJoinMissingKey + , TestLabel "leftJoinMissingKey" testLeftJoinMissingKey + , TestLabel "rightJoinMissingKey" testRightJoinMissingKey + , TestLabel "fullOuterJoinMissingKey" testFullOuterJoinMissingKey + , TestLabel "innerJoinMissingKeys" testInnerJoinMissingKeys + , TestLabel "innerJoinMissingKeysSuggestion" testInnerJoinMissingKeysSuggestion , TestLabel "innerJoinBothEmpty" testInnerJoinBothEmpty , TestLabel "innerJoinLeftEmpty" testInnerJoinLeftEmpty , TestLabel "innerJoinRightEmpty" testInnerJoinRightEmpty diff --git a/tests/Operations/Sort.hs b/tests/Operations/Sort.hs index 1c821dc..98d70dc 100644 --- a/tests/Operations/Sort.hs +++ b/tests/Operations/Sort.hs @@ -85,7 +85,7 @@ sortByColumnDoesNotExist = TestCase ( assertExpectException "[Error Case]" - (D.columnNotFound "[\"test0\"]" "sortBy" (D.columnNames testData)) + (D.columnsNotFound ["test0"] "sortBy" (D.columnNames testData)) (print $ D.sortBy [D.Asc (F.col @Int "test0")] testData) ) diff --git a/tests/Operations/Statistics.hs b/tests/Operations/Statistics.hs index 81baba9..db90765 100644 --- a/tests/Operations/Statistics.hs +++ b/tests/Operations/Statistics.hs @@ -244,7 +244,7 @@ correlationSelfIdentity = (abs (r - 1.0) < 1e-10) ) --- Requesting a missing column should throw ColumnNotFoundException +-- Requesting a missing column should throw ColumnsNotFoundException correlationMissingColumn :: Test correlationMissingColumn = TestCase