diff --git a/ihp-ide/IHP/SchemaCompiler.hs b/ihp-ide/IHP/SchemaCompiler.hs index 4d7140047..fb6c90334 100644 --- a/ihp-ide/IHP/SchemaCompiler.hs +++ b/ihp-ide/IHP/SchemaCompiler.hs @@ -1004,11 +1004,28 @@ compileSetFieldInstances table@(CreateTable { name, columns }) = unlines (map co | otherwise = name' compileUpdateFieldInstances :: (?schema :: Schema) => CreateTable -> Text -compileUpdateFieldInstances table@(CreateTable { name, columns }) = unlines (map compileSetField (dataFields table)) +compileUpdateFieldInstances table@(CreateTable { name, columns, inherits }) = + unlines (map compileSetField (dataFields table)) where modelName = tableNameToModelName name - typeArgs = dataTypeArguments table - compileSetField (name, fieldType) = "instance UpdateField " <> tshow name <> " (" <> compileTypePattern table <> ") (" <> compileTypePattern' name <> ") " <> valueTypeA <> " " <> valueTypeB <> " where\n {-# INLINE updateField #-}\n updateField newValue (" <> compileDataTypePattern table <> ") = " <> modelName <> " " <> (unwords (map compileAttribute (table |> dataFields |> map fst))) + + -- Convert the model name to its plural, lowercase form to match the column name. + colName = modelName |> pluralize |> Text.toLower + + -- Determine the type arguments considering inheritance. + typeArgs = case inherits of + Nothing -> dataTypeArguments table + Just parentTableName -> + let parentTableDef = findTableByName parentTableName + in case parentTableDef of + Just parentTable -> + let parentTypeArgs = dataTypeArguments parentTable.unsafeGetCreateTable + in dataTypeArguments table + <> filter (\fieldName -> Text.toLower fieldName /= colName) parentTypeArgs + Nothing -> error $ "Parent table " <> cs parentTableName <> " not found for table " <> cs name <> "." + + compileSetField (name, fieldType) = + "instance UpdateField " <> tshow name <> " (" <> compileTypePattern table <> ") (" <> compileTypePattern' name <> ") " <> valueTypeA <> " " <> valueTypeB <> " where\n {-# INLINE updateField #-}\n updateField newValue (" <> compileDataTypePattern table <> ") = " <> modelName <> " " <> (unwords (map compileAttribute (table |> dataFields |> map fst))) where (valueTypeA, valueTypeB) = if name `elem` typeArgs @@ -1021,7 +1038,10 @@ compileUpdateFieldInstances table@(CreateTable { name, columns }) = unlines (map | otherwise = name' compileTypePattern' :: Text -> Text - compileTypePattern' name = tableNameToModelName table.name <> "' " <> unwords (map (\f -> if f == name then name <> "'" else f) (dataTypeArguments table)) + compileTypePattern' name = + let filteredArgs = map (\f -> if f == name then name <> "'" else f) typeArgs + in tableNameToModelName table.name <> "' " <> unwords filteredArgs + compileHasFieldId :: (?schema :: Schema) => CreateTable -> Text compileHasFieldId table@CreateTable { name, primaryKeyConstraint } = cs [i|