因果推断学习笔记(四):经典方法尝试之 Weighting (Lalonde's Dataset)
这篇文章用 Lalonde’s dataset (NSW + PSID + CPS Data) 这个数据集来尝试不同的 weighting 方法。
以下内容由Rmarkdown转来。Git: yishilin14/causal_playground
博客的md用render("./causality4_playaround_with_the_lalonde_dataset_weighting.Rmd", md_document(variant = "markdown_github"))
生成~
读数据
数据集
- Paper: Dehejia R H, Wahba S. Causal effects in nonexperimental studies: Reevaluating the evaluation of training programs[J]. Journal of the American statistical Association, 1999, 94(448): 1053-1062.
- Download: http://users.nber.org/~rdehejia/nswdata2.html
1 | # Load the datsets |
定义一个用来计算ATT的函数,之后会反复用到。
1 | estimate.causal.effect <- function(data.name, method.name, unbalance.cnt=0, data, weights) { |
首先,先看看实验数据的两组用户是否同质。比较诡异的是,实验组和对照组的年龄(age)和受教育程度(educ和nodegree)并不是很接近。两组用户比较之后,ATT为1794美金,下文把这个当作ATT的真值。
1 | t1 <- bal.tab(treat.fml, data = nsw_data_exp, estimand = "ATT", m.threshold = .1)$Balance |
Type | M.0.Un | M.1.Un | Diff.Un | M.Threshold.Un | |
---|---|---|---|---|---|
age | Contin. | 25.0538462 | 25.8162162 | 0.1065504 | Not Balanced, >0.1 |
educ | Contin. | 10.0884615 | 10.3459459 | 0.1280603 | Not Balanced, >0.1 |
black | Binary | 0.8269231 | 0.8432432 | 0.0163202 | Balanced, <0.1 |
hispan | Binary | 0.1076923 | 0.0594595 | -0.0482328 | Balanced, <0.1 |
married | Binary | 0.1538462 | 0.1891892 | 0.0353430 | Balanced, <0.1 |
nodegree | Binary | 0.8346154 | 0.7081081 | -0.1265073 | Not Balanced, >0.1 |
re74 | Contin. | 2107.0266585 | 2095.5736886 | -0.0023437 | Balanced, <0.1 |
re75 | Contin. | 1266.9090025 | 1532.0553138 | 0.0823627 | Balanced, <0.1 |
1 | t2 <- estimate.causal.effect("Experimental", "Ground Truth", 3, nsw_data_exp, rep(1, nrow(nsw_data_exp))) |
Data name | Method name | # Unbalance Var. | ATT | 95% Conf. Int. |
---|---|---|---|---|
Experimental | Ground Truth | 3 | 1794 | (481, 3108) |
接下来看看非实验数据的情况,两组用户的差异很大,不能直接比较。
1 | t1 <- bal.tab(treat.fml, data = nsw_data_obs, estimand = "ATT", m.threshold = .1)$Balance |
Type | M.0.Un | M.1.Un | Diff.Un | M.Threshold.Un | |
---|---|---|---|---|---|
age | Contin. | 28.0303030 | 25.8162162 | -0.3094453 | Not Balanced, >0.1 |
educ | Contin. | 10.2354312 | 10.3459459 | 0.0549647 | Balanced, <0.1 |
black | Binary | 0.2027972 | 0.8432432 | 0.6404460 | Not Balanced, >0.1 |
hispan | Binary | 0.1421911 | 0.0594595 | -0.0827317 | Balanced, <0.1 |
married | Binary | 0.5128205 | 0.1891892 | -0.3236313 | Not Balanced, >0.1 |
nodegree | Binary | 0.5967366 | 0.7081081 | 0.1113715 | Not Balanced, >0.1 |
re74 | Contin. | 5619.2365064 | 2095.5736886 | -0.7210838 | Not Balanced, >0.1 |
re75 | Contin. | 2466.4844431 | 1532.0553138 | -0.2902629 | Not Balanced, >0.1 |
1 | unbalance.cnt <- sum(t1$M.Threshold.Un=="Not Balanced, >0.1") |
Data name | Method name | # Unbalance Var. | ATT | 95% Conf. Int. |
---|---|---|---|---|
Observational | Do nothing | 6 | -635 | (-1960, 690) |
因果推断
这一部分,我们用WeightIt
这个包来尝试一下不同的加权方法。
WeightIt-PS
第一个方法是普通的 inverse propensity score weighting。年龄这个属性并不能调整得非常 balance,其它都还好。推断出的ATT为1214美金。
1 | W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "ps") |
Type | M.0.Adj | M.1.Adj | Diff.Adj | M.Threshold | |
---|---|---|---|---|---|
prop.score | Distance | 0.5819524 | 0.5774355 | -0.0205046 | |
age | Contin. | 24.9658450 | 25.8162162 | 0.1188496 | Not Balanced, >0.1 |
educ | Contin. | 10.4030803 | 10.3459459 | -0.0284159 | Balanced, <0.1 |
black | Binary | 0.8454795 | 0.8432432 | -0.0022363 | Balanced, <0.1 |
hispan | Binary | 0.0592923 | 0.0594595 | 0.0001672 | Balanced, <0.1 |
married | Binary | 0.1705802 | 0.1891892 | 0.0186090 | Balanced, <0.1 |
nodegree | Binary | 0.6896872 | 0.7081081 | 0.0184209 | Balanced, <0.1 |
re74 | Contin. | 2106.0448305 | 2095.5736886 | -0.0021428 | Balanced, <0.1 |
re75 | Contin. | 1496.5412337 | 1532.0553138 | 0.0110318 | Balanced, <0.1 |
1 | unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1") |
Data name | Method name | # Unbalance Var. | ATT | 95% Conf. Int. |
---|---|---|---|---|
Observational | PS | 1 | 1214 | (-402, 2830) |
WeightIt-GBM
第二个方法是gbm
,有两个属性不很balance,推断出的ATT也和真值差得十万八千里。个人猜测是这个数据集的数据量太小了,gbm
发挥不出威力,又或者是默认的参数导致了过拟合/欠拟合,这里就不深究了。
1 | W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "gbm") |
## Warning: No stop.method was provided. Using "es.mean".
## Warning: Some extreme weights were generated. Examine them with summary()
## and maybe trim them with trim().
1 | t1 <- bal.tab(W.out, m.threshold = .1)$Balance |
Type | M.0.Adj | M.1.Adj | Diff.Adj | M.Threshold | |
---|---|---|---|---|---|
prop.score | Distance | 0.6016014 | 0.7355983 | 0.6472613 | |
age | Contin. | 25.4715156 | 25.8162162 | 0.0481761 | Balanced, <0.1 |
educ | Contin. | 10.4469803 | 10.3459459 | -0.0502496 | Balanced, <0.1 |
black | Binary | 0.8287765 | 0.8432432 | 0.0144667 | Balanced, <0.1 |
hispan | Binary | 0.0447290 | 0.0594595 | 0.0147304 | Balanced, <0.1 |
married | Binary | 0.1888474 | 0.1891892 | 0.0003418 | Balanced, <0.1 |
nodegree | Binary | 0.6638038 | 0.7081081 | 0.0443043 | Balanced, <0.1 |
re74 | Contin. | 1510.2575345 | 2095.5736886 | 0.1197793 | Not Balanced, >0.1 |
re75 | Contin. | 1074.0766359 | 1532.0553138 | 0.1422625 | Not Balanced, >0.1 |
1 | unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1") |
Data name | Method name | # Unbalance Var. | ATT | 95% Conf. Int. |
---|---|---|---|---|
Observational | GBM | 2 | 461 | (-1422, 2344) |
WeightIt-CBPS
使用CBPS
的时候,所有属性都配平了,ATT也挺好。
1 | W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "cbps") |
Type | M.0.Adj | M.1.Adj | Diff.Adj | M.Threshold | |
---|---|---|---|---|---|
prop.score | Distance | 0.5769492 | 0.5756980 | -0.0056535 | |
age | Contin. | 25.8533261 | 25.8162162 | -0.0051865 | Balanced, <0.1 |
educ | Contin. | 10.3493208 | 10.3459459 | -0.0016785 | Balanced, <0.1 |
black | Binary | 0.8413337 | 0.8432432 | 0.0019096 | Balanced, <0.1 |
hispan | Binary | 0.0596662 | 0.0594595 | -0.0002068 | Balanced, <0.1 |
married | Binary | 0.1920629 | 0.1891892 | -0.0028737 | Balanced, <0.1 |
nodegree | Binary | 0.7039155 | 0.7081081 | 0.0041926 | Balanced, <0.1 |
re74 | Contin. | 2133.8214867 | 2095.5736886 | -0.0078270 | Balanced, <0.1 |
re75 | Contin. | 1512.5767534 | 1532.0553138 | 0.0060507 | Balanced, <0.1 |
1 | unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1") |
Data name | Method name | # Unbalance Var. | ATT | 95% Conf. Int. |
---|---|---|---|---|
Observational | CBPS | 0 | 1280 | (-329, 2890) |
WeightIt-EBAL
使用EBAL
的时候,所有属性都配平了,而且配平程度令人难以置信!ATT也挺好。
1 | W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "ebal") |
Type | M.0.Adj | M.1.Adj | Diff.Adj | M.Threshold | |
---|---|---|---|---|---|
age | Contin. | 25.8162278 | 25.8162162 | -1.6e-06 | Balanced, <0.1 |
educ | Contin. | 10.3459478 | 10.3459459 | -9.0e-07 | Balanced, <0.1 |
black | Binary | 0.8432400 | 0.8432432 | 3.3e-06 | Balanced, <0.1 |
hispan | Binary | 0.0594592 | 0.0594595 | 2.0e-07 | Balanced, <0.1 |
married | Binary | 0.1891903 | 0.1891892 | -1.1e-06 | Balanced, <0.1 |
nodegree | Binary | 0.7081077 | 0.7081081 | 4.0e-07 | Balanced, <0.1 |
re74 | Contin. | 2095.5851047 | 2095.5736886 | -2.3e-06 | Balanced, <0.1 |
re75 | Contin. | 1532.0583946 | 1532.0553138 | -1.0e-06 | Balanced, <0.1 |
1 | unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1") |
Data name | Method name | # Unbalance Var. | ATT | 95% Conf. Int. |
---|---|---|---|---|
Observational | EBAL | 0 | 1273 | (-344, 2890) |
WeightIt-EBCW
使用EBCW
的时候,所有属性都配平了,而且配平程度令人难以置信!ATT也挺好。
1 | W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "ebcw") |
Type | M.0.Adj | M.1.Adj | Diff.Adj | M.Threshold | |
---|---|---|---|---|---|
age | Contin. | 25.8162162 | 25.8162162 | 0 | Balanced, <0.1 |
educ | Contin. | 10.3459459 | 10.3459459 | 0 | Balanced, <0.1 |
black | Binary | 0.8432432 | 0.8432432 | 0 | Balanced, <0.1 |
hispan | Binary | 0.0594595 | 0.0594595 | 0 | Balanced, <0.1 |
married | Binary | 0.1891892 | 0.1891892 | 0 | Balanced, <0.1 |
nodegree | Binary | 0.7081081 | 0.7081081 | 0 | Balanced, <0.1 |
re74 | Contin. | 2095.5736938 | 2095.5736886 | 0 | Balanced, <0.1 |
re75 | Contin. | 1532.0553150 | 1532.0553138 | 0 | Balanced, <0.1 |
1 | unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1") |
Data name | Method name | # Unbalance Var. | ATT | 95% Conf. Int. |
---|---|---|---|---|
Observational | EBCW | 0 | 1273 | (-344, 2890) |
WeightIt-OptWeight
使用optweight
的时候,所有属性都配平了,配平程度也是完美的。ATT也挺好。
1 | W.out <- weightit(treat.fml, data = nsw_data_obs, estimand = "ATT", method = "optweight") |
Type | M.0.Adj | M.1.Adj | Diff.Adj | M.Threshold | |
---|---|---|---|---|---|
age | Contin. | 25.8162166 | 25.8162162 | 0e+00 | Balanced, <0.1 |
educ | Contin. | 10.3459459 | 10.3459459 | 0e+00 | Balanced, <0.1 |
black | Binary | 0.8432432 | 0.8432432 | 0e+00 | Balanced, <0.1 |
hispan | Binary | 0.0594595 | 0.0594595 | 0e+00 | Balanced, <0.1 |
married | Binary | 0.1891892 | 0.1891892 | 0e+00 | Balanced, <0.1 |
nodegree | Binary | 0.7081081 | 0.7081081 | 0e+00 | Balanced, <0.1 |
re74 | Contin. | 2095.5742123 | 2095.5736886 | -1e-07 | Balanced, <0.1 |
re75 | Contin. | 1532.0554281 | 1532.0553138 | 0e+00 | Balanced, <0.1 |
1 | unbalance.cnt <- sum(t1$M.Threshold=="Not Balanced, >0.1") |
Data name | Method name | # Unbalance Var. | ATT | 95% Conf. Int. |
---|---|---|---|---|
Observational | OptWeight | 0 | 1204 | (-399, 2808) |
汇总结果
汇总一下结果,结束~
1 | kable(causal.effect.results.all) |
Data name | Method name | # Unbalance Var. | ATT | 95% Conf. Int. |
---|---|---|---|---|
Experimental | Ground Truth | 3 | 1794 | (481, 3108) |
Observational | Do nothing | 6 | -635 | (-1960, 690) |
Observational | PS | 1 | 1214 | (-402, 2830) |
Observational | GBM | 2 | 461 | (-1422, 2344) |
Observational | CBPS | 0 | 1280 | (-329, 2890) |
Observational | EBAL | 0 | 1273 | (-344, 2890) |
Observational | EBCW | 0 | 1273 | (-344, 2890) |
Observational | OptWeight | 0 | 1204 | (-399, 2808) |