pandasのDataFrameを戻り値とする関数をテストする中で、
期待値として用意したDataFrameと、出力値のDataFrameの比較結果が一致しません。
期待値と出力値をそれぞれコンソールに出力してみましたが、
見た目上は同じ値となっています。
ただ、テスト対象の関数にて、浮動小数点の乗算を行っており、
この箇所が原因となり、期待値と出力値が一致しないのではないかと考えています。
もしかすると、pytestの書き方にも原因があるかもしれませんが、
浮動小数点を含むDataFrame同士の比較をどのように書けばよいでしょうか?
- テスト対象の関数(model.py)
import pandas as pd def calc_square(df) -> pd.DataFrame: df_result: pd.DataFrame = df.copy() df_square: pd.DataFrame = df['height'] ** 2 df_result.insert(2, 'height^2', df_square) return df_result
- テストケース(test_model.py)
import pandas as pd import model def test_01(): # 入力値 dict_input = {'name': ['userA'], 'height': [1.11]} df_input: pd.DataFrame = pd.DataFrame(data = dict_input) # 期待値 dict_expected = {'name': ['userA'], 'height': [1.11], 'height^2': [1.2321]} df_expected: pd.DataFrame = pd.DataFrame(data = dict_expected) # 出力値 df_output: pd.DataFrame = model.calc_square(df_input) # 比較 assert df_expected.equals(df_output)
- テスト実行結果(一部抜粋)
E assert False E + where False = <bound method NDFrame.equals of name height height^2\n0 userA 1.11 1.2321>( name height height^2\n0 userA 1.11 1.2321) E + where <bound method NDFrame.equals of name height height^2\n0 userA 1.11 1.2321> = name height height^2\n0 userA 1.11 1.2321.equals tests\src\test_model.py:41: AssertionError =============================================== short test summary info =============================================== FAILED tests/src/test_model.py::test_01 - assert False ============================================= 1 failed in 2.50s =============================================
■追記
DataFrame同士で比較するために、比較関数を作成する必要があるとのことで、
当初の質問のような、DataFrame同士の比較という形ではなく、
DataFrameの列単位での比較という形でテストを実装しました。
- テストケース(test_model.py)
import pandas as pd import model def test_01(): # 入力値 dict_input = {'name': ['userA'], 'height': [1.11]} df_input: pd.DataFrame = pd.DataFrame(data = dict_input) # 期待値 dict_expected = {'name': ['userA'], 'height': [1.11], 'height^2': [1.2321]} df_expected: pd.DataFrame = pd.DataFrame(data = dict_expected) # 出力値 df_output: pd.DataFrame = model.calc_square(df_input) # 比較 assert df_expected.equals(df_output.round(4) # round()にて小数点以下4桁で丸め
回答2件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2021/07/30 14:10